Skip to content

Commit

Permalink
fix CorrectLayout for softmax & log_softmax (apache#1401)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu authored and sergei-mironov committed Aug 8, 2018
1 parent 9214ccb commit ec2d17c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
4 changes: 2 additions & 2 deletions nnvm/src/top/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ NNVM_REGISTER_OP(softmax)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
Expand Down Expand Up @@ -404,7 +404,7 @@ NNVM_REGISTER_OP(log_softmax)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
Expand Down
16 changes: 15 additions & 1 deletion nnvm/tests/python/unittest/test_correct_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import nnvm.graph as graph
from nnvm.compiler import graph_attr

# Level 1
def correct_layout(g, layout=None):
if isinstance(g, nnvm.symbol.Symbol):
g = graph.create(g)
Expand All @@ -19,6 +18,7 @@ def correct_layout(g, layout=None):
return g, ldict


# Level 1
def test_dense():
x = sym.Variable("data", shape=(10, 20))
y = sym.dense(x, units=30, name="fc")
Expand Down Expand Up @@ -169,6 +169,19 @@ def test_flatten():
assert(ldict["y"][0] == "__undef__")


def test_softmax():
x = sym.Variable("x", shape=(10, 20, 10, 10))
y = sym.softmax(x, name="y")
g, ldict = correct_layout(y, "NCHW")
assert(ldict["x"][0] == "NCHW")
assert(ldict["y"][0] == "NCHW")
# second pass will insert layout transform
_, ldict = correct_layout(g, "NCHW16c")
assert(ldict["x"][0] == "NCHW16c")
assert(ldict["x_NCHW"][0] == "NCHW")
assert(ldict["y"][0] == "NCHW")


# Level 2
def test_conv2d():
x = sym.Variable("data", shape=(1, 32, 512, 512))
Expand Down Expand Up @@ -327,6 +340,7 @@ def test_reduce():
test_split()
test_batchnorm()
test_flatten()
test_softmax()
test_conv2d()
test_conv2d_transpose()
test_max_pool2d()
Expand Down

0 comments on commit ec2d17c

Please sign in to comment.