Skip to content

Commit

Permalink
[Relay][OP] Register topi schedule for Relay fast_exp and fast_tanh (a…
Browse files Browse the repository at this point in the history
…pache#5131)

* register for fast_exp and fast_tanh

* Add unit test for fast math

* Add unit test for op fast math

* Add unit test for op fast math

* Add unit tests to guard registering topi schedule for Relay fast_exp and fast_tanh

* Fix ident

* Fix the indent

* Add fast_tanh in the test_fastmath of topi tests
  • Loading branch information
selo1412 authored and Trevor Morris committed Apr 16, 2020
1 parent e9e9301 commit d52474b
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
register_injective_schedule("left_shift")
register_injective_schedule("shape_of")
register_injective_schedule("ndarray_size")
register_broadcast_schedule("fast_exp")
register_broadcast_schedule("fast_tanh")


# zeros
Expand Down Expand Up @@ -218,3 +220,5 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("negative", False, elemwise_shape_func)
register_shape_func("exp", False, elemwise_shape_func)
register_shape_func("tan", False, elemwise_shape_func)
register_shape_func("fast_exp", False, elemwise_shape_func)
register_shape_func("fast_tanh", False, elemwise_shape_func)
59 changes: 59 additions & 0 deletions tests/python/relay/test_op_fast_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
import tvm.relay as relay
import topi
from tvm import te
from tvm.contrib import graph_runtime


def test_fastmath():
def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"):
a_np = np.arange(low, high, step).astype(dtype)
b_np = f_numpy(a_np)

x = relay.var("x", shape=a_np.shape, dtype="float32")
y = relay_op(x)
func = relay.Function([x], y)
mod = tvm.IRModule.from_expr(func)

with relay.build_config(opt_level=3, required_pass=['FastMath']):
graph, lib, params = relay.build(mod, target="llvm", params=None)

# Check that the op related to fast math have been convered to function in lib
func_name = "fused_" + name
assert lib.get_function(func_name)

ctx = tvm.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
# Set inputs
m.set_input('x', tvm.nd.array(a_np, ctx))
m.set_input(**params)
# Execute
m.run()
# Get outputs
tvm_output = m.get_output(0)
tvm.testing.assert_allclose(tvm_output.asnumpy(), b_np,
rtol=1e-5, atol=1e-5)

test_apply(relay.exp, "fast_exp", np.exp, low=-88, high=88, step=0.01)
test_apply(relay.tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01)


if __name__ == "__main__":
test_fastmath()
3 changes: 3 additions & 0 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def check_device(device):
test_apply(topi.fast_exp, "fast_exp", np.exp,
low=-88, high=88,
step = 0.01)
test_apply(topi.fast_tanh, "fast_tanh", np.tanh,
low=-10, high=10,
step = 0.01)

if __name__ == "__main__":
test_util()
Expand Down

0 comments on commit d52474b

Please sign in to comment.