diff --git a/topi/python/topi/cuda/batch_matmul.py b/topi/python/topi/cuda/batch_matmul.py index c973f1d8ea34..b5dd802ad1e9 100644 --- a/topi/python/topi/cuda/batch_matmul.py +++ b/topi/python/topi/cuda/batch_matmul.py @@ -38,6 +38,7 @@ def schedule_batch_matmul(outs): s: Schedule The computation schedule for the op. """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs s = tvm.create_schedule([x.op for x in outs]) def _schedule(op): @@ -49,6 +50,9 @@ def _schedule(op): BB = s.cache_read(B, "shared", [C]) BL = s.cache_read(BB, "local", [C]) CC = s.cache_write(C, "local") + if op not in s.outputs: + s[C].compute_inline() + C = s.outputs[0].output(0) b, y, x = s[C].op.axis y_bn = get_max_power2_factor(M, 64)