语义分割单通道和多通道输出交叉熵损失函数的计算问题

摘要

本文验证了语义分割任务下,单通道输出和多通道输出时,使用交叉熵计算损失值的细节问题。对比验证了使用简单的函数和自带损失函数的结果,通过验证,进一步加强了对交叉熵的理解。

交叉熵损失函数

交叉熵损失函数的原理和推导过程,可以参考这篇博文,交叉熵的计算公式如下:

[CE(p,q) = -p*log(q) ]

其中 (q) 为预测的概率,(q∈[0,1])(p) 为标签,(p∈{0,1})

而交叉熵损失函数则是利用上式计算每一个分类的交叉熵之和。对于概率,所有分类的概率 (q) 之和满足相加等于1,而对于标签,则需要进行one-hot编码,使得有且只有一个分类的 (p) 为1,其余的分类为0。

单通道输出时的交叉熵损失计算

单通道输出交叉熵损失计算示意图

首先,假设我们研究的是一个二分类语义分割问题。

网络的输入是一个 2×2 的图像,设置 batch_size 为 2,网络输出单通道特征图。网络的标签也是一个 2 ×2 的二进制掩模图(即只有0和1的单通道图像)。

我们在 pytorch 中将其定义:

import torch

# 假设输出一个 [batch_size=2, channel=1, height=2, width=2] 格式的张量 x1
x1 = torch.tensor(
    [[[[ 0.43, -0.25],
        [-0.32, 0.69]]],

        [[[-0.29, 0.37],
          [0.54,  -0.72]]]])

# 假设标签图像为与 x1 同型的张量 y1
y1 = torch.tensor(
    [[[[0., 0.],
        [0., 1.]]],

        [[[0., 0.],
          [1.,  1.]]]])

在进行交叉熵前,首先需要做一个 sigmoid 操作,将数值压缩到0到1之间:

# 根据二进制交叉熵的计算过程
# 首先进行sigmoid计算,然后与标签图像进行二进制交叉熵计算,最后取平均值,即为损失值

# 1. sigmoid
s1 = torch.sigmoid(x1)
s1

'''
out:
tensor([[[[0.6059, 0.4378],
          [0.4207, 0.6660]]],


        [[[0.4280, 0.5915],
          [0.6318, 0.3274]]]]
'''

然后进行交叉熵计算,由于计算的是每个像素的损失值,所以要取个平均值:

# 2.交叉熵计算
loss_cal = -1*(y1*torch.log(s1)+(1-y1)*torch.log(1-s1)) # 此处相当于一个one-hot编码
loss_cal_mean = torch.mean()
loss_cal_mean

'''
out:
tensor(0.6861)
'''

为了验证结果,我们使用 pytorch 自带的二进制交叉熵损失函数计算:

# 使用torch自带的二进制交叉熵计算
loss_bce = torch.nn.BCELoss()(s1,y1)
loss_bce

'''
out:
tensor(0.6861)
'''

当计算损失值前没有进行 sigmoid 操作时,pytorch 还提供了包含这个操作的二进制交叉熵损失函数:

# 使用带sigmoid的二进制交叉熵计算
loss_bce2 = torch.nn.BCEWithLogitsLoss()(x1,y1)
loss_bce2

'''
out:
tensor(0.6861)
'''

可以看到,我们使用了三种方式,计算了交叉熵损失,结果一致。

多通道输出时的交叉熵损失计算

多通道输出交叉熵损失计算示意图

首先,假设我们研究的是一个二分类语义分割问题。

网络的输入是一个 2×2 的图像,设置 batch_size 为 2,网络输出多(二)通道特征图。网络的标签也是一个 2 ×2 的二进制掩模图(即只有0和1的单通道图像)。

我们在 pytorch 中将其定义:

# 假设输出一个[batch_size=2, channel=2, height=2, width=2]格式的张量 x1
x1 = torch.tensor([[[[ 0.3164, -0.1922],
          [ 0.4326, -1.2193]],

         [[ 0.6873,  0.6838],
          [ 0.2244,  0.5615]]],


        [[[-0.2516, -0.8875],
          [-0.6289, -0.1796]],

         [[ 0.0411, -1.7851],
          [-0.3069, -1.0379]]]])

