tvm.tir

张量级 IR 的命名空间

class tvm.tir.Buffer

TVM 中的符号数据缓冲区。

Buffer 提供了一种在 TVM 中表示数据结构的数据布局特例化的方法。

不要直接构造,请使用 decl_buffer() 代替。 有关更多详细信息,请参阅 decl_buffer() 的文档。

另请参阅

decl_buffer

声明缓冲区

access_ptr(access_mask, ptr_type='handle', content_lanes=1, offset=0, extent=None)

获取缓冲区头部的访问指针。

这是与外部函数交互时获取缓冲区数据指针的推荐方法。

参数:
  • access_mask (int) – 访问模式 MASK。 指示访问将读取还是写入数据内容。

  • ptr_type (str, 可选) – 结果指针的数据类型。 除非我们想将指针强制转换为特定类型,否则不要指定。

  • content_lanes (int, 可选) – 数据类型的通道数。 对于向量类型,此值大于 1。

  • offset (Expr, 可选) – 指针的偏移量。 我们可以使用它从 ptr 的地址偏移元素数量。

  • extent (Expr, 可选) – 指针的范围。

示例

# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
# Get access ptr for read with extent
buffer.access_ptr("r", extent = 100)
vload(begin, dtype=None, predicate=None)

生成一个从起始索引加载 dtype 的 Expr。

参数:
  • begin (Array of Expr) – 以 Buffer.dtype 为单位的起始索引

  • dtype (str) – 要加载的数据类型,可以是向量类型,其通道数是 Buffer.dtype 的倍数

  • predicate (Optional[PrimExpr]) – 一个布尔值向量掩码,指示要加载的向量的哪些通道。 掩码的通道数必须等于要加载的通道数。

返回值:

load – 相应的加载表达式。

返回类型:

Expr

vstore(begin, value, predicate=None)

生成一个将值存储到起始索引的 Stmt。

参数:
  • begin (Array of Expr) – 以 Buffer.dtype 为单位的起始索引

  • value (Expr) – 要存储的值。

  • predicate (Optional[PrimExpr]) – 一个布尔值向量掩码,指示要存储的向量的哪些通道。 掩码的通道数必须等于值中的通道数。

返回值:

store – 相应的存储 stmt。

返回类型:

Stmt

scope()

返回与此缓冲区关联的存储范围。 :returns: scope – 与此缓冲区关联的存储范围。 :rtype: str

get_flattened_buffer()

生成一个作为此缓冲区扁平化版本的 Buffer。

返回值:

flattened – 相应的扁平缓冲区。

返回类型:

Buffer

offset_of(indices)

确定扁平缓冲区中提供的索引的偏移量。

参数:

indices (Union[PrimExpr, List[PrimExpr]]) – 原始缓冲区中元素的索引。

返回值:

flattened_indices – 扁平缓冲区中元素的偏移索引。

返回类型:

List[PrimExpr]

tvm.tir.decl_buffer(shape, dtype=None, name='buffer', data=None, strides=None, elem_offset=None, scope='', data_alignment=-1, offset_factor=0, buffer_type='', axis_separators=None, span=None)

声明一个新的符号缓冲区。

通常,缓冲区在 lower 和 build 期间自动创建。 仅当用户想要指定自己的缓冲区布局时才需要这样做。

有关缓冲区用法的详细讨论,请参见下面的注释。

参数:
  • shape (tuple of Expr) – 缓冲区的形状。

  • dtype (str, 可选) – 缓冲区的数据类型。

  • name (str, 可选) – 缓冲区的名称。

  • data (tir.Var, 可选) – 缓冲区中的数据指针。

  • strides (array of Expr) – 缓冲区的步幅。

  • elem_offset (Expr, 可选) – 数组到数据的起始偏移量。 以 dtype 的元素数量表示。

  • scope (str, 可选) – 缓冲区的存储范围,如果不是全局范围。 如果 scope 等于空字符串,则表示它是全局内存。

  • data_alignment (int, 可选) – 数据指针的对齐(以字节为单位)。 如果传递 -1,则对齐将设置为 TVM 的内部默认值。

  • offset_factor (int, 可选) – elem_offset 字段的因子,设置后,elem_offset 必须是 offset_factor 的倍数。 如果传递 0,则对齐将设置为 1。 如果传递非零值,如果 elem_offset 不为 None,我们将为 elem_offset 创建一个 tir.Var。

  • buffer_type (str, 可选, {"", "auto_broadcast"}) – auto_broadcast 缓冲区允许用户实现广播计算,而无需考虑维度大小是否等于 1。 如果维度 j 的形状等于 1,则 TVM 将 buffer[i][j][k] 映射到 buffer[i][0][k]。

  • axis_separators (list of int, 可选) – 如果传递,则为轴组之间分隔符的列表,每个轴组都扁平化为输出轴。 对于平面内存空间,应为 None 或空列表。

  • span (Optional[Span]) – decl_buffer 创建在源代码中的位置。

返回值:

buffer – 创建的缓冲区

返回类型:

tvm.tir.Buffer

注意

缓冲区数据结构反映了 dlpack 中的 DLTensor 结构。 虽然 DLTensor 数据结构非常通用,但通常有助于创建仅处理数据结构特定情况的函数,并使编译后的函数从中受益。

如果用户在构造函数时将 strides 和 elem_offset 作为 None 传递,则该函数将专门用于紧凑且对齐的 DLTensor。 如果用户将完全通用的符号数组传递给 strides,则生成的函数将变为完全通用。

class tvm.tir.DataProducer
class tvm.tir.Layout

布局由大写字母、小写字母和数字组成,其中大写字母表示主轴,相应的小写字母和因子大小表示从属轴。 例如,NCHW16c 可以描述一个 5-D 张量 [batch_size, channel, height, width, channel_block]。 这里从属轴 channel_block=16 是主轴 C(通道)的因子大小。

另请参阅

layout

声明布局

index_of(axis)

获取轴的索引

参数:

axis (str) – 轴名称,需要为 [a-z,A-Z]

返回值:

index – 轴的索引,如果未找到则为 -1。

返回类型:

int

factor_of(axis)

获取从属轴的因子大小。

参数:

axis (str) – 轴名称,需要为 [a-z,A-Z]

返回值:

factor – 轴的从属轴的大小(如果 axis 是主轴),或轴本身的大小(如果 axis 是从属轴)。 如果轴不在布局中,则返回 -1。

返回类型:

int

class tvm.tir.BijectiveLayout

两个布局(src-layout 和 dst-layout)的双射映射。 它提供彼此之间的形状和索引转换。

不要直接构造,请使用 bijective_layout 代替。 有关更多详细信息,请参阅 bijective_layout 的文档。

参数:
  • src_layout (str or Layout) – 源布局。

  • dst_layout (str or Layout) – 目标布局。

另请参阅

bijective_layout

声明布局

forward_index(index)

给定 src-layout 的索引,推断 dst 索引。

参数:

index (Array of Expr) – src-layout 中的索引。

返回值:

dst_index – dst-layout 中推断的索引。

返回类型:

Array of Expr

backward_index(index)

给定 dst-layout 的索引,推断 src 索引。

参数:

index (Array of Expr) – dst-layout 中的索引。

返回值:

src_index – src-layout 中推断的索引。

返回类型:

Array of Expr

forward_shape(shape)

给定 src-layout 的形状,推断 dst 形状。

参数:

shape (Array of Expr) – src-layout 中的形状。

返回值:

dst_shape – dst-layout 中推断的形状。

返回类型:

Array of Expr

backward_shape(shape)

给定 dst-layout 的形状,推断 src 形状。

参数:

shape (Array of Expr) – dst-layout 中的形状。

返回值:

src_shape – src-layout 中推断的形状。

返回类型:

Array of Expr

tvm.tir.bijective_layout(src_layout: str | Layout, dst_layout: str | Layout) BijectiveLayout

创建双射布局映射。

参数:
  • src_layout (str or Layout) – 源布局。

  • dst_layout (str or Layout) – 目标布局。

返回值:

bijective_layout – 创建的双射布局

返回类型:

BijectiveLayout

tvm.tir.layout(layout_str: str, dtype: str = 'int32') Layout

从字符串创建布局节点。

参数:
  • layout_str (str) – 布局表示形式由大写字母、小写字母和数字组成,其中大写字母表示主轴,相应的小写字母和因子大小表示从属轴。 例如,NCHW16c 可以描述一个 5-D 张量 [batch_size, channel, height, width, channel_block]。 这里从属轴 channel_block=16 是主轴 C(通道)的因子大小。

  • dtype (str) – 返回的布局中生成的轴变量的 dtype。 它必须是整数类型。

返回值:

layout – 创建的布局

返回类型:

Layout

class tvm.tir.Var(name: str, dtype: str | Type, span: Span | None = None)

符号变量。

