Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNVM][TENSORFLOW] bug fix on bilinear and resize op integration in frontend. #1440

Merged
merged 3 commits into from
Jul 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 5 additions & 17 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,26 +332,14 @@ def _impl(inputs, attr, params):

def _resize_bilinear():
def _impl(inputs, attr, params):
# Change this when we have corresponding resize bilinear operation.
print("ResizeBilinear:Only NN (nearest neighbor) supported in symetric mode of dimensions")
print("Change this when we have corresponding resize bilinear operation")

# NHWC
input_shape = attr['_input_shapes'][inputs[0]][0]
in_hw = (input_shape[1], input_shape[2])
out_hw = params[inputs[1].list_output_names()[0]]
attr['size'] = attr['_output_shapes'][0][1:3]
inputs.pop(1)
# NHWC
attr['layout'] = 'NHWC'

if in_hw[0] < 0 or in_hw[1] < 0:
scale = 1
else:
# Considering height alone for scale
scale = out_hw[0] / in_hw[0]

return AttrCvt(op_name="upsampling",
ignores=['Tdim', 'align_corners'],
extras={'scale': scale})(inputs, attr)
return AttrCvt(op_name="resize",
ignores=['Tdim'],
extras={'method': "BILINEAR"})(inputs, attr)
return _impl

def _check_numerics():
Expand Down
123 changes: 0 additions & 123 deletions nnvm/python/nnvm/testing/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""
import re
import os.path
import numpy as np

# Tensorflow imports
import tensorflow as tf
Expand Down Expand Up @@ -107,52 +106,6 @@ def id_to_string(self, node_id):
return ''
return self.node_lookup[node_id]

def read_normalized_tensor_from_image_file(file_name,
input_height=299,
input_width=299,
input_mean=0,
input_std=255):
""" Preprocessing of image
Parameters
----------

file_name: String
Image filename.

input_height: int
model input height.

input_width: int
model input width

input_mean: int
Mean to be substracted in normalization.

input_std: int
Standard deviation used in normalization.

Returns
-------

np_array: Numpy array
Normalized image data as a numpy array.

"""

input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)

image_reader = tf.image.decode_jpeg(file_reader, channels=3,
name='jpeg_reader')
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0)
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
tf.InteractiveSession()
np_array = normalized.eval()
return np_array

def get_workload(model_path):
""" Import workload from frozen protobuf

Expand Down Expand Up @@ -181,79 +134,3 @@ def get_workload(model_path):
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return graph_def

def get_workload_inception_v3():
""" Import Inception V3 workload from frozen protobuf

Parameters
----------
Nothing.

Returns
-------
(normalized, graph_def) : Tuple
normalized is normalized input for graph testing.
graph_def is the tensorflow workload for Inception V3.
"""

repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV3/'
model_path = 'InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb'

image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)
from mxnet.gluon.utils import download
download(image_url, image_name)
normalized = read_normalized_tensor_from_image_file(os.path.join("./", image_name))

return (normalized, get_workload(model_path))

def get_workload_inception_v1():
""" Import Inception V1 workload from frozen protobuf

Parameters
----------
Nothing.

Returns
-------
(image_data, tvm_data, graph_def) : Tuple
image_data is raw encoded image data for TF input.
tvm_data is the decoded image data for TVM input.
graph_def is the tensorflow workload for Inception V1.

"""

repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
model_path = 'InceptionV1/classify_image_graph_def-with_shapes.pb'
image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)

from mxnet.gluon.utils import download
download(image_url, image_name)

if not tf.gfile.Exists(os.path.join("./", image_name)):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(os.path.join("./", image_name), 'rb').read()

# TVM doesn't handle decode, hence decode it.
from PIL import Image
tvm_data = Image.open(os.path.join("./", image_name)).resize((299, 299))
tvm_data = np.array(tvm_data)

return (image_data, tvm_data, get_workload(model_path))

def get_workload_mobilenet():
""" Import mobilenet workload from frozen protobuf

Parameters
----------
Nothing.

Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for mobilenet.

