Pytorch分类问题中的交叉熵损失函数使用

本文主要介绍一下分类问题中损失函数的使用,对于二分类、多分类、多标签这个三个不同的场景,在 Pytorch 中的损失函数使用稍有区别。


损失函数

Softmax

在介绍损失函数前,先介绍一下什么是 Softmax,通常在分类问题中会将 Softmax 搭配 Cross Entropy 一同使用。Softmax 的计算公式定义如下:

$$mathtt{softmax(x_i)={exp(x_i) over {sum_{j} exp(x_j)}}}$$

例如,我们现在有一个数组 [1, 2, 3],这三个数的 Softmax 输出是:

$$mathtt{softmax(1)={exp(1) over exp(1)+exp(2)+exp(3)}=0.09}$$

$$mathtt{softmax(2)={exp(2) over exp(1)+exp(2)+exp(3)}=0.2447}$$

$$mathtt{softmax(3)={exp(3) over exp(3)+exp(2)+exp(3)}=0.6652}$$

所以 Softmax 直白来说就是将原来输出是 [1, 2, 3] 的数组,通过 Softmax 函数作用后映射成范围为 (0, 1) 的值,而这些值的累积和为 1(满足概率性质),那么我们就可以将它理解成概率,在最后选取输出结点的时候,我们就可以选取概率最大(也就是值对应最大的)结点,作为我们的预测目标。

Cross Entropy

对于 Cross Entropy,以下是我见过最喜欢的一个解释:

在机器学习中,P 往往用来表示样本的真实分布,比如 [1, 0, 0] 表示当前样本属于第一类;Q 往往用来表示模型所预测的分布,比如 [0.7, 0.2, 0.1]。这里直观的理解就是,如果用 P 来描述样本,那就非常完美,而用 Q 来描述样本,虽然可以大致描述,但是不是那么的完美,信息量不足,需要额外的一些信息增量才能达到和 P 一样完美的描述。如果我们的 Q 通过反复训练,也能趋近完美地描述样本,那么就不再需要额外的信息增量,这时 Q 等价于 P。

所以,如果按照真实分布 P 来衡量描述一个样本所需的编码长度的期望,即平均编码长度,或称信息熵:

$$mathtt{H(p) = -sum_{i=1}^C p(x_i)log(p(x_i))}$$

如果使用拟合分布 Q 来表示来自真实分布 P 的编码长度的期望,即平均编码长度,或称交叉熵:

$$mathtt{H(p,q)=-sum_{i=1}^C p(x_i)log(q(x_i))}$$

所以 H(p, q) >= H(q) 恒成立。我们把 Q 得到的平均编码长度比 P 得到的平均编码长度多出的 bit 数称为“相对熵”,也叫“KL散度”,用来衡量两个分布的差异:

$$mathtt{D_{KL}(p||q)=H(p,q)-H(p)=sum_{i=1}^C p(x_i)log({p(x_i) over q(x_i)})}$$

在机器学习的分类问题中,我们希望通过训练来缩小模型预测和标签之间的差距,即“KL散度”越小越好,根据上面公式,“KL散度”中 H(p) 项不变,所以在优化过程中我们只需要关注“交叉熵”即可,这就是我们使用“交叉熵”作为损失函数的原因。

交叉熵如何计算?

在做完 Softmax 之后,再计算交叉熵,作为损失函数:

$$mathtt{L(widehat{y}, y) =- sum_{i=1}^C y_i log(widehat{y}\_{i})}$$

这里的 $mathtt{widehat{y}}$ 指的是预测值(Softmax 层的输出)。$mathtt{y}$ 指的是真实值,是一个 one-hot 编码后的 C 维向量。什么是 One-hot Encoding?如果一个样例 x 是类别 i,则它标签 y 的第 i 维的值为 1,其余维的值为 0。例如,x 是类别 2,共 4 类,则其标签 y 的 值为 [0, 1, 0, 0]。

多分类和多标签:

1. 二分类:

表示分类任务中有两个类别。

[Pytorch Example]

2. 多分类:

一个样本属于且只属于多个类中的一个,一个样本只能属于一个类,不同类之间是互斥的。对于多分类问题,我们通常有几种处理方法:1)直接多分类;2)one vs one;3)one vs rest 训练 N 个分类器,测试的时候若仅有一个分类器预测为正的类别则对应的类别标记作为最终分类结果,若有多个分类器预测为正类,则选择置信度最大的类别作为最终分类结果。

3. 多标签:

一个样本可以属于多个类别(或标签),不同类之间是有关联的。

激活函数和损失函数搭配选择:

Problem Actv Func Loss Func
binary sigmoid BCE
Multiclass softmax CE
Multilabel sigmoid BCE


 

Pytorch 中的交叉熵损失函数使用

以下详细介绍在 Pytorch 如何使用交叉熵作为损失函数。

对于分类问题,采用交叉熵或对数似然损失函数,这两种损失函数在大部分情况下并无差别,只是因为推导方式不同,而叫法不同。对数似然损失函数的思想是极大化似然函数,而交叉熵是从信息熵的角度考虑。Pytorch 提供了两个类来计算交叉熵,分别是 nn.CrossEntropyLoss() 和 nn.NLLLoss(),而 Pytorch 中的 nn.CrossEntropyLoss() 函数包含 nn.LogSoftmax() 和 nn.NLLLoss()

交叉熵损失函数常用于分类问题,一般来说有:binary,multiclass 和 multilabel 三种使用情形。

