Skip to content

Commit

Permalink
[Bugfix] Recover original layout when alter_layout function return No…
Browse files Browse the repository at this point in the history
…ne (apache#2101)
  • Loading branch information
yzhliu authored and Wei Chen committed Feb 20, 2019
1 parent 410969f commit 8d65463
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 9 deletions.
22 changes: 15 additions & 7 deletions nnvm/src/compiler/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Graph AlterOpLayout(const Graph& src) {

std::vector<std::vector<Layout> > in_layouts_of_node(idx_graph.num_nodes());
std::vector<std::vector<Layout> > out_layouts_of_node(idx_graph.num_nodes());
std::unordered_map<const Node*, uint32_t> new_nodes;
std::unordered_map<const Node*, uint32_t> unchanged_nodes;

if (src.HasAttr("layout")) {
// record layouts so that LayoutTransform pass can fix layouts correctly,
Expand All @@ -56,6 +56,8 @@ Graph AlterOpLayout(const Graph& src) {
const auto& layouts = src.GetAttr<std::vector<Layout> >("layout");
for (uint32_t nid = 0; nid < idx_graph.num_nodes(); ++nid) {
const auto &inode = idx_graph[nid];
// record input layouts for all nodes,
// while replaced nodes will ignore the records here and have undefined input layouts.
std::vector<Layout> in_layout;
for (const auto& e : inode.inputs) {
in_layout.emplace_back(layouts[idx_graph.entry_id(e)]);
Expand All @@ -76,7 +78,8 @@ Graph AlterOpLayout(const Graph& src) {
nnvm::compiler::FTVMAlterOpLayout fn_alter_op_layout =
falter_op_layout.get(n->op(), nullptr);
if (fn_alter_op_layout == nullptr) {
new_nodes[n.get()] = nid;
// will restore the original input layouts later.
unchanged_nodes[n.get()] = nid;
return false;
}

Expand All @@ -102,8 +105,13 @@ Graph AlterOpLayout(const Graph& src) {
Symbol op;
bool do_alter =
fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos, &op);
if (do_alter) *ret = op.outputs;
else new_nodes[n.get()] = nid;

if (do_alter) {
*ret = op.outputs;
} else {
// will restore the original input layouts later.
unchanged_nodes[n.get()] = nid;
}
return do_alter;
};

Expand All @@ -115,15 +123,15 @@ Graph AlterOpLayout(const Graph& src) {
std::vector<Layout> ret_layouts(ret_idx.num_node_entries(), Layout::Undef());
for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) {
const auto& inode = ret_idx[nid];
if (new_nodes.count(inode.source)) {
if (unchanged_nodes.count(inode.source)) {
const std::vector<Layout>& in_layouts =
in_layouts_of_node[new_nodes[inode.source]];
in_layouts_of_node[unchanged_nodes[inode.source]];
for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
ret_layouts[ret_idx.entry_id(e)] = in_layouts[i];
}
const std::vector<Layout>& out_layouts =
out_layouts_of_node[new_nodes[inode.source]];
out_layouts_of_node[unchanged_nodes[inode.source]];
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
ret_layouts[ret_idx.entry_id(nid, i)] = out_layouts[i];
}
Expand Down
56 changes: 54 additions & 2 deletions nnvm/tests/python/compiler/test_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,61 @@ def alter_conv2d_layout(attrs, inputs, tinfos):

# check copy layouts
for node in ["data", "relu", "flatten", "softmax", "conv_weight"]:
assert(layouts[node] == layouts_origin[node])
assert(layouts["conv_alter"] == layouts_origin["conv"])
assert layouts[node] == layouts_origin[node]
assert layouts["conv_alter"] == layouts_origin["conv"]


def test_consecutive_alter_layout():
data = sym.Variable("data", shape=(1, 32, 512, 512))
pool1 = sym.global_avg_pool2d(data, name="global_avg_pool2d_1", layout="NCHW")
pool2 = sym.global_avg_pool2d(pool1, name="global_avg_pool2d_2", layout="NCHW")
relu = sym.relu(pool2, name="relu")

g = graph.create(relu)
g = g.apply("CorrectLayout")
g = graph_attr.set_dtype_inputs(g, "float32")
g = g.apply(["InferShape", "InferType"])
assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']

@reg.register_alter_op_layout("global_avg_pool2d", level=100)
def alter_global_avg_pool2d_layout(attrs, inputs, tinfos):
new_attrs = {k : attrs[k] for k in attrs.keys()}
new_attrs["layout"] = "NCHW16c"
return sym.global_avg_pool2d(inputs[0], **new_attrs)

g = g.apply("AlterOpLayout")

# pool1 get replaced - output layout of pool1 is not recorded
# pool2 get replaced - input layout of pool2 is not recorded
# thus the second entry must be undefined - it can neither recover from pool1's output,
# nor from pool2's input.
assert g.json_attr("layout") == ['NCHW', '__undef__', 'NCHW', 'NCHW']


def test_alter_func_return_none():
data = sym.Variable("data", shape=(1, 32, 512, 512))
pool1 = sym.global_max_pool2d(data, name="pool1", layout="NCHW")
pool2 = sym.global_max_pool2d(pool1, name="pool2", layout="NCHW")
relu = sym.relu(pool2, name="relu")

g = graph.create(relu)
g = g.apply("CorrectLayout")
g = graph_attr.set_dtype_inputs(g, "float32")
g = g.apply(["InferShape", "InferType"])
assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']

@reg.register_alter_op_layout("global_max_pool2d", level=100)
def alter_global_max_pool2d_layout(attrs, inputs, tinfos):
return None

g = g.apply("AlterOpLayout")

# alter func return none, nothing get replaced,
# the layouts should remain the same
assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']


if __name__ == "__main__":
test_alter_conv2d_layout()
test_consecutive_alter_layout()
test_alter_func_return_none()

0 comments on commit 8d65463

Please sign in to comment.