diff --git a/python/tvm/contrib/image.py b/python/tvm/contrib/image.py new file mode 100644 index 0000000000000..acac195ac9895 --- /dev/null +++ b/python/tvm/contrib/image.py @@ -0,0 +1,34 @@ +"""Common system utilities""" +from __future__ import absolute_import as _abs +import math +import numpy as np + + +def bilinear_weights(image, new_h, new_w, layout): + """ Helper function to generate weights for bilinear scaling """ + + if layout == "NHWC": + (height, width) = image.shape[1:3] + elif layout == "NCHW": + (height, width) = image.shape[2:] + else: + raise NotImplementedError( + 'Layout not supported {} '.format(layout)) + + x_ratio = (width-1)/new_w + y_ratio = (height-1)/new_h + + def _bilinear_interpolation(y, x): + x_coord = math.floor(x_ratio * x) + y_coord = math.floor(y_ratio * y) + x_diff = (x_ratio * x) - x_coord + y_diff = (y_ratio * y) - y_coord + + return [y_coord, x_coord, y_diff, x_diff] + + weights = np.empty([new_h, new_w, 4], dtype='float32') + + for i in range(new_h): + for j in range(new_w): + weights[i][j] = _bilinear_interpolation(i, j) + return weights diff --git a/topi/include/topi/nn/scale.h b/topi/include/topi/nn/scale.h new file mode 100644 index 0000000000000..37a81ea0098f0 --- /dev/null +++ b/topi/include/topi/nn/scale.h @@ -0,0 +1,258 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file topi/transform.h + * \brief Transform op constructors + */ +#ifndef TOPI_NN_SCALE_H_ +#define TOPI_NN_SCALE_H_ + +#include +#include +#include +#include + +#include "topi/tags.h" +#include "topi/detail/ravel_unravel.h" +#include "topi/detail/constant_utils.h" +#include "tvm/tvm.h" + +namespace topi { +namespace nn { +using namespace tvm; + +/*! +* \brief Resize given tensor to given shape using nearest neighbour for NHWC +* +* \param inputs The input tensor array. +* \param shape Output shape to scale to. +* \param name Name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor resized to given shape +*/ +inline Tensor scale_nn_nhwc(const Array& inputs, + Array shape, + std::string name = "tensor", + std::string tag = kInjective) { + Array out_shape; + out_shape.push_back(inputs[0]->shape[0]); + out_shape.push_back(shape[0]); + out_shape.push_back(shape[1]); + out_shape.push_back(inputs[0]->shape[3]); + + Expr h_scale = shape[0] / inputs[0]->shape[1]; + Expr w_scale = shape[1] / inputs[0]->shape[2]; + + return compute( + out_shape, [&](const Array& indices) { + Array idx; + idx.push_back(indices[0]); + idx.push_back(indices[1] / h_scale); + idx.push_back(indices[2] / w_scale); + idx.push_back(indices[3]); + + return inputs[0](idx); + }, name, tag); +} + +/*! +* \brief Resize given tensor to given shape using nearest neighbour for NCHW +* +* \param inputs The input tensor array. +* \param shape Output shape to scale to. +* \param name Name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor resized to given shape +*/ +inline Tensor scale_nn_nchw(const Array& inputs, + Array shape, + std::string name = "tensor", + std::string tag = kInjective) { + Array out_shape; + out_shape.push_back(inputs[0]->shape[0]); + out_shape.push_back(inputs[0]->shape[1]); + out_shape.push_back(shape[0]); + out_shape.push_back(shape[1]); + + Expr h_scale = shape[0] / inputs[0]->shape[2]; + Expr w_scale = shape[1] / inputs[0]->shape[3]; + + return compute( + out_shape, [&](const Array& indices) { + Array idx; + idx.push_back(indices[0]); + idx.push_back(indices[1]); + idx.push_back(indices[2] / h_scale); + idx.push_back(indices[3] / w_scale); + + return inputs[0](idx); + }, name, tag); +} + +/*! +* \brief Resize given tensor to given shape using nearest neighbour +* +* \param inputs The input tensor array. +* \param shape Output shape to scale to. +* \param layout input layout +* \param name Name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor resized to given shape +*/ +inline Tensor scale_nn(const Array& inputs, + Array shape, + std::string layout = "NCHW", + std::string name = "tensor", + std::string tag = kInjective) { + if (layout == "NHWC") { + return scale_nn_nhwc(inputs, shape); + } else { + return scale_nn_nchw(inputs, shape); + } +} + +/*! +* \brief Resize given tensor to given shape using bilinear interpolation for NHWC +* +* \param inputs The input tensor array. +* \param shape Output shape to scale to. +* \param name Name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor resized to given shape +*/ +inline Tensor scale_bilinear_nhwc(const Array& inputs, + Array shape, + std::string name = "tensor", + std::string tag = kInjective) { + Array out_shape; + out_shape.push_back(inputs[0]->shape[0]); + out_shape.push_back(shape[0]); + out_shape.push_back(shape[1]); + out_shape.push_back(inputs[0]->shape[3]); + + Array split_ind; + split_ind.push_back(make_const(UInt(32), 2)); + + Array weights = split(inputs[1], split_ind, 2); + + Tensor coords = cast(weights[0], Int(32)); + + Expr cone = make_const(UInt(32), 1); + + return compute( + out_shape, [&](const Array& indices) { + auto y1 = coords(indices[1], indices[2], 0); + auto x1 = coords(indices[1], indices[2], 1); + auto h = weights[1](indices[1], indices[2], 0); + auto w = weights[1](indices[1], indices[2], 1); + + auto A = inputs[0](indices[0], y1, x1, indices[3]); + auto B = inputs[0](indices[0], y1, x1+cone, indices[3]); + auto C = inputs[0](indices[0], y1+cone, x1, indices[3]); + auto D = inputs[0](indices[0], y1+cone, x1+cone, indices[3]); + + return (A*(cone-w)*(cone-h) + B*(w)*(cone-h) + C*(h)*(cone-w) + D*w*h); + }, name, tag); +} + +/*! +* \brief Resize given tensor to given shape using bilinear interpolation for NCHW +* +* \param inputs The input tensor array. +* \param shape Output shape to scale to. +* \param name Name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor resized to given shape +*/ +inline Tensor scale_bilinear_nchw(const Array& inputs, + Array shape, + std::string name = "tensor", + std::string tag = kInjective) { + Array out_shape; + out_shape.push_back(inputs[0]->shape[0]); + out_shape.push_back(inputs[0]->shape[1]); + out_shape.push_back(shape[0]); + out_shape.push_back(shape[1]); + + Array split_ind; + split_ind.push_back(make_const(UInt(32), 2)); + + Array weights = split(inputs[1], split_ind, 2); + Tensor coords = cast(weights[0], Int(32)); + + return compute( + out_shape, [&](const Array& indices) { + auto y1 = coords(indices[2], indices[3], 0); + auto x1 = coords(indices[2], indices[3], 1); + auto h = weights[1](indices[2], indices[3], 0); + auto w = weights[1](indices[2], indices[3], 1); + + auto A = inputs[0](indices[0], indices[1], y1, x1); + auto B = inputs[0](indices[0], indices[1], y1, x1+1); + auto C = inputs[0](indices[0], indices[1], y1+1, x1); + auto D = inputs[0](indices[0], indices[1], y1+1, x1+1); + + return (A*(1-w)*(1-h) + B*(w)*(1-h) + C*(h)*(1-w) + D*w*h); + }, name, tag); +} + +/*! +* \brief Resize given tensor to given shape using bilinear interpolation +* +* \param inputs The input tensor array. +* \param shape Output shape to scale to. +* \param layout input layout +* \param name Name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor resized to given shape +*/ +inline Tensor scale_bilinear(const Array& inputs, + Array shape, + std::string layout = "NCHW", + std::string name = "tensor", + std::string tag = kInjective) { + Tensor ret; + + if (layout == "NHWC") { + ret = scale_bilinear_nhwc(inputs, shape); + } else { + ret = scale_bilinear_nchw(inputs, shape); + } + + return cast(ret, inputs[0]->dtype); +} + +/*! +* \brief Resize given tensor to given shape +* +* \param inputs The input tensor array. +* Bilinear will have 2 inputs one being the weights. +* \param shape Output shape to scale to. +* \param layout input layout +* \param mode Angorithm to use (NN / BILINEAR) +* \param name Name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor resized to given shape +*/ +inline Tensor scale(const Array& inputs, + Array shape, + std::string layout = "NCHW", + std::string mode = "BILINEAR", + std::string name = "tensor", + std::string tag = kInjective) { + if (mode == "NN") { + return scale_nn(inputs, shape, layout); + } else { + return scale_bilinear(inputs, shape, layout); + } +} + +} // namespace nn +} // namespace topi +#endif // TOPI_NN_SCALE_H_ diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index 056d1a76339a2..a276e8aaf0315 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -15,5 +15,6 @@ from .conv2d_transpose import * from .bnn import * from .upsampling import * +from .bilinear_scale import * from .local_response_norm import * from .l2_norm import * diff --git a/topi/python/topi/nn/bilinear_scale.py b/topi/python/topi/nn/bilinear_scale.py new file mode 100644 index 0000000000000..9bf0a05f6c387 --- /dev/null +++ b/topi/python/topi/nn/bilinear_scale.py @@ -0,0 +1,31 @@ +"""TVM operator bilinear scaling compute.""" +from __future__ import absolute_import +import topi + + +def bilinear_scale(data, weights, out_size, layout="NCHW"): + """Perform bilinear scaling on the data. + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, channel, in_height, in_width] + or [batch, in_height, in_width, channel] + + weights: tvm.Tensor + 1-D with weights [x, y, x_diff, y_diff] + helper function tvm.contrib.image.bilinear_weights available to generate this. + + layout: string + either "NCHW" or "NHWC" + + out_size: Tuple + Tuple of (out_height, out_width) + + Returns + ------- + output : tvm.Tensor + 4-D with shape [batch, channel, out_height, out_width] + or [batch, out_height, out_width, channel] + """ + return topi.cpp.nn.scale([data, weights], out_size, layout, "BILINEAR") diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index 9297eb4ad06b1..69dae8b894823 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -1,7 +1,6 @@ """TVM operator upsampling compute.""" from __future__ import absolute_import -import tvm -from .. import util +import topi def upsampling(data, scale, layout="NCHW"): @@ -27,54 +26,12 @@ def upsampling(data, scale, layout="NCHW"): or [batch, in_height*scale, in_width*scale, channel] """ + if layout == "NCHW": - return upsampling_nchw(data, scale) + out_shape = (data.shape[2] * scale, data.shape[3] * scale) elif layout == "NHWC": - return upsampling_nhwc(data, scale) + out_shape = (data.shape[1] * scale, data.shape[2] * scale) else: raise ValueError("not support this layout {} yet".format(layout)) - -def upsampling_nchw(data, scale): - """Perform nearest neighor upsampling on NCHW layout input. - - Parameters - ---------- - data : tvm.Tensor - 4-D with shape [batch, channel, in_height, in_width] - - scale: int - upsampling scaling factor - - Returns - ------- - output : tvm.Tensor - 4-D with shape [batch, channel, in_height*scale, in_width*scale] - """ - batch, channel, height, width = data.shape - out_height = util.simplify(height * scale) - out_width = util.simplify(width * scale) - - return tvm.compute((batch, channel, out_height, out_width), \ - lambda n, c, h, w: data[n, c, h/scale, w/scale]) - - -def upsampling_nhwc(data, scale): - """Perform nearest neighor upsampling on NHWC layout input. - - Parameters - ---------- - data : tvm.Tensor - 4-D with shape [batch, in_height, in_width, channel] - - scale: int - upsampling scaling factor - - """ - - batch, height, width, channel = data.shape - out_height = util.simplify(height * scale) - out_width = util.simplify(width * scale) - - return tvm.compute((batch, out_height, out_width, channel), \ - lambda n, h, w, c: data[n, h/scale, w/scale, c]) + return topi.cpp.nn.scale([data], out_shape, layout, "NN") diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 66a865724dc32..78ffcbac29f3e 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -12,6 +12,7 @@ from .dilate_python import dilate_python from .softmax_python import softmax_python, log_softmax_python from .upsampling_python import upsampling_python +from .bilinear_scale_python import bilinear_scale_python from .reorg_python import reorg_python from .region_python import region_python from .shortcut_python import shortcut_python diff --git a/topi/python/topi/testing/bilinear_scale_python.py b/topi/python/topi/testing/bilinear_scale_python.py new file mode 100644 index 0000000000000..8452499116b82 --- /dev/null +++ b/topi/python/topi/testing/bilinear_scale_python.py @@ -0,0 +1,47 @@ +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Bilinear Scale in python""" +import numpy as np + +def bilinear_scale_python(image, weights, out_size, layout): + """ Bilinear scaling using python""" + (new_h, new_w) = out_size + + if layout == 'NHWC': + (batch, h, w, channel) = image.shape + scaled_image = np.ones((batch, new_h, new_w, channel)) + else: + (batch, channel, h, w) = image.shape + scaled_image = np.ones((batch, channel, new_h, new_w)) + + for b in range(batch): + for i in range(channel): + for j in range(new_h): + for k in range(new_w): + x1 = int(weights[j][k][0]) + y1 = int(weights[j][k][1]) + + x_diff = weights[j][k][2] + y_diff = weights[j][k][3] + + if layout == 'NHWC': + A = image[b][y1][x1][i] + B = image[b][y1][x1+1][i] + C = image[b][y1+1][x1][i] + D = image[b][y1+1][x1+1][i] + else: + A = image[b][i][y1][x1] + B = image[b][i][y1][x1+1] + C = image[b][i][y1+1][x1] + D = image[b][i][y1+1][x1+1] + + pixel = (A*(1-x_diff)*(1-y_diff) + + B*(x_diff)*(1-y_diff) + + C*(y_diff)*(1-x_diff) + + D*(x_diff*y_diff)) + + if layout == 'NHWC': + scaled_image[b][j][k][i] = pixel + else: + scaled_image[b][i][j][k] = pixel + + return scaled_image diff --git a/topi/python/topi/testing/upsampling_python.py b/topi/python/topi/testing/upsampling_python.py index 328c7a5a0bc15..cd97099883b2d 100644 --- a/topi/python/topi/testing/upsampling_python.py +++ b/topi/python/topi/testing/upsampling_python.py @@ -3,13 +3,25 @@ import numpy as np def upsample_nearest(arr, scale): + """ Populate the array by scale factor""" return arr.repeat(scale, axis=0).repeat(scale, axis=1) -def upsampling_python(data, scale): +def upsampling_python(data, scale, layout): + """ Python version of scaling using nearest neighbour """ ishape = data.shape - oshape = (ishape[0], ishape[1], ishape[2]*scale, ishape[3]*scale) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for c in range(oshape[1]): - output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale) - return output_np + if layout == 'NCHW': + oshape = (ishape[0], ishape[1], ishape[2]*scale, ishape[3]*scale) + output_np = np.zeros(oshape, dtype=data.dtype) + for b in range(oshape[0]): + for c in range(oshape[1]): + output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale) + return output_np + elif layout == 'NHWC': + oshape = (ishape[0], ishape[1]*scale, ishape[1]*scale, ishape[3]) + output_np = np.zeros(oshape, dtype=data.dtype) + for b in range(oshape[0]): + for c in range(oshape[3]): + output_np[b, :, :, c] = upsample_nearest(data[b, :, :, c], scale) + return output_np + else: + raise ValueError("not support this layout {} yet".format(layout)) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index f4feafb043f15..f4faf7e756156 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -280,6 +281,12 @@ TVM_REGISTER_GLOBAL("topi.take") } }); +/* Ops from nn/scale.h */ +TVM_REGISTER_GLOBAL("topi.nn.scale") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::scale(args[0], args[1], args[2], args[3]); + }); + /* Ops from nn/batch_norm.h */ TVM_REGISTER_GLOBAL("topi.nn.batch_norm_inference") .set_body([](TVMArgs args, TVMRetValue *rv) { diff --git a/topi/tests/python/test_topi_bilinear_scale.py b/topi/tests/python/test_topi_bilinear_scale.py new file mode 100644 index 0000000000000..54ce7111da79d --- /dev/null +++ b/topi/tests/python/test_topi_bilinear_scale.py @@ -0,0 +1,60 @@ +"""Test code for bilinear scale """ +import numpy as np +import tvm +import topi +import topi.testing +import math +from tvm.contrib.image import bilinear_weights + +def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW'): + + + if layout == 'NCHW': + A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='uint8') + dtype = A.dtype + out_shape = (batch, in_channel, out_height, out_width) + a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype) + elif layout == 'NHWC': + A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8') + dtype = A.dtype + out_shape = (batch, out_height, out_width, in_channel) + a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype) + else: + raise NotImplementedError( + 'Layout not supported {} '.format(layout)) + + W = tvm.placeholder((out_height, out_width, 4), name='A') + weights = bilinear_weights(a_np, out_height, out_width, layout) + + B = topi.nn.bilinear_scale(A, W, (out_height, out_width), layout=layout) + + b_np = topi.testing.bilinear_scale_python(a_np, weights, (out_height, out_width), layout) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(B) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(weights, ctx) + b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx) + f = tvm.build(s, [A, W, B], device) + f(a, w, b) + + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + + for device in ['llvm', 'cuda', 'vulkan']: + check_device(device) + +def test_bilinear_scale(): + verify_bilinear_scale(4, 16, 32, 32, 50, 50) + verify_bilinear_scale(6, 32, 64, 64, 20, 20) + verify_bilinear_scale(4, 16, 32, 32, 50, 50, "NHWC") + verify_bilinear_scale(6, 32, 64, 64, 20, 20, "NHWC") + +if __name__ == "__main__": + test_bilinear_scale() diff --git a/topi/tests/python/test_topi_upsampling.py b/topi/tests/python/test_topi_upsampling.py index 7421dd4151e6d..e07a1a8a13af4 100644 --- a/topi/tests/python/test_topi_upsampling.py +++ b/topi/tests/python/test_topi_upsampling.py @@ -5,14 +5,26 @@ import topi.testing import math -def verify_upsampling(batch, in_channel, in_height, in_width, scale): - A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') - B = topi.nn.upsampling(A, scale) - out_shape = (batch, in_channel, in_height*scale, in_width*scale) - dtype = A.dtype +def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW'): - a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype) - b_np = topi.testing.upsampling_python(a_np, scale) + + if layout == 'NCHW': + A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') + dtype = A.dtype + out_shape = (batch, in_channel, in_height*scale, in_width*scale) + a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype) + elif layout == 'NHWC': + A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A') + dtype = A.dtype + out_shape = (batch, in_height*scale, in_width*scale, in_channel) + a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype) + else: + raise NotImplementedError( + 'Layout not supported {} '.format(layout)) + + B = topi.nn.upsampling(A, scale, layout=layout) + + b_np = topi.testing.upsampling_python(a_np, scale, layout) def check_device(device): ctx = tvm.context(device, 0) @@ -29,12 +41,15 @@ def check_device(device): np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + for device in ['llvm', 'cuda', 'vulkan']: check_device(device) def test_upsampling(): verify_upsampling(8, 16, 32, 32, 2) verify_upsampling(12, 32, 64, 64, 3) + verify_upsampling(8, 16, 32, 32, 2, "NHWC") + verify_upsampling(12, 32, 64, 64, 3, "NHWC") if __name__ == "__main__": test_upsampling()