diff --git a/python/nnvm/compiler/build_module.py b/python/nnvm/compiler/build_module.py index 86fa08ec1..797f1af89 100644 --- a/python/nnvm/compiler/build_module.py +++ b/python/nnvm/compiler/build_module.py @@ -270,6 +270,20 @@ def build(graph, target=None, shape=None, dtype="float32", # Apply optimization with target: graph = optimize(graph, shape, dtype, layout) + + # Clear extra params without nodes. + arg_list = [] + graph_idx = _graph.GraphIndex(graph) + for node in graph_idx.nodes: + if node['op'] == 'null': + arg_list.append(node['name']) + + if params: + param_keys = list(params.keys()) + for key in param_keys: + if key not in arg_list: + params.pop(key) + # Precompute prune if params and cfg.pass_enabled("PrecomputePrune"): graph, params = precompute_prune(graph, params) diff --git a/python/nnvm/frontend/__init__.py b/python/nnvm/frontend/__init__.py index 00ed9e51f..80f66c0d3 100644 --- a/python/nnvm/frontend/__init__.py +++ b/python/nnvm/frontend/__init__.py @@ -5,3 +5,4 @@ from .coreml import from_coreml from .keras import from_keras from .darknet import from_darknet +from .tensorflow import from_tensorflow diff --git a/python/nnvm/frontend/tensorflow.py b/python/nnvm/frontend/tensorflow.py new file mode 100644 index 000000000..09b8074c8 --- /dev/null +++ b/python/nnvm/frontend/tensorflow.py @@ -0,0 +1,512 @@ +# pylint: disable=import-self, invalid-name, unused-argument, too-many-nested-blocks, no-else-return, line-too-long +"""TF: Tensorflow frontend.""" +from __future__ import absolute_import as _abs + +# Tensorflow imports +from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import dtypes + +# Numpy support +import numpy as np + +import tvm +from .. import symbol as _sym +from .. import graph as _graph +from .. compiler import graph_util +from .common import get_nnvm_op, AttrConverter as AttrConvert + +__all__ = ['from_tensorflow'] + +class AttrCvt(object): + """A Wrapper to handle some common jobs: + """ + def __init__(self, op_name, transforms=None, + excludes=None, disables=None, ignores=None, + extras=None, custom_check=None): + self._op_name = op_name + self._transforms = transforms if transforms else {} + self._excludes = excludes if excludes else [] + self._disables = disables if disables else [] + self._ignores = ignores if ignores else [] + self._extras = extras if extras else {} + self._custom_check = custom_check + + def __call__(self, inputs, attrs, *args): + self._ignores.append('_output_shapes') + self._ignores.append('T') + self._ignores.append('use_cudnn_on_gpu') + return AttrConvert(self._op_name, self._transforms, self._excludes, + self._disables, self._ignores, self._extras, self._custom_check)(inputs, attrs, *args) + +def _get_input_shapes(attr): + return [tensor_util.TensorShapeProtoToList(shape) for shape in attr['_output_shapes']] + +def _get_pad(input1d, kernel1d, stride1d): + out1d = (input1d + stride1d - 1) // stride1d + pad = np.maximum((out1d - 1) * stride1d + kernel1d - input1d, 0) + pad = pad // 2 + return pad + +def _math_name_picker(surfix): + def _impl(attr): + return 'broadcast_' + surfix + return _impl + +def _dimension_picker(prefix, surfix=''): + def _impl(attr): + kernel = attr['kernel_shape'] + if len(kernel) == 2: + return prefix + '2d' + surfix + else: + raise NotImplementedError("Only 2d kernel supported.") + return _impl + +def _dimension_constraint(): + def _dim_check(attrs): + if len(attrs['kernel_shape']) == 2: + return True + return False + return _dim_check, "Only 2d kernel supported." + +def _infer_channels(inputs, params, transpose=False): + """A hack for getting 'channles' or 'units' since onnx don't provide + these attributes. We check the shape of weights provided to get the number. + """ + g = _graph.create(inputs) + shape_dict = {k: v.shape for k, v in params.items()} + _, out_shapes = graph_util.infer_shape(g, **shape_dict) + channels = out_shapes[0][0] if not transpose else out_shapes[0][1] + return channels + +def _elemwise(name): + def _impl(inputs, attr, *args): + assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs)) + op_name = _math_name_picker(name)(attr) + axis = int(attr.get('axis', 0)) + conv_ops = ["conv2d", "conv2d_transpose"] + if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops: + # TODO: remove hard coded infershape + inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2) + return get_nnvm_op(op_name)(*inputs) + return _impl + +def _pooling(name): + def _impl(inputs, attr, params): + + attr['data_format'] = attr['data_format'].decode("utf-8") + + if attr['data_format'] == 'NHWC': + attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2]) + if attr['data_format'] == 'NCHW': + attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3]) + + # Fix strides + attr['strides'] = (attr['strides'][1], attr['strides'][2]) + + # Fix padding + input_shapes = _get_input_shapes(attr) + attr['padding'] = attr['padding'].decode("utf-8") + + if attr['padding'] == 'VALID': + attr['padding'] = [0, 0] + elif attr['padding'] == 'SAME': + stride_h, stride_w = attr['strides'] + kernel_h, kernel_w = attr['kernel_shape'] + in_h = input_shapes[0][1] + in_w = input_shapes[0][2] + pad_t = _get_pad(in_h, kernel_h, stride_h) + pad_l = _get_pad(in_w, kernel_w, stride_w) + attr['padding'] = [pad_t, pad_l] + else: + raise TypeError("Unsupported padding type : {}".format(attr['padding'])) + + return AttrCvt( + op_name=_dimension_picker(name), + transforms={ + 'kernel_shape':'pool_size', + 'data_format':'layout'}, + ignores=['ksize'], + extras={'ceil_mode': False}, + custom_check=_dimension_constraint())(inputs, attr) + return _impl + +def _conv(): + def _impl(inputs, attr, params): + attr['data_format'] = attr['data_format'].decode("utf-8") + + # Extract kernel shape from params + conv_param_weights = params[inputs[1].list_output_names()[0]] + + if attr['data_format'] == 'NHWC': + attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1]) + attr['channels'] = conv_param_weights.shape[3] + if 'dilations' in attr: + attr['dilations'] = (attr['dilations'][0], attr['dilations'][1]) + if attr['data_format'] == 'NCHW': + attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3]) + attr['channels'] = conv_param_weights.shape[1] + if 'dilations' in attr: + attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) + + # Fix strides + attr['strides'] = (attr['strides'][1], attr['strides'][2]) + + # Fix padding + input_shapes = _get_input_shapes(attr) + attr['padding'] = attr['padding'].decode("utf-8") + + if attr['padding'] == 'VALID': + attr['padding'] = [0, 0] + elif attr['padding'] == 'SAME': + stride_h, stride_w = attr['strides'] + kernel_h, kernel_w = attr['kernel_shape'] + in_h = input_shapes[0][1] + in_w = input_shapes[0][2] + pad_t = _get_pad(in_h, kernel_h, stride_h) + pad_l = _get_pad(in_w, kernel_w, stride_w) + attr['padding'] = [pad_t, pad_l] + + else: + raise TypeError("Unsupported padding type : {}".format(attr['padding'])) + + if 'kernel_layout' not in attr: + attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW' + + return AttrCvt( + op_name=_dimension_picker('conv'), + transforms={ + 'kernel_shape': 'kernel_size', + 'data_format': 'layout', + 'dilations': ('dilation', (0, 0)), + 'group': ('groups', 1)}, + extras={'use_bias': len(inputs) == 3}, + custom_check=_dimension_constraint())(inputs, attr) + return _impl + +def _decode_image(): + def _impl(inputs, attr, params): + # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. + return inputs[0] + return _impl + +def _cast(): + def _impl(inputs, attr, params): + # Convert from tensorflow Dtype to str + attr['DstT'] = attr['DstT'].name + return AttrCvt(op_name='cast', transforms={'DstT': 'dtype'}, ignores=['SrcT'])(inputs, attr) + return _impl + +def _expand_dims(): + def _impl(inputs, attr, params): + dim_input = inputs.pop(1) + axis = params[dim_input.list_output_names()[0]] + params.pop(dim_input.list_output_names()[0]) + return AttrCvt(op_name="expand_dims", ignores=['Tdim'], extras={'axis': axis.asnumpy()[0]})(inputs, attr) + return _impl + +def _resize_bilinear(): + def _impl(inputs, attr, params): + # TODO: Making a copy node assuming the input image shape is 299x299 + # Change this when we have corresponding resize bilinear operation. + pop_node = inputs.pop(1) + params.pop(pop_node.list_output_names()[0]) + return AttrCvt(op_name="copy", ignores=['align_corners'])(inputs, attr) + return _impl + +def _check_numerics(): + def _impl(inputs, attr, params): + # TODO: Making a copy node assuming no need to verify + return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr) + return _impl + + +def _matmul(): + def _impl(inputs, attr, params): + channels = _infer_channels(inputs[1], params, not attr['transpose_b']) + if attr['transpose_a']: + inputs[0] = _sym.transpose(inputs[0], axis(1, 0)) + if not attr['transpose_b']: + inputs[1] = _sym.transpose(inputs[1], axes=(1, 0)) + return AttrCvt(op_name="dense", + extras={'use_bias': False, 'units': channels}, + ignores=['transpose_a', 'transpose_b', 'T'])(inputs, attr) + + return _impl + +def _identity(): + def _impl(inputs, attr, params): + # TODO: Tensorflow takes CheckNumerics as + # second argument which we could ignore for time being. + if len(inputs) == 2: + pop_node = inputs.pop(1) + params.pop(pop_node.list_output_names()[0]) + return AttrCvt(op_name="copy", ignores=['T'])(inputs, attr) + return _impl + +def _concat(): + def _impl(inputs, attr, params): + pop_node = inputs.pop(0) + axis = params[pop_node.list_output_names()[0]] + params.pop(pop_node.list_output_names()[0]) + return AttrCvt( + op_name="concatenate", ignores=['N'], + extras={'axis': axis.asnumpy()[0]})(inputs, attr) + return _impl + +def _reshape(): + def _impl(inputs, attr, params): + pop_node = inputs.pop(1) + shape_arg = params[pop_node.list_output_names()[0]] + params.pop(pop_node.list_output_names()[0]) + return AttrCvt( + op_name="reshape", + extras={'shape':tuple(shape_arg.asnumpy())}, + ignores=['Tshape'])(inputs, attr) + return _impl + +def _bias_add(): + def _impl(inputs, attr, params): + return _sym.broadcast_add(inputs[0], inputs[1]) + return _impl + +def _batch_norm(): + def _impl(inputs, attr, params): + # Rearrange inputs from + # (data, moving_mean, moving_variance, beta, gamma) to + # to + # (data, gamma, beta, moving_mean, moving_var) + new_inputs = [inputs[0], inputs[4], inputs[3], inputs[1], inputs[2]] + + return AttrCvt( + op_name='batch_norm', + transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'}, + extras={'axis': 3}, # Fix axis + disables=['momentum'])(new_inputs, attr) + return _impl + +# compatible operators that do NOT require any conversion. +_identity_list = [] + +# _convert_map defines maps of name to converter functor(callable) +# for 1 to 1 mapping, use Renamer if nothing but name is different +# use AttrCvt if attributes need to be converted +# for 1 to N mapping(composed), use custom callable functions +# for N to 1 mapping, currently not supported(?) +_convert_map = { + 'AvgPool' : _pooling('avg_pool'), + 'BatchNormWithGlobalNormalization' : _batch_norm(), + 'BiasAdd' : _bias_add(), + 'Cast' : _cast(), + 'CheckNumerics' : _check_numerics(), # TODO + 'Concat' : _concat(), + 'Conv2D' : _conv(), + 'DecodeJpeg' : _decode_image(), + 'ExpandDims' : _expand_dims(), + 'Identity' : _identity(), + 'MatMul' : _matmul(), + 'MaxPool' : _pooling('max_pool'), + 'Mul' : _elemwise('mul'), + 'Relu' : AttrCvt('relu'), + 'Reshape' : _reshape(), + 'ResizeBilinear' : _resize_bilinear(), # TODO + 'Softmax' : AttrCvt('softmax', {'axis': ('axis', 1)}), + 'Sub' : _elemwise('sub'), +} + + +class GraphProto(object): + """ TODO: A helper class for handling nnvm graph copying from Tensorflow GraphDef. + Definition: + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto + """ + def __init__(self): + self._nodes = {} + self._params = {} + self._renames = {} + self._replacements = {} + self._output_shapes = {} + self._num_input = 0 + self._num_param = 0 + self._input_node = '' + + def from_tensorflow(self, graph): + """Construct nnvm nodes from tensor flow graph definition - GraphDef. + TODO: Detailed explanation of TF GraphDef parsing. + + Parameters + ---------- + graph : tensorflow graph definition object + The loaded tensorflow GraphDef + + Returns + ------- + sym : nnvm.sym.Symbol + The returned nnvm symbol + params : dict + A dict of name: tvm.nd.array pairs, used as pretrained weights + """ + # Parse throught all nodes and start extracting + # params aka Const nodes + # input nodes : First const node + # normal nodes : other normal nodes + + for node in graph.node: + # Tensor flow doesn't have seperate list for params extraction. + # Operator name 'Const' is treated as a parameter to build NNVM params dict. + if node.op == "Const": + # TODO: Assuming first Const node as Graph Input node + if self._input_node == '': + self._input_node = node.name + self._num_input += 1 + self._nodes[node.name] = _sym.Variable(name=node.name) + else: + # Rest all nodes are Param nodes, lets parse + self._num_param += 1 + for key, value in node.attr.items(): + if key == 'value': + np_array = tensor_util.MakeNdarray(value.tensor) + array_ndim = len(np_array.shape) + if array_ndim == 0: + new_array = np.empty([1], dtype=np_array.dtype) + new_array[0] = np_array + self._params[node.name] = tvm.nd.array(new_array) + else: + self._params[node.name] = tvm.nd.array(np_array) + self._nodes[node.name] = _sym.Variable(name=node.name, shape=self._params[node.name].shape) + else: + if key != 'dtype' and key != '_output_shapes': + raise NotImplementedError("Other attributes for a Const(param) Node {} ? .".format(key)) + if node.name not in self._nodes: + raise NotImplementedError("Some thing Wrong : Const {} couldn't be converted to Param.".format(node.name)) + else: + attr = self._parse_attr(node.attr) + self._output_shapes[node.name] = [tensor_util.TensorShapeProtoToList(shape) for shape in attr['_output_shapes']] + + try: + inputs = [self._nodes[i] for i in node.input] + except KeyError: + # TODO: Need to find clean way to handle dropout of optional nodes like '^CheckNumerics' + print ("Some Exception while inputs list:", node.input, " ignoring...") + + inputs = self._fix_extranodes(node.op, attr, inputs) + + op = self._convert_operator(node.op, inputs, attr) + # TODO: Assuming only one output. + self._nodes[node.name] = op + node_output = op + # TODO: Assume the final node is the output node + out = node_output + return out, self._params + + def _get_attr(self, buf): + """Returns the value of the attr of this buf with the given `name`. + + Args: + buf: attrvalue protobuf. + + Returns: + The value of the attr, as a Python object. + + Raises: + ValueError: If this op does not have an attr with the given `name`. + """ + fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] + + x = buf + + # Treat an empty oneof value as an empty list. + if not x.WhichOneof("value"): + return [] + if x.HasField("list"): + for f in fields: + if getattr(x.list, f): + if f == "type": + return [dtypes.as_dtype(x) for x in list(getattr(x.list, f))] + else: + return list(getattr(x.list, f)) + return [] + else: + for f in fields: + if x.HasField(f): + if f == "type": + return dtypes.as_dtype(getattr(x, f)) + else: + return getattr(x, f) + assert False, "Unsupported field type in " + str(x) + return [] + + def _parse_attr(self, attr_proto): + """Convert a list of AttributeProto to a dict, with names as keys.""" + attrs = {} + for key, value in attr_proto.items(): + attrs[key] = self._get_attr(value) + + return attrs + + def _convert_operator(self, op_name, inputs, attrs, identity_list=None, convert_map=None): + """Convert from Tensorflow operator to nnvm operator. + The converter must specify conversions explicity for incompatible name, and + apply handlers to operator attributes. + + Parameters + ---------- + op_name : str + Operator name, such as Conv2D, AvgPool + inputs : list of nnvm.Symbol + List of input symbols. + attrs : dict + Dict of operator attributes + identity_list : list + List of operators that don't require conversion + convert_map : dict + Dict of name : callable, where name is the op's name that + require conversion to nnvm, callable are functions which + take attrs and return (new_op_name, new_attrs) + + Returns + ------- + sym : nnvm.Symbol + Converted nnvm Symbol + """ + identity_list = identity_list if identity_list else _identity_list + convert_map = convert_map if convert_map else _convert_map + if op_name in identity_list: + sym = get_nnvm_op(op_name)(*inputs, **attrs) + elif op_name in convert_map: + sym = convert_map[op_name](inputs, attrs, self._params) + else: + raise NotImplementedError("Operator {} not implemented.".format(op_name)) + return sym + + def _fix_extranodes(self, op_name, attr, inputs): + if op_name == "Softmax": + # TODO: require some times flatten of data before it goes to softmax + # Need to relook into this with latest softmax axis support. + op = AttrCvt(op_name='flatten')(inputs, {}) + node_output = op.list_output_names() + for k, i in zip(list(node_output), range(len(node_output))): + self._nodes[k] = op[i] + inputs = [op] + + return inputs + +def from_tensorflow(graph): + """ TODO: Load tensorflow graph which is a python tensorflow graph object into nnvm graph. + The companion parameters will be handled automatically. + + Parameters + ---------- + graph : GraphDef object + Tensorflow GraphDef + + Returns + ------- + sym : nnvm.Symbol + Compatible nnvm symbol + + params : dict of str to tvm.ndarray + Dict of converted parameters stored in tvm.ndarray format + """ + g = GraphProto() + sym, params = g.from_tensorflow(graph) + return sym, params diff --git a/python/nnvm/top/tensor.py b/python/nnvm/top/tensor.py index 1e8688f9f..462a0ec83 100644 --- a/python/nnvm/top/tensor.py +++ b/python/nnvm/top/tensor.py @@ -52,6 +52,15 @@ def _compute(attrs, x, _): reg.register_pattern("copy", OpPattern.ELEMWISE) reg.register_schedule("copy", _fschedule_broadcast) +# cast +@reg.register_compute("cast") +def compute_cast(attrs, inputs, _): + """Compute definition of cast""" + dtype = attrs.get_string("dtype") + return topi.cast(inputs[0], dtype) +reg.register_pattern("cast", OpPattern.ELEMWISE) +reg.register_schedule("cast", _fschedule_broadcast) + # exp reg.register_pattern("exp", OpPattern.ELEMWISE) reg.register_schedule("exp", _fschedule_broadcast) diff --git a/tutorials/from_tensorflow.py b/tutorials/from_tensorflow.py new file mode 100644 index 000000000..0f2d2b76c --- /dev/null +++ b/tutorials/from_tensorflow.py @@ -0,0 +1,206 @@ +""" +Compile Tensorflow Models +==================== +This article is an introductory tutorial to deploy tensorflow models with NNVM. + +For us to begin with, tensorflow module is required to be installed. + +A quick solution is to install tensorlfow from + +https://www.tensorflow.org/install/install_sources +""" + +import nnvm +import tvm +import numpy as np +import os.path +import re + +# Tensorflow imports +import tensorflow as tf +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util + + +repo_base = 'https://github.com/srkreddy1238/dmlc_data/raw/master/models/tensorflow/InceptionV1/' +img_name = 'elephant-299.jpg' +image_url = os.path.join(repo_base, img_name) +model_name = 'classify_image_graph_def-with_shapes.pb' +model_url = os.path.join(repo_base, model_name) +map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt' +map_proto_url = os.path.join(repo_base, map_proto) +lable_map = 'imagenet_synset_to_human_label_map.txt' +lable_map_url = os.path.join(repo_base, lable_map) + + +###################################################################### +# Some helper functions + +def _ProcessGraphDefParam(graph_def): + """Type-checks and possibly canonicalizes `graph_def`.""" + if not isinstance(graph_def, graph_pb2.GraphDef): + # `graph_def` could be a dynamically-created message, so try a duck-typed + # approach + try: + old_graph_def = graph_def + graph_def = graph_pb2.GraphDef() + graph_def.MergeFrom(old_graph_def) + except TypeError: + raise TypeError('graph_def must be a GraphDef proto.') + return graph_def + +class NodeLookup(object): + """Converts integer node ID's to human readable labels.""" + + def __init__(self, + label_lookup_path=None, + uid_lookup_path=None): + if not label_lookup_path: + label_lookup_path = os.path.join( + "./", map_proto) + if not uid_lookup_path: + uid_lookup_path = os.path.join( + "./", lable_map) + self.node_lookup = self.load(label_lookup_path, uid_lookup_path) + + def load(self, label_lookup_path, uid_lookup_path): + """Loads a human readable English name for each softmax node. + + Args: + label_lookup_path: string UID to integer node ID. + uid_lookup_path: string UID to human-readable string. + + Returns: + dict from integer node ID to human-readable string. + """ + if not tf.gfile.Exists(uid_lookup_path): + tf.logging.fatal('File does not exist %s', uid_lookup_path) + if not tf.gfile.Exists(label_lookup_path): + tf.logging.fatal('File does not exist %s', label_lookup_path) + + # Loads mapping from string UID to human-readable string + proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines() + uid_to_human = {} + p = re.compile(r'[n\d]*[ \S,]*') + for line in proto_as_ascii_lines: + parsed_items = p.findall(line) + uid = parsed_items[0] + human_string = parsed_items[2] + uid_to_human[uid] = human_string + + # Loads mapping from string UID to integer node ID. + node_id_to_uid = {} + proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines() + for line in proto_as_ascii: + if line.startswith(' target_class:'): + target_class = int(line.split(': ')[1]) + if line.startswith(' target_class_string:'): + target_class_string = line.split(': ')[1] + node_id_to_uid[target_class] = target_class_string[1:-2] + + # Loads the final mapping of integer node ID to human-readable string + node_id_to_name = {} + for key, val in node_id_to_uid.items(): + if val not in uid_to_human: + tf.logging.fatal('Failed to locate: %s', val) + name = uid_to_human[val] + node_id_to_name[key] = name + + return node_id_to_name + + def id_to_string(self, node_id): + if node_id not in self.node_lookup: + return '' + return self.node_lookup[node_id] + +###################################################################### +# Download processed tensorflow model +# --------------------------------------------- +# In this section, we download a pretrained Tensorflow model and classify an image. +from mxnet.gluon.utils import download + +download(image_url, img_name) +download(model_url, model_name) +download(map_proto_url, map_proto) +download(lable_map_url, lable_map) + + +###################################################################### +# Creates graph from saved graph_def.pb. +with tf.gfile.FastGFile(os.path.join( + "./", model_name), 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + graph = tf.import_graph_def(graph_def, name='') + graph_def = _ProcessGraphDefParam(graph_def) + + +###################################################################### +# Decode image +from PIL import Image +image = Image.open(img_name).resize((299, 299)) + +def transform_image(image): + image = np.array(image) + return image + +x = transform_image(image) +print('x', x.shape) + +###################################################################### +# Import the graph to NNVM +# ----------------- +sym, params = nnvm.frontend.from_tensorflow(graph_def) + +###################################################################### +# Now compile the graph through NNVM +import nnvm.compiler +target = 'llvm' +shape_dict = {'DecodeJpeg/contents': x.shape} +dtype_dict = {'DecodeJpeg/contents': 'uint8'} +graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype=dtype_dict, params=params) + + + +###################################################################### +# Save the compilation output. +""" +lib.export_library("imagenet_tensorflow.so") +with open("imagenet_tensorflow.json", "w") as fo: + fo.write(graph.json()) +with open("imagenet_tensorflow.params", "wb") as fo: + fo.write(nnvm.compiler.save_param_dict(params)) +""" + +###################################################################### +# Execute the portable graph on TVM +# --------------------------------- +# Now, we would like to reproduce the same forward computation using TVM. +from tvm.contrib import graph_runtime +ctx = tvm.cpu(0) +dtype = 'uint8' +m = graph_runtime.create(graph, lib, ctx) +# set inputs +m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype))) +m.set_input(**params) +# execute +m.run() +# get outputs +tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32')) + + +###################################################################### +# Process the output to human readable +# ------------------------------------ +predictions = tvm_output.asnumpy() +predictions = np.squeeze(predictions) + +# Creates node ID --> English string lookup. +node_lookup = NodeLookup() + +top_k = predictions.argsort()[-10:][::-1] +for node_id in top_k: + human_string = node_lookup.id_to_string(node_id) + score = predictions[node_id] + print('%s (score = %.5f)' % (human_string, score))