torchScript 认识

什么是 TorchScript?

  • 一种从 PyTorch 代码创建可序列化和可优化模型的方法,在 Python 环境中保存 TorchScript 程序,然后将其应用到没有 Python 依赖项的进程中加载,比如在 C++ 程序中使用
  • 简单来说,TorchScript 软件栈可以将 Python 代码转换成 C++ 代码。TorchScript 软件栈包括两部分:TorchScript(Python)和 LibTorch(C++)。TorchScript 负责将 Python 代码转成一个模型文件,LibTorch 负责解析运行这个模型文件

TorchScript 的两种模式?

  • TorchScript 保存模型有两种模式:trace 模式、script 模式、混合模式
  • Trace 模式:跟踪模型的执行,然后将其路径记录下来。在使用 trace 模式时,需要构造一个符合要求的输入,然后使用 TorchScript tracer 运行一遍,整个运行过程就会被记录下来。不能有 if-else 等控制流,只支持 Tensor 操作。这是因为跟踪的 graph 是静态的,必须执行前确定算子顺序
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
     class Module_0(torch.nn.Module):  
    def __init__(self, N, M):
    super(Module_0, self).__init__()
    self.weight = torch.nn.Parameter(torch.rand(N, M))
    self.linear = torch.nn.Linear(N, M)
    def forward(self, input: torch.Tensor) -> torch.Tensor:
    output = self.weight.mm(input)
    output = self.linear(output)
    return output
    scripted_module = torch.jit.trace(Module_0(2, 3).eval(), (torch.zeros(3, 2)))
    scripted_module.save("Module_0.pt")
  • Script 模式:TorchScript 实现了一个完整的编译器以支持 script 模式。保存模型阶段对应编译器的前端(语法分析、类型检查、中间代码生成)。在保存模型时,TorchScript 编译器解析 Python 代码,并构建代码的 AST(抽象语法树)。不仅支持 if-else 等控制流,还支持非 Tensor 操作,如 List、Tuple、Map 等容器操作
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
     class Module_1(torch.nn.Module):  
    def __init__(self, N, M):
    super(Module_1, self).__init__()
    self.weight = torch.nn.Parameter(torch.rand(N, M))
    self.linear = torch.nn.Linear(N, M)
    def forward(self, input: torch.Tensor, do_linear: bool) -> torch.Tensor:
    output = self.weight.mm(input)
    if do_linear:
    output = self.linear(output)
    return output
    scripted_module = torch.jit.script(Module_1(3, 3).eval())
    scripted_module.save("Module_1.pt")
  • 混合模式:trace 模式和 script 模式各有千秋也各有局限,在使用时将两种模式结合在一起使用可以最大化发挥 TorchScript 的优势。。例如,一个 module 包含控制流,同时也包含一个只有 Tensor 操作的子模型。这种情况下当然可以直接使用 script 模式,但是 script 模式需要对部分变量进行类型标注,比较繁琐。这种情况下就可以仅对上述子模型进行 trace,整体再进行 script
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
     class Module_2(torch.nn.Module):  
    def __init__(self, N, M):
    super(Module_2, self).__init__()
    self.linear = torch.nn.Linear(N, M)
    self.sub_module = torch.jit.trace(Module_0(2, 3).eval(), (torch.zeros(3, 2)))
    def forward(self, input: torch.Tensor, do_linear: bool) -> torch.Tensor:
    output = self.sub_module(input)
    if do_linear:
    output = self.linear(output)
    return output
    scripted_module = torch.jit.script(Module_2(2, 3).eval())

TorchScript 的调试?

  • 禁用 TorchScript(脚本和跟踪):通过在命令行中使用 PYTORCH_JIT=0 禁用了 TorchScript(脚本和跟踪)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    @torch.jit.script
    def scripted_fn(x : torch.Tensor):
    for i in range(12):
    x = x + x
    return
    def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)
    traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
    traced_fn(torch.rand(3, 4))
    # 启动命令:PYTORCH_JIT=0 python disable_jit_example.py
  • 检查代码:通过 print(xxx.code) 的方式输出代码
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    @torch.jit.script
    def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
    if i < 10:
    rv = rv - 1.0
    else:
    rv = rv + 1.0
    return rv
    print(foo.code)
  • 解释图形:通过 print(xxx.graph) 输出图形
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    @torch.jit.script
    def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
    if i < 10:
    rv = rv - 1.0
    else:
    rv = rv + 1.0
    return rv
    print(foo.graph)
  • 自动跟踪检查:自动捕获跟踪中许多错误
    1
    2
    3
    4
    5
    6
    7
    8
     def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
    result = result * x[i]
    return result
    inputs = (torch.rand(3, 4, 5),)
    check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
    traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

我想在 GPU 上训练模型并在 CPU 上进行推理。什么是 最佳实践?

  • 首先将模型从 GPU 转换为 CPU,然后保存
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
     cpu_model = gpu_model.cpu()
    sample_input_cpu = sample_input_gpu.cpu()
    traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
    torch.jit.save(traced_cpu, "cpu.pt")
    traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
    torch.jit.save(traced_gpu, "gpu.pt")
    # ... later, when using the model:
    if use_gpu:
    model = torch.jit.load("gpu.pt")
    else:
    model = torch.jit.load("cpu.pt")
    model(input)

参考:

  1. TorchScript — PyTorch 2.4 documentation
  2. clearhanhui/LearnLibTorch: LibTorch 中文教程。