在要求低延迟、脚本部署的生存环境,C++是更被青睐的语言。在之前版本的PyTorch提供的都是Python语言的接口,PyTorch1.0版本的发布带来了C++的前端,为生产环境的部署带来了极大的方便。目前PyTorch1.3.0的稳定版本已经可以在官网下载了,本文的所有代码均运行在 cuda 10.1, cudnn 7 LibTorch1.3.0环境下。
1. 将Pytorch Model转换为Torch Script
首先,C++能够理解的模型不是原生的Python版本的模型,而是需要通过Torch Script编译和序列化Python模型,C++调用序列化后的模型进行预测。
官方给出了两种将PyTorch模型转换为Torch Script的方法,第一种是通过torch.jit.trace方法,该方法缺点是在forward()中不能有复杂的条件控制,优点是操作比较方便;第二种是通过torch.jit.script方法进行转换,该方法允许forward()中有条件控制。
1.1 通过torch.jit.trace方法转换
我们先新建一个简单的线性模型
1 | class MyModel(nn.Module): |
接下来我们仅需一行代码即可完成转换
1 | B, N, M = 64, 32, 8 |
traced_script_module即为转换好的C++需要的模型。可以调用save()函数保存模型。
1 | traced_script_module.save("model.pt") |
1.2 使用torch.jit.script方法
在1.0.0的版本中,该方法需要重新写一个类,这个类继承torch.jit.ScriptModule,同时对需要使用的方法添加@torch.jit.script_method装饰器,较为麻烦。在最新的1.3.0版本中已经简化模型仍然继承的是torch.nn.Module,且不需要使用装饰器,直接使用torch.jit.script方法即可,非常方便。
模型代码不变:
1 | class MyModel(nn.Module): |
保存模型之前需要调用torch.jit.script将模型转换为ScriptModule:
1 | B, N, M = 64, 32, 8 |
2. C++模型加载和预测
PyTorch的C++接口官方包名为LibTorch,可以在官网下载,无需编译即可使用。
2.1 模型加载
模型加载代码很简单,官网给的实例如下
1 |
|
首先在c++代码中需要引入torch/script.h头文件,然后加载模型即可。
2.2 模型预测
模型预测部分涉及到模型加载、输入数据转换、预测、输出数据获取等几个部分。代码中均有详细注释。
1 |
|
3. 编译
官方推荐使用CMake进行编译,使用非常简单。我们下载的libtorch目录结构如下
1 | libtorch |
lib/保存的是动态链接库文件include/保存的是头文件share/保存的是一些CMake必备的配置
在c++代码目录下新建CMakeLists.txt,填入下面信息即可
1 | cmake_minimum_required(VERSION 3.0 FATAL_ERROR) |
其中myapp是要编译生成的可执行文件的名字,eval.cpp是需要编译的源文件,即2.2小节的代码。
运行下面命令
1 | mkdir build |
然后,在当前目录下可以看到example-app的可执行文件,运行命令即可调用保存好的模型
1 | ./myapp model.pt |
4. 一些解决的问题
4.1 不使用CMake编译
如果我们不想使用CMake进行编译,则需要在Makefile中添加对应的头文件路径、需要链接的库文件。
1 | target = myapp |
5. 其他
在之前 1.0.0版本中,如果模型中使用了LSTM等循环网络结构,且使用c++加载模型,且在GPU上预测,在调用to()方法将模型放置到显存中时,init hidden state并不会自动被放入显存。这个bug在python中也存在,但是在python中可以使用重写父类的to()或者cuda()和cpu()的方法解决,然而,在c++中并不可以。一种非常蠢的方法是在Python中用hard code指定好init hidden state的显卡位置,c++调用时也使用该显卡进行预测,这样就可以使用GPU计算了。目前在1.3.0中经过测试,如果默认不指定 init hidden state,可以正常运行,目测已经fix该bug。
另外,如果在模型中使用了LSTM等循环网络结构,且使用c++在GPU进行预测,代码退出时会报core dump。目前在1.3.0中经过测试,退出也不存在core dump问题,目测已经fix该bug。
最近一直在研究使用c++版本的PyTorch进行强化学习训练,发现相比较TensorFlow来说,PyTorch的C++接口更加简单易懂,且与Python版本相似度很高,基本可以满足日常的模型训练需要。所以,在使用c++时,也可以不使用上述的ScriptModule进行模型加载,可以直接使用c++进行模型存储,再通过c++加载模型完成预测。
不过,PyTorch的C++文档比起Python文档,基本和没有文档一样,可以通过查看相应的头文件了解相关接口含义。大家感兴趣可以试一试。