处理样本不平衡LOSS—Focal Loss

0 前言

Focal Loss是为了处理样本不平衡问题而提出的,经时间验证,在多种任务上,效果还是不错的。在理解Focal Loss前,需要先深刻理一下交叉熵损失,和带权重的交叉熵损失。然后我们从样本权重的角度出发,理解Focal Loss是如何分配样本权重的。Focal是动词Focus的形容词形式,那么它究竟Focus在什么地方呢?(详细的代码请看Gitee)。

1 交叉熵

1.1 交叉熵损失(Cross Entropy Loss)

(N)个样本,输入一个(C)分类器,得到的输出为(Xin mathcal{R}^{N imes C}),它共有(C)类;其中某个样本的输出记为(xin mathcal{R}^{1 imes C}),即(x[j])(X)的某个行向量,那么某个交叉熵损失可以写为如下公式

[ ext{loss}left( x, ext{class} ight) =-log left( frac{exp left( xleft[ ext{class} ight] ight)}{sum_j{expleft( xleft[ j ight] ight)}} ight) =-xleft[ ext{class} ight] +log left( sum_j{expleft( xleft[ j ight] ight)} ight) ag{1-1} ]

其中( ext{class}in [0, C))是这个样本的类标签,如果给出了类标签的权重向量(Win mathcal{R}^{1 imes C}),那么带权重的交叉熵损失可以更改为如下公式

[operatorname{loss}(x, ext {class})=W[ ext {class}]left(-x[ ext {class}]+log left(sum_{j} exp (x[j]) ight) ight) ag{1-2} ]

最终对这个(N)个样本的损失求和或者求平均

[ell = egin{cases} sum_{i}^{N}{ ext{loss}(x^{(i)}, ext{class}^{(i)})}& ext{, sum}\ dfrac{1}{N}sum_{i}^{N}{ ext{loss}(x^{(i)}, ext{class}^{(i)})}& ext{, mean} end{cases} ag{1-3} ]

这个就是我们平时经常用到的交叉熵损失了。

1.2 二分类交叉熵损失(Binary Cross Entropy Loss)

上面所提到的交叉熵损失是适用于多分类(二分类及以上)的,但是它的公式看起来似乎与我们平时在书上或论文中看到的不一样,一般我们常见的交叉熵损失公式如下:

[l = -ylog{hat{y}}-(1-y)log{(1-hat{y})} ]

这是一个典型的二分类交叉熵损失,其中(yin{0, 1})表示标签值,(hat{y}in[0, 1])表示分类模型的类别1预测值。上面这个公式是一个综合的公式,它等价于:

[l = egin{cases} -log{hat{y}_0} &y=0 \ -log{hat{y}_1} &y=1 end{cases}; quad ext{where}quad hat{y}_0+hat{y}_1 = 1 ]

其中(hat{y}_0, hat{y}_1)是二分类模型输出的2个伪概率值

例:如果二分类模型是神经网络,且最后一层为: 2个神经元+Softmax,那么(hat{y}_0, hat{y}_1)就对应着这两个神经元的输出值。当然它也可以带上类别的权重。

同样地,有(N)个样本,输入一个2分类器,得到的输出为(Xin mathcal{R}^{N imes 2}),再经过Softmax函数,(hat{Y}=sigma(X)in mathcal{R}^{N imes 2}),标签为(Yin mathcal{R}^{N imes 2}),每个样本的二分类损失记为(l^{(i)}, i=0,1,2,cdots,N),最终对这个(N)个样本的损失求和或者求平均

[ell = egin{cases} sum_{i}^{N}l^{(i)}& ext{, sum}\ dfrac{1}{N}sum_{i}^{N}l^{(i)}& ext{, mean} end{cases}; l^{(i)} = -y^{(i)}log{hat{y}^{(i)}}-(1-y^{(i)})log{(1-hat{y}^{(i)})} ]

注:如果一次只训练一个样本,即(N=1),那么上面带类别权重的损失中的权重是无效的。因为权重是相对的,某一个样本的权重大,那么必然需要有另一个样本的权重小,这样才能体现出这一批样本中某些样本的重要性。(N=1)时,已没有权重的概念,它是唯一的,也是最重要的。(N=1),或者说batch_size=1这种情况在训练视频文章数据时,是会常出现的。由于我们显示/内存的限制,而视频/文章数据又比较大,一次只能训练一个样本,此时我们就需要注意权重的问题了。

2 Focal Loss

2.1 基本思想

一般来讲,Focal Loss(以下简称FL)[1]是为解决样本不平衡的问题,但是更准确地讲,它是为解决难分类样本(Hard Example)易分类样本(Easy Example)的不平衡问题。对于样本不平衡,其实通过上面的带权重的交叉熵损失便可以一定程度上解决这个问题,但是在实际问题中,以权重来解决样本不平衡问题的效果不够理想,此时我们应当思考,表面上我们的样本不平衡,但实质上导致效果不好的原因也许并不是简单地因为样本不平衡,而是因为样本中存在一些Hard Example,同时存在许多Easy Example,Easy Example虽然容易被分类器分辨,损失较小,但是由于其数量大,它们累积起来依然于大于Hard Example的Loss值,因此我们需要给Hard Example较大的权重,而Easy Example较小的权重

那么什么叫Hard Example,什么叫Easy Example呢?看下面的图就知道了。

fig2-1 fig2-2 fig2-3 fig2-4
图2-1 Hard Example 图2-2 Easy Example1 图2-3 Easy Example2 图2-4 Example Space

假设,我们的任务是训练一个分类器,分类出人和马,对于上面的三张图,图2-2和图2-3应该是非常容易判断出来的,但是图2-1就是不那么容易了,它即有人的特征,又有马的特征,非常容易混淆。这种样本虽然在数据集中出现的频率可能并不高,但是想要提高分类器的性能,需要着力解决这种样本分类问题。

