使用 TVM 优化深度学习 GPU 算子:深度卷积示例


高效的深度学习算子是深度学习系统的核心。通常,这些算子难以优化,需要 HPC 专家付出巨大努力。TVM,一个端到端的张量 IR/DSL 堆栈,使这项工作变得容易得多。

这篇博客教你如何在 TVM 的帮助下编写高性能 GPU 算子内核。我们以深度卷积(即 topi.nn.depthwise_conv2d_nchw)为例,并演示了我们如何超越 tensorflow 中已手工优化的 CUDA 内核。在不同的工作负载下,我们的最终版本比 tf-1.2 中优化的内核快 2 倍-4 倍,并且在启用算子融合的情况下快 3 倍-7 倍。以下是在 GTX1080 上测试的结果,滤波器大小 = [1, 256, 3, 3],步幅 = [1, 1],填充 = ‘SAME’

image

深度卷积简介

深度卷积是现代架构(如 Xception [1] 和 MobileNet [2])的重要组成部分。它是降低深度神经网络计算复杂度的有效方法。

image

来源: http://machinethink.net/blog/googles-mobile-net-architecture-on-iphone/

在 TVM 中,深度卷积可以声明为

# padding stage
PaddedInput = tvm.compute(
    (batch, in_channel, height_after_pad, width_after_pad),
    lambda b, c, i, j: tvm.select(
        tvm.all(i >= pad_top, i - pad_top < in_height, j >= pad_left, j - pad_left < in_width),
        Input[b, c, i - pad_top, j - pad_left], tvm.const(0.0)),
    name="PaddedInput")
# depthconv stage
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
    (batch, out_channel, out_height, out_width),
    lambda b, c, i, j: tvm.sum(
        PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] * Filter[c/channel_multiplier, c%channel_multiplier, di, dj],
        axis=[di, dj]),
    name='DepthwiseConv2d')

通用 GPU 优化指南

这部分简要介绍了优化 CUDA 代码时应该了解的三个概念:数据重用、共享内存和 bank 冲突。如果您已经了解它们,则可以跳过这部分。

数据重用

在现代计算架构中,从内存加载数据的成本远高于进行单次浮点计算 [3]。因此,我们总是希望在输入数据加载到寄存器或共享内存(缓存)后重用它们。

深度卷积中有两种形式的数据重用:滤波器重用和输入重用。滤波器重用发生在滤波器在输入通道上滑动并多次计算时。输入重用通过平铺实现,让我们以 3x3 深度卷积为例

image

在没有平铺的情况下,每个线程计算 1 个输出元素并加载 3x3 输入数据。16 个线程总共有 9x16 次加载。

image

通过平铺,每个线程计算 2x2 输出元素并加载 4x4 输入数据。4 个线程总共有 16x4 次加载。

共享内存和 Bank 冲突

共享内存可以看作是 GPU 中的缓存。它位于芯片上,比全局内存快得多。

image

共享内存是按块分配的。常见的做法是将数据从全局内存加载到共享内存中,然后块中的所有线程都从共享内存中读取数据。

共享内存的大小是有限的(通常为 48K),因此我们必须注意共享内存溢出。此外,分配给一个块的共享内存过多会限制每个多处理器中的活动块数。

共享内存的另一个性能问题是 bank 冲突。共享内存被划分为大小相等的内存模块(banks),可以同时访问,但是,如果多个线程访问同一个内存 bank(导致 bank 冲突),则访问将被串行化,从而降低有效带宽。

共享内存 banks 的组织方式使得连续的地址被分配给连续的 banks。为了避免 bank 冲突,最好是连续的线程访问连续的内存地址,如下所示(每种颜色代表一个共享内存 bank)

image

有关共享内存和 bank 冲突的更多详细信息,请参阅 Nvidia 的这篇博客

好的,现在让我们开始在 TVM 中优化深度卷积。

调度优化

内联计算 PaddedInput 以节省内存分配

从第 1 部分可以看出,填充被显式声明为一个单独的阶段。我们内联计算它以避免冗余的内存分配

s = tvm.create_schedule(Output.op)
s[PaddedInput].compute_inline()

将一个大通道划分为更小的块

深度卷积的一种直接调度方法是,一个 cuda 块负责一个输入通道和相应的滤波器,将它们加载到共享内存中,然后进行计算

