Pass 基础设施
Relax 和 TVM IR 都包含一系列优化 pass,这些 pass 改进了模型的性能指标,例如平均推理、内存占用或特定设备的功耗。存在一套标准的优化以及机器学习特定的优化,包括常量折叠、死代码消除、运算符布局更改、运算符融合、缓冲区处理和循环转换等。这些 pass 中的每一个都构建为 ir 到 ir 的转换,使用在遍历期间和/或之前收集的分析结果。
然而,随着 TVM 的快速发展,对更系统和高效的方式来管理这些 pass 的需求变得越来越明显。此外,一个通用的框架,用于管理跨 TVM 堆栈不同层(例如 Relax 和 tir)的 pass,为开发人员快速原型化和将实现的 pass 插入系统铺平了道路。
本文档描述了这样一种基础设施的设计,该基础设施利用了生产编译器用于管理优化 pass 的方式以及现代深度学习框架用于构建层的方式。
例如,许多现有的生产编译器,如 GCC 和 LLVM,都采用 pass 管理器来有效地管理 pass 的执行。最初,管理 pass 很简单,因为 pass 的数量很少,但成熟的编译器将包含数百个单独的 pass。通常,外部用户希望正确调度自定义 pass,而无需修改单个手工制作的 pass 顺序。
类似地,现代深度学习框架,如 Pytorch 和 MXNet Gluon,也倾向于通过 Sequential 和 Block 分别启用 pass 风格的层构建方案。借助这些构造,这些现代框架能够方便地将模块/层添加到其容器中,并轻松构建神经网络。
TVM pass 基础设施的设计很大程度上受到了 LLVM 中使用的分层 pass 管理器和流行的深度学习框架中使用的块式容器的启发。pass 基础设施的主要目标包括
实现更好的程序化优化编排。这允许用户灵活地自定义和构建自己的优化 pipeline。
提供一种用户友好的方式来调试优化 pass。
减轻开发人员手动和分别解决 pass 之间依赖关系的负担。
简化开发人员实现新 pass 的过程。例如,我们允许用户在 Python 中实现一个 pass,并让 pass 基础设施操纵其执行。
设计
我们专注于用户的易扩展性,使用户可以快速添加新 pass 而不会损失向后兼容性。该设计包含后端和前端。前者实现了 pass 基础设施的主要逻辑。后者提供了简单的 API 供用户交互,即允许用户快速创建自己的优化 pipeline。
C++ 后端
我们提供了一个 PassInfo
对象来包含 pass 所需的基本信息。name
是 pass 名称,opt_level
指示将在哪个优化级别启用该 pass,required
表示执行某个 pass 所需的 pass(有关更多详细信息,请参阅 include/tvm/ir/transform.h)。例如,在注册 pass 期间(稍后将介绍),pass 开发人员可以指定 pass 的名称、将执行 pass 的优化级别和/或所需的 pass。opt_level
可用于帮助 pass 基础设施识别在用户提供的优化级别下运行时是否需要执行某个 pass。required
字段可供 pass 基础设施用于解决 pass 依赖关系。
class PassInfoNode : public Object {
String name;
int opt_level;
Array<String> required;
};
PassContext
PassContext
携带了优化 pass 的有用信息。例如,它包含错误报告系统,因此优化作者可以提供关于优化失败原因的诊断信息。PassContext
也被设计为替换旧的 BuildConfig
,后者用于帮助用户配置编译选项,包括优化级别和所需/禁用的 pass 等。例如,我们可能有一个配置,它执行 opt_level=3
的所有 pass,并使用 PassContext
提供的 disabled_pass=xx
禁用某些 pass。现在我们可以全局收集 opt_level=3
的所有 pass,并排除禁用 pass 列表中的那些。PassContext
还提供了一种检测所有 pass 的方法。请参阅 Pass 检测 部分。
此类旨在方便用户编写 Python with
语法,以便在特定配置下执行优化。此外,用户可以通过 PassContext::Current()
以线程安全的方式获取特定程序范围内可用的上下文,因为线程局部存储 PassContextThreadLocalStore
用于保存创建的 pass 上下文对象。稍后将提供示例,展示我们如何使用 C++ 和 Python API 创建使用 pass 上下文的编译 pipeline。
class PassContextNode : public Object {
public:
int opt_level{2};
tvm::Array<tvm::Expr> required_pass;
tvm::Array<tvm::Expr> disabled_pass;
mutable Optional<DiagnosticContext> diag_ctx;
Map<String, ObjectRef> config;
Array<instrument::PassInstrument> instruments;
};
class PassContext : public NodeRef {
public:
TVM_DLL static PassContext Create();
TVM_DLL static PassContext Current();
TVM_DLL void InstrumentEnterPassContext();
TVM_DLL void InstrumentExitPassContext();
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
/* Other fields are omitted. */
private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Classes to get the Python `with` like syntax.
friend class tvm::With<PassContext>;
};
struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
/*! \brief The current pass context. */
std::stack<PassContext> context_stack;
PassContextThreadLocalEntry() {
default_context = PassContext(make_node<PassContextNode>());
}
};
/*! \brief The thread-local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
PassContextThreadLocalStore;
Pass 构造
pass 基础设施以分层方式设计,它可以工作在 Relax/tir 程序的不同的粒度级别。引入了一个纯虚类 PassNode
作为不同优化 pass 的基类。此类包含几个虚方法,这些方法必须在模块、函数或 pass 序列的级别上由子类实现。
class PassNode : Object {
virtual PassInfo Info() const = 0;
virtual Module operator()(const IRModule& mod
const PassContext& pass_ctx) const = 0;
};
functor 展示了 pass 必须如何实现,即它始终在特定上下文下的 IRModule
上工作。所有 pass 都以 Module
到 Module
的方式设计。因此,由 pass 基础设施管理的优化将始终更新整个模块。
已经创建了几个子类来实现不同类型的优化 pass,例如,函数级 pass、模块级 pass 和顺序 pass。每个子类本身都可以充当 pass 管理器。例如,它们可以收集所需的 pass 并执行它们,或者基于给定的元数据构建依赖关系图。它们的完整定义可以在 src/ir/transform.cc 中找到。
模块级 Pass
模块级 pass 主要用于全局和过程间优化 (IPO),这类似于 LLVM 中使用的模块 pass。Relax 中一些需要模块全局视图的典型 pass,例如 A-normal form 转换和 lambda 提升等,都属于此集合。在此级别,用户甚至可以在模块中添加和/或删除函数。请注意,所有 pass
class ModulePassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
// Other members/methods are omitted
};
pass_info
维护模块级 pass 所需的信息。pass_func
勾勒出真正的优化。例如,我们可能需要在模块上执行死代码消除。我们可以在 pass_func
中实现该算法,并让它在模块上运行。然后它将删除死代码,包括模块中未使用的函数。请注意,此字段设计为 packed function,这使得可以在 C++ 和 Python 中实现优化。
函数级 Pass
函数级 pass 用于为给定的 Relax/tir 模块实现各种函数内级别的优化。它一次从模块的函数列表中获取一个函数进行优化,并产生一个重写的 Relax Function
或 tir PrimFunc
。大多数 pass 可以归为这一类,例如 Relax 中的公共子表达式消除和推理简化,以及 tir 中的向量化和展平存储等。
请注意,此级别 pass 的范围是 Relax 函数或 tir 原始函数。因此,我们不能通过这些 pass 添加或删除函数,因为它们不知道全局信息。
class FunctionPassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
bool SkipFunction(const Function& func) const;
// Other members/methods are omitted...
};
pass_info
与我们在模块 pass 中描述的相同。pass_func
接受一个函数进行优化,它也需要一个模块,因为我们可能会使用它来报告错误。函数可以使用 “SkipOptimization” 注释,以便在优化期间忽略它。
顺序 Pass
SequentialPass
类似于 Pytorch nn.Sequential
,它包含一系列用于执行的 pass。
class SequentialPassNode : PassNode {
PassInfo pass_info;
// Passes need to be executed.
Array<Pass> passes;
bool PassEnabled(const PassInfo& info) const;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};
以下代码展示了如何调用顺序 pass 中的各个 pass。本质上,我们使用添加到 pass 列表的顺序,按顺序执行顺序 pass 中的每个 pass。
Module SequentialNode::operator()(const Module& module,
const PassContext& pass_ctx) const {
Module mod = module;
for (const Pass& pass : passes) {
ICHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!PassEnabled(pass_info)) continue;
for (const auto& it : pass_info->required) {
const auto* name = it.as<tvm::ir::StringImm>();
ICHECK(name);
mod = GetPass(name->value)(mod, pass_ctx);
}
mod = pass(mod, pass_ctx);
}
return mod;
}
在调用 pass 时,我们首先检查是否启用了此 pass。这首先通过检查用户是否显式禁用了该 pass 来完成,然后检查用户是否将其指定为必需的 pass。如果仍然不确定是否启用了此 pass,则将检查其 opt_level
。仅当其优化级别不低于 pass 上下文中的配置优化级别时,此 pass 才会被启用并因此执行。
要执行 pass,我们需要首先使用 pass 名称在 TVM packed function 注册表中检索已注册的 pass。这是可能的,因为每个 pass 都注册了一个 API 端点,我们稍后将展示。
Pass GetPass(const std::string& pass_name) {
using tvm::runtime::Registry;
std::string fpass_name = "relax.transform." + pass_name;
const auto* f = Registry::Get(fpass_name);
ICHECK(f != nullptr) << "Cannot find " << fpass_name
<< "to create the pass " << pass_name;
return (*f)();
}
提供了一些辅助函数来创建每种类型的上述 pass。这些辅助函数也暴露给 Python 前端,供用户有利地使用 Python API 创建特定的 pass 对象。
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass CreatePrimFuncPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass Sequential(tvm::Array<Pass> passes, PassInfo pass_info);
Pass 注册
我们已经介绍了不同级别 pass 的概念以及用于编译的上下文。看看用户可以多么容易地注册一个 pass 会很有趣。让我们以常量折叠为例。此 pass 已经实现,用于折叠 Relax 函数中的常量(在 src/relax/transforms/fold_constant.cc 中找到)。
提供了一个 API 来执行 Expr
到 Expr
的转换。
Expr FoldConstant(const Expr& expr);
为了将此 pass 注册到 pass 基础设施,我们首先需要确定将在哪个级别执行此 pass。由于常量折叠发生在单个函数上,我们应该直观地通过 CreateFunctionPass
为其创建一个 FunctionPass
。pass_func
作为 packed function 返回,它在 IRModule 中的每个函数上调用 Expr
到 Expr
API。{}
表示此 pass 不需要任何先决条件。否则,pass 开发人员必须识别并列出它们。
同时,一个 pass API 端点以名称 "relax.transform.FoldConstant
注册。因此,此 pass 成为注册表中的一个条目,C++(例如上面的 GetPass
)和 Python 都可以在需要时访问它。
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) { return ConstantFolder::Fold(f, m); };
return CreateFunctionPass(pass_func, 0, "FoldConstant", {});
}
TVM_REGISTER_GLOBAL("relax.transform.FoldConstant")
.set_body_typed(FoldConstant);
} // namespace transform
为了允许其他 C++ 模块应用此 pass,我们在 include/tvm/relax/transform.h 中声明了一个自由函数,如下所示
TVM_DLL Pass FoldConstant();
Pass 检测
Pass 检测是一种分析 pass 本身的机制。例如,我们可以使用该基础设施来了解 pass 需要多少时间和内存,或者 pass 如何转换 IR 模块。
我们在 PassContext
的生命周期中引入了四个检测点。
TVM_DLL void InstrumentEnterPassContext();
TVM_DLL void InstrumentExitPassContext();
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
InstrumentEnterPassContext
在进入 PassContext
实例的范围时立即调用。
InstrumentExitPassContext
在离开 PassContext
范围时或在 pass 执行期间发生异常时调用。当检测被 tvm.transform.PassContext
中的 override_instruments
覆盖时,也会调用此方法。请参阅 覆盖当前 PassContext 中的检测。
InstrumentBeforePass
在执行之前调用。InstrumentAfterPass
在执行后调用(如果应运行该 pass)。行为类似于
if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) {
new_ir_module = run_pass(ir_module, pass_ctx);
pass_ctx.InstrumentAfterPass(new_ir_module, pass_info);
return new_ir_module;
}
PassInstrument
接口允许您在上述四种方法中运行任意代码。可以将多个 PassInstrument
实例注册到单个 PassContext
中。PassInstrument
实例按照传递给 PassContext
的 instruments
参数的顺序依次调用。
PassInstrument
提供以下接口
namespace instrument {
class PassInstrumentNode : public Object {
public:
String name;
virtual void EnterPassContext() const = 0;
virtual void ExitPassContext() const = 0;
virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0;
virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0;
virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0;
/* Other fields are omitted. */
};
class PassInstrument : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode);
};
} // namespace instrument
Python 前端提供了快速实现 PassInstrument
的方法。请参阅 Pass 检测。
在 PassContext
中,PassInstrument
实例的调用顺序如下
with PassContext(instruments=[pi]) # pi = a PassInstrument implementation.
pi.EnterPassContext()
if pi.ShouldRun(Pass1):
pi.RunBeforePass()
Pass1()
pi.RunAfterPass()
if pi.ShouldRun(Pass2):
pi.RunBeforePass()
Pass2()
pi.RunAfterPass()
pi.ExitPassContext()
以下是 PassInstrument
接口和 PassContext
方法之间关系的简要介绍。有关更多详细信息,请参阅 (src/ir/transform.cc)。
InstrumentEnterPassContext
EnterPassContext()
按照传递给PassContext
的instruments
的顺序执行。当引发异常时,
PassContext
通过清除所有已注册的PassInstrument
实例来禁用 pass 检测。然后
PassContext
执行每个成功完成EnterPassContext()
的PassInstrument
实例的ExitPassContext()
方法例如,如果
PassInstrument
A、B 和 C 注册到一个PassContext
,并且 A 完成了EnterPassContext()
,而 B 抛出了异常,则 C 永远不会执行;执行 A 的ExitPassContext()
。
InstrumentExitPassContext
每个
PassInstrument
实例的ExitPassContext()
按照传递给PassContext
的instruments
的顺序执行。当发生异常时,
instruments
将被清除。在抛出异常的实例之后注册的
PassInstrument
实例不执行ExitPassContext
。
InstrumentBeforePass
如果 pass 未列为必需的 pass,则执行
ShouldRun
。如果 pass 未被
ShouldRun
阻止,则RunBeforePass
按照instruments
的顺序执行。请注意,
InstrumentBeforePass
返回一个布尔值,指示是否应运行该 pass。当发生异常时,它会立即抛出。我们依靠 Python 上下文管理器来安全地退出
PassContext
(这意味着将运行每个检测的ExitPassContext
。对于 C++,请参阅 include/tvm/support/with.h。)
InstrumentAfterPass
RunAfterPass
按照传递给PassContext
的instruments
的顺序执行。当发生异常时,它会立即抛出。我们依靠 Python 上下文管理器或
With
类 (include/tvm/support/with.h) 来安全地退出PassContext
内置检测
有几个内置检测。标有 *TODO* 的那些尚未实现。
PassTimingInstrument(请参阅 src/ir/instrument.cc)
分析 pass 的执行时间。
PrintIRBefore(TODO)
在 pass 转换 IR 模块之前打印它。
tvm.transform.PrintIR()
如果我们将其插入到 pass 周围,也可以达到此目的。但是,使用PassInstrument
,我们无需修改 pass 序列。
PrintAfter(TODO)
在 pass 转换 IR 模块之后打印它。
Python 前端
前端侧只需要一些简单的 API。例如,我们可以为用户提供以下 API 来创建和执行 pass(完整实现在 python/tvm/relax/transform/transform.py 和 python/tvm/ir/transform.py 中提供)。后端接收信息并决定应该使用哪个函数来创建 Pass 对象。
PassContext
Python 前端为 PassContext
提供了一个包装器,以通过覆盖 __enter__
和 __exit__
来启用 with
语法。提供了一个 current
静态方法,供用户获取在特定范围内使用的上下文。
@tvm._ffi.register_object("transform.PassContext")
class PassContext(tvm.runtime.Object):
def __enter__(self):
_transform.EnterPassContext(self)
return self
def __exit__(self, ptype, value, trace, config):
_transform.ExitPassContext(self)
@staticmethod
def current():
"""Return the current pass context."""
return _transform.GetCurrentPassContext()
PassContext
用于配置编译选项,包括优化级别和所需/禁用的 pass。它还可以接受配置字典,以便不同的 pass 可以方便地获取传递的数据,例如回退设备信息和循环展开的步长/深度等。为了启用获取所需的配置,必须通过 TVM_REGISTER_PASS_CONFIG_OPTION
注册键。例如,以下内容供循环展开 pass 使用
TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);
有关更多详细信息,请参阅 src/tir/transforms/unroll_loop.cc。
Pass 检测
可以通过在实现以下方法的类上使用 pass_instrument
装饰器 (python/tvm/ir/instrument.py) 来实现 PassInstrument
。请注意,建议使用 pass_instrument
装饰器来实现 PassInstrument
,而不是覆盖或子类化。
enter_pass_ctx
此方法在进入
PassContext
时运行。
exit_pass_ctx
此方法在退出
PassContext
时运行。
should_run
此方法在 pass 执行之前运行,返回一个布尔值,指示是否应运行该 pass。
run_before_pass
如果应运行 pass,则此方法在 pass 执行之前立即运行。
run_after_pass
此方法在 pass 执行完毕后立即运行。
PassInstrument
实例可以通过 tvm.transform.PassContext
中的 instruments
参数注册。
使用 pass 检测 教程提供了有关如何使用 Python API 实现 PassInstrument
的示例。
覆盖当前 PassContext 中的检测
提供了 override_instruments
方法来覆盖当前 PassContext
的 instruments
。例如,如果 pass 在没有显式创建新的 PassContext
的情况下运行,仍然可以通过以下方式将 PassInstrument
注册到全局 PassContext
中
cur_pass_ctx = tvm.transform.PassContext.current()
# override PassInstrument instances
cur_pass_ctx.override_instruments([pass_inst])
mod = pass_seq(mod)
result = pass_inst.get_result()
请注意,当调用 override_instruments
时,将调用旧的 PassInstrument
实例的 exit_pass_ctx
方法。然后调用新的 PassInstrument
的 enter_pass_ctx
方法。