From 012de1ca9e18cb35a1f267541d921070bc23e9ac Mon Sep 17 00:00:00 2001 From: Hiroyuki Makino Date: Sat, 2 Feb 2019 15:31:20 +0900 Subject: [PATCH] [Relay][Frontend] Caffe2 Support (#2507) * [Relay][Frontend] Add Caffe2 Support * [Relay][Frontend] Add Caffe2 Support (fix unsed import) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import) * [Relay][Frontend] Add Caffe2 Support (fix model install and reflect code reviews) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 frontend import) * [Relay][Frontend] Add Caffe2 Support (rename function name in test_forward) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import) * [Relay][Frontend] Add Caffe2 Support (fix caffe2 model import) * [Doc] Caffe2 frontend tutorial * [Doc] Caffe2 frontend tutorial * [Doc] Caffe2 frontend tutorial * [Relay][Frontend] Add Caffe2 Support (remove unsed file) --- docker/Dockerfile.ci_gpu | 3 + docker/install/ubuntu_install_caffe2.sh | 3 + python/tvm/relay/frontend/__init__.py | 1 + python/tvm/relay/frontend/caffe2.py | 565 ++++++++++++++++++ .../frontend/caffe2/model_zoo/__init__.py | 29 + .../frontend/caffe2/model_zoo/squeezenet.py | 132 ++++ tests/python/frontend/caffe2/test_forward.py | 87 +++ tests/python/frontend/caffe2/test_graph.py | 21 + tests/scripts/task_python_frontend.sh | 4 + tutorials/frontend/from_caffe2.py | 130 ++++ 10 files changed, 975 insertions(+) create mode 100644 docker/install/ubuntu_install_caffe2.sh create mode 100755 python/tvm/relay/frontend/caffe2.py create mode 100644 tests/python/frontend/caffe2/model_zoo/__init__.py create mode 100644 tests/python/frontend/caffe2/model_zoo/squeezenet.py create mode 100644 tests/python/frontend/caffe2/test_forward.py create mode 100755 tests/python/frontend/caffe2/test_graph.py create mode 100644 tutorials/frontend/from_caffe2.py diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index fa15113289d0..6a599b1e3917 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -67,6 +67,9 @@ RUN bash /install/ubuntu_install_onnx.sh COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh RUN bash /install/ubuntu_install_tflite.sh +COPY install/ubuntu_install_caffe2.sh /install/ubuntu_install_caffe2.sh +RUN bash /install/ubuntu_install_caffe2.sh + RUN pip3 install Pillow COPY install/ubuntu_install_vulkan.sh /install/ubuntu_install_vulkan.sh diff --git a/docker/install/ubuntu_install_caffe2.sh b/docker/install/ubuntu_install_caffe2.sh new file mode 100644 index 000000000000..5fe827927e87 --- /dev/null +++ b/docker/install/ubuntu_install_caffe2.sh @@ -0,0 +1,3 @@ +python3 -m caffe2.python.models.download -i -f squeezenet +python3 -m caffe2.python.models.download -i -f resnet50 +python3 -m caffe2.python.models.download -i -f vgg19 diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index e8917bcdb598..d582e02e5cc7 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -12,3 +12,4 @@ from .onnx import from_onnx from .tflite import from_tflite from .coreml import from_coreml +from .caffe2 import from_caffe2 diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py new file mode 100755 index 000000000000..69d3c3642cfe --- /dev/null +++ b/python/tvm/relay/frontend/caffe2.py @@ -0,0 +1,565 @@ +# pylint: disable=import-self, invalid-name, line-too-long, unused-argument +"""Caffe2 frontend""" +from __future__ import absolute_import as _abs +from .. import ir_pass +from .. import expr as _expr +from .. import op as _op +from ... import nd as _nd +from .common import AttrCvt, Renamer +from .common import get_relay_op, new_var, infer_channels + +__all__ = ['from_caffe2'] + +def dimension_picker(prefix, surfix=''): + def _impl(attr): + kernel = attr['kernel_shape'] + if len(kernel) == 2: + return prefix + '2d' + surfix + else: + raise NotImplementedError("Only 2d kernel supported.") + + return _impl + + +def revert_caffe2_pad(pads): + """Caffe2 requires two times the normal padding.""" + if len(pads) == 4: + pads = pads[:2] + elif len(pads) == 2: + pass + else: + raise ValueError("Invalid caffe2 type padding: {}".format(pads)) + return pads + + +def dimension_constraint(): + def _dim_check(args): + if len(args['kernel_shape']) == 2: + return True + return False + + return _dim_check, "Only 2d kernel supported." + + +def _clean_up_pool_args(args): + """ A helper function to clean up common arguments in conv and pooling ops. + """ + assert isinstance(args, dict) + + if 'stride_h' in args and 'stride_w' in args: + assert 'stride' not in args and 'strides' not in args + args['strides'] = [args['stride_h'], args['stride_w']] + args.pop('stride_h') + args.pop('stride_w') + elif 'stride' in args: + args['strides'] = [args['stride'], args['stride']] + args.pop('stride') + + # rename 'kernel', 'kernels', to 'kernel_shape' + if 'kernel_h' in args and 'kernel_w' in args: + assert 'kernel' not in args and 'kernels' not in args + args['kernel_shape'] = [args['kernel_h'], args['kernel_w']] + args.pop('kernel_h') + args.pop('kernel_w') + elif 'kernel' in args: + args['kernel_shape'] = [args['kernel'], args['kernel']] + args.pop('kernel') + elif 'kernels' in args: + args['kernel_shape'] = args['kernels'] + args.pop('kernels') + + if 'pad_t' in args and 'pad_l' in args and 'pad_b' in args and 'pad_r' in args: + assert 'pad' not in args and 'pads' not in args + args['pads'] = [ + args['pad_t'], args['pad_l'], args['pad_b'], args['pad_r'] + ] + for pad in ['pad_t', 'pad_l', 'pad_b', 'pad_r']: + args.pop(pad) + elif 'pad' in args: + args['pads'] = [args['pad'], args['pad']] + args.pop('pad') + + if 'dilation_h' in args and 'dilation_w' in args: + assert 'dilation' not in args and 'dilations' not in args + args['dilations'] = [args['dilation_h'], args['dilation_w']] + args.pop('dilation_h') + args.pop('dilation_w') + elif 'dilation' in args: + args['dilations'] = [args['dilation'], args['dilation']] + args.pop('dilation') + + return args + + +class Caffe2OpConverter(object): + """ A helper class for holding Caffe2 op converters. + """ + + @classmethod + def get_converter(cls): + """ Get converter. + + :return: converter, which should be `_impl`. + """ + + if hasattr(cls, '_impl'): + return getattr(cls, '_impl') + else: + raise NotImplementedError('{} not implemented'.format( + cls.__name__)) + + +_caffe2_internal_args = [ + # nnpack args + 'algo', + 'convolution_transform_strategy', + 'float16_compute', + 'shared_buffer', + + # training args + 'init_params', + 'cudnn_exhaustive_search', + 'exhaustive_search', + + # training args + 'adj', + 'hwgq', + + # args that we don't care + 'legacy_pad', +] + + +class Elemwise(Caffe2OpConverter): + """ A helper class for elemwise op converters. + """ + name = '' + @classmethod + def _math_name_picker(cls, suffix): + + def _impl(attr): + if attr.get('broadcast', 0): + return 'broadcast_' + suffix + return 'elemwise_' + suffix + + return _impl + + @classmethod + def _impl(cls, inputs, args, params): + assert len(inputs) == 2, "Math op take 2 inputs, {} given".format( + len(inputs)) + op_name = cls._math_name_picker(cls.name)(args) + axis = int(args.get('axis', 0)) + conv_ops = ["conv2d", "conv2d_transpose"] + if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops: + # TODO(zhreshold): remove hard coded infershape + inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2) + return get_relay_op(op_name)(*inputs) + + +class Add(Elemwise): + """ Operator converter for Add. + """ + name = 'add' + + +class Pool(Caffe2OpConverter): + """ A helper class for pool op converters. + """ + + name = '' + @classmethod + def _impl(cls, inputs, args, params): + _clean_up_pool_args(args) + if 'global_pooling' in args and args['global_pooling'] == 1: + op_name = dimension_picker('global_' + cls.name) + return get_relay_op(op_name(args))(*inputs) + + return AttrCvt( + op_name=dimension_picker(cls.name), + transforms={ + 'kernel_shape': 'pool_size', + 'pads': ('padding', (0, 0), revert_caffe2_pad), + 'strides': 'strides', + }, + ignores=['dilations', 'order', 'legacy_pad', 'global_pooling'], + extras={'ceil_mode': False}, + custom_check=dimension_constraint())(inputs, args, params) + + +class AveragePool(Pool): + name = 'avg_pool' + + +class MaxPool(Pool): + name = 'max_pool' + + +class Conv(Caffe2OpConverter): + """ Operator converter for Conv. + """ + + @classmethod + def _impl(cls, inputs, args, params): + # get number of channels + channels = infer_channels(inputs[1]) + args['channels'] = channels + _clean_up_pool_args(args) + out = AttrCvt( + op_name=dimension_picker('conv'), + transforms={ + 'group': ('groups', 1), + 'kernel_shape': 'kernel_size', + 'pads': ('padding', (0, 0), revert_caffe2_pad), + 'strides': 'strides', + 'dilations': ('dilation', (1, 1)), + 'order': ('data_layout', ("NCHW"), lambda x: x if isinstance(x, str) else x.decode('UTF-8')), + }, + excludes=[], + ignores=[], + custom_check=dimension_constraint())(inputs[:2], args, params) + use_bias = len(inputs) == 3 + if use_bias: + out = _op.nn.bias_add(out, inputs[2]) + return out + + +class Concat(Caffe2OpConverter): + """ Operator converter for Concat. + """ + + @classmethod + def _impl(cls, inputs, args, params): + def _get_axis_from_order_str(order): + order = order if isinstance(order, str) else order.decode('UTF-8') + if order == 'NCHW': + return 1 + elif order == 'NHWC': + return 3 + else: + raise RuntimeError( + "Unsupported storage order: {} in caffe2".format(order)) + + return AttrCvt( + op_name='concatenate', + transforms={ + 'order': ('axis', (1), _get_axis_from_order_str), + }, + excludes=['add_axis'])((inputs,), args, params) + + +class NormalizePlanarYUV(Caffe2OpConverter): + """ Operator converter for NormalizePlanarYUV. + caffe2 definition: https://github.com/pytorch/pytorch/blob/master/caffe2/operators/norm_planar_yuv_op.cc + """ + + @classmethod + def _impl(cls, inputs, args, params): + assert len(inputs) == 3 + mean = _op.expand_dims(inputs[1], axis=2, num_newaxis=2) + std = _op.expand_dims(inputs[2], axis=2, num_newaxis=2) + + return _op.broadcast_divide(_op.subtract(inputs[0], mean), std) + + +class ResizeNearest(Caffe2OpConverter): + """ Operator converter for Upsample (nearest mode). + """ + + @classmethod + def _impl(cls, inputs, args, params): + width_scale = args['width_scale'] if 'width_scale' in args else 1 + height_scale = args['height_scale'] if 'height_scale' in args else 1 + assert width_scale == height_scale + + return _op.nn.upsampling( + inputs[0], scale=int(width_scale), method="NEAREST_NEIGHBOR") + + +class Sum(Caffe2OpConverter): + """ Operator converter for Sum. + """ + + @classmethod + def _impl(cls, inputs, args, params): + # Sum Operator + for in_index in range(len(inputs) - 1): + inputs[in_index + 1] = _op.add(inputs[in_index], inputs[in_index + 1]) + + return inputs[len(inputs) - 1] + + +class Softmax(Caffe2OpConverter): + """ Operator converter for Softmax. + """ + + @classmethod + def _impl(cls, inputs, args, params): + # set default value when axis is not set in the model + if 'axis' not in args: + args['axis'] = 1 + return AttrCvt('softmax', transforms={'axis': ('axis', args['axis'])})(inputs, args, params) + + +class FC(Caffe2OpConverter): + """ Operator converter for FC. + """ + + @classmethod + def _impl(cls, inputs, args, params): + inputs[0] = _op.nn.batch_flatten(inputs[0]) + units = infer_channels(inputs[1]) + res = _op.nn.dense(inputs[0], inputs[1], units=units) + use_bias = len(inputs) == 3 + if use_bias: + res = _op.nn.bias_add(res, inputs[2]) + return res + + +class SpatialBN(Caffe2OpConverter): + """ Operator converter for SpatialBN. + """ + + @classmethod + def _impl(cls, inputs, args, params): + return AttrCvt( + op_name='batch_norm', + disables=['momentum'], + ignores=[ + 'order', 'spatial', 'is_test', 'consumed_inputs', 'num_batches' + ])(inputs, args, params) + + +# compatible operators that do NOT require any conversion. +_identity_list = [] + +# _convert_map defines maps of name to converter functor(callable) +# for 1 to 1 mapping, use Renamer if nothing but name is different +# use AttrCvt if attributes need to be converted +# for 1 to N mapping(composed), use custom callable functions +# for N to 1 mapping, currently not supported(?) + +# Minimal set of ops for squeezenet and resnet50 +def _get_convert_map(): + return { + # caffe2 common operators + 'Add': Add.get_converter(), + 'Sum': Sum.get_converter(), + 'Softmax': Softmax.get_converter(), + + # nn + 'AveragePool': AveragePool.get_converter(), + 'MaxPool': MaxPool.get_converter(), + 'Conv': Conv.get_converter(), + 'Concat': Concat.get_converter(), + 'FC': FC.get_converter(), + 'SpatialBN': SpatialBN.get_converter(), + 'ResizeNearest': ResizeNearest.get_converter(), + 'Relu': AttrCvt('relu', {}, ignores=['order']), + 'Sigmoid': Renamer('sigmoid'), + 'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']), + + # c2 image preprocessing ops + 'NormalizePlanarYUV': NormalizePlanarYUV.get_converter(), + } + + +class Caffe2NetDef(object): + """A helper class for handling Relay expression copying from pb2.GraphProto. + Definition: https://github.com/pytorch/pytorch/blob/master/caffe2/proto/caffe2.proto + """ + + def __init__(self, shape, dtype): + self._nodes = {} + self._params = {} + self._visited_nodes = set() + self._ops = {} + self._shape = shape + self._dtype = dtype + + def from_caffe2(self, init_net, predict_net): + """Construct Relay expression from caffe2 graph. + + Parameters + ---------- + init_net : protobuf object + predict_net : protobuf object + + Returns + ------- + func : tvm.relay.expr.Function + Compatible relay function + params : dict + A dict of name: tvm.nd.array pairs, used as pretrained weights + """ + from caffe2.python import workspace + workspace.RunNetOnce(init_net) + + # Input + input_name = predict_net.op[0].input[0] + + # Params + self._params = {} + used_blobs = set() + for c2_op in predict_net.op: + for i in c2_op.input: + used_blobs.add(i) + for blob in workspace.Blobs(): + if blob in used_blobs and blob != input_name: + self._params[blob] = _nd.array(workspace.FetchBlob(blob)) + + # Variables + self._nodes = {} + for blob in predict_net.external_input: + if blob in self._params: + self._nodes[blob] = new_var(blob, shape=self._params[blob].shape, dtype=self._params[blob].dtype) + else: + shape = self._shape[blob] if blob in self._shape else () + if isinstance(self._dtype, dict) and blob in self._dtype: + dtype = str(self._dtype[blob]) + elif isinstance(self._dtype, str): + dtype = self._dtype + else: + dtype = "float32" + self._nodes[blob] = new_var(blob, shape=shape, dtype=dtype) + + # Ops + for c2_op in predict_net.op: + for blob in c2_op.output: + self._ops[blob] = c2_op + + for c2_op in predict_net.op: + self._process_op(c2_op) + + # Outputs + out = [] + for blob in predict_net.external_output: + out.append(self._nodes[blob]) + + if len(out) > 1: + outputs = _expr.Tuple(out) + else: + outputs = out[0] + + func = _expr.Function(ir_pass.free_vars(outputs), outputs) + + return func, self._params + + def _get_node(self, blob): + """Get the Symbol of blob and detect cyclic dependency in the graph.""" + if blob in self._nodes: + return self._nodes[blob] + + assert blob not in self._visited_nodes, 'Cyclic dependency in the graph (in {})'.format( + blob) + self._visited_nodes.add(blob) + + self._process_op(self._ops[blob]) + return self._nodes[blob] + + def _process_op(self, c2_op): + op_type = c2_op.type + args = self._parse_arg(c2_op.arg) + inputs = [self._get_node(i) for i in c2_op.input] + tvm_op = self._convert_operator(op_type, inputs, args) + + if not isinstance(tvm_op, _expr.TupleWrapper): + self._nodes[c2_op.output[0]] = tvm_op + else: + for k, i in zip(list(c2_op.output), range(len(tvm_op))): + self._nodes[k] = tvm_op[i] + + def _parse_arg(self, arg): + """Convert a list of Argument to a dict, with names as keys.""" + args = {} + for a in arg: + for f in ['f', 'i', 's']: + if a.HasField(f): + args[a.name] = getattr(a, f) + for f in ['floats', 'ints', 'strings']: + if list(getattr(a, f)): + assert a.name not in args, "Only one type of attr is allowed" + args[a.name] = tuple(getattr(a, f)) + for f in ['n']: + if a.HasField(f): + raise NotImplementedError( + "Field {} is not supported in relay.".format(f)) + for f in ['nets']: + if list(getattr(a, f)): + raise NotImplementedError( + "Field {} is not supported in relay.".format(f)) + if a.name not in args: + raise ValueError("Cannot parse attribute: \n{}\n.".format(a)) + return args + + def _convert_operator(self, + op_type, + inputs, + args, + identity_list=None, + convert_map=None): + """Convert from Caffe2 operator to Relay operator. + The converter must specify conversions explicity for incompatible name, and + apply handlers to operator attributes. + + Parameters + ---------- + op_type : str + Operator name, such as Convolution, FullyConnected + inputs : list of tvm.relay.expr.Function + List of input inputs. + args : dict + Dict of operator attributes + identity_list : list + List of operators that don't require conversion + convert_map : dict + Dict of name : callable, where name is the op's name that + require conversion to relay, callable are functions which + take args and return (new_op_type, new_args) + + Returns + ------- + func : tvm.relay.expr.Function + Converted relay function + """ + identity_list = identity_list if identity_list else _identity_list + convert_map = convert_map if convert_map else _get_convert_map() + if op_type in identity_list: + func = get_relay_op(op_type)(*inputs, **args) + elif op_type in convert_map: + # Add a sanitizing step to convert all byte strings in args to strings + func = convert_map[op_type](inputs, args, self._params) + else: + raise NotImplementedError( + "Operator {} not implemented.".format(op_type)) + return func + + +def from_caffe2(init_net, predict_net, shape=None, dtype="float32"): + """Load caffe2 graph which contains init_net and predict_net into Relay Function. + + Parameters + ---------- + init_net : protobuf object + Caffe2 NetDef containing the weights + + predict_net : protobuf object + Caffe2 NetDef containing the graph + + shape : dict of str to tuple + The input shape to the graph + + dtype : str or dict of str to str + The input types to the graph + + Returns + ------- + sym : tvm.relay.expr.Function + Compatible relay function + + params : dict of str to tvm.ndarray + Dict of converted parameters stored in tvm.ndarray format + """ + + caffe2 = Caffe2NetDef(shape, dtype) + return caffe2.from_caffe2(init_net, predict_net) diff --git a/tests/python/frontend/caffe2/model_zoo/__init__.py b/tests/python/frontend/caffe2/model_zoo/__init__.py new file mode 100644 index 000000000000..18e74add8428 --- /dev/null +++ b/tests/python/frontend/caffe2/model_zoo/__init__.py @@ -0,0 +1,29 @@ +"""Store for caffe2 examples and common models.""" +from __future__ import absolute_import as _abs +import os +import sys +import importlib +from . import squeezenet +from caffe2.python.models.download import ModelDownloader + +models = [ + 'squeezenet', + 'resnet50', + 'vgg19', +] + +mf = ModelDownloader() + +class Model: + def __init__(self, model_name): + self.init_net, self.predict_net, self.value_info = mf.get_c2_model(model_name) + +for model in models: + try: + locals()['c2_' + model] = importlib.import_module('caffe2.python.models.' + model) + except ImportError: + locals()['c2_' + model] = Model(model) + +# squeezenet +def relay_squeezenet(): + return squeezenet.get_workload() diff --git a/tests/python/frontend/caffe2/model_zoo/squeezenet.py b/tests/python/frontend/caffe2/model_zoo/squeezenet.py new file mode 100644 index 000000000000..74ade8989d05 --- /dev/null +++ b/tests/python/frontend/caffe2/model_zoo/squeezenet.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=unused-argument + +""" +Symbol of SqueezeNet + +Reference: +Iandola, Forrest N., et al. +"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016). +""" + +from tvm import relay +from tvm.relay.testing import create_workload + +# Helpers +def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels, prefix=""): + net = _make_fire_conv(net, squeeze_channels, 1, 0, "%s/squeeze1x1" % prefix) + + left = _make_fire_conv(net, expand1x1_channels, 1, 0, "%s/expand1x1" % prefix) + right = _make_fire_conv(net, expand3x3_channels, 3, 1, "%s/expand3x3" % prefix) + # NOTE : Assume NCHW layout here + net = relay.concatenate((left, right), axis=1) + return net + + +def _make_fire_conv(net, channels, kernel_size, padding=0, prefix=""): + net = relay.nn.conv2d(net, relay.var("%s_weight" % prefix), + channels=channels, + kernel_size=(kernel_size, kernel_size), + padding=(padding, padding)) + net = relay.nn.bias_add(net, relay.var("%s_bias" % prefix)) + net = relay.nn.relu(net) + return net + + +# Net +def get_net(batch_size, image_shape, num_classes, dtype): + """Get symbol of SqueezeNet + + Parameters + ---------- + batch_size : int + The batch size used in the model + + image_shape : tuple + The input image shape + + num_classes: int + The number of classification results + + dtype : str + The data type + + """ + data_shape = (batch_size,) + image_shape + net = relay.var("data", shape=data_shape, dtype=dtype) + net = relay.nn.conv2d(net, relay.var("conv1_weight"), + channels=64, + kernel_size=(3, 3), + strides=(2, 2), + padding=(0, 0)) + net = relay.nn.bias_add(net, relay.var("conv1_bias")) + net = relay.nn.relu(net) + net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2)) + net = _make_fire(net, 16, 64, 64, 'fire2') + net = _make_fire(net, 16, 64, 64, "fire3") + net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2)) + net = _make_fire(net, 32, 128, 128, "fire4") + net = _make_fire(net, 32, 128, 128, "fire5") + net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2)) + net = _make_fire(net, 48, 192, 192, "fire6") + net = _make_fire(net, 48, 192, 192, "fire7") + net = _make_fire(net, 64, 256, 256, "fire8") + net = _make_fire(net, 64, 256, 256, "fire9") + net = relay.nn.dropout(net, rate=0.5) + net = relay.nn.conv2d(net, relay.var('conv10_weight'), channels=num_classes, kernel_size=(1, 1)) + net = relay.nn.bias_add(net, relay.var("conv10_bias")) + net = relay.nn.relu(net) + net = relay.nn.global_avg_pool2d(net) + net = relay.nn.softmax(net, axis=1) + args = relay.ir_pass.free_vars(net) + return relay.Function(args, net) + + +def get_workload(batch_size=1, + image_shape=(3, 224, 224), + num_classes=1000, + dtype="float32"): + """Get benchmark workload for SqueezeNet + + Parameters + ---------- + batch_size : int, optional + The batch size used in the model + + num_classes : int, optional + Number of classes + + image_shape : tuple, optional + The input image shape + + dtype : str, optional + The data type + + Returns + ------- + net : relay.Function + The computational graph + + params : dict of str to NDArray + The parameters. + """ + + net = get_net(batch_size, image_shape, num_classes, dtype) + return create_workload(net) diff --git a/tests/python/frontend/caffe2/test_forward.py b/tests/python/frontend/caffe2/test_forward.py new file mode 100644 index 000000000000..655e9bc2bab5 --- /dev/null +++ b/tests/python/frontend/caffe2/test_forward.py @@ -0,0 +1,87 @@ +import numpy as np +import tvm +from tvm.contrib import graph_runtime +from tvm.relay.testing.config import ctx_list +from tvm import relay +from model_zoo import c2_squeezenet, c2_resnet50, c2_vgg19 +from caffe2.python import workspace + + +def get_tvm_output(model, + input_data, + target, + ctx, + output_shape, + output_dtype='float32'): + """ Generic function to execute and get tvm output""" + # supporting multiple inputs in caffe2 in a bit tricky, + # because the input names can appear at the beginning or end of model.predict_net.external_input + assert isinstance(input_data, np.ndarray) + + # here we use the first input blob to the first op to get the input name + input_names = model.predict_net.op[0].input[0] + shape_dict = {input_names: input_data.shape} + dtype_dict = {input_names: input_data.dtype} + func, params = relay.frontend.from_caffe2(model.init_net, model.predict_net, shape_dict, dtype_dict) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(func, target, params=params) + + m = graph_runtime.create(graph, lib, ctx) + + # set inputs + m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype))) + m.set_input(**params) + + # execute + m.run() + + # get outputs + if isinstance(output_shape, list) and isinstance(output_dtype, list): + tvm_output_list = [] + for i, s in enumerate(output_shape): + tvm_output = m.get_output(i, tvm.nd.empty((s), output_dtype[i])) + tvm_output_list.append(tvm_output.asnumpy()) + return tvm_output_list + else: + tvm_output = m.get_output(0, tvm.nd.empty((output_shape), + output_dtype)) + return tvm_output.asnumpy() + + +def get_caffe2_output(model, x, dtype='float32'): + workspace.RunNetOnce(model.init_net) + + input_blob = model.predict_net.op[0].input[0] + workspace.FeedBlob(input_blob, x.astype(dtype)) + workspace.RunNetOnce(model.predict_net) + + output_blob = model.predict_net.external_output[0] + c2_output = workspace.FetchBlob(output_blob) + return c2_output + + +def verify_caffe2_forward_impl(model, data_shape, out_shape): + dtype = 'float32' + data = np.random.uniform(size=data_shape).astype(dtype) + c2_out = get_caffe2_output(model, data, dtype) + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, data, target, ctx, out_shape, dtype) + tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) + + +def test_forward_squeezenet1_1(): + verify_caffe2_forward_impl(c2_squeezenet, (1, 3, 224, 224), (1, 1000, 1, 1)) + + +def test_forward_resnet50(): + verify_caffe2_forward_impl(c2_resnet50, (1, 3, 224, 224), (1, 1000)) + + +def test_forward_vgg19(): + verify_caffe2_forward_impl(c2_vgg19, (1, 3, 224, 224), (1, 1000)) + + +if __name__ == '__main__': + test_forward_squeezenet1_1() + test_forward_resnet50() + test_forward_vgg19() diff --git a/tests/python/frontend/caffe2/test_graph.py b/tests/python/frontend/caffe2/test_graph.py new file mode 100755 index 000000000000..ebcbf5b51770 --- /dev/null +++ b/tests/python/frontend/caffe2/test_graph.py @@ -0,0 +1,21 @@ +"""Test graph equality of caffe2 models.""" +from tvm import relay +from model_zoo import c2_squeezenet, relay_squeezenet + + +def compare_graph(f1, f2): + f1 = relay.ir_pass.infer_type(f1) + f2 = relay.ir_pass.infer_type(f2) + assert relay.ir_pass.alpha_equal(f1, f2) + + +def test_squeeze_net(): + shape_dict = {'data': (1, 3, 224, 224)} + dtype_dict = {'data': 'float32'} + from_c2_func, _ = relay.frontend.from_caffe2(c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict) + relay_func, _ = relay_squeezenet() + compare_graph(from_c2_func, relay_func) + + +if __name__ == '__main__': + test_squeeze_net() diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index e0c4eca6f83f..880c35ee42e0 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -47,3 +47,7 @@ python3 -m nose -v tests/python/frontend/nnvm_to_relay || exit -1 echo "Running relay TFLite frontend test..." python3 -m nose -v tests/python/frontend/tflite || exit -1 + +echo "Running relay caffe2 frondend test..." +python3 -m nose -v tests/python/frontend/caffe2 || exit -1 + diff --git a/tutorials/frontend/from_caffe2.py b/tutorials/frontend/from_caffe2.py new file mode 100644 index 000000000000..fce7f30d865d --- /dev/null +++ b/tutorials/frontend/from_caffe2.py @@ -0,0 +1,130 @@ +""" +Compile Caffe2 Models +===================== +**Author**: `Hiroyuki Makino `_ + +This article is an introductory tutorial to deploy Caffe2 models with Relay. + +For us to begin with, Caffe2 should be installed. + +A quick solution is to install via conda + +.. code-block:: bash + + # for cpu + conda install pytorch-nightly-cpu -c pytorch + # for gpu with CUDA 8 + conda install pytorch-nightly cuda80 -c pytorch + +or please refer to official site +https://caffe2.ai/docs/getting-started.html +""" +###################################################################### +# Utils for downloading files +# ---------------------------- +def download(url, path, overwrite=False): + import os + if os.path.isfile(path) and not overwrite: + print('File {} exists, skip.'.format(path)) + return + print('Downloading from url {} to {}'.format(url, path)) + try: + import urllib.request + urllib.request.urlretrieve(url, path) + except: + import urllib + urllib.urlretrieve(url, path) + +###################################################################### +# Load pretrained Caffe2 model +# ---------------------------- +# We load a pretrained resnet50 classification model provided by Caffe2. +from caffe2.python.models.download import ModelDownloader +mf = ModelDownloader() + +class Model: + def __init__(self, model_name): + self.init_net, self.predict_net, self.value_info = mf.get_c2_model(model_name) + +resnet50 = Model('resnet50') + +###################################################################### +# Load a test image +# ------------------ +# A single cat dominates the examples! +from PIL import Image +from matplotlib import pyplot as plt +import numpy as np +img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true' +download(img_url, 'cat.png') +img = Image.open('cat.png').resize((224, 224)) +plt.imshow(img) +plt.show() +# input preprocess +def transform_image(image): + image = np.array(image) - np.array([123., 117., 104.]) + image /= np.array([58.395, 57.12, 57.375]) + image = image.transpose((2, 0, 1)) + image = image[np.newaxis, :].astype('float32') + return image + +data = transform_image(img) + +###################################################################### +# Compile the model on Relay +# -------------------------- + +# Caffe2 input tensor name, shape and type +input_name = resnet50.predict_net.op[0].input[0] +shape_dict = {input_name: data.shape} +dtype_dict = {input_name: data.dtype} + +# parse Caffe2 model and convert into Relay computation graph +from tvm import relay +func, params = relay.frontend.from_caffe2(resnet50.init_net, resnet50.predict_net, shape_dict, dtype_dict) + +# compile the model +# target x86 cpu +target = 'llvm' +with relay.build_config(opt_level=3): + graph, lib, params = relay.build(func, target, params=params) + +###################################################################### +# Execute on TVM +# --------------- +# The process is no different from other examples. +import tvm +from tvm.contrib import graph_runtime +# context x86 cpu, use tvm.gpu(0) if you run on GPU +ctx = tvm.cpu(0) +# create a runtime executor module +m = graph_runtime.create(graph, lib, ctx) +# set inputs +m.set_input(input_name, tvm.nd.array(data.astype('float32'))) +# set related params +m.set_input(**params) +# execute +m.run() +# get outputs +tvm_out = m.get_output(0) +top1_tvm = np.argmax(tvm_out.asnumpy()[0]) + +##################################################################### +# Look up synset name +# ------------------- +# Look up prediction top 1 index in 1000 class synset. +from caffe2.python import workspace +synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', + '4d0b62f3d01426887599d4f7ede23ee5/raw/', + '596b27d23537e5a1b5751d2b0481ef172f58b539/', + 'imagenet1000_clsid_to_human.txt']) +synset_name = 'synset.txt' +download(synset_url, synset_name) +with open(synset_name) as f: + synset = eval(f.read()) +print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, synset[top1_tvm])) +# confirm correctness with caffe2 output +p = workspace.Predictor(resnet50.init_net, resnet50.predict_net) +caffe2_out = p.run({input_name: data}) +top1_caffe2 = np.argmax(caffe2_out) +print('Caffe2 top-1 id: {}, class name: {}'.format(top1_caffe2, synset[top1_caffe2]))