From eadb72ac4db257cd0c40655a1cef54ea04356e1c Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Sun, 27 May 2018 22:47:32 +0530 Subject: [PATCH] [Tensor Flow][Test Cases] Test cases updated. --- .../frontend/tensorflow/test_forward.py | 504 +++++++++--------- 1 file changed, 264 insertions(+), 240 deletions(-) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index c7629a4f6..814e9956c 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -7,88 +7,33 @@ by the script. """ from __future__ import print_function -import os -import sys -import urllib -import requests import numpy as np import nnvm.compiler import tvm import tensorflow as tf -from tensorflow.core.framework import graph_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.ops import nn_ops - - -if sys.version_info >= (3,): - import urllib.request as urllib2 -else: - import urllib2 - -####################################################################### -# Some tensorflow helper functions to handle models -# ------------------------------------------------- -def process_graph_default(graph_def): - """Type-checks and possibly canonicalizes `graph_def`.""" - if not isinstance(graph_def, graph_pb2.GraphDef): - # `graph_def` could be a dynamically-created message, so try a duck-typed - # approach - try: - old_graph_def = graph_def - graph_def = graph_pb2.GraphDef() - graph_def.MergeFrom(old_graph_def) - except TypeError: - raise TypeError('graph_def must be a GraphDef proto.') - return graph_def - - -def load_graph(model_name): - with tf.gfile.FastGFile(model_name, 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) - # pylint: disable=unused-variable - graph = tf.import_graph_def(graph_def, name='') - # pylint: enable=unused-variable - graph_def = process_graph_default(graph_def) - return graph_def - -####################################################################### -# File download helper function -# ----------------------------- -def _download(url, path, overwrite=False, sizecompare=False): - ''' Download from internet''' - if os.path.isfile(path) and not overwrite: - if sizecompare: - file_size = os.path.getsize(path) - res_head = requests.head(url) - res_get = requests.get(url, stream=True) - if 'Content-Length' not in res_head.headers: - res_get = urllib2.urlopen(url) - urlfile_size = int(res_get.headers['Content-Length']) - if urlfile_size != file_size: - print("exist file got corrupted, downloading", path, " file freshly") - _download(url, path, True, False) - return - print('File {} exists, skip.'.format(path)) - return - print('Downloading from url {} to {}'.format(url, path)) - # pylint: disable=bare-except - try: - urllib.request.urlretrieve(url, path) - print('') - except: - urllib.urlretrieve(url, path) - # pylint: enable=bare-except +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops ####################################################################### # Generic run functions for TVM & tensorflow # ------------------------------------------ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype): """ Generic function to compile on nnvm and execute on tvm """ + sym, params = nnvm.frontend.from_tensorflow(graph_def) target = 'llvm' - shape_dict = {input_node: input_data.shape} - dtype_dict = {input_node: input_data.dtype} + if isinstance(input_data, list): + shape_dict = {} + dtype_dict = {} + for i, e in enumerate(input_node): + shape_dict[e] = input_data[i].shape + dtype_dict[e] = input_data[i].dtype + else: + shape_dict = {input_node: input_data.shape} + dtype_dict = {input_node: input_data.dtype} + graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype=dtype_dict, params=params) @@ -96,7 +41,12 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype) from tvm.contrib import graph_runtime m = graph_runtime.create(graph, lib, ctx) # set inputs - m.set_input(input_node, tvm.nd.array(input_data.astype(input_data.dtype))) + if isinstance(input_data, list): + for i, e in enumerate(input_node): + m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) + else: + m.set_input(input_node, tvm.nd.array(input_data.astype(input_data.dtype))) + m.set_input(**params) # execute m.run() @@ -105,158 +55,93 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype) return tvm_output.asnumpy() def run_tf_graph(sess, input_data, input_node, output_node): - tensor = sess.graph.get_tensor_by_name(output_node) - output_data = sess.run(tensor, {input_node: input_data}) - return output_data - -####################################################################### -# Inception V1 -# ------------ -def inception_v1_tvm(graph_def, image_name): - from PIL import Image - image = Image.open(image_name).resize((299, 299)) - image = np.array(image) - - output = run_tvm_graph(graph_def, image, 'DecodeJpeg/contents', (1, 1008), 'float32') - return np.squeeze(output) - - -def inception_v1_tf(graph_def, image_name): - if not tf.gfile.Exists(image_name): - tf.logging.fatal('File does not exist %s', image) - image_data = tf.gfile.FastGFile(image_name, 'rb').read() - - with tf.Session() as sess: - output = run_tf_graph(sess, image_data, 'DecodeJpeg/contents:0', 'softmax:0') - return np.squeeze(output) + """ Generic function to execute tensor flow """ -def test_forward_inception_v1(): - '''test inception V1 model''' - model_name = 'inception_v1' + tensor = sess.graph.get_tensor_by_name(output_node) - repo = 'https://github.com/srkreddy1238/dmlc_data/raw/master/models/tensorflow/InceptionV1/' - model_name = 'classify_image_graph_def-with_shapes.pb' + if isinstance(input_data, list): + input_dict = {} + for i, e in enumerate(input_node): + input_dict[e] = input_data[i] + else: + input_dict = {input_node: input_data} - model_url = repo + model_name - _download(model_url, model_name) + output_data = sess.run(tensor, input_dict) + return output_data - graph_def = load_graph(model_name) +####################################################################### +# Pooling +# ------- +def test_pooling(input_shape, **kwargs): + """ One iteration of pool operation with given shapes and attributes """ - image_name = 'elephant-299.jpg' - image_url = repo + image_name - _download(image_url, image_name) + x = -np.arange( + np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1 - tf_output = inception_v1_tf(graph_def, image_name) - tvm_output = inception_v1_tvm(graph_def, image_name) + with tf.Graph().as_default(): + in_data = constant_op.constant(x, shape=input_shape, dtype='float32') + # pylint: disable=unused-variable + pool = nn_ops.pool(in_data, **kwargs) + # pylint: enable=unused-variable - np.testing.assert_allclose(tf_output, tvm_output, rtol=2e-2, atol=2e-2) + if kwargs['pooling_type'] == 'MAX': + out_node = 'max_pool' + out_name = 'max_pool:0' + else: + out_node = 'avg_pool' + out_name = 'avg_pool:0' + + with tf.Session() as sess: + graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + [out_node], + ) + + tf_output = run_tf_graph(sess, x, 'Const:0', out_name) + tvm_output = run_tvm_graph(graph_def, x.astype('float32'), + "Const", tf_output.shape, 'float32') + np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3) + + sess.close() + +def test_forward_pooling(): + """ Pooling """ + + test_pooling(input_shape=[2, 9, 10, 2], + window_shape=[1, 1], + padding='SAME', + pooling_type='MAX', + dilation_rate=[1, 1], + strides=[1, 1]) + + test_pooling(input_shape=[2, 9, 10, 2], + window_shape=[1, 1], + padding='SAME', + pooling_type='AVG', + dilation_rate=[1, 1], + strides=[1, 1]) + test_pooling(input_shape=[2, 10, 9, 2], + window_shape=[1, 1], + padding='SAME', + pooling_type='MAX', + dilation_rate=[1, 1], + strides=[1, 1]) + test_pooling(input_shape=[2, 10, 9, 2], + window_shape=[1, 1], + padding='SAME', + pooling_type='AVG', + dilation_rate=[1, 1], + strides=[1, 1]) ####################################################################### # Convolution # ----------- -# Borrowed from tensorflow for test cases. -def get_shrunk_inception_shapes(shrink=10): - """Iterator for smaller versions of convolution shapes in 2015 Inception. - - Relative to inception, each depth value is `depth // shrink`. - - Args: - shrink: Factor to shrink each depth value by relative to Inception. - - Yields: - Tuple (input_size, filter_size, out_size, stride, padding), the convolution - parameters of Inception layers. - """ - input_sizes = [[4, 5, 5, 1248], [4, 8, 8, 384], [4, 8, 8, 384], - [4, 8, 8, 2048], [4, 8, 8, 448], [4, 8, 8, 2048], - [4, 8, 8, 2048], [4, 8, 8, 2048], [4, 8, 8, 1760], - [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 8, 8, 1760], - [4, 17, 17, 192], [4, 17, 17, 192], [4, 17, 17, 1248], - [4, 17, 17, 128], [4, 17, 17, 1248], [4, 17, 17, 224], - [4, 17, 17, 192], [4, 17, 17, 192], [4, 17, 17, 1216], - [4, 17, 17, 1216], [4, 17, 17, 224], [4, 17, 17, 192], - [4, 17, 17, 192], [4, 17, 17, 1152], [4, 17, 17, 1152], - [4, 17, 17, 192], [4, 17, 17, 160], [4, 17, 17, 1152], - [4, 17, 17, 1024], [4, 17, 17, 128], [4, 17, 17, 1024], - [4, 17, 17, 128], [4, 17, 17, 1024], [4, 17, 17, 128], - [4, 17, 17, 768], [4, 17, 17, 128], [4, 17, 17, 128], - [4, 17, 17, 768], [4, 17, 17, 768], [4, 35, 35, 96], - [4, 35, 35, 288], [4, 35, 35, 64], [4, 35, 35, 288], - [4, 35, 35, 256], [4, 35, 35, 48], [4, 35, 35, 256], - [4, 35, 35, 96], [4, 35, 35, 192], [4, 35, 35, 192], - [4, 35, 35, 192], [4, 73, 73, 64], [4, 73, 73, 64], - [4, 147, 147, 24]] - filter_sizes = [[1, 1, 1248, 128], [1, 3, 384, 384], [3, 1, 384, 384], - [1, 1, 2048, 192], [3, 3, 448, 384], [1, 1, 2048, 320], - [1, 1, 2048, 448], [1, 1, 2048, 384], [1, 1, 1760, 384], - [1, 1, 1760, 192], [1, 1, 1760, 448], [1, 1, 1760, 320], - [3, 3, 192, 192], [3, 3, 192, 192], [1, 1, 1248, 192], - [3, 3, 128, 320], [1, 1, 1248, 128], [1, 3, 224, 224], - [3, 1, 192, 256], [1, 3, 192, 256], [1, 1, 1216, 192], - [1, 1, 1216, 96], [3, 1, 224, 224], [3, 3, 192, 224], - [1, 3, 192, 192], [1, 1, 1152, 192], [1, 1, 1152, 128], - [3, 1, 192, 192], [3, 3, 160, 192], [1, 1, 1152, 160], - [1, 1, 1024, 128], [1, 3, 128, 192], [1, 1, 1024, 160], - [3, 1, 128, 192], [1, 1, 1024, 256], [3, 1, 128, 128], - [1, 1, 768, 192], [1, 3, 128, 128], [3, 3, 128, 128], - [1, 1, 768, 128], [1, 1, 768, 320], [3, 3, 96, 96], - [3, 3, 288, 384], [3, 3, 64, 96], [1, 1, 288, 64], - [1, 1, 256, 64], [5, 5, 48, 64], [1, 1, 256, 48], - [3, 3, 96, 96], [1, 1, 192, 32], [1, 1, 192, 64], - [1, 1, 192, 48], [3, 3, 64, 192], [1, 1, 64, 64], - [1, 1, 24, 64]] - out_sizes = [[4, 5, 5, 128], [4, 8, 8, 384], [4, 8, 8, 384], - [4, 8, 8, 192], [4, 8, 8, 384], [4, 8, 8, 320], - [4, 8, 8, 448], [4, 8, 8, 384], [4, 8, 8, 384], - [4, 8, 8, 192], [4, 8, 8, 448], [4, 8, 8, 320], - [4, 8, 8, 192], [4, 17, 17, 192], [4, 17, 17, 192], - [4, 8, 8, 320], [4, 17, 17, 128], [4, 17, 17, 224], - [4, 17, 17, 256], [4, 17, 17, 256], [4, 17, 17, 192], - [4, 17, 17, 96], [4, 17, 17, 224], [4, 17, 17, 224], - [4, 17, 17, 192], [4, 17, 17, 192], [4, 17, 17, 128], - [4, 17, 17, 192], [4, 17, 17, 192], [4, 17, 17, 160], - [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 160], - [4, 17, 17, 192], [4, 17, 17, 256], [4, 17, 17, 128], - [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 128], - [4, 17, 17, 128], [4, 17, 17, 320], [4, 17, 17, 96], - [4, 17, 17, 384], [4, 35, 35, 96], [4, 35, 35, 64], - [4, 35, 35, 64], [4, 35, 35, 64], [4, 35, 35, 48], - [4, 35, 35, 96], [4, 35, 35, 32], [4, 35, 35, 64], - [4, 35, 35, 48], [4, 71, 71, 192], [4, 73, 73, 64], - [4, 147, 147, 64]] - strides = [ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1 - ] - # Shrink sizes to make the test faster - # pylint: disable=invalid-name - for i in input_sizes: - i[3] //= shrink - for f in filter_sizes: - f[2] //= shrink - f[3] //= shrink - for o in out_sizes: - o[3] //= shrink - - VALID = "VALID" - SAME = "SAME" - paddings = [ - SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, - VALID, SAME, SAME, VALID, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, - SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, - SAME, SAME, SAME, SAME, SAME, VALID, VALID, SAME, SAME, SAME, SAME, SAME, - SAME, SAME, SAME, SAME, VALID, VALID, VALID - ] - for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides, - paddings): - yield i, f, o, s, p - # pylint: enable=invalid-name - -def test_convolution_iteration(tensor_in_sizes, filter_in_sizes, - dilations, strides, padding, data_format): +def test_convolution(tensor_in_sizes, filter_in_sizes, + dilations, strides, padding, data_format): """ One iteration of convolution with given shapes and attributes """ + total_size_1 = 1 total_size_2 = 1 for s in tensor_in_sizes: @@ -268,48 +153,187 @@ def test_convolution_iteration(tensor_in_sizes, filter_in_sizes, data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)] - in_data = constant_op.constant(data_array, shape=tensor_in_sizes, dtype='float32') - in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32') - strides = [1] + strides + [1] - dilations = [1] + dilations + [1] + with tf.Graph().as_default(): + in_data = constant_op.constant(data_array, shape=tensor_in_sizes, dtype='float32') + in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32') + strides = [1] + strides + [1] + dilations = [1] + dilations + [1] - # pylint: disable=unused-variable - conv = nn_ops.conv2d(in_data, - in_filter, - strides=strides, - padding=padding, - data_format=data_format) - # pylint: enable=unused-variable + # pylint: disable=unused-variable + conv = nn_ops.conv2d(in_data, + in_filter, + strides=strides, + padding=padding, + data_format=data_format) + # pylint: enable=unused-variable - with tf.Session() as sess: - graph_def = tf.graph_util.convert_variables_to_constants( - sess, - sess.graph.as_graph_def(add_shapes=True), - ['Conv2D'], - ) + with tf.Session() as sess: + graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + ['Conv2D'], + ) - tf_output = run_tf_graph(sess, np.reshape(data_array, tensor_in_sizes), - 'Const:0', 'Conv2D:0') - tvm_output = run_tvm_graph(graph_def, - np.reshape(data_array, tensor_in_sizes).astype('float32'), - "Const", tf_output.shape, 'float32') + tf_output = run_tf_graph(sess, np.reshape(data_array, tensor_in_sizes), + 'Const:0', 'Conv2D:0') + tvm_output = run_tvm_graph(graph_def, + np.reshape(data_array, tensor_in_sizes).astype('float32'), + "Const", tf_output.shape, 'float32') - np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3) - sess.close() + sess.close() def test_forward_convolution(): - # pylint: disable=unused-variable - for index, (input_size_, filter_size_, output_size_, stride_, - padding_) in enumerate(get_shrunk_inception_shapes()): - with tf.Graph().as_default(): - test_convolution_iteration(input_size_, filter_size_, [1, 1], - [stride_, stride_], padding_, 'NHWC') - # pylint: enable=unused-variable + test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC') + test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC') + test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC') + test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC') + +####################################################################### +# Reshape +# ----------- + +def test_reshape(data, out_shape): + """ One iteration of reshape operation with given sata and out shape """ + + with tf.Graph().as_default(): + in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype) + + # pylint: disable=unused-variable + reshape_out = array_ops.reshape(in_data, out_shape) + # pylint: enable=unused-variable + + with tf.Session() as sess: + graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + ['Reshape'], + ) + + tf_output = run_tf_graph(sess, data, + 'Const:0', 'Reshape:0') + tvm_output = run_tvm_graph(graph_def, + data, + "Const", tf_output.shape, data.dtype) + + np.testing.assert_allclose(tf_output, tvm_output) + + sess.close() + +def test_forward_reshape(): + test_reshape(np.arange(6.0), [2, 3]) + test_reshape(np.arange(6), [-1, 2]) + test_reshape(np.arange(6), [3, -1]) + test_reshape(np.arange(6), [-1]) + +####################################################################### +# Squeeze +# ----------- + +def test_squeeze(data, squeeze_dims=None): + """ One iteration of squeeze """ + + if squeeze_dims is None: + squeeze_dims = [] + + with tf.Graph().as_default(): + in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype) + + # pylint: disable=unused-variable + if squeeze_dims: + squeeze_out = array_ops.squeeze(in_data, squeeze_dims) + else: + squeeze_out = array_ops.squeeze(in_data) + # pylint: enable=unused-variable + + with tf.Session() as sess: + graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + ['Squeeze'], + ) + + tf_output = run_tf_graph(sess, data, + 'Const:0', 'Squeeze:0') + tvm_output = run_tvm_graph(graph_def, + data, + "Const", tf_output.shape, data.dtype) + + np.testing.assert_allclose(tf_output, tvm_output) + + sess.close() + +def test_forward_squeeze(): + """ Squeeze """ + + # Nothing to squeeze. + test_squeeze(np.arange(2).reshape((2))) + test_squeeze(np.arange(6).reshape((2, 3))) + + # Squeeze the middle element away. + test_squeeze(np.arange(4).reshape((2, 1, 2))) + + # Squeeze on both ends. + test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1))) + + # Positive squeeze dim index. + test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [0]) + test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [2, 4]) + test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [0, 4, 2]) + + # Negative squeeze dim index. + test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-1]) + test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5]) + test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1]) + +####################################################################### +# ConcatV2 +# ----------- + +def test_concat_v2(data, dim): + """ One iteration of ConcatV2 """ + + with tf.Graph().as_default(): + + # pylint: disable=unused-variable + concat_out = gen_array_ops._concat_v2(data, dim) + # pylint: enable=unused-variable + + with tf.Session() as sess: + graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + ['ConcatV2'], + ) + + tf_output = run_tf_graph(sess, data, + ['ConcatV2/values_0:0', 'ConcatV2/values_1:0'], 'ConcatV2:0') + tvm_output = run_tvm_graph(graph_def, + data, + ["ConcatV2/values_0", 'ConcatV2/values_1'], + tf_output.shape, tf_output.dtype) + + np.testing.assert_allclose(tf_output, tvm_output) + + sess.close() + +def test_forward_concat_v2(): + t1 = np.array([]) + t2 = np.array([]) + test_concat_v2([t1, t2], 0) + + t1 = np.array([[1, 2, 3], [4, 5, 6]]) + t2 = np.array([[7, 8, 9], [10, 11, 12]]) + + test_concat_v2([t1, t2], 1) ####################################################################### # Main # ---- if __name__ == '__main__': - test_forward_inception_v1() test_forward_convolution() + test_forward_pooling() + test_forward_reshape() + test_forward_squeeze() + test_forward_concat_v2()