回答知乎提问:https://www.zhihu.com/question/565420155

最近正好研究了一下这个schedule,顺便简单总结一下,官方给的文档介绍确实比较抽象: https://tvm.apache.org/docs/reference/api/python/tir.html

题主困惑的应该是factor和offset是什么意思,为什么这样能够解决shared memory bank conflict?

第一个问题,可以看看代码,首先是底层的实现(https://github.com/apache/tvm/blob/HEAD/src/tir/transforms/storage_flatten.cc#L480-L481):

PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
for (size_t i = shape.size(); i != 0; --i) {
  size_t dim = i - 1;
  if (dim < avec.size() && avec[dim].align_factor != 0) {
    PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
    PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
    stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
    stride = bound_analyzer_->Simplify(stride);
  }
  rstrides.push_back(stride);
  stride = stride * shape[dim];
}

显然可以通过图中的公式计算出最后的stride,例如网上能搜到的一个case:

import tvm

n = 1024
factor = 100
offset = 8
dtype = "float32"
A = tvm.te.placeholder((n, n), dtype=dtype, name='A')
k = tvm.te.reduce_axis((0, n), name='k')
B = tvm.te.compute((n,), lambda i: tvm.te.sum(A[i, k], axis=k), name='B')

s = tvm.te.create_schedule(B.op)
AA = s.cache_read(A, "shared", [B])

print(tvm.lower(s, [A, B], simple_mode=True))
print("---------cutting line---------")

s[AA].storage_align(AA.op.axis[0], factor, offset)

print(tvm.lower(s, [A, B], simple_mode=True))

'''
@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024], [])}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024], [])} {
  allocate(A.shared: Pointer(shared float32), float32, [1048576]), storage_scope = shared {
    for (ax0: int32, 0, 1024) {
      for (ax1: int32, 0, 1024) {
        let cse_var_1: int32 = ((ax0*1024) + ax1)
        A.shared_1: Buffer(A.shared, float32, [1048576], [], scope="shared")[cse_var_1] = A[cse_var_1]
      }
    }
    for (i: int32, 0, 1024) {
      B[i] = 0f32
      for (k: int32, 0, 1024) {
        B[i] = (B[i] + A.shared_1[((i*1024) + k)])
      }
    }
  }
}


---------cutting line---------
@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024], [])}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024], [])} {
  allocate(A.shared: Pointer(shared float32), float32, [1134592]), storage_scope = shared {
    for (ax0: int32, 0, 1024) {
      for (ax1: int32, 0, 1024) {
        A.shared_1: Buffer(A.shared, float32, [1134592], [], scope="shared")[((ax0*1108) + ax1)] = A[((ax0*1024) + ax1)]
      }
    }
    for (i: int32, 0, 1024) {
      B[i] = 0f32
      for (k: int32, 0, 1024) {
        B[i] = (B[i] + A.shared_1[((i*1108) + k)])
      }
    }
  }
}
'''

用这个公式计算一下:
$$
(100+8-1024%100)% 100 + 1024 = (108-24) + 1024 = 1108
$$
这个公式可以理解为,对于原来给定的一个stride,如1024,首先跟factor对其,如1024对其之后是1100,再补上offset,可以实现一个类似memory zero padding的效果,再tvm的repo里,还可以翻到一些经常用的(并没有,奇怪的用法:

s[CS].storage_align(bb, CS_align - 1, CS_align)

推导一下公式
$$
stride = stride + (C-1+C-(stride%(C-1)))% (C-1)
$$
而在一些情况下, 这里的CS_align等于stride,则stride不变,如果加上一个offset,则需要另外考虑。

第二个问题需要了解一下在gpu矩阵乘法计算中的一种通过加pad的方式解决bank conflict的方法,假设我们都按照cutlass的思路来进行矩阵乘法计算,并且利用tensorcore,以一个简单的warp算m16n16k16的矩阵乘法为例子:

Drawing6

左边图片中白色的部分是一个典型的A矩阵在shared memory里的排布,大小是128*32的矩阵,一次取一个小矩阵在内存的排布,一次使用l ds128指令取八个float16的元素,每个线程访问的bank如下面所示,有一半的bank是没有被访问到的,一种常用的解法是给每一行加PAD,例如右图,每一行加4个bank大小的pad,这样带宽就可以利用满,这样做法的优点是简单,但是缺点也很明显,一是写入shared memory就会有conflict,需要动脑消除一下,二是会增加shared memory的开销,有了这个图示,就可以解决第二个问题了。

回到tvm,如果只用一个storage align schedule,速度可能会快一些,这来源于你解决了wmma::load_matrix_sync引入的shared memory load conflict,但是因为从global memory读入shared memory的shared memory store过程中线程与线程之间多了padding,会导致引入store的conflict。

而且理论上存在解,不需要加padding,控制好每个线程访问的bank让他们不conflict,cutlass里提供了这样的一种解法:

image-20221108192807163

这两种情况显然不能用storage_align解决了,可以用tvm的tensorize schdule和decl_buffer来达到这个目的,这种实现方式也更自由,如这里的代码:

import tvm
from tvm import te

def intrin_load_matrix_to_slb():
    output_shape = (16, 64)
    strides_src = [64, 1]
    strides_dst = [64, 1]

    A = te.placeholder(output_shape, name="A", dtype="float32")
    C = te.compute(output_shape, lambda *i: A(*i), name="C")

    BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="global", strides=strides_src, data_alignment=64, offset_factor=1)
    BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope="shared",  strides=strides_dst, data_alignment=64, offset_factor=1)

    def intrin_func(ins, outs):
        ib = tvm.tir.ir_builder.create()

        BA = ins[0]
        BC = outs[0]

        tx = te.thread_axis("threadIdx.x")
        ib.scope_attr(tx, "thread_extent", 64)
        index = tx // 1

        for outer in range(0, 16):
                ib.emit(BC.vstore([outer, index], BA.vload([outer, index], "float32")))

        return ib.get()
    return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})

