diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 0cd2efb703cf..d3d707b820a5 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -25,7 +25,7 @@ from . import nn as _nn from .op import register_gradient from .reduce import sum as _sum -from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like +from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like, equal from .transform import ( broadcast_to_like, collapse_sum_like, @@ -269,6 +269,18 @@ def conv2d_grad(orig, grad): return [backward_data, backward_weight] +@register_gradient("max") +def max_grad(orig, grad): + """Returns the gradient of max""" + # Only support axis=0, since broadcasting orig to x behaves incorrectly + x, axis = orig.args[0], orig.attrs.axis + assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0) + orig = broadcast_to_like(orig, x) + grad = broadcast_to_like(grad, x) + indicators = cast_like(equal(orig, x), grad) + return [indicators * grad] + + @register_gradient("nn.softmax") def softmax_grad(orig, grad): """Gradient of softmax""" diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py index 5db1d9391a9a..3c799b88aadc 100644 --- a/tests/python/relay/test_op_grad_level4.py +++ b/tests/python/relay/test_op_grad_level4.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest from tvm import relay from tvm.relay.testing import check_grad @@ -30,6 +31,16 @@ def test_sum_grad(): verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True) +def test_max_grad(): + s = (5, 10) + t = relay.TensorType(s) + x = relay.var("x", t) + axis = 0 + z = relay.max(x, axis) + + fwd_func = relay.Function([x], z) + check_grad(fwd_func, eps=1e-7, rtol=1) + if __name__ == "__main__": - test_sum_grad() + pytest.main()