From d52474ba06e09c4cd300e063a53d33541d02be74 Mon Sep 17 00:00:00 2001 From: Selo1412 <62495665+selo1412@users.noreply.github.com> Date: Fri, 27 Mar 2020 01:30:38 +0800 Subject: [PATCH] [Relay][OP] Register topi schedule for Relay fast_exp and fast_tanh (#5131) * register for fast_exp and fast_tanh * Add unit test for fast math * Add unit test for op fast math * Add unit test for op fast math * Add unit tests to guard registering topi schedule for Relay fast_exp and fast_tanh * Fix ident * Fix the indent * Add fast_tanh in the test_fastmath of topi tests --- python/tvm/relay/op/_tensor.py | 4 ++ tests/python/relay/test_op_fast_math.py | 59 +++++++++++++++++++++++++ topi/tests/python/test_topi_math.py | 3 ++ 3 files changed, 66 insertions(+) create mode 100644 tests/python/relay/test_op_fast_math.py diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 4b5eadaaf31c..eb355015f815 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -74,6 +74,8 @@ register_injective_schedule("left_shift") register_injective_schedule("shape_of") register_injective_schedule("ndarray_size") +register_broadcast_schedule("fast_exp") +register_broadcast_schedule("fast_tanh") # zeros @@ -218,3 +220,5 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("negative", False, elemwise_shape_func) register_shape_func("exp", False, elemwise_shape_func) register_shape_func("tan", False, elemwise_shape_func) +register_shape_func("fast_exp", False, elemwise_shape_func) +register_shape_func("fast_tanh", False, elemwise_shape_func) diff --git a/tests/python/relay/test_op_fast_math.py b/tests/python/relay/test_op_fast_math.py new file mode 100644 index 000000000000..1d661c380af0 --- /dev/null +++ b/tests/python/relay/test_op_fast_math.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.relay as relay +import topi +from tvm import te +from tvm.contrib import graph_runtime + + +def test_fastmath(): + def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"): + a_np = np.arange(low, high, step).astype(dtype) + b_np = f_numpy(a_np) + + x = relay.var("x", shape=a_np.shape, dtype="float32") + y = relay_op(x) + func = relay.Function([x], y) + mod = tvm.IRModule.from_expr(func) + + with relay.build_config(opt_level=3, required_pass=['FastMath']): + graph, lib, params = relay.build(mod, target="llvm", params=None) + + # Check that the op related to fast math have been convered to function in lib + func_name = "fused_" + name + assert lib.get_function(func_name) + + ctx = tvm.cpu(0) + m = graph_runtime.create(graph, lib, ctx) + # Set inputs + m.set_input('x', tvm.nd.array(a_np, ctx)) + m.set_input(**params) + # Execute + m.run() + # Get outputs + tvm_output = m.get_output(0) + tvm.testing.assert_allclose(tvm_output.asnumpy(), b_np, + rtol=1e-5, atol=1e-5) + + test_apply(relay.exp, "fast_exp", np.exp, low=-88, high=88, step=0.01) + test_apply(relay.tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01) + + +if __name__ == "__main__": + test_fastmath() diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index a8a56ef73de3..94b78a9a3ebe 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -240,6 +240,9 @@ def check_device(device): test_apply(topi.fast_exp, "fast_exp", np.exp, low=-88, high=88, step = 0.01) + test_apply(topi.fast_tanh, "fast_tanh", np.tanh, + low=-10, high=10, + step = 0.01) if __name__ == "__main__": test_util()