转换

在本节中,我们将深入探讨 Relax 程序的转换。转换是编译流程的关键要素之一,用于优化和与硬件后端集成。

让我们首先创建一个简单的 Relax 程序,就像我们在上一节中所做的那样。

import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn


class NNModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


origin_mod, params = NNModule().export_tvm(
    {"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
)
origin_mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((n, 128), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((n, 10), dtype="float32") = add1
            R.output(gv)
        return gv

应用转换

Pass 是将转换应用于程序的主要方式。我们可以将 pass 应用于程序。作为第一步,让我们应用内置的 pass LegalizeOps 将高级运算符降低为低级运算符。

mod = tvm.relax.transform.LegalizeOps()(origin_mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(var_matmul: T.handle, fc1_bias: T.Buffer((T.int64(128),), "float32"), var_T_add: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        matmul = T.match_buffer(var_matmul, (n, T.int64(128)))
        T_add = T.match_buffer(var_T_add, (n, T.int64(128)))
        # with T.block("root"):
        for ax0, ax1 in T.grid(n, T.int64(128)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul[v_ax0, v_ax1], fc1_bias[v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = matmul[v_ax0, v_ax1] + fc1_bias[v_ax1]

    @T.prim_func(private=True)
    def add1(var_matmul1: T.handle, fc2_bias: T.Buffer((T.int64(10),), "float32"), var_T_add: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        matmul1 = T.match_buffer(var_matmul1, (n, T.int64(10)))
        T_add = T.match_buffer(var_T_add, (n, T.int64(10)))
        # with T.block("root"):
        for ax0, ax1 in T.grid(n, T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul1[v_ax0, v_ax1], fc2_bias[v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = matmul1[v_ax0, v_ax1] + fc2_bias[v_ax1]

    @T.prim_func(private=True)
    def matmul(var_x: T.handle, permute_dims: T.Buffer((T.int64(784), T.int64(128)), "float32"), var_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        x = T.match_buffer(var_x, (n, T.int64(784)))
        matmul = T.match_buffer(var_matmul, (n, T.int64(128)))
        # with T.block("root"):
        for i0, i1, k in T.grid(n, T.int64(128), T.int64(784)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], permute_dims[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0.0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * permute_dims[v_k, v_i1]

    @T.prim_func(private=True)
    def matmul1(var_relu: T.handle, permute_dims1: T.Buffer((T.int64(128), T.int64(10)), "float32"), var_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        relu = T.match_buffer(var_relu, (n, T.int64(128)))
        matmul = T.match_buffer(var_matmul, (n, T.int64(10)))
        # with T.block("root"):
        for i0, i1, k in T.grid(n, T.int64(10), T.int64(128)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0.0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + relu[v_i0, v_k] * permute_dims1[v_k, v_i1]

    @T.prim_func(private=True)
    def relu(var_add: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        add = T.match_buffer(var_add, (n, T.int64(128)))
        compute = T.match_buffer(var_compute, (n, T.int64(128)))
        # with T.block("root"):
        for i0, i1 in T.grid(n, T.int64(128)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(add[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(add[v_i0, v_i1], T.float32(0.0))

    @T.prim_func(private=True)
    def transpose(fc1_weight: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc1_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]

    @T.prim_func(private=True)
    def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc2_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]

    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32"))
            matmul = R.call_tir(cls.matmul, (x, permute_dims), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            add = R.call_tir(cls.add, (matmul, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            relu = R.call_tir(cls.relu, (add,), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32"))
            matmul1 = R.call_tir(cls.matmul1, (relu, permute_dims1), out_sinfo=R.Tensor((n, 10), dtype="float32"))
            add1 = R.call_tir(cls.add1, (matmul1, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
            gv: R.Tensor((n, 10), dtype="float32") = add1
            R.output(gv)
        return gv

从输出中我们可以看到,程序中的高级运算符(即 relax.op)被其对应的低级运算符(即 relax.call_tir)所取代。

接下来,让我们尝试应用运算符融合,这是一种在 ML 编译器中广泛使用的优化技术。请注意,在 Relax 中,融合优化是通过一组 pass 的协作完成的。我们可以按顺序应用它们。

mod = tvm.ir.transform.Sequential(
    [
        tvm.relax.transform.AnnotateTIROpPattern(),
        tvm.relax.transform.FuseOps(),
        tvm.relax.transform.FuseTIR(),
    ]
)(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def fused_matmul1_add1(p_relu: T.handle, permute_dims1: T.Buffer((T.int64(128), T.int64(10)), "float32"), fc2_bias: T.Buffer((T.int64(10),), "float32"), p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        relu = T.match_buffer(p_relu, (n, T.int64(128)))
        T_add_intermediate = T.match_buffer(p_output0, (n, T.int64(10)))
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((n, T.int64(10)))
        for i0, i1, k in T.grid(n, T.int64(10), T.int64(128)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + relu[v_i0, v_k] * permute_dims1[v_k, v_i1]
        for ax0, ax1 in T.grid(n, T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul_intermediate[v_ax0, v_ax1], fc2_bias[v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc2_bias[v_ax1]

    @T.prim_func(private=True)
    def fused_matmul_add_relu(p_x: T.handle, permute_dims: T.Buffer((T.int64(784), T.int64(128)), "float32"), fc1_bias: T.Buffer((T.int64(128),), "float32"), p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        x = T.match_buffer(p_x, (n, T.int64(784)))
        compute_intermediate = T.match_buffer(p_output0, (n, T.int64(128)))
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((n, T.int64(128)))
        T_add_intermediate = T.alloc_buffer((n, T.int64(128)))
        for i0, i1, k in T.grid(n, T.int64(128), T.int64(784)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], permute_dims[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + x[v_i0, v_k] * permute_dims[v_k, v_i1]
        for ax0, ax1 in T.grid(n, T.int64(128)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul_intermediate[v_ax0, v_ax1], fc1_bias[v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc1_bias[v_ax1]
        for i0, i1 in T.grid(n, T.int64(128)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_add_intermediate[v_i0, v_i1])
                T.writes(compute_intermediate[v_i0, v_i1])
                compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))

    @T.prim_func(private=True)
    def transpose(fc1_weight: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc1_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]

    @T.prim_func(private=True)
    def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc2_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]

    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32"))
            lv = R.call_tir(cls.fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32"))
            gv = R.call_tir(cls.fused_matmul1_add1, (lv, permute_dims1, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
            R.output(gv)
        return gv

结果,我们可以看到 matmuladdrelu 运算符被融合到一个内核中(即一个 call_tir)。

有关所有内置 pass,请参阅 relax.transform

自定义 Pass

我们也可以定义自己的 pass。让我们以将 relu 运算符重写为 gelu 运算符为例。

首先,我们需要编写一个 Relax IR Mutator 来进行重写。

from tvm.relax.expr_functor import PyExprMutator, mutator


@mutator
class ReluRewriter(PyExprMutator):
    def __init__(self, mod):
        super().__init__(mod)

    def visit_call_(self, call: relax.Call) -> relax.Expr:
        # visit the relax.Call expr, and only handle the case when op is relax.nn.relu
        if call.op.name == "relax.nn.relu":
            return relax.op.nn.gelu(call.args[0])

        return super().visit_call_(call)

然后我们可以编写一个 pass,将 mutator 应用于整个模块。

@tvm.transform.module_pass(opt_level=0, name="ReluToGelu")
class ReluToGelu:  # pylint: disable=too-few-public-methods
    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        """IRModule-level transformation"""
        rewriter = ReluRewriter(mod)
        for g_var, func in mod.functions_items():
            if isinstance(func, relax.Function):
                func = rewriter.visit_expr(func)
                rewriter.builder_.update_func(g_var, func)
        return rewriter.builder_.get()


mod = ReluToGelu()(origin_mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((n, 128), dtype="float32") = R.nn.gelu(add)
            permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((n, 10), dtype="float32") = add1
            R.output(gv)
        return gv

打印的输出显示 relax.nn.relu 运算符被重写为 relax.nn.gelu 运算符。

有关 mutator 的详细信息,请参阅 relax.expr_functor.PyExprMutator

总结

在本节中,我们展示了如何将转换应用于 Relax 程序。我们还展示了如何定义和应用自定义转换。

Gallery 由 Sphinx-Gallery 生成