diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 79f2e800ee14..7aafe9d82f76 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -468,7 +468,7 @@ There are other options to tune the performance. else return std::vector{"data", "weight", "bias"}; }) -.set_attr("FListOutputNames", +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { return std::vector{"output"}; }) diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index 475b63625166..278c130064fd 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -267,7 +267,7 @@ This could be used for model inference with `row_sparse` weights trained with `S return std::vector{"data", "weight"}; } }) -.set_attr("FListOutputNames", +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { return std::vector{"output"}; }) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 2486be04a52f..20cc4b511cc4 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -24,6 +24,7 @@ import itertools from numpy.testing import assert_allclose, assert_array_equal from mxnet.test_utils import * +from mxnet.base import py_str from common import setup_module, with_seed import unittest @@ -5399,6 +5400,30 @@ def f(x, a, b, c): check_numeric_gradient(quad_sym, [data_np], atol=0.001) +def test_op_output_names_monitor(): + def check_name(op_sym, expected_names): + output_names = [] + + def get_output_names_callback(name, arr): + output_names.append(py_str(name)) + + op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null') + op_exe.set_monitor_callback(get_output_names_callback) + op_exe.forward() + for output_name, expected_name in zip(output_names, expected_names): + assert output_name == expected_name + + data = mx.sym.Variable('data', shape=(10, 3, 10, 10)) + conv_sym = mx.sym.Convolution(data, kernel=(2, 2), num_filter=1, name='conv') + check_name(conv_sym, ['conv_output']) + + fc_sym = mx.sym.FullyConnected(data, num_hidden=10, name='fc') + check_name(fc_sym, ['fc_output']) + + lrn_sym = mx.sym.LRN(data, nsize=1, name='lrn') + check_name(lrn_sym, ['lrn_output', 'lrn_tmp_norm']) + + if __name__ == '__main__': import nose nose.runmodule()