参数:
  • name (str) – 名称

  • dtype (Union[str, ir.Type]) – 数据类型

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.SizeVar(name: str, dtype: str | Type, span: Span | None = None)
用于表示张量索引大小的符号变量

它大于或等于零。

参数:
  • name (str) – 名称

  • dtype (Union[str, ir.Type]) – 数据类型

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Reduce(combiner: CommReducer, src: List[PrimExpr], rdom: List[IterVar], condition: PrimExpr, value_index: int, init: List[PrimExpr] | None = None, span: Span | None = None)

Reduce 节点。

参数:
  • combiner (CommReducer) – 合并器。

  • src (list of Expr) – 源表达式。

  • rdom (list of IterVar) – 迭代域

  • condition (PrimExpr) – reduce 条件。

  • value_index (int) – 值索引。

  • init (list of Expr) – 输出的初始值。 这可以是 int、float 或 ProducerLoad

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.FloatImm(dtype: str, value: float, span: Span | None = None)

浮点常量。

参数:
  • dtype (str) – 数据类型

  • value (float) – 常数值。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.IntImm(dtype: str, value: int, span: Span | None = None)

整数常量。

参数:
  • dtype (str) – 数据类型

  • value (int) – 常量值。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.StringImm(value: str, span: Span | None = None)

字符串常量。

参数:
  • value (str) – 函数的值。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Cast(dtype, value, span: Span | None = None)

类型转换表达式。

参数:
  • dtype (str) – 数据类型

  • value (PrimExpr) – 函数的值。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Add(a: PrimExpr, b: PrimExpr, span: Span | None = None)

加法节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Sub(a: PrimExpr, b: PrimExpr, span: Span | None = None)

减法节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Mul(a: PrimExpr, b: PrimExpr, span: Span | None = None)

乘法节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Div(a: PrimExpr, b: PrimExpr, span: Span | None = None)

除法节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Mod(a: PrimExpr, b: PrimExpr, span: Span | None = None)

取模节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.FloorDiv(a: PrimExpr, b: PrimExpr, span: Span | None = None)

向下取整除法节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.FloorMod(a: PrimExpr, b: PrimExpr, span: Span | None = None)

向下取整取模节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Min(a: PrimExpr, b: PrimExpr, span: Span | None = None)

最小值节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Max(a: PrimExpr, b: PrimExpr, span: Span | None = None)

最大值节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.EQ(a: PrimExpr, b: PrimExpr, span: Span | None = None)

等于节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.NE(a: PrimExpr, b: PrimExpr, span: Span | None = None)

不等于节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.LT(a: PrimExpr, b: PrimExpr, span: Span | None = None)

小于节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.LE(a: PrimExpr, b: PrimExpr, span: Span | None = None)

小于等于节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.GT(a: PrimExpr, b: PrimExpr, span: Span | None = None)

大于节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.GE(a: PrimExpr, b: PrimExpr, span: Span | None = None)

大于等于节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.And(a: PrimExpr, b: PrimExpr, span: Span | None = None)

逻辑与节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Or(a: PrimExpr, b: PrimExpr, span: Span | None = None)

逻辑或节点。

参数:
  • a (PrimExpr) – 左操作数。

  • b (PrimExpr) – 右操作数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Not(a: PrimExpr, span: Span | None = None)

逻辑非节点。

参数:
  • a (PrimExpr) – 输入值

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Select(condition: PrimExpr, true_value: PrimExpr, false_value: PrimExpr, span: Span | None = None)

选择节点。

注意

Select 可能会计算 true_value 和 false_value。如果想要获取仅评估正确分支的条件表达式,请使用 tvm.tir.if_then_else 代替。

参数:
  • condition (PrimExpr) – 条件表达式。

  • true_value (PrimExpr) – 当条件为真时取的值。

  • false_value (PrimExpr) – 当条件为假时取的值。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.BufferLoad(buffer: Buffer, indices: List[PrimExpr], predicate: PrimExpr | None = None, span: Span | None = None)

Buffer 加载节点。

参数:
  • buffer (Buffer) – 要加载的 buffer。

  • indices (List[PrimExpr]) – 用于加载值的 buffer 索引。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

  • predicate (Optional[PrimExpr]) – 一个布尔值向量掩码,指示要加载的向量的哪些通道。 掩码的通道数必须等于要加载的通道数。

class tvm.tir.ProducerLoad(producer: DataProducer, indices: List[PrimExpr], span: Span | None = None)

Producer 加载节点。

参数:
  • producer (DataProducer) – 要加载的 buffer。

  • indices (List[PrimExpr]) – buffer 索引。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Ramp(base: PrimExpr, stride: PrimExpr, lanes: PrimExpr, span: Span | None = None)

Ramp 节点。

参数:
  • base (PrimExpr) – 基表达式。

  • stride (PrimExpr) – ramp 的步长。

  • lanes (PrimExpr) – 表达式的通道数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Broadcast(value: PrimExpr, lanes: PrimExpr, span: Span | None = None)

Broadcast 节点。

参数:
  • value (PrimExpr) – 表达式的值。

  • lanes (PrimExpr) – 表达式的通道数。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Shuffle(vectors: List[PrimExpr], indices: List[PrimExpr], span: Span | None = None)

Shuffle 节点。

参数:
  • vectors (List[PrimExpr]) – 向量

  • indices (List[PrimExpr]) – 索引

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Call(dtype: str, op: Op | str, args: List[PrimExpr], span: Span | None = None)

tir.Call 节点。

参数:
  • dtype (str) – 返回数据类型

  • op (Union[Op, str]) – 要调用的函数,或全局 tvm.Op 的名称

  • args (list of Expr) – 调用的输入参数

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.CallEffectKind

tir.Call 效果的可能种类。

class tvm.tir.Let(var: Var, value: PrimExpr, body: PrimExpr, span: Span | None = None)

Let 节点。

参数:
  • var (tir.Var) – 绑定中的变量。

  • value (PrimExpr) – 要绑定的值。

  • body (PrimExpr) – 主体表达式。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.IterVar(dom: Range, var: Var | str, iter_type: int, thread_tag: str = '', span: Span | None = None)

表示迭代变量。

IterVar 表示计算中的轴迭代。

参数:
  • dom (Range) – 迭代的域。

  • var (Union[tir.Var, str]) – 用于迭代的内部变量。

  • iter_type (int) – 迭代类型。

  • thread_tag (str) – 线程类型标签。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

另请参阅

te.thread_axis

创建线程轴 IterVar。

te.reduce_axis

创建规约轴 IterVar。

class tvm.tir.CommReducer(lhs: List[Var], rhs: List[Var], result: List[PrimExpr], identity_element: List[PrimExpr], span: Span | None = None)

交换规约运算符

参数:
  • lhs (List[tir.Var]) – 规约器的左侧参数。

  • rhs (List[tir.Var]) – 规约器的右侧参数。

  • result (List[PrimExpr]) – 规约结果。

  • identity_element (List[PrimExpr]) – 单位元素。

  • span (Optional[Span]) – 此表达式在源代码中的位置。

class tvm.tir.Any(span: Span | None = None)

Any 节点。

spanOptional[Span]

此表达式在源代码中的位置。

class tvm.tir.Stmt

所有语句的基类。

class tvm.tir.LetStmt(var: Var, value: PrimExpr, body: Stmt, span: Span | None = None)

LetStmt 节点。

参数:
  • var (tir.Var) – 绑定中的变量。

  • value (PrimExpr) – 要绑定的值。

  • body (Stmt) – 主体语句。

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.AssertStmt(condition: PrimExpr, message: PrimExpr, body: Stmt, span: Span | None = None)

AssertStmt 节点。

参数:
  • condition (PrimExpr) – 断言条件。

  • message (PrimExpr) – 错误消息。

  • body (tvm.tir.Stmt) – 主体语句。

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.ForKind(value)

for 循环的种类。

注意

ForKind 可以更改循环的控制流语义,需要在所有 TIR 通道中考虑。

class tvm.tir.For(loop_var: Var, min: PrimExpr, extent: PrimExpr, kind: ForKind, body: Stmt, thread_binding: IterVar | None = None, annotations: Mapping[str, Object] | None = None, span: Span | None = None)

For 节点。

参数:
  • loop_var (tir.Var) – 循环变量。

  • min (PrimExpr) – 起始值。

  • extent (PrimExpr) – 循环的长度。

  • kind (ForKind) – for 循环的类型。

  • body (Stmt) – 主体语句。

  • thread_binding (Optional[tir.IterVar]) – 此循环绑定到的线程。仅当 kind 为 ThreadBinding 时有效

  • annotations (Optional[Mapping[str, Object]]) – 附加的注解提示。

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.While(condition: PrimExpr, body: Stmt, span: Span | None = None)

While 节点。

