pytorch转onnx常见问题

一、Type Error: Type 'tensor(bool)' of input parameter (121) of operator (ScatterND) in node (ScatterND_128) is invalid

问题
模型转出成功后,用onnxruntime加载,出现不支持参数问题, 这里出现tensor(bool)是因为代码中使用了bool类型的索引

解决措施
索引采用torch.where替代

...
mask = dist < distance
distance[mask] = dist[mask]
...

更改为

distance = torch.where(dist < distance, dist, distance)

二、FAIL : Load model from ./test.onnx failed:Fatal error: ATen is not a registered function/op

问题
模型转出成功后,用onnxruntime加载,出现没有注册的算子

解决措施
torch.onnx.export函数中设置opset_version=12

三、动态输入/输出

有时候输入和输出维度是变化的,这个时候在导出的时候可以添加dynamic_axes参数,并指定哪些参数和维度是动态的。

结果

四、Removing initializer 'bn1.num_batches_tracked'. It is not used by any node and should be removed from the model.

问题
模型转出成功后,用onnxruntime运行出现以上警告

解决措施
对模型进行优化

import onnx
import onnxoptimizer  # pip install onnxoptimizer

onnx_model = onnx.load(onnxfile)
passes = ["extract_constant_to_initializer", "eliminate_unused_initializer"]
optimized_model = onnxoptimizer.optimize(onnx_model, passes)

onnx.save(optimized_model, onnxfile)
原文地址:https://www.cnblogs.com/xiaxuexiaoab/p/15634117.html