From 81e8cfc64e9e2e5ab8c72b67b6d3be4f79758db5 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 23 Jun 2020 23:01:46 +0200 Subject: [PATCH] add a few gradients (#5899) --- python/tvm/relay/op/_tensor_grad.py | 57 ++++++++++++++++++++++ tests/python/relay/test_op_grad_level1.py | 1 + tests/python/relay/test_op_grad_level10.py | 6 +++ tests/python/relay/test_op_grad_level3.py | 7 +++ tests/python/relay/test_op_grad_level4.py | 15 +++--- 5 files changed, 79 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 0deb87a60e34..00ea09771a1f 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -270,6 +270,14 @@ def abs_grad(orig, grad): return [where(less(x, zeros), -ones * grad, ones * grad)] +@register_gradient("erf") +def erf_grad(orig, grad): + # c_2_div_sqrt_pi = 2.0 / math.sqrt(math.pi) + inp, = orig.args + c_2_div_sqrt_pi = const(1.1283791670955126, dtype=inp.checked_type.dtype) + return [c_2_div_sqrt_pi * exp(- inp * inp) * grad] + + @register_gradient("clip") def clip_grad(orig, grad): """Returns grad * (select(x < min || max < x , 0, 1)).""" @@ -479,6 +487,19 @@ def dense_grad(orig, grad): collapse_sum_like(_nn.dense(transpose(grad), transpose(data), units=data.checked_type.shape[1]), weight)] + +@register_gradient("nn.batch_matmul") +def batch_matmul_grad(orig, grad): + """gradient for nn.batch_matmul: in einsum LHS_bik,RHS_bjk->RES_bij + grads: GRAD_OUT_bij,RHS_bjk->GRAD_IN_LHS_bik + GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk + """ + lhs, rhs = orig.args + return [collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs), + collapse_sum_like(_nn.batch_matmul(transpose(grad, [0, 2, 1]), + transpose(lhs, [0, 2, 1])), rhs)] + + @register_gradient("reshape") def reshape_grad(orig, grad): """Gradient of reshape""" @@ -529,6 +550,42 @@ def sum_grad(orig, grad): return [broadcast_to_like(grad, data)] +@register_gradient("mean") +def mean_grad(orig, grad): + """Returns grad broadcasted to data dims""" + data, axis = orig.args[0], _get_reduce_axis(orig) + shape = data.checked_type.concrete_shape + if axis is None: + axis = list(range(len(data.checked_type.concrete_shape))) + if not orig.attrs.keepdims: + grad = _unreduce_expand(grad, axis) + mult = 1.0 + for a in axis: + mult /= shape[a] + return [broadcast_to_like(grad * const(mult, dtype=data.checked_type.dtype), data)] + + +@register_gradient("variance") +def variance_grad(orig, grad): + """Note that we take mean as an argument in the variance node""" + data, data_mean, axis = orig.args[0], orig.args[1], _get_reduce_axis(orig) + shape = data.checked_type.concrete_shape + if axis is None: + axis = list(range(len(data.checked_type.concrete_shape))) + if not orig.attrs.keepdims: + grad = _unreduce_expand(grad, axis) + mult = 2.0 + for a in axis: + mult /= shape[a] + return [(grad * const(mult, dtype=data.checked_type.dtype)) * data, + const(-2, dtype=data.checked_type.dtype) * grad * data_mean] + + +@register_gradient("copy") +def copy_grad(orig, grad): + return [grad] + + @register_gradient("nn.cross_entropy") def cross_entropy_grad(orig, grad): x, y = orig.args diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py index 9faf6d903a9c..85506e0f513f 100644 --- a/tests/python/relay/test_op_grad_level1.py +++ b/tests/python/relay/test_op_grad_level1.py @@ -62,6 +62,7 @@ def check_single_op(opfunc, ref): (tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)), (tvm.relay.abs, lambda x: np.where(x < 0, -np.ones_like(x), np.ones_like(x))), (relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))), + (tvm.relay.erf, lambda x: 2.0 / (np.pi**(0.5)) * np.exp(-x * x)), (tvm.relay.cos, lambda x: -1.0 * np.sin(x)), (tvm.relay.sin, lambda x: np.cos(x)), (tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)), diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index acf3b75e0cb5..6e6499998047 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -44,5 +44,11 @@ def test_checkpoint(): check_grad(relay.Function(inputs, out_single)) +def test_batch_matmul_grad(): + x = relay.var("x", shape=(2, 3, 5), dtype="float64") + y = relay.var("y", shape=(2, 4, 5), dtype="float64") + check_grad(relay.Function([x, y], relay.op.nn.batch_matmul(x, y))) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index d13687fbec72..b1d0e2540542 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -64,5 +64,12 @@ def test_cast_grad(): fwd_func = relay.Function([data], relay.cast(data, "float64")) check_grad(fwd_func) + +def test_copy_grad(): + data = relay.var("data", relay.TensorType((10, 4), "float64")) + fwd_func = relay.Function([data], relay.copy(data)) + check_grad(fwd_func) + + if __name__ == "__main__": pytest.main() diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py index f690a186ea41..956c6af8d5cb 100644 --- a/tests/python/relay/test_op_grad_level4.py +++ b/tests/python/relay/test_op_grad_level4.py @@ -19,17 +19,18 @@ from tvm.relay.testing import check_grad -def verify_sum_grad(d_shape, axis=None, keepdims=False, exclude=False): +def verify_reduction_grad(red_fn, d_shape, axis=None, keepdims=False, exclude=False): data = relay.var("data", relay.TensorType(d_shape, "float32")) - fwd_func = relay.Function([data], relay.sum(data, axis=axis, keepdims=keepdims, exclude=exclude)) + fwd_func = relay.Function([data], red_fn(data, axis=axis, keepdims=keepdims, exclude=exclude)) check_grad(fwd_func) -def test_sum_grad(): - verify_sum_grad((4, 2)) - verify_sum_grad((4, 2), axis=-1, keepdims=True) - verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True) - verify_sum_grad((4, 2, 1), axis=1) +def test_reduction_grad(): + for op in (relay.sum, relay.variance, relay.mean): + verify_reduction_grad(op, (4, 2)) + verify_reduction_grad(op, (4, 2), axis=-1, keepdims=True) + verify_reduction_grad(op, (4, 2, 1), axis=(1, 2), exclude=True) + verify_reduction_grad(op, (4, 2, 1), axis=1) def verify_max_grad(d_shape, axis=None, keepdims=False, exclude=False):