M = 64
N = 64
A = te.placeholder((M, N), dtype="float32", name="A")
B = te.compute((M, N), lambda *i: A(*i), name="B", )
s = te.create_schedule(B.op)
tx = te.thread_axis("threadIdx.x")
AS = s.cache_read(A, "shared", [B])
cx, ci = B.op.axis
cxo, cxi = s[B].split(cx, factor=16)
s[B].reorder(cxo, cxi, ci)
s[B].bind(ci, tx)

s[AS].compute_at(s[B], cxo)
ax, ai = AS.op.axis
# s[AS].storage_align(ax, 63, 64)
s[AS].tensorize(ax, intrin_load_matrix_to_slb())
s[AS].double_buffer()
print(tvm.lower(s, [A, B]))
output:
  @main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: Pointer(float32), float32, [64, 64], []),
             A: Buffer(A_2: Pointer(float32), float32, [64, 64], [])}
  buffer_map = {A_1: A, B_1: B} {
  allocate(A.shared: Pointer(shared float32), float32, [1024]), storage_scope = shared;
  for (i0.outer: int32, 0, 4) {
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 {
      A.shared[threadIdx.x] = (float32*)A_2[((i0.outer*1024) + threadIdx.x)]
      A.shared[(threadIdx.x + 64)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 64)]
      A.shared[(threadIdx.x + 128)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 128)]
      A.shared[(threadIdx.x + 192)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 192)]
      A.shared[(threadIdx.x + 256)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 256)]
      A.shared[(threadIdx.x + 320)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 320)]
      A.shared[(threadIdx.x + 384)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 384)]
      A.shared[(threadIdx.x + 448)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 448)]
      A.shared[(threadIdx.x + 512)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 512)]
      A.shared[(threadIdx.x + 576)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 576)]
      A.shared[(threadIdx.x + 640)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 640)]
      A.shared[(threadIdx.x + 704)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 704)]
      A.shared[(threadIdx.x + 768)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 768)]
      A.shared[(threadIdx.x + 832)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 832)]
      A.shared[(threadIdx.x + 896)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 896)]
      A.shared[(threadIdx.x + 960)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 960)]
    }
    for (i0.inner: int32, 0, 16) {
      attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
      B_2[(((i0.outer*1024) + (i0.inner*64)) + threadIdx.x_1)] = (float32*)A.shared[((i0.inner*64) + threadIdx.x_1)]
    }
  }
}
TVMMLSYS

