Sharpness-Aware Minimization for Efficiently Improving Generalization

Foret P., Kleiner A., Mobahi H., Neyshabur B. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning Representations.

在训练的时候对权重加扰动能增强泛化性.

主要内容

如上图所示, 一般的训练方法虽然能够收敛到一个不错的局部最优点, 但是往往这个局部最优点附近是非常不光滑的, 即对权重(w)添加微小的扰动(w+epsilon) 可能就会导致不好的结果, 作者认为这与模型的泛化性有很大关系(实际上已有别的文章提出这一观点).

作者给出如下的理论分析:

在满足一定条件下有

[L_{mathscr{D}} (w) le max_{|epsilon |_2 le ho} L_{mathcal{S}} (w + epsilon) + h(|w|_2^2/ ho^2). ]

其中(h)是一个严格单调递增函数, (L_{mathcal{S}})是在训练集(mathcal{S})上的损失,

[L_{mathscr{D}}(w) = mathbb{E}_{(x, y) sim mathscr{D}} [l(x, y;w)]. ]

如果把(h(|w|_2^2/ ho^2))看成(lambda |w|_2^2)(即常用的weight decay), 我们的目标函数可以认为是

[min_w L_{mathcal{S}}^{SAM} (w) + lambda |w|_2^2, ]

[L_{mathcal{S}}^{SAM}(w) := max_{|epsilon |_p le ho} L_{mathcal{S}} (w + epsilon), ]

注: 这里(|cdot |_p)而并不仅限于(|cdot |_2).

采用近似的方法求解上面的问题(就和对抗样本一样):

[epsilon^* (w) := mathop{arg max} limits_{|epsilon|_ple ho} L_{mathcal{S}}(w + epsilon) approx mathop{arg max} limits_{|epsilon|_ple ho} L_{mathcal{S}}(w) + epsilon^T abla_w L_{mathcal{S}}(w) = mathop{arg max} limits_{|epsilon|_ple ho} epsilon^T abla_w L_{mathcal{S}}(w). ]

就是一个对偶范数的问题.

虽然(epsilon^*(w))实际上是和(w)有关的, 但是在实际中只是当初普通的量带入, 这样就不用计算二阶导数了, 即

[ abla_w L_{mathcal{S}}^{SAM}(w) approx abla_w L_{mathcal{S}}(w) |_{w + hat{epsilon}(w)}. ]

实验结果非常好, 不仅能够提高普通的正确率, 在标签受到污染的情况下也能有很好的鲁棒性.

代码

原文代码

原文地址:https://www.cnblogs.com/MTandHJ/p/14955482.html