参数:
  • condition (PrimExpr) – 终止条件。

  • body (Stmt) – 主体语句。

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.BufferStore(buffer: Buffer, value: PrimExpr, indices: List[PrimExpr], predicate: PrimExpr | None = None, span: Span | None = None)

Buffer 存储节点。

参数:
  • buffer (Buffer) – 缓冲区。

  • value (PrimExpr) – 要存储的值。

  • indices (List[PrimExpr]) – 要存储的索引位置。

  • predicate (Optional[PrimExpr]) – 一个布尔值向量掩码,指示要存储的向量的哪些通道。 掩码的通道数必须等于值中的通道数。

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.BufferRealize(buffer: Buffer, bounds: List[Range], condition: PrimExpr, body: Stmt, span: Span | None = None)

Buffer 实现节点。

参数:
  • buffer (Buffer) – 缓冲区。

  • bounds (List[Range]) – 要存储的值。

  • condition (PrimExpr) – 实现条件。

  • body (Stmt) – 语句的主体。

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.ProducerStore(producer: DataProducer, value: PrimExpr, indices: List[PrimExpr], span: Span | None = None)

ProducerStore 节点。

参数:
  • producer (DataProducer) – 数据生产者。

  • value (PrimExpr) – 要存储的值。

  • indices (list of Expr) – 存储的索引参数。

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.Allocate(buffer_var: Var, dtype: str, extents: List[PrimExpr], condition: PrimExpr, body: Stmt, annotations: Mapping[str, Object] | None = None, span: Span | None = None)

Allocate 节点。

参数:
  • buffer_var (tir.Var) – 缓冲区变量。

  • dtype (str) – 缓冲区的数据类型。

  • extents (list of Expr) – allocate 的范围

  • condition (PrimExpr) – 条件。

  • body (Stmt) – 主体语句。

  • annotations (Optional[Mapping[str, Object]]) – 附加的注解提示

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.AllocateConst(buffer_var: Var, dtype: str, extents: List[PrimExpr], data_or_idx: NDArray | int, body: Stmt, annotations: Mapping[str, Object] | None = None, span: Span | None = None)

分配常量节点。

参数:
  • buffer_var (tir.Var) – 缓冲区变量。

  • dtype (str) – 缓冲区的数据类型。

  • extents (list of Expr) – allocate 的范围

  • data_or_idx (Union[NDArray, int]) – 如果是 NDArray,这是与常量关联的常量数据。如果是整数,这是指 IRModule 的 “constants” 属性的索引,该属性包含 AllocateConst

  • body (Stmt) – 主体语句。

  • annotations (Optional[Mapping[str, Object]]) – 关于分配的附加注解。

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.AttrStmt(node: Object, attr_key: str, value: PrimExpr, body: Stmt, span: Span | None = None)

AttrStmt 节点。

参数:
  • node (Object) – 要注解属性的节点

  • attr_key (str) – 属性类型键。

  • value (PrimExpr) – 属性的值

  • body (Stmt) – 主体语句。

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.DeclBuffer(buffer: Buffer, body: Stmt, span: Span | None = None)

DeclBuffer 节点。

参数:
  • buffer (Buffer) – 被声明的缓冲区。

  • body (Stmt) – 要执行的主体语句。

  • span (Optional[Span]) – 此 DeclBuffer 在源代码中的位置。

class tvm.tir.ProducerRealize(producer: DataProducer, bounds: List[Range], condition: PrimExpr, body: Stmt, storage_scope: str = '', span: Span | None = None)

ProducerRealize 节点。

参数:
  • producer (DataProducer) – 数据生产者。

  • bounds (List[Range]) – realize 的边界

  • condition (PrimExpr) – 实现条件。

  • body (Stmt) – realize 主体

  • storage_scope (str) – 与此 realization 关联的存储作用域

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.SeqStmt(seq: List[Stmt], span: Span | None = None)

语句序列。

参数:
  • seq (List[Stmt]) – 语句列表

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.IfThenElse(condition: PrimExpr, then_case: Stmt, else_case: Stmt | None, span: Span | None = None)

IfThenElse 节点。

参数:
  • condition (PrimExpr) – 条件表达式

  • then_case (Stmt) – 如果条件为真,则执行的语句。

  • else_case (Optional[Stmt]) – 如果条件为假,则执行的语句。

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.Evaluate(value: PrimExpr, span: Span | None = None)

Evaluate 节点。

参数:
  • value (PrimExpr) – 要计算的表达式。

  • span (Optional[Span]) – 语句在源代码中的位置。

class tvm.tir.Prefetch(buffer: Buffer, bounds: List[Range], span: Span | None = None)

Prefetch 节点。

参数:
  • buffer (Buffer) – 要预取的缓冲区。

  • bounds (List[Range]) – 要预取的边界。

  • span (Optional[Span]) – 语句在源代码中的位置。

tvm.tir.stmt_seq(*args: PrimExpr | Stmt) SeqStmt

创建语句序列

参数:

*args (Union[PrimExpr, Stmt]) – 要组合为序列的语句列表。

返回值:

stmt – 组合后的语句。

返回类型:

Stmt

tvm.tir.stmt_list(stmt: Stmt) List[Stmt]

从块创建语句列表。

参数:

stmt (Stmt) – 输入语句。

返回值:

stmt_list – 解包后的语句列表

返回类型:

List[Stmt]

class tvm.tir.BufferRegion(buffer: Buffer, region: List[Range])

BufferRegion 节点。

参数:
  • buffer (Buffer) – 缓冲区区域的缓冲区

  • region (List[Range]) – 缓冲区区域的区域数组

class tvm.tir.MatchBufferRegion(buffer: Buffer, source: BufferRegion)

MatchBufferRegion 节点。

参数:
class tvm.tir.Block(iter_vars: List[IterVar], reads: List[BufferRegion], writes: List[BufferRegion], name_hint: str, body: Stmt, init: Stmt | None = None, alloc_buffers: List[Buffer] | None = None, match_buffers: List[MatchBufferRegion] | None = None, annotations: Mapping[str, Object] | None = None, span: Span | None = None)

Block 节点。

参数:
  • iter_vars (List[IterVar]) – 块变量。

  • reads (List[BufferRegion]) – 块的读取缓冲区区域。

  • writes (List[BufferRegion]) – 块的写入缓冲区区域。

  • name_hint (str) – 块的 name_hint。

  • body (Stmt) – 块的主体。

  • init (Optional[Stmt]) – reduction 块的 init 块

  • alloc_buffers (Optional[list[Buffer]]) – 缓冲区分配

  • match_buffers (Optional[List[MatchBufferRegion]]) – 子区域缓冲区匹配

  • annotations (Optional[Mapping[str, Object]]) – 附加的注解提示。

  • span (Optional[Span]) – 此块在源代码中的位置。

class tvm.tir.BlockRealize(iter_values: List[PrimExpr], predicate: PrimExpr | bool, block: Block, span: Span | None = None)

BlockRealize 节点。

参数:
  • iter_values (List[PrimExpr]) – 块变量的绑定值。

  • predicate (Union[PrimExpr, bool]) – 块的谓词。

  • block (Block) – 要 realize 的块

  • span (Optional[Span]) – 此 block_realize 在源代码中的位置。

class tvm.tir.PrimFunc(params, body, ret_type=None, buffer_map=None, attrs=None, span=None)

函数声明表达式。

参数:
with_body(new_body, span=None)

创建一个具有相同签名但新主体的 PrimFunc。

参数:
  • new_body (Stmt) – 新的主体。

  • span (Optional[Span]) – 此 itervar 在源代码中的位置。

返回值:

new_func – 创建的新函数。

返回类型:

PrimFunc

specialize(param_map: Mapping[Var, PrimExpr | Buffer])

特化 PrimFunc 的参数

参数:

param_map (Mapping[tir.Var, Union[PrimExpr, Buffer]]) – 从函数参数到实例的映射

示例

我们可以使用符号形状定义 Meta TIR 函数

@T.prim_func
def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
    A = T.match_buffer(a, (m, n), "float32")
    B = T.match_buffer(b, (m, n), "float32")

    for i, j in T.grid(m, n):
        with T.block():
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj]

然后我们可以使用给定的形状或缓冲区使其特化。

a, _, m, n = mem_copy.params
func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
# or
func = mem_copy.specialize({n: 16, m: 16})

特化后的函数

@T.prim_func
def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")
    B = T.match_buffer(b, (16, 16), "float32")

    for i, j in T.grid(16, 16):
        with T.block():
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj]
返回值:

func – 参数特化后的新函数

返回类型:

PrimFunc

class tvm.tir.TensorIntrin(desc, impl)

张量内联函数。

参数:
  • desc (PrimFunc) – 描述计算的函数。

  • impl (PrimFunc) – 执行实现的函数。

static register(name: str, desc: PrimFunc, impl: PrimFunc, override: bool = False)

使用名称注册张量内在函数。

