libtorch
pytorch是一个强大的机器学习库,其中集成了很多方法,但从python本身角度讲,它的速度还不够快。用pytorch官网的话说:
虽然对于许多需要动态性和易迭代性的场景来说,Python是一种合适且首选的语言,但在同样的情况下,Python的这些特性恰恰是不利的。它常常应用于生产环境,这是一个低延迟和有严格部署要求的领域,一般选择C++。
打ACM的时候,也是用C++写一些比较恶心的算法居多,于是我也大胆地尝试了用C++加载pytorch模型。
官网教程
下面是官网的教程。
将Pytorch模型转化为Torch Script
Torch Script可以完好的表达pytorch模型,而且也能被C++头文件所理解。有两种方法可以将pytorch模型转换成TorchScript,Tracing和显式注释。有关这两种方法的详细使用请参考Torch Script reference。
Tracing
这种方法需要你给模型传入一个sample input
,它会跟踪在模型的forward
方法中的过程。
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
Annotation
这种方法适合模型的forward
方法不是那么显而易见的,而需要一些流程控制的。
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_module = MyModule(10,20)
sm = torch.jit.script(my_module)
如果你想在模型的forward
方法中使用TorchScrip尚不支持的一些python特性,使用@torch.jit.ignore
来修饰。
导出TorchScript模型
就一句话:
traced_script_module.save("traced_resnet_model.pt")
这样就和python没啥关系了,开始大大方方的搞C++。
在C++中加载Model
一个最简单的C++应用
下面代码保存为example-app.cpp
#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>
";
return -1;
}
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model
";
return -1;
}
std::cout << "ok
";
}
在这个最简单的C++程序中,torch::jit::load()
函数用来加载模形,参数为模型文件名,返回torch::jit::script::Module
类,<torch/script.h>
头文件包含了需要的类和方法,下面我们来下载这个文件。
导入依赖
libtorch
文件可以从官网下载,下载好对应版本后解压缩,结构如下
libtorch/
bin/
include/
lib/
share/
CMakeLists.txt
的写法官网也有
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
项目结构:
example-app/
CMakeLists.txt
example-app.cpp
在bash中执行编译命令(你需要有cmake工具,可以去官网下载,注意版本要和满足CMakeLists.txt
中VERSION
的条件)
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
cmake --build . --config Release
一切顺利的话你应该可以在build
文件夹中看到你的程序的可执行文件了。
$ ./example-app
<path_to_model>/traced_resnet_model.pt
ok
运行模型
是时候你去掌握一些libtorch的用法了,模型已经导入成功,使用libtorch中的一些方法,你就可以像在python中一样去跑你的模型了。
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '
';
我的实践
我在Ubuntu16CentOS7Windows10上都试了,只有Ubuntu一次成功了。
由于我以前用C++都是写ACM的小代码(with vim),所以调试起来都是g++
到底的,没有用过CMake,简单学习了一下。在工程中,使用C++开发不可能每个单独去编译,这是十分低效的,cmake可以快速的生成开发文件,节省开发者时间。
基本用法网上比较统一,我不太懂其中原理,也没什么新意,就不多说了。
Ubuntu
我在Ubuntu16.04环境下下载了官网的相应版本。然后完全按照流程进行,可以参考这里。
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.4.0%2Bcpu.zip
我修改了CMakeLists.txt
,本想这样不用在VScode的CMake工具中再配置什么其他参数,但是后来还是手动cmake的。
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} "/libtorch") # added this, '/libtorch' is where the file unzipped
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
遇到问题
libtorch的系统问题
之前在Ubuntu下用了一个错误的libtorch,cmake没有问题,make的时候全是错误。
cuda版本问题
我在windows下用的cuda10,然后去Ubuntu的时候还是选的cuda10,实际上那是一台不可用cuda的虚拟机,后来版本换成None就对了。这个错误比较明显,因为它会提示与gpu有关的错误。
cmake版本问题
如大家所见,官网给的要求CMake版本在3.0及以上,我手动把CMakeLists.txt
的版本要求降低到了已有的2.8,报错!升级之后就可以了。
g++版本问题
在我的CentOS7服务器上,yum install gcc-g++
,安装的版本是4.*,也会产生错误,关于g++升级问题,切记千万不要手贱先去去除软连接(即不要随便执行rm /usr/bin/g++ & rm /usr/bin/gcc
)。
官方文档
- The Torch Script reference: https://pytorch.org/docs/master/jit.html
- The PyTorch C++ API documentation: https://pytorch.org/cppdocs/
- The PyTorch Python API documentation: https://pytorch.org/docs/
本文如有描述不当之处,欢迎指出!