上一篇文章中我们介绍了将PyTorch模型转换到ONNX(Open Neural Network Exchange)格式的一般步骤,这篇文章将开始系统介绍torch格式模型转换成ONNX的原理。
1. PyTorch模型的序列化——TorchScript
TorchScirpt是一种序列化和优化PyTorch模型的方法,会将一个torch.nn.Module
模型转换为torch.jit.ScriptModel
格式的模型,上一篇文章中介绍的PyTorch内部的转换ONNX的接口torch.onnx.export
实际上所传入的就是一个torch.jit.ScriptModel
格式的模型(上一节中之所以直接传入torch.nn.module
,是因为torch.onnx.export
接口会默认将torch.nn.module
隐式的转换为torch.jit.ScriptModel
)。
TorchScript有两种模式:torch.jit.trace
(跟踪模式)、torch.jit.script
(记录模式),在实际应用中,根据需要选择其中一种
1.1 torch.jit.trace
跟踪模式是torch.onnx.export
的默认模式,在使用该模式时,需要传入一个输入对torch模型进行一次推理,并记录整个的计算过程(主要包含用到的算子、输入输出的形状等信息)。即跟踪模式导出的是模型的静态图,无法追踪模型的控制流(条件分支和循环),对于模型中的条件分支,跟踪模式只能记录给定的输入所执行的分支,对于循环控制,也只能记录给定的输入执行的循环次数。
# 跟踪模式导出模型,需要执行推理
model_trace = torch.jit.trace(lm, dummy_input)
torch.onnx.export(model_trace, dummy_input, "./models/model_trace.onnx", output_names=['output'], input_names=['input'])
trace_model_session = ort.InferenceSession("./models/model_trace.onnx")
trace_ort_output = trace_model_session.run(['output'],
{'input': dummy_input.numpy()})
np.allclose(dummy_output.detach().numpy(), trace_ort_output)
1.2 torch.jit.script
记录模式会对给定的torch模型进行解析,真正理解
torch模型的控制流,它导出的是动态图,所以在导出时不需要在torch模型上执行推理。与torch.jit.trace
相比,动态图能够更好的支持带有流程控制语句的模型,但代价是推理速度相较于前者有所降低。在实际的模型部署工作中,尽量使用追踪模式,当追踪模式解决不了实际问题时,再考虑使用记录模式。
# 记录模式导出模型,不需要执行推理
model_script = torch.jit.script(lm)
torch.onnx.export(model_script, dummy_input, "./models/model_script.onnx", output_names=['output'], input_names=['input'])
script_model_session = ort.InferenceSession("./models/model_script.onnx")
script_ort_output = script_model_session.run(['output'],
{'input': dummy_input.numpy()})
np.allclose(dummy_output.detach().numpy(), script_ort_output)
2. torch.onnx.export接口
以下是torch.onnx.export
函数的签名:
def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, opset_version=None, _retain_param_name=True,
do_constant_folding=True, example_outputs=None, strip_doc_string=True,
dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None,
enable_onnx_checker=True, use_external_data_format=False):
一下是一些常用参数介绍:
-
model
:待转换的TorchScript
模型,也可以直接传入torch.nn.Module
,该接口会将其隐式转换为TorchScript
(torch.jit.trace
模式) -
args
:模型的输入 -
f
:转换后ONNX模型的存储路径 -
export_params
:是否导出模型参数,若为False
,则只导出模型结构,默认为True
,当在不同深度学习框架中传递模型时,可能只需要模型结构,此时可以将此参数设置为False
-
input_names
:输入的变量名,后续推理以及使用推理引擎加载时需要使用,当不设置此参数时,API会自动生成简单的名称 -
output_names
:输出的变量名,作用同上 -
opset_version
:使用的ONNX算子集版本 -
dynamic_axes
:默认情况下,export接口的输入输出的维度、大小都是固定的,但在某些情况下,有些维度的大小是动态的,如:批次大小维度batch_size,此时则需要指定此参数,具体用法将在下一篇文章中介绍。