参数:
  • name (str) – 要注册的 TensorIntrin 的名称。

  • desc (PrimFunc) – 描述计算的函数。

  • impl (PrimFunc) – 执行实现的函数。

  • override (bool) – 是否覆盖现有的内在函数。

static get(name: str, allow_missing: bool = False) TensorIntrin | None

通过名称查找张量内在函数。

参数:
  • name (str) – 要查找的 TensorIntrin 的名称。

  • allow_missing (bool) – 是否允许缺少张量内在函数。如果为 False,则在张量内在函数

  • exist. (不存在) 时引发错误。

返回值:

result – 具有指定名称的 TensorIntrin,如果未找到则为 None。

返回类型:

Optional[TensorIntrin]

class tvm.tir.IndexMap(initial_indices, final_indices, inverse_index_map)

从多维索引到另一组多维索引的映射

参数:
  • initial_indices (List[tir.Var]) – 表示重映射之前的索引的变量。

  • final_indices (List[PrimExpr]) – 定义重映射之后的索引的表达式。

  • inverse_index_map (Union[Callable, Optional[IndexMap]]) – 可选的预定义逆索引映射。当定义此项时,IndexMap::Inverse 将返回预定义的逆索引映射。否则,逆索引映射将即时计算。用户有责任确保预定义逆索引映射的正确性。

static from_func(mapping_function: Callable, ndim: int | None = None, inverse_index_map: Callable | IndexMap | None = None, *, index_dtype: str = 'int64')

从函数创建索引映射

参数:
  • mapping_function (Callable) – 从源索引映射到目标索引的函数。该函数应接受 tir.Var 参数,并返回 tir.PrimExprtir.PrimExpr 列表。返回 tir.PrimExpr 等同于返回包含该 tir.PrimExpr 的长度为 1 的列表。

  • ndim (Optional[int]) – 此转换应应用到的缓冲区的维数。如果 mapping_function 使用可变参数 *args,则必须指定 ndim。如果 mapping_function 不使用可变参数,则 ndim 是可选的。

  • inverse_index_map (Union[Callable, Optional[IndexMap]]) – 可选的预定义逆索引映射。当定义此项时,IndexMap::Inverse 将返回预定义的逆索引映射。否则,逆索引映射将即时计算。用户有责任确保预定义逆索引映射的正确性。

返回值:

index_map – 返回表示 mapping_function 的 IndexMap。

返回类型:

IndexMap

static from_func_with_separators(mapping_function: Callable, ndim: int | None = None, inverse_index_map: Callable | IndexMap | None = None, *, index_dtype: str = 'int64')

从函数创建索引映射

参数:
  • mapping_function (Callable) – 从源索引映射到目标索引的函数。该函数应接受 tir.Var 参数,并返回 tir.PrimExpr 或列表。返回列表的每个元素都应该是 tir.PrimExpr 或对象 IndexMap.AXIS_SEPARATOR。返回 tir.PrimExpr 等同于返回包含该 tir.PrimExpr 的长度为 1 的列表。

  • ndim (Optional[int]) – 此转换应应用到的缓冲区的维数。如果 mapping_function 使用可变参数 *args,则必须指定 ndim。如果 mapping_function 不使用可变参数,则 ndim 是可选的。

  • inverse_index_map (Union[Callable, Optional[IndexMap]]) – 可选的预定义逆索引映射。当定义此项时,IndexMap::Inverse 将返回预定义的逆索引映射。否则,逆索引映射将即时计算。用户有责任确保预定义逆索引映射的正确性。

  • index_dtype (str) – 用于映射函数中输入迭代器的默认索引数据类型。

返回值:

ret – 返回一个元组,其第一个元素是表示 mapping_function 的 IndexMap,第二个索引是 IndexMap.AXIS_SEPARATOR 出现的索引列表。

返回类型:

Tuple[IndexMap, List[int]]

is_equivalent_to(other_map: IndexMap) bool

返回索引映射是否等效。

参数:

other_map (IndexMap) – 应与之进行比较的 IndexMap。

返回值:

is_equivalent – 如果两个映射表示相同的转换,则为 True,否则为 False

返回类型:

bool

map_indices(indices: List[PrimExpr]) List[PrimExpr]

将索引映射应用于一组索引

参数:

indices (List[PrimExpr]) – 要映射的索引

返回值:

result – 映射后的索引

返回类型:

List[PrimExpr]

map_shape(shape: List[PrimExpr]) List[PrimExpr]

将索引映射应用于缓冲区形状

参数:

shape (List[PrimExpr]) – 要映射的缓冲区形状

返回值:

result – 映射后的形状

返回类型:

List[PrimExpr]

map_ndarray(arr_src: NDArray) NDArray

应用此索引映射来转换输入 NDArray 的布局

参数:

arr_src (runtime.NDArray) – 要转换的 NDArray

返回值:

arr_dst – 转换后的 NDArray

返回类型:

runtime.NDArray

inverse(shape: List[Range | PrimExpr]) IndexMap

返回映射的逆

如果函数不是双射的,则抛出错误。

参数:

shape (List[Union[Range,PrimExpr]]) – 应在其中确定逆的区域。用于验证映射在此范围内是否是双射的。

返回值:

inverse – 逆

返回类型:

IndexMap

non_surjective_inverse(shape: List[Range | PrimExpr]) Tuple[IndexMap, PrimExpr]

返回映射的逆

可以应用于引入填充的转换。

参数:

shape (List[Union[Range,PrimExpr]]) – 应在其中确定逆的区域。用于确定谓词。

返回值:

result – 逆,以及谓词,对于该谓词,逆映射到输入范围内的有效索引。

返回类型:

Tuple[IndexMap, PrimExpr]

示例

index_map = IndexMap.from_func(lambda i: [i//4, i%4])
inverse_map, predicate = index_map.non_surjective_inverse([14])
assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k])
print(predicate) # Prints "(axis0==3) && (axis2 >= 2)"
tvm.tir.call_packed_lowered(*args, span=None)

Packed 调用的降级版本。packed 函数的参数可以是 Expr 或 Buffer。当呈现 Expr 时,参数是相应的 POD 类型。当参数是 Buffer 时,相应的 PackedFunc 将接收一个 TVMArrayHandle,其内容在回调期间有效。如果 PackedFunc 是 python 回调,则相应的参数是 NDArray。

参数:
  • args (list of Expr or Buffer.) – 位置参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

另请参阅

te.extern

使用 extern 函数调用创建张量。

tvm.tir.call_cpacked_lowered(*args, span=None)

C-packed 调用的降级版本。与 call_packed 相同,只是第一个参数是函数名称(如 call_extern 中所示),最后一个参数是资源句柄。

参数:
  • args (list of Expr or Buffer.) – 位置参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

另请参阅

te.extern

使用 extern 函数调用创建张量。

tvm.tir.call_tir(global_var: GlobalVar, *args)

在同一 IRModule 中执行对另一个 PrimFunc 的调用

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.call_packed(*args, span=None)

通过调用外部 packed 函数来构建表达式。

packed 函数的参数可以是 Expr 或 Buffer。当呈现 Expr 时,参数是相应的 POD 类型。

当参数是 Buffer 时,相应的 PackedFunc 将接收一个 TVMArrayHandle,其内容在回调期间有效。如果 PackedFunc 是 python 回调,则相应的参数是 NDArray。

参数:
  • args (list of Expr or Buffer.) – 位置参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

另请参阅

te.extern

使用 extern 函数调用创建张量。

tvm.tir.call_cpacked(*args, span=None)

通过调用外部 packed 函数来构建表达式。

与 call_packed 相同,只是第一个参数是函数名称(如 call_extern 中所示),最后一个参数是资源句柄。

参数:
  • args (list of Expr or Buffer.) – 位置参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

另请参阅

te.extern

使用 extern 函数调用创建张量。

tvm.tir.call_intrin(dtype, func_name, *args, span=None)

通过调用内在函数来构建表达式。

内在函数可以通过内在函数转换规则使用多种数据类型进行重载。

参数:
  • dtype (str) – 结果的数据类型。

  • func_name (str) – 内在函数名称。

  • args (list) – 位置参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.call_pure_extern(dtype, func_name, *args, span=None)

通过调用纯外部函数来构建表达式。

参数:
  • dtype (str) – 结果的数据类型。

  • func_name (str) – 外部函数名称。

  • args (list) – 位置参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.call_extern(dtype, func_name, *args, span=None)

通过调用外部函数来构建表达式。

参数:
  • dtype (str) – 结果的数据类型。

  • func_name (str) – 外部函数名称。

  • args (list) – 位置参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.call_llvm_intrin(dtype, name, *args, span=None)

通过调用 llvm 内在函数来构建表达式

参数:
  • dtype (str) – 结果的数据类型。

  • name (str) – llvm 内在函数的名称。

  • args (list) – 位置参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.call_llvm_pure_intrin(dtype, name, *args, span=None)

