From 35e4505497bc6fc58a60bae3294eafc30ef84175 Mon Sep 17 00:00:00 2001 From: Pariksheet Pinjari Date: Fri, 6 Apr 2018 10:23:29 +0530 Subject: [PATCH] [FRONTEND] DarkNet Yolo2 Frontend Support (#377) --- nnvm/Makefile | 2 +- nnvm/python/nnvm/frontend/__init__.py | 1 + nnvm/python/nnvm/frontend/darknet.py | 637 ++++++++++++++++++ nnvm/python/nnvm/testing/__init__.py | 2 + nnvm/python/nnvm/testing/darknet.py | 494 ++++++++++++++ nnvm/python/nnvm/testing/yolo2_detection.py | 246 +++++++ nnvm/python/nnvm/top/__init__.py | 1 + nnvm/python/nnvm/top/vision.py | 40 ++ nnvm/src/top/vision/yolo2/region.cc | 35 + nnvm/src/top/vision/yolo2/region.h | 101 +++ nnvm/src/top/vision/yolo2/reorg.cc | 52 ++ nnvm/src/top/vision/yolo2/reorg.h | 110 +++ nnvm/tests/ci_build/Dockerfile.gpu | 3 + .../install/ubuntu_install_darknet.sh | 4 + .../python/frontend/darknet/test_forward.py | 257 +++++++ nnvm/tutorials/from_darknet.py | 227 +++++++ 16 files changed, 2211 insertions(+), 1 deletion(-) create mode 100644 nnvm/python/nnvm/frontend/darknet.py create mode 100644 nnvm/python/nnvm/testing/darknet.py create mode 100644 nnvm/python/nnvm/testing/yolo2_detection.py create mode 100644 nnvm/python/nnvm/top/vision.py create mode 100644 nnvm/src/top/vision/yolo2/region.cc create mode 100644 nnvm/src/top/vision/yolo2/region.h create mode 100644 nnvm/src/top/vision/yolo2/reorg.cc create mode 100644 nnvm/src/top/vision/yolo2/reorg.h create mode 100644 nnvm/tests/ci_build/install/ubuntu_install_darknet.sh create mode 100644 nnvm/tests/python/frontend/darknet/test_forward.py create mode 100644 nnvm/tutorials/from_darknet.py diff --git a/nnvm/Makefile b/nnvm/Makefile index 4779e95b317a..62a4fadad6f0 100644 --- a/nnvm/Makefile +++ b/nnvm/Makefile @@ -56,7 +56,7 @@ endif all: lib/libnnvm.a lib/libnnvm_compiler.$(SHARED_LIBRARY_SUFFIX) SRC = $(wildcard src/*.cc src/c_api/*.cc src/core/*.cc src/pass/*.cc) -SRC_COMPILER = $(wildcard src/top/*/*.cc src/compiler/*.cc src/compiler/*/*.cc) +SRC_COMPILER = $(wildcard src/top/*/*.cc wildcard src/top/vision/*/*.cc src/compiler/*.cc src/compiler/*/*.cc) ALL_OBJ = $(patsubst %.cc, build/%.o, $(SRC)) TOP_OBJ = $(patsubst %.cc, build/%.o, $(SRC_COMPILER)) ALL_DEP = $(ALL_OBJ) diff --git a/nnvm/python/nnvm/frontend/__init__.py b/nnvm/python/nnvm/frontend/__init__.py index 100d4115bc3a..00ed9e51fbfc 100644 --- a/nnvm/python/nnvm/frontend/__init__.py +++ b/nnvm/python/nnvm/frontend/__init__.py @@ -4,3 +4,4 @@ from .onnx import from_onnx from .coreml import from_coreml from .keras import from_keras +from .darknet import from_darknet diff --git a/nnvm/python/nnvm/frontend/darknet.py b/nnvm/python/nnvm/frontend/darknet.py new file mode 100644 index 000000000000..413b07d648a4 --- /dev/null +++ b/nnvm/python/nnvm/frontend/darknet.py @@ -0,0 +1,637 @@ +""" +DarkNet symbol frontend. +""" + +from __future__ import absolute_import as _abs +from enum import IntEnum +import numpy as np +import tvm +from .. import symbol as _sym + +class LAYERTYPE(IntEnum): + """Darknet LAYERTYPE Class constant.""" + CONVOLUTIONAL = 0 + DECONVOLUTIONAL = 1 + CONNECTED = 2 + MAXPOOL = 3 + SOFTMAX = 4 + DETECTION = 5 + DROPOUT = 6 + CROP = 7 + ROUTE = 8 + COST = 9 + NORMALIZATION = 10 + AVGPOOL = 11 + LOCAL = 12 + SHORTCUT = 13 + ACTIVE = 14 + RNN = 15 + GRU = 16 + LSTM = 17 + CRNN = 18 + BATCHNORM = 19 + NETWORK = 20 + XNOR = 21 + REGION = 22 + REORG = 23 + BLANK = 24 + +class ACTIVATION(IntEnum): + """Darknet ACTIVATION Class constant.""" + LOGISTIC = 0 + RELU = 1 + RELIE = 2 + LINEAR = 3 + RAMP = 4 + TANH = 5 + PLSE = 6 + LEAKY = 7 + ELU = 8 + LOGGY = 9 + STAIR = 10 + HARDTAN = 11 + LHTAN = 12 + +__all__ = ['from_darknet'] + +def _darknet_get_nnvm_op(op_name): + """Get the nnvm operation from opname, raise error if not supported.""" + op = getattr(_sym, op_name) + if not op: + raise RuntimeError("Not to map op_name {} to nnvm.sym".format(op_name)) + return op + +def _darknet_required_attr(attr, key): + """Check the attribute exists and return if exists, if not return error.""" + assert isinstance(attr, dict) + if key not in attr: + raise AttributeError("Required attribute {} not found.".format(key)) + return attr[key] + +def _darknet_raise_not_supported(attr, op='nnvm'): + """Raise error if any operation is not supported.""" + err = "{} is not supported in {}.".format(attr, op) + raise NotImplementedError(err) + +def _darknet_warn_not_used(attr, op='nnvm'): + """Raise warning if any operation not supported.""" + import warnings + err = "{} is ignored in {}.".format(attr, op) + warnings.warn(err) + +def _darknet_parse_tshape(tshape): + """Parse tshape in string.""" + return [int(x.strip()) for x in tshape.strip('()').split(',')] + +def _darknet_parse_bool_str(attr, key, default='False'): + """Parse bool string to boolean.""" + return attr.get(key, default).strip().lower() in \ + ['true', '1', 't', 'y', 'yes'] + +def _darknet_maxpooling(inputs, attrs): + """Process the max pool 2d operation.""" + kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel')) + if len(kernel) != 1: + _darknet_raise_not_supported('non-2d kernel', 'pool_2d') + + op_name, new_attrs = 'max_pool2d', {} + strides = int(attrs.get('stride', (1, 1))) + pads = int(attrs.get('pad', (0, 0))) + new_attrs['pool_size'] = [kernel[0], kernel[0]] + new_attrs['strides'] = str((strides, strides)) + new_attrs['padding'] = str((pads, pads)) + + return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + +def _darknet_avgpooling(inputs, attrs): + """Process the average pool 2d operation.""" + kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel')) + if len(kernel) != 1: + _darknet_raise_not_supported('non-2d kernel', 'pool_2d') + + op_name, new_attrs = 'avg_pool2d', {} + strides = int(attrs.get('stride', (1, 1))) + pads = int(attrs.get('pad', (0, 0))) + new_attrs['pool_size'] = [kernel[0], kernel[0]] + new_attrs['strides'] = str((strides, strides)) + new_attrs['padding'] = str((pads, pads)) + + return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + +def _darknet_batch_norm(inputs, attrs): + """Process the batchnormalization operation.""" + op_name, new_attrs = 'darknet_batch_norm', {} + new_attrs['axis'] = attrs.get('axis', 1) + new_attrs['epsilon'] = attrs.get('eps', 0.000001) + new_attrs['center'] = True + new_attrs['scale'] = True + return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + +def _darknet_conv2d(inputs, attrs): + """Process the convolution 2d operation.""" + kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel')) + if len(kernel) != 1: + _darknet_raise_not_supported('non 2d kernel', 'conv2d') + layout = attrs.get('layout', 'NCHW') + if layout not in ['NCHW', 'NHWC']: + _darknet_raise_not_supported('layout: ' + layout, 'conv2d') + strides = int(attrs.get('stride', (1, 1))) + pads = int(attrs.get('pad', (0, 0))) + + op_name, new_attrs = 'conv2d', {} + new_attrs['channels'] = _darknet_required_attr(attrs, 'num_filter') + new_attrs['kernel_size'] = [kernel[0], kernel[0]] + new_attrs['strides'] = (strides, strides) + new_attrs['padding'] = (pads, pads) + new_attrs['dilation'] = attrs.get('dilate', (1, 1)) + new_attrs['groups'] = attrs.get('num_group', 1) + new_attrs['layout'] = layout + if attrs.get('use_batchNorm', False) is True: + new_attrs['use_bias'] = False + else: + new_attrs['use_bias'] = True + out_name = {} + sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs) + out_name[0] = sym.list_output_names()[0].replace('_output', '') + + if attrs.get('use_batchNorm', False) is True: + op_name, new_attrs = 'batch_norm', {} + new_attrs['epsilon'] = 0.000001 + sym = _darknet_get_nnvm_op(op_name)(*sym, **new_attrs) + out_name[1] = sym.list_output_names()[0].replace('_output', '') + if 'activation' in attrs: + new_attrs = {} + new_attrs['activation'] = attrs['activation'] + new_attrs['slope'] = 0.1 + sym, _ = _darknet_activations(sym, new_attrs) + return sym, out_name + + +def _darknet_conv2d_transpose(inputs, attrs): + """Process the convolution 2d transpose operation.""" + if 'target_shape' in attrs: + _darknet_raise_not_supported('target_shape', 'conv2d_transpose') + kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel')) + if len(kernel) != 2: + _darknet_raise_not_supported('non-2d kernel', 'conv2d_transpose') + layout = attrs.get('layout', 'NCHW') + if layout not in ['NCHW', 'NHWC']: + _darknet_raise_not_supported('layout: ' + layout, 'conv2d_transpose') + op_name, new_attrs = 'conv2d_transpose', {} + new_attrs['channels'] = _darknet_required_attr(attrs, 'num_filter') + new_attrs['kernel_size'] = kernel + new_attrs['strides'] = attrs.get('stride', (1, 1)) + new_attrs['output_padding'] = attrs.get('adj', (0, 0)) + new_attrs['padding'] = attrs.get('pad', (0, 0)) + new_attrs['dilation'] = attrs.get('dilate', (1, 1)) + new_attrs['groups'] = attrs.get('num_group', 1) + new_attrs['layout'] = layout + new_attrs['use_bias'] = not _darknet_parse_bool_str(attrs, 'no_bias') + return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + +def _darknet_shortcut(inputs, attrs): + """Process the shortcut operation.""" + op_name, new_attrs = 'elemwise_add', {} + input_0 = inputs[0] + input_1 = inputs[1] + input_0_channel = int(attrs['out_channel']) + input_1_channel = int(attrs['add_out_channel']) + input_0_size = int(attrs['out_size']) + input_1_size = int(attrs['add_out_size']) + + if input_0_size > input_1_size: + scale = int(input_0_size/input_1_size) + input_1 = _sym.upsampling(input_1, scale=scale, name="_upsampling") + elif input_0_size < input_1_size: + stride = int(input_1_size/input_0_size) + input_1 = _sym.avg_pool2d(input_1, pool_size=(1, 1), + strides=(stride, stride), padding=(0, 0), name="_downsampling") + + if input_0_channel != input_1_channel: + pad_channel = input_0_channel - input_1_channel + input_1 = _sym.pad(input_1, pad_width=((0, 0), (0, pad_channel), (0, 0), (0, 0)), + pad_value=0.) + + new_inputs = _as_list([input_0, input_1]) + sym = _darknet_get_nnvm_op(op_name)(*new_inputs, **new_attrs) + out_name = sym.list_output_names()[0].replace('_output', '') + if 'activation' in attrs: + new_attrs['activation'] = attrs['activation'] + sym, _ = _darknet_activations(sym, new_attrs) + return sym, out_name + +def _darknet_dense(inputs, attrs): + """Process the dense operation.""" + op_name, new_attrs = 'dense', {} + new_attrs['units'] = _darknet_required_attr(attrs, 'num_hidden') + + if attrs.get('use_bias', False) is True: + new_attrs['use_bias'] = True + if attrs.get('use_flatten', False) is True: + inputs[0] = _sym.flatten(inputs[0]) + sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs) + out_name = sym.list_output_names()[0].replace('_output', '') + if 'activation' in attrs: + new_attrs = {} + new_attrs['activation'] = attrs['activation'] + sym, _ = _darknet_activations(sym, new_attrs) + return sym, out_name + +def _darknet_dropout(inputs, attrs): + """Process the dropout operation, its a blank operation.""" + op_name, new_attrs = 'dropout', {} + new_attrs['rate'] = attrs.get('p', 0.5) + return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + +def _darknet_reshape(inputs, attrs): + """Process the reshape operation.""" + if _darknet_parse_bool_str(attrs, 'reverse'): + _darknet_raise_not_supported('reverse', 'reshape') + op_name, new_attrs = 'reshape', {} + new_attrs['shape'] = _darknet_required_attr(attrs, 'shape') + return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + +def _darknet_softmax_output(inputs, attrs): + """Process the softmax operation.""" + op_name, new_attrs = 'softmax', {} + if _darknet_parse_bool_str(attrs, 'multi_output'): + new_attrs['axis'] = 1 + + if attrs.get('use_flatten', False) is True: + inputs[0] = _sym.flatten(inputs[0]) + return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + +def _darknet_route(inputs, attrs): + """Process the route operation, which is equivalent to concat.""" + op_name = 'concatenate' + new_attrs = {'axis': attrs.get('dim', 1)} + return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + +def _darknet_reorg(inputs, attrs): + """Process the reorg operation.""" + op_name, new_attrs = 'yolo2_reorg', {} + if 'stride' in attrs: + new_attrs = {'stride': attrs.get('stride', 1)} + return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + +def _darknet_region(inputs, attrs): + """Process the region operation.""" + op_name, new_attrs = 'yolo2_region', {} + if 'n' in attrs: + new_attrs['n'] = attrs.get('n', 1) + if 'classes' in attrs: + new_attrs['classes'] = attrs.get('classes', 1) + if 'coords' in attrs: + new_attrs['coords'] = attrs.get('coords', 0) + if 'background' in attrs: + new_attrs['background'] = attrs.get('background', 0) + if 'softmax' in attrs: + new_attrs['softmax'] = attrs.get('softmax', 0) + return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + +def _darknet_activations(inputs, attrs): + """Process the activation function.""" + act = _darknet_required_attr(attrs, 'activation') + if ACTIVATION.RELU == act: + act_type = 'relu' + elif ACTIVATION.TANH == act: + act_type = 'tanh' + elif ACTIVATION.LINEAR == act: + return inputs, None + elif ACTIVATION.LEAKY == act: + act_type = 'leaky_relu' + else: + _darknet_raise_not_supported('act: ' + act) + + if act_type in ['relu', 'tanh']: + op_name, new_attrs = act_type, {} + sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs) + elif act_type in ['leaky_relu']: + op_name, new_attrs = act_type, {} + new_attrs['alpha'] = attrs.get('slope', 0.1) + sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs) + else: + _darknet_raise_not_supported('act_type: ' + act_type) + return sym, None + +def _darknet_op_not_support(inputs, attrs): + """Raise exception if the operation is not supported.""" + err = "{} is not supported in {}.".format(attrs, inputs) + raise NotImplementedError(err) + +_DARKNET_CONVERT_MAP = { + 'CONVOLUTIONAL' : _darknet_conv2d, + 'DECONVOLUTIONAL' : _darknet_conv2d_transpose, + 'CONNECTED' : _darknet_dense, + 'MAXPOOL' : _darknet_maxpooling, + 'SOFTMAX' : _darknet_softmax_output, + 'DROPOUT' : _darknet_dropout, + 'AVGPOOL' : _darknet_avgpooling, + 'BATCHNORM' : _darknet_batch_norm, + 'RESHAPE' : _darknet_reshape, + 'ROUTE' : _darknet_route, + 'REORG' : _darknet_reorg, + 'REGION' : _darknet_region, + 'ACTIVATION' : _darknet_activations, + 'SHORTCUT' : _darknet_shortcut, + 'DETECTION' : _darknet_op_not_support, + 'CROP' : _darknet_op_not_support, + 'COST' : _darknet_op_not_support, + 'NORMALIZATION' : _darknet_op_not_support, + 'LOCAL' : _darknet_op_not_support, + 'ACTIVE' : _darknet_op_not_support, + 'RNN' : _darknet_op_not_support, + 'GRU' : _darknet_op_not_support, + 'LSTM' : _darknet_op_not_support, + 'CRNN' : _darknet_op_not_support, + 'NETWORK' : _darknet_op_not_support, + 'XNOR' : _darknet_op_not_support, + 'BLANK' : _darknet_op_not_support, +} + +def _darknet_convert_symbol(op_name, inputs, attrs): + """Convert from darknet op to nnvm op. + The converter must specify some conversions explicitly to + support gluon format ops such as conv2d... + + Parameters + ---------- + op_name : str + Operator name, such as Convolution, Connected, etc + inputs : list of nnvm.Symbol + List of input symbols. + attrs : dict + Dict of operator attributes + + Returns + ------- + out_name : converted out name of operation + sym : nnvm.Symbol + Converted nnvm Symbol + """ + + if op_name in _DARKNET_CONVERT_MAP: + sym, out_name = _DARKNET_CONVERT_MAP[op_name](inputs, attrs) + else: + _darknet_raise_not_supported('Operator: ' + op_name) + if out_name is None: + out_name = sym.list_output_names()[0].replace('_output', '') + return out_name, sym + + +def _as_list(arr): + """Force being a list, ignore if already is.""" + if isinstance(arr, list): + return arr + return [arr] + +def _read_memory_buffer(shape, data, dtype): + length = 1 + for x in shape: + length *= x + data_np = np.zeros(length, dtype=dtype) + for i in range(length): + data_np[i] = data[i] + return data_np.reshape(shape) + +def _get_darknet_layername(layer_type): + """Get the layer name from the darknet enums.""" + return str((LAYERTYPE(layer_type))).replace('LAYERTYPE.', '') + +def _get_convolution_weights(layer, opname, params, dtype): + """Get the convolution layer weights and biases.""" + if layer.nweights == 0: + return + + if (layer.n * layer.c * layer.size * layer.size) != layer.nweights: + raise RuntimeError("layer weights size not matching with n c h w") + + weights = _read_memory_buffer((layer.n, layer.c, layer.size, layer.size), layer.weights, dtype) + + biases = _read_memory_buffer((layer.n, ), layer.biases, dtype) + + k = _get_tvm_params_name(opname[0], 'weight') + params[k] = tvm.nd.array(weights) + + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + _get_batchnorm_weights(layer, opname[1], params, layer.n, dtype) + k = _get_tvm_params_name(opname[1], 'beta') + params[k] = tvm.nd.array(biases) + else: + k = _get_tvm_params_name(opname[0], 'bias') + params[k] = tvm.nd.array(biases) + +def _get_connected_weights(layer, opname, params, dtype): + """Parse the weights and biases for fully connected or dense layer.""" + size = layer.outputs * layer.inputs + if size == 0: + return + + weights = _read_memory_buffer((layer.outputs, layer.inputs), layer.weights, dtype) + biases = _read_memory_buffer((layer.outputs, ), layer.biases, dtype) + + k = _get_tvm_params_name(opname, 'weight') + params[k] = tvm.nd.array(weights) + k = _get_tvm_params_name(opname, 'bias') + params[k] = tvm.nd.array(biases) + + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + _get_batchnorm_weights(layer, opname, params, layer.outputs, dtype) + +def _get_batchnorm_weights(layer, opname, params, size, dtype): + """Parse the weights for batchnorm, which includes, scales, moving mean + and moving variances.""" + scales = _read_memory_buffer((size, ), layer.scales, dtype) + rolling_mean = _read_memory_buffer((size, ), layer.rolling_mean, dtype) + rolling_variance = _read_memory_buffer((size, ), layer.rolling_variance, dtype) + + k = _get_tvm_params_name(opname, 'moving_mean') + params[k] = tvm.nd.array(rolling_mean) + k = _get_tvm_params_name(opname, 'moving_var') + params[k] = tvm.nd.array(rolling_variance) + k = _get_tvm_params_name(opname, 'gamma') + params[k] = tvm.nd.array(scales) + +def _get_darknet_attrs(net, layer_num): + """Parse attributes of each layer and return.""" + attr = {} + use_flatten = True + layer = net.layers[layer_num] + op_name = _get_darknet_layername(layer.type) + + if LAYERTYPE.CONVOLUTIONAL == layer.type: + attr.update({'layout' : 'NCHW'}) + attr.update({'pad' : str(layer.pad)}) + attr.update({'num_group' : str(layer.groups)}) + attr.update({'num_filter' : str(layer.n)}) + attr.update({'stride' : str(layer.stride)}) + attr.update({'kernel' : str(layer.size)}) + attr.update({'activation' : (layer.activation)}) + + if layer.nbiases == 0: + attr.update({'use_bias' : False}) + else: + attr.update({'use_bias' : True}) + + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + attr.update({'use_batchNorm' : True}) + attr.update({'use_scales' : True}) + + #elif LAYERTYPE.BATCHNORM == layer.type: + # attr.update({'flatten' : str('True')}) + + elif LAYERTYPE.CONNECTED == layer.type: + attr.update({'num_hidden' : str(layer.outputs)}) + attr.update({'activation' : (layer.activation)}) + if layer_num != 0: + layer_prev = net.layers[layer_num - 1] + if (layer_prev.out_h == layer.h and + layer_prev.out_w == layer.w and + layer_prev.out_c == layer.c): + use_flatten = False + attr.update({'use_flatten' : use_flatten}) + if layer.nbiases == 0: + attr.update({'use_bias' : False}) + else: + attr.update({'use_bias' : True}) + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + attr.update({'use_batchNorm' : True}) + attr.update({'use_scales' : True}) + + elif LAYERTYPE.MAXPOOL == layer.type: + attr.update({'pad' : str(layer.pad)}) + attr.update({'stride' : str(layer.stride)}) + attr.update({'kernel' : str(layer.size)}) + + elif LAYERTYPE.AVGPOOL == layer.type: + attr.update({'pad' : str(layer.pad)}) + if layer.stride == 0: + attr.update({'stride' : str(1)}) + else: + attr.update({'stride' : str(layer.stride)}) + if layer.size == 0 and layer.h == layer.w: + attr.update({'kernel' : str(layer.h)}) + else: + attr.update({'kernel' : str(layer.size)}) + + elif LAYERTYPE.DROPOUT == layer.type: + attr.update({'p' : str(layer.probability)}) + + elif LAYERTYPE.SOFTMAX == layer.type: + attr.update({'axis' : 1}) + attr.update({'use_flatten' : True}) + + elif LAYERTYPE.SHORTCUT == layer.type: + add_layer = net.layers[layer.index] + attr.update({'activation' : (layer.activation)}) + attr.update({'out_channel' : (layer.out_c)}) + attr.update({'out_size' : (layer.out_h)}) + attr.update({'add_out_channel' : (add_layer.out_c)}) + attr.update({'add_out_size' : (add_layer.out_h)}) + + elif LAYERTYPE.ROUTE == layer.type: + pass + + elif LAYERTYPE.COST == layer.type: + pass + + elif LAYERTYPE.REORG == layer.type: + attr.update({'stride' : layer.stride}) + + elif LAYERTYPE.REGION == layer.type: + attr.update({'n' : layer.n}) + attr.update({'classes' : layer.classes}) + attr.update({'coords' : layer.coords}) + attr.update({'background' : layer.background}) + attr.update({'softmax' : layer.softmax}) + else: + err = "Darknet layer {} is not supported in nnvm.".format(op_name) + raise NotImplementedError(err) + + return op_name, attr + +def _get_tvm_params_name(opname, arg_name): + """Makes the params name for the k,v pair.""" + return opname + '_'+ arg_name + +def _get_darknet_params(layer, opname, tvmparams, dtype='float32'): + """To parse and get the darknet params.""" + if LAYERTYPE.CONVOLUTIONAL == layer.type: + _get_convolution_weights(layer, opname, tvmparams, dtype) + + #elif LAYERTYPE.BATCHNORM == layer.type: + # size = layer.outputs + # _get_batchnorm_weights(layer, opname, tvmparams, size, dtype) + + elif LAYERTYPE.CONNECTED == layer.type: + _get_connected_weights(layer, opname, tvmparams, dtype) + +def _preproc_layer(net, i, sym_array): + """To preprocess each darknet layer, some layer doesnt need processing.""" + layer = net.layers[i] + if i == 0: + name = 'data' + attribute = {} + sym = [_sym.Variable(name, **attribute)] + else: + sym = sym_array[i - 1] + skip_layer = False + + if LAYERTYPE.ROUTE == layer.type: + sym = [] + for j in range(layer.n): + sym.append(sym_array[layer.input_layers[j]]) + if layer.n == 1: + skip_layer = True + + elif LAYERTYPE.COST == layer.type: + skip_layer = True + + elif LAYERTYPE.SHORTCUT == layer.type: + sym = [sym, sym_array[layer.index]] + + elif LAYERTYPE.BLANK == layer.type: + skip_layer = True + + if skip_layer is True: + sym_array[i] = sym + + return skip_layer, sym + +def _from_darknet(net, dtype='float32'): + """To convert the darknet symbol to nnvm symbols.""" + sym_array = {} + tvmparams = {} + for i in range(net.n): + need_skip, sym = _preproc_layer(net, i, sym_array) + if need_skip is True: + continue + op_name, attr = _get_darknet_attrs(net, i) + layer_name, sym = _darknet_convert_symbol(op_name, _as_list(sym), attr) + _get_darknet_params(net.layers[i], layer_name, tvmparams, dtype) + sym_array[i] = sym + + return sym, tvmparams + +def from_darknet(net, dtype='float32'): + """Convert from darknet's model into compatible NNVM format. + Reconstruct a nnvm symbol by traversing the darknet input. + + Parameters + ---------- + net : ctype Pointer to network + Darknet parsed symbols + + dtype : str + Datatype of the input net structure, default is float32 + + Returns + ------- + sym : nnvm.Symbol + Compatible nnvm symbol + + params : dict of str to tvm.NDArray + The parameter dict to be used by nnvm + """ + + return _from_darknet(net, dtype) diff --git a/nnvm/python/nnvm/testing/__init__.py b/nnvm/python/nnvm/testing/__init__.py index 84c51a4df41d..19c37f7ac187 100644 --- a/nnvm/python/nnvm/testing/__init__.py +++ b/nnvm/python/nnvm/testing/__init__.py @@ -7,3 +7,5 @@ from . import mlp from . import resnet from . import vgg +from . import darknet +from . import yolo2_detection diff --git a/nnvm/python/nnvm/testing/darknet.py b/nnvm/python/nnvm/testing/darknet.py new file mode 100644 index 000000000000..30b790bb4e6d --- /dev/null +++ b/nnvm/python/nnvm/testing/darknet.py @@ -0,0 +1,494 @@ +# pylint: disable=invalid-name, unused-variable, unused-argument, no-init +""" +Compile DarkNet Models +==================== +DarkNet helper functions for darknet model parsing and image loading. +This functions will not be loaded by default. +These are utility functions used for testing and tutorial file. +""" +from __future__ import division +from enum import IntEnum +import math +import numpy as np +import cv2 +from cffi import FFI + +def _resize_image(img, w_in, h_in): + """Resize the image to the given height and width.""" + imc, imh, imw = img.shape + h_in = int(h_in) + w_in = int(w_in) + part = np.zeros((imc, imh, w_in)) + resized = np.zeros((imc, h_in, w_in)) + w_scale = (imw - 1) / (w_in - 1) + h_scale = (imh - 1) / (h_in - 1) + for k in range(imc): + for j in range(imh): + for c in range(w_in): + if c == w_in - 1 or imw == 1: + part[k][j][c] = img[k][j][imw - 1] + else: + fdx, idx = math.modf(c * w_scale) + part[k][j][c] = (1 - fdx) * img[k][j][int(idx)] + \ + fdx * img[k][j][int(idx) + 1] + for k in range(imc): + for j in range(h_in): + fdy, idy = math.modf(j * h_scale) + for c in range(w_in): + resized[k][j][c] = (1 - fdy)*part[k][int(idy)][c] + if (j == h_in - 1) or (imh == 1): + continue + for c in range(w_in): + resized[k][j][c] += fdy * part[k][int(idy) + 1][c] + return resized + +def load_image_color(test_image): + """To load the image using opencv api and do preprocessing.""" + imagex = cv2.imread(test_image) + imagex = np.array(imagex) + imagex = imagex.transpose((2, 0, 1)) + imagex = np.divide(imagex, 255.0) + imagex = np.flip(imagex, 0) + return imagex + +def _letterbox_image(img, w_in, h_in): + """To get the image in boxed format.""" + imc, imh, imw = img.shape + if (w_in / imw) < (h_in / imh): + new_w = w_in + new_h = imh * w_in / imw + else: + new_h = h_in + new_w = imw * h_in/imh + resized = _resize_image(img, new_w, new_h) + boxed = np.full((imc, h_in, w_in), 0.5, dtype=float) + _, resizedh, resizedw = resized.shape + boxed[:, int((h_in - new_h) / 2) + :int((h_in - new_h) / 2) + resizedh, int((w_in - new_w) / 2) + :int((w_in - new_w) / 2) + resizedw] = resized + return boxed + +def load_image(image, resize_width, resize_height): + """Load the image and convert to the darknet model format. + The image processing of darknet is different from normal. + Parameters + ---------- + image : string + The image file name with path + + resize_width : integer + The width to which the image needs to be resized + + resize_height : integer + The height to which the image needs to be resized + + Returns + ------- + img : Float array + Array of processed image + """ + + img = load_image_color(image) + return _letterbox_image(img, resize_width, resize_height) + +class LAYERTYPE(IntEnum): + """Darknet LAYERTYPE Class constant.""" + CONVOLUTIONAL = 0 + DECONVOLUTIONAL = 1 + CONNECTED = 2 + MAXPOOL = 3 + SOFTMAX = 4 + DETECTION = 5 + DROPOUT = 6 + CROP = 7 + ROUTE = 8 + COST = 9 + NORMALIZATION = 10 + AVGPOOL = 11 + LOCAL = 12 + SHORTCUT = 13 + ACTIVE = 14 + RNN = 15 + GRU = 16 + LSTM = 17 + CRNN = 18 + BATCHNORM = 19 + NETWORK = 20 + XNOR = 21 + REGION = 22 + REORG = 23 + BLANK = 24 + +class ACTIVATION(IntEnum): + """Darknet ACTIVATION Class constant.""" + LOGISTIC = 0 + RELU = 1 + RELIE = 2 + LINEAR = 3 + RAMP = 4 + TANH = 5 + PLSE = 6 + LEAKY = 7 + ELU = 8 + LOGGY = 9 + STAIR = 10 + HARDTAN = 11 + LHTAN = 12 + +__darknetffi__ = FFI() + +__darknetffi__.cdef(""" +typedef struct network network; +typedef struct layer layer; + +typedef struct{ + int *leaf; + int n; + int *parent; + int *child; + int *group; + char **name; + + int groups; + int *group_size; + int *group_offset; +} tree; + +typedef enum{ + LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN +} ACTIVATION; + + +typedef enum { + CONVOLUTIONAL, + DECONVOLUTIONAL, + CONNECTED, + MAXPOOL, + SOFTMAX, + DETECTION, + DROPOUT, + CROP, + ROUTE, + COST, + NORMALIZATION, + AVGPOOL, + LOCAL, + SHORTCUT, + ACTIVE, + RNN, + GRU, + LSTM, + CRNN, + BATCHNORM, + NETWORK, + XNOR, + REGION, + REORG, + BLANK +} LAYERTYPE; + +typedef enum{ + SSE, MASKED, LONE, SEG, SMOOTH +} COSTTYPE; + + +struct layer{ + LAYERTYPE type; + ACTIVATION activation; + COSTTYPE cost_type; + void (*forward); + void (*backward); + void (*update); + void (*forward_gpu); + void (*backward_gpu); + void (*update_gpu); + int batch_normalize; + int shortcut; + int batch; + int forced; + int flipped; + int inputs; + int outputs; + int nweights; + int nbiases; + int extra; + int truths; + int h,w,c; + int out_h, out_w, out_c; + int n; + int max_boxes; + int groups; + int size; + int side; + int stride; + int reverse; + int flatten; + int spatial; + int pad; + int sqrt; + int flip; + int index; + int binary; + int xnor; + int steps; + int hidden; + int truth; + float smooth; + float dot; + float angle; + float jitter; + float saturation; + float exposure; + float shift; + float ratio; + float learning_rate_scale; + int softmax; + int classes; + int coords; + int background; + int rescore; + int objectness; + int does_cost; + int joint; + int noadjust; + int reorg; + int log; + int tanh; + + float alpha; + float beta; + float kappa; + + float coord_scale; + float object_scale; + float noobject_scale; + float mask_scale; + float class_scale; + int bias_match; + int random; + float thresh; + int classfix; + int absolute; + + int onlyforward; + int stopbackward; + int dontload; + int dontloadscales; + + float temperature; + float probability; + float scale; + + char * cweights; + int * indexes; + int * input_layers; + int * input_sizes; + int * map; + float * rand; + float * cost; + float * state; + float * prev_state; + float * forgot_state; + float * forgot_delta; + float * state_delta; + float * combine_cpu; + float * combine_delta_cpu; + + float * concat; + float * concat_delta; + + float * binary_weights; + + float * biases; + float * bias_updates; + + float * scales; + float * scale_updates; + + float * weights; + float * weight_updates; + + float * delta; + float * output; + float * squared; + float * norms; + + float * spatial_mean; + float * mean; + float * variance; + + float * mean_delta; + float * variance_delta; + + float * rolling_mean; + float * rolling_variance; + + float * x; + float * x_norm; + + float * m; + float * v; + + float * bias_m; + float * bias_v; + float * scale_m; + float * scale_v; + + + float *z_cpu; + float *r_cpu; + float *h_cpu; + float * prev_state_cpu; + + float *temp_cpu; + float *temp2_cpu; + float *temp3_cpu; + + float *dh_cpu; + float *hh_cpu; + float *prev_cell_cpu; + float *cell_cpu; + float *f_cpu; + float *i_cpu; + float *g_cpu; + float *o_cpu; + float *c_cpu; + float *dc_cpu; + + float * binary_input; + + struct layer *input_layer; + struct layer *self_layer; + struct layer *output_layer; + + struct layer *reset_layer; + struct layer *update_layer; + struct layer *state_layer; + + struct layer *input_gate_layer; + struct layer *state_gate_layer; + struct layer *input_save_layer; + struct layer *state_save_layer; + struct layer *input_state_layer; + struct layer *state_state_layer; + + struct layer *input_z_layer; + struct layer *state_z_layer; + + struct layer *input_r_layer; + struct layer *state_r_layer; + + struct layer *input_h_layer; + struct layer *state_h_layer; + + struct layer *wz; + struct layer *uz; + struct layer *wr; + struct layer *ur; + struct layer *wh; + struct layer *uh; + struct layer *uo; + struct layer *wo; + struct layer *uf; + struct layer *wf; + struct layer *ui; + struct layer *wi; + struct layer *ug; + struct layer *wg; + + tree *softmax_tree; + + size_t workspace_size; +}; + + +typedef enum { + CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM +} LEARNINGRATEPOLICY; + +typedef struct network{ + int n; + int batch; + size_t *seen; + int *t; + float epoch; + int subdivisions; + layer *layers; + float *output; + LEARNINGRATEPOLICY policy; + + float learning_rate; + float momentum; + float decay; + float gamma; + float scale; + float power; + int time_steps; + int step; + int max_batches; + float *scales; + int *steps; + int num_steps; + int burn_in; + + int adam; + float B1; + float B2; + float eps; + + int inputs; + int outputs; + int truths; + int notruth; + int h, w, c; + int max_crop; + int min_crop; + float max_ratio; + float min_ratio; + int center; + float angle; + float aspect; + float exposure; + float saturation; + float hue; + int random; + + int gpu_index; + tree *hierarchy; + + float *input; + float *truth; + float *delta; + float *workspace; + int train; + int index; + float *cost; +} network; + + +typedef struct { + int w; + int h; + int c; + float *data; +} image; + +network *load_network(char *cfg, char *weights, int clear); +image letterbox_image(image im, int w, int h); +int resize_network(network *net, int w, int h); +void top_predictions(network *net, int n, int *index); +void free_image(image m); +image load_image_color(char *filename, int w, int h); +float *network_predict_image(network *net, image im); +network *make_network(int n); +layer make_convolutional_layer(int batch, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam); +layer make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation, int batch_normalize, int adam); +layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride, int padding); +layer make_avgpool_layer(int batch, int w, int h, int c); +layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2); +layer make_batchnorm_layer(int batch, int w, int h, int c); +layer make_reorg_layer(int batch, int w, int h, int c, int stride, int reverse, int flatten, int extra); +layer make_region_layer(int batch, int w, int h, int n, int classes, int coords); +void free_network(network *net); +""" + ) diff --git a/nnvm/python/nnvm/testing/yolo2_detection.py b/nnvm/python/nnvm/testing/yolo2_detection.py new file mode 100644 index 000000000000..b7744c45cff4 --- /dev/null +++ b/nnvm/python/nnvm/testing/yolo2_detection.py @@ -0,0 +1,246 @@ +# pylint: disable=invalid-name, unused-variable, unused-argument, no-init +""" +Yolo detection boxes helper functions +==================== +DarkNet helper functions for yolo and image loading. +This functions will not be loaded by default. +These are utility functions used for testing and tutorial file. +""" +from __future__ import division +import math +from collections import namedtuple +import numpy as np +from PIL import Image +from PIL import ImageDraw +from PIL import ImageFont + +def _entry_index(batch, w, h, outputs, classes, coords, location, entry): + n = int(location/(w*h)) + loc = location%(w*h) + return batch*outputs + n*w*h*(coords+classes+1) + entry*w*h + loc + +Box = namedtuple('Box', ['x', 'y', 'w', 'h']) +def _get_region_box(x, biases, n, index, i, j, w, h, stride): + b = Box(0, 0, 0, 0) + b = b._replace(x=(i + x[index + 0*stride]) / w) + b = b._replace(y=(j + x[index + 1*stride]) / h) + b = b._replace(w=np.exp(x[index + 2*stride]) * biases[2*n] / w) + b = b._replace(h=np.exp(x[index + 3*stride]) * biases[2*n+1] / h) + return b + +def _correct_region_boxes(boxes, n, w, h, netw, neth, relative): + new_w, new_h = (netw, (h*netw)/w) if (netw/w < neth/h) else ((w*neth/h), neth) + for i in range(n): + b = boxes[i] + b = boxes[i] + b = b._replace(x=(b.x - (netw - new_w)/2/netw) / (new_w/netw)) + b = b._replace(y=(b.y - (neth - new_h)/2/neth) / (new_h/neth)) + b = b._replace(w=b.w * netw/new_w) + b = b._replace(h=b.h * neth/new_h) + if not relative: + b = b._replace(x=b.x * w) + b = b._replace(w=b.w * w) + b = b._replace(y=b.y * h) + b = b._replace(h=b.h * h) + boxes[i] = b + +def _overlap(x1, w1, x2, w2): + l1 = x1 - w1/2 + l2 = x2 - w2/2 + left = l1 if l1 > l2 else l2 + r1 = x1 + w1/2 + r2 = x2 + w2/2 + right = r1 if r1 < r2 else r2 + return right - left + +def _box_intersection(a, b): + w = _overlap(a.x, a.w, b.x, b.w) + h = _overlap(a.y, a.h, b.y, b.h) + if w < 0 or h < 0: + return 0 + return w*h + +def _box_union(a, b): + i = _box_intersection(a, b) + u = a.w*a.h + b.w*b.h - i + return u + +def _box_iou(a, b): + return _box_intersection(a, b)/_box_union(a, b) + +def get_region_boxes(layer_in, imw, imh, netw, neth, thresh, probs, + boxes, relative, tvm_out): + "To get the boxes for the image based on the prediction" + lw = layer_in.w + lh = layer_in.h + probs = [[0 for i in range(layer_in.classes + 1)] for y in range(lw*lh*layer_in.n)] + boxes = [Box(0, 0, 0, 0) for i in range(lw*lh*layer_in.n)] + for i in range(lw*lh): + row = int(i / lw) + col = int(i % lw) + for n in range(layer_in.n): + index = n*lw*lh + i + obj_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes, + layer_in.coords, n*lw*lh + i, layer_in.coords) + box_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes, + layer_in.coords, n*lw*lh + i, 0) + mask_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes, + layer_in.coords, n*lw*lh + i, 4) + scale = 1 if layer_in.background else tvm_out[obj_index] + boxes[index] = _get_region_box(tvm_out, layer_in.biases, n, box_index, col, + row, lw, lh, lw*lh) + if not layer_in.softmax_tree: + max_element = 0 + for j in range(layer_in.classes): + class_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes, + layer_in.coords, n*lw*lh + i, layer_in.coords+1+j) + prob = scale*tvm_out[class_index] + probs[index][j] = prob if prob > thresh else 0 + max_element = max(max_element, prob) + probs[index][layer_in.classes] = max_element + + _correct_region_boxes(boxes, lw*lh*layer_in.n, imw, imh, netw, neth, relative) + return boxes, probs + + +def do_nms_sort(boxes, probs, total, classes, thresh): + "Does the sorting based on the threshold values" + SortableBbox = namedtuple('SortableBbox', ['index_var', 'class_var', 'probs']) + + s = [SortableBbox(0, 0, []) for i in range(total)] + for i in range(total): + s[i] = s[i]._replace(index_var=i) + s[i] = s[i]._replace(class_var=0) + s[i] = s[i]._replace(probs=probs) + + for k in range(classes): + for i in range(total): + s[i] = s[i]._replace(class_var=k) + s = sorted(s, key=lambda x: x.probs[x.index_var][x.class_var], reverse=True) + for i in range(total): + if probs[s[i].index_var][k] == 0: + continue + a = boxes[s[i].index_var] + for j in range(i+1, total): + b = boxes[s[j].index_var] + if _box_iou(a, b) > thresh: + probs[s[j].index_var][k] = 0 + return boxes, probs + +def draw_detections(im, num, thresh, boxes, probs, names, classes): + "Draw the markings around the detected region" + for i in range(num): + labelstr = [] + category = -1 + for j in range(classes): + if probs[i][j] > thresh: + if category == -1: + category = j + labelstr.append(names[j]) + if category > -1: + imc, imh, imw = im.shape + width = int(imh * 0.006) + offset = category*123457 % classes + red = _get_color(2, offset, classes) + green = _get_color(1, offset, classes) + blue = _get_color(0, offset, classes) + rgb = [red, green, blue] + b = boxes[i] + left = int((b.x-b.w/2.)*imw) + right = int((b.x+b.w/2.)*imw) + top = int((b.y-b.h/2.)*imh) + bot = int((b.y+b.h/2.)*imh) + + if left < 0: + left = 0 + if right > imw-1: + right = imw-1 + if top < 0: + top = 0 + if bot > imh-1: + bot = imh-1 + _draw_box_width(im, left, top, right, bot, width, red, green, blue) + label = _get_label(''.join(labelstr), rgb) + _draw_label(im, top + width, left, label, rgb) + +def _get_pixel(im, x, y, c): + return im[c][y][x] + +def _set_pixel(im, x, y, c, val): + if x < 0 or y < 0 or c < 0 or x >= im.shape[2] or y >= im.shape[1] or c >= im.shape[0]: + return + im[c][y][x] = val + +def _draw_label(im, r, c, label, rgb): + w = label.shape[2] + h = label.shape[1] + if (r - h) >= 0: + r = r - h + + for j in range(h): + if j < h and (j + r) < im.shape[1]: + for i in range(w): + if i < w and (i + c) < im.shape[2]: + for k in range(label.shape[0]): + val = _get_pixel(label, i, j, k) + _set_pixel(im, i+c, j+r, k, val)#rgb[k] * val) + +def _get_label(labelstr, rgb): + text = labelstr + colorText = "black" + testDraw = ImageDraw.Draw(Image.new('RGB', (1, 1))) + font = ImageFont.truetype("arial.ttf", 25) + width, height = testDraw.textsize(labelstr, font=font) + img = Image.new('RGB', (width, height), color=(int(rgb[0]*255), int(rgb[1]*255), + int(rgb[2]*255))) + d = ImageDraw.Draw(img) + d.text((0, 0), text, fill=colorText, font=font) + opencvImage = np.divide(np.asarray(img), 255) + return opencvImage.transpose(2, 0, 1) + +def _get_color(c, x, max_value): + c = int(c) + colors = [[1, 0, 1], [0, 0, 1], [0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 0]] + ratio = (float(x)/float(max_value)) * 5 + i = int(math.floor(ratio)) + j = int(math.ceil(ratio)) + ratio -= i + r = (1-ratio) * colors[i][c] + ratio*colors[j][c] + return r + +def _draw_box(im, x1, y1, x2, y2, r, g, b): + y1 = int(y1) + y2 = int(y2) + x1 = int(x1) + x2 = int(x2) + ac, ah, aw = im.shape + if x1 < 0: + x1 = 0 + if x1 >= aw: + y1 = 0 + if y1 >= ah: + y1 = ah - 1 + if y2 < 0: + y2 = 0 + if y2 >= ah: + y2 = ah - 1 + + for i in range(x1, x2): + im[0][y1][i] = r + im[0][y2][i] = r + im[1][y1][i] = g + im[1][y2][i] = g + im[2][y1][i] = b + im[2][y2][i] = b + + for i in range(y1, y2): + im[0][i][x1] = r + im[0][i][x2] = r + im[1][i][x1] = g + im[1][i][x2] = g + im[2][i][x1] = b + im[2][i][x2] = b + +def _draw_box_width(im, x1, y1, x2, y2, w, r, g, b): + for i in range(int(w)): + _draw_box(im, x1+i, y1+i, x2-i, y2-i, r, g, b) diff --git a/nnvm/python/nnvm/top/__init__.py b/nnvm/python/nnvm/top/__init__.py index 273324d1f88c..12294fa0df48 100644 --- a/nnvm/python/nnvm/top/__init__.py +++ b/nnvm/python/nnvm/top/__init__.py @@ -7,6 +7,7 @@ from . import nn from . import transform from . import reduction +from . import vision from .registry import OpPattern from .registry import register_compute, register_schedule, register_pattern diff --git a/nnvm/python/nnvm/top/vision.py b/nnvm/python/nnvm/top/vision.py new file mode 100644 index 000000000000..89409de6263b --- /dev/null +++ b/nnvm/python/nnvm/top/vision.py @@ -0,0 +1,40 @@ + +# pylint: disable=invalid-name, unused-argument +"""Definition of nn ops""" +from __future__ import absolute_import + +import topi +import tvm +from . import registry as reg +from .registry import OpPattern + +@reg.register_compute("yolo2_reorg") +def compute_reorg(attrs, inputs, _): + """Compute definition of reorg""" + return topi.vision.reorg(inputs[0], attrs.get_int("stride")) + +@reg.register_schedule("yolo2_reorg") +def schedule_reorg(attrs, outs, target): + """Schedule definition of reorg""" + with tvm.target.create(target): + return topi.generic.schedule_injective(outs) + +reg.register_pattern("yolo2_reorg", OpPattern.INJECTIVE) + +@reg.register_compute("yolo2_region") +def compute_region(attrs, inputs, _): + """Compute definition of region""" + n = attrs.get_int("n") + classes = attrs.get_int("classes") + coords = attrs.get_int("coords") + background = attrs.get_int("background") + softmax = attrs.get_int("softmax") + return topi.vision.yolo2.region(inputs[0], n, classes, coords, background, softmax) + +@reg.register_schedule("yolo2_region") +def schedule_region(attrs, outs, target): + """Schedule definition of region""" + with tvm.target.create(target): + return topi.generic.vision.schedule_region(outs) + +reg.register_pattern("yolo2_region", OpPattern.OPAQUE) diff --git a/nnvm/src/top/vision/yolo2/region.cc b/nnvm/src/top/vision/yolo2/region.cc new file mode 100644 index 000000000000..87860be3d03a --- /dev/null +++ b/nnvm/src/top/vision/yolo2/region.cc @@ -0,0 +1,35 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file region.cc + * \brief Property def of pooling operators. + */ +#include +#include +#include +#include +#include "../../op_common.h" +#include "region.h" + +namespace nnvm { +namespace top { + +NNVM_REGISTER_OP(yolo2_region) +.describe(R"code(Region layer +)code" NNVM_ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_support_level(5) +.add_argument("data", "Tensor", "Input data") +.set_attr("FInferType", RegionType<1, 1>) +.set_attr("FInferShape", RegionShape<1, 1>) +.set_attr( + "FInplaceOption", + [](const NodeAttrs &attrs) { + return std::vector>{{0, 0}, {1, 0}}; + }) +.set_attr("FGradient", [](const NodePtr &n, + const std::vector &ograds) { + return std::vector{ograds[0], ograds[0]}; +}); +} // namespace top +} // namespace nnvm diff --git a/nnvm/src/top/vision/yolo2/region.h b/nnvm/src/top/vision/yolo2/region.h new file mode 100644 index 000000000000..cc816eab6ae1 --- /dev/null +++ b/nnvm/src/top/vision/yolo2/region.h @@ -0,0 +1,101 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file region.h + */ +#ifndef NNVM_TOP_VISION_YOLO2_REGION_H_ +#define NNVM_TOP_VISION_YOLO2_REGION_H_ + +#include +#include +#include +#include +#include + +namespace nnvm { +namespace top { + +template +inline bool RegionAttr(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs, + const AttrType &none) { + AttrType dattr = none; + size_t in_size = in_attrs->size(); + size_t out_size = out_attrs->size(); + if (n_in != -1) { + in_size = static_cast(n_in); + } + if (n_out != -1) { + out_size = static_cast(n_out); + } + + auto deduce = [&](std::vector *vec, size_t size, const char *name) { + for (size_t i = 0; i < size; ++i) { + if (i == 0) + CHECK(assign(&dattr, (*vec)[i])) + << "Incompatible attr in node " << attrs.name << " at " << i + << "-th " << name << ": " + << "expected " << attr_string(dattr) << ", got " + << attr_string((*vec)[i]); + } + }; + deduce(in_attrs, in_size, "input"); + + auto write = [&](std::vector *vec, size_t size, const char *name) { + for (size_t i = 0; i < size; ++i) { + CHECK(assign(&(*vec)[i], dattr)) + << "Incompatible attr in node " << attrs.name << " at " << i << "-th " + << name << ": " + << "expected " << attr_string(dattr) << ", got " + << attr_string((*vec)[i]); + } + }; + write(out_attrs, out_size, "output"); + + if (is_none(dattr)) { + return false; + } + return true; +} + +template +inline bool RegionShape(const NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + if (n_in != -1) { + CHECK_EQ(in_attrs->size(), static_cast(n_in)) + << " in operator " << attrs.name; + } + if (n_out != -1) { + CHECK_EQ(out_attrs->size(), static_cast(n_out)) + << " in operator " << attrs.name; + } + return RegionAttr( + attrs, in_attrs, out_attrs, TShape()); +} + +template +inline bool RegionType(const NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + if (n_in != -1) { + CHECK_EQ(in_attrs->size(), static_cast(n_in)) + << " in operator " << attrs.name; + } + if (n_out != -1) { + CHECK_EQ(out_attrs->size(), static_cast(n_out)) + << " in operator " << attrs.name; + } + return RegionAttr( + attrs, in_attrs, out_attrs, -1); +} +} // namespace top +} // namespace nnvm +#endif // NNVM_TOP_VISION_YOLO2_REGION_H_ diff --git a/nnvm/src/top/vision/yolo2/reorg.cc b/nnvm/src/top/vision/yolo2/reorg.cc new file mode 100644 index 000000000000..e58940eb25dd --- /dev/null +++ b/nnvm/src/top/vision/yolo2/reorg.cc @@ -0,0 +1,52 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file reorg.cc + */ +#include +#include +#include +#include +#include "../../op_common.h" +#include "../../elemwise_op_common.h" +#include "reorg.h" + +namespace nnvm { +namespace top { + +// reorg +DMLC_REGISTER_PARAMETER(ReorgParam); + +inline bool ReorgInferShape(const nnvm::NodeAttrs &attrs, + std::vector *in_shape, + std::vector *out_shape) { + const ReorgParam ¶m = nnvm::get(attrs.parsed); + TShape dshape = in_shape->at(0); + if (dshape.ndim() == 0) + return false; + NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape); + CHECK_EQ(dshape.ndim(), 4) << "Input data should be 4D"; + CHECK_GT(param.stride, 0U) << "Stride value cannot be 0"; + TShape oshape({dshape[0], 0, 0, 0}); + oshape[1] = dshape[1] * param.stride * param.stride; + oshape[2] = dshape[2] / param.stride; + oshape[3] = dshape[3] / param.stride; + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); + return true; +} + +NNVM_REGISTER_OP(yolo2_reorg) +.describe(R"(Perform reorg operation on input array based on the stride value. +- **data**: Input is 4D array of shape (batch_size, channels, in_height, in_width). +- **out**: Output is 4D array of shape (batch_size, channels/(stride*stride), in_height*stride, in_width*stride). +)" NNVM_ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_support_level(5) +.add_argument("data", "Tensor", "Data input to reorganize") +.set_attr_parser(ParamParser) +.add_arguments(ReorgParam::__FIELDS__()) +.set_attr("FGetAttrDict", ParamGetAttrDict) +.set_attr("FInferType", ElemwiseType<-1, 1>) +.set_attr("FInferShape", ReorgInferShape); +} // namespace top +} // namespace nnvm diff --git a/nnvm/src/top/vision/yolo2/reorg.h b/nnvm/src/top/vision/yolo2/reorg.h new file mode 100644 index 000000000000..87e0510e2781 --- /dev/null +++ b/nnvm/src/top/vision/yolo2/reorg.h @@ -0,0 +1,110 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file reorg.h + */ +#ifndef NNVM_TOP_VISION_YOLO2_REORG_H_ +#define NNVM_TOP_VISION_YOLO2_REORG_H_ + +#include +#include +#include +#include +#include + +namespace nnvm { +namespace top { + +template +inline bool ReorgAttr(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs, + const AttrType &none) { + AttrType dattr = none; + size_t in_size = in_attrs->size(); + size_t out_size = out_attrs->size(); + if (n_in != -1) { + in_size = static_cast(n_in); + } + if (n_out != -1) { + out_size = static_cast(n_out); + } + + auto deduce = [&](std::vector *vec, size_t size, const char *name) { + for (size_t i = 0; i < size; ++i) { + if (i == 0) { + CHECK(assign(&dattr, (*vec)[i])) + << "Incompatible attr in node " << attrs.name << " at " << i + << "-th " << name << ": " + << "expected " << attr_string(dattr) << ", got " + << attr_string((*vec)[i]); + } + } + }; + deduce(in_attrs, in_size, "input"); + + auto write = [&](std::vector *vec, size_t size, const char *name) { + for (size_t i = 0; i < size; ++i) { + CHECK(assign(&(*vec)[i], dattr)) + << "Incompatible attr in node " << attrs.name << " at " << i << "-th " + << name << ": " + << "expected " << attr_string(dattr) << ", got " + << attr_string((*vec)[i]); + } + }; + write(out_attrs, out_size, "output"); + + if (is_none(dattr)) { + return false; + } + return true; +} + +template +inline bool ReorgShape(const NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + if (n_in != -1) { + CHECK_EQ(in_attrs->size(), static_cast(n_in)) + << " in operator " << attrs.name; + } + if (n_out != -1) { + CHECK_EQ(out_attrs->size(), static_cast(n_out)) + << " in operator " << attrs.name; + } + return ReorgAttr( + attrs, in_attrs, out_attrs, TShape()); +} + +template +inline bool ReorgType(const NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + if (n_in != -1) { + CHECK_EQ(in_attrs->size(), static_cast(n_in)) + << " in operator " << attrs.name; + } + if (n_out != -1) { + CHECK_EQ(out_attrs->size(), static_cast(n_out)) + << " in operator " << attrs.name; + } + return ReorgAttr( + attrs, in_attrs, out_attrs, -1); +} + +struct ReorgParam : public dmlc::Parameter { + int stride; + + DMLC_DECLARE_PARAMETER(ReorgParam) { + DMLC_DECLARE_FIELD(stride).set_default(1).describe("Stride value"); + } +}; +} // namespace top +} // namespace nnvm +#endif // NNVM_TOP_VISION_YOLO2_REORG_H_ diff --git a/nnvm/tests/ci_build/Dockerfile.gpu b/nnvm/tests/ci_build/Dockerfile.gpu index bde32322ceaa..2ee5ed04e91f 100644 --- a/nnvm/tests/ci_build/Dockerfile.gpu +++ b/nnvm/tests/ci_build/Dockerfile.gpu @@ -41,6 +41,9 @@ RUN bash /install/ubuntu_install_coreml.sh COPY install/ubuntu_install_keras.sh /install/ubuntu_install_keras.sh RUN bash /install/ubuntu_install_keras.sh +COPY install/ubuntu_install_darknet.sh /install/ubuntu_install_darknet.sh +RUN bash /install/ubuntu_install_darknet.sh + RUN pip install Pillow # Environment variables diff --git a/nnvm/tests/ci_build/install/ubuntu_install_darknet.sh b/nnvm/tests/ci_build/install/ubuntu_install_darknet.sh new file mode 100644 index 000000000000..f5e0c2791d80 --- /dev/null +++ b/nnvm/tests/ci_build/install/ubuntu_install_darknet.sh @@ -0,0 +1,4 @@ +#install the necessary dependancies, cffi, opencv +wget 'https://github.com/siju-samuel/darknet/blob/master/lib/libdarknet.so?raw=true' -O libdarknet.so +pip2 install opencv-python cffi +pip3 install opencv-python cffi diff --git a/nnvm/tests/python/frontend/darknet/test_forward.py b/nnvm/tests/python/frontend/darknet/test_forward.py new file mode 100644 index 000000000000..ad28c49c014d --- /dev/null +++ b/nnvm/tests/python/frontend/darknet/test_forward.py @@ -0,0 +1,257 @@ +""" +Compile Darknet Models +===================== +This article is a test script to test darknet models with NNVM. +All the required models and libraries will be downloaded from the internet +by the script. +""" +import os +import requests +import numpy as np +from nnvm import frontend +from nnvm.testing.darknet import __darknetffi__ +import nnvm.compiler +import tvm +import sys +import urllib +if sys.version_info >= (3,): + import urllib.request as urllib2 +else: + import urllib2 + +def _download(url, path, overwrite=False, sizecompare=False): + ''' Download from internet''' + if os.path.isfile(path) and not overwrite: + if sizecompare: + file_size = os.path.getsize(path) + res_head = requests.head(url) + res_get = requests.get(url, stream=True) + if 'Content-Length' not in res_head.headers: + res_get = urllib2.urlopen(url) + urlfile_size = int(res_get.headers['Content-Length']) + if urlfile_size != file_size: + print("exist file got corrupted, downloading", path, " file freshly") + _download(url, path, True, False) + return + print('File {} exists, skip.'.format(path)) + return + print('Downloading from url {} to {}'.format(url, path)) + try: + urllib.request.urlretrieve(url, path) + print('') + except: + urllib.urlretrieve(url, path) + +DARKNET_LIB = 'libdarknet.so' +DARKNETLIB_URL = 'https://github.com/siju-samuel/darknet/blob/master/lib/' \ + + DARKNET_LIB + '?raw=true' +_download(DARKNETLIB_URL, DARKNET_LIB) +LIB = __darknetffi__.dlopen('./' + DARKNET_LIB) + +def test_forward(net): + '''Test network with given input image on both darknet and tvm''' + def get_darknet_output(net, img): + return LIB.network_predict_image(net, img) + + def get_tvm_output(net, img): + '''Compute TVM output''' + dtype = 'float32' + batch_size = 1 + sym, params = frontend.darknet.from_darknet(net, dtype) + data = np.empty([batch_size, img.c, img.h, img.w], dtype) + i = 0 + for c in range(img.c): + for h in range(img.h): + for k in range(img.w): + data[0][c][h][k] = img.data[i] + i = i + 1 + + target = 'llvm' + shape_dict = {'data': data.shape} + #with nnvm.compiler.build_config(opt_level=2): + graph, library, params = nnvm.compiler.build(sym, target, shape_dict, dtype, params=params) + ###################################################################### + # Execute on TVM + # --------------- + # The process is no different from other examples. + from tvm.contrib import graph_runtime + ctx = tvm.cpu(0) + m = graph_runtime.create(graph, library, ctx) + # set inputs + m.set_input('data', tvm.nd.array(data.astype(dtype))) + m.set_input(**params) + m.run() + # get outputs + out_shape = (net.outputs,) + tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy() + return tvm_out + + test_image = 'dog.jpg' + img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + test_image +'?raw=true' + _download(img_url, test_image) + img = LIB.letterbox_image(LIB.load_image_color(test_image.encode('utf-8'), 0, 0), net.w, net.h) + darknet_output = get_darknet_output(net, img) + darknet_out = np.zeros(net.outputs, dtype='float32') + for i in range(net.outputs): + darknet_out[i] = darknet_output[i] + tvm_out = get_tvm_output(net, img) + np.testing.assert_allclose(darknet_out, tvm_out, rtol=1e-3, atol=1e-3) + +def test_forward_extraction(): + '''test extraction model''' + model_name = 'extraction' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + _download(cfg_url, cfg_name) + _download(weights_url, weights_name) + net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0) + test_forward(net) + LIB.free_network(net) + +def test_forward_alexnet(): + '''test alexnet model''' + model_name = 'alexnet' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + _download(cfg_url, cfg_name) + _download(weights_url, weights_name) + net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0) + test_forward(net) + LIB.free_network(net) + +def test_forward_resnet50(): + '''test resnet50 model''' + model_name = 'resnet50' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + _download(cfg_url, cfg_name) + _download(weights_url, weights_name) + net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0) + test_forward(net) + LIB.free_network(net) + +def test_forward_yolo(): + '''test yolo model''' + model_name = 'yolo' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + _download(cfg_url, cfg_name) + _download(weights_url, weights_name) + net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0) + test_forward(net) + LIB.free_network(net) + +def test_forward_convolutional(): + '''test convolutional layer''' + net = LIB.make_network(1) + layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0) + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + test_forward(net) + LIB.free_network(net) + +def test_forward_dense(): + '''test fully connected layer''' + net = LIB.make_network(1) + layer = LIB.make_connected_layer(1, 75, 20, 1, 0, 0) + net.layers[0] = layer + net.w = net.h = 5 + LIB.resize_network(net, 5, 5) + test_forward(net) + LIB.free_network(net) + +def test_forward_maxpooling(): + '''test maxpooling layer''' + net = LIB.make_network(1) + layer = LIB.make_maxpool_layer(1, 224, 224, 3, 2, 2, 0) + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + test_forward(net) + LIB.free_network(net) + +def test_forward_avgpooling(): + '''test avgerage pooling layer''' + net = LIB.make_network(1) + layer = LIB.make_avgpool_layer(1, 224, 224, 3) + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + test_forward(net) + LIB.free_network(net) + +def test_forward_batch_norm(): + '''test batch normalization layer''' + net = LIB.make_network(1) + layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 1, 0, 0, 0) + for i in range(32): + layer.rolling_mean[i] = np.random.rand(1) + layer.rolling_variance[i] = np.random.rand(1) + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + test_forward(net) + LIB.free_network(net) + +def test_forward_shortcut(): + '''test shortcut layer''' + net = LIB.make_network(3) + layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0) + layer_2 = LIB.make_convolutional_layer(1, 111, 111, 32, 32, 1, 1, 1, 0, 1, 0, 0, 0, 0) + layer_3 = LIB.make_shortcut_layer(1, 0, 111, 111, 32, 111, 111, 32) + layer_3.activation = 1 + net.layers[0] = layer_1 + net.layers[1] = layer_2 + net.layers[2] = layer_3 + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + test_forward(net) + LIB.free_network(net) + +def test_forward_reorg(): + '''test reorg layer''' + net = LIB.make_network(2) + layer_1 = LIB.make_convolutional_layer(1, 222, 222, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0) + layer_2 = LIB.make_reorg_layer(1, 110, 110, 32, 2, 0, 0, 0) + net.layers[0] = layer_1 + net.layers[1] = layer_2 + net.w = net.h = 222 + LIB.resize_network(net, 222, 222) + test_forward(net) + LIB.free_network(net) + +def test_forward_region(): + '''test region layer''' + net = LIB.make_network(2) + layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 8, 1, 3, 2, 0, 1, 0, 0, 0, 0) + layer_2 = LIB.make_region_layer(1, 111, 111, 2, 2, 1) + layer_2.softmax = 1 + net.layers[0] = layer_1 + net.layers[1] = layer_2 + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + test_forward(net) + LIB.free_network(net) + +if __name__ == '__main__': + test_forward_resnet50() + test_forward_alexnet() + test_forward_extraction() + test_forward_yolo() + test_forward_convolutional() + test_forward_maxpooling() + test_forward_avgpooling() + test_forward_batch_norm() + test_forward_shortcut() + test_forward_dense() + test_forward_reorg() + test_forward_region() diff --git a/nnvm/tutorials/from_darknet.py b/nnvm/tutorials/from_darknet.py new file mode 100644 index 000000000000..b10327168b6a --- /dev/null +++ b/nnvm/tutorials/from_darknet.py @@ -0,0 +1,227 @@ +""" +Tutorial for running Yolo-V2 in Darknet Models +===================== +**Author**: `Siju Samuel `_ + +This article is an introductory tutorial to deploy darknet models with NNVM. + +All the required models and libraries will be downloaded from the internet + +by the script. + +This script runs the YOLO-V2 Model with the bounding boxes + +Darknet parsing have dependancy with CFFI and CV2 library + +Please install CFFI and CV2 before executing this script + +pip install cffi + +pip install opencv-python +""" +from ctypes import * +import math +import random +import nnvm +import nnvm.frontend.darknet +from nnvm.testing.darknet import __darknetffi__ +import matplotlib.pyplot as plt +import numpy as np +import tvm +import os, sys, time, urllib, requests +if sys.version_info >= (3,): + import urllib.request as urllib2 + import urllib.parse as urlparse +else: + import urllib2 + import urlparse + +###################################################################### +# Set the parameters here. +# Supported models alexnet, resnet50, resnet152, extraction, yolo +###################################################################### +model_name = 'yolo' +test_image = 'dog.jpg' +target = 'llvm' +ctx = tvm.cpu(0) +###################################################################### + +def dlProgress(count, block_size, total_size): + """Show the download progress.""" + global start_time + if count == 0: + start_time = time.time() + return + duration = time.time() - start_time + progress_size = int(count * block_size) + speed = int(progress_size / (1024 * duration)) + percent = int(count * block_size * 100 / total_size) + sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % + (percent, progress_size / (1024 * 1024), speed, duration)) + sys.stdout.flush() + +def download(url, path, overwrite=False, sizecompare=False): + """Downloads the file from the internet. + Set the input options correctly to overwrite or do the size comparison + + Parameters + ---------- + url : str + Operator name, such as Convolution, Connected, etc + path : str + List of input symbols. + overwrite : dict + Dict of operator attributes + sizecompare : dict + Dict of operator attributes + + Returns + ------- + out_name : converted out name of operation + sym : nnvm.Symbol + Converted nnvm Symbol + """ + if os.path.isfile(path) and not overwrite: + if (sizecompare): + fileSize = os.path.getsize(path) + resHead = requests.head(url) + resGet = requests.get(url,stream=True) + if 'Content-Length' not in resHead.headers : + resGet = urllib2.urlopen(url) + urlFileSize = int(resGet.headers['Content-Length']) + if urlFileSize != fileSize: + print ("exist file got corrupted, downloading", path , " file freshly") + download(url, path, True, False) + return + print('File {} exists, skip.'.format(path)) + return + print('Downloading from url {} to {}'.format(url, path)) + try: + urllib.request.urlretrieve(url, path, reporthook=dlProgress) + print('') + except: + urllib.urlretrieve(url, path, reporthook=dlProgress) + +###################################################################### +# Prepare cfg and weights file +# Pretrained model available https://pjreddie.com/darknet/imagenet/ +# -------------------------------------------------------------------- +# Download cfg and weights file first time. + +cfg_name = model_name + '.cfg' +weights_name = model_name + '.weights' +cfg_url = 'https://github.com/siju-samuel/darknet/blob/master/cfg/' + \ + cfg_name + '?raw=true' +weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + +download(cfg_url, cfg_name) +download(weights_url, weights_name) + +###################################################################### +# Download and Load darknet library +# --------------------------------- + +darknet_lib = 'libdarknet.so' +darknetlib_url = 'https://github.com/siju-samuel/darknet/blob/master/lib/' + \ + darknet_lib + '?raw=true' +download(darknetlib_url, darknet_lib) + +#if the file doesnt exist, then exit normally. +if os.path.isfile('./' + darknet_lib) is False: + exit(0) + +darknet_lib = __darknetffi__.dlopen('./' + darknet_lib) +cfg = "./" + str(cfg_name) +weights = "./" + str(weights_name) +net = darknet_lib.load_network(cfg.encode('utf-8'), weights.encode('utf-8'), 0) +dtype = 'float32' +batch_size = 1 +print("Converting darknet to nnvm symbols...") +sym, params = nnvm.frontend.darknet.from_darknet(net, dtype) + +###################################################################### +# Compile the model on NNVM +# -------------------------------------------------------------------- +# compile the model +data = np.empty([batch_size, net.c ,net.h, net.w], dtype); +shape = {'data': data.shape} +print("Compiling the model...") +with nnvm.compiler.build_config(opt_level=2): + graph, lib, params = nnvm.compiler.build(sym, target, shape, dtype, params) + +##################################################################### +# Save the json +# -------------------------------------------------------------------- +def save_lib(): + #Save the graph, params and .so to the current directory + print("Saving the compiled output...") + path_name = 'nnvm_darknet_' + model_name + path_lib = path_name + '_deploy_lib.so' + lib.export_library(path_lib) + with open(path_name ++ "deploy_graph.json", "w") as fo: + fo.write(graph.json()) + with open(path_name ++ "deploy_param.params", "wb") as fo: + fo.write(nnvm.compiler.save_param_dict(params)) +#save_lib() + +###################################################################### +# Load a test image +# -------------------------------------------------------------------- +print("Loading the test image...") +img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + \ + test_image +'?raw=true' +download(img_url, test_image) + +data = nnvm.testing.darknet.load_image(test_image, net.w, net.h) + +###################################################################### +# Execute on TVM +# -------------------------------------------------------------------- +# The process is no different from other examples. +from tvm.contrib import graph_runtime + +m = graph_runtime.create(graph, lib, ctx) + +# set inputs +m.set_input('data', tvm.nd.array(data.astype(dtype))) +m.set_input(**params) +# execute +print("Running the test image...") + +m.run() +# get outputs +out_shape = (net.outputs,) +tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy() + +#do the detection and bring up the bounding boxes +thresh = 0.24 +hier_thresh = 0.5 +img = nnvm.testing.darknet.load_image_color(test_image) +_, im_h, im_w = img.shape +probs= [] +boxes = [] +region_layer = net.layers[net.n - 1] +boxes, probs = nnvm.testing.yolo2_detection.get_region_boxes(region_layer, im_w, im_h, net.w, net.h, + thresh, probs, boxes, 1, tvm_out) + +boxes, probs = nnvm.testing.yolo2_detection.do_nms_sort(boxes, probs, + region_layer.w*region_layer.h*region_layer.n, region_layer.classes, 0.3) + +coco_name = 'coco.names' +coco_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + coco_name +'?raw=true' +font_name = 'arial.ttf' +font_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + font_name +'?raw=true' +download(coco_url, coco_name) +download(font_url, font_name) + +with open(coco_name) as f: + content = f.readlines() + +names = [x.strip() for x in content] + +nnvm.testing.yolo2_detection.draw_detections(img, region_layer.w*region_layer.h*region_layer.n, + thresh, boxes, probs, names, region_layer.classes) +plt.imshow(img.transpose(1,2,0)) +plt.show()