From 5ab95e780800821a23342aa486546255202f6d3a Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 18 Jul 2018 21:34:48 +0530 Subject: [PATCH] [NNVM][TENSORFLOW] bug fix on bilinear and resize op integration in frontend. (#1440) --- nnvm/python/nnvm/frontend/tensorflow.py | 22 +--- nnvm/python/nnvm/testing/tf.py | 123 ------------------ nnvm/src/top/image/resize.cc | 10 +- nnvm/tests/python/compiler/test_top_level2.py | 7 +- .../frontend/tensorflow/test_forward.py | 92 ++++++++++--- topi/include/topi/image/resize.h | 52 +++++--- .../topi/testing/bilinear_resize_python.py | 60 ++++----- 7 files changed, 144 insertions(+), 222 deletions(-) diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index aa00c51836fb9..31ee0a8b116fc 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -332,26 +332,14 @@ def _impl(inputs, attr, params): def _resize_bilinear(): def _impl(inputs, attr, params): - # Change this when we have corresponding resize bilinear operation. - print("ResizeBilinear:Only NN (nearest neighbor) supported in symetric mode of dimensions") - print("Change this when we have corresponding resize bilinear operation") - - # NHWC - input_shape = attr['_input_shapes'][inputs[0]][0] - in_hw = (input_shape[1], input_shape[2]) - out_hw = params[inputs[1].list_output_names()[0]] + attr['size'] = attr['_output_shapes'][0][1:3] inputs.pop(1) + # NHWC attr['layout'] = 'NHWC' - if in_hw[0] < 0 or in_hw[1] < 0: - scale = 1 - else: - # Considering height alone for scale - scale = out_hw[0] / in_hw[0] - - return AttrCvt(op_name="upsampling", - ignores=['Tdim', 'align_corners'], - extras={'scale': scale})(inputs, attr) + return AttrCvt(op_name="resize", + ignores=['Tdim'], + extras={'method': "BILINEAR"})(inputs, attr) return _impl def _check_numerics(): diff --git a/nnvm/python/nnvm/testing/tf.py b/nnvm/python/nnvm/testing/tf.py index 1762ce56596ba..f5ca0e11d194f 100644 --- a/nnvm/python/nnvm/testing/tf.py +++ b/nnvm/python/nnvm/testing/tf.py @@ -6,7 +6,6 @@ """ import re import os.path -import numpy as np # Tensorflow imports import tensorflow as tf @@ -107,52 +106,6 @@ def id_to_string(self, node_id): return '' return self.node_lookup[node_id] -def read_normalized_tensor_from_image_file(file_name, - input_height=299, - input_width=299, - input_mean=0, - input_std=255): - """ Preprocessing of image - Parameters - ---------- - - file_name: String - Image filename. - - input_height: int - model input height. - - input_width: int - model input width - - input_mean: int - Mean to be substracted in normalization. - - input_std: int - Standard deviation used in normalization. - - Returns - ------- - - np_array: Numpy array - Normalized image data as a numpy array. - - """ - - input_name = "file_reader" - output_name = "normalized" - file_reader = tf.read_file(file_name, input_name) - - image_reader = tf.image.decode_jpeg(file_reader, channels=3, - name='jpeg_reader') - float_caster = tf.cast(image_reader, tf.float32) - dims_expander = tf.expand_dims(float_caster, 0) - resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width]) - normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) - tf.InteractiveSession() - np_array = normalized.eval() - return np_array - def get_workload(model_path): """ Import workload from frozen protobuf @@ -181,79 +134,3 @@ def get_workload(model_path): graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name='') return graph_def - -def get_workload_inception_v3(): - """ Import Inception V3 workload from frozen protobuf - - Parameters - ---------- - Nothing. - - Returns - ------- - (normalized, graph_def) : Tuple - normalized is normalized input for graph testing. - graph_def is the tensorflow workload for Inception V3. - """ - - repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV3/' - model_path = 'InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb' - - image_name = 'elephant-299.jpg' - image_url = os.path.join(repo_base, image_name) - from mxnet.gluon.utils import download - download(image_url, image_name) - normalized = read_normalized_tensor_from_image_file(os.path.join("./", image_name)) - - return (normalized, get_workload(model_path)) - -def get_workload_inception_v1(): - """ Import Inception V1 workload from frozen protobuf - - Parameters - ---------- - Nothing. - - Returns - ------- - (image_data, tvm_data, graph_def) : Tuple - image_data is raw encoded image data for TF input. - tvm_data is the decoded image data for TVM input. - graph_def is the tensorflow workload for Inception V1. - - """ - - repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/' - model_path = 'InceptionV1/classify_image_graph_def-with_shapes.pb' - image_name = 'elephant-299.jpg' - image_url = os.path.join(repo_base, image_name) - - from mxnet.gluon.utils import download - download(image_url, image_name) - - if not tf.gfile.Exists(os.path.join("./", image_name)): - tf.logging.fatal('File does not exist %s', image) - image_data = tf.gfile.FastGFile(os.path.join("./", image_name), 'rb').read() - - # TVM doesn't handle decode, hence decode it. - from PIL import Image - tvm_data = Image.open(os.path.join("./", image_name)).resize((299, 299)) - tvm_data = np.array(tvm_data) - - return (image_data, tvm_data, get_workload(model_path)) - -def get_workload_mobilenet(): - """ Import mobilenet workload from frozen protobuf - - Parameters - ---------- - Nothing. - - Returns - ------- - graph_def: graphdef - graph_def is the tensorflow workload for mobilenet. - - """ - - return get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb") diff --git a/nnvm/src/top/image/resize.cc b/nnvm/src/top/image/resize.cc index 151296d8ecc48..b89070fe38971 100644 --- a/nnvm/src/top/image/resize.cc +++ b/nnvm/src/top/image/resize.cc @@ -40,13 +40,9 @@ inline bool ResizeInferShape(const nnvm::NodeAttrs& attrs, dshape = ConvertLayout(dshape, param.layout, kNCHW); TShape oshape = dshape; - if (param.layout == "NCHW") { - oshape[2] = param.size[0]; - oshape[3] = param.size[1]; - } else { - oshape[1] = param.size[0]; - oshape[2] = param.size[1]; - } + oshape[2] = param.size[0]; + oshape[3] = param.size[1]; + oshape = ConvertLayout(oshape, kNCHW, param.layout); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py index 19387a5865284..46809b0d658c4 100644 --- a/nnvm/tests/python/compiler/test_top_level2.py +++ b/nnvm/tests/python/compiler/test_top_level2.py @@ -300,11 +300,10 @@ def test_upsampling_bilinear(): def test_resize_bilinear(): x = sym.Variable("x") - scale = 2 - y = sym.upsampling(x, scale=scale, method="BILINEAR", name="y", layout="NHWC") + y = sym.resize(x, size=(60, 60), method="BILINEAR", name="y", layout="NHWC") dtype = "float32" dshape = (1, 32, 32, 4) - oshape = (1, 32*scale, 32*scale, 4) + oshape = (1, 60, 60, 4) shape_dict = {"x": dshape} dtype_dict = {"x": dtype} for target, ctx in ctx_list(): @@ -314,7 +313,7 @@ def test_resize_bilinear(): data = tvm.nd.array(a_np) m.run(x=data) out = m.get_output(0, tvm.nd.empty(oshape, dtype)) - b_np = topi.testing.bilinear_resize_python(a_np, (32*scale, 32*scale), "NHWC") + b_np = topi.testing.bilinear_resize_python(a_np, (60, 60), "NHWC") np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5) if __name__ == "__main__": diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 9fcf33bdf558c..3275d62153093 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -469,24 +469,65 @@ def test_forward_multi_input(): sess.close() +####################################################################### +# Resize Bilinear +# --------------- + +def _test_resize_bilinear(in_shape, to_shape, align_corners): + """ One iteration of resize bilinear """ + + data = np.random.uniform(size=in_shape).astype('float32') + shape_data = np.array(to_shape).astype('int32') + + with tf.Graph().as_default(): + in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype) + shape_data = constant_op.constant(shape_data, shape=shape_data.shape, dtype=shape_data.dtype) + + # pylint: disable=unused-variable + resize_out = tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners) + # pylint: enable=unused-variable + + with tf.Session() as sess: + graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + ['ResizeBilinear'], + ) + + tf_output = run_tf_graph(sess, data, + 'Const:0', 'ResizeBilinear:0') + + tvm_output = run_tvm_graph(graph_def, + data, + "Const", tf_output.shape, data.dtype) + + np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3) + + sess.close() + +def test_forward_resize_bilinear(): + """ Resize Bilinear """ + + _test_resize_bilinear((4, 16, 32, 32), [50, 50], False) + _test_resize_bilinear((6, 32, 64, 64), [20, 20], True) + + ####################################################################### # Inception V3 # ------------ def test_forward_inception_v3(): '''test inception V3 model''' with tf.Graph().as_default(): - (data, graph_def) = nnvm.testing.tf.get_workload_inception_v3() + graph_def = nnvm.testing.tf.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb') # Call the utility to import the graph definition into default graph. graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) - tvm_output = run_tvm_graph(graph_def, data, 'input', (1, 1001), 'float32') + data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32') + with tf.Session() as sess: tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0') - - top_tvm = np.squeeze(tvm_output).argsort()[-3:][::-1] - top_tf = np.squeeze(tf_output).argsort()[-3:][::-1] - - np.testing.assert_allclose(top_tf, top_tvm, rtol=1e-5, atol=1e-5) + tvm_output = run_tvm_graph(graph_def, data, 'input', tf_output.shape, 'float32') + np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5) ####################################################################### # Inception V1 @@ -494,16 +535,35 @@ def test_forward_inception_v3(): def test_forward_inception_v1(): '''test inception V1 model''' with tf.Graph().as_default(): - (data, tvm_data, graph_def) = nnvm.testing.tf.get_workload_inception_v1() + graph_def = nnvm.testing.tf.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb") # Call the utility to import the graph definition into default graph. graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) - tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents', (1, 1008), 'float32') + # Build an image from random data. + from PIL import Image + from tvm.contrib import util + + img_array = np.random.uniform(size=(1, 600, 600, 3)).astype("uint8") + img = Image.frombuffer('RGB', (600, 600), img_array.tostring(), 'raw', 'RGB', 0, 1) + temp = util.tempdir() + img_path = temp.relpath("tf-test.jpg") + img.save(img_path); + + import os.path + if not tf.gfile.Exists(os.path.join(img_path)): + tf.logging.fatal('File does not exist %s', image) + data = tf.gfile.FastGFile(os.path.join(img_path), 'rb').read() + + temp.remove() + # Extract tensorflow decoded image frame for tvm input with tf.Session() as sess: - tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0') + tvm_data = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'DecodeJpeg:0') - np.testing.assert_allclose(tf_output, tvm_output, rtol=2e-2, atol=2e-2) + with tf.Session() as sess: + tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0') + tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents', tf_output.shape, 'float32') + np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5) ####################################################################### # Mobilenet @@ -511,7 +571,7 @@ def test_forward_inception_v1(): def test_forward_mobilenet(): '''test mobilenet model''' with tf.Graph().as_default(): - graph_def = nnvm.testing.tf.get_workload_mobilenet() + graph_def = nnvm.testing.tf.get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb") # Call the utility to import the graph definition into default graph. graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def) @@ -520,12 +580,7 @@ def test_forward_mobilenet(): with tf.Session() as sess: tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0') - - out_shape = tf_output.shape - tvm_output = run_tvm_graph(graph_def, data, 'input', out_shape, 'float32') - top_tvm = np.squeeze(tvm_output).argsort()[-10:][::-1] - top_tf = np.squeeze(tf_output).argsort()[-10:][::-1] - + tvm_output = run_tvm_graph(graph_def, data, 'input', tf_output.shape, 'float32') np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5) ####################################################################### @@ -544,3 +599,4 @@ def test_forward_mobilenet(): test_forward_inception_v1() test_forward_mobilenet() test_forward_variable() + test_forward_resize_bilinear() diff --git a/topi/include/topi/image/resize.h b/topi/include/topi/image/resize.h index e00fe15fdf1e4..b6bd51ef0fd20 100644 --- a/topi/include/topi/image/resize.h +++ b/topi/include/topi/image/resize.h @@ -153,7 +153,7 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input, Expr y_ratio; Expr x_ratio; - if (align_corners) { + if (!align_corners) { y_ratio = make_const(Float(32), (static_cast(*in_height) / static_cast(*out_height))); x_ratio = make_const(Float(32), (static_cast(*in_width) / @@ -170,21 +170,31 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input, return compute( out_shape, [&](const Array& indices) { - auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(y_ratio * indices[1])); - auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(x_ratio * indices[2])); + auto in_y = indices[1] * y_ratio; + auto yf = tvm::floor(in_y); + auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y)); - auto y1 = tvm::select(((y0 + cone) > other_y), other_y, (y0 + cone)); - auto x1 = tvm::select(((x0 + cone) > other_x), other_x, (x0 + cone)); + auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y)); + auto y1 = tvm::select((yc > other_y), other_y, yc); + auto y_lerp = in_y - yf; - auto h = (y_ratio * indices[1]) - y0; - auto w = (x_ratio * indices[2]) - x0;; + auto in_x = indices[2] * x_ratio; + auto xf = tvm::floor(in_x); + auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x)); + + auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x)); + auto x1 = tvm::select((xc > other_x), other_x, xc); + auto x_lerp = in_x - xf; auto A = input(indices[0], y0, x0, indices[3]); auto B = input(indices[0], y0, x1, indices[3]); auto C = input(indices[0], y1, x0, indices[3]); auto D = input(indices[0], y1, x1, indices[3]); - return (A*(cone-w)*(cone-h) + B*(w)*(cone-h) + C*(h)*(cone-w) + D*w*h); + auto top = A + (B - A) * x_lerp; + auto bottom = C + (D - C) * x_lerp; + + return (top + (bottom - top) * y_lerp); }, name, tag); } @@ -220,7 +230,7 @@ inline Tensor resize_bilinear_nchw(const Tensor& input, Expr y_ratio; Expr x_ratio; - if (align_corners) { + if (!align_corners) { y_ratio = make_const(Float(32), (static_cast(*in_height) / static_cast(*out_height))); x_ratio = make_const(Float(32), (static_cast(*in_width) / @@ -237,21 +247,31 @@ inline Tensor resize_bilinear_nchw(const Tensor& input, return compute( out_shape, [&](const Array& indices) { - auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(y_ratio * indices[2])); - auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(x_ratio * indices[3])); + auto in_y = indices[2] * y_ratio; + auto yf = tvm::floor(in_y); + auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y)); - auto y1 = tvm::select(((y0 + cone) > other_y), other_y, (y0 + cone)); - auto x1 = tvm::select(((x0 + cone) > other_x), other_x, (x0 + cone)); + auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y)); + auto y1 = tvm::select((yc > other_y), other_y, yc); + auto y_lerp = in_y - yf; - auto h = (y_ratio * indices[2]) - y0; - auto w = (x_ratio * indices[3]) - x0;; + auto in_x = indices[3] * x_ratio; + auto xf = tvm::floor(in_x); + auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x)); + + auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x)); + auto x1 = tvm::select((xc > other_x), other_x, xc); + auto x_lerp = in_x - xf; auto A = input(indices[0], indices[1], y0, x0); auto B = input(indices[0], indices[1], y0, x1); auto C = input(indices[0], indices[1], y1, x0); auto D = input(indices[0], indices[1], y1, x1); - return ((A*(cone-w)*(cone-h)) + (B*(w)*(cone-h)) + (C*(h)*(cone-w)) + (D*w*h)); + auto top = A + (B - A) * x_lerp; + auto bottom = C + (D - C) * x_lerp; + + return (top + (bottom - top) * y_lerp); }, name, tag); } diff --git a/topi/python/topi/testing/bilinear_resize_python.py b/topi/python/topi/testing/bilinear_resize_python.py index de931d7e1f5df..c014b075681aa 100644 --- a/topi/python/topi/testing/bilinear_resize_python.py +++ b/topi/python/topi/testing/bilinear_resize_python.py @@ -3,32 +3,6 @@ import math import numpy as np -def bilinear_weights(height, width, new_h, new_w, align_corners=False): - """ Helper function to generate weights for bilinear scaling """ - - if align_corners: - x_ratio = np.float32(np.float32(width)/np.float32(new_w)) - y_ratio = np.float32(np.float32(height)/np.float32(new_h)) - else: - x_ratio = np.float32(np.float32(width-1)/np.float32(new_w-1)) - y_ratio = np.float32(np.float32(height-1)/np.float32(new_h-1)) - - def _bilinear_interpolation(y, x): - x_coord = math.floor(x_ratio * x) - y_coord = math.floor(y_ratio * y) - x_diff = np.float32((x_ratio * x) - x_coord) - y_diff = np.float32((y_ratio * y) - y_coord) - - return [y_coord, x_coord, y_diff, x_diff] - - # weights to hold (srcx, srcy, x_diff, y_diff) for each out value. - weights = np.empty([new_h, new_w, 4], dtype='float32') - - for i in range(new_h): - for j in range(new_w): - weights[i][j] = _bilinear_interpolation(i, j) - return weights - def bilinear_resize_python(image, out_size, layout, align_corners=False): """ Bilinear scaling using python""" (new_h, new_w) = out_size @@ -40,20 +14,32 @@ def bilinear_resize_python(image, out_size, layout, align_corners=False): (batch, channel, h, w) = image.shape scaled_image = np.ones((batch, channel, new_h, new_w)) - weights = bilinear_weights(h, w, new_h, new_w, align_corners) + if align_corners: + height_scale = np.float32(h-1) / np.float32(out_size[0]-1) + width_scale = np.float32(w-1) / np.float32(out_size[1]-1) + else: + height_scale = np.float32(h) / np.float32(out_size[0]) + width_scale = np.float32(w) / np.float32(out_size[1]) for b in range(batch): for i in range(channel): for j in range(new_h): for k in range(new_w): - y0 = int(weights[j][k][0]) - x0 = int(weights[j][k][1]) + in_y = j * height_scale + y0 = math.floor(in_y) + y1 = min(math.ceil(in_y), h - 1) + y_lerp = in_y - y0 + + y0 = int(y0) + y1 = int(y1) - x1 = min((x0+1), (w-1)) - y1 = min((y0+1), (h-1)) + in_x = k * width_scale + x0 = math.floor(in_x) + x1 = min(math.ceil(in_x), w - 1) + x_lerp = in_x - x0 - y_diff = weights[j][k][2] - x_diff = weights[j][k][3] + x0 = int(x0) + x1 = int(x1) if layout == 'NHWC': A = image[b][y0][x0][i] @@ -66,10 +52,10 @@ def bilinear_resize_python(image, out_size, layout, align_corners=False): C = image[b][i][y1][x0] D = image[b][i][y1][x1] - pixel = np.float32((A*(1-x_diff)*(1-y_diff) + - B*(x_diff)*(1-y_diff) + - C*(y_diff)*(1-x_diff) + - D*(x_diff*y_diff))) + top = A + (B - A) * x_lerp + bottom = C + (D - C) * x_lerp + + pixel = np.float32(top + (bottom - top) * y_lerp) if layout == 'NHWC': scaled_image[b][j][k][i] = pixel