Skip to content

Commit

Permalink
Fix list output names typo (apache#10300)
Browse files Browse the repository at this point in the history
* Add unit test

* Add unit test
  • Loading branch information
reminisce authored and piiswrong committed Mar 28, 2018
1 parent cd6b503 commit c2a517b
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/operator/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ There are other options to tune the performance.
else
return std::vector<std::string>{"data", "weight", "bias"};
})
.set_attr<nnvm::FListInputNames>("FListOutputNames",
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ This could be used for model inference with `row_sparse` weights trained with `S
return std::vector<std::string>{"data", "weight"};
}
})
.set_attr<nnvm::FListInputNames>("FListOutputNames",
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import itertools
from numpy.testing import assert_allclose, assert_array_equal
from mxnet.test_utils import *
from mxnet.base import py_str
from common import setup_module, with_seed
import unittest

Expand Down Expand Up @@ -5399,6 +5400,30 @@ def f(x, a, b, c):
check_numeric_gradient(quad_sym, [data_np], atol=0.001)


def test_op_output_names_monitor():
def check_name(op_sym, expected_names):
output_names = []

def get_output_names_callback(name, arr):
output_names.append(py_str(name))

op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
op_exe.set_monitor_callback(get_output_names_callback)
op_exe.forward()
for output_name, expected_name in zip(output_names, expected_names):
assert output_name == expected_name

data = mx.sym.Variable('data', shape=(10, 3, 10, 10))
conv_sym = mx.sym.Convolution(data, kernel=(2, 2), num_filter=1, name='conv')
check_name(conv_sym, ['conv_output'])

fc_sym = mx.sym.FullyConnected(data, num_hidden=10, name='fc')
check_name(fc_sym, ['fc_output'])

lrn_sym = mx.sym.LRN(data, nsize=1, name='lrn')
check_name(lrn_sym, ['lrn_output', 'lrn_tmp_norm'])


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit c2a517b

Please sign in to comment.