Relax 创建

本教程演示如何创建 Relax 函数和程序。我们将介绍定义 Relax 函数的各种方法,包括使用 TVMScript 和 Relax NNModule API。

使用 TVMScript 创建 Relax 程序

TVMScript 是一种领域特定语言,用于表示 Apache TVM 的中间表示 (IR)。它是一种 Python 方言,可用于定义 IRModule,其中包含 TensorIR 和 Relax 函数。

在本节中,我们将展示如何使用 TVMScript 定义一个仅包含高级 Relax 算子的简单 MLP 模型。

from tvm import relax, topi
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T


@I.ir_module
class RelaxModule:
    @R.function
    def forward(
        data: R.Tensor(("n", 784), dtype="float32"),
        w0: R.Tensor((128, 784), dtype="float32"),
        b0: R.Tensor((128,), dtype="float32"),
        w1: R.Tensor((10, 128), dtype="float32"),
        b1: R.Tensor((10,), dtype="float32"),
    ) -> R.Tensor(("n", 10), dtype="float32"):
        with R.dataflow():
            lv0 = R.matmul(data, R.permute_dims(w0)) + b0
            lv1 = R.nn.relu(lv0)
            lv2 = R.matmul(lv1, R.permute_dims(w1)) + b1
            R.output(lv2)
        return lv2


