Skip to content

Commit

Permalink
[NNVM][FRONTEND][TENSORFLOW] bug fix on bilinear and resize op integr…
Browse files Browse the repository at this point in the history
…ation to frontend.

	* strengthen the testcases with random data.
  • Loading branch information
srkreddy1238 committed Jul 17, 2018
1 parent 30c1cd6 commit e21b247
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 222 deletions.
22 changes: 5 additions & 17 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,26 +332,14 @@ def _impl(inputs, attr, params):

def _resize_bilinear():
def _impl(inputs, attr, params):
# Change this when we have corresponding resize bilinear operation.
print("ResizeBilinear:Only NN (nearest neighbor) supported in symetric mode of dimensions")
print("Change this when we have corresponding resize bilinear operation")

# NHWC
input_shape = attr['_input_shapes'][inputs[0]][0]
in_hw = (input_shape[1], input_shape[2])
out_hw = params[inputs[1].list_output_names()[0]]
attr['size'] = attr['_output_shapes'][0][1:3]
inputs.pop(1)
# NHWC
attr['layout'] = 'NHWC'

if in_hw[0] < 0 or in_hw[1] < 0:
scale = 1
else:
# Considering height alone for scale
scale = out_hw[0] / in_hw[0]

return AttrCvt(op_name="upsampling",
ignores=['Tdim', 'align_corners'],
extras={'scale': scale})(inputs, attr)
return AttrCvt(op_name="resize",
ignores=['Tdim'],
extras={'method': "BILINEAR"})(inputs, attr)
return _impl

def _check_numerics():
Expand Down
123 changes: 0 additions & 123 deletions nnvm/python/nnvm/testing/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""
import re
import os.path
import numpy as np

# Tensorflow imports
import tensorflow as tf
Expand Down Expand Up @@ -107,52 +106,6 @@ def id_to_string(self, node_id):
return ''
return self.node_lookup[node_id]

def read_normalized_tensor_from_image_file(file_name,
input_height=299,
input_width=299,
input_mean=0,
input_std=255):
""" Preprocessing of image
Parameters
----------
file_name: String
Image filename.
input_height: int
model input height.
input_width: int
model input width
input_mean: int
Mean to be substracted in normalization.
input_std: int
Standard deviation used in normalization.
Returns
-------
np_array: Numpy array
Normalized image data as a numpy array.
"""

input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)

image_reader = tf.image.decode_jpeg(file_reader, channels=3,
name='jpeg_reader')
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0)
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
tf.InteractiveSession()
np_array = normalized.eval()
return np_array

