diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index ed4ac5ea6a63..1233e9b0b89b 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -36,10 +36,11 @@ namespace tvm { // Internal node container Buffer class BufferNode; -/*! \brief memory access kind */ -enum class AccessMask : int { - kRead = 1, - kWrite = 2 +/*! \brief buffer type */ +enum BufferType : int { + kDefault = 1, + // Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1. + kAutoBroadcast = 2, }; /*! @@ -129,6 +130,8 @@ class BufferNode : public Node { * elem_offset is guaranteed to be multiple of offset_factor. */ int offset_factor; + /*! \brief buffer type */ + BufferType buffer_type; /*! \brief constructor */ BufferNode() {} @@ -142,6 +145,7 @@ class BufferNode : public Node { v->Visit("scope", &scope); v->Visit("data_alignment", &data_alignment); v->Visit("offset_factor", &offset_factor); + v->Visit("buffer_type", &buffer_type); } /*! \return preferred index type for this buffer node */ @@ -159,7 +163,8 @@ class BufferNode : public Node { std::string name, std::string scope, int data_alignment, - int offset_factor); + int offset_factor, + BufferType buffer_type); static constexpr const char* _type_key = "Buffer"; TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node); diff --git a/python/tvm/api.py b/python/tvm/api.py index d88f06170543..e4777b6e3964 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -531,7 +531,8 @@ def decl_buffer(shape, elem_offset=None, scope="", data_alignment=-1, - offset_factor=0): + offset_factor=0, + buffer_type=""): """Declare a new symbolic buffer. Normally buffer is created automatically during lower and build. @@ -574,11 +575,39 @@ def decl_buffer(shape, If 0 is pssed, the alignment will be set to 1. if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None. + buffer_type: str, optional, {"", "auto_broadcast"} + auto_broadcast buffer allows one to implement broadcast computation + without considering whether dimension size equals to one. + TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1. + Returns ------- buffer : Buffer The created buffer + Example + ------- + Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation, + + .. code-block:: python + + m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2") + n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2") + o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2") + A = tvm.placeholder((m0, m1, m2), name='A') + B = tvm.placeholder((n0, n1, n2), name='B') + C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C') + Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="broadcast") + Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="broadcast") + s = tvm.create_schedule(C.op) + fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb}) + ctx = tvm.cpu(0) + a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx) + fadd(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) + Note ---- Buffer data structure reflects the DLTensor structure in dlpack. @@ -602,7 +631,7 @@ def decl_buffer(shape, data = var(name, "handle") return _api_internal._Buffer( data, dtype, shape, strides, elem_offset, name, scope, - data_alignment, offset_factor) + data_alignment, offset_factor, buffer_type) def layout(layout_str): """Create a layout node from a string. diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 42d60b85e375..00ac715e8c07 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -207,7 +207,13 @@ TVM_REGISTER_API("Range") }); TVM_REGISTER_API("_Buffer") -.set_body_typed(BufferNode::make); +.set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args.size(), 10); + auto buffer_type = args[9].operator std::string(); + BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; + *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], + args[5], args[6], args[7], args[8], type); + }); TVM_REGISTER_API("_BufferAccessPtr") .set_body_method(&Buffer::access_ptr); diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 6917200ff920..c1622338174d 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -342,7 +342,7 @@ Buffer BufferWithOffsetAlignment(Array shape, } return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", - data_alignment, offset_factor); + data_alignment, offset_factor, kDefault); } void GetBinds(const Array& args, diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 3e0615162a8f..573ecffe1b08 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -49,7 +49,8 @@ Buffer decl_buffer(Array shape, Expr(), name, "", - 0, 0); + 0, 0, + kDefault); } // Split the given expression w.r.t the add operator @@ -365,7 +366,8 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const { n->name + "_slice", n->scope, n->data_alignment, - 0); + 0, + n->buffer_type); } Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const { @@ -405,7 +407,8 @@ Buffer BufferNode::make(Var data, std::string name, std::string scope, int data_alignment, - int offset_factor) { + int offset_factor, + BufferType buffer_type) { auto n = make_node(); n->data = std::move(data); n->dtype = dtype; @@ -428,6 +431,12 @@ Buffer BufferNode::make(Var data, n->elem_offset = std::move(elem_offset); n->data_alignment = data_alignment; n->offset_factor = offset_factor; + n->buffer_type = buffer_type; + if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) { + for (size_t i = 0; i < n->shape.size(); ++i) { + n->strides.push_back(tvm::var("stride")); + } + } return Buffer(n); } diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 2822393d3f75..d93d08864438 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -242,6 +242,21 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, check = IfThenElse::make(Not::make(is_null), check, Stmt()); init_nest_.emplace_back(Block::make(check, Evaluate::make(0))); } + } else if (buffer->buffer_type == kAutoBroadcast) { + Type stype = buffer->DefaultIndexType(); + Expr stride = make_const(stype, 1); + for (size_t i = buffer->shape.size(); i != 0; --i) { + size_t k = i - 1; + std::ostringstream field_name; + field_name << v_strides->name_hint << '[' << k << ']'; + Expr value = cast(buffer->shape[k].type(), + Load::make(tvm_shape_type, v_strides, + IntImm::make(Int(32), k), const_true(1))); + value = tvm::if_then_else(is_null, stride, value); + value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); + Bind_(buffer->strides[k], value, field_name.str(), true); + stride = Simplify(stride * buffer->shape[k]); + } } else { std::ostringstream stride_null_err_msg; stride_null_err_msg << arg_name << ".strides: expected non-null strides."; diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index a906ee3e5474..8df5fe1f7757 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -160,7 +160,7 @@ class CopyIntrinInjector : public IRMutator { store_strides[loop_var_size], store->buffer_var->name_hint, GetStorageScope(store->buffer_var.get()), - 0, 0); + 0, 0, kDefault); Buffer src = BufferNode::make( Var(load->buffer_var.node_), load->type, @@ -169,7 +169,7 @@ class CopyIntrinInjector : public IRMutator { src_elem_offset, load->buffer_var->name_hint, GetStorageScope(load->buffer_var.get()), - 0, 0); + 0, 0, kDefault); *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); CHECK(out->defined()) << "flower function did not return correct stmt"; return true; diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 215f6d739732..ff6b41612bf4 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -220,7 +220,7 @@ class StorageFlattener : public IRMutator { Var(key.GetName(), Handle()), op->type, shape, strides, Expr(), key.GetName(), skey.to_string(), - align, 0); + align, 0, kDefault); buf_map_[key] = e; Stmt body = this->Mutate(op->body); diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py index e0bb0279c09f..bd45eac2358a 100644 --- a/tests/python/unittest/test_lang_buffer.py +++ b/tests/python/unittest/test_lang_buffer.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm.schedule import Buffer +import numpy as np def test_buffer(): m = tvm.var('m') @@ -108,6 +109,34 @@ def assert_simplified_equal(index_simplified, index_direct): index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1)))) assert_simplified_equal(index_simplified, index_direct) +def test_buffer_broadcast(): + m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2") + n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2") + o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2") + + A = tvm.placeholder((m0, m1, m2), name='A') + B = tvm.placeholder((n0, n1, n2), name='B') + + C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C') + + Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast") + Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast") + s = tvm.create_schedule(C.op) + + def check(): + if not tvm.module.enabled("llvm"): + return + fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb}) + ctx = tvm.cpu(0) + a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(2, 1, 1)).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx) + fadd(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) + + check() + + if __name__ == "__main__": test_buffer() test_buffer_access_ptr() @@ -115,3 +144,4 @@ def assert_simplified_equal(index_simplified, index_direct): test_buffer_access_ptr_extent() test_buffer_vload() test_buffer_index_merge_mult_mod() + test_buffer_broadcast() diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index ac00e52899fa..667722e465c4 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -49,7 +49,7 @@ inline Buffer DeclExternBuffer(Array shape, auto data = var(name, Handle()); auto elem_offset = Expr(); return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", - -1, 0); + -1, 0, kDefault); } /*!