半精度训练pytorch+Apex

想起一个关于运维的段子:很多问题可以通过重启解决,想说算法工(diao)程(bao)师(xia)的很多问题可以通过换版本解决。

起因是白嫖到一个tensorflow的架子跑bert,自己花一上午时间搞定了单机多卡训练,之后花了两个下午也没有搞定半精度,症状是不报错,但是显存不降,速度不涨(32G v100)。于是开始怀念我熟悉的pytorch+apex,又断断续续花了两天多的时间把整个训练框架用pytorch实现了一遍,基于huggingface的transformers。

看到单卡loss正常下降就开始了多卡+apex半精度,结果发现fp16O1虽然显存降了,速度却比fp32还要慢2倍多,期间也参考了下其他人遇到的问题,最终怀疑了一下是不是自己的pytorch版本太老,pytorch版本从1.1.0切换到1.5.1,重新编译apex,果然速度上来了....前后版本如下(右侧是正常的,fp16O1速度是左侧版本的10倍),python版本都是3.7.6:

apex按readme quick start安装即可,可能需要指定加载路径。

export PYTHONPATH=/你的apex路径/:$PYTHONPATH

apex半精度训练可以参考这里transformers里面已经调用的很好了,不必自己改什么。

原文地址:https://www.cnblogs.com/zhengmeisong/p/13448015.html