From d4f5909c0785070af60a7967b4d851e37e352d18 Mon Sep 17 00:00:00 2001 From: lfengad Date: Thu, 5 Mar 2020 15:58:31 +0800 Subject: [PATCH] Add BN support with run-time mean and variance calculation --- python/tvm/relay/frontend/tensorflow.py | 9 ++- .../tensorflow/test_bn_trainingmod.py | 61 +++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 tests/python/frontend/tensorflow/test_bn_trainingmod.py diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 14d2418da7100..2cbbc019db347 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -887,7 +887,14 @@ def _impl(inputs, attr, params): if 'U' in attr: need_cast = True inputs[0] = _op.cast(inputs[0], dtype=attr['U'].name) - + # Check if mean and variance are empty + # If so, replace them with Mean and Variance Ops + # For run-time calculation + moving_mean_shape = [int(n) for n in inputs[3].type_annotation.shape] + moving_variance_shape = [int(n) for n in inputs[4].type_annotation.shape] + if (moving_mean_shape[0] == 0 and moving_variance_shape[0] == 0): + inputs[3] = _op.mean(inputs[0], axis=axis, keepdims=False, exclude=True) + inputs[4] = _op.variance(inputs[0], axis=axis, keepdims=False, exclude=True) out = AttrCvt(op_name='batch_norm', transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'}, diff --git a/tests/python/frontend/tensorflow/test_bn_trainingmod.py b/tests/python/frontend/tensorflow/test_bn_trainingmod.py new file mode 100644 index 0000000000000..babacfa470a12 --- /dev/null +++ b/tests/python/frontend/tensorflow/test_bn_trainingmod.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +BatchNorm without given mean and variance given testcases +==================== +This article is a test script to test fused_batch_norm operators in TensorFlow frontend when mean and variance are not given. +""" +import tvm +import numpy as np +import tensorflow as tf +from tvm import relay +from tensorflow.python.framework import graph_util + +def test_fusedbatchnorm(): + g=tf.Graph() + with g.as_default(): + input_tensor = tf.placeholder(tf.float32,shape=(1, 12, 12, 32), name='input') + alpha = tf.constant(np.random.rand(32,), dtype=tf.float32, name='alpha') + beta = tf.constant(np.random.rand(32,), dtype=tf.float32, name='beta') + bn = tf.nn.fused_batch_norm(x=input_tensor, offset=beta, scale=alpha, name='bn') + out = tf.identity(bn[0], name='sum') + data = np.random.rand(1, 12, 12, 32) + with tf.Session(graph=out.graph) as sess: + sess.run([tf.global_variables_initializer()]) + tf_out = sess.run(out, feed_dict={input_tensor:data}) + constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['sum']) + + + layout = None + target = 'llvm' + ctx=tvm.cpu(0) + mod, params = relay.frontend.from_tensorflow(constant_graph, layout=layout, outputs=['sum']) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(mod, + target=target, + target_host = target, + params=params) + from tvm.contrib import graph_runtime + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + m.set_input('input', data) + m.run() + tvm_out=m.get_output(0) + tvm.testing.assert_allclose(tvm_out.asnumpy(), tf_out.astype(tvm_out.dtype), rtol=1e-3) + +if __name__ == "__main__": + test_fusedbatchnorm()