From e0ec87d80a8cc085af5c8d95bfee5899e38f9617 Mon Sep 17 00:00:00 2001 From: ziheng Date: Fri, 22 Feb 2019 10:18:56 -0800 Subject: [PATCH] [RELAY/OP] Gradient of relay level1 ops (#2633) --- python/tvm/relay/expr.py | 3 + python/tvm/relay/op/__init__.py | 1 + python/tvm/relay/op/_tensor.py | 18 ------ python/tvm/relay/op/_tensor_grad.py | 79 +++++++++++++++++++++++ python/tvm/relay/op/op.py | 2 +- tests/python/relay/test_op_grad_level1.py | 76 ++++++++++++++++++++++ tests/python/relay/test_op_level1.py | 16 ++--- 7 files changed, 168 insertions(+), 27 deletions(-) create mode 100644 python/tvm/relay/op/_tensor_grad.py create mode 100644 tests/python/relay/test_op_grad_level1.py diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 9257bad7dd58..bd28acc9e4b5 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -51,6 +51,9 @@ def astype(self, dtype): """ return _make.cast(self, dtype) + def __neg__(self): + return _op_make.negative(self) + def __add__(self, other): if isinstance(other, Expr): return _op_make.add(self, other) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 13f521dad660..84b0ceef8524 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -18,6 +18,7 @@ # operator registry from . import _tensor +from . import _tensor_grad from . import _transform from . import _reduce from ..expr import Expr diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index d9b5e2e89ce0..39e1f7afbfa2 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -3,25 +3,7 @@ from __future__ import absolute_import import topi from .op import register_compute, register_schedule, register_pattern -from .op import register_gradient from .op import schedule_injective, OpPattern -from .transform import collapse_sum_like -from .tensor import negative - - -def add_grad(orig, grad): - return [collapse_sum_like(grad, orig.args[0]), collapse_sum_like(grad, orig.args[1])] - - -register_gradient("add", add_grad) - - -def subtract_grad(orig, grad): - return [collapse_sum_like(grad, orig.args[0]), - collapse_sum_like(negative(grad), orig.args[1])] - - -register_gradient("subtract", subtract_grad) schedule_broadcast = schedule_injective schedule_elemwise = schedule_injective diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py new file mode 100644 index 000000000000..173e97a00496 --- /dev/null +++ b/python/tvm/relay/op/_tensor_grad.py @@ -0,0 +1,79 @@ +#pylint: disable=invalid-name, unused-argument +"""Backend compiler related feature registration""" +from __future__ import absolute_import +from ..expr import const +from .op import register_gradient +from .transform import collapse_sum_like, where +from .tensor import exp, negative, power, less +from .tensor import zeros_like, ones_like + + +@register_gradient("log") +def log_grad(orig, grad): + """Returns [grad * (1 / x)]""" + x = orig.args[0] + return [grad * ones_like(x) / x] + + +@register_gradient("exp") +def exp_grad(orig, grad): + """Returns [grad * exp(x)]""" + return [grad * exp(orig.args[0])] + + +@register_gradient("sqrt") +def sqrt_grad(orig, grad): + """Returns [grad * 0.5 * (x ^ -0.5)]""" + a = const(0.5) # (TODO) type? + return [grad * a * power(orig.args[0], negative(a))] + + +@register_gradient("sigmoid") +def sigmoid_grad(orig, grad): + """Returns [grad * sigmoid(x) * (1 - sigmoid(x))].""" + return [grad * orig * (ones_like(orig) - orig)] + + +@register_gradient("tanh") +def tanh_grad(orig, grad): + """Returns grad * (1 - tanh(x) * tanh(x)).""" + return [grad * ones_like(orig) - orig * orig] + + +@register_gradient("nn.relu") +def relu_grad(orig, grad): + """Returns grad * (select(x < 0, 0, 1)).""" + x = orig.args[0] + zeros = zeros_like(x) + ones = ones_like(x) + return [where(less(x, zeros), zeros, ones * grad)] + + +@register_gradient("add") +def add_grad(orig, grad): + """Returns [grad, grad]""" + return [collapse_sum_like(grad, orig.args[0]), + collapse_sum_like(grad, orig.args[1])] + + +@register_gradient("subtract") +def subtract_grad(orig, grad): + """Returns [grad, -grad]""" + return [collapse_sum_like(grad, orig.args[0]), + collapse_sum_like(negative(grad), orig.args[1])] + + +@register_gradient("multiply") +def multiply_grad(orig, grad): + """Returns [grad * y, grad * x]""" + x, y = orig.args + return [collapse_sum_like(grad * y, x), + collapse_sum_like(grad * x, y)] + + +@register_gradient("divide") +def divide_grad(orig, grad): + """Returns [grad / y, - grad * (x / y) / y]""" + x, y = orig.args + return [collapse_sum_like(grad / y, x), + collapse_sum_like(- (grad * orig / y), y)] diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index e751a4e5565e..37f1fc1ee2b5 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -168,7 +168,7 @@ def register_pattern(op_name, pattern, level=10): """ return register(op_name, "TOpPattern", pattern, level) -def register_gradient(op_name, fgradient, level=10): +def register_gradient(op_name, fgradient=None, level=10): """Register operator pattern for an op. Parameters diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py new file mode 100644 index 000000000000..a9d91f757407 --- /dev/null +++ b/tests/python/relay/test_op_grad_level1.py @@ -0,0 +1,76 @@ +import tvm +import numpy as np +from tvm import relay +from tvm.relay.ir_pass import gradient, infer_type +from tvm.relay.testing import ctx_list + +def sigmoid(x): + one = np.ones_like(x) + return one / (one + np.exp(-x)) + +def relu(x): + x_copy = np.copy(x) + np.maximum(x_copy, 0, x_copy) + return x_copy + +def test_unary_op(): + def check_single_op(opfunc, ref): + shape = (10, 4) + dtype = 'float32' + tp = relay.TensorType(shape, dtype) + x = relay.var("x", tp) + y = opfunc(x) + + if ref is not None: + data = np.random.rand(*shape).astype(dtype) + ref_grad = ref(data) + fwd_func = relay.Function([x], y) + bwd_func = infer_type(gradient(fwd_func)) + + for target, ctx in ctx_list(): + intrp = relay.create_executor(ctx=ctx, target=target) + op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) + np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) + + for opfunc, ref in [(tvm.relay.log, lambda x: 1 / x), + (tvm.relay.exp, np.exp), + (tvm.relay.sigmoid, lambda x: sigmoid(x) * (1 - sigmoid(x))), + (tvm.relay.tanh, lambda x: 1 - np.tanh(x) * np.tanh(x)), + (tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)), + (relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x)))]: + check_single_op(opfunc, ref) + + +def test_binary_op(): + def inst(vars, sh): + return [vars.get(s, s) for s in sh] + + def check_binary_op(opfunc, ref): + s = (5, 10, 5) + t = relay.TensorType((5, 10, 5)) + x = relay.var("x", t) + y = relay.var("y", t) + z = opfunc(x, y) + + x_data = np.random.rand(*s).astype(t.dtype) + y_data = np.random.rand(*s).astype(t.dtype) + ref_grad0, ref_grad1 = ref(x_data, y_data) + fwd_func = relay.Function([x, y], z) + bwd_func = infer_type(gradient(fwd_func)) + + for target, ctx in ctx_list(): + intrp = relay.create_executor(ctx=ctx, target=target) + op_res, (op_grad0, op_grad1) = intrp.evaluate(bwd_func)(x_data, y_data) + np.testing.assert_allclose(op_grad0.asnumpy(), ref_grad0, rtol=0.01) + np.testing.assert_allclose(op_grad1.asnumpy(), ref_grad1, rtol=0.01) + + for opfunc, ref in [(relay.add, lambda x, y: [np.ones_like(x), np.ones_like(y)]), + (relay.subtract, lambda x, y: [np.ones_like(x), -np.ones_like(y)]), + (relay.multiply, lambda x, y: [y, x]), + (relay.divide, lambda x, y: [1 / y, - x / (y**2)])]: + check_binary_op(opfunc, ref) + + +if __name__ == "__main__": + test_unary_op() + test_binary_op() diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 6a1662b65170..d29b808be0d1 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -39,11 +39,11 @@ def check_single_op(opfunc, ref): for opfunc, ref in [(tvm.relay.log, np.log), - (tvm.relay.exp, np.exp), - (tvm.relay.sqrt, np.sqrt), - (tvm.relay.sigmoid, sigmoid), - (tvm.relay.tanh, np.tanh), - (relay.nn.relu, relu)]: + (tvm.relay.exp, np.exp), + (tvm.relay.sqrt, np.sqrt), + (tvm.relay.sigmoid, sigmoid), + (tvm.relay.tanh, np.tanh), + (relay.nn.relu, relu)]: check_single_op(opfunc, ref) @@ -84,9 +84,9 @@ def check_binary_op(opfunc, ref): np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) for opfunc, ref in [(relay.add, np.add), - (relay.subtract, np.subtract), - (relay.multiply, np.multiply), - (relay.divide, np.divide)]: + (relay.subtract, np.subtract), + (relay.multiply, np.multiply), + (relay.divide, np.divide)]: check_binary_op(opfunc, ref)