diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc index ba89e5ceba589..322d77b6d032f 100644 --- a/nnvm/src/top/nn/nn.cc +++ b/nnvm/src/top/nn/nn.cc @@ -345,7 +345,7 @@ NNVM_REGISTER_OP(softmax) .set_num_outputs(1) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>) +.set_attr("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>) .set_support_level(1) .set_attr( "FTVMCompute", [](const NodeAttrs& attrs, @@ -404,7 +404,7 @@ NNVM_REGISTER_OP(log_softmax) .set_num_outputs(1) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>) +.set_attr("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>) .set_attr( "FTVMCompute", [](const NodeAttrs& attrs, const Array& inputs, diff --git a/nnvm/tests/python/unittest/test_correct_layout.py b/nnvm/tests/python/unittest/test_correct_layout.py index c428a2f837ac6..6176586284a7e 100644 --- a/nnvm/tests/python/unittest/test_correct_layout.py +++ b/nnvm/tests/python/unittest/test_correct_layout.py @@ -3,7 +3,6 @@ import nnvm.graph as graph from nnvm.compiler import graph_attr -# Level 1 def correct_layout(g, layout=None): if isinstance(g, nnvm.symbol.Symbol): g = graph.create(g) @@ -19,6 +18,7 @@ def correct_layout(g, layout=None): return g, ldict +# Level 1 def test_dense(): x = sym.Variable("data", shape=(10, 20)) y = sym.dense(x, units=30, name="fc") @@ -169,6 +169,19 @@ def test_flatten(): assert(ldict["y"][0] == "__undef__") +def test_softmax(): + x = sym.Variable("x", shape=(10, 20, 10, 10)) + y = sym.softmax(x, name="y") + g, ldict = correct_layout(y, "NCHW") + assert(ldict["x"][0] == "NCHW") + assert(ldict["y"][0] == "NCHW") + # second pass will insert layout transform + _, ldict = correct_layout(g, "NCHW16c") + assert(ldict["x"][0] == "NCHW16c") + assert(ldict["x_NCHW"][0] == "NCHW") + assert(ldict["y"][0] == "NCHW") + + # Level 2 def test_conv2d(): x = sym.Variable("data", shape=(1, 32, 512, 512)) @@ -327,6 +340,7 @@ def test_reduce(): test_split() test_batchnorm() test_flatten() + test_softmax() test_conv2d() test_conv2d_transpose() test_max_pool2d()