Skip to content

Commit

Permalink
add a few gradients (#5899)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jun 23, 2020
1 parent aa84ee2 commit 81e8cfc
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 7 deletions.
57 changes: 57 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,14 @@ def abs_grad(orig, grad):
return [where(less(x, zeros), -ones * grad, ones * grad)]


@register_gradient("erf")
def erf_grad(orig, grad):
# c_2_div_sqrt_pi = 2.0 / math.sqrt(math.pi)
inp, = orig.args
c_2_div_sqrt_pi = const(1.1283791670955126, dtype=inp.checked_type.dtype)
return [c_2_div_sqrt_pi * exp(- inp * inp) * grad]


@register_gradient("clip")
def clip_grad(orig, grad):
"""Returns grad * (select(x < min || max < x , 0, 1))."""
Expand Down Expand Up @@ -479,6 +487,19 @@ def dense_grad(orig, grad):
collapse_sum_like(_nn.dense(transpose(grad), transpose(data),
units=data.checked_type.shape[1]), weight)]


@register_gradient("nn.batch_matmul")
def batch_matmul_grad(orig, grad):
"""gradient for nn.batch_matmul: in einsum LHS_bik,RHS_bjk->RES_bij
grads: GRAD_OUT_bij,RHS_bjk->GRAD_IN_LHS_bik
GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk
"""
lhs, rhs = orig.args
return [collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs),
collapse_sum_like(_nn.batch_matmul(transpose(grad, [0, 2, 1]),
transpose(lhs, [0, 2, 1])), rhs)]


@register_gradient("reshape")
def reshape_grad(orig, grad):
"""Gradient of reshape"""
Expand Down Expand Up @@ -529,6 +550,42 @@ def sum_grad(orig, grad):
return [broadcast_to_like(grad, data)]


@register_gradient("mean")
def mean_grad(orig, grad):
"""Returns grad broadcasted to data dims"""
data, axis = orig.args[0], _get_reduce_axis(orig)
shape = data.checked_type.concrete_shape
if axis is None:
axis = list(range(len(data.checked_type.concrete_shape)))
if not orig.attrs.keepdims:
grad = _unreduce_expand(grad, axis)
mult = 1.0
for a in axis:
mult /= shape[a]
return [broadcast_to_like(grad * const(mult, dtype=data.checked_type.dtype), data)]


@register_gradient("variance")
def variance_grad(orig, grad):
"""Note that we take mean as an argument in the variance node"""
data, data_mean, axis = orig.args[0], orig.args[1], _get_reduce_axis(orig)
shape = data.checked_type.concrete_shape
if axis is None:
axis = list(range(len(data.checked_type.concrete_shape)))
if not orig.attrs.keepdims:
grad = _unreduce_expand(grad, axis)
mult = 2.0
for a in axis:
mult /= shape[a]
return [(grad * const(mult, dtype=data.checked_type.dtype)) * data,
const(-2, dtype=data.checked_type.dtype) * grad * data_mean]


@register_gradient("copy")
def copy_grad(orig, grad):
return [grad]


@register_gradient("nn.cross_entropy")
def cross_entropy_grad(orig, grad):
x, y = orig.args
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/test_op_grad_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def check_single_op(opfunc, ref):
(tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)),
(tvm.relay.abs, lambda x: np.where(x < 0, -np.ones_like(x), np.ones_like(x))),
(relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))),
(tvm.relay.erf, lambda x: 2.0 / (np.pi**(0.5)) * np.exp(-x * x)),
(tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
(tvm.relay.sin, lambda x: np.cos(x)),
(tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)),
Expand Down
6 changes: 6 additions & 0 deletions tests/python/relay/test_op_grad_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,11 @@ def test_checkpoint():
check_grad(relay.Function(inputs, out_single))


def test_batch_matmul_grad():
x = relay.var("x", shape=(2, 3, 5), dtype="float64")
y = relay.var("y", shape=(2, 4, 5), dtype="float64")
check_grad(relay.Function([x, y], relay.op.nn.batch_matmul(x, y)))


if __name__ == "__main__":
pytest.main([__file__])
7 changes: 7 additions & 0 deletions tests/python/relay/test_op_grad_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,12 @@ def test_cast_grad():
fwd_func = relay.Function([data], relay.cast(data, "float64"))
check_grad(fwd_func)


def test_copy_grad():
data = relay.var("data", relay.TensorType((10, 4), "float64"))
fwd_func = relay.Function([data], relay.copy(data))
check_grad(fwd_func)


if __name__ == "__main__":
pytest.main()
15 changes: 8 additions & 7 deletions tests/python/relay/test_op_grad_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
from tvm.relay.testing import check_grad


def verify_sum_grad(d_shape, axis=None, keepdims=False, exclude=False):
def verify_reduction_grad(red_fn, d_shape, axis=None, keepdims=False, exclude=False):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
fwd_func = relay.Function([data], relay.sum(data, axis=axis, keepdims=keepdims, exclude=exclude))
fwd_func = relay.Function([data], red_fn(data, axis=axis, keepdims=keepdims, exclude=exclude))
check_grad(fwd_func)


def test_sum_grad():
verify_sum_grad((4, 2))
verify_sum_grad((4, 2), axis=-1, keepdims=True)
verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
verify_sum_grad((4, 2, 1), axis=1)
def test_reduction_grad():
for op in (relay.sum, relay.variance, relay.mean):
verify_reduction_grad(op, (4, 2))
verify_reduction_grad(op, (4, 2), axis=-1, keepdims=True)
verify_reduction_grad(op, (4, 2, 1), axis=(1, 2), exclude=True)
verify_reduction_grad(op, (4, 2, 1), axis=1)


def verify_max_grad(d_shape, axis=None, keepdims=False, exclude=False):
Expand Down

0 comments on commit 81e8cfc

Please sign in to comment.