提出Hard Example和Easy Example后,可以将样本空间划分为如图2-4所示的样本空间。其中纵轴为多数类样本(Majority Class)少数类样本(Minority Class),上面的带权重的交叉熵损失只能解决Majority Class和Minority Class的样本不平衡问题,并没有考虑Hard Example和Easy Example的问题,Focal Loss的提出就是为解决这个难易样本的分类问题。

2.2 Focal Loss解决方案

要解决难易样本的分类问题,首先就需要找出Hard Example和Easy Example。这对于神经网络来说,应该是一件比较容易的事情。如图2-6所示,这是一个5分类的网络,神经网络的最后一层输出时,加上一个Softmax或者Sigmoid就会得到输出的伪概率值,代表着模型预测的每个类别的概率,

fig2-5 fig2-6
图2-6 Easy Example Classifier Output 图2-7 Hard Example Classifier Output

图2-6中,样本标签为1,分类器输出值最大的为第1个神经元(以0开始计数),这刚好预测准确,而且其输出值2也比其它神经元的输出值要大不少,因此可以认为这是一个易分类样本(Easy Example);图2-7的样本标签是3,分类器输出值最大的为第4个神经元,并且这几个神经元的输出值都相差不大,神经网络无法准确判断这个样本的类别,所以可以认为这是一个难分类样本(Hard Example)。其实说白了,判断Easy/Hard Example的方法就是看分类网络的最后的输出值。如果网络预测准确,且其概率较大,那么这是一个Easy Example,如果网络输出的概率较小,这是一个Hard Example。下面用数学公式严谨地表达来Focal Loss的表达式。

令一个(C)类分类器的输出为(oldsymbol{y}in mathcal{R}^{C imes 1}),定义函数(f)将输出(oldsymbol{y})转为伪概率值(oldsymbol{p}=f(oldsymbol{y})),当前样本的类标签为(t),记(p_t=oldsymbol{p}[t]),它表示分类器预测为(t)类的概率值,再结合上面的交叉熵损失,定义Focal Loss为:

[ ext{FL} = -(1-p_t)log(p_t) ag{2-1} ]

这实质就是交叉熵损失前加了一个权重,只不过这个权重有点不一样的来头。为了更好地控制前面权重的大小,可以给前面的权重系数添加一个指数(gamma),那么更改式(2-1):

[ ext{FL} = -(1-p_t)^gammalog(p_t) ag{2-2} ]

其中(gamma)一值取值为2就好,(gamma)取值为0时与交叉熵损失等价,(gamma)越大,就越抑制Easy Example的损失,相对就会越放大Hard Example的损失。同时为解决样本类别不平衡的问题,可以再给式(2-2)添加一个类别的权重(alpha_t)(这个类别权重上面的交叉熵损失已经实现):

[ ext{FL} = -alpha_t(1-p_t)^gammalog(p_t) ag{2-3} ]

到这里,Focal Loss理论就结束了,非常简单,但是有效。

3 Focal Loss实现(Pytorch)

3.1 交叉熵损失实现(numpy)

为了更好的理解Focal Loss的实现,先理解交叉熵损失的实现,我这里用numpy简单地实现了一下交叉熵损失。

import numpy as np

def cross_entropy(output, target):
    out_exp = np.exp(output)
    out_cls = np.array([out_exp[i, t] for i, t in enumerate(target)])
    ce = -np.log(out_cls / out_exp.sum(1))
    return ce

代码中第5行,可能稍微有点难以理解,它不过是为了找出标签对应的输出值。比如第2个样本的标签值为3,那它分类器的输出应当选择第2行,第3列的值。

3.2 Focal Loss实现

下面的代码的1012行:依据输出,计算概率,再将其转为`focal_weight`;1516行,将类权重和focal_weight添加到交叉熵损失,得到最终的focal_loss;18~21行,实现meansum两种reduction方法,注意求平均不是简单的直接平均,而是加权平均

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, weight=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight
        self.reduction = reduction

    def forward(self, output, target):
        # convert output to pseudo probability
        out_target = torch.stack([output[i, t] for i, t in enumerate(target)])
        probs = torch.sigmoid(out_target)
        focal_weight = torch.pow(1-probs, self.gamma)

        # add focal weight to cross entropy
        ce_loss = F.cross_entropy(output, target, weight=self.weight, reduction='none')
        focal_loss = focal_weight * ce_loss

        if self.reduction == 'mean':
            focal_loss = (focal_loss/focal_weight.sum()).sum()
        elif self.reduction == 'sum':
            focal_loss = focal_loss.sum()

        return focal_loss

注:上面实现中,output的维度应当满足output.dim==2,并且其形状为(batch_size, C),且target.max()<C

总结

Focal Loss从2017年提出至今,该论文已有2000多引用,足以说明其有效性。其实从本质上讲,它也只不过是给样本重新分配权重,它相对类别权重的分配方法,只不过是将样本空间进行更为细致的划分,从图2-4很容易理解,类别权重的方法,只是将样本空间划分为蓝色线上下两个部分,而加入难易样本的划分,又可以将空间划分为左右两个部分,如此,样本空间便被划分4个部分,这样更加细致。其实借助于这个思想,我们是否可以根据不同任务的需求,更加细致划分我们的样本空间,然后再相应的分配不同的权重呢?

参考文献

[1] Lin, T.-Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017). Focal loss for dense object detection. Paper presented at the Proceedings of the IEEE international conference on computer vision.

原文地址:https://www.cnblogs.com/endlesscoding/p/12155588.html