【DAIS】2020-arxiv-DAIS: Automatic Channel Pruning via Differentiable Annealing Indicator Search-论文阅读

DAIS: Automatic Channel Pruning via Differentiable Annealing Indicator Search

2020-arxiv-DAIS: Automatic Channel Pruning via Differentiable Annealing Indicator Search

来源:ChenBong 博客园

  • Institute:Peking University,Didi Chuxing,Northeastern University
  • Author:Yushuo Guan,Kaigui Bian,Zhengping Che,Yanzhi Wang
  • GitHub:/
  • Citation:/

Introduction

image-20201124183013197

给每个 output channel 附加一个辅助参数 α,用于生成该 channel 的 indicator(其实就是mask,0~1)

设计了 3 个训练损失项,分别用于:

  1. indicator 约束,使 indicator 稀疏化(二值化)
  2. FLOPs 约束,约束到目标 FLOPs
  3. Symmetry 约束,专门用于有残差连接的网络,保持残差块的输出通道剪枝率相同

分为 search stage 和 fine-tune stage, search stage 网络逐步收敛到 compact 结构,随后固定网络结构进行fine-tune

借鉴NAS中可微分的方法(DARTS),使用梯度下降更新 channel 的 indicator

设计 annealing function(温度系数 t),使得随着训练(搜索)的进行,indicator 根据FLOPs要求,逐步收敛到onehot(0 or 1),即可得到 pruned model

Motivation

  1. 之前的工作 【CNN-FCF (Li et al. 2019)】,也有使用二值化的 channel indicator 来决定是否剪掉某个 channel,但不可微,需要额外的优化工具(如ADMM)
  2. 之前的自动化通道剪枝【TAS (Dong and Yang 2019)】采用在spernet中搜索 pruned model,但存在候选剪枝子网和 spuernet 之间的 gap 的问题 &&

Contribution

  1. 在搜索阶段,使用 gradient-based bi-level optimization(和DARTs类似),使用梯度下降交替更新网络权重W 和 辅助参数 α
  2. 在搜索阶段,设计 annealing function,使得 indicator 逐步收敛到0/1,得到 pruned model
  3. 在搜索阶段,设计了3种约束

Method

bi-level optimization 优化目标

image-20201124183128721

  • W,α 交替更新
  • 在训练集上更新W,在验证集上更新 α,避免对训练集的过拟合
    • refer:
    • image-20201124193640335

Annealing Indicator 的设计

(l) 层的第 (i) 个 output channel 的 indicator: ( ilde{I}_{l}^{i}, i inleft[1, c_{l} ight])( ilde{I}_{l}^{i}) 的值由 (α^i_l) 决定,且取值范围在 ( ilde{I}_{l}^{i} in [0,1])

简单的归一化设计:

( ilde{I}_{l}^{i}=frac{1}{1+e^{-alpha_{i}^{i}}} qquad (6))

加上退火策略的归一化设计:

( ilde{I}_{l}^{i}=H_{T}left(alpha_{l}^{i} ight)=frac{1}{1+e^{-alpha_{l}^{i} / T}}, quad I_{l}^{i}=lim _{T ightarrow 0} H_{T}left(alpha_{l}^{i} ight) qquad (7))

初始时,温度系数T最大,(T=T_0) ;之后逐渐变小,T采用退火策略 (T=T_0/σ(n)) 逐渐趋于0,n为搜索阶段的epoch数

三种 Regularization

Lasso regularizer

image-20201124185237215

  • indicator 的 (l_1 norm)

Continuous FLOPs estimator regularizer

image-20201124185141872

  • (FLOP_l=(h×w)×(k^2×c_{in})×c_{out})

Symmetry regularizer

  • 对 residual block output channel 的修剪,会导致同一个 stage 的 residual block 的 output channel 不匹配
  • 以往的工作:
    • 要么不对 residual block 的 output channel 进行修剪,
    • 要么修剪后导致同一个 stage 的 residual block 的 output channel 不匹配
      • 直接抛弃残差连接
      • 使用 1×1 卷积重新进行通道对齐
    • (其实还有对 同一个 stage 的 residual block 的 output channel 使用相同的剪枝率的办法,类似于下面的Constrained)
  • 这些替代的办法会打断原有的梯度传播,导致梯度爆炸或梯度消失,导致性能下降,作者做了一个实验来证明

