tvm.transform
跨 IR 变体的通用 Pass 基础设施。
- tvm.transform.ApplyPassToFunction(transform: Pass, func_name_regex: str, error_if_no_function_matches_regex: bool = False) Pass
用于将 Pass 应用于 IRModule 中特定函数的实用工具
TVM 在降低过程的所有阶段都使用 IRModule 到 IRModule 的转换。当手动编写优化模型,或对 IRModule 中的特定内核执行优化时,这些转换可能很有用。此实用工具允许将 Pass 应用于指定函数,而不会更改模块中的其他函数。
- 参数:
- 返回:
new_transform – 修改后的 IRModule 到 IRModule 的 Pass。
- 返回类型:
- class tvm.transform.ModulePass
在 tvm.IRModule 上工作的 Pass。用户无需直接与此类交互。相反,模块 Pass 应该通过 module_pass 创建,因为 module_pass API 的设计足够灵活,可以处理以不同方式创建模块 Pass。此外,模块 Pass 的所有成员都可以从基类访问。同样的规则也适用于 FunctionPass。
- class tvm.transform.Pass
所有 Pass 的基类。这里的所有方法都只是后端实现的简单包装器。它们被定义为方便用户与基类交互。
- property info
获取 Pass 元数据。
- class tvm.transform.PassContext(opt_level=2, required_pass=None, disabled_pass=None, instruments=None, config=None, trace=None, trace_stack=None, make_traceable=None, num_evals=0, tuning_api_database=None)
TVM 优化/分析运行的基础。每个 Pass 上下文都包含许多辅助信息,用于帮助优化 Pass。此类信息包括错误报告器,用于记录优化期间的错误等。
- opt_levelOptional[int]
此 Pass 的优化级别。
- required_passOptional[Union[List[str], Set[str], Tuple[str]]]
某个 Pass 所需的 Pass 列表。
- disabled_passOptional[Union[List[str], Set[str], Tuple[str]]]
禁用的 Pass 列表。
- instrumentsOptional[Sequence[PassInstrument]]
Pass 工具实现的列表。
- configOptional[Dict[str, Object]]
特定 Pass 的其他配置。
- trace: Optional[relax.tuning.Trace]
跟踪模式的初始跟踪。
- trace_stack: Optional[List[relax.tuning_api.Trace]]
跟踪模式的初始跟踪堆栈。
- make_traceable: Optional[List[str]]
要使其可跟踪的 Pass 列表。
- num_evals: int
管道中进行的初始评估次数。
tuning_api_database: Optional[relax.tuning_api.JSONDatabase]
- override_instruments(instruments)
覆盖此 PassContext 中的工具。
如果存在现有工具,则会调用它们的
exit_pass_ctx
回调。然后切换到新工具并调用新的enter_pass_ctx
回调。- instrumentsSequence[PassInstrument]
Pass 工具实现的列表。
- static current()
返回当前的 Pass 上下文。
- push_trace(trace)
将跟踪推入堆栈。
- pop_trace(return_current=True)
从堆栈中弹出最顶层的跟踪。 :returns: Trace :rtype: Optional[relax.tuning.Trace]
- get_trace_stack()
获取当前的跟踪堆栈。
- get_trace_stack_size()
获取当前堆栈的大小。
- get_current_trace()
获取堆栈顶部的跟踪。
- get_tuning_api_database()
获取调优 API 数据库。
- class tvm.transform.PassInfo(opt_level, name, required=None, traceable=False)
该类包含 Pass 所需的元数据。它是运行优化或分析所需信息的容器。当需要更多元数据时,可以通过添加新成员来扩展此类。
- tvm.transform.PrintIR(header='', show_meta_data=False)
一个特殊的跟踪 Pass,用于打印标题和 IR。
- class tvm.transform.Sequential(passes=None, opt_level=0, name='sequential', required=None, traceable=False)
一个在 Pass 对象序列上工作的 Pass。可以使用此类按顺序执行多个 Pass。
请注意,用户还可以提供一系列他们不想在运行顺序 Pass 时应用的 Pass。Pass 依赖关系也将在后端解析。
- tvm.transform.module_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False)
装饰模块 Pass。
当提供 pass_func 时,此函数返回一个回调。否则,它充当装饰器函数。
pass_func 也可以是具有 transform_module 方法的类类型。此函数将使用 transform_module 作为 Pass 函数创建一个装饰的 ModulePass。
- 参数:
pass_func (Optional[Callable[(Module, PassContext) ->Module]]) – 转换函数或类。
opt_level (int) – 此模块 Pass 的优化级别。
name (Optional[str]) – 模块 Pass 的名称。名称可以为空。在这种情况下,优化函数的名称将用作 Pass 名称。
required (Optional[List[str]]) – 模块 Pass 依赖的 Pass 列表。
traceable (Boolean) – 指示模块 Pass 是否可跟踪的布尔变量
- 返回:
create_module_pass – 如果未提供 pass_func,则将返回装饰器,否则返回装饰结果。返回的装饰器有两种行为,具体取决于输入:当我们装饰 Pass 函数时,将返回一个新的 ModulePass。当我们装饰类类型时,将返回一个新的 ModulePass 类。
- 返回类型:
Union[Callable, ModulePass]
示例
以下代码块装饰一个模块 Pass 类。
@tvm.ir.transform.module_pass class CustomPipeline: def __init__(self, enable_fold): self.enable_fold = enable_fold self.const_fold = relax.transform.FoldConstant() def transform_module(self, mod, ctx): if self.enable_fold: mod = self.const_fold(mod, ctx) return mod # create an instance of customized pipeline pipeline = CustomPipeline(enable_fold=False) assert isinstance(pipeline, transform.ModulePass) # run the pipeline. output_module = pipeline(input_module)
以下代码通过装饰用户定义的转换函数来创建模块 Pass。
@tvm.ir.transform.module_pass(opt_level=2) def transform(mod, ctx): return relax.transform.FoldConstant(mod) module_pass = transform assert isinstance(module_pass, transform.ModulePass) assert module_pass.info.opt_level == 2 # Given a module m, the optimization could be invoked as the follwoing: updated_mod = module_pass(m) # Now a function abs should be added to the module m.