w关于IGEMM在CUDA Core上的优化,网上基本没有看到开源的实现(甚至连实现不好的版本都没有,可能大家都是直接用的cublas/cutlass来实现吧,这里是我通过手写CUDA代码的方式使用DP4A指令达到sota的性能,不得不说这个东西没有人实现也是有原因的,坑是真的多。

我们以M=K=N 16384的矩阵乘举例。

Digilal DesignEEEE

在之前的两篇文章中,我们分别用TVM的Tensor Expression与TIR Script完成了在Nvidia Cuda Core上的高效的FP32 矩阵乘法,3090-24GB的各种精度在Cuda Core和Tensor Core上的Peak TFLOPS如下表所示:

3090-24GB FP32 FP16 BF16 INT32 INT8 INT4
Cuda Core 35.6 35.6 35.6 17.8 71.2 \
Tensor Core \ 142 / 284* 142 / 284* \ 284 / 568* 568 / 1136*

有意思的是,3090上,FP16的Peak Peformance和FP32是一样的,这一点比较特殊,是因为架构上的改动,一般而言fp16的性能都会是fp32的两倍或者四倍,这个主要是因为20系的gpu把fp32和int32的Cuda Core分开了,从而能同时进行fp32和int32的计算,30系把int32的core又就加上了fp32的计算单元,所以fp32的计算能力翻倍,而cutlass下的16384的gemm。

按照3090上的硬件单元分类,我们还可以探索一些有意思的加速,比如在CUDA Core上使用SIMD指令(DP4A,HFMA2来优化int8、half的性能,

CUDA ProgrammingMLSys

上一篇文章中讲到如何利用cutlass优化gemm的思路,使用tvm tensor expression来实现一个高效的矩阵乘法,这里再探索一下直接从TIR Script把这个东西复现一下,对比一下两者的异同。

TensorIR 今年7月再arxiv上放了一篇preprint,感兴趣的读者可以自行阅读:https://arxiv.org/abs/2207.04296

不过写这篇文章的时候,tvm上游(main)分支的tir与paper里还不是一样,siyuan他们另做了许多改进,估计要等paper中了才会被合并到上游(貌似是在投ASPLOS?所以这里还是以tvm上我们可以实际操作的TensorIR Script为例子,优化的思路则不多讲解,和之前的tensor expression是一样的。

PS: 感觉TIR Script的设计和写法更贴近GPU,比tensor expression更抽象,有亿点点摸不着头脑,不过也比直接从tensor ir来构建一个dag要舒服地多,虽然通过自己瞎理解与实验加上在论坛交流了一下,也算是都摸出来怎么实现,但我相信应该还会有更优雅的写法。

代码还是放在: https://github.com/LeiWang1999/tvm_gpu_gemm

CUDA ProgrammingMLSys

这里记录的是我想从tvm的tensor expression出发,参考一下cutlass efficient gemm的思路,一步一步优化一下GEMM的一些思考和碎碎念,目的是为了理解cutlass优化gemm的思路。

我们使用CUTLASS Profiler来运行一个gemm的运算,并用nsight compute dump下来其运行过程中的一些情况,可以拿到他的一些信息,如grid的大小与block的大小等。比如对于16384的float32类型数据的gemm,cutlass的grid size是(512, 16, 1)-> 8192个block, block size是(256,1,1),一共是2,097,152个线程,因为最后产生C的大小是(16384,16384),所以平均每个thread需要产生128个C的元素,结合这些参数的信息,使用tvm的te进行schedule(其实可以试试tensor ir),最后成功打到了和cublas,cutlass相近的性能。

测试GPU: rtx 3090 24GB

CUDA Version: 11.1

TVM Version: 10.0

代码放在:https://github.com/LeiWang1999/tvm_gpu_gemm

CUDA ProgrammingMLSys