Skip to content

Commit

Permalink
fix buffer elem_offset calculation (apache#1762)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu authored and tqchen committed Sep 24, 2018
1 parent 934c60b commit ee5550f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
14 changes: 5 additions & 9 deletions src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,12 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
Expr base = n->elem_offset;
if (n->strides.size() == 0) {
CHECK_EQ(n->shape.size(), index.size());
if (n->shape.size() != 0) {
if (is_zero(base)) {
base = index[0];
} else {
base = base + index[0];
if (index.size() > 0) {
Expr offset = index[0];
for (size_t i = 1; i < index.size(); ++i) {
offset = MergeMulMod(offset * n->shape[i] + index[i]);
}
}
base = MergeMulMod(base);
for (size_t i = 1; i < index.size(); ++i) {
base = MergeMulMod(base * n->shape[i] + index[i]);
base = base + offset;
}
} else {
CHECK_EQ(n->strides.size(), index.size());
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_lang_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ def test_buffer_access_ptr_offset():
assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE

def test_buffer_vload():
m = tvm.var('m')
n = tvm.var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32, elem_offset=100)
load = Ab.vload([2, 3])
offset = tvm.ir_pass.Simplify(load.index)
assert tvm.ir_pass.Equal(offset, n * 2 + 103)

def test_buffer_index_merge_mult_mod():
m = tvm.var('m')
n = tvm.var('n')
Expand Down Expand Up @@ -76,4 +84,5 @@ def assert_simplified_equal(index_simplified, index_direct):
test_buffer()
test_buffer_access_ptr()
test_buffer_access_ptr_offset()
test_buffer_vload()
test_buffer_index_merge_mult_mod()

0 comments on commit ee5550f

Please sign in to comment.