IS = s.cache_read(PaddedInput, "shared", [DepthwiseConv2d])
FS = s.cache_read(Filter, "shared", [DepthwiseConv2d])
block_y = tvm.thread_axis("blockIdx.y")
block_x = tvm.thread_axis("blockIdx.x")
# bind the dimension of batch (N in NCHW) with block_y
s[Output].bind(Output.op.axis[0], block_y)
# bind the dimension of channel (C in NCHW) with block_x
s[Output].bind(Output.op.axis[1], block_x)

我们测试了 GTX 1080 上 1000 次运行的平均时间成本,并与 tensorflow 中的 depthwise_conv2d 进行了比较。这是结果

输入 滤波器 步幅 tf-1.2 SAME 填充 (us) TVM SAME 填充 (us)
[1, 256, 21, 21] [256, 1, 3, 3] [1, 1] 16.1 9.1
[1, 256, 32, 32] [256, 1, 3, 3] [1, 1] 34.8 14.5
[1, 256, 64, 64] [256, 1, 3, 3] [1, 1] 130.9 98.9
[1, 256, 96, 96] [256, 1, 3, 3] [1, 1] 251.6 387.4

我们可以看到,这种调度在通道尺寸较小(如 21 x 21 或 32 x 32)时表现良好,但是,当通道尺寸增加到大于 64 x 64 时,其性能会严重下降。一个主要原因是分配给一个块的共享内存过多会限制每个多处理器中的活动块数。

我们修改了调度,将一个大通道划分为更小的块。例如,一个通道(64 x 64 或 96 x 96)被划分为 32 x 32 的块,一个 cuda 块负责一个 32 x 32 的块

blocking_h = 32
blocking_w = 32
# split the dimension of height (H in NCHW)
bx1, _ = s[Output].split(Output.op.axis[2], factor=blocking_h)
# split the dimension of width (W in NCHW)
bx2, _ = s[Output].split(Output.op.axis[3], factor=blocking_w)
# assign one 32 x 32 block to one cuda block
by = s[Output].fuse(Output.op.axis[0], Output.op.axis[1])
s[Output].bind(by, block_y)
bx = s[Output].fuse(bx1, bx2)
s[Output].bind(bx, block_x)

这是新的结果

输入 [blocking_h, blocking_w] tf-1.2 SAME 填充 (us) TVM SAME 填充 (us)
[1, 256, 64, 64] [32, 32] 130.9 63.4
[1, 256, 96, 96] [32, 32] 251.6 132.5

我们的分块策略奏效了!对于 64 x 64 通道尺寸,它带来了 1.6 倍的加速(98.9us -> 63.4us);对于 96 x 96 通道尺寸,它带来了 2.9 倍的加速(387.4us -> 132.5us)。

调整线程数参数

如何调度工作负载,例如,在一个 cuda 块的线程之间调度 32x32?直观上,它应该是这样的

num_thread_y = 8
num_thread_x = 8
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
ty, yi = s[Output].split(h_dim, nparts=num_thread_y)
tx, xi = s[Output].split(w_dim, nparts=num_thread_x)
s[Output].reorder(ty, tx, yi, xi)
s[Output].bind(ty, thread_y)
s[Output].bind(tx, thread_x)

调度中有两个参数:num_thread_ynum_thread_x。如何确定它们的最佳组合?好吧,让我们先做一些实验。以下是滤波器 = [256, 1, 3, 3] 和步幅 = [1, 1] 的结果

案例 输入 num_thread_y num_thread_x TVM SAME 填充 (us)
1 [1, 256, 32, 32] 8 32 9.7
2 [1, 256, 32, 32] 4 32 8.8
3 [1, 256, 32, 32] 1 32 17.7
4 [1, 256, 32, 32] 32 1 32.5

从以上结果中可以得出许多有趣的观察结果

  • 案例 2 比案例 1 快。在案例 2 中,每个线程计算输出中的 8x1 瓦片,这对应于输入中的 10x3 瓦片。它比案例 1 的 4x1 瓦片具有更好的数据重用。

  • 案例 3 比案例 2 慢。这是因为在案例 3 中,每个线程的工作负载太大,导致本地内存读取成本过高。

  • 案例 4 比案例 3 慢。这是因为 num_thread_x = 32 确保没有 bank 冲突,而 num_thread_y = 32 则不能。

