diff --git a/topi/python/topi/x86/batch_matmul.py b/topi/python/topi/x86/batch_matmul.py index b505cbfabb55..25b49d12400d 100644 --- a/topi/python/topi/x86/batch_matmul.py +++ b/topi/python/topi/x86/batch_matmul.py @@ -92,33 +92,40 @@ def schedule_batch_matmul(cfg, outs): def _callback(op): if "batch_matmul" in op.tag: C = op.output(0) - A, B = s[C].op.input_tensors + A, B = op.input_tensors _, M, K = get_const_tuple(A.shape) _, _, N = get_const_tuple(C.shape) + if op not in s.outputs: + s[C].compute_inline() + O = outs[0] + else: + O = C + + CC = s.cache_write(C, "global") + # create tuning space cfg.define_split("tile_y", M, num_outputs=2) cfg.define_split("tile_x", N, num_outputs=2) cfg.define_split("tile_k", K, num_outputs=2) - k, = s[C].op.reduce_axis - - ko, ki = cfg["tile_k"].apply(s, C, k) - CC = s.rfactor(C, ki) - - b, y, x = s[C].op.axis - yo, yi = cfg["tile_y"].apply(s, C, y) - xo, xi = cfg["tile_x"].apply(s, C, x) - s[C].reorder(b, yo, xo, yi, xi) - bxyo = s[C].fuse(b, yo, xo) - s[C].parallel(bxyo) - s[C].fuse(yi, xi) - - s[CC].compute_at(s[C], bxyo) - _, _, y, x = s[CC].op.axis - s[CC].fuse(y, x) - s[CC].vectorize(s[CC].op.axis[0]) - s[C].pragma(bxyo, 'auto_unroll_max_step', 16) + b, y, x = s[O].op.axis + yo, yi = cfg["tile_y"].apply(s, O, y) + xo, xi = cfg["tile_x"].apply(s, O, x) + s[O].reorder(b, yo, xo, yi, xi) + bxyo = s[O].fuse(b, yo, xo) + s[O].parallel(bxyo) + + s[CC].compute_at(s[O], bxyo) + k, = s[CC].op.reduce_axis + ko, ki = cfg["tile_k"].apply(s, CC, k) + + Crf = s.rfactor(CC, ki) + s[Crf].compute_at(s[CC], s[CC].op.axis[0]) + _, _, y, x = s[Crf].op.axis + s[Crf].fuse(y, x) + s[Crf].vectorize(s[Crf].op.axis[0]) + s[O].pragma(bxyo, 'auto_unroll_max_step', 16) traverse_inline(s, outs[0].op, _callback) return s