Skip to content

Commit

Permalink
[COREML]multiple output support, reshape, split ops added (apache#6296)
Browse files Browse the repository at this point in the history
* [COREML]multiple output support, reshape, split ops added

* Review comments addressed
  • Loading branch information
siju-samuel authored and Trevor Morris committed Sep 2, 2020
1 parent c3ddeca commit 53f9da7
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 10 deletions.
40 changes: 30 additions & 10 deletions python/tvm/relay/frontend/coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,14 @@ def _ReduceLayerParams(op, inexpr, etab):
raise tvm.error.OpAttributeUnImplemented(msg.format(mode))


def _ReshapeLayerParams(op, inexpr, etab):
return _op.reshape(inexpr, op.targetShape)


def _SplitLayerParams(op, inexpr, etab):
return _op.split(inexpr, op.nOutputs, axis=-3)


_convert_map = {
'NeuralNetworkMeanImage': _NeuralNetworkMeanImage,
'NeuralNetworkImageScaler': _NeuralNetworkImageScaler,
Expand All @@ -435,6 +443,8 @@ def _ReduceLayerParams(op, inexpr, etab):
'MinLayerParams': _MinLayerParams,
'UnaryFunctionLayerParams': _UnaryFunctionLayerParams,
'ReduceLayerParams': _ReduceLayerParams,
'ReshapeLayerParams': _ReshapeLayerParams,
'SplitLayerParams': _SplitLayerParams,
}

# SAME padding: https://www.tensorflow.org/api_guides/python/nn
Expand Down Expand Up @@ -464,7 +474,7 @@ def get_pad_value(data, kernel, stride):
return pad_before, pad_after


def coreml_op_to_relay(op, inname, outname, etab):
def coreml_op_to_relay(op, inname, outnames, etab):
"""Convert coreml layer to a Relay expression and update the expression table.
Parameters
Expand All @@ -474,7 +484,7 @@ def coreml_op_to_relay(op, inname, outname, etab):
inname : str or list of str
Name of the input Relay expression.
outname : str
outnames : str or list of str
Name of the output Relay expression.
etab : relay.frontend.common.ExprTable
Expand All @@ -488,9 +498,17 @@ def coreml_op_to_relay(op, inname, outname, etab):
insym = etab.get_expr(inname)
else:
insym = [etab.get_expr(i) for i in inname]
ret = _convert_map[classname](op, insym, etab)
if outname:
etab.set_expr(outname, ret, force_override=True)
outs = _convert_map[classname](op, insym, etab)

if outnames:
if isinstance(outnames, _base.string_types) or len(outnames) == 1:
outname = outnames if isinstance(outnames, _base.string_types) else outnames[0]
etab.set_expr(outname, outs, force_override=True)
else:
# the number of ouputs from model op and tvm relay must be same
assert len(outnames) == len(outs)
for outname, out in zip(outnames, outs):
etab.set_expr(outname, out, force_override=True)


def from_coreml(model, shape=None):
Expand Down Expand Up @@ -550,16 +568,18 @@ def from_coreml(model, shape=None):
for l in cc.layers:
layertype = l.WhichOneof('layer')
layerop = getattr(l, layertype)
assert len(l.output) == 1
if len(l.input) == 1:
coreml_op_to_relay(layerop, l.input[0], l.output[0], etab)
coreml_op_to_relay(layerop, l.input[0], l.output, etab)
else:
coreml_op_to_relay(layerop, list(l.input), l.output[0], etab)
coreml_op_to_relay(layerop, list(l.input), l.output, etab)

outexpr = [etab.get_expr(o.name) if o.name in etab.exprs else _expr.var(o.name)
for o in spec.description.output]
# for now return first output
outexpr = outexpr[0]

# check there are multiple outputs in the model and all are there in etab
multi_out = all([bool(o.name in etab.exprs) for o in spec.description.output])
outexpr = _expr.Tuple(outexpr) if multi_out else outexpr[0]

func = _function.Function(analysis.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return IRModule.from_expr(func), params
64 changes: 64 additions & 0 deletions tests/python/frontend/coreml/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,68 @@ def _verify_reduce(input_dim, mode, axis, ref_func, dtype='float32'):
_verify_reduce(dshape, "argmax", axis, np.argmax, dtype='int32')


def verify_reshape(input_dim, target_shape, mode):
dtype = 'float32'

a_np = np.random.uniform(-100.0, 100.0, size=input_dim).astype(dtype)
ref_val = np.reshape(a_np, target_shape)

inputs = [('input', datatypes.Array(*input_dim))]
output = [('output', datatypes.Array(*ref_val.shape))]
builder = NeuralNetworkBuilder(inputs, output)
builder.add_reshape(name="reshape",
input_name='input',
output_name='output',
target_shape=target_shape,
mode=mode)

model = cm.models.MLModel(builder.spec)
for target, ctx in ctx_list():
out = run_tvm_graph(model, target, ctx, [a_np],
['input'], ref_val.shape, dtype)
tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)


def test_forward_reshape():
for mode in [0, 1]:
verify_reshape((20,), (1, 2, 2, 5), mode)
verify_reshape((1, 3, 20, 20), (1, 12, 10, 10), mode)


def verify_split(input_dim, nOutputs):
dtype = 'float32'

a_np = np.random.uniform(-100.0, 100.0, size=input_dim).astype(dtype)
ref_val = np.split(a_np, nOutputs, axis=-3)

inputs = [('input', datatypes.Array(*input_dim))]

output_names = []
outputs = []
output_shapes = []
for i, out in enumerate(ref_val):
output_name = "output" + str(i)
output_names = output_names + [output_name]
outputs = outputs + [(output_name, datatypes.Array(*out.shape))]
output_shapes = output_shapes + [out.shape]

builder = NeuralNetworkBuilder(inputs, outputs)
builder.add_split(name="split",
input_name='input',
output_names=output_names)

model = cm.models.MLModel(builder.spec)
for target, ctx in ctx_list():
out = run_tvm_graph(model, target, ctx, [a_np],
['input'], output_shapes, [dtype] * len(output_shapes))
tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)


def test_forward_split():
verify_split((1, 4, 4, 4,), 2)
verify_split((1, 3, 30, 20,), 3)


def verify_image_scaler(input_dim, blue_bias=0.0, green_bias=0.0, red_bias=0.0, image_scale=1.0):
dtype = 'float32'
a_np = np.random.uniform(size=input_dim).astype(dtype)
Expand Down Expand Up @@ -664,6 +726,8 @@ def test_forward_convolution():
test_forward_min()
test_forward_unary()
test_forward_reduce()
test_forward_reshape()
test_forward_split()
test_mobilenet_checkonly()
test_resnet50_checkonly()
test_forward_image_scaler()
Expand Down

0 comments on commit 53f9da7

Please sign in to comment.