Skip to content

Commit

Permalink
[RELAY/OP] Gradient of relay level1 ops (apache#2633)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang authored and wweic committed Mar 12, 2019
1 parent 1fe20ac commit c388b9c
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 27 deletions.
3 changes: 3 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def astype(self, dtype):
"""
return _make.cast(self, dtype)

def __neg__(self):
return _op_make.negative(self)

def __add__(self, other):
if isinstance(other, Expr):
return _op_make.add(self, other)
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

# operator registry
from . import _tensor
from . import _tensor_grad
from . import _transform
from . import _reduce
from ..expr import Expr
Expand Down
18 changes: 0 additions & 18 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,7 @@
from __future__ import absolute_import
import topi
from .op import register_compute, register_schedule, register_pattern
from .op import register_gradient
from .op import schedule_injective, OpPattern
from .transform import collapse_sum_like
from .tensor import negative


def add_grad(orig, grad):
return [collapse_sum_like(grad, orig.args[0]), collapse_sum_like(grad, orig.args[1])]


register_gradient("add", add_grad)


def subtract_grad(orig, grad):
return [collapse_sum_like(grad, orig.args[0]),
collapse_sum_like(negative(grad), orig.args[1])]


register_gradient("subtract", subtract_grad)

schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective
Expand Down
79 changes: 79 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from ..expr import const
from .op import register_gradient
from .transform import collapse_sum_like, where
from .tensor import exp, negative, power, less
from .tensor import zeros_like, ones_like


@register_gradient("log")
def log_grad(orig, grad):
"""Returns [grad * (1 / x)]"""
x = orig.args[0]
return [grad * ones_like(x) / x]


@register_gradient("exp")
def exp_grad(orig, grad):
"""Returns [grad * exp(x)]"""
return [grad * exp(orig.args[0])]


@register_gradient("sqrt")
def sqrt_grad(orig, grad):
"""Returns [grad * 0.5 * (x ^ -0.5)]"""
a = const(0.5) # (TODO) type?
return [grad * a * power(orig.args[0], negative(a))]


@register_gradient("sigmoid")
def sigmoid_grad(orig, grad):
"""Returns [grad * sigmoid(x) * (1 - sigmoid(x))]."""
return [grad * orig * (ones_like(orig) - orig)]


@register_gradient("tanh")
def tanh_grad(orig, grad):
"""Returns grad * (1 - tanh(x) * tanh(x))."""
return [grad * ones_like(orig) - orig * orig]


@register_gradient("nn.relu")
def relu_grad(orig, grad):
"""Returns grad * (select(x < 0, 0, 1))."""
x = orig.args[0]
zeros = zeros_like(x)
ones = ones_like(x)
return [where(less(x, zeros), zeros, ones * grad)]


@register_gradient("add")
def add_grad(orig, grad):
"""Returns [grad, grad]"""
return [collapse_sum_like(grad, orig.args[0]),
collapse_sum_like(grad, orig.args[1])]


@register_gradient("subtract")
def subtract_grad(orig, grad):
"""Returns [grad, -grad]"""
return [collapse_sum_like(grad, orig.args[0]),
collapse_sum_like(negative(grad), orig.args[1])]


@register_gradient("multiply")
def multiply_grad(orig, grad):
"""Returns [grad * y, grad * x]"""
x, y = orig.args
return [collapse_sum_like(grad * y, x),
collapse_sum_like(grad * x, y)]


@register_gradient("divide")
def divide_grad(orig, grad):
"""Returns [grad / y, - grad * (x / y) / y]"""
x, y = orig.args
return [collapse_sum_like(grad / y, x),
collapse_sum_like(- (grad * orig / y), y)]
2 changes: 1 addition & 1 deletion python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def register_pattern(op_name, pattern, level=10):
"""
return register(op_name, "TOpPattern", pattern, level)

def register_gradient(op_name, fgradient, level=10):
def register_gradient(op_name, fgradient=None, level=10):
"""Register operator pattern for an op.
Parameters
Expand Down
76 changes: 76 additions & 0 deletions tests/python/relay/test_op_grad_level1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import tvm
import numpy as np
from tvm import relay
from tvm.relay.ir_pass import gradient, infer_type
from tvm.relay.testing import ctx_list

def sigmoid(x):
one = np.ones_like(x)
return one / (one + np.exp(-x))

def relu(x):
x_copy = np.copy(x)
np.maximum(x_copy, 0, x_copy)
return x_copy

def test_unary_op():
def check_single_op(opfunc, ref):
shape = (10, 4)
dtype = 'float32'
tp = relay.TensorType(shape, dtype)
x = relay.var("x", tp)
y = opfunc(x)

if ref is not None:
data = np.random.rand(*shape).astype(dtype)
ref_grad = ref(data)
fwd_func = relay.Function([x], y)
bwd_func = infer_type(gradient(fwd_func))

for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)

for opfunc, ref in [(tvm.relay.log, lambda x: 1 / x),
(tvm.relay.exp, np.exp),
(tvm.relay.sigmoid, lambda x: sigmoid(x) * (1 - sigmoid(x))),
(tvm.relay.tanh, lambda x: 1 - np.tanh(x) * np.tanh(x)),
(tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)),
(relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x)))]:
check_single_op(opfunc, ref)


def test_binary_op():
def inst(vars, sh):
return [vars.get(s, s) for s in sh]

def check_binary_op(opfunc, ref):
s = (5, 10, 5)
t = relay.TensorType((5, 10, 5))
x = relay.var("x", t)
y = relay.var("y", t)
z = opfunc(x, y)

x_data = np.random.rand(*s).astype(t.dtype)
y_data = np.random.rand(*s).astype(t.dtype)
ref_grad0, ref_grad1 = ref(x_data, y_data)
fwd_func = relay.Function([x, y], z)
bwd_func = infer_type(gradient(fwd_func))

for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad0, op_grad1) = intrp.evaluate(bwd_func)(x_data, y_data)
np.testing.assert_allclose(op_grad0.asnumpy(), ref_grad0, rtol=0.01)
np.testing.assert_allclose(op_grad1.asnumpy(), ref_grad1, rtol=0.01)

for opfunc, ref in [(relay.add, lambda x, y: [np.ones_like(x), np.ones_like(y)]),
(relay.subtract, lambda x, y: [np.ones_like(x), -np.ones_like(y)]),
(relay.multiply, lambda x, y: [y, x]),
(relay.divide, lambda x, y: [1 / y, - x / (y**2)])]:
check_binary_op(opfunc, ref)


if __name__ == "__main__":
test_unary_op()
test_binary_op()
16 changes: 8 additions & 8 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def check_single_op(opfunc, ref):


for opfunc, ref in [(tvm.relay.log, np.log),
(tvm.relay.exp, np.exp),
(tvm.relay.sqrt, np.sqrt),
(tvm.relay.sigmoid, sigmoid),
(tvm.relay.tanh, np.tanh),
(relay.nn.relu, relu)]:
(tvm.relay.exp, np.exp),
(tvm.relay.sqrt, np.sqrt),
(tvm.relay.sigmoid, sigmoid),
(tvm.relay.tanh, np.tanh),
(relay.nn.relu, relu)]:
check_single_op(opfunc, ref)


Expand Down Expand Up @@ -84,9 +84,9 @@ def check_binary_op(opfunc, ref):
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)

for opfunc, ref in [(relay.add, np.add),
(relay.subtract, np.subtract),
(relay.multiply, np.multiply),
(relay.divide, np.divide)]:
(relay.subtract, np.subtract),
(relay.multiply, np.multiply),
(relay.divide, np.divide)]:
check_binary_op(opfunc, ref)


Expand Down

0 comments on commit c388b9c

Please sign in to comment.