defmatmul(n, m, l): """Return the computing expression of matrix multiplication A : n x l matrix B : l x m matrix C : n x m matrix with C = A B """ k = te.reduce_axis((0, l), name='k') A = te.placeholder((n, l), name='A') B = te.placeholder((l, m), name='B') C = te.compute((n, m), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name='C') return A, B, C
如果我们不使用reduce_axis语法,则上述程序可能会编写成如下形式:
1 2 3 4 5 6 7 8 9 10 11
deftvm_matrix_multi_noreduce(): m, n, l = [te.var(name) for name in ('m','n','l')] A = te.placeholder((m, l), dtype='float32', name='a') B = te.placeholder((l, n), dtype='float32', name='b') deff(i, j): result=0 for k inrange(0, l): result += A[i,k] * B[i,j] return reuslt C = te.compute((m,n), f, name='c') return A, B, C
A, B, C = tvm_matrix_multi() s = te.create_schedule(C.op) mod = tvm.build(s, [A,B,C]) defget_abc(shape, constructor=None): """Return random a, b and empty c with the same shape. """ np.random.seed(0) a = np.random.normal(size=shape).astype(np.float32) b = np.random.normal(size=shape).astype(np.float32) c = np.empty_like(a) if constructor: a, b, c = [constructor(x) for x in (a, b, c)] return a, b, c a, b, c = get_abc((2,2), tvm.nd.array) mod(a, b, c)
comp = lambda a, b: a * b init = lambda dtype: tvm.tir.const(1, dtype=dtype) product = te.comm_reducer(comp, init)
生成的是一维向量每个元素相乘的操作:
1 2 3 4 5 6 7
n = te.var('n') m = te.var('m') A = te.placeholder((n, m), name='a') k = te.reduce_axis((0, m), name='k') B = te.compute((n,), lambda i: product(A[i, k], axis=k), name='b') s = te.create_schedule(B.op) tvm.lower(s, [A, B], simple_mode=True)
从 k 轴上进行规约,就是把二维矩阵 A 的 i 轴上的数据相乘。
具体分析一下:
k 从 0->m 增加,当 k=0 的时候:compare 方法接受了两个参数,x 为初始化的 1,y 为 A[i,0],然后两者相乘赋值给 B[i];在第二个 loop,k=1, x 为 B[i], y 为 A[i,1] ……。
# x and y are the operands of reduction, both of them is a tuple of index # and value. deffcombine(x, y): lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs
# our identity element also need to be a tuple, so `fidentity` accepts # two types as inputs. deffidentity(t0, t1): return tvm.tir.const(-1, t0), tvm.te.min_value(t1)
Comments