IRModule
本教程介绍了 Apache TVM Unity 的核心抽象概念,即 IRModule。IRModule 包含了 ML 模型的**整体**,包括计算图、张量程序以及对外部库的潜在调用。
import numpy as np
import tvm
from tvm import relax
创建 IRModule
IRModule 可以通过多种方式初始化。我们在下面演示其中的几种。
import torch
from torch import nn
from torch.export import export
from tvm.relax.frontend.torch import from_exported_program
从现有模型导入
初始化 IRModule 最常见的方式是从现有模型导入。Apache TVM Unity 支持从各种框架导入,例如 PyTorch 和 ONNX。本教程仅演示从 PyTorch 导入的过程。
# Create a dummy model
class TorchModel(nn.Module):
def __init__(self):
super(TorchModel, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
# Give an example argument to torch.export
example_args = (torch.randn(1, 784, dtype=torch.float32),)
# Convert the model to IRModule
with torch.no_grad():
exported_program = export(TorchModel().eval(), example_args)
mod_from_torch = from_exported_program(
exported_program, keep_params_as_input=True, unwrap_unit_return_tuple=True
)
mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch)
# Print the IRModule
mod_from_torch.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(p_fc1_weight, axes=None)
lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(x, lv, out_dtype="float32")
lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, p_fc1_bias)
lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(p_fc2_weight, axes=None)
lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, p_fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
使用 Relax NN 模块编写
Apache TVM Unity 还提供了一组类似 PyTorch 的 API,以帮助用户直接编写 IRModule。
from tvm.relax.frontend import nn
class RelaxModel(nn.Module):
def __init__(self):
super(RelaxModel, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
mod_from_relax, params_from_relax = RelaxModel().export_tvm(
{"forward": {"x": nn.spec.Tensor((1, 784), "float32")}}
)
mod_from_relax.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = add1
R.output(gv)
return gv
通过 TVMScript 创建
TVMScript 是一个基于 Python 的 DSL,用于 IRModule。我们能够直接以 TVMScript 语法输出 IRModule,或者,解析 TVMScript 以获得 IRModule。
from tvm.script import ir as I
from tvm.script import relax as R
@I.ir_module
class TVMScriptModule:
@R.function
def main(
x: R.Tensor((1, 784), dtype="float32"),
fc1_weight: R.Tensor((256, 784), dtype="float32"),
fc1_bias: R.Tensor((256,), dtype="float32"),
fc2_weight: R.Tensor((10, 256), dtype="float32"),
fc2_bias: R.Tensor((10,), dtype="float32"),
) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims = R.permute_dims(fc1_weight, axes=None)
matmul = R.matmul(x, permute_dims, out_dtype="void")
add = R.add(matmul, fc1_bias)
relu = R.nn.relu(add)
permute_dims1 = R.permute_dims(fc2_weight, axes=None)
matmul1 = R.matmul(relu, permute_dims1, out_dtype="void")
add1 = R.add(matmul1, fc2_bias)
gv = add1
R.output(gv)
return gv
mod_from_script = TVMScriptModule
mod_from_script.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = add1
R.output(gv)
return gv
IRModule 的属性
IRModule 是函数的集合,通过 GlobalVars 索引。
mod = mod_from_torch
print(mod.get_global_vars())
[I.GlobalVar("main")]
我们可以通过使用 GlobalVars 或其名称索引来访问 IRModule 中的函数
# from tvm.script import relax as R
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(p_fc1_weight, axes=None)
lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(x, lv, out_dtype="float32")
lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, p_fc1_bias)
lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(p_fc2_weight, axes=None)
lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, p_fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
IRModule 的转换
转换是 Apache TVM Unity 的重要组成部分。一个转换接受一个 IRModule 并输出另一个 IRModule。我们可以将一系列转换应用于 IRModule 以获得新的 IRModule。这是优化模型的常用方法。
在本入门教程中,我们仅演示如何将转换应用于 IRModule。有关每个转换的详细信息,请参阅 Transformation API 参考
我们首先将 **LegalizeOps** 转换应用于 IRModule。此转换会将 Relax 模块转换为混合阶段,同一模块中同时包含 Relax 和 TensorIR 函数。同时,Relax 运算符将被转换为 call_tir
。
mod = mod_from_torch
mod = relax.transform.LegalizeOps()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(lv1: T.Buffer((T.int64(1), T.int64(256)), "float32"), p_fc1_bias: T.Buffer((T.int64(256),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(lv1[v_ax0, v_ax1], p_fc1_bias[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = lv1[v_ax0, v_ax1] + p_fc1_bias[v_ax1]
@T.prim_func(private=True)
def add1(lv5: T.Buffer((T.int64(1), T.int64(10)), "float32"), p_fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(lv5[v_ax0, v_ax1], p_fc2_bias[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = lv5[v_ax0, v_ax1] + p_fc2_bias[v_ax1]
@T.prim_func(private=True)
def matmul(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], lv[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0.0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * lv[v_k, v_i1]
@T.prim_func(private=True)
def matmul1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0.0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]
@T.prim_func(private=True)
def relu(lv2: T.Buffer((T.int64(1), T.int64(256)), "float32"), compute: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(1), T.int64(256)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(lv2[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(lv2[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def transpose(p_fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p_fc1_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = p_fc1_weight[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(p_fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p_fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = p_fc2_weight[v_ax1, v_ax0]
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.transpose, (p_fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
lv1 = R.call_tir(cls.matmul, (x, lv), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv2 = R.call_tir(cls.add, (lv1, p_fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv3 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv4 = R.call_tir(cls.transpose1, (p_fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
lv5 = R.call_tir(cls.matmul1, (lv3, lv4), out_sinfo=R.Tensor((1, 10), dtype="float32"))
lv6 = R.call_tir(cls.add1, (lv5, p_fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
转换后,模块内部的函数会多得多。让我们再次打印全局变量。
print(mod.get_global_vars())
[I.GlobalVar("add"), I.GlobalVar("add1"), I.GlobalVar("main"), I.GlobalVar("matmul"), I.GlobalVar("matmul1"), I.GlobalVar("relu"), I.GlobalVar("transpose"), I.GlobalVar("transpose1")]
接下来,Apache TVM Unity 为用户提供了一组默认的转换流水线,以简化转换过程。然后,我们可以将默认流水线应用于模块。默认的 **zero** 流水线包含非常基础的转换,包括
**LegalizeOps**:此转换将 Relax 运算符转换为带有相应 TensorIR 函数的 call_tir 函数。在此转换之后,IRModule 将同时包含 Relax 函数和 TensorIR 函数。
**AnnotateTIROpPattern**:此转换注释 TensorIR 函数的模式,为后续的运算符融合做准备。
**FoldConstant**:此 pass 执行常量折叠,优化涉及常量的运算。
**FuseOps 和 FuseTIR**:这两个 pass 协同工作,根据上一步(AnnotateTIROpPattern)中注释的模式融合运算符。这些 pass 转换 Relax 函数和 TensorIR 函数。
注意
在这里,我们在流程中应用了两次 **LegalizeOps**。第二次是无用的,但无害的。
每个 pass 都可以在流程中重复,因为我们确保 pass 可以处理所有合法的 IRModule 输入。这种设计可以帮助用户构建自己的流水线。
mod = relax.get_pipeline("zero")(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_matmul1_add1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), p_fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(10)))
for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], p_fc2_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + p_fc2_bias[v_ax1]
@T.prim_func(private=True)
def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), p_fc1_bias: T.Buffer((T.int64(256),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], lv[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + x[v_i0, v_k] * lv[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], p_fc1_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + p_fc1_bias[v_ax1]
for i0, i1 in T.grid(T.int64(1), T.int64(256)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_add_intermediate[v_i0, v_i1])
T.writes(compute_intermediate[v_i0, v_i1])
compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def transpose(p_fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p_fc1_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = p_fc1_weight[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(p_fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p_fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = p_fc2_weight[v_ax1, v_ax0]
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.transpose, (p_fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
lv_1 = R.call_tir(cls.fused_matmul_add_relu, (x, lv, p_fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv4 = R.call_tir(cls.transpose1, (p_fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
gv = R.call_tir(cls.fused_matmul1_add1, (lv_1, lv4, p_fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
R.output(gv)
return gv
通用部署 IRModule
优化之后,我们可以将模型编译成 TVM 运行时模块。值得注意的是,Apache TVM Unity 提供了通用部署的能力,这意味着我们可以将相同的 IRModule 部署在不同的后端上,包括 CPU、GPU 和其他新兴后端。
部署在 CPU 上
我们可以通过将目标指定为 llvm
将 IRModule 部署在 CPU 上。
exec = tvm.compile(mod, target="llvm")
dev = tvm.cpu()
vm = relax.VirtualMachine(exec, dev)
raw_data = np.random.rand(1, 784).astype("float32")
data = tvm.nd.array(raw_data, dev)
cpu_out = vm["main"](data, *params_from_torch["main"]).numpy()
print(cpu_out)
[[-0.02533746 0.02381632 -0.13650313 0.00278021 -0.01219752 -0.14017892
0.10063572 0.02572871 0.10278713 0.2471247 ]]
部署在 GPU 上
除了 CPU 后端,我们还可以将 IRModule 部署在 GPU 上。GPU 需要包含额外信息的程序,例如线程绑定和共享内存分配。我们需要进一步的转换来生成 GPU 程序。
我们使用 DLight
来生成 GPU 程序。在本教程中,我们不会深入探讨 DLight
的细节。
from tvm import dlight as dl
with tvm.target.Target("cuda"):
gpu_mod = dl.ApplyDefaultSchedule(
dl.gpu.Matmul(),
dl.gpu.Fallback(),
)(mod)
现在我们可以在 GPU 上编译 IRModule,与我们在 CPU 上所做的类似。
exec = tvm.compile(gpu_mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(exec, dev)
# Need to allocate data and params on GPU device
data = tvm.nd.array(raw_data, dev)
gpu_params = [tvm.nd.array(p, dev) for p in params_from_torch["main"]]
gpu_out = vm["main"](data, *gpu_params).numpy()
print(gpu_out)
# Check the correctness of the results
assert np.allclose(cpu_out, gpu_out, atol=1e-3)
[[-0.02533747 0.02381634 -0.13650307 0.00278023 -0.01219756 -0.14017901
0.10063575 0.02572877 0.10278708 0.2471247 ]]
部署在其他后端
Apache TVM Unity 还支持其他后端,例如不同类型的 GPU(Metal、ROCm、Vulkan 和 OpenCL)、不同类型的 CPU(x86、ARM)以及其他新兴后端(例如,WebAssembly)。部署过程与 GPU 后端类似。