From 1d7723bbef879570489708cedbadfe2899b62fc9 Mon Sep 17 00:00:00 2001 From: SWu Date: Fri, 13 Dec 2019 15:09:56 -0500 Subject: [PATCH] Fix bias_add gradient (#4516) * Fix bias_add gradient A change caused collapse_sum_like to reject implicit dimension broadcasting for bias_add gradient, so switch to explicit sum reduction on the non-bias axis dimensions. * Lint fix --- python/tvm/relay/op/_tensor_grad.py | 4 ++-- tests/python/relay/test_op_grad_level1.py | 15 +++++++++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index d55cad7c7a2d..944e51e636f5 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -379,9 +379,9 @@ def log_softmax_grad(orig, grad): @register_gradient("nn.bias_add") def bias_add_grad(orig, grad): """Returns gradient of bias_add""" - data, bias = orig.args + data = orig.args[0] return [collapse_sum_like(grad, data), - collapse_sum_like(grad, bias)] + _sum(grad, orig.attrs.axis, keepdims=False, exclude=True)] @register_gradient("nn.dense") diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py index 114bda0eccd5..3be62a3170fb 100644 --- a/tests/python/relay/test_op_grad_level1.py +++ b/tests/python/relay/test_op_grad_level1.py @@ -110,12 +110,19 @@ def test_log_softmax_grad(): check_grad(fwd_func, scale=1) -def test_bias_add_grad(): - data = relay.var("data", relay.TensorType((1, 16), "float32")) - bias = relay.var("bias", relay.TensorType((16,), "float32")) - fwd_func = relay.Function([data, bias], relay.nn.bias_add(data, bias)) +def verify_bias_add(d_shape, b_shape, axis=1): + data = relay.var("data", relay.TensorType(d_shape, "float32")) + bias = relay.var("bias", relay.TensorType(b_shape, "float32")) + fwd_func = relay.Function([data, bias], relay.nn.bias_add(data, bias, axis=axis)) check_grad(fwd_func) +def test_bias_add_grad(): + verify_bias_add((1, 16), (16,)) + verify_bias_add((1, 8, 2, 2), (8,)) + verify_bias_add((1, 2, 2, 8), (8,), 3) + verify_bias_add((4, 8), (8,)) + + if __name__ == "__main__": pytest.main([__file__])