理解 TensorIR 抽象

TensorIR 是 Apache TVM 中的张量程序抽象,Apache TVM 是标准的机器学习编译框架之一。张量程序抽象的主要目标是描述循环和相关的硬件加速选项,包括线程处理、专用硬件指令的应用以及内存访问。

为了帮助解释,让我们使用以下张量计算序列作为一个激励示例。具体来说,对于两个 128×128 矩阵 AB,让我们执行以下两个步骤的张量计算。

Yi,j=kAi,k×Bk,jCi,j=relu(Yi,j)=max(Yi,j,0)

上述计算类似于神经网络中常见的典型原始张量函数,即带有 relu 激活的线性层。我们使用 TensorIR 来描述上述计算,如下所示。

在我们调用 TensorIR 之前,让我们使用带有 NumPy 的原生 Python 代码来展示计算过程

def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j in range(128):
            for k in range(128):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)

考虑到低级 NumPy 示例,现在我们准备介绍 TensorIR。下面的代码块展示了 mm_relu 的 TensorIR 实现。该特定代码是用一种名为 TVMScript 的语言实现的,TVMScript 是一种嵌入在 python AST 中的领域特定方言。

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

接下来,让我们分析上述 TensorIR 程序中的元素。

函数参数和缓冲区

函数参数对应于 numpy 函数上相同的一组参数。

# TensorIR
def mm_relu(A: T.Buffer((128, 128), "float32"),
            B: T.Buffer((128, 128), "float32"),
            C: T.Buffer((128, 128), "float32")):
    ...
# NumPy
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    ...

这里 ABC 采用名为 T.Buffer 的类型,该类型带有形状参数 (128, 128) 和数据类型 float32。这些附加信息有助于可能的 MLC 过程生成专门针对形状和数据类型的代码。

类似地,TensorIR 也在中间结果分配中使用缓冲区类型。

# TensorIR
Y = T.alloc_buffer((128, 128), dtype="float32")
# NumPy
Y = np.empty((128, 128), dtype="float32")

循环迭代

循环迭代也存在直接的对应关系。

T.grid 是 TensorIR 中的语法糖,用于我们编写多个嵌套的迭代器。

# TensorIR with `T.grid`
for i, j, k in T.grid(128, 128, 128):
    ...
# TensorIR with `range`
for i in range(128):
    for j in range(128):
        for k in range(128):
            ...
# NumPy
for i in range(128):
    for j in range(128):
        for k in range(128):
            ...

计算块

一个显著的区别在于计算语句:TensorIR 引入了一个额外的结构,称为 T.block

# TensorIR
with T.block("Y"):
    vi = T.axis.spatial(128, i)
    vj = T.axis.spatial(128, j)
    vk = T.axis.reduce(128, k)
    with T.init():
        Y[vi, vj] = T.float32(0)
    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
# NumPy
vi, vj, vk = i, j, k
if vk == 0:
    Y[vi, vj] = 0
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]

表示 TensorIR 内的基本计算单元。重要的是,块包含比标准 NumPy 代码更多的信息。它包含一组块轴 (vi, vj, vk) 以及围绕它们描绘的计算。

vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)

以上三行以以下语法声明了关于块轴的 关键属性

[block_axis] = T.axis.[axis_type]([axis_range], [mapped_value])

这三行传达了以下细节

  • 它们指定了 vivjvk 的绑定(在本例中,分别绑定到 ijk)。

  • 它们声明了 vivjvk 的原始范围(即 T.axis.spatial(128, i) 中的 128)。

  • 它们声明了迭代器的属性(spatial,reduce)。

块轴属性

让我们更深入地研究块轴的属性。这些属性表示轴与正在进行的计算的关系。该块包含三个轴 vivjvk,同时该块读取缓冲区 A[vi, vk]B[vk, vj] 并写入缓冲区 Y[vi, vj]。严格来说,该块对 Y 执行(归约)更新,我们暂时将其标记为写入,因为我们不需要来自另一个块的 Y 的值。

值得注意的是,对于固定的 vivj 值,计算块在 Y 的空间位置(Y[vi, vj])产生一个点值,该值独立于 Y 中的其他位置(具有不同的 vivj 值)。我们可以将 vivj 称为 空间轴,因为它们直接对应于块写入的缓冲区空间区域的起始位置。参与归约的轴(vk)被指定为 归约轴

为何块中需要额外信息

一个重要的观察是,额外的信息(块轴范围及其属性)使得块在执行独立于外部循环嵌套 i, j, k 的迭代时是 自包含的

块轴信息还提供了额外的属性,这些属性帮助我们验证用于执行计算的外部循环的正确性。例如,上面的代码块将导致错误,因为循环期望大小为 128 的迭代器,但我们仅将其绑定到大小为 127 的 for 循环。

# wrong program due to loop and block iteration mismatch
for i in range(127):
    with T.block("C"):
        vi = T.axis.spatial(128, i)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
        error here due to iterator size mismatch
        ...

块轴绑定的语法糖

在每个块轴直接映射到外部循环迭代器的情况下,我们可以使用 T.axis.remap 在单行中声明块轴。

# SSR means the properties of each axes are "spatial", "spatial", "reduce"
vi, vj, vk = T.axis.remap("SSR", [i, j, k])

这等效于

vi = T.axis.spatial(range_of_i, i)
vj = T.axis.spatial(range_of_j, j)
vk = T.axis.reduce (range_of_k, k)

因此,我们也可以如下编写程序。

@tvm.script.ir_module
class MyModuleWithAxisRemapSugar:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))