From 9ad1ce187a6d9680ab9feb15abeff81fec6cfcd4 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Fri, 28 Jun 2019 18:36:58 +0000 Subject: [PATCH] fix --- topi/python/topi/cuda/batch_matmul.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/topi/python/topi/cuda/batch_matmul.py b/topi/python/topi/cuda/batch_matmul.py index c973f1d8ea34..b5dd802ad1e9 100644 --- a/topi/python/topi/cuda/batch_matmul.py +++ b/topi/python/topi/cuda/batch_matmul.py @@ -38,6 +38,7 @@ def schedule_batch_matmul(outs): s: Schedule The computation schedule for the op. """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs s = tvm.create_schedule([x.op for x in outs]) def _schedule(op): @@ -49,6 +50,9 @@ def _schedule(op): BB = s.cache_read(B, "shared", [C]) BL = s.cache_read(BB, "local", [C]) CC = s.cache_write(C, "local") + if op not in s.outputs: + s[C].compute_inline() + C = s.outputs[0].output(0) b, y, x = s[C].op.axis y_bn = get_max_power2_factor(M, 64)