总结我们从以上观察中学到的知识

  • 大瓦片有利于数据重用,但不利于本地内存读取。
  • num_thread_ynum_thread_x 对 bank 冲突的影响是不对称的。
  • 找到 num_thread_ynum_thread_x 的最佳组合是为了在高效的共享内存访问(避免 bank 冲突)、数据重用和本地内存读取之间取得平衡。

非常棘手。那么,我们究竟应该如何找到最佳组合呢?答案是暴力搜索。我们可以将 num_thread_ynum_thread_x 作为参数传递给调度函数,并尝试所有可能的组合以找到最佳组合。这在 TVM 中可以很容易地完成

def schedule_depthwise_conv2d(..., num_thread_y=8, num_thread_x=8):
    num_thread_y = num_thread_y
    num_thread_x = num_thread_x
    do_schedule_as_usual
    return schedule

min_time_cost = inf
for num_thread_y, num_thread_x in all_possible_combinations:
    schedule = schedule_depthwise_conv2d(..., num_thread_y=num_thread_y, num_thread_x=num_thread_x)
    time_cost = test_depthwise_conv2d(..., schedule)
    if time_cost < min_time_cost:
        min_time_cost = time_cost
        optimal_combination = [num_thread_y, num_thread_x]

实际上,它可以被看作是一个简单的自动调度器。

Vthread 和步幅模式

TVM 中引入 Vthread(虚拟线程)是为了支持步幅模式。我们可以这样使用它

num_vthread_y = 2
num_vthread_x = 2
num_thread_y = 8
num_thread_x = 8
thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy")
thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx")
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
# split the dimension of height (H in NCHW) twice
tvy, vyi = s[Output].split(h_dim, nparts=num_vthread_y)
ty, yi = s[Output].split(vyi, nparts=num_thread_y)
# split the dimension of width (W in NCHW) twice
tvx, vxi = s[Output].split(w_dim, nparts=num_vthread_x)
tx, xi = s[Output].split(vxi, nparts=num_thread_x)
# bind thread and vthread respectively
s[Output].bind(tvy, thread_vy)
s[Output].bind(tvx, thread_vx)
s[Output].bind(ty, thread_y)
s[Output].bind(tx, thread_x)
s[Output].reorder(tvy, tvx, ty, tx, yi, xi)

让我们打印 IR 以查看 vthread 的作用

/* Input = [1, 1, 32, 32], Filter = [1, 1, 3, 3], stride = [1, 1], padding = 'SAME' */
produce DepthwiseConv2d {
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
  // attr [iter_var(threadIdx.y, Range(min=0, extent=8), threadIdx.y)] thread_extent = 8
  // attr [iter_var(threadIdx.x, Range(min=0, extent=8), threadIdx.x)] thread_extent = 8
  for (i.inner.inner.inner, 0, 2) {
    for (j.inner.inner.inner, 0, 2) {
      DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] = 0.000000f
      DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 512)] = 0.000000f
      DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 16)] = 0.000000f
      DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 528)] = 0.000000f
      for (di, 0, 3) {
        for (dj, 0, 3) {
          DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] = (DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] + (tvm_if_then_else(((((((1 - di) - i.inner.inner.inner) <= (((blockIdx.x*16) + threadIdx.y)*2)) && ((((blockIdx.x*16) + threadIdx.y)*2) < ((33 - di) - i.inner.inner.inner))) && (((1 - dj) - j.inner.inner.inner) <= (threadIdx.x*2))) && ((threadIdx.x*2) < ((33 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + -33)], 0.000000f)*Filter[((di*3) + dj)]))
          DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 512)] = (DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 512)] + (tvm_if_then_else(((((((-15 - di) - i.inner.inner.inner) <= (((blockIdx.x*16) + threadIdx.y)*2)) && ((((blockIdx.x*16) + threadIdx.y)*2) < ((17 - di) - i.inner.inner.inner))) && (((1 - dj) - j.inner.inner.inner) <= (threadIdx.x*2))) && ((threadIdx.x*2) < ((33 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + 479)], 0.000000f)*Filter[((di*3) + dj)]))
          DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 16)] = (DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 16)] + (tvm_if_then_else(((((((1 - di) - i.inner.inner.inner) <= (((blockIdx.x*16) + threadIdx.y)*2)) && ((((blockIdx.x*16) + threadIdx.y)*2) < ((33 - di) - i.inner.inner.inner))) && (((-15 - dj) - j.inner.inner.inner) <= (threadIdx.x*2))) && ((threadIdx.x*2) < ((17 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + -17)], 0.000000f)*Filter[((di*3) + dj)]))
          DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 528)] = (DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 528)] + (tvm_if_then_else(((((((-15 - di) - i.inner.inner.inner) <= (((blockIdx.x*16) + threadIdx.y)*2)) && ((((blockIdx.x*16) + threadIdx.y)*2) < ((17 - di) - i.inner.inner.inner))) && (((-15 - dj) - j.inner.inner.inner) <= (threadIdx.x*2))) && ((threadIdx.x*2) < ((17 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + 495)], 0.000000f)*Filter[((di*3) + dj)]))
        }
      }
    }
  }
}

