理解 TensorIR 抽象
TensorIR 是 Apache TVM 中的张量程序抽象,Apache TVM 是标准的机器学习编译框架之一。张量程序抽象的主要目标是描述循环和相关的硬件加速选项,包括线程处理、专用硬件指令的应用以及内存访问。
为了帮助解释,让我们使用以下张量计算序列作为一个激励示例。具体来说,对于两个 A
和 B
,让我们执行以下两个步骤的张量计算。
上述计算类似于神经网络中常见的典型原始张量函数,即带有 relu 激活的线性层。我们使用 TensorIR 来描述上述计算,如下所示。
在我们调用 TensorIR 之前,让我们使用带有 NumPy 的原生 Python 代码来展示计算过程
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
Y = np.empty((128, 128), dtype="float32")
for i in range(128):
for j in range(128):
for k in range(128):
if k == 0:
Y[i, j] = 0
Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
for i in range(128):
for j in range(128):
C[i, j] = max(Y[i, j], 0)
考虑到低级 NumPy 示例,现在我们准备介绍 TensorIR。下面的代码块展示了 mm_relu
的 TensorIR 实现。该特定代码是用一种名为 TVMScript 的语言实现的,TVMScript 是一种嵌入在 python AST 中的领域特定方言。
@tvm.script.ir_module
class MyModule:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")):
Y = T.alloc_buffer((128, 128), dtype="float32")
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
接下来,让我们分析上述 TensorIR 程序中的元素。
函数参数和缓冲区
函数参数对应于 numpy 函数上相同的一组参数。
# TensorIR
def mm_relu(A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")):
...
# NumPy
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
...
这里 A
、B
和 C
采用名为 T.Buffer
的类型,该类型带有形状参数 (128, 128)
和数据类型 float32
。这些附加信息有助于可能的 MLC 过程生成专门针对形状和数据类型的代码。
类似地,TensorIR 也在中间结果分配中使用缓冲区类型。
# TensorIR
Y = T.alloc_buffer((128, 128), dtype="float32")
# NumPy
Y = np.empty((128, 128), dtype="float32")
循环迭代
循环迭代也存在直接的对应关系。
T.grid
是 TensorIR 中的语法糖,用于我们编写多个嵌套的迭代器。
# TensorIR with `T.grid`
for i, j, k in T.grid(128, 128, 128):
...
# TensorIR with `range`
for i in range(128):
for j in range(128):
for k in range(128):
...
# NumPy
for i in range(128):
for j in range(128):
for k in range(128):
...
计算块
一个显著的区别在于计算语句:TensorIR 引入了一个额外的结构,称为 T.block
。
# TensorIR
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
# NumPy
vi, vj, vk = i, j, k
if vk == 0:
Y[vi, vj] = 0
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
块 表示 TensorIR 内的基本计算单元。重要的是,块包含比标准 NumPy 代码更多的信息。它包含一组块轴 (vi, vj, vk)
以及围绕它们描绘的计算。
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
以上三行以以下语法声明了关于块轴的 关键属性。
[block_axis] = T.axis.[axis_type]([axis_range], [mapped_value])
这三行传达了以下细节
它们指定了
vi
、vj
、vk
的绑定(在本例中,分别绑定到i
、j
、k
)。它们声明了
vi
、vj
、vk
的原始范围(即T.axis.spatial(128, i)
中的 128)。它们声明了迭代器的属性(spatial,reduce)。
块轴属性
让我们更深入地研究块轴的属性。这些属性表示轴与正在进行的计算的关系。该块包含三个轴 vi
、vj
和 vk
,同时该块读取缓冲区 A[vi, vk]
、B[vk, vj]
并写入缓冲区 Y[vi, vj]
。严格来说,该块对 Y 执行(归约)更新,我们暂时将其标记为写入,因为我们不需要来自另一个块的 Y 的值。
值得注意的是,对于固定的 vi
和 vj
值,计算块在 Y
的空间位置(Y[vi, vj]
)产生一个点值,该值独立于 Y
中的其他位置(具有不同的 vi
、vj
值)。我们可以将 vi
、vj
称为 空间轴,因为它们直接对应于块写入的缓冲区空间区域的起始位置。参与归约的轴(vk
)被指定为 归约轴。
为何块中需要额外信息
一个重要的观察是,额外的信息(块轴范围及其属性)使得块在执行独立于外部循环嵌套 i, j, k
的迭代时是 自包含的。
块轴信息还提供了额外的属性,这些属性帮助我们验证用于执行计算的外部循环的正确性。例如,上面的代码块将导致错误,因为循环期望大小为 128 的迭代器,但我们仅将其绑定到大小为 127 的 for 循环。
# wrong program due to loop and block iteration mismatch
for i in range(127):
with T.block("C"):
vi = T.axis.spatial(128, i)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
error here due to iterator size mismatch
...
块轴绑定的语法糖
在每个块轴直接映射到外部循环迭代器的情况下,我们可以使用 T.axis.remap
在单行中声明块轴。
# SSR means the properties of each axes are "spatial", "spatial", "reduce"
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
这等效于
vi = T.axis.spatial(range_of_i, i)
vj = T.axis.spatial(range_of_j, j)
vk = T.axis.reduce (range_of_k, k)
因此,我们也可以如下编写程序。
@tvm.script.ir_module
class MyModuleWithAxisRemapSugar:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")):
Y = T.alloc_buffer((128, 128), dtype="float32")
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))