通过调用纯 llvm 内在函数来构建表达式

参数:
  • dtype (str) – 结果的数据类型。

  • name (str) – llvm 内在函数的名称。

  • args (list) – 位置参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ret(val, span=None)

创建 tir 返回表达式

参数:
  • val (Expr) – 返回的 tir 表达式,其数据类型为 int、float 或 void 指针。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

ret – 返回表达式

返回类型:

PrimExpr

tvm.tir.all(*args, span=None)
创建新表达式,表示参数中所有条件的交集

参数

参数:
  • args (list) – 符号布尔表达式列表

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

expr – 表达式

返回类型:

Expr

tvm.tir.any(*args, span=None)

创建新表达式,表示参数中所有条件的并集

参数:
  • args (list) – 符号布尔表达式列表

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

expr – 表达式

返回类型:

Expr

tvm.tir.min_value(dtype, span=None)

dtype 的最小值

参数:
  • dtype (str) – 数据类型。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

value – dtype 的最小值。

返回类型:

tvm.Expr

tvm.tir.max_value(dtype: str, span: Span | None = None) Any

dtype 的最大值

参数:
  • dtype (str) – 数据类型。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

value – dtype 的最大值。

返回类型:

tvm.Expr

tvm.tir.trace(args, trace_action='tvm.default_trace_action')

在运行时跟踪张量数据。

trace 函数允许在运行时跟踪特定的张量。跟踪值应作为最后一个参数出现。应指定跟踪操作,默认情况下使用 tvm.default_trace_action。

参数:
  • args (list of Expr or Buffers.) – 位置参数。

  • trace_action (str.) – 跟踪操作的名称。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

另请参阅

tvm.tir.call_packed

创建 packed 函数。

tvm.tir.tvm_check_return(expected, return_unexpected, nested_call)

返回堆栈上新的 dtype[num] :param expected: 预期返回代码。:type expected: int :param return_unexpected: 意外返回代码。:type return_unexpected: int :param nested_call: 要检查返回的调用表达式。:type nested_call: PrimExpr

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.tvm_stack_alloca(dtype_str, num)

返回堆栈上新的 dtype[num]

参数:
  • dtype_str (str) – 数组的数据类型。

  • num (int) – 数组的大小。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.tvm_stack_make_shape(*args)

在堆栈上分配形状元组,返回句柄

参数:

args (int) – 元组形状。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset)

在堆栈上分配 NDArray(DLTensor),返回句柄

参数:
  • data (Expr) – 数组的数据。

  • shape (Expr) – 数组的形状。

  • strides (Expr) – 数组的步幅。

  • ndim (Expr) – 数组的维度。

  • arr_dtype (Expr) – 数组的数据类型。

  • elem_offse (Expr) – 数组的元素偏移量。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.tvm_tuple(*value)

在 AttrStmt 的 value 字段中创建元组结构

参数:

value (Expr) – 元组中的值。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.tvm_struct_get(arr, index, field, dtype)

获取数组中的结构体字段值

参数:
  • dtype (str) – 结果的数据类型。

  • arr (StructType*) – 结构体数组。

  • index (int) – 结构体的索引。

  • field (int) – 结构体的字段。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.tvm_struct_set(arr, index, field, value)

在数组的结构体字段中设置值

参数:
  • arr (StructType*) – 结构体数组。

  • index (int) – 结构体的索引。

  • field (int) – 结构体的字段。

  • value (Expr) – 要在字段中设置的值。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.address_of(buffer_load, span=None)

返回缓冲区中元素的地址

参数:
  • buffer_load (BufferLoad) – 缓冲区加载。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.lookup_param(param_name, span=None)

按名称返回参数

参数:
  • param_name (str) – 参数的名称。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.assume(cond=None)

提供一个可用于简化的真语句

参数:

cond (Expr) – 约束条件。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.undef()

返回一个已初始化但任意的值

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.tvm_thread_allreduce(*freduce_args)

在线程块内执行 allreduce。

参数:

freduce_args (Expr) – 参数。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.type_annotation(dtype)

创建类型注释表达式

参数:

dtype (Expr) – 数据类型。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.tvm_access_ptr(ptype, data, offset, extent, rw_mask)

获取带有内存访问模式信息的头部访问地址

参数:
  • ptype (Expr) – 指针的数据类型。

  • data (DType*) – 指针的数据。

  • offset (int) – 指针的偏移量。

  • extent (int) – 指针的范围。

  • rw_mask (int) – 读写掩码。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.tvm_throw_last_error()

抛出 TVMGetLastError()

返回值:

ret – 返回表达式

返回类型:

PrimExpr

tvm.tir.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)

用于张量核心加载算子的 TVM intrinsic

参数:
  • fragment (tir.Var) – wmma 片段。

  • m (UIntImm) – wmma 片段的形状。

  • n (UIntImm) – wmma 片段的形状。

  • k (UIntImm) – wmma 片段的形状。

  • index (Expr) – 片段索引。

  • buffer_ptr (Expr) – 片段缓冲区指针。

  • stride (Expr) – 片段步幅。

  • layout (Literal["row_major", "column_major"]) – 片段布局。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)

用于张量核心存储算子的 TVM intrinsic

参数:
  • fragment (tir.Var) – wmma 片段。

  • m (UIntImm) – wmma 片段的形状。

  • n (UIntImm) – wmma 片段的形状。

  • k (UIntImm) – wmma 片段的形状。

  • index (Expr) – 片段索引。

  • buffer_ptr (Expr) – 片段缓冲区指针。

  • stride (Expr) – 片段步幅。

  • layout (Literal["row_major", "column_major"]) – 片段布局。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)

用于张量核心 mma_sync 算子的 TVM intrinsic

参数:
  • fragment_d (tir.Var) – wmma 片段 d。

  • index_d (Expr) – 片段 d 索引。

  • fragment_a (tir.Var) – wmma 片段 a。

  • index_a (Expr) – 片段 a 索引。

  • fragment_b (tir.Var) – wmma 片段 b。

  • index_b (Expr) – 片段 b 索引。

  • fragment_c (tir.Var) – wmma 片段 c。

  • index_c (Expr) – 片段 c 索引。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)

用于张量核心 bmma_sync 算子的 TVM intrinsic

参数:
  • fragment_d (tir.Var) – bwmma 片段 d。

  • index_d (Expr) – 片段 d 索引。

  • fragment_a (tir.Var) – bwmma 片段 a。

  • index_a (Expr) – 片段 a 索引。

  • fragment_b (tir.Var) – bwmma 片段 b。

  • index_b (Expr) – 片段 b 索引。

  • fragment_c (tir.Var) – bwmma 片段 c。

  • index_c (Expr) – 片段 c 索引。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.tvm_fill_fragment(fragment, m, n, k, index, value)

用于张量核心 fill_fragment 算子的 TVM intrinsic

参数:
  • fragment (tir.Var) – wmma 片段

  • m (UIntImm) – wmma 片段的形状。

  • n (UIntImm) – wmma 片段的形状。

  • k (UIntImm) – wmma 片段的形状。

  • index (Expr) – 片段索引。

  • value (Expr) – 要填充到片段中的值。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ptx_mma(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, saturate, operator=None)

用于 ptx 张量核心 mma 指令的 TVM intrinsic https://docs.nvda.net.cn/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma

参数:
  • dtype (str) – 结果的数据类型。

  • shape (str) – mma 片段的形状。

  • A_layout (Literal["row", "col"]) – 乘数片段 A 的布局。

  • B_layout (Literal["row", "col"]) – 乘数片段 B 的布局。

  • A_dtype (str) – 乘数片段 A 的数据类型。

  • B_dtype (str) – 乘数片段 B 的数据类型。

  • C_dtype (str) – 累加器片段 C 的数据类型。

  • multiplicand_a (tir.Var) – 乘数片段 A 变量。

  • a_index (Expr) – 乘数片段 A 的索引。

  • multiplicand_b (tir.Var) – 乘数片段 B 变量。

  • b_index (Expr) – 乘数片段 B 的索引。

  • accumulator (tir.Var) – 累加器片段 C 变量。

  • c_index (Expr) – 累加器片段 C 的索引。

  • saturate (bool) – 输出端的可选饱和。

  • operator (Optional[Literal["xor", "and"]]) – 1 位运算符。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ptx_mma_sp(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, metadata, meta_index, sparse_selector, saturate)

用于稀疏张量核心 ptx 指令的 TVM intrinsic https://docs.nvda.net.cn/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma

