diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index ff4c77accf073..8268fc4e1aed5 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -239,7 +239,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, AssertStmt::make(arith::ComputeReduce(conds, Expr()), stride_err_msg.str(), Evaluate::make(0)); check = IfThenElse::make(Not::make(is_null), check, Stmt()); - init_nest_.emplace_back(Block::make(check, Evaluate::make(0))); + asserts_.emplace_back(Block::make(check, Evaluate::make(0))); } } else if (buffer->buffer_type == kAutoBroadcast) { Type stype = buffer->DefaultIndexType(); diff --git a/tests/python/unittest/test_build_lower.py b/tests/python/unittest/test_build_lower.py index 8600fc2b46288..082b85f35b42a 100644 --- a/tests/python/unittest/test_build_lower.py +++ b/tests/python/unittest/test_build_lower.py @@ -32,5 +32,13 @@ def test_lower_rfactor(): s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) fapi = tvm.lower(s, [A, B]) +def test_dependent_output_shape(): + n, m, x = tvm.var('n'), tvm.var('m'), tvm.var('x') + A = tvm.placeholder((n, m)) + B = tvm.compute((m, n/x), lambda i, j: A[i,j] , name='B') + s = tvm.create_schedule(B.op) + mod = tvm.build(s, [A, B, x]) + if __name__ == "__main__": test_lower_rfactor() + test_dependent_output_shape()