def get_workload(model_path):
""" Import workload from frozen protobuf
Expand Down Expand Up @@ -181,79 +134,3 @@ def get_workload(model_path):
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return graph_def

def get_workload_inception_v3():
""" Import Inception V3 workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
(normalized, graph_def) : Tuple
normalized is normalized input for graph testing.
graph_def is the tensorflow workload for Inception V3.
"""

repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV3/'
model_path = 'InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb'

image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)
from mxnet.gluon.utils import download
download(image_url, image_name)
normalized = read_normalized_tensor_from_image_file(os.path.join("./", image_name))

return (normalized, get_workload(model_path))

def get_workload_inception_v1():
""" Import Inception V1 workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
(image_data, tvm_data, graph_def) : Tuple
image_data is raw encoded image data for TF input.
tvm_data is the decoded image data for TVM input.
graph_def is the tensorflow workload for Inception V1.
"""

repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
model_path = 'InceptionV1/classify_image_graph_def-with_shapes.pb'
image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)

from mxnet.gluon.utils import download
download(image_url, image_name)

if not tf.gfile.Exists(os.path.join("./", image_name)):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(os.path.join("./", image_name), 'rb').read()

# TVM doesn't handle decode, hence decode it.
from PIL import Image
tvm_data = Image.open(os.path.join("./", image_name)).resize((299, 299))
tvm_data = np.array(tvm_data)

return (image_data, tvm_data, get_workload(model_path))

def get_workload_mobilenet():
""" Import mobilenet workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for mobilenet.
"""

return get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb")
10 changes: 3 additions & 7 deletions nnvm/src/top/image/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,9 @@ inline bool ResizeInferShape(const nnvm::NodeAttrs& attrs,
dshape = ConvertLayout(dshape, param.layout, kNCHW);

TShape oshape = dshape;
if (param.layout == "NCHW") {
oshape[2] = param.size[0];
oshape[3] = param.size[1];
} else {
oshape[1] = param.size[0];
oshape[2] = param.size[1];
}
oshape[2] = param.size[0];
oshape[3] = param.size[1];

oshape = ConvertLayout(oshape, kNCHW, param.layout);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);

Expand Down
7 changes: 3 additions & 4 deletions nnvm/tests/python/compiler/test_top_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,10 @@ def test_upsampling_bilinear():

def test_resize_bilinear():
x = sym.Variable("x")
scale = 2
y = sym.upsampling(x, scale=scale, method="BILINEAR", name="y", layout="NHWC")
y = sym.resize(x, size=(60, 60), method="BILINEAR", name="y", layout="NHWC")
dtype = "float32"
dshape = (1, 32, 32, 4)
oshape = (1, 32*scale, 32*scale, 4)
oshape = (1, 60, 60, 4)
shape_dict = {"x": dshape}
dtype_dict = {"x": dtype}
for target, ctx in ctx_list():
Expand All @@ -314,7 +313,7 @@ def test_resize_bilinear():
data = tvm.nd.array(a_np)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = topi.testing.bilinear_resize_python(a_np, (32*scale, 32*scale), "NHWC")
b_np = topi.testing.bilinear_resize_python(a_np, (60, 60), "NHWC")
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5)

if __name__ == "__main__":
Expand Down
86 changes: 68 additions & 18 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,49 +469,103 @@ def test_forward_multi_input():

sess.close()

#######################################################################
# Resize Bilinear
# ---------------

def _test_resize_bilinear(in_shape, to_shape, align_corners):
""" One iteration of resize bilinear """

data = np.random.uniform(size=in_shape).astype('float32')
shape_data = np.array(to_shape).astype('int32')

with tf.Graph().as_default():
in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype)
shape_data = constant_op.constant(shape_data, shape=shape_data.shape, dtype=shape_data.dtype)

# pylint: disable=unused-variable
resize_out = tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners)
# pylint: enable=unused-variable

with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['ResizeBilinear'],
)

tf_output = run_tf_graph(sess, data,
'Const:0', 'ResizeBilinear:0')

tvm_output = run_tvm_graph(graph_def,
data,
"Const", tf_output.shape, data.dtype)

np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)

sess.close()

def test_forward_resize_bilinear():
""" Resize Bilinear """

_test_resize_bilinear((4, 16, 32, 32), [50, 50], False)
_test_resize_bilinear((6, 32, 64, 64), [20, 20], True)


#######################################################################
# Inception V3
# ------------
def test_forward_inception_v3():
'''test inception V3 model'''
with tf.Graph().as_default():
(data, graph_def) = nnvm.testing.tf.get_workload_inception_v3()
graph_def = nnvm.testing.tf.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)

tvm_output = run_tvm_graph(graph_def, data, 'input', (1, 1001), 'float32')
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')

with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')

top_tvm = np.squeeze(tvm_output).argsort()[-3:][::-1]
top_tf = np.squeeze(tf_output).argsort()[-3:][::-1]

np.testing.assert_allclose(top_tf, top_tvm, rtol=1e-5, atol=1e-5)
tvm_output = run_tvm_graph(graph_def, data, 'input', tf_output.shape, 'float32')
np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)

#######################################################################
# Inception V1
# ------------
def test_forward_inception_v1():
'''test inception V1 model'''
with tf.Graph().as_default():
(data, tvm_data, graph_def) = nnvm.testing.tf.get_workload_inception_v1()
graph_def = nnvm.testing.tf.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)

tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents', (1, 1008), 'float32')
# Build an image from random data.
from PIL import Image
img_array = np.random.uniform(size=(1, 600, 600, 3)).astype("uint8")
img = Image.frombuffer('RGB', (600, 600), img_array.tostring(), 'raw', 'RGB', 0, 1)
img.save('tf_test.jpg');

import os.path
if not tf.gfile.Exists(os.path.join('./tf_test.jpg')):
tf.logging.fatal('File does not exist %s', image)
data = tf.gfile.FastGFile(os.path.join("./tf_test.jpg"), 'rb').read()

# Extract tensorflow decoded image frame for tvm input
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
tvm_data = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'DecodeJpeg:0')

np.testing.assert_allclose(tf_output, tvm_output, rtol=2e-2, atol=2e-2)
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents', tf_output.shape, 'float32')
np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)

#######################################################################
# Mobilenet
# ---------
def test_forward_mobilenet():
'''test mobilenet model'''
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload_mobilenet()
graph_def = nnvm.testing.tf.get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)

Expand All @@ -520,12 +574,7 @@ def test_forward_mobilenet():

with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')

out_shape = tf_output.shape
tvm_output = run_tvm_graph(graph_def, data, 'input', out_shape, 'float32')
top_tvm = np.squeeze(tvm_output).argsort()[-10:][::-1]
top_tf = np.squeeze(tf_output).argsort()[-10:][::-1]

tvm_output = run_tvm_graph(graph_def, data, 'input', tf_output.shape, 'float32')
np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)

#######################################################################
Expand All @@ -544,3 +593,4 @@ def test_forward_mobilenet():
test_forward_inception_v1()
test_forward_mobilenet()
test_forward_variable()
test_forward_resize_bilinear()
Loading

0 comments on commit e21b247

Please sign in to comment.