image-20201124201418721

  • Random 指从 ResNet-110 中随机采样子网,如果 residual block 的input channel 和 output channel 不同,则抛弃该 block 的残差连接
  • Constraint 指从 ResNet-110 采样子网,并确保每个 stage 的 residual block 的output channel 相同(可以保留所有残差连接)

image-20201124185209001

  • 只对有残差连接的网络使用,确保 residual block 的 (c_{in}=c_{out})

Experiments

Setup

Search Stage

  • use (R_{FLOPs}) and (R_{sym}) as the default regularizers
  • Search Epochs
    • CIFAR:Search 100 epochs,按 7:3 划分 train set, val set
    • ImageNet:Search 7508 iterations
  • α initialized: (alpha in mathcal{N}left(1,0.1^{2} ight))
  • (T_0=1, T_n=T_0/σ(n))(σ(n) = 49×n/N_{max}+1)(N_{max}) denotes the total number of training epochs.
  • The weights of (R_{FLOPs}) is 2 and (epsilon) = 0.05.
  • The weights of (R_{sym}) is 0.01 for ResNet-56/110 and 0 otherwise.

Fine-tuning Stage

  • Train Epochs

    • CIFAR:300 epochs
    • ImageNet:120 epochs
  • cos lr

CIFAR-10/100

image-20201124185810730

image-20201124185929546

ImageNet

image-20201124185849020

  • 最后一列加速比是使用PyTorch Mobile在Salaxy S9手机上得到的

Ablation Study

search methods

image-20201124191040891

  • Slimming:使用BN层的值作为 Indicator
  • Random:从 ResNet-110 中随机采样子网,如果 residual block 的input channel 和 output channel 不同,则抛弃该 block 的残差连接
  • Constraint:从 ResNet-110 采样子网,并确保每个 stage 的 residual block 的output channel 相同(可以保留所有残差连接)

The effectiveness of (R_{FLOPs}) and (R_{sym})

image-20201124191040891

  • (w/o R_{FLOPs}) : replaces (R_{FLOPs}) by (R_{lasso})
  • (w/o R_{sym}) :removes the symmetry regularizer
    • image-20201124192107185
    • symmetry regularizer 即约束同一个 stage 的 block 的 output channel 相同

The effectiveness of the annealing-relaxed function

image-20201124191040891

  • w/o annealing​:removing the annealing-relaxed function
    • 移除退火策略,即使用简单的归一化 ( ilde{I}_{l}^{i}=frac{1}{1+e^{-alpha_{i}^{i}}} qquad (6)) ,这时候 ( ilde{I}_{l}^{i}) 不会收敛到 [0~1] ,需要引入阈值,将低于阈值的 output channel(filter)剪掉: ( ilde{I}_{l}^{i}<0.55)

The effectiveness of the bi-level optimization

image-20201124191040891

  • w/o bi-level:同时在训练集上更新 W 和 α

Robustness of DAIS

image-20201124191054009

  • the impact of a shorter training scheme:
    • e50:search 50 epochs
  • the impact of different temperature decay scheme:
    • cosine(σ(n) = 49×(1−cos(frac{π}{2}n/N_{max}))+1)
    • smallT(σ(n) = 99×n/N_{max}+1)

One-shot capability of DAIS

image-20201124191109538

  • 原始方法的FLOPs剪枝率是在 search stage 全程固定的,search 完以后再进行fine-tune
  • 这里改为在 search stage 逐步提高剪枝率?

Recoverable Pruning

image-20201124191123859

  • 被剪掉的channel可能会重新恢复,自我调整能力

Conclusion

Summary

  • 可微分的mask
  • 逐渐 onehot 的mask
  • 连续的 FLOPs 估计

To Read

Reference

原文地址:https://www.cnblogs.com/chenbong/p/14047882.html