野路子码农系列(9)利用ONNX加速Pytorch模型推断

最近在做一个文本多分类的模型,非常常规的BERT+finetune的套路,考虑到运行成本,打算GPU训练后用CPU做推断。

在小破本上试了试,发现推断速度异常感人,尤其是序列长度增加之后,一条4-5秒不是梦。

于是只能寻找加速手段,早先听过很多人提到过ONNX,但从来没试过,于是就学习了一下,发现效果还挺不错的,手法其实也很简单,就是有几个小坑。

第1步 - 保存模型

首先得从torch中将模型导出成ONNX格式,可以在cross-validation的eval阶段进行这一步骤:

def eval_fn(data_loader, model, device):
    '此处省略其他代码'
    
    onnx_path = 'inference_model.onnx' # 指定保存路径
    torch.onnx._export(
        model, # BERT fintune model (instance)
        (ids, mask, token_type_ids), # model的输入参数,装入tuple
        onnx_path, # 保存路径
        opset_version=10, # 此处有坑,必须指定≥10,否则会报错
        do_constant_folding=True,
        input_names=['ids', 'mask', 'token_type_ids'], # model输入参数的名称
        output_names=['output'],
        export_params=True,
        dynamic_axes={
            'ids': {0: 'batch_size', 1: 'seq_length'}, # 0, 1分别代表axis 0和axis 1
            'mask': {0: 'batch_size', 1: 'seq_length'},
            'token_type_ids': {0: 'batch_size', 1: 'seq_length'},
            'output': {0: 'batch_size', 1: 'seq_length'}
        } # 用于变长序列(比如dynamic padding)和可能改变batch size的情况
    )
    
    
    return '此处省略返回值'

  

这里需要注意的几个点:

  •  torch自带了导出ONNX的方法,直接用就行
  • 你的模型可以有1个输入参数,也可以有多个,如果有多个,得装在tuple里
  • 相应的input_names要与你的参数一一对应,放在list里
  • opset_version建议设成10,默认不设的话可能会报错(ONNX export of Slice with dynamic inputs)
  • 如果你在data loader里设置了collate func来进行dynamic padding的话(不同batch的文本长度可能不一样),一定要设置dynamic_axes,否则之后加载推断时会出错(因为它会要求你推断时输入的各个维度与你保存ONNX模型时的输入纬度完全一致)。

第2步 - 加载模型与推断

接下来是推断环节,首先别忘了用 pip install onnxpip install onnxruntime 来安装必需的库,之后通过以下代码导入使用:

import onnxruntime as ort

接下来你可以照常写你的dataset和data loader,但需要注意的是,data loader返回的得是numpy.array,而不是torch.tensor(collate_fn里改改就行),否则报错伺候。

然后就是导入模型:

import onnxruntime as ort

onnx_model_path = 'inference_model.onnx' 
session = ort.InferenceSession(onnx_model_path)

再把data loader的输出分别接入对应的三个参数就好了:

session.run(ids, mask, token_type_ids)

%%timeit看一下运行时间(CPU):

4条长度为10的文本

torch:4.77s

torch+ONNX:39.7ms

4条长度为50的文本

torch:21.2s

torch+ONNX:246ms

差不多快了百倍有余,效果相当不错啦。

原文地址:https://www.cnblogs.com/silence-gtx/p/15509545.html