diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index cb3194f8eb1d..69967c55a7ff 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -226,16 +226,12 @@ inline Expr ElemOffset(const BufferNode* n, Array index) { Expr base = n->elem_offset; if (n->strides.size() == 0) { CHECK_EQ(n->shape.size(), index.size()); - if (n->shape.size() != 0) { - if (is_zero(base)) { - base = index[0]; - } else { - base = base + index[0]; + if (index.size() > 0) { + Expr offset = index[0]; + for (size_t i = 1; i < index.size(); ++i) { + offset = MergeMulMod(offset * n->shape[i] + index[i]); } - } - base = MergeMulMod(base); - for (size_t i = 1; i < index.size(); ++i) { - base = MergeMulMod(base * n->shape[i] + index[i]); + base = base + offset; } } else { CHECK_EQ(n->strides.size(), index.size()); diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py index a5a8f5d065a6..51f1e3abb7e9 100644 --- a/tests/python/unittest/test_lang_buffer.py +++ b/tests/python/unittest/test_lang_buffer.py @@ -41,6 +41,14 @@ def test_buffer_access_ptr_offset(): assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v)) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE +def test_buffer_vload(): + m = tvm.var('m') + n = tvm.var('n') + Ab = tvm.decl_buffer((m, n), tvm.float32, elem_offset=100) + load = Ab.vload([2, 3]) + offset = tvm.ir_pass.Simplify(load.index) + assert tvm.ir_pass.Equal(offset, n * 2 + 103) + def test_buffer_index_merge_mult_mod(): m = tvm.var('m') n = tvm.var('n') @@ -76,4 +84,5 @@ def assert_simplified_equal(index_simplified, index_direct): test_buffer() test_buffer_access_ptr() test_buffer_access_ptr_offset() + test_buffer_vload() test_buffer_index_merge_mult_mod()