"""

return get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb")
10 changes: 3 additions & 7 deletions nnvm/src/top/image/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,9 @@ inline bool ResizeInferShape(const nnvm::NodeAttrs& attrs,
dshape = ConvertLayout(dshape, param.layout, kNCHW);

TShape oshape = dshape;
if (param.layout == "NCHW") {
oshape[2] = param.size[0];
oshape[3] = param.size[1];
} else {
oshape[1] = param.size[0];
oshape[2] = param.size[1];
}
oshape[2] = param.size[0];
oshape[3] = param.size[1];

oshape = ConvertLayout(oshape, kNCHW, param.layout);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);

Expand Down
7 changes: 3 additions & 4 deletions nnvm/tests/python/compiler/test_top_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,10 @@ def test_upsampling_bilinear():

def test_resize_bilinear():
x = sym.Variable("x")
scale = 2
y = sym.upsampling(x, scale=scale, method="BILINEAR", name="y", layout="NHWC")
y = sym.resize(x, size=(60, 60), method="BILINEAR", name="y", layout="NHWC")
dtype = "float32"
dshape = (1, 32, 32, 4)
oshape = (1, 32*scale, 32*scale, 4)
oshape = (1, 60, 60, 4)
shape_dict = {"x": dshape}
dtype_dict = {"x": dtype}
for target, ctx in ctx_list():
Expand All @@ -314,7 +313,7 @@ def test_resize_bilinear():
data = tvm.nd.array(a_np)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = topi.testing.bilinear_resize_python(a_np, (32*scale, 32*scale), "NHWC")
b_np = topi.testing.bilinear_resize_python(a_np, (60, 60), "NHWC")
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5)

if __name__ == "__main__":
Expand Down
92 changes: 74 additions & 18 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,49 +469,109 @@ def test_forward_multi_input():

sess.close()

#######################################################################
# Resize Bilinear
# ---------------

def _test_resize_bilinear(in_shape, to_shape, align_corners):
""" One iteration of resize bilinear """

data = np.random.uniform(size=in_shape).astype('float32')
shape_data = np.array(to_shape).astype('int32')

with tf.Graph().as_default():
in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype)
shape_data = constant_op.constant(shape_data, shape=shape_data.shape, dtype=shape_data.dtype)

# pylint: disable=unused-variable
resize_out = tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners)
# pylint: enable=unused-variable

with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['ResizeBilinear'],
)

tf_output = run_tf_graph(sess, data,
'Const:0', 'ResizeBilinear:0')

tvm_output = run_tvm_graph(graph_def,
data,
"Const", tf_output.shape, data.dtype)

np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)

sess.close()

def test_forward_resize_bilinear():
""" Resize Bilinear """

_test_resize_bilinear((4, 16, 32, 32), [50, 50], False)
_test_resize_bilinear((6, 32, 64, 64), [20, 20], True)


#######################################################################
# Inception V3
# ------------
def test_forward_inception_v3():
'''test inception V3 model'''
with tf.Graph().as_default():
(data, graph_def) = nnvm.testing.tf.get_workload_inception_v3()
graph_def = nnvm.testing.tf.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)

tvm_output = run_tvm_graph(graph_def, data, 'input', (1, 1001), 'float32')
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')

with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')

top_tvm = np.squeeze(tvm_output).argsort()[-3:][::-1]
top_tf = np.squeeze(tf_output).argsort()[-3:][::-1]

np.testing.assert_allclose(top_tf, top_tvm, rtol=1e-5, atol=1e-5)
tvm_output = run_tvm_graph(graph_def, data, 'input', tf_output.shape, 'float32')
np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)

#######################################################################
# Inception V1
# ------------
def test_forward_inception_v1():
'''test inception V1 model'''
with tf.Graph().as_default():
(data, tvm_data, graph_def) = nnvm.testing.tf.get_workload_inception_v1()
graph_def = nnvm.testing.tf.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)

tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents', (1, 1008), 'float32')
# Build an image from random data.
from PIL import Image
from tvm.contrib import util

img_array = np.random.uniform(size=(1, 600, 600, 3)).astype("uint8")
img = Image.frombuffer('RGB', (600, 600), img_array.tostring(), 'raw', 'RGB', 0, 1)
temp = util.tempdir()
img_path = temp.relpath("tf-test.jpg")
img.save(img_path);

import os.path
if not tf.gfile.Exists(os.path.join(img_path)):
tf.logging.fatal('File does not exist %s', image)
data = tf.gfile.FastGFile(os.path.join(img_path), 'rb').read()

temp.remove()

# Extract tensorflow decoded image frame for tvm input
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
tvm_data = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'DecodeJpeg:0')

np.testing.assert_allclose(tf_output, tvm_output, rtol=2e-2, atol=2e-2)
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents', tf_output.shape, 'float32')
np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)

#######################################################################
# Mobilenet
# ---------
def test_forward_mobilenet():
'''test mobilenet model'''
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload_mobilenet()
graph_def = nnvm.testing.tf.get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)

Expand All @@ -520,12 +580,7 @@ def test_forward_mobilenet():

with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')

out_shape = tf_output.shape
tvm_output = run_tvm_graph(graph_def, data, 'input', out_shape, 'float32')
top_tvm = np.squeeze(tvm_output).argsort()[-10:][::-1]
top_tf = np.squeeze(tf_output).argsort()[-10:][::-1]

tvm_output = run_tvm_graph(graph_def, data, 'input', tf_output.shape, 'float32')
np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)

#######################################################################
Expand All @@ -544,3 +599,4 @@ def test_forward_mobilenet():
test_forward_inception_v1()
test_forward_mobilenet()
test_forward_variable()
test_forward_resize_bilinear()
Loading