转换
在本节中,我们将介绍编译流程的主要组成部分 - 原始张量函数的转换。
在上一节中,我们给出了一个如何使用 TensorIR 编写 mm_relu
的示例。在实践中,实现相同功能可以有多种方法,并且每种实现方式都可能导致不同的性能。
注意
本教程主要说明 TensorIR 转换的应用,而不是深入研究优化技术。
首先,让我们看一下上一节中 mm_relu
的实现
import tvm
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module
class MyModule:
@T.prim_func
def main(
A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
Y = T.alloc_buffer((128, 128))
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
在我们转换函数之前,让我们首先评估原始实现的性能。
import numpy as np
a_np = np.random.uniform(size=(128, 128)).astype("float32")
b_np = np.random.uniform(size=(128, 128)).astype("float32")
c_np = a_np @ b_np
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.array(np.zeros((128, 128), dtype="float32"))
def evaluate(mod: tvm.IRModule):
lib = tvm.tir.build(mod, target="llvm")
# check correctness
lib(a_nd, b_nd, c_nd)
np.testing.assert_allclose(c_nd.numpy(), c_np, rtol=1e-5)
# evaluate performance
f_timer = lib.time_evaluator("main", tvm.cpu())
print(f_timer(a_nd, b_nd, c_nd))
evaluate(MyModule)
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
2.3102 2.3102 2.3102 2.3102 0.0000
初始化调度
我们通过建立 Schedule 辅助类来启动代码转换过程,使用提供的 MyModule 作为输入。
sch = tvm.tir.Schedule(MyModule)
循环分块
随后,我们执行必要的操作以获取对块 Y 及其关联循环的引用。
我们现在继续执行转换。最初的修改涉及将循环 j
拆分为两个单独的循环,其中内部循环的长度为 4。至关重要的是要理解转换过程是程序性的;因此,两次无意执行该块将产生错误,指出变量 j
不存在。
可以检查转换的结果,因为它保留在 sch.mod
中。
sch.mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0, j_1, k in T.grid(128, 16, 8, 128):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0.0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
在初始转换阶段之后,已经生成了两个补充循环 j_0
和 j_1
,它们的范围分别为 32 和 4。后续操作涉及重新排序这两个循环。
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0, k, j_1 in T.grid(128, 16, 128, 8):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0.0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
0.8779 0.8779 0.8779 0.8779 0.0000
利用局部性
随后,我们将执行两个额外的转换步骤以实现不同的变体。首先,我们使用一个名为 reverse_compute_at 的原语,将块 C 重新定位到 Y 的内部循环。
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 16):
for k, j_1 in T.grid(128, 8):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0.0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(8):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + ax0)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
重写归约
到目前为止,归约初始化和更新步骤一直保持在单个块体中。这种合并形式有助于循环转换,因为初始化和更新的外部循环 i
、j
通常需要保持同步。
在循环转换之后,我们可以通过 decompose_reduction 原语将 Y 元素的初始化与归约更新分开。
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 16):
for j_1_init in range(8):
with T.block("Y_init"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
T.reads()
T.writes(Y[vi, vj])
Y[vi, vj] = T.float32(0.0)
for k, j_1 in T.grid(128, 8):
with T.block("Y_update"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(8):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + ax0)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
0.3304 0.3304 0.3304 0.3304 0.0000
追踪转换
TensorIR 调度是一种过程式语言,转换以逐步方式执行。我们可以通过打印调度或调度的历史记录来追踪转换。
我们已经通过打印 sch.mod
查看了调度。我们还可以通过 sch.trace
打印调度的历史记录。
sch.trace.show()
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
b0 = sch.get_block(name="Y", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
l4, l5 = sch.split(loop=l2, factors=[None, 8], preserve_unit_iters=True, disable_predication=False)
sch.reorder(l4, l3, l5)
b6 = sch.get_block(name="C", func_name="main")
sch.reverse_compute_at(block=b6, loop=l4, preserve_unit_loops=False, index=-1)
b7 = sch.decompose_reduction(block=b0, loop=l3)
或者,我们可以结合历史追踪输出 IRModule。
sch.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 16):
for j_1_init in range(8):
with T.block("Y_init"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
T.reads()
T.writes(Y[vi, vj])
Y[vi, vj] = T.float32(0.0)
for k, j_1 in T.grid(128, 8):
with T.block("Y_update"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(8):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + ax0)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
b0 = sch.get_block(name="Y", func_name="main")
l1, l2, l3 = sch.get_loops(block=b0)
l4, l5 = sch.split(loop=l2, factors=[None, 8], preserve_unit_iters=True, disable_predication=False)
sch.reorder(l4, l3, l5)
b6 = sch.get_block(name="C", func_name="main")
sch.reverse_compute_at(block=b6, loop=l4, preserve_unit_loops=False, index=-1)
b7 = sch.decompose_reduction(block=b0, loop=l3)