Skip to content

Commit

Permalink
Fix ArgBinder assert order (#3794)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored Aug 17, 2019
1 parent b76b627 commit f9d8d06
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/pass/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
AssertStmt::make(arith::ComputeReduce<ir::And>(conds, Expr()),
stride_err_msg.str(), Evaluate::make(0));
check = IfThenElse::make(Not::make(is_null), check, Stmt());
init_nest_.emplace_back(Block::make(check, Evaluate::make(0)));
asserts_.emplace_back(Block::make(check, Evaluate::make(0)));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
Type stype = buffer->DefaultIndexType();
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_build_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,13 @@ def test_lower_rfactor():
s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
fapi = tvm.lower(s, [A, B])

def test_dependent_output_shape():
n, m, x = tvm.var('n'), tvm.var('m'), tvm.var('x')
A = tvm.placeholder((n, m))
B = tvm.compute((m, n/x), lambda i, j: A[i,j] , name='B')
s = tvm.create_schedule(B.op)
mod = tvm.build(s, [A, B, x])

if __name__ == "__main__":
test_lower_rfactor()
test_dependent_output_shape()

0 comments on commit f9d8d06

Please sign in to comment.