diff --git a/nnvm/src/compiler/alter_op_layout.cc b/nnvm/src/compiler/alter_op_layout.cc index fc68423f4233..f70b4f1b625b 100644 --- a/nnvm/src/compiler/alter_op_layout.cc +++ b/nnvm/src/compiler/alter_op_layout.cc @@ -46,7 +46,7 @@ Graph AlterOpLayout(const Graph& src) { std::vector > in_layouts_of_node(idx_graph.num_nodes()); std::vector > out_layouts_of_node(idx_graph.num_nodes()); - std::unordered_map new_nodes; + std::unordered_map unchanged_nodes; if (src.HasAttr("layout")) { // record layouts so that LayoutTransform pass can fix layouts correctly, @@ -56,6 +56,11 @@ Graph AlterOpLayout(const Graph& src) { const auto& layouts = src.GetAttr >("layout"); for (uint32_t nid = 0; nid < idx_graph.num_nodes(); ++nid) { const auto &inode = idx_graph[nid]; +<<<<<<< HEAD +======= + // record input layouts for all nodes, + // while replaced nodes will ignore the records here and have undefined input layouts. +>>>>>>> [Bugfix] Recover original layout when alter_layout function return None (#2101) std::vector in_layout; for (const auto& e : inode.inputs) { in_layout.emplace_back(layouts[idx_graph.entry_id(e)]); @@ -76,7 +81,8 @@ Graph AlterOpLayout(const Graph& src) { nnvm::compiler::FTVMAlterOpLayout fn_alter_op_layout = falter_op_layout.get(n->op(), nullptr); if (fn_alter_op_layout == nullptr) { - new_nodes[n.get()] = nid; + // will restore the original input layouts later. + unchanged_nodes[n.get()] = nid; return false; } @@ -102,8 +108,13 @@ Graph AlterOpLayout(const Graph& src) { Symbol op; bool do_alter = fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos, &op); - if (do_alter) *ret = op.outputs; - else new_nodes[n.get()] = nid; + + if (do_alter) { + *ret = op.outputs; + } else { + // will restore the original input layouts later. + unchanged_nodes[n.get()] = nid; + } return do_alter; }; @@ -115,15 +126,15 @@ Graph AlterOpLayout(const Graph& src) { std::vector ret_layouts(ret_idx.num_node_entries(), Layout::Undef()); for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) { const auto& inode = ret_idx[nid]; - if (new_nodes.count(inode.source)) { + if (unchanged_nodes.count(inode.source)) { const std::vector& in_layouts = - in_layouts_of_node[new_nodes[inode.source]]; + in_layouts_of_node[unchanged_nodes[inode.source]]; for (uint32_t i = 0; i < inode.inputs.size(); ++i) { const auto& e = inode.inputs[i]; ret_layouts[ret_idx.entry_id(e)] = in_layouts[i]; } const std::vector& out_layouts = - out_layouts_of_node[new_nodes[inode.source]]; + out_layouts_of_node[unchanged_nodes[inode.source]]; for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { ret_layouts[ret_idx.entry_id(nid, i)] = out_layouts[i]; } diff --git a/nnvm/tests/python/compiler/test_alter_op_layout.py b/nnvm/tests/python/compiler/test_alter_op_layout.py index 0fbf5ad3b479..cc3df61a28c7 100644 --- a/nnvm/tests/python/compiler/test_alter_op_layout.py +++ b/nnvm/tests/python/compiler/test_alter_op_layout.py @@ -45,9 +45,61 @@ def alter_conv2d_layout(attrs, inputs, tinfos): # check copy layouts for node in ["data", "relu", "flatten", "softmax", "conv_weight"]: - assert(layouts[node] == layouts_origin[node]) - assert(layouts["conv_alter"] == layouts_origin["conv"]) + assert layouts[node] == layouts_origin[node] + assert layouts["conv_alter"] == layouts_origin["conv"] + + +def test_consecutive_alter_layout(): + data = sym.Variable("data", shape=(1, 32, 512, 512)) + pool1 = sym.global_avg_pool2d(data, name="global_avg_pool2d_1", layout="NCHW") + pool2 = sym.global_avg_pool2d(pool1, name="global_avg_pool2d_2", layout="NCHW") + relu = sym.relu(pool2, name="relu") + + g = graph.create(relu) + g = g.apply("CorrectLayout") + g = graph_attr.set_dtype_inputs(g, "float32") + g = g.apply(["InferShape", "InferType"]) + assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW'] + + @reg.register_alter_op_layout("global_avg_pool2d", level=100) + def alter_global_avg_pool2d_layout(attrs, inputs, tinfos): + new_attrs = {k : attrs[k] for k in attrs.keys()} + new_attrs["layout"] = "NCHW16c" + return sym.global_avg_pool2d(inputs[0], **new_attrs) + + g = g.apply("AlterOpLayout") + + # pool1 get replaced - output layout of pool1 is not recorded + # pool2 get replaced - input layout of pool2 is not recorded + # thus the second entry must be undefined - it can neither recover from pool1's output, + # nor from pool2's input. + assert g.json_attr("layout") == ['NCHW', '__undef__', 'NCHW', 'NCHW'] + + +def test_alter_func_return_none(): + data = sym.Variable("data", shape=(1, 32, 512, 512)) + pool1 = sym.global_max_pool2d(data, name="pool1", layout="NCHW") + pool2 = sym.global_max_pool2d(pool1, name="pool2", layout="NCHW") + relu = sym.relu(pool2, name="relu") + + g = graph.create(relu) + g = g.apply("CorrectLayout") + g = graph_attr.set_dtype_inputs(g, "float32") + g = g.apply(["InferShape", "InferType"]) + assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW'] + + @reg.register_alter_op_layout("global_max_pool2d", level=100) + def alter_global_max_pool2d_layout(attrs, inputs, tinfos): + return None + + g = g.apply("AlterOpLayout") + + # alter func return none, nothing get replaced, + # the layouts should remain the same + assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW'] if __name__ == "__main__": test_alter_conv2d_layout() + test_consecutive_alter_layout() + test_alter_func_return_none()