diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index f4fcd9237ed3..8a5803fddb3e 100644 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -172,6 +172,12 @@ class Add(Elemwise): name = 'add' +class Mul(Elemwise): + """ Operator converter for Mul. + """ + name = 'multiply' + + class Pool(Caffe2OpConverter): """ A helper class for pool op converters. """ @@ -233,6 +239,33 @@ def _impl(cls, inputs, args, params): return out +class ConvTranspose(Caffe2OpConverter): + """ Operator converter for ConvTranspose. + """ + + @classmethod + def _impl(cls, inputs, args, params): + # get number of channels + channels = infer_channels(inputs[1], True) + args['channels'] = channels + _clean_up_pool_args(args) + out = AttrCvt( + op_name=dimension_picker('conv', '_transpose'), + transforms={ + 'kernel_shape': 'kernel_size', + 'pads': ('padding', (0, 0), revert_caffe2_pad), + 'dilations': ('dilation', (1, 1)), + 'order': ('data_layout', ("NCHW"), lambda x: x if isinstance(x, str) else x.decode('UTF-8')), + }, + excludes=[], + ignores=_caffe2_internal_args, + custom_check=dimension_constraint())(inputs[:2], args, params) + use_bias = len(inputs) == 3 + if use_bias: + out = _op.nn.bias_add(out, inputs[2]) + return out + + class Concat(Caffe2OpConverter): """ Operator converter for Concat. """ @@ -353,12 +386,14 @@ def _get_convert_map(): # caffe2 common operators 'Add': Add.get_converter(), 'Sum': Sum.get_converter(), + 'Mul': Mul.get_converter(), 'Softmax': Softmax.get_converter(), # nn 'AveragePool': AveragePool.get_converter(), 'MaxPool': MaxPool.get_converter(), 'Conv': Conv.get_converter(), + 'ConvTranspose': ConvTranspose.get_converter(), 'Concat': Concat.get_converter(), 'FC': FC.get_converter(), 'SpatialBN': SpatialBN.get_converter(),