参数:
  • dtype (str) – 结果的数据类型。

  • shape (str) – mma 片段的形状。

  • A_layout (Literal["row", "col"]) – 乘数片段 A 的布局。

  • B_layout (Literal["row", "col"]) – 乘数片段 B 的布局。

  • A_dtype (str) – 乘数片段 A 的数据类型。

  • B_dtype (str) – 乘数片段 B 的数据类型。

  • C_dtype (str) – 累加器片段 C 的数据类型。

  • multiplicand_a (tir.Var) – 乘数片段 A 变量。

  • a_index (Expr) – 乘数片段 A 的索引。

  • multiplicand_b (tir.Var) – 乘数片段 B 变量。

  • b_index (Expr) – 乘数片段 B 的索引。

  • accumulator (tir.Var) – 累加器片段 C 变量。

  • c_index (Expr) – 累加器片段 C 的索引。

  • metadata (Expr) – 操作数的元数据。

  • meta_index (Expr) – 操作数的元数据索引。

  • sparse_selector (Expr) – 稀疏选择器,指示存储元数据的线程。

  • saturate (bool) – 输出端的可选饱和。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)

用于将 PTX MMA 的结果存储到目标指针的 TVM intrinsic

参数:
  • dtype (str) – 结果的数据类型。

  • m (IntImm) – mma 片段的形状。

  • n (IntImm) – mma 片段的形状。

  • dst_ptr (tir.Var) – 目标指针变量。

  • src_ptr (tir.Var) – 源指针变量。

  • src_offset (Expr) – 源偏移量。

  • dst_stride (tir.Var) – 目标步幅。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.mma_fill(dtype, local_size, local_ptr, offset)

用于将 MMA 累加寄存器归零初始化的 TVM intrinsic

参数:
  • dtype (str) – 结果的数据类型。

  • local_size (IntImm) – 元素的数量。

  • local_ptr (tir.Var) – 目标指针变量。

  • offset (Expr) – 目标偏移量。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset)

用于从共享内存加载 ptx 矩阵的 TVM intrinsic https://docs.nvda.net.cn/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix

参数:
  • dtype (str) – 结果的数据类型。

  • trans (bool) – 矩阵以列优先格式加载。

  • num (IntImm) – 矩阵的数量。

  • type (Literal[".b16"]) – 矩阵的数据类型。

  • local_ptr (tir.Var) – 本地指针变量。

  • local_offset (Expr) – 本地指针的偏移量。

  • smem_ptr (tir.Var) – 共享内存指针变量。

  • smem_offset (Expr) – 共享内存指针的偏移量。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes)

用于使用 cp.async 从全局内存异步复制到共享内存的 TVM intrinsic https://docs.nvda.net.cn/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async

参数:
  • dtype (str) – 结果的数据类型。

  • shared_ptr (tir.Var) – 共享内存指针变量。

  • shared_offset (Expr) – 共享内存指针的偏移量。

  • global_ptr (tir.Var) – 全局内存指针变量。

  • global_offset (Expr) – 全局内存指针的偏移量。

  • bytes (int) – 要复制的数据大小(字节)。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id)

用于使用 cp.async.bulk 从全局内存异步复制到共享内存的 TVM intrinsic https://docs.nvda.net.cn/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk

参数:
  • dtype (str) – 结果的数据类型。

  • shared_ptr (tir.Var) – 共享内存指针变量。

  • shared_offset (Expr) – 共享内存指针的偏移量。

  • global_ptr (tir.Var) – 全局内存指针变量。

  • global_offset (Expr) – 全局内存指针的偏移量。

  • bytes (int) – 要复制的数据大小(字节)。

  • barrier_id (int) – 屏障共享内存指针的 ID。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ptx_commit_group()

用于 ptx 异步复制提交的 TVM intrinsic https://docs.nvda.net.cn/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ptx_wait_group(num)

用于 ptx 异步复制等待的 TVM intrinsic https://docs.nvda.net.cn/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group

参数:

num (int) – 要等待的最近未提交的挂起 cp.async 组的数量。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ptx_cp_async_barrier(barrier_id)

用于使用 cp.async.mbarrier.arrive 的 ptx 异步复制屏障的 TVM intrinsic https://docs.nvda.net.cn/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive

参数:

barrier_id (int) – 屏障共享内存指针的 ID。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ptx_init_barrier_thread_count(barrier_id, thread_count)

用于使用 mbarrier.init 初始化线程计数的 ptx 屏障的 TVM intrinsic https://docs.nvda.net.cn/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init

参数:
  • barrier_id (int) – 屏障共享内存指针的 ID。

  • thread_count (int) – 预计到达屏障的线程数。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ptx_arrive_barrier(barrier_id)

用于使用 mbarrier.arrive 的 ptx 屏障到达的 TVM intrinsic https://docs.nvda.net.cn/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive

参数:

barrier_id (int) – 屏障共享内存指针的 ID。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ptx_arrive_barrier_expect_tx(barrier_id, byte_count)

用于使用 mbarrier.arrive.expect_tx 的带有 expect tx 的 ptx 屏障到达的 TVM intrinsic https://docs.nvda.net.cn/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive https://docs.nvda.net.cn/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation

参数:
  • barrier_id (int) – 屏障共享内存指针的 ID。

  • byte_count (int) – 增加 mbarrier 对象的 tx 计数,以跟踪其他异步事务的完成情况。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ptx_wait_barrier(barrier_id)

用于使用 mbarrier.try_wait 的 ptx 屏障等待的 TVM intrinsic https://docs.nvda.net.cn/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait

参数:

barrier_id (int) – 屏障共享内存指针的 ID。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.create_barriers(barrier_count)

用于创建 N 个屏障的 TVM intrinsic

参数:

barrier_count (int) – 要创建的屏障数量。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.make_filled_simdgroup_matrix(d: Var, index: PrimExpr, value: PrimExpr, col: int = 8, row: int = 8)

创建一个填充的 SIMDGroup 矩阵

参数:
  • d (var) – simdgroup 变量

  • index (PrimExpr) – 矩阵的索引。

  • value (PrimExpr) – 要填充的值。

  • col (int) – 列数。

  • row (int) – 行数。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.simdgroup_load(d: Var, index: PrimExpr, ptr: PrimExpr, stride: PrimExpr, col: int = 8, row: int = 8, transpose_matrix: bool = False)

从设备内存或线程组内存加载数据到 simdgroup

参数:
  • d (var) – simdgroup 变量

  • index (PrimExpr) – 矩阵的索引。

  • ptr (PrimExpr) – 指针。

  • stride (PrimExpr) – 步幅。

  • col (int) – 列数。

  • row (int) – 行数。

  • transpose_matrix (bool) – 是否转置矩阵。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.simdgroup_multiply_accumulate(d: Var, index_d: PrimExpr, a: Var, index_a: PrimExpr, b: Var, index_b: PrimExpr, c: Var, index_c: PrimExpr)

在 simdgroup 中乘法和累加两个矩阵,即 d = a * b + c

参数:
  • d (tir.Var) – 目标矩阵。

  • index_d (PrimExpr) – 目标矩阵的索引。

  • a (tir.Var) – 第一个矩阵。

  • index_a (PrimExpr) – 第一个矩阵的索引。

  • b (tir.Var) – 第二个矩阵。

  • index_b (PrimExpr) – 第二个矩阵的索引。

  • c (tir.Var) – 第三个矩阵。

  • index_c (PrimExpr) – 第三个矩阵的索引。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.simdgroup_store(d: PrimExpr, index: PrimExpr, ptr: PrimExpr, stride: PrimExpr, col: int = 8, row: int = 8, transpose_matrix: bool = False)

将数据从 simdgroup 存储到设备内存或线程组内存

参数:
transpose_matrixbool

是否转置矩阵。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.vectorlow(dtype, vec)

获取向量的低半部分

参数:
  • dtype (str) – 结果的数据类型。

  • vec (list) – 输入向量。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.vectorhigh(dtype, vec)

获取向量的高半部分

参数:
  • dtype (str) – 结果的数据类型。

  • vec (list) – 输入向量。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.vectorcombine(dtype, vec1, vec2)

连接两个向量

参数:
  • vec1 (list) – 输入向量。

  • vec2 (list) – 输入向量。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.infinity(dtype: str, span: Span | None = None) Any

dtype 类型的无穷大值

参数:
  • dtype (str) – 数据类型。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

value – dtype 类型的无穷大值。

返回类型:

tvm.Expr

tvm.tir.reinterpret(dtype, value, span: Span | None = None) Any

dtype 类型的无穷大值

参数:
  • dtype (str) – 数据类型。

  • value (PrimExpr) – 输入值。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

value – dtype 类型的重新解释转换值。

返回类型:

tvm.Expr

tvm.tir.exp(x)

计算输入 x 的指数。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.exp2(x)

计算 2 的 x 次方 (2**x)

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.exp10(x)

计算 10 的 x 次方 (10**x)

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.log(x)

计算输入 x 的自然对数。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.log2(x)

计算输入 x 的以 2 为底的对数。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.log10(x)