RelaxModule.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(data: R.Tensor(("n", 784), dtype="float32"), w0: R.Tensor((128, 784), dtype="float32"), b0: R.Tensor((128,), dtype="float32"), w1: R.Tensor((10, 128), dtype="float32"), b1: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        with R.dataflow():
            lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(w0, axes=None)
            lv1: R.Tensor((n, 128), dtype="float32") = R.matmul(data, lv, out_dtype="void")
            lv0: R.Tensor((n, 128), dtype="float32") = R.add(lv1, b0)
            lv1_1: R.Tensor((n, 128), dtype="float32") = R.nn.relu(lv0)
            lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(w1, axes=None)
            lv5: R.Tensor((n, 10), dtype="float32") = R.matmul(lv1_1, lv4, out_dtype="void")
            lv2: R.Tensor((n, 10), dtype="float32") = R.add(lv5, b1)
            R.output(lv2)
        return lv2

Relax 不仅是图级 IR,还支持跨级表示和转换。具体来说,我们可以在 Relax 函数中直接调用 TensorIR 函数。

@I.ir_module
class RelaxModuleWithTIR:
    @T.prim_func
    def relu(x: T.handle, y: T.handle):
        n, m = T.int64(), T.int64()
        X = T.match_buffer(x, (n, m), "float32")
        Y = T.match_buffer(y, (n, m), "float32")
        for i, j in T.grid(n, m):
            with T.block("relu"):
                vi, vj = T.axis.remap("SS", [i, j])
                Y[vi, vj] = T.max(X[vi, vj], T.float32(0))

    @R.function
    def forward(
        data: R.Tensor(("n", 784), dtype="float32"),
        w0: R.Tensor((128, 784), dtype="float32"),
        b0: R.Tensor((128,), dtype="float32"),
        w1: R.Tensor((10, 128), dtype="float32"),
        b1: R.Tensor((10,), dtype="float32"),
    ) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        cls = RelaxModuleWithTIR
        with R.dataflow():
            lv0 = R.matmul(data, R.permute_dims(w0)) + b0
            lv1 = R.call_tir(cls.relu, lv0, R.Tensor((n, 128), dtype="float32"))
            lv2 = R.matmul(lv1, R.permute_dims(w1)) + b1
            R.output(lv2)
        return lv2


RelaxModuleWithTIR.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def relu(x: T.handle, y: T.handle):
        n, m = T.int64(), T.int64()
        X = T.match_buffer(x, (n, m))
        Y = T.match_buffer(y, (n, m))
        # with T.block("root"):
        for i, j in T.grid(n, m):
            with T.block("relu"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(X[vi, vj])
                T.writes(Y[vi, vj])
                Y[vi, vj] = T.max(X[vi, vj], T.float32(0.0))

    @R.function
    def forward(data: R.Tensor(("n", 784), dtype="float32"), w0: R.Tensor((128, 784), dtype="float32"), b0: R.Tensor((128,), dtype="float32"), w1: R.Tensor((10, 128), dtype="float32"), b1: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        cls = Module
        with R.dataflow():
            lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(w0, axes=None)
            lv1: R.Tensor((n, 128), dtype="float32") = R.matmul(data, lv, out_dtype="void")
            lv0: R.Tensor((n, 128), dtype="float32") = R.add(lv1, b0)
            lv1_1 = R.call_tir(cls.relu, (lv0,), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(w1, axes=None)
            lv5: R.Tensor((n, 10), dtype="float32") = R.matmul(lv1_1, lv4, out_dtype="void")
            lv2: R.Tensor((n, 10), dtype="float32") = R.add(lv5, b1)
            R.output(lv2)
        return lv2

注意

您可能会注意到打印输出与编写的 TVMScript 代码不同。这是因为我们以标准格式打印 IRModule,同时我们支持输入的语法糖

例如,我们可以将多个运算符组合成一行,如

lv0 = R.matmul(data, R.permute_dims(w0)) + b0

但是,规范化的表达式在一个绑定中仅需要一个操作。因此,打印输出与编写的 TVMScript 代码不同,如

lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(w0, axes=None)
lv1: R.Tensor((n, 128), dtype="float32") = R.matmul(data, lv, out_dtype="void")
lv0: R.Tensor((n, 128), dtype="float32") = R.add(lv1, b0)

使用 NNModule API 创建 Relax 程序

除了 TVMScript 之外,我们还提供了一个类似 PyTorch 的 API 用于定义神经网络。它的设计比 TVMScript 更直观且更易于使用。

在本节中,我们将展示如何使用 Relax NNModule API 定义相同的 MLP 模型。

from tvm.relax.frontend import nn


class NNModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x

在我们定义 NNModule 后,我们可以通过 export_tvm 将其导出到 TVM IRModule。

mod, params = NNModule().export_tvm({"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}})
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((n, 128), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((n, 10), dtype="float32") = add1
            R.output(gv)
        return gv

我们还可以在 NNModule 中插入自定义函数调用,例如 Tensor Expression(TE)、TensorIR 函数或其他 TVM packed 函数。

@T.prim_func
def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
    M, N, K = T.int64(), T.int64(), T.int64()
    X = T.match_buffer(x, (M, K), "float32")
    W = T.match_buffer(w, (N, K), "float32")
    B = T.match_buffer(b, (N,), "float32")
    Z = T.match_buffer(z, (M, N), "float32")
    for i, j, k in T.grid(M, N, K):
        with T.block("linear"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                Z[vi, vj] = 0
            Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
    for i, j in T.grid(M, N):
        with T.block("add"):
            vi, vj = T.axis.remap("SS", [i, j])
            Z[vi, vj] = Z[vi, vj] + B[vj]


class NNModuleWithTIR(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        n = x.shape[0]
        # We can call external functions using nn.extern
        x = nn.extern(
            "env.linear",
            [x, self.fc1.weight, self.fc1.bias],
            out=nn.Tensor.placeholder((n, 128), "float32"),
        )
        # We can also call TensorIR via Tensor Expression API in TOPI
        x = nn.tensor_expr_op(topi.nn.relu, "relu", [x])
        # We can also call other TVM packed functions
        x = nn.tensor_ir_op(
            tir_linear,
            "tir_linear",
            [x, self.fc2.weight, self.fc2.bias],
            out=nn.Tensor.placeholder((n, 10), "float32"),
        )
        return x


mod, params = NNModuleWithTIR().export_tvm(
    {"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def relu(var_env_linear: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        env_linear = T.match_buffer(var_env_linear, (n, T.int64(128)))
        compute = T.match_buffer(var_compute, (n, T.int64(128)))
        # with T.block("root"):
        for i0, i1 in T.grid(n, T.int64(128)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(env_linear[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(env_linear[v_i0, v_i1], T.float32(0.0))

    @T.prim_func
    def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
        M, K = T.int64(), T.int64()
        X = T.match_buffer(x, (M, K))
        N = T.int64()
        W = T.match_buffer(w, (N, K))
        B = T.match_buffer(b, (N,))
        Z = T.match_buffer(z, (M, N))
        # with T.block("root"):
        for i, j, k in T.grid(M, N, K):
            with T.block("linear"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                T.reads(X[vi, vk], W[vj, vk])
                T.writes(Z[vi, vj])
                with T.init():
                    Z[vi, vj] = T.float32(0.0)
                Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
        for i, j in T.grid(M, N):
            with T.block("add"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(Z[vi, vj], B[vj])
                T.writes(Z[vi, vj])
                Z[vi, vj] = Z[vi, vj] + B[vj]

    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            env_linear = R.call_dps_packed("env.linear", (x, fc1_weight, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            lv = R.call_tir(cls.relu, (env_linear,), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            lv1 = R.call_tir(cls.tir_linear, (lv, fc2_weight, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
            gv: R.Tensor((n, 10), dtype="float32") = lv1
            R.output(gv)
        return gv

使用 Block Builder API 创建 Relax 程序

除了上述 API 之外,我们还提供了一个 Block Builder API 用于创建 Relax 程序。它是一个 IR 构建器 API,更底层,广泛用于 TVM 的内部逻辑中,例如编写自定义 pass。

bb = relax.BlockBuilder()
n = T.int64()
x = relax.Var("x", R.Tensor((n, 784), "float32"))
fc1_weight = relax.Var("fc1_weight", R.Tensor((128, 784), "float32"))
fc1_bias = relax.Var("fc1_bias", R.Tensor((128,), "float32"))
fc2_weight = relax.Var("fc2_weight", R.Tensor((10, 128), "float32"))
fc2_bias = relax.Var("fc2_bias", R.Tensor((10,), "float32"))
with bb.function("forward", [x, fc1_weight, fc1_bias, fc2_weight, fc2_bias]):
    with bb.dataflow():
        lv0 = bb.emit(relax.op.matmul(x, relax.op.permute_dims(fc1_weight)) + fc1_bias)
        lv1 = bb.emit(relax.op.nn.relu(lv0))
        gv = bb.emit(relax.op.matmul(lv1, relax.op.permute_dims(fc2_weight)) + fc2_bias)
        bb.emit_output(gv)
    bb.emit_func_output(gv)

mod = bb.get()
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor(("v", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("v", 10), dtype="float32"):
        v = T.int64()
        with R.dataflow():
            lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            lv1: R.Tensor((v, 128), dtype="float32") = R.matmul(x, lv, out_dtype="void")
            lv2: R.Tensor((v, 128), dtype="float32") = R.add(lv1, fc1_bias)
            lv3: R.Tensor((v, 128), dtype="float32") = R.nn.relu(lv2)
            lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            lv5: R.Tensor((v, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="void")
            lv6: R.Tensor((v, 10), dtype="float32") = R.add(lv5, fc2_bias)
            gv: R.Tensor((v, 10), dtype="float32") = lv6
            R.output(gv)
        return lv6

此外,Block Builder API 支持构建包含 Relax 函数、TensorIR 函数和其他 TVM packed 函数的跨级 IRModule。

bb = relax.BlockBuilder()
with bb.function("forward", [x, fc1_weight, fc1_bias, fc2_weight, fc2_bias]):
    with bb.dataflow():
        lv0 = bb.emit(
            relax.call_dps_packed(
                "env.linear",
                [x, fc1_weight, fc1_bias],
                out_sinfo=relax.TensorStructInfo((n, 128), "float32"),
            )
        )
        lv1 = bb.emit_te(topi.nn.relu, lv0)
        tir_gv = bb.add_func(tir_linear, "tir_linear")
        gv = bb.emit(
            relax.call_tir(
                tir_gv,
                [lv1, fc2_weight, fc2_bias],
                out_sinfo=relax.TensorStructInfo((n, 10), "float32"),
            )
        )
        bb.emit_output(gv)
    bb.emit_func_output(gv)
mod = bb.get()
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def relu(var_lv: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        v = T.int64()
        lv = T.match_buffer(var_lv, (v, T.int64(128)))
        compute = T.match_buffer(var_compute, (v, T.int64(128)))
        # with T.block("root"):
        for i0, i1 in T.grid(v, T.int64(128)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(lv[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(lv[v_i0, v_i1], T.float32(0.0))

    @T.prim_func
    def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
        M, K = T.int64(), T.int64()
        X = T.match_buffer(x, (M, K))
        N = T.int64()
        W = T.match_buffer(w, (N, K))
        B = T.match_buffer(b, (N,))
        Z = T.match_buffer(z, (M, N))
        # with T.block("root"):
        for i, j, k in T.grid(M, N, K):
            with T.block("linear"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                T.reads(X[vi, vk], W[vj, vk])
                T.writes(Z[vi, vj])
                with T.init():
                    Z[vi, vj] = T.float32(0.0)
                Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
        for i, j in T.grid(M, N):
            with T.block("add"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(Z[vi, vj], B[vj])
                T.writes(Z[vi, vj])
                Z[vi, vj] = Z[vi, vj] + B[vj]

    @R.function
    def forward(x: R.Tensor(("v", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("v", 10), dtype="float32"):
        v = T.int64()
        cls = Module
        with R.dataflow():
            lv = R.call_dps_packed("env.linear", (x, fc1_weight, fc1_bias), out_sinfo=R.Tensor((v, 128), dtype="float32"))
            lv1 = R.call_tir(cls.relu, (lv,), out_sinfo=R.Tensor((v, 128), dtype="float32"))
            lv2 = R.call_tir(cls.tir_linear, (lv1, fc2_weight, fc2_bias), out_sinfo=R.Tensor((v, 10), dtype="float32"))
            gv: R.Tensor((v, 10), dtype="float32") = lv2
            R.output(gv)
        return lv2

请注意,Block Builder API 不如上述 API 用户友好,但它是最底层的 API,并且与 IR 定义紧密结合。我们建议只想定义和转换 ML 模型的用户使用上述 API。但是对于那些想要构建更复杂转换的用户来说,Block Builder API 是一个更灵活的选择。

总结

本教程演示如何针对不同的用例使用 TVMScript、NNModule API、Block Builder API 和 PackedFunc API 创建 Relax 程序。

图库由 Sphinx-Gallery 生成