【ICLR2018】Mixup 解读

Mixup 发表于 ICLR 2018,是一种用于数据增强的算法,通过将不同类之间的样本混合,从而实现训练数据的扩充。

算法的实现方式非常简单,如下:

[hat{x}=lambda x_i + (1-lambda) x_j ]

[hat{y}=lambda y_i + (1-lambda) y_j ]

其中,(x_i)(x_j) 表示输入的样本特征, (y_i)(y_j) 表示 one-hot 形式的标签,(lambda) 是由贝塔分布计算出来的混合系数。

官方实现代码如下,非常容易理解:

def mixup_data(x, y, alpha):
    # 计算lambda的值 
    lam = np.random.beta(alpha, alpha)
    # 取得当前 batch 里的样本数量
    batch_size = x.size()[0]
    # 随机排序
    index = torch.randperm(batch_size).cuda()
    # 随机混合 x 样本,并生成对应标签
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    # 按照公式计算损失
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

训练时,调用代码如下:

inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, alpha)
outputs = net(inputs)
loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)

模型理解与思考: Mixup 在虚拟样本中进行训练(在两个随机样本及标签进行插值构建样本),将该方法集成到现有模型只需要几行代码,并且几乎没有计算开销。

我完成了一个在Google Colab 上利用 Mixup 进行CIFAR10 分类的例子,供感兴趣的朋友参考。链接: https://github.com/OUCTheoryGroup/colab_demo/blob/master/11_MixUp_ICLR2018.ipynb

原文地址:https://www.cnblogs.com/gaopursuit/p/14416459.html