diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 56aa1d6dcaf8..f79b63778430 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -81,7 +81,7 @@ def _dim_check(attrs): def _get_param(params, input_node): if isinstance(input_node, _expr.Constant): return np.atleast_1d(input_node.data.asnumpy()) - return params.pop(input_node.name_hint).asnumpy() + return params[input_node.name_hint].asnumpy() def _get_num_param(params, input_node): return _get_param(params, input_node).item() diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index b1efe4a8c26f..13db8fd0dce7 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -3058,6 +3058,21 @@ def test_forward_add_n(): _test_forward_add_n(in4) _test_forward_add_n(in5) +####################################################################### +# Sharing params case +# ---------------------- + + +def test_sharing_node(): + """Test the sharing params case.""" + np_data = np.random.uniform(size=(2,2,2)).astype('float32') + with tf.Graph().as_default(): + in_data = tf.placeholder(tf.float32, shape=(2, 2, 2), name='in_data') + axis = tf.constant([-1], dtype=tf.int32, name='axis') + mean0 = tf.reduce_mean(in_data, axis=axis, keepdims=False, name='mean0') + mean1 = tf.reduce_mean(in_data, axis=axis, keepdims=False, name='mean1') + out = tf.add(mean0, mean1, name='out') + compare_tf_with_tvm([np_data], ['in_data:0'], 'out:0') ####################################################################### # Unravel Index @@ -3311,3 +3326,6 @@ def test_forward_isfinite(): # Internal misc. ops test_read_variable_op() + + # Sharing params case using Mean ops + test_sharing_node()