Skip to content

Commit

Permalink
[Lang] Fix undef BijectiveLayout and add scalar layout support (#3105)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu authored and tqchen committed Apr 29, 2019
1 parent 73f87ae commit 9d002e8
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 17 deletions.
6 changes: 4 additions & 2 deletions include/tvm/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,13 @@ class Layout;
// Internal node container Buffer
class LayoutNode : public Node {
public:
/*! \brief string representation of layout */
/*! \brief string representation of layout, "" for scalar. */
std::string name;
/*! \brief specify each axis of the layout,
* in which the variable name is the name of the axis.
* The IterVar's extent indicates the size of the axis,
* it is a variable for a primal axis, but a constant for a subordinate axis.
* Empty for scalar's layout.
*/
Array<IterVar> axes;

Expand All @@ -122,6 +123,7 @@ class LayoutNode : public Node {
* For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
* Layout for scalar is defined, while both its name and axes have size 0.
*/
class Layout : public NodeRef {
public:
Expand Down Expand Up @@ -175,7 +177,7 @@ class Layout : public NodeRef {
* that starts at dimension \p pos and spans \p len dimensions
* (or until the end of the layout, whichever comes first).
* \param pos The start position.
* \param len The length of the sub-layout.
* \param len The length of the sub-layout. if 0, return layout of scalar
* \return A newly constructed Layout object.
*/
Layout SubLayout(size_t pos, size_t len) const;
Expand Down
9 changes: 8 additions & 1 deletion src/lang/data_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,14 @@ Layout::Layout(const Array<IterVar>& axes) {
}

Layout::Layout(const std::string& name) { // NOLINT(*)
if (name.empty() || name == "__undef__") return;
if (name == "__undef__") return;

node_ = make_node<LayoutNode>();
LayoutNode *node = operator->();
node->name = name;

if (name.empty()) return; // scalar

// parse layout string
int32_t factor = 0;
for (char c : name) {
Expand Down Expand Up @@ -146,6 +148,7 @@ Layout LayoutNode::make(const std::string& layout) {

Layout Layout::SubLayout(size_t pos, size_t len) const {
if (!defined() || pos > ndim()) return Layout::Undef();
if (len == 0) return Layout(Array<IterVar>());
if (pos + len > ndim()) len = ndim() - pos;
Array<IterVar> new_layout;
const auto axes = operator->()->axes;
Expand Down Expand Up @@ -195,6 +198,10 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
inline bool GetStoreRule(Array<Expr>* rule,
const Layout& src_layout,
const Layout& dst_layout) {
if (!src_layout.defined() || src_layout.name().empty() ||
!dst_layout.defined() || dst_layout.name().empty()) {
return false;
}
for (size_t i = 0; i < dst_layout.ndim(); ++i) {
const auto& store_axis = dst_layout[i];
const IterVar& store_axis_impl = dst_layout->axes[i];
Expand Down
12 changes: 8 additions & 4 deletions src/relay/pass/alter_op_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,19 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
layouts.Set(undef_idx,
layouts[defined_idx].SubLayout(
old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
old_in_shapes[undef_idx].size()));
return Array<Array<Layout> > {layouts, {layouts[defined_idx]}};
old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
old_in_shapes[undef_idx].size()));
return Array<Array<Layout> >{layouts, {layouts[defined_idx]}};
} else {
// only know the tensor with smaller dimensions,
// so we cannot infer the final broadcasted output.
// fails in this case.
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
return Array<Array<Layout> >{{Layout::Undef()}, {Layout::Undef()}};
}
} else if (layouts[0].defined() && layouts[1].defined() &&
(layouts[0].ndim() == 0 || layouts[1].ndim() == 0)) {
int scalar = layouts[0].ndim() == 0 ? 0 : 1;
return Array<Array<Layout> >{layouts, {layouts[1-scalar]}};
} else {
// try to broadcast the tensors to the larger dimension
int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1;
Expand Down
20 changes: 10 additions & 10 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def expected():
b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)


def test_alter_return_none():
Expand All @@ -81,7 +81,7 @@ def alter_conv2d(attrs, inputs, tinfos):

b = before()
b = infer_type(b)
assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)
assert(called[0])


Expand Down Expand Up @@ -147,7 +147,7 @@ def expected():
b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_dual_path():
Expand Down Expand Up @@ -213,7 +213,7 @@ def expected():
b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)

def test_alter_layout_resnet():
"""Test alternating the layout of a residual block
Expand Down Expand Up @@ -273,7 +273,7 @@ def expected():
b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_broadcast_op():
Expand Down Expand Up @@ -323,7 +323,7 @@ def expected():
b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)

def test_alter_layout_scalar():
"""Test alternating the layout of a conv2d.
Expand Down Expand Up @@ -370,7 +370,7 @@ def expected():
b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)

def test_alter_layout_concatenate():
""" """
Expand Down Expand Up @@ -425,7 +425,7 @@ def expected():
b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_nchw_upsamping_op():
Expand Down Expand Up @@ -469,7 +469,7 @@ def expected():
b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_strided_slice():
Expand Down Expand Up @@ -511,7 +511,7 @@ def expected():
b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)


if __name__ == "__main__":
Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_lang_data_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ def test_layout():
def test_bilayout_convertible():
# not convertible
assert tvm.bijective_layout("NCHW", "ABCD") is None
assert tvm.bijective_layout("__undef__", "NCHW") is None
assert tvm.bijective_layout("NCHW", "__undef__") is None
assert tvm.bijective_layout("__undef__", "__undef__") is None
assert tvm.bijective_layout("", "NCHW") is None
assert tvm.bijective_layout("NCHW", "") is None
assert tvm.bijective_layout("", "") is None
# convertible
assert tvm.bijective_layout("NCHW", "NCHW16c") is not None

Expand Down

0 comments on commit 9d002e8

Please sign in to comment.