计算输入 x 的以 10 为底的对数。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.log1p(x)

计算输入 x 的 log(x + 1)。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.ldexp(x1, x2)

返回 x1 * (2 ** x2)。

参数:
返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.clz(x)

计算整数 x 的前导零位数。

参数:

x (PrimExpr) – 32 位或 64 位输入整数。如果输入为 0,则结果未定义。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.sin(x)

计算输入 x 的正弦值。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.sinh(x)

计算输入 x 的双曲正弦值。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.asin(x)

计算输入 x 的反正弦值。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.asinh(x)

计算输入 x 的反双曲正弦值。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.cos(x)

计算输入 x 的余弦值。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.cosh(x)

计算输入 x 的双曲余弦值。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.acos(x)

计算输入 x 的反余弦值。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.acosh(x)

计算输入 x 的反余弦值。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.tan(x)

计算输入 x 的正切值。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.tanh(x)

计算输入 x 的双曲正切值。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.atan(x)

计算输入 x 的反正切值。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.atan2(x1, x2)

计算 arctan2(x1, x2)。

参数:
返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.atanh(x)

计算输入 x 的反双曲正切值。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.bitwise_and(x, y, span=None)

计算两个值的按位与

参数:
  • x (PrimExpr) – 左操作数

  • y (PrimExpr) – 右操作数

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

res – 结果。

返回类型:

PrimExpr

tvm.tir.bitwise_not(x, span=None)

计算输入值的按位非

参数:
  • x (PrimExpr) – 输入操作数

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

res – 结果。

返回类型:

PrimExpr

tvm.tir.bitwise_or(x, y, span=None)

计算两个值的按位或

参数:
  • x (PrimExpr) – 左操作数

  • y (PrimExpr) – 右操作数

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

res – 结果。

返回类型:

PrimExpr

tvm.tir.bitwise_xor(x, y, span=None)

计算两个值的按位异或

参数:
  • x (PrimExpr) – 左操作数

  • y (PrimExpr) – 右操作数

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

res – 结果。

返回类型:

PrimExpr

tvm.tir.erf(x)

计算输入 x 的高斯误差函数。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.sigmoid(x)

快速获取 sigmoid 函数

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.sqrt(x)

计算输入 x 的平方根。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.rsqrt(x)

计算输入 x 的平方根的倒数。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.floor(x: PrimExprWithOp, span=None)

计算浮点数输入 x 的向下取整值。

参数:
  • x (PrimExpr) – 输入参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.ceil(x, span=None)

计算浮点数输入 x 的向上取整值。

参数:
  • x (PrimExpr) – 输入参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.hypot(x1, x2)

等价于 sqrt(x1**2 + x2**2),逐元素计算。

参数:
返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.trunc(x, span=None)

获取输入的截断值。

标量 x 的截断值是最接近零的整数 i,它比 x 更接近零。

参数:
  • x (PrimExpr) – 输入参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.abs(x, span=None)

逐元素获取输入的绝对值。

参数:
  • x (PrimExpr) – 输入参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.round(x, span=None)

将数组元素四舍五入到最接近的整数。

参数:
  • x (PrimExpr) – 输入参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.nextafter(x1, x2)

返回 x1 之后朝向 x2 的下一个浮点值。

参数:
返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.nearbyint(x, span=None)

将数组元素四舍五入到最接近的整数。此内部函数使用 llvm.nearbyint 而不是 llvm.round,前者更快,但结果与 te.round 不同。值得注意的是,nearbyint 根据舍入模式进行舍入,而 te.round (llvm.round) 忽略该模式。有关两者之间差异的更多信息,请参阅: https://cppreference.cn/w/cpp/numeric/math/round https://cppreference.cn/w/cpp/numeric/math/nearbyint

参数:
  • x (PrimExpr) – 输入参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.power(x, y, span=None)

x 的 y 次方

参数:
  • x (PrimExpr) – 输入参数。

  • y (PrimExpr) – 指数

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

z – 结果。

返回类型:

PrimExpr

tvm.tir.pow(x, y, span=None)

x 的 y 次方

参数:
  • x (PrimExpr) – 输入参数。

  • y (PrimExpr) – 指数

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

z – 结果。

返回类型:

PrimExpr

tvm.tir.popcount(x)

计算输入 x 中设置的位数。

参数:

x (PrimExpr) – 输入参数。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.fmod(x, y)

返回 x 除以 y 的余数,其符号与 x 相同。

参数:
返回值:

z – 结果。

返回类型:

PrimExpr

tvm.tir.if_then_else(cond, t, f, span=None)

条件选择表达式。

参数:
  • cond (PrimExpr) – 条件

  • t (PrimExpr) – 如果 cond 为真,则为结果表达式。

  • f (PrimExpr) – 如果 cond 为假,则为结果表达式。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

result – 条件表达式的结果。

返回类型:

节点

注意

与 Select 不同,if_then_else 不会执行不满足条件的分支。您可以使用它来防止越界访问。与 Select 不同,如果向量中的某些通道具有不同的条件,则 if_then_else 无法向量化。

tvm.tir.likely(cond, span=None)

将条件标记为 likely (很可能)。

参数:
  • cond (PrimExpr) – 输入参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

y – 标记的表达式。

返回类型:

PrimExpr

tvm.tir.isnan(x, span=None)

检查输入值是否为 Nan。

参数:
  • x (PrimExpr) – 输入参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.isnullptr(x, span=None)

检查输入值是否为 nullptr。

参数:
  • x (PrimExpr) – 输入参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.isfinite(x, span=None)

检查输入值是否为有限值。

参数:
  • x (PrimExpr) – 输入参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.isinf(x, span=None)

检查输入值是否为无穷大值。

参数:
  • x (PrimExpr) – 输入参数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.copysign(x1, x2)

逐元素地将 x1 的符号更改为 x2 的符号。

参数:
返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.div(a, b, span=None)

按照 C/C++ 语义计算 a / b。

参数:
  • a (PrimExpr) – 左侧操作数,已知为非负数。

  • b (PrimExpr) – 右侧操作数,已知为非负数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

res – 结果表达式。

返回类型:

PrimExpr

注意

当操作数是整数时,返回 truncdiv(a, b, span)。

tvm.tir.indexdiv(a, b, span=None)

计算 floor(a / b),其中 a 和 b 是非负数。

参数:
  • a (PrimExpr) – 左侧操作数,已知为非负数。

  • b (PrimExpr) – 右侧操作数,已知为非负数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

res – 结果表达式。

返回类型:

PrimExpr

注意

使用此函数来分割非负索引。此函数可能会利用操作数的非负性。

tvm.tir.indexmod(a, b, span=None)

计算 indexdiv 的余数。a 和 b 是非负数。

参数:
  • a (PrimExpr) – 左侧操作数,已知为非负数。

  • b (PrimExpr) – 右侧操作数,已知为非负数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

res – 结果表达式。

返回类型:

PrimExpr

注意

使用此函数来分割非负索引。此函数可能会利用操作数的非负性。

tvm.tir.truncdiv(a, b, span=None)

计算两个表达式的 truncdiv。

参数:
  • a (PrimExpr) – 左侧操作数

  • b (PrimExpr) – 右侧操作数

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

res – 结果表达式。

返回类型:

PrimExpr

注意

这是 C 语言中的默认整数除法行为。

tvm.tir.truncmod(a, b, span=None)

计算两个表达式的 truncmod。

参数:
  • a (PrimExpr) – 左侧操作数

  • b (PrimExpr) – 右侧操作数

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

res – 结果表达式。

返回类型:

PrimExpr

注意

这是 C 语言中的默认整数除法行为。

tvm.tir.floordiv(a, b, span=None)

计算两个表达式的 floordiv。

参数:
  • a (PrimExpr) – 左侧操作数

  • b (PrimExpr) – 右侧操作数

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

res – 结果表达式。

返回类型:

PrimExpr

tvm.tir.floormod(a, b, span=None)

计算两个表达式的 floormod。

参数:
  • a (PrimExpr) – 左侧操作数

  • b (PrimExpr) – 右侧操作数

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

res – 结果表达式。

返回类型:

PrimExpr

tvm.tir.ceildiv(lhs, rhs, span=None)

通用 ceildiv 运算符。

参数:
  • lhs (object) – 左侧操作数。

  • rhs (object) – 右侧操作数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

op – ceildiv 运算的结果 Expr。

返回类型:

tvm.Expr

tvm.tir.comm_reducer(fcombine, fidentity, name='reduce')

为规约创建可交换的 reducer。

参数:
  • fcombine (function(Expr -> Expr -> Expr)) – 一个二元函数,接受两个 Expr 作为输入并返回一个 Expr。

  • fidentity (function(str -> Expr)) – 一个函数,接受一个类型字符串作为输入并返回一个 const Expr。

返回值:

