diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 283da4b15f9a0..2c96a0745150c 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -314,22 +314,6 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode { } }; -/*! \brief Attributes for adaptive pool operator */ -struct AdaptivePool2DAttrs : public tvm::AttrsNode { - Array output_size; - std::string layout; - - TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relay.attrs.AdaptivePool2DAttrs") { - TVM_ATTR_FIELD(output_size).set_default(Array({})) - .describe("Output height and width."); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - } -}; - /*! \brief Attributes for dense operator */ struct DenseAttrs : public tvm::AttrsNode { diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index ed4e964383eb3..a75f39fb54560 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -394,6 +394,39 @@ struct GlobalPool2DParam : public dmlc::Parameter { } }; + +struct AdaptiveMaxPool2DParam : public dmlc::Parameter { + TShape output_size; + std::string layout; + + DMLC_DECLARE_PARAMETER(AdaptiveMaxPool2DParam) { + DMLC_DECLARE_FIELD(output_size) + .describe("Output height and width"); + DMLC_DECLARE_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + } +}; + + +struct AdaptiveAvgPool2DParam : public dmlc::Parameter { + TShape output_size; + std::string layout; + + DMLC_DECLARE_PARAMETER(AdaptiveAvgPool2DParam) { + DMLC_DECLARE_FIELD(output_size) + .describe("Output height and width"); + DMLC_DECLARE_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + } +}; + + struct UpSamplingParam : public dmlc::Parameter { int scale; std::string layout; diff --git a/nnvm/python/nnvm/frontend/__init__.py b/nnvm/python/nnvm/frontend/__init__.py index 49f53df1174f2..e2aee09141780 100644 --- a/nnvm/python/nnvm/frontend/__init__.py +++ b/nnvm/python/nnvm/frontend/__init__.py @@ -7,3 +7,4 @@ from .darknet import from_darknet from .tensorflow import from_tensorflow from .caffe2 import from_caffe2 +from .pytorch import from_pytorch diff --git a/nnvm/python/nnvm/frontend/pytorch/__init__.py b/nnvm/python/nnvm/frontend/pytorch/__init__.py new file mode 100644 index 0000000000000..2e1f99f70b1b0 --- /dev/null +++ b/nnvm/python/nnvm/frontend/pytorch/__init__.py @@ -0,0 +1,2 @@ +r'''PyTorch->NNVM converter''' +from .converter import from_pytorch diff --git a/nnvm/python/nnvm/frontend/pytorch/aten.py b/nnvm/python/nnvm/frontend/pytorch/aten.py new file mode 100644 index 0000000000000..8c1aecb532637 --- /dev/null +++ b/nnvm/python/nnvm/frontend/pytorch/aten.py @@ -0,0 +1,680 @@ +r'''This file contains one class per PyTorch ATen operator. For the full +list of operators, see +https://github.com/zdevito/ATen/blob/master/aten/src/ATen/native/native_functions.yaml +''' +import operator +from functools import reduce +import numpy as np +import tvm +from nnvm.symbol import Symbol +from .base import PyTorchOp, attr_2d, make_symbol + + +class ATenOp(PyTorchOp): + r'''Base class for ATen operators''' + + def __init__(self, node, graph): + super(ATenOp, self).__init__(node, graph) + self.dtype = 'float32' + +class Device(ATenOp): + r'''aten::device operator''' + + def __init__(self, node, graph): + super(Device, self).__init__(node, graph) + self.set_output(0, self.get_output_name(0), None) + + +class AllSame(ATenOp): + r'''Base class of aten::ones and aten::zeros''' + + def __init__(self, node, graph, val): + super(AllSame, self).__init__(node, graph) + val = float(val) + shape = self.get_input(0).get_output(0) + if not shape: + self.set_output(0, self.get_output_name(0), val) + else: + attrs = { + 'shape': shape, + 'dtype': 'float32', + 'fill_value': val, + } + self.set_output(0, self.get_output_name(0), + make_symbol('full', **attrs)) + + +class Ones(AllSame): + r'''aten::ones operator''' + def __init__(self, node, graph): + super(Ones, self).__init__(node, graph, 1) + + +class Zeros(AllSame): + r'''aten::zeros operator''' + def __init__(self, node, graph): + super(Zeros, self).__init__(node, graph, 0) + + +class HardTanh(ATenOp): + r'''aten::hardtanh and aten::hardtanh_ operators''' + + def __init__(self, node, graph): + super(HardTanh, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + attrs = { + 'a_min': self.get_input(1).get_output(0), + 'a_max': self.get_input(2).get_output(0), + } + self.set_output(0, self.get_output_name(0), + make_symbol('clip', *inputs, **attrs)) + + +class Conv2D(ATenOp): + r'''aten::_convolution operator''' + + def __init__(self, node, graph): + super(Conv2D, self).__init__(node, graph) + if self.get_input(6).get_output(0): + topi_name = 'conv2d_transpose' + else: + topi_name = 'conv2d' + inputs = [self.get_input(i).get_output(0) for i in [0, 1, 2]] + if inputs[2] is None: + bias = np.zeros([self.inputs[1].shape[0]]).astype('float32') + self.graph.add_param(self.get_input_name(2), bias) + inputs[2] = self.graph[self.get_input_name(2)].get_output(0) + weight_shape = self.get_input(1).shape + attrs = { + 'channels': weight_shape[0], + 'kernel_size': weight_shape[2:], + 'strides': self.get_input(3).get_output(0), + 'padding': self.get_input(4).get_output(0), + 'dilation': self.get_input(5).get_output(0), + 'groups': self.get_input(8).get_output(0), + 'kernel_layout': 'HWIO', + } + self.set_output(0, self.get_output_name(0), + make_symbol(topi_name, *inputs, **attrs)) + + +class Threshold(ATenOp): + r'''aten::threshold operator. Returns constant if input is less than or + equal to threshold. Otherwise, returns input.''' + + def __init__(self, node, graph): + super(Threshold, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + attrs = { + 'threshold': self.get_input(1).get_output(0), + 'constant': self.get_input(2).get_output(0), + } + if attrs['threshold'] != attrs['constant']: + msg = 'For aten::threshold_, threshold != constant is not ' \ + 'implemented.' + raise RuntimeError(msg) + self.set_output(0, self.get_output_name(0), + make_symbol('relu', *inputs, **attrs)) + + +class Pad(ATenOp): + r'''aten::constant_pad_nd operator''' + + def __init__(self, node, graph): + super(Pad, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + padding = self.get_input(1).get_output(0) + attrs = { + 'pad_width': list(zip(padding, padding)), + 'pad_value': self.get_input(2).get_output(0), + } + self.set_output(0, self.get_output_name(0), + make_symbol('pad', *inputs, **attrs)) + + +class BatchNorm(ATenOp): + r'''aten::batch_norm operator''' + + def __init__(self, node, graph): + super(BatchNorm, self).__init__(node, graph) + self.topi_name = 'batch_norm' + inputs = [self.get_input(i).get_output(0) for i in range(5)] + attrs = { + 'epsilon': self.get_input(7).get_output(0), + } + self.set_output(0, self.get_output_name(0), + make_symbol('batch_norm', *inputs, **attrs)) + + +class Concatenate(ATenOp): + r'''aten::cat operator''' + + def __init__(self, node, graph): + super(Concatenate, self).__init__(node, graph) + inputs = self.get_input(0).get_output(0) + attrs = { + 'axis': self.get_input(1).get_output(0), + } + self.set_output(0, self.get_output_name(0), + make_symbol('concatenate', *inputs, **attrs)) + + +class PermuteAxes(ATenOp): + r'aten::t, aten::transpose, aten::permute operators''' + + def __init__(self, node, graph): + super(PermuteAxes, self).__init__(node, graph) + ndims = len(self.get_input(0).shape) + axes = list(range(ndims)) + num_inputs = len(self.inputs) + if num_inputs == 1: + axes[-1] = ndims - 2 + axes[-2] = ndims - 1 + elif num_inputs == 3: + parse = lambda i: ndims * (i < 0) + i + src, dst = [parse(self.get_input(i).get_output(0)) for i in [1, 2]] + axes[src] = dst + axes[dst] = src + else: + axes = self.get_input(1).get_output(0) + attrs = { + 'axes': axes, + } + inputs = [self.get_input(0).get_output(0)] + self.set_output(0, self.get_output_name(0), + make_symbol('transpose', *inputs, **attrs)) + + +class Size(ATenOp): + r'''aten::size operator''' + + def __init__(self, node, graph): + super(Size, self).__init__(node, graph) + axis = self.get_input(1).get_output(0) + self.set_output(0, self.get_output_name(0), + self.get_input(0).shape[axis]) + + +class View(ATenOp): + r'''aten::view operator''' + + def __init__(self, node, graph): + super(View, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + attrs = { + 'shape': self.get_input(1).get_output(0), + } + self.set_output(0, self.get_output_name(0), + make_symbol('reshape', *inputs, **attrs)) + + +class Select(ATenOp): + r'''aten::select operator''' + + def __init__(self, node, graph): + super(Select, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + self._dim = self.get_input(1).get_output(0) + index = self.get_input(2).get_output(0) + end = self.get_input(0).shape[:] + end[self._dim] = index + 1 + begin = [0] * len(end) + begin[self._dim] = index + self.attrs = { + 'begin': begin, + 'end': end, + 'stride': 1, + } + sym = make_symbol('strided_slice', *inputs, **self.attrs) + inputs = [sym] + attrs = { + 'axis': self._dim, + } + self.set_output(0, self.get_output_name(0), + make_symbol('squeeze', *inputs, **attrs)) + + @property + def shape(self): + r'''Get the shape''' + if not hasattr(self, '_shape'): + begin = np.array(self.attrs['begin']).astype(int) + end = np.array(self.attrs['end']).astype(int) + shape = (end - begin).tolist() + return shape[:self._dim] + shape[self._dim + 1:] + + +class Copy(ATenOp): + r'''aten::copy operator''' + + def __init__(self, node, graph): + super(Copy, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + self.set_output(0, self.get_output_name(0), + make_symbol('copy', *inputs)) + + +class ReLU(ATenOp): + r'''aten::relu and aten::relu_ operators''' + + def __init__(self, node, graph): + super(ReLU, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + self.set_output(0, self.get_output_name(0), + make_symbol('relu', *inputs)) + + +class LogSoftmax(ATenOp): + r'''aten::log_softmax operator''' + + def __init__(self, node, graph): + super(LogSoftmax, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + self.set_output(0, self.get_output_name(0), + make_symbol('log_softmax', *inputs)) + + +class Sigmoid(ATenOp): + r'''aten::sigmoid operator''' + + def __init__(self, node, graph): + super(Sigmoid, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + self.set_output(0, self.get_output_name(0), + make_symbol('sigmoid', *inputs)) + + +class MatMul(ATenOp): + r'''aten::matmul operator''' + + def __init__(self, node, graph): + super(MatMul, self).__init__(node, graph) + inputs = [self.get_input(i).get_output(0) for i in range(2)] + self.set_output(0, self.get_output_name(0), + make_symbol('matmul', *inputs)) + + +class Dense(ATenOp): + r'''aten::addmm operator''' + + def __init__(self, node, graph): + super(Dense, self).__init__(node, graph) + inputs = [self.get_input(i).get_output(0) for i in [1, 2, 0]] + units = self.get_input(2).shape[1] + attrs = { + 'units': units, + } + alpha = self.get_input(4).get_output(0) + beta = self.get_input(3).get_output(0) + inputs[0] *= alpha + inputs[1] = make_symbol('transpose', beta * inputs[1], axes=[1, 0]) + self.set_output(0, self.get_output_name(0), + make_symbol('dense', *inputs, **attrs)) + + +class MaxPool2D(ATenOp): + r'''aten::max_pool2d_with_indices operator''' + + def __init__(self, node, graph): + super(MaxPool2D, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + if attr_2d(self.get_input(4).get_output(0), 1) != [1, 1]: + raise RuntimeError('Only dilation = 1 supported') + attrs = { + 'pool_size': attr_2d(self.get_input(1).get_output(0)), + 'strides': attr_2d(self.get_input(2).get_output(0), 1), + 'padding': attr_2d(self.get_input(3).get_output(0), 0), + 'ceil_mode': self.get_input(5).get_output(0), + } + self.set_output(0, self.get_output_name(0), + make_symbol('max_pool2d', *inputs, **attrs)) + + +class AvgPool2D(ATenOp): + r'''aten::avg_pool2d operator''' + + def __init__(self, node, graph): + super(AvgPool2D, self).__init__(node, graph) + self.topi_name = 'avg_pool2d' + inputs = [self.get_input(0).get_output(0)] + attrs = { + 'pool_size': attr_2d(self.get_input(1).get_output(0)), + 'strides': attr_2d(self.get_input(2).get_output(0), 1), + 'padding': attr_2d(self.get_input(3).get_output(0), 0), + 'ceil_mode': self.get_input(4).get_output(0), + 'count_include_pad': self.get_input(5).get_output(0), + } + self.set_output(0, self.get_output_name(0), + make_symbol('avg_pool2d', *inputs, **attrs)) + + +class AdaptivePool2D(ATenOp): + r'''Base class for adaptive pooling operators such as + aten::adaptive_avg_pool2d and aten::adaptive_max_pool2d''' + + def __init__(self, node, graph, pool_type): + super(AdaptivePool2D, self).__init__(node, graph) + topi_name = 'adaptive_{}_pool2d'.format(pool_type) + inputs = [self.get_input(0).get_output(0)] + attrs = { + 'output_size': self.get_input(1).get_output(0), + } + self.set_output(0, self.get_output_name(0), + make_symbol(topi_name, *inputs, **attrs)) + + +class AdaptiveAvgPool2D(AdaptivePool2D): + r'''aten::adaptive_avg_pool2d operator''' + + def __init__(self, node, graph): + super(AdaptiveAvgPool2D, self).__init__(node, graph, 'avg') + + +class AdaptiveMaxPool2D(AdaptivePool2D): + r'''aten::adaptive_max_pool2d operator''' + + def __init__(self, node, graph): + super(AdaptiveMaxPool2D, self).__init__(node, graph, 'max') + + +class Dropout(ATenOp): + r'''aten::dropout operator''' + + def __init__(self, node, graph): + super(Dropout, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + if self.get_input(2).get_output(0): + rate = 0 + else: + rate = self.get_input(1).get_output(0) + attrs = { + 'rate': rate, + } + self.set_output(0, self.get_output_name(0), + make_symbol('dropout', *inputs, **attrs)) + + +class Slice(ATenOp): + r'''aten::slice operator''' + + def __init__(self, node, graph): + super(Slice, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + end = self.get_input(0).shape[:] + begin = [0] * len(end) + dim = self.get_input(1).get_output(0) + begin[dim] = self.get_input(2).get_output(0) + end[dim] = min(end[dim], self.get_input(3).get_output(0)) + attrs = { + 'begin': begin, + 'end': end, + 'stride': self.get_input(4).get_output(0), + } + self.set_output(0, self.get_output_name(0), + make_symbol('strided_slice', *inputs, **attrs)) + + +class BinaryOp(ATenOp): + r'''Base class for binary operators such as aten::add and aten::mul''' + + def __init__(self, node, graph, operator_name): + def prep(node): + out = node.get_output(0) + if isinstance(out, Symbol): + return out + if isinstance(out, tvm.nd.NDArray): + out = out.asnumpy() + return float(out) + ATenOp.__init__(self, node, graph) + linput, rinput = [prep(self.get_input(i)) for i in [0, 1]] + if not all(isinstance(inp, Symbol) for inp in [linput, rinput]): + self.set_output(0, self.get_output_name(0), + reduce(getattr(operator, operator_name), [linput, rinput])) + else: + topi_name = 'broadcast_' + operator_name + self.set_output(0, self.get_output_name(0), + make_symbol(topi_name, linput, rinput)) + + +class Subtract(BinaryOp): + r'''aten::sub and aten::sub_ operators''' + + def __init__(self, node, graph): + super(Subtract, self).__init__(node, graph, 'sub') + + +class Add(BinaryOp): + r'''aten::add and aten::add_ operators''' + + def __init__(self, node, graph): + super(Add, self).__init__(node, graph, 'add') + + +class Multiply(BinaryOp): + r'''aten::mul and aten::mul_ operators''' + + def __init__(self, node, graph): + super(Multiply, self).__init__(node, graph, 'mul') + + +class Divide(BinaryOp): + r'''aten::div and aten::div_ operators''' + + def __init__(self, node, graph): + super(Divide, self).__init__(node, graph, 'div') + + +class Unsqueeze(ATenOp): + r'''aten::unsqueeze operator''' + + def __init__(self, node, graph): + super(Unsqueeze, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + axis = self.get_input(1).get_output(0) + self.set_output(0, self.get_output_name(0), + make_symbol('expand_dims', *inputs, axis=axis)) + + +class Expand(ATenOp): + r'''aten::expand operator''' + + def __init__(self, node, graph): + super(Expand, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + shape = self.get_input(0).shape + ndims = len(shape) + sizes = self.get_input(1).get_output(0) + self._shape = [max(shape[i], sizes[i]) for i in range(ndims)] + out = self.get_input(0).get_output(0) + out = self.get_input(0).get_output(0) + for i in range(ndims): + if sizes[i] in {-1, shape[i]}: + continue + inputs = [out] * sizes[i] + out = make_symbol('concatenate', *inputs, axis=i) + self.set_output(0, self.get_output_name(0), out) + + +class To(ATenOp): + r'''aten::to operator''' + + def __init__(self, node, graph): + super(To, self).__init__(node, graph) + self.set_output(0, self.get_output_name(0), + self.get_input(0).get_output(0)) + + +class Pow(ATenOp): + r'''aten::pow operator''' + + def __init__(self, node, graph): + super(Pow, self).__init__(node, graph) + val = self.get_input(1).get_output(0) + self.set_output(0, self.get_output_name(0), + self.get_input(0).get_output(0) ** val) + + +class Chunk(ATenOp): + r'''aten::chunk operator''' + + def __init__(self, node, graph): + super(Chunk, self).__init__(node, graph) + num_chunks = self.get_input(1).get_output(0) + axis = self.get_input(2).get_output(0) + shape = self.get_input(0).shape + dim = int(shape[axis]) + if dim % num_chunks: + unif_size = int(dim / (num_chunks - 1)) + else: + unif_size = int(dim / num_chunks) + chunks = [] + for i in range(0, dim, unif_size): + begin = [0] * len(shape) + end = shape[:] + begin[axis] = i + end[axis] = i + unif_size + attrs = { + 'begin': begin, + 'end': end, + 'stride': [1] * len(shape), + } + chunk = make_symbol('strided_slice', + self.get_input(0).get_output(0), + **attrs) + chunks.append(chunk) + if dim % num_chunks: + begin = [0] * len(shape) + end = shape[:] + begin[axis] = unif_size * (num_chunks - 1) + end[axis] = dim + attrs = { + 'begin': begin, + 'end': end, + 'stride': [1] * len(shape), + } + chunk = make_symbol('strided_slice', + self.get_input(0).get_output(0), + **attrs) + chunks.append(chunk) + self.set_output(0, self.get_output_name(0), chunks) + + +class Reduce(ATenOp): + r'''Base class for reduce operations such as aten::max, aten::sum, and aten::prod''' + + def __init__(self, node, graph, topi_name): + super(Reduce, self).__init__(node, graph) + if len(self.inputs) > 1: + inputs = [self.get_input(0).get_output(0)] + axis = self.get_input(1).get_output(0) + else: + inputs = [self.get_input(0).get_output(0)] + axis = list(range(len(self.inputs[0].shape))) + self.set_output(0, self.get_output_name(0), + make_symbol(topi_name, *inputs, axis=axis)) + + +class Max(BinaryOp, Reduce, ATenOp): + r'''Converts all aten::max operations, including both the binary op and the reduce op''' + + def __init__(self, node, graph): + def is_binary_op_arg(node): + out = node.get_output(0) + return isinstance(out, (Symbol, tvm.nd.NDArray)) + ATenOp.__init__(self, node, graph) + if len(self.inputs) > 1: + if all(is_binary_op_arg(self.get_input(i)) for i in [0, 1]): + BinaryOp.__init__(self, node, graph, 'max') + return + Reduce.__init__(self, node, graph, 'max') + + +class Sum(Reduce): + r'''Sum over all elements of the input tensor or along specified axes''' + + def __init__(self, node, graph): + super(Sum, self).__init__(node, graph, 'sum') + + +class Min(Reduce): + r'''Compute the min over all elements of the input tensor or along specified axes''' + + def __init__(self, node, graph): + super(Min, self).__init__(node, graph, 'min') + + +class Prod(Reduce): + r'''Compute the product of all elements of the input tensor or along specified axes''' + + def __init__(self, node, graph): + super(Prod, self).__init__(node, graph, 'prod') + + +class Mean(Reduce): + r'''Compute the mean of all elements of the input tensor or along specified axes''' + def __init__(self, node, graph): + super(Mean, self).__init__(node, graph, 'mean') + + +class Sqrt(ATenOp): + r'''Compute the elementwise square root''' + + def __init__(self, node, graph): + super(Sqrt, self).__init__(node, graph) + inputs = [self.get_input(0).get_output(0)] + self.set_output(0, self.get_output_name(0), + make_symbol('sqrt', *inputs)) + + +ATEN_MAP = { + 'device': Device, + 'ones': Ones, + 'zeros': Zeros, + 'hardtanh': HardTanh, + 'hardtanh_': HardTanh, + '_convolution': Conv2D, + 'threshold': Threshold, + 'threshold_': Threshold, + 'constant_pad_nd': Pad, + 'contiguous': Copy, + 'batch_norm': BatchNorm, + 'cat': Concatenate, + 't': PermuteAxes, + 'transpose': PermuteAxes, + 'transpose_': PermuteAxes, + 'permute': PermuteAxes, + 'size': Size, + 'view': View, + 'select': Select, + 'clone': Copy, + 'relu': ReLU, + 'relu_': ReLU, + 'log_softmax': LogSoftmax, + 'sigmoid': Sigmoid, + 'addmm': Dense, + 'matmul': MatMul, + 'max_pool2d_with_indices': MaxPool2D, + 'avg_pool2d': AvgPool2D, + 'adaptive_max_pool2d': AdaptiveMaxPool2D, + 'adaptive_avg_pool2d': AdaptiveAvgPool2D, + 'dropout': Dropout, + 'slice': Slice, + 'sub': Subtract, + 'sub_': Subtract, + 'add': Add, + 'add_': Add, + 'mul': Multiply, + 'mul_': Multiply, + 'div': Divide, + 'div_': Divide, + 'unsqueeze': Unsqueeze, + 'expand': Expand, + 'to': To, + 'pow': Pow, + 'chunk': Chunk, + 'max': Max, + 'sum': Sum, + 'min': Min, + 'prod': Prod, + 'mean': Mean, + 'sqrt': Sqrt, +} diff --git a/nnvm/python/nnvm/frontend/pytorch/base.py b/nnvm/python/nnvm/frontend/pytorch/base.py new file mode 100644 index 0000000000000..f792803c9f83d --- /dev/null +++ b/nnvm/python/nnvm/frontend/pytorch/base.py @@ -0,0 +1,174 @@ +r'''Basic classes for PyTorch operators and graphs''' +from collections import OrderedDict +from nnvm.symbol import Variable, Symbol +from nnvm.frontend.common import get_nnvm_op +from nnvm.compiler.graph_util import infer_shape +from nnvm.graph import create + + +def make_symbol(topi_name, *inputs, **attrs): + r'''Create an NNVM symbol given a Topi name, inputs, and attrs.''' + return get_nnvm_op(topi_name)(*inputs, **attrs) + + +def attr_2d(val, default=None): + r'''Helper function for computing attributes of 2D functions''' + if not val: + return [default] * 2 + if isinstance(val, list): + return val + return [int(val)] * 2 + + +class PyTorchGraph: + r'''Wrapper for the PyTorch JIT IR graph''' + + def __init__(self): + self.inputs = OrderedDict() + self.params = OrderedDict() + self.ops = OrderedDict() + self.outputs = OrderedDict() + + def __getitem__(self, name): + if name in self.inputs: + return self.inputs[name] + if name in self.params: + return self.params[name] + if name in self.ops: + return self.ops[name] + if name in self.outputs: + return self.outputs[name] + raise RuntimeError('Node {} not found.'.format(name)) + + def __contains__(self, name): + attrs = ['inputs', 'params', 'ops', 'outputs'] + return any(name in getattr(self, k) for k in attrs) + + def add_input(self, name, tensor): + r'''Add an input of the PyTorch model''' + self.inputs[name] = PyTorchInput(name, tensor, self) + + def add_param(self, name, tensor): + r'''Add a param of the PyTorch model''' + self.params[name] = PyTorchParam(name, tensor.astype('float32'), self) + + def add_op(self, op_node): + r'''Add an operator and its associated outputs of the PyTorch model''' + self.ops[op_node.name] = op_node + for i in range(len(op_node.outputs)): + self.outputs[op_node.output_names[i]] = op_node.outputs[i] + + +class PyTorchNode: + r'''Base class for PyTorch scalar, tensors, and operators''' + + def __init__(self, graph): + self.graph = graph + self.input_names = [] + self.inputs = [] + self.output_names = [] + self.outputs = [] + + def get_output_name(self, index): + r'''Get the name of the output at the given index''' + return self.output_names[index] + + def get_output(self, index): + r'''Get the parsed output at the given index''' + return self.outputs[index] + + def set_output(self, index, name, val): + r'''Set the output at the given index with the specified name and value''' + while len(self.output_names) <= index: + self.output_names.append('') + while len(self.outputs) <= index: + self.outputs.append(None) + self.output_names[index] = name + self.outputs[index] = val + + +class PyTorchConstantTensor(PyTorchNode): + r'''Base class for PyTorch input tensors and parameter tensors''' + + def __init__(self, name, arr, graph): + super(PyTorchConstantTensor, self).__init__(graph) + self.name = name + self.arr = arr + self.dtype = self.arr.dtype.name + output = Variable(name=self.name, shape=self.shape, + dtype=self.dtype) + self.set_output(0, name, output) + + @property + def shape(self): + r'''Get the shape of the tensor''' + return list(self.arr.shape) + + +class PyTorchInput(PyTorchConstantTensor): + r'''PyTorch input tensors''' + + def __init__(self, name, arr, graph): + super(PyTorchInput, self).__init__(name, arr, graph) + self.kind = 'input' + + +class PyTorchParam(PyTorchConstantTensor): + r'''PyTorch parameter tensors''' + + def __init__(self, name, arr, graph): + super(PyTorchParam, self).__init__(name, arr, graph) + self.kind = 'param' + + +class PyTorchOutput(PyTorchNode): + r'''PyTorch output tensors and scalars''' + + def __init__(self, name, val, graph): + super(PyTorchOutput, self).__init__(graph) + if isinstance(val, Symbol): + self._shape = infer_shape(create(val))[1][0] + self.set_output(0, name, val) + + @property + def shape(self): + r'''Get the shape of the output''' + return self._shape[:] + + +class PyTorchOp(PyTorchNode): + r'''Base class for PyTorch Prim and ATen operators''' + + def __init__(self, node, graph): + super(PyTorchOp, self).__init__(graph) + self.kind = node.kind() + self.name = self.kind + '_' + str(len(self.graph.ops)) + self.input_names = [] + self.inputs = [] + for index, inp in enumerate(node.inputs()): + input_name = inp.uniqueName() + self.set_input(index, input_name, graph[input_name]) + for out in node.outputs(): + self.output_names.append(out.uniqueName()) + self._node = node + + def get_input_name(self, index): + r'''Get the input name at the given index''' + return self.input_names[index] + + def get_input(self, index): + r'''Get the parsed input at the specified index''' + return self.inputs[index] + + def set_input(self, index, name, val): + r'''Set the input at the given index with the specified name and value''' + while len(self.input_names) <= index: + self.input_names.append('') + while len(self.inputs) <= index: + self.inputs.append(None) + self.input_names[index] = name + self.inputs[index] = val + + def set_output(self, index, name, val): + node = PyTorchOutput(name, val, self.graph) + super(PyTorchOp, self).set_output(index, name, node) diff --git a/nnvm/python/nnvm/frontend/pytorch/converter.py b/nnvm/python/nnvm/frontend/pytorch/converter.py new file mode 100644 index 0000000000000..d14d10205914b --- /dev/null +++ b/nnvm/python/nnvm/frontend/pytorch/converter.py @@ -0,0 +1,127 @@ +r'''Convert PyTorch models to NNVM symbol graphs''' +from pickle import UnpicklingError +import tvm +from nnvm.symbol import Symbol, Group +import numpy as np +import torch +from .aten import ATEN_MAP +from .prim import PRIM_MAP +from .base import PyTorchGraph + + +def operator_map(kind): + namespace, op_name = kind.split('::') + return { + 'aten': ATEN_MAP, + 'prim': PRIM_MAP, + }[namespace][op_name] + + +class PyTorchConverter: + r'''Converter from PyTorch JIT IR to NNVM''' + + def __init__(self, filename, input_shapes): + self._load_model(filename, input_shapes) + self._num_inputs = len(input_shapes) + self.graph = PyTorchGraph() + self._parse_inputs(input_shapes) + self._parse_params() + self._parse_ops() + + def _load_model(self, filename, input_shapes): + try: + self._trace = torch.jit.load(filename).float().eval() + except RuntimeError: + try: + self._trace = torch.load(filename).float().eval() + except UnpicklingError: + raise RuntimeError('Failed to load model') + shapes = [input_shapes[k] for k in sorted(input_shapes)] + inputs = [torch.zeros(shape).float() for shape in shapes] + try: + self._trace = torch.jit.trace(self._trace, *inputs).float().eval() + except RuntimeError: + inputs = [inp.cuda() for inp in inputs] + self._trace = torch.jit.trace(self._trace, *inputs).float().eval() + + @property + def _ir_tensor_names(self): + return [i.uniqueName() for i in self._trace.graph.inputs()] + + def _parse_inputs(self, input_shapes): + input_names = sorted(input_shapes) + ir_names = self._ir_tensor_names[:self._num_inputs] + ir_name_map = dict(zip(input_names, ir_names)) + inv_ir_name_map = dict((v, k) for k, v in ir_name_map.items()) + for i, inp in enumerate(self._trace.graph.inputs()): + if i >= self._num_inputs: + break + ir_name = inp.uniqueName() + if ir_name in inv_ir_name_map: + inp.setUniqueName(inv_ir_name_map[ir_name]) + for input_name in sorted(input_shapes): + input_shape = input_shapes[input_name] + tensor = np.zeros(input_shape).astype(np.float32) + ir_name = ir_name_map[input_name] + for inp in self._trace.graph.inputs(): + if inp.uniqueName() == ir_name: + inp.setUniqueName(input_name) + break + self.graph.add_input(input_name, tensor) + + def _parse_params(self): + state_dict = self._trace.state_dict() + state_dict_names = list(state_dict.keys()) + ir_names = self._ir_tensor_names[self._num_inputs:] + name_map = dict(zip(state_dict_names, ir_names)) + for state_dict_name, param in state_dict.items(): + ir_name = name_map[state_dict_name] + tensor = param.cpu().numpy() + self.graph.add_param(ir_name, tensor) + + def _parse_ops(self): + unsupported_ops = set() + for node in self._trace.graph.nodes(): + kind = node.kind() + try: + operator_map(kind) + except KeyError: + unsupported_ops.add(kind) + if unsupported_ops: + ops_str = str(list(unsupported_ops)).strip('[]').replace("'", '') + msg = 'The following operators are not implemented: {}' + raise tvm.error.OpNotImplemented(msg.format(ops_str)) + for node in self._trace.graph.nodes(): + kind = node.kind() + self.graph.add_op(operator_map(kind)(node, self.graph)) + + def convert(self): + r'''Convert the parsed PyTorch model to an NNVM symbol graph and + parameter dict.''' + params = {k: tvm.nd.array(v.arr) for k, v in self.graph.params.items()} + incoming_nodes = set() + for name, op in self.graph.ops.items(): + incoming_nodes.update(op.input_names) + outputs = [] + for name in self.graph.ops: + for i in range(len(self.graph.ops[name].outputs)): + output_name = self.graph.ops[name].get_output_name(i) + node = self.graph.ops[name].get_output(i) + if output_name not in incoming_nodes: + output = node.get_output(0) + if isinstance(output, Symbol): + outputs.append(output) + elif isinstance(output, list): + is_symbol = lambda n: isinstance(n, Symbol) + outputs.extend(filter(is_symbol, output)) + if len(outputs) == 1: + output = outputs[0] + else: + output = Group(outputs) + return output, params + + +def from_pytorch(filename, input_shapes): + converter = PyTorchConverter(filename, input_shapes) + sym, params = converter.convert() + return sym, params diff --git a/nnvm/python/nnvm/frontend/pytorch/prim.py b/nnvm/python/nnvm/frontend/pytorch/prim.py new file mode 100644 index 0000000000000..7789e4e7365a5 --- /dev/null +++ b/nnvm/python/nnvm/frontend/pytorch/prim.py @@ -0,0 +1,109 @@ +r'''This file contains one class per PyTorch Prim operator. For the full list +of operators, see +https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/interned_strings.h +''' +import re +import tvm +import numpy as np +from .base import PyTorchOp + + +class PrimOp(PyTorchOp): + r'''Base class for Prim operators''' + + +class Constant(PrimOp): + r'''prim::Constant operator''' + + def __init__(self, node, graph): + super(Constant, self).__init__(node, graph) + output = next(node.outputs()) + type_kind = output.type().kind() + value = self._parse_value_from_string() + output_name = self.get_output_name(0) + if type_kind == 'IntType': + self.set_output(0, output_name, int(value)) + elif type_kind == 'FloatType': + self.set_output(0, output_name, value) + elif type_kind == 'BoolType': + self.set_output(0, output_name, bool(value)) + elif type_kind == 'CompleteTensorType' and output.type().sizes() == []: + self.shape = output.type().sizes() + arr = value * np.ones(self.shape).astype(float) + self.set_output(0, output_name, tvm.nd.array(arr)) + elif type_kind == 'StringType': + self.set_output(0, output_name, value) + else: + msg = 'Only "IntType", "FloatType", "BoolType", "StringType", and ' \ + '"CompleteTensorType" type-kinds are supported. For ' \ + '"CompleteTensorType", type-sizes must be [].' + raise RuntimeError(msg) + + def _parse_value_from_string(self): + r'''For some reason, __getitem__ is sometimes stripped from the + torch._C.Node objects.''' + pattern = r'(?<=value=)[^]]+' + string = str(self._node) + value_string = re.findall(pattern, string)[0].strip('{}') + try: + return float(value_string) + except ValueError: + return None + + +class ListConstruct(PrimOp): + r'''prim::ListConstruct operator''' + + def __init__(self, node, graph): + super(ListConstruct, self).__init__(node, graph) + self.set_output(0, self.get_output_name(0), + [inp.get_output(0) for inp in self.inputs]) + + +class Int(PrimOp): + r'''prim::Int operator''' + + def __init__(self, node, graph): + super(Int, self).__init__(node, graph) + val = self.get_input(0).get_output(0).asnumpy() + self.set_output(0, self.get_output_name(0), int(val)) + + +class NumToTensor(PrimOp): + r'''prim::NumToTensor operator''' + + def __init__(self, node, graph): + super(NumToTensor, self).__init__(node, graph) + self.shape = [] + val = self.get_input(0).get_output(0) + dtype = type(val) + arr = val * np.ones(self.shape).astype(dtype) + self.set_output(0, self.get_output_name(0), tvm.nd.array(arr)) + + +class Undefined(PrimOp): + r'''prim::Undefined operator''' + + def __init__(self, node, graph): + super(Undefined, self).__init__(node, graph) + self.set_output(0, self.get_output_name(0), None) + +class ListUnpack(PrimOp): + r'''prim::ListUnpack operator''' + + def __init__(self, node, graph): + super(ListUnpack, self).__init__(node, graph) + for i in range(len(self.output_names)): + self.set_output(i, self.get_output_name(i), + self.get_input(0).get_output(0)[i]) + + +PRIM_MAP = { + 'Constant': Constant, + 'ListConstruct': ListConstruct, + 'TupleConstruct': ListConstruct, + 'Int': Int, + 'NumToTensor': NumToTensor, + 'Undefined': Undefined, + 'ListUnpack': ListUnpack, +} diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index c496044788df6..a4fd222a5aee2 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -430,6 +430,27 @@ def schedule_global_avg_pool2d(_, outs, target): reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) + +# adaptive_max_pool2d +@reg.register_schedule("adaptive_max_pool2d") +def schedule_adaptive_max_pool2d(attrs, outs, target): + """Schedule definition of adaptive_max_pool2d""" + with tvm.target.create(target): + return topi.generic.schedule_adaptive_pool(outs) + +reg.register_pattern("adaptive_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# adaptive_avg_pool2d +@reg.register_schedule("adaptive_avg_pool2d") +def schedule_adaptive_avg_pool2d(attrs, outs, target): + """Schedule definition of adaptive_avg_pool2d""" + with tvm.target.create(target): + return topi.generic.schedule_adaptive_pool(outs) + +reg.register_pattern("adaptive_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + # upsampling @reg.register_schedule("upsampling") def schedule_upsampling(_, outs, target): diff --git a/nnvm/src/top/nn/pooling.cc b/nnvm/src/top/nn/pooling.cc index 6a53e1994fc17..05202eab51dea 100644 --- a/nnvm/src/top/nn/pooling.cc +++ b/nnvm/src/top/nn/pooling.cc @@ -413,5 +413,163 @@ NNVM_REGISTER_OP(global_avg_pool2d) .set_num_inputs(1) .set_support_level(2); +DMLC_REGISTER_PARAMETER(AdaptiveMaxPool2DParam); + +template +inline bool AdaptivePool2DInferShape(const nnvm::NodeAttrs& attrs, + std::vector* in_shape, + std::vector* out_shape) { + const T& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 1U); + CHECK_EQ(out_shape->size(), 1U); + + auto output_size = ShapeToArray(param.output_size); + + TShape dshape = (*in_shape)[0]; + if (dshape.ndim() == 0) return false; + + CHECK_GE(dshape.ndim(), 2U) + << "AdaptivePool2D only support input >= 2-D: input must have height and width"; + + Layout layout(param.layout); + CHECK(layout.contains('H') && layout.contains('W') && + !layout.contains('h') && !layout.contains('w')) + << "Invalid layout " << layout + << ". AdaptivePool2D layout must have H and W, which cannot be split"; + + const auto hidx = layout.indexof('H'); + const auto widx = layout.indexof('W'); + + TShape oshape = dshape; +// CHECK(output_size[0] <= dshape[hidx]) +// << "output height (" << output_size[0] << ") exceeds input (" +// << dshape[hidx] << ")"; +// CHECK(output_size[1] <= dshape[widx]) +// << "output width (" << output_size[1] << ") exceeds input (" +// << dshape[widx] << ")"; + + oshape[hidx] = param.output_size[0]; + oshape[widx] = param.output_size[1]; + + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); + return true; +} + +template +inline bool AdaptivePool2DCorrectLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + const T ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(ilayouts->size(), 1); + CHECK_EQ(last_ilayouts->size(), 1); + CHECK_EQ(olayouts->size(), 1); + + Layout input = (*ilayouts)[0]; + const Layout layout(param.layout); + + if (input.defined()) { + CHECK(input.convertible(layout)) << "Invalid input layout " << input; + if (input.indexof('W') != layout.indexof('W') || + input.indexof('H') != layout.indexof('H') || + input.contains('w') || input.contains('h')) { + // as long as the index doesn't change for width and height + // pool2d can keep the input layout. + input = layout; + } + } else { + input = layout; + } + + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, input); + NNVM_ASSIGN_LAYOUT(*olayouts, 0, input); + + return true; +} + +NNVM_REGISTER_OP(adaptive_max_pool2d) +.describe(R"code(Adaptive max pooling operation for one dimensional data. + +- **data**: This depends on the `layout` parameter. Input is 4D array of shape + (batch_size, channels, height, width) if `layout` is `NCHW`. +- **out**: This depends on the `layout` parameter. Output is 4D array of shape + (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. + +)code" NNVM_ADD_FILELINE) +.add_argument("data", "4D Tensor", "Input data.") +.add_arguments(AdaptiveMaxPool2DParam::__FIELDS__()) +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", ParamGetAttrDict) +.set_num_outputs(1) +.set_num_inputs(1) +.set_attr("FInferShape", AdaptivePool2DInferShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCorrectLayout", AdaptivePool2DCorrectLayout) +.set_attr("FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const AdaptiveMaxPool2DParam& param = nnvm::get(attrs.parsed); + + auto output_size = ShapeToArray(param.output_size); + + Layout layout(param.layout); + CHECK(layout.convertible(Layout("NCHW"))) + << "max_pool2d currently only supports layouts that are convertible from NCHW"; + CHECK_EQ(layout.indexof('h'), -1) << "max_pool2d does not support input split on height"; + CHECK_EQ(layout.indexof('w'), -1) << "max_pool2d does not support input split on width"; + + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; + + return Array{ + topi::nn::adaptive_pool(inputs[0], output_size, + topi::nn::kMaxPool, layout.name())}; +}) +.set_support_level(2); + +DMLC_REGISTER_PARAMETER(AdaptiveAvgPool2DParam); + +NNVM_REGISTER_OP(adaptive_avg_pool2d) +.describe(R"code(Adaptive average pooling operation for one dimensional data. + +- **data**: This depends on the `layout` parameter. Input is 4D array of shape + (batch_size, channels, height, width) if `layout` is `NCHW`. +- **out**: This depends on the `layout` parameter. Output is 4D array of shape + (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. + +)code" NNVM_ADD_FILELINE) +.add_argument("data", "4D Tensor", "Input data.") +.add_arguments(AdaptiveAvgPool2DParam::__FIELDS__()) +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", ParamGetAttrDict) +.set_attr("FInferShape", AdaptivePool2DInferShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCorrectLayout", AdaptivePool2DCorrectLayout) +.set_attr("FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const AdaptiveAvgPool2DParam& param = nnvm::get(attrs.parsed); + + auto output_size = ShapeToArray(param.output_size); + + Layout layout(param.layout); + CHECK(layout.convertible(Layout("NCHW"))) + << "avg_pool2d currently only supports layouts that are convertible from NCHW"; + CHECK_EQ(layout.indexof('h'), -1) << "avg_pool2d does not support input split on height"; + CHECK_EQ(layout.indexof('w'), -1) << "avg_pool2d does not support input split on width"; + + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; + + return Array{ + topi::nn::adaptive_pool(inputs[0], output_size, + topi::nn::kAvgPool, layout.name())}; +}) +.set_num_outputs(1) +.set_num_inputs(1) +.set_support_level(2); + } // namespace top } // namespace nnvm diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 74ef2740f845f..da58e9386c0f3 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -225,27 +225,6 @@ def schedule_global_avg_pool2d(_, outs, target): reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) - -# adaptive_max_pool2d -@reg.register_schedule("nn.adaptive_max_pool2d") -def schedule_adaptive_max_pool2d(_, outs, target): - """Schedule definition of adaptive_max_pool2d""" - with target: - return topi.generic.schedule_adaptive_pool(outs) - -reg.register_pattern("nn.adaptive_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) - - -# adaptive_avg_pool2d -@reg.register_schedule("nn.adaptive_avg_pool2d") -def schedule_adaptive_avg_pool2d(_, outs, target): - """Schedule definition of adaptive_avg_pool2d""" - with target: - return topi.generic.schedule_adaptive_pool(outs) - -reg.register_pattern("nn.adaptive_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) - - # leaky_relu reg.register_schedule("nn.leaky_relu", schedule_broadcast) reg.register_pattern("nn.leaky_relu", OpPattern.ELEMWISE) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index db419eecd1a40..1a9e02a08c98d 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -374,99 +374,6 @@ def global_avg_pool2d(data, return _make.global_avg_pool2d(data, layout) -def adaptive_max_pool2d(data, - output_size=None, - layout="NCHW"): - r"""2D adaptive max pooling operator. - - This operator takes data as input and does 2D max value calculation - across each window represented by WxH. - - - In the default case, where the data_layout is `NCHW` - a data Tensor with shape `(batch_size, in_channels, height, width)`, - to produce an output Tensor with shape - (batch_size, in_channels, output_height, output_width). - - The pooling kernel and stride sizes are automatically chosen for - desired output sizes. - - For output_size: - If this argument is not provided, input height and width will be used - as output height and width. - - If a single integer is provided for output_size, the output size is - (N x C x output_size x output_size) for any input (NCHW). - - If a tuple of integers (height, width) are provided for output_size, - the output size is (N x C x height x width) for any input (NCHW). - - Parameters - ---------- - data : tvm.relay.Expr - The input data to the operator. - - output_size : tuple of int. optional - Output height and width. - - layout : str, optional - Layout of the input. - - Returns - ------- - result : tvm.relay.Expr - The computed result. - """ - output_size = [] or output_size - return _make.adaptive_max_pool2d(data, output_size, layout) - -def adaptive_avg_pool2d(data, - output_size=None, - layout="NCHW"): - r"""2D adaptive average pooling operator. - - This operator takes data as input and does 2D average value calculation - across each window represented by WxH. - - - In the default case, where the data_layout is `NCHW` - a data Tensor with shape `(batch_size, in_channels, height, width)`, - to produce an output Tensor with shape - (batch_size, in_channels, output_height, output_width). - - The pooling kernel and stride sizes are automatically chosen for - desired output sizes. - - For output_size: - If this argument is not provided, input height and width will be used - as output height and width. - - If a single integer is provided for output_size, the output size is - (N x C x output_size x output_size) for any input (NCHW). - - If a tuple of integers (height, width) are provided for output_size, - the output size is (N x C x height x width) for any input (NCHW). - - Parameters - ---------- - data : tvm.relay.Expr - The input data to the operator. - - output_size : tuple of int. optional - Output height and width. - - layout : str, optional - Layout of the input. - - Returns - ------- - result : tvm.relay.Expr - The computed result. - """ - output_size = [] or output_size - return _make.adaptive_avg_pool2d(data, output_size, layout) - - def upsampling(data, scale=1, layout="NCHW", diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 4cebf2324a01b..23704693732b8 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -382,171 +382,5 @@ RELAY_REGISTER_OP("nn.global_max_pool2d") Pool2DInferCorrectLayout) .set_attr("FTVMCompute", GlobalPool2DCompute); - -// relay.nn.adaptive_pool_2d -TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs); - -bool AdaptivePool2DRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) { return false; } - const auto dshape = data->shape; - CHECK_NE(dshape.size(), 0); - CHECK_GE(dshape.size(), 2U) - << "Pool2D only support input >= 2-D: input must have height and width"; - const auto param = attrs.as(); - CHECK(param != nullptr); - - Layout layout(param->layout); - CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && - !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool2D layout must have H and W, which cannot be split"; - - const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); - const auto widx = layout.IndexOf(LayoutAxis::Get('W')); - Array oshape(dshape); - auto output_size = param->output_size; - CHECK_LE(output_size.size(), 2U) - << "output_size can have up to 2 elements."; - IndexExpr output_height, output_width; - if (output_size.empty()) { - output_height = dshape[hidx]; - output_width = dshape[widx]; - } else if (output_size.size() == 1) { - output_height = output_size[0]; - output_width = output_size[0]; - } else { - output_height = output_size[0]; - output_width = output_size[1]; - } - - oshape.Set(hidx, output_height); - oshape.Set(widx, output_width); - - // assign output type - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); - return true; -} - -template -Array AdaptivePool2DCompute(const Attrs& attrs, - const Array& inputs, - const Type& out_type, - const Target& target) { - static const Layout kNCHW("NCHW"); - const auto* param = attrs.as(); - CHECK(param != nullptr); - Layout layout(param->layout); - CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined()) - << "Adaptive pool2d currently only supports layouts that are convertible from NCHW"; - CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) - << "Adaptive pool2d does not support input split on height"; - CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) - << "Adaptive pool2d does not support input split on width"; - - CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) - << "Pool2D only support 4-D input (e.g., NCHW)" - << " or 5-D input (last dimension is a split of channel)"; - - auto output_size = param->output_size; - const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); - const auto widx = layout.IndexOf(LayoutAxis::Get('W')); - IndexExpr output_height, output_width; - if (output_size.empty()) { - output_height = inputs[0]->shape[hidx]; - output_width = inputs[0]->shape[widx]; - } else if (output_size.size() == 1) { - output_height = output_size[0]; - output_width = output_size[0]; - } else { - output_height = output_size[0]; - output_width = output_size[1]; - } - return Array{ - topi::nn::adaptive_pool(inputs[0], Array{ output_height, output_width }, - mode, layout.name()) }; -} - -// relay.nn.adaptive_avg_pool2d -Expr MakeAdaptiveAvgPool2D(Expr data, - Array output_size, - std::string layout) { - auto attrs = make_node(); - attrs->output_size = std::move(output_size); - attrs->layout = std::move(layout); - static const Op& op = Op::Get("nn.adaptive_avg_pool2d"); - return CallNode::make(op, {data}, Attrs(attrs), {}); -} - -TVM_REGISTER_API("relay.op.nn._make.adaptive_avg_pool2d") -.set_body_typed(MakeAdaptiveAvgPool2D); - -RELAY_REGISTER_OP("nn.adaptive_avg_pool2d") - .describe(R"code(Adaptive average pooling operation for 2D data. - -- **data**: This depends on the `layout` parameter. Input is 4D array of shape - (batch_size, channels, height, width) if `layout` is `NCHW`. -- **output_size**: If this argument is not provided, input height and width will be used - as output height and width. - If a single integer is provided for output_size, the output size is - (N x C x output_size x output_size) for any input (NCHW). - If a tuple of integers (height, width) are provided for output_size, - the output size is (N x C x height x width) for any input (NCHW). -- **out**: This depends on the `layout` parameter. Output is 4D array of shape - (batch_size, channels, output_height, output_width) if `layout` is `NCHW`. - -)code" TVM_ADD_FILELINE) -.set_attrs_type_key("relay.attrs.AdaptivePool2DAttrs") -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("AdaptiveAvgPool2D", AdaptivePool2DRel) -.set_attr("FInferCorrectLayout", - Pool2DInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool2DCompute); - - -// relay.nn.adaptive_max_pool2d -Expr MakeAdaptiveMaxPool2D(Expr data, - Array output_size, - std::string layout) { - auto attrs = make_node(); - attrs->output_size = std::move(output_size); - attrs->layout = std::move(layout); - static const Op& op = Op::Get("nn.adaptive_max_pool2d"); - return CallNode::make(op, {data}, Attrs(attrs), {}); -} - -TVM_REGISTER_API("relay.op.nn._make.adaptive_max_pool2d") -.set_body_typed(MakeAdaptiveMaxPool2D); - -RELAY_REGISTER_OP("nn.adaptive_max_pool2d") - .describe(R"code(Adaptive max pooling operation for 2D data. - -- **data**: This depends on the `layout` parameter. Input is 4D array of shape - (batch_size, channels, height, width) if `layout` is `NCHW`. -- **output_size**: If this argument is not provided, input height and width will be used - as output height and width. - If a single integer is provided for output_size, the output size is - (N x C x output_size x output_size) for any input (NCHW). - If a tuple of integers (height, width) are provided for output_size, - the output size is (N x C x height x width) for any input (NCHW). -- **out**: This depends on the `layout` parameter. Output is 4D array of shape - (batch_size, channels, output_height, output_width) if `layout` is `NCHW`. - -)code" TVM_ADD_FILELINE) -.set_attrs_type_key("relay.attrs.AdaptivePool2DAttrs") -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("AdaptiveMaxPool2D", AdaptivePool2DRel) -.set_attr("FInferCorrectLayout", - Pool2DInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool2DCompute); - } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/pytorch/mnist.py b/tests/python/frontend/pytorch/mnist.py new file mode 100644 index 0000000000000..16ff1c30f3cf1 --- /dev/null +++ b/tests/python/frontend/pytorch/mnist.py @@ -0,0 +1,24 @@ +r'''MNIST model''' +import torch.nn as nn +import torch.nn.functional as F + +class Net(nn.Module): + r'''MNIST model base on + https://github.com/pytorch/examples/blob/master/mnist/main.py''' + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 20, 5, 1) + self.conv2 = nn.Conv2d(20, 50, 5, 1) + self.fc1 = nn.Linear(4*4*50, 500) + self.fc2 = nn.Linear(500, 10) + + def forward(self, *args): + x = args[0] + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, 2, 2) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, 2, 2) + x = x.view(-1, 4*4*50) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) diff --git a/tests/python/frontend/pytorch/mobilenet.py b/tests/python/frontend/pytorch/mobilenet.py new file mode 100644 index 0000000000000..34562b86c3130 --- /dev/null +++ b/tests/python/frontend/pytorch/mobilenet.py @@ -0,0 +1,180 @@ +r'''MobileNet V1 and V2''' +import math +import torch.nn as nn + + +def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +def conv_1x1_bn(inp, oup): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +class InvertedResidual(nn.Module): + r'''InvertedResidual module''' + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = round(inp * expand_ratio) + self.use_res_connect = self.stride == 1 and inp == oup + + if expand_ratio == 1: + self.conv = nn.Sequential( + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + else: + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + + def forward(self, *args): + x = args[0] + if self.use_res_connect: + return x + self.conv(x) + return self.conv(x) + +class MobileNetV2(nn.Module): + r'''MobileNet V2 model based on + https://github.com/tonylins/pytorch-mobilenet-v2/blob/master/MobileNetV2.py''' + def __init__(self, n_class=1000, input_size=224, width_mult=1.): + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + interverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + assert input_size % 32 == 0 + input_channel = int(input_channel * width_mult) + self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel + self.features = [conv_bn(3, input_channel, 2)] + # building inverted residual blocks + for t, c, n, s in interverted_residual_setting: + output_channel = int(c * width_mult) + for i in range(n): + if i == 0: + self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) + else: + self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) + input_channel = output_channel + # building last several layers + self.features.append(conv_1x1_bn(input_channel, self.last_channel)) + # make it nn.Sequential + self.features = nn.Sequential(*self.features) + + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(self.last_channel, n_class), + ) + + self._initialize_weights() + + def forward(self, *args): + x = args[0] + x = self.features(x) + x = x.mean(3).mean(2) + x = self.classifier(x) + return x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + +class MobileNetV1(nn.Module): + r'''MobileNet V1 model based on + https://github.com/marvis/pytorch-mobilenet''' + def __init__(self): + super(MobileNetV1, self).__init__() + + def conv_bn(inp, oup, stride): # pylint: disable=redefined-outer-name + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True) + ) + + def conv_dw(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.ReLU(inplace=True), + + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True), + ) + + self.model = nn.Sequential( + conv_bn(3, 32, 2), + conv_dw(32, 64, 1), + conv_dw(64, 128, 2), + conv_dw(128, 128, 1), + conv_dw(128, 256, 2), + conv_dw(256, 256, 1), + conv_dw(256, 512, 2), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 1024, 2), + conv_dw(1024, 1024, 1), + nn.AvgPool2d(7), + ) + self.fc = nn.Linear(1024, 1000) # pylint: disable=invalid-name + + def forward(self, *args): + x = args[0] + x = self.model(x) + x = x.view(-1, 1024) + x = self.fc(x) + return x diff --git a/tests/python/frontend/pytorch/net_s3fd.py b/tests/python/frontend/pytorch/net_s3fd.py new file mode 100755 index 0000000000000..192d78e3b7b3c --- /dev/null +++ b/tests/python/frontend/pytorch/net_s3fd.py @@ -0,0 +1,122 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import torch.nn.functional as F + +class L2Norm(nn.Module): + def __init__(self,n_channels, scale=1.0): + super(L2Norm,self).__init__() + self.n_channels = n_channels + self.scale = scale + self.eps = 1e-10 + self.weight = nn.Parameter(torch.Tensor(self.n_channels)) + self.weight.data *= 0.0 + self.weight.data += self.scale + + def forward(self, x): + norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps + x = x / norm * self.weight.view(1,-1,1,1) + return x + +class s3fd(nn.Module): + def __init__(self): + super(s3fd, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) + self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) + + self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) + self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) + + self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) + self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) + + self.conv3_3_norm = L2Norm(256,scale=10) + self.conv4_3_norm = L2Norm(512,scale=8) + self.conv5_3_norm = L2Norm(512,scale=5) + + self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + + self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) + self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + h = F.relu(self.conv1_1(x)) + h = F.relu(self.conv1_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv2_1(h)) + h = F.relu(self.conv2_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv3_1(h)) + h = F.relu(self.conv3_2(h)) + h = F.relu(self.conv3_3(h)); f3_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv4_1(h)) + h = F.relu(self.conv4_2(h)) + h = F.relu(self.conv4_3(h)); f4_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv5_1(h)) + h = F.relu(self.conv5_2(h)) + h = F.relu(self.conv5_3(h)); f5_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.fc6(h)) + h = F.relu(self.fc7(h)); ffc7 = h + h = F.relu(self.conv6_1(h)) + h = F.relu(self.conv6_2(h)); f6_2 = h + h = F.relu(self.conv7_1(h)) + h = F.relu(self.conv7_2(h)); f7_2 = h + + f3_3 = self.conv3_3_norm(f3_3) + f4_3 = self.conv4_3_norm(f4_3) + f5_3 = self.conv5_3_norm(f5_3) + + cls1 = self.conv3_3_norm_mbox_conf(f3_3) + reg1 = self.conv3_3_norm_mbox_loc(f3_3) + cls2 = self.conv4_3_norm_mbox_conf(f4_3) + reg2 = self.conv4_3_norm_mbox_loc(f4_3) + cls3 = self.conv5_3_norm_mbox_conf(f5_3) + reg3 = self.conv5_3_norm_mbox_loc(f5_3) + cls4 = self.fc7_mbox_conf(ffc7) + reg4 = self.fc7_mbox_loc(ffc7) + cls5 = self.conv6_2_mbox_conf(f6_2) + reg5 = self.conv6_2_mbox_loc(f6_2) + cls6 = self.conv7_2_mbox_conf(f7_2) + reg6 = self.conv7_2_mbox_loc(f7_2) + + # max-out background label + chunk = torch.chunk(cls1,4,1) + bmax = torch.max(torch.max(chunk[0],chunk[1]),chunk[2]) + cls1 = torch.cat([bmax,chunk[3]],dim=1) + + return (cls1,reg1,cls2,reg2,cls3,reg3,cls4,reg4,cls5,reg5,cls6,reg6) diff --git a/tests/python/frontend/pytorch/single_op.py b/tests/python/frontend/pytorch/single_op.py new file mode 100644 index 0000000000000..a4542b183e4fe --- /dev/null +++ b/tests/python/frontend/pytorch/single_op.py @@ -0,0 +1,262 @@ +r'''Models consisting of single operators''' +import torch +from torch.nn import Module + + +class Add1(Module): + def forward(self, *args): + return args[0] + args[0] + +class Add2(Module): + def forward(self, *args): + return args[0] + 1 + +class Add3(Module): + def forward(self, *args): + ones = torch.ones([1, 3, 224, 224]) + if torch.cuda.is_available(): + ones = ones.cuda() + return args[0] + ones + +class Add4(Module): + def forward(self, *args): + ones = torch.ones([1, 1, 224, 224]) + if torch.cuda.is_available(): + ones = ones.cuda() + return args[0] + ones + +class Add5(Module): + def forward(self, *args): + ones = torch.ones([]) + if torch.cuda.is_available(): + ones = ones.cuda() + return args[0] + ones + +class Subtract1(Module): + def forward(self, *args): + return args[0] - args[0] + +class Subtract2(Module): + def forward(self, *args): + return args[0] - 1 + +class Subtract3(Module): + def forward(self, *args): + ones = torch.ones([1, 3, 224, 224]) + if torch.cuda.is_available(): + ones = ones.cuda() + return args[0] - ones + +class Subtract4(Module): + def forward(self, *args): + ones = torch.ones([1, 1, 224, 224]) + if torch.cuda.is_available(): + ones = ones.cuda() + return args[0] - ones + +class Subtract5(Module): + def forward(self, *args): + ones = torch.ones([]) + if torch.cuda.is_available(): + ones = ones.cuda() + return args[0] - ones + +class Multiply1(Module): + def forward(self, *args): + return args[0] * args[0] + +class Multiply2(Module): + def forward(self, *args): + return args[0] * 1 + +class Multiply3(Module): + def forward(self, *args): + ones = torch.ones([1, 3, 224, 224]) + if torch.cuda.is_available(): + ones = ones.cuda() + return args[0] * ones + +class Multiply4(Module): + def forward(self, *args): + ones = torch.ones([1, 1, 224, 224]) + if torch.cuda.is_available(): + ones = ones.cuda() + return args[0] * ones + +class Multiply5(Module): + def forward(self, *args): + ones = torch.ones([]) + if torch.cuda.is_available(): + ones = ones.cuda() + return args[0] * ones + +class Unsqueeze1(Module): + def forward(self, *args): + return args[0].unsqueeze(2) + +class Concatenate1(Module): + def forward(self, *args): + return torch.cat([args[0][:, 0].unsqueeze(1), args[0][:, 1].unsqueeze(1)], 1) + +class Concatenate2(Module): + def forward(self, *args): + a = (args[0][:, :, 0] + 2) * 7 + b = (args[0][:, :, 1] + 3) * 11 + c = (args[0][:, :, 2] + 5) * 13 + return torch.cat([t.unsqueeze(2) for t in [a, b, c]], 2) + +class ReLU1(Module): + def forward(self, *args): + return torch.nn.ReLU()(args[0]) + +class AdaptiveAvgPool2D1(Module): + def forward(self, *args): + return torch.nn.AdaptiveAvgPool2d([1, 1])(args[0]) + +class AdaptiveAvgPool2D2(Module): + def forward(self, *args): + return torch.nn.AdaptiveAvgPool2d([100, 100])(args[0]) + +class AdaptiveAvgPool2D3(Module): + def forward(self, *args): + return torch.nn.AdaptiveAvgPool2d([224, 224])(args[0]) + +class MaxPool2D1(Module): + def forward(self, *args): + return torch.nn.MaxPool2d(kernel_size=[1, 1])(args[0]) + +class MaxPool2D2(Module): + def forward(self, *args): + return torch.nn.MaxPool2d(kernel_size=[100, 100])(args[0]) + +class MaxPool2D3(Module): + def forward(self, *args): + return torch.nn.MaxPool2d(kernel_size=[224, 224])(args[0]) + +class HardTanh1(Module): + def forward(self, *args): + return torch.nn.Hardtanh()(args[0]) + +class Conv2D1(Module): + + def __init__(self): + super(Conv2D1, self).__init__() + self.conv = torch.nn.Conv2d(3, 64, 7, bias=True) + + def forward(self, *args): + return self.conv(args[0]) + +class Conv2D2(Module): + + def __init__(self): + super(Conv2D2, self).__init__() + self.conv = torch.nn.Conv2d(3, 64, 7, bias=False) + + def forward(self, *args): + return self.conv(args[0]) + +class Threshold1(Module): + def forward(self, *args): + return torch.nn.Threshold(0, 0)(args[0]) + +class Pad1(Module): + def forward(self, *args): + return torch.ConstantPad2d(3)(args[0]) + +class Contiguous1(Module): + def forward(self, *args): + return args[0].contiguous() + +class BatchNorm1(Module): + def __init__(self): + super(BatchNorm1, self).__init__() + self.batch_norm = torch.nn.BatchNorm2d(3, affine=True) + def forward(self, *args): + return self.batch_norm(args[0]) + +class BatchNorm2(Module): + def __init__(self): + super(BatchNorm2, self).__init__() + self.batch_norm = torch.nn.BatchNorm2d(3, affine=False) + def forward(self, *args): + return self.batch_norm(args[0]) + +class Transpose1(Module): + def forward(self, *args): + return args[0].transpose(2, 3) + +class Transpose2(Module): + def forward(self, *args): + return args[0].transpose(-2, -1) + +class Transpose3(Module): + def forward(self, *args): + return args[0].t() + +class Size1(Module): + def forward(self, *args): + return args[0].size(0) * args[0] + +class View1(Module): + def forward(self, *args): + return args[0].view((1, 3 * 224 * 224)) + +class View2(Module): + def forward(self, *args): + return args[0].view(args[0].shape[0], -1) + +class Select1(Module): + def forward(self, *args): + return args[0].select(1, 1) + +class Clone1(Module): + def forward(self, *args): + return args[0].clone() + +class LogSoftmax1(Module): + def forward(self, *args): + return torch.nn.LogSoftmax(dim=1)(args[0][0, 0]) + +class Sigmoid1(Module): + def forward(self, *args): + return torch.nn.Sigmoid()(args[0]) + +class Dense1(Module): + def __init__(self): + super(Dense1, self).__init__() + self.linear = torch.nn.Linear(224, 7) + def forward(self, *args): + return self.linear(args[0][0, 0]) + +class AvgPool2D1(Module): + def forward(self, *args): + return torch.nn.AvgPool2d(kernel_size=[100, 100])(args[0]) + +class Dropout1(Module): + def forward(self, *args): + return torch.nn.functional.dropout(args[0][0, 0], 0.5, False) + +class Slice1(Module): + def forward(self, *args): + return args[0][:, :, :, :3] + +class Slice2(Module): + def forward(self, *args): + return args[0][0, :, :, :] + +class Mean1(Module): + def forward(self, *args): + return args[0].mean(2) + +class Expand1(Module): + def forward(self, *args): + return args[0].expand((3, -1, -1, -1)) + +class Pow1(Module): + def forward(self, *args): + return args[0] ** 2 + +class Chunk1(Module): + def forward(self, *args): + chunks = args[0].chunk(7, 2) + return torch.cat(chunks, 2) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py new file mode 100644 index 0000000000000..b9c41ff654543 --- /dev/null +++ b/tests/python/frontend/pytorch/test_forward.py @@ -0,0 +1,567 @@ +r'''Unit tests for various models and operators''' +from time import time +import os +import sys +from tempfile import TemporaryDirectory +from scipy.stats import t as tdistr +import numpy as np +import torch +import tvm +import nnvm +import torchvision +import single_op +from mnist import Net +import mobilenet + + +sys.setrecursionlimit(10000) +if torch.cuda.is_available(): + TARGET = tvm.target.cuda() + CTX = tvm.gpu() +else: + TARGET = 'llvm -mcpu=skylake-avx512' + CTX = tvm.cpu() + + +def _vectorize(ten): + return ten.reshape(-1) + + +def atol(tru, est): + def _atol_elt(tru, est): + return abs(tru - est) + tru = _vectorize(tru) + est = _vectorize(est) + return max([_atol_elt(x, y) for x, y in zip(tru, est)]) + + +def rtol(tru, est): + def _rtol_elt(tru, est): + return abs(tru - est) / min(abs(tru), abs(est)) + tru = _vectorize(tru) + est = _vectorize(est) + return max([_rtol_elt(x, y) for x, y in zip(tru, est)]) + + +def assert_shapes_match(tru, est): + if tru.shape != est.shape: + msg = "Output shapes {} and {} don't match" + raise AssertionError(msg.format(tru.shape, est.shape)) + + +def load_torchvision(model_name): + r'''Given a model name, returns a Torchvision model in eval mode as well + as an example input.''' + if model_name.startswith('inception'): + height = width = 299 + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + else: + height = width = 224 + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + input_shape = [1, 3, height, width] + input_data = torch.randn(input_shape).float() + for channel in range(3): + input_data[:, channel] -= mean[channel] + input_data[:, channel] /= std[channel] + model = getattr(torchvision.models, model_name)(pretrained=True) + model = model.float().eval() + return model, input_data + + +def load_pretrainedmodels(model_name): + r'''Given a model name, returns a pretrainedmodels.pytorch model in eval + mode as well as an example input.''' + import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch + model = getattr(pretrainedmodels, model_name)().float().eval() + input_shape = [1, *model.input_size] + input_data = torch.rand(input_shape).float() * 256 + for channel in range(3): + input_data[:, channel] -= model.mean[channel] + input_data[:, channel] /= model.std[channel] + return model, input_data + + +def load_mobilenet(model_name): + r'''Given a model name, returns a MobileNet model in eval mode as well as + an example input.''' + class_name = 'MobileNet' + model_name[-2:].capitalize() + model = getattr(mobilenet, class_name)().float().eval() + input_shape = [1, 3, 224, 224] + input_data = torch.rand(input_shape).float() * 256 + imagenet_mean = [123., 117., 104.] + imagenet_stdev = [58.395, 57.12, 57.375] + for channel in range(3): + input_data[:, channel] -= imagenet_mean[channel] + input_data[:, channel] /= imagenet_stdev[channel] + return model, input_data + + +def load_mnist(): + r'''Returns a MNIST model in eval mode as well as an example input.''' + model = Net() + input_shape = [1, 1, 28, 28] + input_data = torch.rand(input_shape).float() * 256 + return model, input_data + + +def load_single_op(model_name): + r'''Given a model name, returns a single-operator model in eval + mode as well as an example input.''' + model = getattr(single_op, model_name)().float().eval() + input_shape = [1, 3, 224, 224] + input_data = torch.rand(input_shape).float() + return model, input_data + + +def load_fastai(): + r'''Returns a FastAI model as well as an example input.''' + model = torch.jit.load('fastai.pth', map_location='cpu') + input_shape = [1, 3, 224, 224] + input_data = torch.rand(input_shape).float() + return model, input_data + + +def load_sfd(): + from net_s3fd import s3fd + model = s3fd() + input_shape = [1, 3, 512, 512] + input_data = torch.rand(input_shape).float() + return model, input_data + + +def load_model(model_name): + r'''Given a model name, returns a model as well as an example input.''' + if hasattr(torchvision.models, model_name): + return load_torchvision(model_name) + if model_name.startswith('mobilenet'): + return load_mobilenet(model_name) + if model_name == 'mnist': + return load_mnist() + if hasattr(single_op, model_name): + return load_single_op(model_name) + if model_name == 'fastai': + return load_fastai() + if model_name == 'sfd': + return load_sfd() + try: + if hasattr(pretrainedmodels, model_name): + return load_pretrainedmodels(model_name) + except ModuleNotFoundError: + raise ModuleNotFoundError('Please install pretrainedmodels.pytorch') + raise RuntimeError('Model not supported') + + +def confidence_interval(mean, stdev, count, alpha=.01): + r'''Returns the lower and upper bounds of the confidence interval of a random + variable. Confidence is 1 - alpha (default confidence is 99%).''' + stdval = tdistr.ppf(1 - alpha / 2, count - 1) + lower, upper = mean + np.array([-1, 1]) * stdval * stdev / np.sqrt(count) + return lower, upper + +def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40): + r'''Compute the latency of the given model''' + latencies = [] + count = 0 + while True: + if isinstance(model, torch.nn.Module): + input_data = [torch.rand(shape).float() for shape in input_shapes] + if torch.cuda.is_available(): + input_data = list(map(lambda x: x.cuda(), input_data)) + model = model.cuda() + t_start = time() + model(*input_data) + t_end = time() + latencies.append(t_end - t_start) + else: + input_data = {} + for i, shape in enumerate(input_shapes): + name = 'input' + str(i) + arr = np.random.random(shape).astype('float32') + input_data[name] = tvm.nd.array(arr) + t_start = time() + model.set_input(**input_data) + model.run() + for i, shape in enumerate(output_shapes): + arr = np.zeros(shape).astype('float32') + model.get_output(i, tvm.nd.array(arr)) + t_end = time() + count += 1 + if count < dryruns: + continue + latencies.append(t_end - t_start) + mean = np.mean(latencies) + stdev = np.std(latencies) + sample_size = len(latencies) + if sample_size > dryruns: + lower, upper = confidence_interval(mean, stdev, sample_size) + est = (upper + lower) / 2 + err = (upper - lower) / 2 + if err < thresh: + return est + print(f'Latency so far is {est:.3f} +/- {err:.3f} seconds.') + +def verify_model(model_name): + r'''Assert that the output of a compiled model matches with that of its + baseline.''' + baseline_model, baseline_input = load_model(model_name) + if torch.cuda.is_available(): + baseline_model = baseline_model.cuda() + baseline_input = baseline_input.cuda() + baseline_outputs = baseline_model(baseline_input) + if isinstance(baseline_outputs, tuple): + baseline_outputs = tuple(out.detach().cpu().numpy() for out in baseline_outputs) + else: + baseline_outputs = (baseline_outputs.detach().cpu().numpy(),) + output_shapes = [out.shape for out in baseline_outputs] + dtype = 'float32' + input_name = 'input0' + input_shapes = {input_name: list(baseline_input.shape)} + baseline_model(baseline_input) + trace = torch.jit.trace(baseline_model, baseline_input).float().eval() + with TemporaryDirectory() as tmp: + path = os.path.join(tmp, 'model.pth') + torch.jit.save(trace, path) + sym, params = nnvm.frontend.from_pytorch(path, input_shapes) + compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())} + graph, lib, params = nnvm.compiler.build(sym, TARGET, input_shapes, + dtype='float32', + params=params) + compiled_model = tvm.contrib.graph_runtime.create(graph, lib, CTX) + compiled_model.set_input(**params) + compiled_model.set_input(**compiled_input) + compiled_model.run() + for i, baseline_output in enumerate(baseline_outputs): + output_shape = baseline_output.shape + compiled_output = compiled_model.get_output( + i, tvm.nd.array(np.zeros(output_shape).astype(dtype), CTX)).asnumpy() + assert_shapes_match(baseline_output, compiled_output) + tvm.testing.assert_allclose(baseline_output, compiled_output, + rtol=1e-5, atol=1e-5) +# thresh = 1e-2 +# units = 1e3 +# input_shapes = list(input_shapes.values()) +# baseline_latency = measure_latency(baseline_model, input_shapes, +# output_shapes, thresh) * units +# compiled_latency = measure_latency(compiled_model, input_shapes, +# output_shapes, thresh) * units +# thresh = int(thresh * units) +# print(f'Baseline latency is {baseline_latency:.3f} +/- {thresh:d} ms.') +# print(f'Compiled latency is {compiled_latency:.3f} +/- {thresh:d} ms.') + from subprocess import call + call('rm -rf ~/.torch/models/*', shell=True) + + +def verify_ones1(): + verify_model('Ones1') + +def verify_zeros1(): + verify_model('Zeros1') + +def verify_add1(): + verify_model('Add1') + +def verify_add2(): + verify_model('Add2') + +def verify_add3(): + verify_model('Add3') + +def verify_add4(): + verify_model('Add4') + +def verify_add5(): + verify_model('Add5') + +def verify_subtract1(): + verify_model('Subtract1') + +def verify_subtract2(): + verify_model('Subtract2') + +def verify_subtract3(): + verify_model('Subtract3') + +def verify_subtract4(): + verify_model('Subtract4') + +def verify_subtract5(): + verify_model('Subtract5') + +def verify_multiply1(): + verify_model('Multiply1') + +def verify_multiply2(): + verify_model('Multiply2') + +def verify_multiply3(): + verify_model('Multiply3') + +def verify_multiply4(): + verify_model('Multiply4') + +def verify_multiply5(): + verify_model('Multiply5') + +def verify_unsqueeze1(): + verify_model('Unsqueeze1') + +def verify_concatenate1(): + verify_model('Concatenate1') + +def verify_concatenate2(): + verify_model('Concatenate2') + +def verify_relu1(): + verify_model('ReLU1') + +def verify_adaptiveavgpool2d1(): + verify_model('AdaptiveAvgPool2D1') + +def verify_adaptiveavgpool2d2(): + verify_model('AdaptiveAvgPool2D2') + +def verify_adaptiveavgpool2d3(): + verify_model('AdaptiveAvgPool2D3') + +def verify_maxpool2d1(): + verify_model('MaxPool2D1') + +def verify_maxpool2d2(): + verify_model('MaxPool2D2') + +def verify_maxpool2d3(): + verify_model('MaxPool2D3') + +def verify_hardtanh1(): + verify_model('HardTanh1') + +def verify_conv2d1(): + verify_model('Conv2D1') + +def verify_conv2d2(): + verify_model('Conv2D2') + +def verify_threshold1(): + verify_model('Threshold1') + +def verify_contiguous1(): + verify_model('Contiguous1') + +def verify_batchnorm1(): + verify_model('BatchNorm1') + +def verify_batchnorm2(): + verify_model('BatchNorm2') + +def verify_transpose1(): + verify_model('Transpose1') + +def verify_transpose2(): + verify_model('Transpose2') + +def verify_size1(): + verify_model('Size1') + +def verify_view1(): + verify_model('View1') + +def verify_view2(): + verify_model('View2') + +def verify_select1(): + verify_model('Select1') + +def verify_clone1(): + verify_model('Clone1') + +def verify_logsoftmax1(): + verify_model('LogSoftmax1') + +def verify_sigmoid1(): + verify_model('Sigmoid1') + +def verify_dense1(): + verify_model('Dense1') + +def verify_avgpool2d1(): + verify_model('AvgPool2D1') + +def verify_dropout1(): + verify_model('Dropout1') + +def verify_slice1(): + verify_model('Slice1') + +def verify_slice2(): + verify_model('Slice2') + +def verify_mean1(): + verify_model('Mean1') + +def verify_expand1(): + verify_model('Expand1') + +def verify_pow1(): + verify_model('Pow1') + +def verify_chunk1(): + verify_model('Chunk1') + +def verify_alexnet(): + verify_model('alexnet') + +def verify_densenet121(): + verify_model('densenet121') + +def verify_densenet161(): + verify_model('densenet161') + +def verify_densenet169(): + verify_model('densenet169') + +def verify_densenet201(): + verify_model('densenet201') + +def verify_inception_v3(): + verify_model('inception_v3') + +def verify_resnet101(): + verify_model('resnet101') + +def verify_resnet152(): + verify_model('resnet152') + +def verify_resnet18(): + verify_model('resnet18') + +def verify_resnet34(): + verify_model('resnet34') + +def verify_resnet50(): + verify_model('resnet50') + +def verify_squeezenet1_0(): + verify_model('squeezenet1_0') + +def verify_squeezenet1_1(): + verify_model('squeezenet1_1') + +def verify_vgg11(): + verify_model('vgg11') + +def verify_vgg11_bn(): + verify_model('vgg11_bn') + +def verify_vgg13(): + verify_model('vgg13') + +def verify_vgg13_bn(): + verify_model('vgg13_bn') + +def verify_vgg16(): + verify_model('vgg16') + +def verify_vgg16_bn(): + verify_model('vgg16_bn') + +def verify_vgg19(): + verify_model('vgg19') + +def verify_vgg19_bn(): + verify_model('vgg19_bn') + +def verify_sfd(): + verify_model('sfd') + +def verify_mobilenet_v1(): + verify_model('mobilenet_v1') + +def verify_mobilenet_v2(): + verify_model('mobilenet_v2') + +def verify_mnist(): + verify_model('mnist') + +def verify_fastai(): + verify_model('fastai') + + +if __name__ == '__main__': + verify_mobilenet_v1() + verify_mobilenet_v2() + verify_mnist() + verify_fastai() + verify_sfd() + verify_add1() + verify_add2() + verify_add3() + verify_add4() + verify_add5() + verify_subtract1() + verify_subtract2() + verify_subtract3() + verify_subtract4() + verify_subtract5() + verify_multiply1() + verify_multiply2() + verify_multiply3() + verify_multiply4() + verify_multiply5() + verify_unsqueeze1() + verify_concatenate1() + verify_concatenate2() + verify_relu1() + verify_adaptiveavgpool2d1() + verify_adaptiveavgpool2d2() + verify_adaptiveavgpool2d3() + verify_maxpool2d1() + verify_maxpool2d2() + verify_maxpool2d3() + verify_hardtanh1() + verify_conv2d1() + verify_conv2d2() + verify_threshold1() + verify_contiguous1() + verify_batchnorm1() + verify_batchnorm2() + verify_transpose1() + verify_transpose2() + verify_size1() + verify_view1() + verify_view2() + verify_select1() + verify_clone1() + verify_logsoftmax1() + verify_sigmoid1() + verify_dense1() + verify_avgpool2d1() + verify_dropout1() + verify_slice1() + verify_slice2() + verify_mean1() + verify_expand1() + verify_pow1() + verify_chunk1() + verify_alexnet() + verify_densenet121() + verify_densenet161() + verify_densenet169() + verify_densenet201() + verify_resnet101() + verify_resnet152() + verify_resnet18() + verify_resnet34() + verify_resnet50() + verify_squeezenet1_0() + verify_squeezenet1_1() + verify_vgg11() + verify_vgg11_bn() + verify_vgg13() + verify_vgg13_bn() + verify_vgg16() + verify_vgg16_bn() + verify_vgg19() + verify_vgg19_bn() + verify_inception_v3() diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 27a570e768b21..c8a38565ac7aa 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -278,48 +278,6 @@ def test_avg_pool2d_no_count_pad(): tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) -def verify_adaptive_pool2d(dshape, out_size, pool_type, layout="NCHW", dtype="float32"): - def start_index(index, odim, idim): - return int(np.floor(index * idim / odim)) - - def end_index(index, odim, idim): - return int(np.ceil((index + 1) * idim / odim)) - - np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype) - n, c, h, w = dshape - oh, ow = out_size - oshape = (n, c) + out_size - np_out = np.zeros(oshape).astype(dtype) - np_op = np.mean if pool_type == "avg" else np.max - for i in range(n): - for j in range(c): - for k in range(oh): - k_start = start_index(k, oh, h) - k_end = end_index(k, oh, h) - k_sl = slice(k_start, k_end) - for l in range(ow): - l_start = start_index(l, ow, w) - l_end = end_index(l, ow, w) - l_sl = slice(l_start, l_end) - np_out[i, j, k, l] = np_op(np_data[i, j, k_sl, l_sl]) - - opfunc = relay.nn.adaptive_avg_pool2d if pool_type == "avg" else relay.nn.adaptive_max_pool2d - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - y = opfunc(x, out_size, layout) - func = relay.Function([x], y) - - for target, ctx in ctx_list(): - intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - relay_out = intrp1.evaluate(func)(np_data) - tvm.testing.assert_allclose(relay_out.asnumpy(), np_out, rtol=1e-5, atol=1e-5) - -def test_adaptive_pool2d(): - verify_adaptive_pool2d((1, 9, 224, 224), (1, 1), "max") - verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg") - verify_adaptive_pool2d((1, 14, 56, 78), (34, 13), "max") - verify_adaptive_pool2d((1, 5, 46, 97), (4, 96), "avg") - - def test_flatten_infer_type(): d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") x = relay.var("x", relay.TensorType((d1, d2, d3, d4), "float32")) @@ -508,7 +466,6 @@ def test_upsampling(): if __name__ == "__main__": test_pool2d() test_avg_pool2d_no_count_pad() - test_adaptive_pool2d() test_lrn() test_l2_normalize() test_conv2d_infer_type() diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index 4a192ef142ff2..2ae1cc604343c 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -4,6 +4,8 @@ set -e set -u export PYTHONPATH=nnvm/python:python:topi/python +export PYTHONPATH=$PYTHONPATH:.local/lib/python3.6/site-packages + # to avoid openblas threading error export OMP_NUM_THREADS=1 @@ -56,3 +58,6 @@ python3 -m nose -v tests/python/frontend/tensorflow echo "Running relay caffe2 frondend test..." python3 -m nose -v tests/python/frontend/caffe2 + +echo "Running nnvm PyTorch frontend test..." +python3 -m nose -v tests/python/frontend/pytorch diff --git a/tests/scripts/task_python_vta.sh b/tests/scripts/task_python_vta.sh index ea71fda178647..dad476ec5314d 100755 --- a/tests/scripts/task_python_vta.sh +++ b/tests/scripts/task_python_vta.sh @@ -12,10 +12,15 @@ rm -rf ~/.tvm make cython make cython3 +echo "Installing PyTorch dependencies for testing..." +pip3 install torch --user +pip3 install torchvision --user +export PYTHONPATH=$PYTHONPATH:.local/lib/python3.6/site-packages + echo "Running unittest..." -python -m nose -v vta/tests/python/unittest +#python -m nose -v vta/tests/python/unittest python3 -m nose -v vta/tests/python/unittest echo "Running integration test..." -python -m nose -v vta/tests/python/integration +#python -m nose -v vta/tests/python/integration python3 -m nose -v vta/tests/python/integration