在要求低延迟、脚本部署的生存环境,C++是更被青睐的语言。在之前版本的PyTorch
提供的都是Python
语言的接口,PyTorch1.0版本的发布带来了C++的前端,为生产环境的部署带来了极大的方便。目前PyTorch1.3.0的稳定版本已经可以在官网下载了,本文的所有代码均运行在 cuda 10.1, cudnn 7 LibTorch1.3.0环境下。
1. 将Pytorch Model
转换为Torch Script
首先,C++能够理解的模型不是原生的Python版本的模型,而是需要通过Torch Script
编译和序列化Python模型,C++调用序列化后的模型进行预测。
官方给出了两种将PyTorch
模型转换为Torch Script
的方法,第一种是通过torch.jit.trace
方法,该方法缺点是在forward()
中不能有复杂的条件控制,优点是操作比较方便;第二种是通过torch.jit.script
方法进行转换,该方法允许forward()
中有条件控制。
1.1 通过torch.jit.trace
方法转换
我们先新建一个简单的线性模型
1 | class MyModel(nn.Module): |
接下来我们仅需一行代码即可完成转换
1 | B, N, M = 64, 32, 8 |
traced_script_module
即为转换好的C++需要的模型。可以调用save()
函数保存模型。
1 | traced_script_module.save("model.pt") |
1.2 使用torch.jit.script
方法
在1.0.0
的版本中,该方法需要重新写一个类,这个类继承torch.jit.ScriptModule
,同时对需要使用的方法添加@torch.jit.script_method
装饰器,较为麻烦。在最新的1.3.0
版本中已经简化模型仍然继承的是torch.nn.Module
,且不需要使用装饰器,直接使用torch.jit.script
方法即可,非常方便。
模型代码不变:
1 | class MyModel(nn.Module): |
保存模型之前需要调用torch.jit.script
将模型转换为ScriptModule
:
1 | B, N, M = 64, 32, 8 |
2. C++模型加载和预测
PyTorch的C++接口官方包名为LibTorch
,可以在官网下载,无需编译即可使用。
2.1 模型加载
模型加载代码很简单,官网给的实例如下
1 |
|
首先在c++代码中需要引入torch/script.h
头文件,然后加载模型即可。
2.2 模型预测
模型预测部分涉及到模型加载、输入数据转换、预测、输出数据获取等几个部分。代码中均有详细注释。
1 |
|
3. 编译
官方推荐使用CMake
进行编译,使用非常简单。我们下载的libtorch
目录结构如下
1 | libtorch |
lib/
保存的是动态链接库文件include/
保存的是头文件share/
保存的是一些CMake
必备的配置
在c++代码目录下新建CMakeLists.txt
,填入下面信息即可
1 | cmake_minimum_required(VERSION 3.0 FATAL_ERROR) |
其中myapp
是要编译生成的可执行文件的名字,eval.cpp
是需要编译的源文件,即2.2小节的代码。
运行下面命令
1 | mkdir build |
然后,在当前目录下可以看到example-app
的可执行文件,运行命令即可调用保存好的模型
1 | ./myapp model.pt |
4. 一些解决的问题
4.1 不使用CMake
编译
如果我们不想使用CMake
进行编译,则需要在Makefile
中添加对应的头文件路径、需要链接的库文件。
1 | target = myapp |
5. 其他
在之前 1.0.0
版本中,如果模型中使用了LSTM等循环网络结构,且使用c++加载模型,且在GPU上预测,在调用to()
方法将模型放置到显存中时,init hidden state
并不会自动被放入显存。这个bug在python中也存在,但是在python中可以使用重写父类的to()
或者cuda()
和cpu()
的方法解决,然而,在c++中并不可以。一种非常蠢的方法是在Python中用hard code指定好init hidden state
的显卡位置,c++调用时也使用该显卡进行预测,这样就可以使用GPU计算了。目前在1.3.0
中经过测试,如果默认不指定 init hidden state
,可以正常运行,目测已经fix该bug。
另外,如果在模型中使用了LSTM
等循环网络结构,且使用c++在GPU进行预测,代码退出时会报core dump。目前在1.3.0
中经过测试,退出也不存在core dump问题,目测已经fix该bug。
最近一直在研究使用c++版本的PyTorch进行强化学习训练,发现相比较TensorFlow来说,PyTorch的C++接口更加简单易懂,且与Python版本相似度很高,基本可以满足日常的模型训练需要。所以,在使用c++时,也可以不使用上述的ScriptModule
进行模型加载,可以直接使用c++进行模型存储,再通过c++加载模型完成预测。
不过,PyTorch的C++文档比起Python文档,基本和没有文档一样,可以通过查看相应的头文件了解相关接口含义。大家感兴趣可以试一试。