https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
import torch
from torch.nn import CrossEntropyLoss
BATCH_SIZE = 2
MAX_SEQ_LENGTH = 3
EMBEDDING_SIZE = VOCAB_SIZE = 5
input_ids = torch.randn((BATCH_SIZE, MAX_SEQ_LENGTH, EMBEDDING_SIZE))
input_ids
tensor([[[-1.2207e-01, 9.4908e-01, 9.6387e-02, 1.4286e+00, 1.9476e+00],
[ 6.7663e-01, 8.2334e-04, -2.6130e+00, 3.3753e-02, -3.8190e-01],
[-2.3500e-01, 1.7386e-01, -9.2278e-01, 6.7210e-01, -5.9908e-01]],
[[-7.6696e-01, 1.0766e+00, -3.9904e-01, 1.3112e-01, 1.4053e-02],
[ 1.1152e+00, 1.4323e+00, 8.4845e-01, 8.5321e-01, -2.1357e-01],
[ 5.5852e-01, 1.7036e-01, 4.6033e-01, 1.2075e+00, -9.5198e-01]]])
labels = torch.randint(low=0, high=4, size=(BATCH_SIZE, MAX_SEQ_LENGTH))
labels
tensor([[2, 0, 2],
[2, 0, 3]])
loss_fct = CrossEntropyLoss()
"""
- Input: :math:`(N, C)` where `C = number of classes`, or
:math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
in the case of `K`-dimensional loss.
- Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or
:math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
"""
# 注意:input是一个隐向量,loss_fct内回自动算arg_max,label是class整型值
loss_fct(input=input_ids.view(-1, EMBEDDING_SIZE), target=labels.view(-1))
tensor(1.7686)
# EMBEDDING_SIZE即CLASS类别数,即字典大小VOCAB_SIZE。注意label值要与EMBEDDING_SIZE一致
labels = torch.randint(low=0, high=10, size=(BATCH_SIZE, MAX_SEQ_LENGTH))
labels
tensor([[6, 2, 5],
[7, 7, 1]])
# 越界
loss_fct(input=input_ids.view(-1, EMBEDDING_SIZE), target=labels.view(-1))
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-37-cbf2ac1ed94f> in <module>
1 # 越界
----> 2 loss_fct(input=input_ids.view(-1, EMBEDDING_SIZE), target=labels.view(-1))
/opt/conda/envs/blog/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
/opt/conda/envs/blog/lib/python3.8/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
959
960 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 961 return F.cross_entropy(input, target, weight=self.weight,
962 ignore_index=self.ignore_index, reduction=self.reduction)
963
/opt/conda/envs/blog/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2466 if size_average is not None or reduce is not None:
2467 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
2469
2470
/opt/conda/envs/blog/lib/python3.8/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
2262 .format(input.size(0), target.size(0)))
2263 if dim == 2:
-> 2264 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
2265 elif dim == 4:
2266 ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
IndexError: Target 6 is out of bounds.