自定义优化

Apache TVM 的主要设计目标之一是能够轻松自定义优化管道,以用于研究或开发目的,并迭代工程优化。在本教程中,我们将

回顾整体流程

../../_static/downloads/tvm_overall_flow.svg

整体流程包括以下步骤

  • 构建或导入模型:构建神经网络模型或从其他框架(例如 PyTorch、ONNX)导入预训练模型,并创建 TVM IRModule,其中包含编译所需的所有信息,包括用于计算图的高级 Relax 函数和用于张量程序的低级 TensorIR 函数。

  • 执行可组合的优化:执行一系列优化转换,例如图优化、张量程序优化和库分发。

  • 构建和通用部署:将优化后的模型构建为可部署到通用运行时的模块,并在不同的设备(例如 CPU、GPU 或其他加速器)上执行它。

import os
import tempfile
import numpy as np
import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn

可组合的 IRModule 优化

Apache TVM Unity 提供了一种灵活的方式来优化 IRModule。围绕 IRModule 优化的一切都可以与现有管道组合。请注意,每个优化都可以专注于计算图的一部分,从而实现部分降低或部分优化。

在本教程中,我们将演示如何使用 Apache TVM Unity 优化模型。

准备 Relax 模块

我们首先准备一个 Relax 模块。该模块可以从其他框架导入,使用 NN 模块前端或 TVMScript 构建。这里我们使用一个简单的神经网络模型作为示例。

class RelaxModel(nn.Module):
    def __init__(self):
        super(RelaxModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


input_shape = (1, 784)
mod, params = RelaxModel().export_tvm({"forward": {"x": nn.spec.Tensor(input_shape, "float32")}})
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            gv: R.Tensor((1, 10), dtype="float32") = matmul1
            R.output(gv)
        return gv

库分发

我们希望快速尝试针对某些平台(例如 GPU)的库优化变体。我们可以为特定平台和运算符编写特定的分发pass。这里我们演示如何为某些模式分发 CUBLAS 库。

注意

本教程仅演示了针对 CUBLAS 的单个运算符分发,突出了优化管道的灵活性。在实际应用中,我们可以导入多种模式并将它们分发到不同的内核。

# Import cublas pattern
import tvm.relax.backend.cuda.cublas as _cublas


# Define a new pass for CUBLAS dispatch
@tvm.transform.module_pass(opt_level=0, name="CublasDispatch")
class CublasDispatch:
    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        # Check if CUBLAS is enabled
        if not tvm.get_global_func("relax.ext.cublas", True):
            raise Exception("CUBLAS is not enabled.")

        # Get interested patterns
        patterns = [relax.backend.get_pattern("cublas.matmul_transposed_bias_relu")]
        # Note in real-world cases, we usually get all patterns
        # patterns = relax.backend.get_patterns_with_prefix("cublas")

        # Fuse ops by patterns and then run codegen
        mod = relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True)(mod)
        mod = relax.transform.RunCodegen()(mod)
        return mod


mod = CublasDispatch()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"external_mods": [metadata["runtime.Module"][0]]})
    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_relu_cublas", (fc1_weight, x, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(lv, permute_dims1, out_dtype="void")
            gv: R.Tensor((1, 10), dtype="float32") = matmul1
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

在分发pass之后,我们可以看到第一个 nn.Linearnn.ReLU 被融合并重写为一个 call_dps_packed 函数,该函数调用 CUBLAS 库。值得注意的是,另一部分没有改变,这意味着我们可以选择性地为某些计算分发优化。

自动调优

从前面的示例继续,我们可以使用自动调优进一步优化模型的计算的剩余部分。这里我们演示如何使用 meta-schedule 自动调优模型。

我们可以使用 MetaScheduleTuneTIR pass 来简单地调优模型,而使用 MetaScheduleApplyDatabase pass 将最佳配置应用于模型。调优过程将生成搜索空间,调优模型,后续步骤将最佳配置应用于模型。在运行这些pass之前,我们需要通过 LegalizeOps 将 relax 运算符降低为 TensorIR 函数

注意

为了节省 CI 时间并避免不稳定性,我们在 CI 环境中跳过调优过程。