Binary Case

二分类问题我们一般直接使用 nn.BCELoss(pred, target) 或者 nn.BCEWithLogitsLoss(pred, target),区别是后者无需对 pred 输出手动做 sigmoid,其余都一样。

pred:[N] 或者 [N, *],where THE STAR means any number of additional dimensions,也就是 pred 可以是任意 shape 的 Tensor;

target:shape 和 pred 一样即可。

In case of binary semantic segmentation:

pred:[N, 1, H, W];

target:[N, 1, H, W];

import torch
import torch.nn as nn
import torch.nn.functional as F


N, C = 16, 1

input = torch.randn(N, 1, 224, 224)    # inupt: [16, 1, 224, 224]
conv = nn.Conv2d(1, C, kernel_size=(3, 3), padding=1)
pred = conv(input)    # pred: [16, 1, 224, 224]

target = torch.empty(N, 1, 224, 224, dtype=torch.float).random_(0, 2)    # target: [16, 1, 224, 224]

loss = F.binary_cross_entropy(F.sigmoid(pred), target)

In case of binary image classification:

pred:[N, 2] 或者 [N];

target:[N, 2] 或者 [N];

import torch
import torch.nn as nn
import torch.nn.functional as F


N, C = 16, 2

input = torch.randn(N, 1, 224, 224)    # input: [16, 1, 224, 224]
conv = nn.Conv2d(1, 64, kernel_size=(3, 3), padding=1)
fc = nn.Linear(64*224*224, C)

pred = conv(input)
pred = fc(pred.view(N, 64*224*224))    # pred: [16, 2]

target = torch.empty(N, 2, dtype=torch.float).random_(0, 2)    # target: [16, 2]

loss = F.binary_cross_entropy(F.sigmoid(pred), target)

PS:

(1)对于 nn.BCELoss(pred, target),pred 和 target 的数据类型都是 float;

(2)关于添加 weight:

class_weight = torch.autograd.Variable(torch.FloatTensor([1, 10])).cuda()    # 这里正例比较少,因此权重要大一些
weight = class_weight[input_mask.long()]
loss = F.binary_cross_entropy(F.sigmoid(pred), target, weight=weight)

Multi-Class Case

这时我们直接使用 nn.CrossEntropyLoss(pred, target) 或者 nn.LogSoftmax() + nn.NLLLoss()

 

pred:[N, C] 或者 [N, C, H, W],其中 C 代表输出通道数,也就是一共有多少类;

target:[N] 或者 [N, H, W],每个数的数值 0 <= target[i] <= C-1;

 

In case of multi-class semantic segmentation:

pred:[N, C, H, W];

target:[N, H, W];

例如,现在我们有三种类别,Batch Size 为 16,那么 pred 的 shape 为 [16, 3, H, W],target 的 shape 为 [16, H, W],target 值的范围为 0 <= target[i] <= 2。

import torch
import torch.nn as nn
import torch.nn.functional as F


N, C = 16, 3
criterion = nn.NLLLoss()
m = nn.LogSoftmax(dim=1)

input = torch.randn(N, 1, 224, 224)    # [16, 1, 224, 224]
conv = nn.Conv2d(1, C, kernel_size=(3, 3), padding=1)
pred = conv(input)    # [16, 3, 224, 224]

# each element in target has to have 0 <= value < C
target = torch.empty(N, 224, 224, dtype=torch.long).random_(0, C)    # [16, 224, 224]

loss = criterion(m(pred), target)

In case of multi-class image classification:

pred:[N, C];

target:[N];

例如,我们现在有三种类别,Batch Size 为 16,那么 pred 的 shape 为 [16, 3],target 的 shape 为 [16],target 的范围为 0 <= target[i] <= 2。

import torch
import torch.nn as nn
import torch.nn.functional as F


N, C = 16, 3
m = nn.LogSoftmax(dim=1)
criterion = nn.NLLLoss()

input = torch.randn(N, 1, 224, 224)    # [16, 1, 224, 224]
conv = nn.Conv2d(1, 64, kernel_size=(3, 3), padding=1)
fc = nn.Linear(64*224*224, C)

pred = conv(input)
pred = pred.contiguous().view(-1)
pred = pred.view(N, 64*224*224)
pred = fc(pred)    # [16, 3]

target = torch.empty(N, dtype=torch.long).random_(0, 3)    # [16]

loss = criterion(m(pred), target)

PS:

(1)  nn.CrossEntropyLoss() 和 nn.LogSoftmax() + nn.NLLLoss() 等效:

import torch
import torch.nn as nn
import torch.nn.functional as F


# ---- nn.CrossEntropyLoss() ----
criterion = nn.CrossEntropyLoss()
pred = torch.Tensor([[-0.7715, -0.6205,-0.2562]])
target = torch.tensor([2])
loss = criterion(pred, target)

# ---- nn.NLLLoss() + nn.LogSoftmax(dim=1) ----
criterion = nn.NLLLoss()
m = nn.LogSoftmax(dim=1)
pred = torch.Tensor([[-0.7715, -0.6205,-0.2562]])
target = torch.tensor([2])
loss = criterion(m(pred), target)

print(loss)

可见计算结果都是 0.8294;

(2) target 的数据类型是 long;
(3) 参数 reduce,size_average  和 reduction  的使用请参考:Ref1 和 Ref2

原文地址:https://www.cnblogs.com/hmlovetech/p/14515622.html