Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix list output names typo #10300

Merged
merged 2 commits into from
Mar 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/operator/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ There are other options to tune the performance.
else
return std::vector<std::string>{"data", "weight", "bias"};
})
.set_attr<nnvm::FListInputNames>("FListOutputNames",
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ This could be used for model inference with `row_sparse` weights trained with `S
return std::vector<std::string>{"data", "weight"};
}
})
.set_attr<nnvm::FListInputNames>("FListOutputNames",
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import itertools
from numpy.testing import assert_allclose, assert_array_equal
from mxnet.test_utils import *
from mxnet.base import py_str
from common import setup_module, with_seed
import unittest

Expand Down Expand Up @@ -5399,6 +5400,30 @@ def f(x, a, b, c):
check_numeric_gradient(quad_sym, [data_np], atol=0.001)


def test_op_output_names_monitor():
def check_name(op_sym, expected_names):
output_names = []

def get_output_names_callback(name, arr):
output_names.append(py_str(name))

op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
op_exe.set_monitor_callback(get_output_names_callback)
op_exe.forward()
for output_name, expected_name in zip(output_names, expected_names):
assert output_name == expected_name

data = mx.sym.Variable('data', shape=(10, 3, 10, 10))
conv_sym = mx.sym.Convolution(data, kernel=(2, 2), num_filter=1, name='conv')
check_name(conv_sym, ['conv_output'])

fc_sym = mx.sym.FullyConnected(data, num_hidden=10, name='fc')
check_name(fc_sym, ['fc_output'])

lrn_sym = mx.sym.LRN(data, nsize=1, name='lrn')
check_name(lrn_sym, ['lrn_output', 'lrn_tmp_norm'])


if __name__ == '__main__':
import nose
nose.runmodule()