# 假设标签图像为与x1同型,然后去掉channel的张量 y1 (注意两点,channel没了,格式为LongTensor)
y1 = torch.LongTensor([[[0., 1.],
         [1., 0.]],

        [[1., 1.],
         [0., 1.]]])

在进行交叉熵前,首先需要做一个 softmax 操作,将数值压缩到0到1之间,且使得各通道之间的数值之和为1:

# 1.softmax
s1 = torch.softmax(x1,dim=1)
s1

'''
out:
tensor([[[[0.4083, 0.2940],
          [0.5519, 0.1442]],

         [[0.5917, 0.7060],
          [0.4481, 0.8558]]],


        [[[0.4273, 0.7105],
          [0.4202, 0.7023]],

         [[0.5727, 0.2895],
          [0.5798, 0.2977]]]])
'''

对于标签图,由于其张量的形状与网络输出张量不一样,因此需要做一个one-hot转换,什么是one-hot?请看这篇博文

# 2.one-hot
y1_one_hot = torch.zeros_like(x1).scatter_(dim=1,index=y1.unsqueeze(dim=1),src=torch.ones_like(x1))
y1_one_hot

'''
out:
tensor([[[[1., 0.],
          [0., 1.]],

         [[0., 1.],
          [1., 0.]]],


        [[[0., 0.],
          [1., 0.]],

         [[1., 1.],
          [0., 1.]]]])
'''

这里需要重点理解这个scatter_函数,他起到的作用十分关键,one-hot 转换时,其实可以理解为将一个同型的全1矩阵中的元素,有选择性的复制到全0矩阵中的过程,这里的选择依据就是我们的标签图,它决定了哪个位置和通道上的元素取值为 1 。在scatter_ 函数中,dim 决定了用于确定我们在哪个维度上开始定位要建立联系的元素,index是我们选择的依据。

按照交叉熵定义,继续计算:

# 交叉熵计算
loss_cal = -1 *(y1_one_hot * torch.log(s1)) 
loss_cal_mean = loss_cal.sum(dim=1).mean() # 在batch维度下计算每个样本的交叉熵
loss_cal_mean

'''
out:
tensor(0.9823)
'''

我们也可以使用 pytorch 自带的交叉熵损失函数计算:

loss_ce = torch.nn.CrossEntropyLoss()(x1,y1)
loss_ce

'''
tensor(0.9823)
'''

可以看到,两种方式结果一样。

结论

  • 交叉熵本质上将一群对象择其一进行研究,自然就变成一个二进制问题,即是这个对象或不是这个对象,然后将标签与概率融进公式中,计算损失值。对于每一个对象都可以计算一个损失值,求个平均值就是最后这个群体的损失值了。

  • 不论是sigmoid或者softmax,我们都是在有目的将数据规整到0到1之间,从而形成一个概率值,sigmoid针对的是二分类问题,因此算出一个概率,另一个用一减去就到了。多分类问题,由于最后会输出对应数量的值,softmax 能够将这些值转换到0到1,并满足加起来等于1,这样的话,当我们只研究其中一个类的概率时,其他类的概率自然就是用1减去这个类的概率了,又回到了二分类问题。

  • 对于二分类语义分割问题,其实采用上述两种方式都是可以的。

参考资料

[1] pytorch中的 scatter_()函数使用和详解

[2] pytorch交叉熵使用方法

[3] pytorch损失函数之nn.BCELoss()(为什么用交叉熵作为损失函数)

[4] pytorch损失函数之nn.CrossEntropyLoss()、nn.NLLLoss()

[5] PyTorch中名不符实的损失函数

[6] Pytorch中Softmax、Log_Softmax、NLLLoss以及CrossEntropyLoss的关系与区别详解

[7] 二分类问题,应该选择sigmoid还是softmax?

原文地址:https://www.cnblogs.com/gshang/p/13887133.html