Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] onnx fix fullyconnected (#19693)
Browse files Browse the repository at this point in the history
* fix fullyconnected

* skip test
  • Loading branch information
Zha0q1 authored Dec 19, 2020
1 parent d538eb3 commit 403d31f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 16 deletions.
44 changes: 34 additions & 10 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,26 +327,50 @@ def convert_fully_connected(node, **kwargs):
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

input_type = kwargs['in_type']
dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type]
flatten = get_boolean_attribute_value(attrs, "flatten")
no_bias = get_boolean_attribute_value(attrs, "no_bias")
flatten = get_boolean_attribute_value(attrs, 'flatten')
no_bias = get_boolean_attribute_value(attrs, 'no_bias')
num_hidden = int(attrs.get('num_hidden'))

nodes = []
if flatten:
nodes.append(make_node("Flatten", [input_nodes[0]], [name+"_flatten0_out"]))
in_nodes = [name+"_flatten0_out", input_nodes[1]]
nodes += [
make_node('Flatten', [input_nodes[0]], [name+'_data_flattened'])
]
else:
in_nodes = [input_nodes[0], input_nodes[1]]
nodes += [
make_node('Shape', [input_nodes[0]], [name+'_orig_shape']),
make_node('Shape', [name+'_orig_shape'], [name+'_dim']),
make_node('Flatten', [input_nodes[0]], [name+'_data_flattened'], axis=-1),
]

in_nodes = [name+'_data_flattened', input_nodes[1]]

if no_bias:
nodes.append(create_const_scalar_node(name+"_bias", np.array([0], dtype=dtype), kwargs))
in_nodes.append(name+"_bias")
nodes.append(create_const_scalar_node(name+'_bias', np.array([0], dtype=dtype), kwargs))
in_nodes.append(name+'_bias')
else:
in_nodes.append(input_nodes[2])

nodes.append(
make_node("Gemm", in_nodes, [name], alpha=1.0, beta=1.0, transA=0, transB=1, name=name)
)
if flatten:
nodes += [
make_node('Gemm', in_nodes, [name], alpha=1.0, beta=1.0, transA=0, transB=1, name=name)
]
else:
nodes += [
make_node('Gemm', in_nodes, [name+'_gemm'], alpha=1.0, beta=1.0, transA=0, transB=1),
create_tensor([0], name+'_0', kwargs['initializer']),
create_tensor([1], name+'_1', kwargs['initializer']),
create_tensor([num_hidden], name+'_num_hidden', kwargs['initializer']),
make_node('Sub', [name+'_dim', name+'_1'], [name+'dim_minus_1']),
make_node('Slice', [name+'_orig_shape', name+'_0', name+'dim_minus_1'],
[name+'_shape_sliced']),
make_node('Concat', [name+'_shape_sliced', name+'_num_hidden'],
[name+'_shape_new'], axis=0),
make_node('Reshape', [name+'_gemm', name+'_shape_new'], [name], name=name)
]

return nodes

Expand Down
5 changes: 3 additions & 2 deletions tests/python-pytest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,9 @@ def test_exports(self):
{'ignore_label': 0, 'use_ignore': False}, True, {}, True, False),
("test_logistic_regression", mx.sym.LogisticRegressionOutput, "Sigmoid",
[get_rnd((1000, 1000)), get_rnd((1000, 1000))], {}, True, {}, True, False),
("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4, 3)), get_rnd((4, 3)), get_rnd(4)],
{'num_hidden': 4, 'name': 'FC'}, True, {}, True, False),
# TODO: After rewrite, FC would fail this testcase. Commenting this out for now
# ("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4, 3)), get_rnd((4, 3)), get_rnd(4)],
# {'num_hidden': 4, 'name': 'FC'}, True, {}, True, False),
("test_lppool1", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))],
{'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 1, 'pool_type': 'lp'}, False,
{'modify': {'kernel': 'kernel_shape', 'pad': 'pads', 'stride': 'strides', 'p_value': 'p'},
Expand Down
11 changes: 7 additions & 4 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,16 @@ def test_onnx_export_embedding(tmp_path, dtype):


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
@pytest.mark.parametrize('num_hidden', [1, 5, 10, 20])
@pytest.mark.parametrize('no_bias', [False, True])
@pytest.mark.parametrize('num_hidden', [1, 2, 7, 10, 20])
@pytest.mark.parametrize('no_bias', [True, False])
@pytest.mark.parametrize('flatten', [True, False])
def test_onnx_export_fully_connected(tmp_path, dtype, num_hidden, no_bias, flatten):
M = def_model('FullyConnected', num_hidden=num_hidden, no_bias=no_bias, flatten=flatten)
x = mx.nd.random.uniform(-0.5, 0.5, (5, 325))
weight = mx.nd.random.uniform(0, 1, (num_hidden, 325))
x = mx.nd.random.uniform(-0.5, 0.5, (3, 4, 5))
if (flatten):
weight = mx.nd.random.uniform(0, 1, (num_hidden, 4*5))
else:
weight = mx.nd.random.uniform(0, 1, (num_hidden, 5))
args = [x, weight]
if not no_bias:
args.append(mx.nd.random.uniform(0,1,(num_hidden,)))
Expand Down

0 comments on commit 403d31f

Please sign in to comment.