在没有 vthread 的情况下(仅设置为 1),IR 是

/* Input = [1, 1, 32, 32], Filter = [1, 1, 3, 3], stride = [1, 1], padding = 'SAME' */
produce DepthwiseConv2d {
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
  // attr [iter_var(threadIdx.y, Range(min=0, extent=8), threadIdx.y)] thread_extent = 8
  // attr [iter_var(threadIdx.x, Range(min=0, extent=8), threadIdx.x)] thread_extent = 8
  for (i.inner.inner.inner, 0, 4) {
    for (j.inner.inner.inner, 0, 4) {
      DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] = 0.000000f
      for (di, 0, 3) {
        for (dj, 0, 3) {
          DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] = (DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] + (tvm_if_then_else(((((((1 - di) - i.inner.inner.inner) <= (((blockIdx.x*8) + threadIdx.y)*4)) && ((((blockIdx.x*8) + threadIdx.y)*4) < ((33 - di) - i.inner.inner.inner))) && (((1 - dj) - j.inner.inner.inner) <= (threadIdx.x*4))) && ((threadIdx.x*4) < ((33 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + -33)], 0.000000f)*Filter[((di*3) + dj)]))
        }
      }
    }
  }
}

我们可以看到,当 num_vthread_y = 2num_vthread_x = 2 时,32 x 32 通道被划分为四个 16 x 16 的子通道。每个线程一次计算四个输出元素,每个子通道中一个元素。

以下是滤波器 = [256, 1, 3, 3],步幅 = [1, 1],blocking_h = 32,blocking_w = 32 的结果

案例 输入 num_thread_y, num_thread_x num_vthread_y, num_vthread_x TVM SAME 填充 (us)
1 [1, 256, 96, 96] 8, 8 1, 1 132.5
2 [1, 256, 96, 96] 8, 8 1, 4 103.1
3 [1, 256, 96, 96] 4, 32 1, 1 95.9
4 [1, 256, 96, 96] 8, 16 1, 2 90.9

案例 2 比案例 1 快。这是因为在案例 2 中,num_thread_x=8num_vthread_x=4 一起确保连续的线程访问连续的内存地址,从而避免 bank 冲突,如下所示(每种颜色代表一个线程的工作负载)

image

理论上,案例 3 和案例 4 应该一样快,因为它们每个线程的工作负载相同,并且都享受高效的共享内存访问。不知何故,案例 4 只是稍微快一点。

还记得 tensorflow 的速度吗?它是 251.6us,现在 TVM 快了 2.8 倍。387.4 -> 132.5 -> 95.9 -> 90.9,分块帮助最大;调整线程数节省了 37us;vthread 额外节省了 5us。

事实上,对于较大的内核大小或 channel_multiplier(因为更多的滤波器重用),TVM 可以比 tensorflow 快得多

输入 滤波器 步幅 tf-1.2 SAME 填充 (us) TVM SAME 填充 (us) TVM 快多少
[1, 256, 96, 96] [256, 1, 3, 3] [1, 1] 251.6 90.9 2.8倍
[1, 256, 96, 96] [256, 1, 5, 5] [1, 1] 597.6 128.9 4.6倍
[1, 256, 96, 96] [256, 2, 3, 3] [1, 1] 659.9 143.7 4.6倍
[1, 256, 96, 96] [256, 2, 5, 5] [1, 1] 1203.9 170.5 7.1倍

