Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen] Support broadcast op with symbolic shape #3389

Merged
merged 10 commits into from
Jul 2, 2019
9 changes: 8 additions & 1 deletion include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ class BufferNode : public Node {
* elem_offset is guaranteed to be multiple of offset_factor.
*/
int offset_factor;
/*!
* \brief buffer type, {"", "broadcast"},
* during runtime broadcast buffer will set stride = 1 for corresponding dimension == 1
*/
std::string buffer_type;
/*! \brief constructor */
BufferNode() {}

Expand All @@ -142,6 +147,7 @@ class BufferNode : public Node {
v->Visit("scope", &scope);
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
v->Visit("buffer_type", &buffer_type);
}

/*! \return preferred index type for this buffer node */
Expand All @@ -159,7 +165,8 @@ class BufferNode : public Node {
std::string name,
std::string scope,
int data_alignment,
int offset_factor);
int offset_factor,
std::string buffer_type);

static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
Expand Down
34 changes: 32 additions & 2 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,8 @@ def decl_buffer(shape,
elem_offset=None,
scope="",
data_alignment=-1,
offset_factor=0):
offset_factor=0,
buffer_type=""):
"""Declare a new symbolic buffer.

Normally buffer is created automatically during lower and build.
Expand Down Expand Up @@ -573,11 +574,40 @@ def decl_buffer(shape,
If 0 is pssed, the alignment will be set to 1.
if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.

buffer_type: str, optional, {"", "broadcast"}
broadcast buffer allows one to implement broadcast computation
without considering whether dimension size equals to one.
TVM will insert `strides[i] = shape[i] == 1 ? 0 : strides[i]` during arg binding.
See src/pass/arg_binder.cc for reference.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us avoid refer to c++ code in the python side, would be great if python doc can stand by itself

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed


Returns
-------
buffer : Buffer
The created buffer

Example
-------
Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation,

.. code-block:: python

m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2")
A = tvm.placeholder((m0, m1, m2), name='A')
B = tvm.placeholder((n0, n1, n2), name='B')
C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="broadcast")
Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="broadcast")
s = tvm.create_schedule(C.op)
fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
fadd(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())

Note
----
Buffer data structure reflects the DLTensor structure in dlpack.
Expand All @@ -601,7 +631,7 @@ def decl_buffer(shape,
data = var(name, "handle")
return _api_internal._Buffer(
data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor)
data_alignment, offset_factor, buffer_type)

def layout(layout_str):
"""Create a layout node from a string.
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape,
}

return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
data_alignment, offset_factor);
data_alignment, offset_factor, "");
}

void GetBinds(const Array<Tensor>& args,
Expand Down
16 changes: 13 additions & 3 deletions src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ Buffer decl_buffer(Array<Expr> shape,
Expr(),
name,
"",
0, 0);
0, 0,
"");
}

// Split the given expression w.r.t the add operator
Expand Down Expand Up @@ -364,7 +365,8 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
n->name + "_slice",
n->scope,
n->data_alignment,
0);
0,
n->buffer_type);
}

Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const {
Expand Down Expand Up @@ -404,7 +406,8 @@ Buffer BufferNode::make(Var data,
std::string name,
std::string scope,
int data_alignment,
int offset_factor) {
int offset_factor,
std::string buffer_type) {
auto n = make_node<BufferNode>();
n->data = std::move(data);
n->dtype = dtype;
Expand All @@ -427,6 +430,13 @@ Buffer BufferNode::make(Var data,
n->elem_offset = std::move(elem_offset);
n->data_alignment = data_alignment;
n->offset_factor = offset_factor;
n->buffer_type = std::move(buffer_type);
if (n->buffer_type == "broadcast" && n->shape.size() > 0 && n->strides.empty()) {
for (size_t i = 0; i < n->shape.size() - 1; ++i) {
n->strides.push_back(tvm::var("stride"));
}
n->strides.push_back(make_const(n->shape[0].type(), 1));
}
return Buffer(n);
}

Expand Down
16 changes: 16 additions & 0 deletions src/pass/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,22 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
check = IfThenElse::make(Not::make(is_null), check, Stmt());
init_nest_.emplace_back(Block::make(check, Evaluate::make(0)));
}
} else if (buffer->buffer_type == "broadcast") {
Buffer stride_view = buffer.MakeStrideView();
Type stype = buffer->DefaultIndexType();
Expr stride = make_const(stype, 1);
for (size_t i = stride_view->strides.size(); i != 0; --i) {
size_t k = i - 1;
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
Expr value = cast(buffer->shape[k].type(),
Load::make(tvm_shape_type, v_strides,
IntImm::make(Int(32), k), const_true(1)));
value = tvm::if_then_else(is_null, stride, value);
value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
Bind_(buffer->strides[k], value, field_name.str(), true);
stride = stride * buffer->shape[k];
}
} else {
std::ostringstream stride_null_err_msg;
stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
Expand Down
4 changes: 2 additions & 2 deletions src/pass/inject_copy_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class CopyIntrinInjector : public IRMutator {
store_strides[loop_var_size],
store->buffer_var->name_hint,
GetStorageScope(store->buffer_var.get()),
0, 0);
0, 0, "");
Buffer src = BufferNode::make(
Var(load->buffer_var.node_),
load->type,
Expand All @@ -169,7 +169,7 @@ class CopyIntrinInjector : public IRMutator {
src_elem_offset,
load->buffer_var->name_hint,
GetStorageScope(load->buffer_var.get()),
0, 0);
0, 0, "");
*out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
CHECK(out->defined()) << "flower function did not return correct stmt";
return true;
Expand Down
2 changes: 1 addition & 1 deletion src/pass/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class StorageFlattener : public IRMutator {
Var(key.GetName(), Handle()),
op->type, shape, strides, Expr(),
key.GetName(), skey.to_string(),
align, 0);
align, 0, "");

buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_lang_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
import tvm
from tvm.schedule import Buffer
import numpy as np

def test_buffer():
m = tvm.var('m')
Expand Down Expand Up @@ -108,10 +109,39 @@ def assert_simplified_equal(index_simplified, index_direct):
index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1))))
assert_simplified_equal(index_simplified, index_direct)

def test_buffer_broadcast():
m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2")

A = tvm.placeholder((m0, m1, m2), name='A')
B = tvm.placeholder((n0, n1, n2), name='B')

C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')

Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="broadcast")
Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="broadcast")
s = tvm.create_schedule(C.op)

def check():
if not tvm.module.enabled("llvm"):
return
fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
fadd(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())

check()


if __name__ == "__main__":
test_buffer()
test_buffer_access_ptr()
test_buffer_access_ptr_offset()
test_buffer_access_ptr_extent()
test_buffer_vload()
test_buffer_index_merge_mult_mod()
test_buffer_broadcast()
2 changes: 1 addition & 1 deletion topi/include/topi/detail/extern.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ inline Buffer DeclExternBuffer(Array<Expr> shape,
auto data = var(name, Handle());
auto elem_offset = Expr();
return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
-1, 0);
-1, 0, "");
}

/*!
Expand Down