device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
if os.getenv("CI", "") != "true":
    trials = 2000
    with target, tempfile.TemporaryDirectory() as tmp_dir:
        mod = tvm.ir.transform.Sequential(
            [
                relax.get_pipeline("zero"),
                relax.transform.MetaScheduleTuneTIR(work_dir=tmp_dir, max_trials_global=trials),
                relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir),
            ]
        )(mod)

    mod.show()

DLight 规则

DLight 规则是一组用于调度和优化内核的默认规则。DLight 规则旨在实现快速编译和公平的性能。在某些情况下,例如语言模型,DLight 提供了出色的性能,而对于通用模型,它在性能和编译时间之间取得了平衡。

from tvm import dlight as dl

# Apply DLight rules
with target:
    mod = tvm.ir.transform.Sequential(
        [
            relax.get_pipeline("zero"),
            dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            ),
        ]
    )(mod)

mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"external_mods": [metadata["runtime.Module"][0]]})
    @T.prim_func(private=True)
    def matmul(lv: T.Buffer((T.int64(1), T.int64(256)), "float32"), permute_dims1: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 4, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(10)), scope="local")
        for ax0_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"):
                for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.block("matmul_rf_init"):
                        vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                        v0 = T.axis.spatial(T.int64(10), ax0_fused_0 * T.int64(10) + ax0_fused_1)
                        T.reads()
                        T.writes(matmul_rf_local[vax1_fused_1, T.int64(0), v0])
                        matmul_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0.0)
                    for ax1_fused_0, u in T.grid(T.int64(16), 1):
                        with T.block("matmul_rf_update"):
                            vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                            v0 = T.axis.spatial(T.int64(10), ax0_fused_0 * T.int64(10) + ax0_fused_1)
                            vax1_fused_0 = T.axis.reduce(T.int64(16), ax1_fused_0)
                            T.reads(matmul_rf_local[vax1_fused_1, T.int64(0), v0], lv[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
                            T.writes(matmul_rf_local[vax1_fused_1, T.int64(0), v0])
                            matmul_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_rf_local[vax1_fused_1, T.int64(0), v0] + lv[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * permute_dims1[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
            for ax1_fused in T.thread_binding(T.int64(10), thread="threadIdx.x"):
                for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.block("matmul"):
                        vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax1_fused])
                        T.reads(matmul_rf_local[vax1_fused_1, T.int64(0), v0])
                        T.writes(matmul[T.int64(0), v0])
                        with T.init():
                            matmul[T.int64(0), v0] = T.float32(0.0)
                        matmul[T.int64(0), v0] = matmul[T.int64(0), v0] + matmul_rf_local[vax1_fused_1, T.int64(0), v0]

    @T.prim_func(private=True)
    def transpose(fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(T.int64(3), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_transpose"):
                    v0 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(10))
                    v1 = T.axis.spatial(T.int64(10), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(10))
                    T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < T.int64(2560))
                    T.reads(fc2_weight[v1, v0])
                    T.writes(T_transpose[v0, v1])
                    T_transpose[v0, v1] = fc2_weight[v1, v0]

    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_relu_cublas", (fc1_weight, x, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            permute_dims1 = R.call_tir(cls.transpose, (fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
            gv = R.call_tir(cls.matmul, (lv, permute_dims1), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

注意

本教程侧重于演示优化管道,而不是将性能推向极限。当前的优化可能不是最佳的。

部署优化后的模型

我们可以构建优化后的模型并将其部署到 TVM 运行时。

ex = tvm.compile(mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(ex, dev)
# Need to allocate data and params on GPU device
data = tvm.nd.array(np.random.rand(*input_shape).astype("float32"), dev)
gpu_params = [tvm.nd.array(np.random.rand(*p.shape).astype(p.dtype), dev) for _, p in params]
gpu_out = vm["forward"](data, *gpu_params).numpy()
print(gpu_out)
[[25195.38  26593.332 25156.406 26140.078 26229.059 26686.791 25628.742
  25817.941 25999.223 25453.73 ]]

总结

本教程演示了如何在 Apache TVM 中为 ML 模型自定义优化管道。我们可以轻松地组合优化pass,并为计算图的不同部分自定义优化。优化管道的灵活性使我们能够快速迭代优化并提高模型的性能。

由 Sphinx-Gallery 生成的图库