通过 DLPack 构建跨框架深度学习编译器
Tensorflow、PyTorch 和 ApacheMxNet 等深度学习框架为快速原型设计和部署深度学习模型提供了强大的工具箱。不幸的是,它们的易用性通常以碎片化为代价:仅在隔离情况下使用每个框架才容易。垂直整合使常见用例的开发变得精简,但偏离常规路径可能会很棘手。
一个支持不佳的场景是在内存中直接从一个框架向另一个框架传递张量,而无需任何数据重复或复制。支持这种用例将使使用者能够串联管线,其中某些算子在一个框架中比另一个框架中得到更好的支持(或更快),从而提高效率。框架之间共享的数据表示也将弥合这一差距,并允许编译器堆栈在为算子生成代码时以单一格式为目标。
DLPack 是张量数据结构的中间内存表示标准。借助 DLPack 作为通用表示,我们可以在传统上只能依赖供应商提供的库的框架编写的脚本中利用 TVM。TVM packed 函数可以对 DLPack 张量进行操作,提供桥接来自 PyTorch 和 MxNet 等框架的张量数据结构的包装器,实现零数据复制。
DLPack 呈现了一种简单、可移植的内存数据结构
/*!
* \brief Plain C Tensor object, does not manage memory.
*/
typedef struct {
/*!
* \brief The opaque data pointer points to the allocated data.
* This will be CUDA device pointer or cl_mem handle in OpenCL.
* This pointer is always aligns to 256 bytes as in CUDA.
*/
void* data;
/*! \brief The device context of the tensor */
DLContext ctx;
/*! \brief Number of dimensions */
int ndim;
/*! \brief The data type of the pointer*/
DLDataType dtype;
/*! \brief The shape of the tensor */
int64_t* shape;
/*!
* \brief strides of the tensor,
* can be NULL, indicating tensor is compact.
*/
int64_t* strides;
/*! \brief The offset in bytes to the beginning pointer to data */
uint64_t byte_offset;
} DLTensor;
作为一个示例,我们在 TVM 中声明和编译一个矩阵乘法算子,并构建一个使用 DLPack 表示的包装器,以允许此算子支持 PyTorch 张量。我们还使用 MxNet 重复此演示。此扩展使机器学习开发人员能够快速将研究代码移植到相对不受支持的硬件平台,而不会牺牲性能。
DLPack 如何提供在框架和 TVM 之间共享的中间包装器的图示
图 1
首先,我们在 PyTorch 中计算参考输出
import torch
x = torch.rand(56,56)
y = torch.rand(56,56)
z = x.mm(y)
然后,我们定义并构建一个 TVM 矩阵乘法算子,使用默认调度
n = tvm.convert(56)
X = tvm.placeholder((n,n), name='X')
Y = tvm.placeholder((n,n), name='Y')
k = tvm.reduce_axis((0, n), name='k')
Z = tvm.compute((n,n), lambda i,j : tvm.sum(X[i,k]*Y[k,j], axis=k))
s = tvm.create_schedule(Z.op)
fmm = tvm.build(s, [X, Y, Z], target_host='llvm', name='fmm')
为了简洁起见,我们不介绍 TVM 大量的调度原语集合,我们可以使用这些原语来优化矩阵乘法。如果您希望在您的硬件设备上快速运行自定义 GEMM 算子,可以在此处找到详细教程。
然后,我们将 TVM 函数转换为支持 PyTorch 张量的函数
from tvm.contrib.dlpack import to_pytorch_func
# fmm is the previously built TVM function (Python function)
# fmm is the wrapped TVM function (Python function)
fmm_pytorch = to_pytorch_func(fmm)
z2 = torch.empty(56,56)
fmm_pytorch(x, y, z2)
np.testing.assert_allclose(z.numpy(), z2.numpy())
并验证结果是否匹配。
我们可以重复相同的示例,但这次使用 MxNet
import mxnet
from tvm.contrib.mxnet import to_mxnet_func
ctx = mxnet.cpu(0)
x = mxnet.nd.uniform(shape=(56,56), ctx=ctx)
y = mxnet.nd.uniform(shape=(56,56), ctx=ctx)
z = mxnet.nd.empty(shape=(56,56), ctx=ctx)
f = tvm.build(s, [X, Y, Z], target_host='llvm', name='f')
f_mxnet = to_mxnet_func(f)
f_mxnet(x, y, z)
np.testing.assert_allclose(z.asnumpy(), x.asnumpy().dot(y.asnumpy()))
PyTorch 示例的幕后原理
由于 TVM 提供了函数来将 dlpack 张量转换为 tvm NDArray
,反之亦然,因此所需要的只是通过包装函数来实现一些语法糖。convert_func
是一个用于使用具有 dlpack 支持的张量的框架的通用转换器,可用于实现方便的转换器,例如 to_pytorch_func
。
def convert_func(tvm_func, tensor_type, to_dlpack_func):
assert callable(tvm_func)
def _wrapper(*args):
args = tuple(ndarray.from_dlpack(to_dlpack_func(arg))\
if isinstance(arg, tensor_type) else arg for arg in args)
return tvm_func(*args)
return _wrapper
def to_pytorch_func(tvm_func):
import torch
import torch.utils.dlpack
return convert_func(tvm_func, torch.Tensor, torch.utils.dlpack.to_dlpack)