From f75d7789ad6b167e97dafa427b7409ef969ac75d Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Wed, 12 Aug 2020 01:12:56 -0700 Subject: [PATCH] [TOPI] Fix reduction (#6250) --- python/tvm/topi/cuda/reduction.py | 2 ++ tests/python/relay/test_pass_fuse_ops.py | 29 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/python/tvm/topi/cuda/reduction.py b/python/tvm/topi/cuda/reduction.py index 38e30867b791..664ea441141b 100644 --- a/python/tvm/topi/cuda/reduction.py +++ b/python/tvm/topi/cuda/reduction.py @@ -139,6 +139,8 @@ def traverse_after_reduce(operator): for tensor in input_tensors: if tensor.op not in scheduled_ops: traverse_before_reduce(tensor.op) + elif isinstance(operator, tvm.te.PlaceholderOp): + pass else: raise RuntimeError("Unsupported operator: %s" % operator.tag) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index f4369c1f1d90..1727429e74de 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -694,6 +694,34 @@ def expected(): assert tvm.ir.structural_equal(m["main"], after) +def test_fuse_bcast_reduce_scalar(): + """Test fusion case with broadcast and reduction involving scalar""" + + def before(): + x = relay.var("x", shape=(), dtype="int32") + less = relay.less(x, relay.const(10, dtype="int32")) + z = relay.min(less) + return relay.Function([x], z) + + def expected(): + p0 = relay.var("p0", shape=(), dtype="int32") + less = relay.less(p0, relay.const(10, dtype="int32")) + z0 = relay.min(less) + f0 = relay.Function([p0], z0) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + + x = relay.var("x", shape=(), dtype="int32") + f = relay.Call(f0, [x]) + return relay.Function([x], f) + + orig = before() + m = fuse2(tvm.IRModule.from_expr(orig)) + for tgt, _ in tvm.relay.testing.config.ctx_list(): + relay.build(m, tgt) + after = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(m["main"], after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -712,3 +740,4 @@ def expected(): test_fuse_max() test_fuse_take() test_fuse_gather_nd() + test_fuse_bcast_reduce_scalar()