#TVM

回答知乎提问:https://www.zhihu.com/question/565420155

最近正好研究了一下这个schedule,顺便简单总结一下,官方给的文档介绍确实比较抽象: https://tvm.apache.org/docs/reference/api/python/tir.html

题主困惑的应该是factor和offset是什么意思,为什么这样能够解决shared memory bank conflict?

第一个问题,可以看看代码,首先是底层的实现(https://github.com/apache/tvm/blob/HEAD/src/tir/transforms/storage_flatten.cc#L480-L481):

PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
for (size_t i = shape.size(); i != 0; --i) {
  size_t dim = i - 1;
  if (dim < avec.size() && avec[dim].align_factor != 0) {
    PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
    PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
    stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
    stride = bound_analyzer_->Simplify(stride);
  }
  rstrides.push_back(stride);
  stride = stride * shape[dim];
}

Read More

First. reduce_axis

tvm.reduce_axis是第一次接触到 reduce 有关的操作。

举个例子,我们实现一个矩阵乘法(参考自tvm.d2l.ai):如下图

计算公式如下:

Read More

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×