算子融合

我们可以在深度学习中进行的一种典型优化是算子融合,它在单个内核中一起计算多个算子,而无需将中间结果保存回全局内存。TVM 开箱即用地支持这一点。

考虑神经网络中的常见模式:depthwise_conv2d + scale_shift + relu。我们可以通过稍微修改原始调度,将这三个算子融合为一个

DepthwiseConv2d = topi.nn.depthwise_conv2d(Input, Filter, stride, padding)
ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)

Output = Relu # is no longer DepthwiseConv2d
s[ScaleShift].compute_inline() # this line fuses ScaleShift, explicitly
s[DepthwiseConv2d].set_scope("local") # this line fuses DepthwiseConv2d, implicitly
schedule(Output) # schedule for Output the same way we schedule for DepthwiseConv2d as discussed above
s[DepthwiseConv2d].compute_at(s[Output], tx) # tx is the inner most axis, bound to threadIdx.x

它生成像这样的 IR

/* Input = [1, 1, 32, 32], Filter = [1, 1, 3, 3], stride = [1, 1], padding = 'SAME' */
produce Relu {
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
  // attr [DepthwiseConv2d] storage_scope = "local"
  allocate DepthwiseConv2d[float32 * 1 * 1 * 4 * 4]
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
  // attr [iter_var(threadIdx.y, Range(min=0, extent=8), threadIdx.y)] thread_extent = 8
  // attr [iter_var(threadIdx.x, Range(min=0, extent=8), threadIdx.x)] thread_extent = 8
  produce DepthwiseConv2d {
    for (i, 0, 4) {
      for (j, 0, 4) {
        DepthwiseConv2d[((i*4) + j)] = 0.000000f
        for (di, 0, 3) {
          for (dj, 0, 3) {
            DepthwiseConv2d[((i*4) + j)] = (DepthwiseConv2d[((i*4) + j)] + (tvm_if_then_else(((((((1 - di) - i) <= (((blockIdx.x*8) + threadIdx.y)*4)) && ((((blockIdx.x*8) + threadIdx.y)*4) < ((33 - di) - i))) && (((1 - dj) - j) <= (threadIdx.x*4))) && ((threadIdx.x*4) < ((33 - dj) - j))), Input[(((((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i*32)) + j) + (di*32)) + dj) + -33)], 0.000000f)*Filter[((di*3) + dj)]))
          }
        }
      }
    }
  }
  for (i2.inner.inner.inner, 0, 4) {
    for (i3.inner.inner.inner, 0, 4) {
      Relu[((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i2.inner.inner.inner*32)) + i3.inner.inner.inner)] = max(((DepthwiseConv2d[((i2.inner.inner.inner*4) + i3.inner.inner.inner)]*Scale[0]) + Shift[0]), 0.000000f)
    }
  }
}

我们可以看到,每个线程在将 depthwise_conv2d 的结果写入全局内存之前,都会计算 scale_shiftrelu。融合后的算子与单个 depthwise_conv2d 一样快。以下是输入 = [1, 256, 96, 96],滤波器 = [256, 1, 3, 3],步幅 = [1, 1],填充 = ‘SAME’ 的结果

  • tf-1.2 depthwise_conv2d: 251.6 us
  • tf-1.2 depthwise_conv2d + scale_shift + relu (分离): 419.9 us
  • TVM depthwise_conv2d: 90.9 us
  • TVM depthwise_conv2d + scale_shift + relu (融合): 91.5 us

算子融合的优势是显而易见的。

这还不是终点,TVM 可以以更智能的方式进行算子融合。您可以参考 这里 并阅读下面提供的源代码。

给我看代码

致谢

作者衷心感谢陈天奇先生的有益建议和启发性讨论。

作者简介

Yuwei Hu图森未来 HPC 组的实习生。他在北京航空航天大学获得电气工程学士学位后,正在经历间隔年。

参考文献

[1] Xception: 基于深度可分离卷积的深度学习

[2] MobileNets: 面向移动视觉应用的高效卷积神经网络

[3] 典型 PC 上各种操作的近似计时