将 TVM 集成到 PyTorch 中
随着 TVM 不断展示深度学习执行效率的提升,PyTorch 显然可以从直接利用编译器堆栈中获益。PyTorch 的一个主要原则是提供无缝且强大的集成,而不会妨碍用户。为此,PyTorch 现在有了一个官方的基于 TVM 的后端,torch_tvm。
用法很简单
import torch_tvm
torch_tvm.enable()
就是这样!然后,PyTorch 将尝试在其 JIT 编译过程中将所有可以转换的算子转换为已知的 Relay 算子。
背景
与许多其他 ML 框架不同,PyTorch 公开了一个 eager-execution(即时执行)编程接口。这种编程风格避免了基于图的元编程,而专注于以 Pythonic 的方式直接操作 n 维数组(张量)。因此,该框架最初非常适合模型的实验和开发,但不适合自动性能优化或部署。为了利用优化编译器技术,最近 PyTorch 引入了一些重大更改来解决这个问题。
PyTorch 1.0 引入了 PyTorch IR,这是一种特定于 PyTorch 的模型中间表示,类似于 Relay。PyTorch 程序可以通过模型追踪(记录模型的执行过程)或 TorchScript(Python 的一个子集)转换为 IR。新的 TVM 后端将 PyTorch IR 降低为 Relay,并且能够在用户几乎无需参与的情况下透明地提高 PyTorch 的性能。
集成和结果
为了支持 Relay,PyTorch JIT 添加了两个功能:自定义转换 pass 和自定义子图解释器。
当启用 torch_tvm
时,可以转换为 Relay Expr
的 PyTorch IR 子图将被标记为 Relay 兼容。由于 PyTorch IR 并不总是包含形状信息,因此在调用之前,任何子图都无法以有用的方式编译。
在用户调用期间,PyTorch JIT 运行时将确定输入形状信息,并使用新的 Relay C++ 构建系统编译先前标记的子图。编译会根据输入形状进行缓存,以便后续运行。更多详细信息可以在 README 中找到。
torch_tvm
建立了一个持续的基准测试系统,该系统正在监控 ResNet18 在 CPU 上的性能。开箱即用的 TVM 为各种 ResNet 模型提供了超过默认 PyTorch JIT 后端两倍的性能。下图详细说明了在 AWS c5n.4xlarge 实例上使用 16 个线程实现的每秒迭代次数(越大越好)。
这些结果非常令人鼓舞,该项目将继续专注于提高更多模型上的 CPU 推理速度。
未来的工作
目前,PyTorch JIT 做了大量工作来查找其 IR 的纯函数子集以馈送到 Relay。这避免了将别名和控制流信息映射到 Relay 的需要,但并非必要。将更多的 PyTorch IR 映射到 Relay 可能会带来性能提升,这是该项目的一个目标。PyTorch IR 正在快速变化,因为它正在开发中,因此必须谨慎地进行此操作。
将进行更多工作,以确保 PyTorch 和 TVM 代码之间的交接是高效的。这包括统一线程模型、分配器以及减少与将输入复制到 TVM 相关的开销。
教程
如果您已经编写了 PyTorch 模型,最简单的入门方法是使用 torch.jit.trace
,如下所示
import torch_tvm
from your_model import model, inputs
torch_tvm.enable(opt_level=3)
iters = 100
warmup = 10
# Ensure your model is in eval mode and also turn off gradients.
with torch.no_grad():
# Use tuned parameters for better performance.
with autotvm.apply_history_best("test/autotvm_tuning.log"):
# This is where all the compilation happens.
trace_tvm = torch.jit.trace(model, inputs)
# Warmup
for _ in range(warmup):
_ = trace_tvm(*inputs)
# Benchmark
start = time.time()
for _ in range(iters):
_ = trace_tvm(*inputs)
tvm_time = time.time() - start
print("Took {}s to run {} iters".format(tvm_time, iters))
此代码的大部分来自 benchmarks.py。请注意,用于 AVX2 LLVM 编译的调优参数位于 repo 的 test/
文件夹中。
如果您更喜欢直接使用 Relay,可以直接从 PyTorch 函数中提取表达式,可以通过(隐式)追踪或 TorchScript
def add(a, b, c):
return a + b + c
# via tracing
relay_graph = torch_tvm.to_relay(add, inputs)
@torch.jit.script
def mul(a, b, c):
return a * b * c
# via script
relay_graph = torch_tvm.to_relay(mul, inputs)