reducer – 一个函数,用于在轴上创建规约表达式。有两种使用方法:

  1. 接受 (expr, axis, where) 以在指定的轴上生成 Reduce Expr;

  2. 简单地将其与多个 Expr 一起使用。

返回类型:

函数

示例

n = te.var("n")
m = te.var("m")
mysum = te.comm_reducer(lambda x, y: x+y,
    lambda t: tvm.tir.const(0, dtype=t), name="mysum")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), name="k")
B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
tvm.tir.min(expr, axis, where=None, init=None, *args)

在轴上创建最小值表达式。

参数:
  • expr (PrimExpr) – 源表达式。

  • axis (IterVar) – 规约 IterVar 轴

  • where (optional, Expr) – 规约的过滤谓词。

返回值:

value – 结果值。

返回类型:

PrimExpr

示例

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this min reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.min represents tvm.te.min or tvm.tir.min.
B = te.compute((m,), lambda i: tvm.min(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
min_res = tvm.min(m, n)
tvm.tir.max(expr, axis, where=None, init=None, *args)

在轴上创建最大值表达式。

参数:
  • expr (PrimExpr) – 源表达式。

  • axis (IterVar) – 规约 IterVar 轴

  • where (optional, Expr) – 规约的过滤谓词。

返回值:

value – 结果值。

返回类型:

PrimExpr

示例

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this max reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.max represents tvm.te.max or tvm.tir.max.
B = te.compute((m,), lambda i: tvm.max(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
max_res = tvm.max(m, n)
tvm.tir.sum(expr, axis, where=None, init=None, *args)

在轴上创建总和表达式。

参数:
  • expr (PrimExpr) – 源表达式。

  • axis (IterVar) – 规约 IterVar 轴

  • where (optional, Expr) – 规约的过滤谓词。

返回值:

value – 结果值。

返回类型:

PrimExpr

示例

m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")

# there are two way to use this sum reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
# tvm.sum represents tvm.te.sum or tvm.tir.sum.
B = te.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name="B")

# mode 2, simply use it with multiple Exprs:
sum_res = tvm.sum(m, n)
tvm.tir.q_multiply_shift(x, y, q, s)

执行两个 Q-number x 和 y 之间的乘法,然后进行右移 s。数学表达式为

out = round(x*y*2^-s)

更多关于 Q-number 的信息请参考这里:https://en.wikipedia.org/wiki/Q_(number_format) 舍入规则是四舍五入到最接近的值,半值向上舍入(即,round(x.1) = x 且 round (x.5) = x+1)

参数:
  • x (PrimExpr) – 第一个 Q-number

  • y (PrimExpr) – 第二个 Q-number

  • q (PrimExpr) – x 和 y 中的小数位数。需要 > 0

  • s (PrimExpr) – 整数移位

返回值:

y – 结果。

返回类型:

PrimExpr

tvm.tir.q_multiply_shift_per_axis(x: PrimExpr, y: PrimExpr, ls: PrimExpr, rs: PrimExpr, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm)

执行两个 Q-number x 和 y 之间的乘法

参数:
  • x (PrimExpr) – 第一个 Q-number。

  • y (PrimExpr) – 第二个 Q-number。

  • ls (PrimExpr) – 整数左移。

  • rs (PrimExpr) – 整数右移。

  • q (IntImm) – x 和 y 中的小数位数。需要 > 0。

  • is_lshift_required (IntImm) – 是否需要进行左移。

  • is_rshift_required (IntImm) – 是否需要进行右移。

返回值:

z – 结果。

返回类型:

PrimExpr

tvm.tir.shift_left(x, y, span=None)

返回 x 左移 y 位的结果。

参数:
返回值:

z – 结果。

返回类型:

PrimExpr

tvm.tir.shift_right(x, y, span=None)

返回 x 右移 y 位的结果。

参数:
返回值:

z – 结果。

返回类型:

PrimExpr

tvm.tir.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint)

用于分配临时工作空间的后端函数

参数:
  • device_type (int) – 将分配空间的设备类型。

  • device_id (int) – 将分配空间的设备 ID。

  • nbytes (int) – 请求的空间大小。

  • dtype_code_hint (int) – 数组元素的类型代码。仅在某些后端(如 OpenGL)中使用。

  • dtype_bits_hint (int) – 数组元素的类型位数。仅在某些后端(如 OpenGL)中使用。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.TVMBackendFreeWorkspace(device_type, device_id, ptr)

用于释放临时工作空间的后端函数。

参数:
  • device_type (int) – 将分配空间的设备类型。

  • device_id (int) – 将分配空间的设备 ID。

  • ptr (tir.Var) – 分配空间指针的结果。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.start_profile_intrinsic(id)

启动 profile intrinsic。 :param id: intrinsic ID。 :type id: int

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.end_profile_intrinsic(id)

结束 profile intrinsic。 :param id: intrinsic ID。 :type id: int

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.vscale()

获取目标的 vscale 值。它将被降低为 llvm.vscale intrinsic (https://llvm.net.cn/docs/LangRef.html#llvm-vscale-intrinsic) :returns: **call** – 调用 vscale intrinsic 的 tir.Call :rtype: PrimExpr

tvm.tir.get_active_lane_mask(dtype, base, limit)

根据上限 (limit) 和当前值 (base) 计算谓词掩码。

它将被降低为 llvm.get.active.lane.mask intrinsic。 (https://llvm.net.cn/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics)

参数:
  • dtype (str) – 结果的数据类型。

  • base (PrimExpr) – 表示 base 的表达式。

  • limit (PrimExpr) – 表示 limit 的表达式。

tvm.tir.get_vscale_expr(dtype: str | DataType, min_size: int = 128) PrimExpr

创建一个数据类型相关的可伸缩表达式。

参数:
  • dtype (Union[str, tvm.DataType]) – 元素数据类型。

  • min_size (int) – 可伸缩向量的最小大小(以位为单位)。

tvm.tir.dp4a(vec1, vec2, acc=0)

两个 int8x4 向量的点积,并添加一个可选的累加器

参数:
  • vec1 (int8x4) – 输入向量。

  • vec2 (int8x4) – 输入向量。

  • acc (int32) – 累加器。

返回值:

call – 调用表达式。

返回类型:

PrimExpr

tvm.tir.ignore_loop_partition(predicate) PrimExpr

注解一个谓词,使其不被视为循环分区的目标条件。

参数:

predicate (PrimExpr) – 注解的谓词表达式。

tvm.tir.add(lhs, rhs, span=None)

通用加法运算符。

参数:
  • lhs (object) – 左侧操作数。

  • rhs (object) – 右侧操作数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

op – 加法运算的结果 Expr。

返回类型:

tvm.Expr

tvm.tir.subtract(lhs, rhs, span=None)

通用减法运算符。

参数:
  • lhs (object) – 左侧操作数。

  • rhs (object) – 右侧操作数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

op – 减法运算的结果 Expr。

返回类型:

tvm.Expr

tvm.tir.multiply(lhs, rhs, span=None)

通用乘法运算符。

参数:
  • lhs (object) – 左侧操作数。

  • rhs (object) – 右侧操作数。

  • span (Optional[Span]) – 此运算符在源代码中的位置。

返回值:

op – 乘法运算的结果 Expr。

返回类型:

tvm.Expr

class tvm.tir.BlockDependenceInfo(mod: IRModule | PrimFunc)

一个帮助使用两个核心对象 BlockScope 和 StmtSRef 构建和查询块级别依赖关系的对象

公开的数据结构有:1) sref2scope:从 sref 到其对应的 BlockScope 的映射 2) stmt2ref:从块到对应的 StmtSRef 的映射

请注意,此对象不存储循环的 SRef,因为目的只是为了公开块级别的依赖关系。这提供了一个优势,即给定块 sref 的作用域块(父块)可以直接作为 sref->parent 访问

get_sref(block: Block) StmtSRef | None

返回指向该块的对应 sref

参数:

stmt (Block) – 要检索 sref 的块

返回值:

sref – 对应的 sref

返回类型:

StmtSRef

get_block_scope(block_sref: StmtSRef) BlockScope

获取与块 sref 对应的 BlockScope

参数:

block_sref (StmtSRef) – 要检索的块 sref

返回值:

scope – 对应的 BlockScope

返回类型:

StmtSRef

tvm.tir.build(mod: PrimFunc | IRModule, target: str | Target | None = None, pipeline: None | str | Pass = 'default')

构建具有签名的函数,为与目标信息耦合的设备生成代码。

参数:
返回值:

一个结合了主机和设备代码的模块。

返回类型:

tvm.runtime.Module

tvm.tir.get_tir_pipeline(name: str = 'default', **kwargs) Pass

按名称获取预构建的 pipeline

参数:

name (Optional[str]) – pipeline 的名称

tvm.tir.get_default_tir_pipeline(target: Target) Pass

获取给定目标的默认 TIR pipeline。