Skip to content

Commit

Permalink
Fix bias_add gradient (#4516)
Browse files Browse the repository at this point in the history
* Fix bias_add gradient

A change caused collapse_sum_like to reject implicit dimension
broadcasting for bias_add gradient, so switch to explicit sum reduction
on the non-bias axis dimensions.

* Lint fix
  • Loading branch information
SWu authored and vinx13 committed Dec 13, 2019
1 parent 6e085b4 commit f10944c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ def log_softmax_grad(orig, grad):
@register_gradient("nn.bias_add")
def bias_add_grad(orig, grad):
"""Returns gradient of bias_add"""
data, bias = orig.args
data = orig.args[0]
return [collapse_sum_like(grad, data),
collapse_sum_like(grad, bias)]
_sum(grad, orig.attrs.axis, keepdims=False, exclude=True)]


@register_gradient("nn.dense")
Expand Down
15 changes: 11 additions & 4 deletions tests/python/relay/test_op_grad_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,19 @@ def test_log_softmax_grad():
check_grad(fwd_func, scale=1)


def test_bias_add_grad():
data = relay.var("data", relay.TensorType((1, 16), "float32"))
bias = relay.var("bias", relay.TensorType((16,), "float32"))
fwd_func = relay.Function([data, bias], relay.nn.bias_add(data, bias))
def verify_bias_add(d_shape, b_shape, axis=1):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
bias = relay.var("bias", relay.TensorType(b_shape, "float32"))
fwd_func = relay.Function([data, bias], relay.nn.bias_add(data, bias, axis=axis))
check_grad(fwd_func)


def test_bias_add_grad():
verify_bias_add((1, 16), (16,))
verify_bias_add((1, 8, 2, 2), (8,))
verify_bias_add((1, 2, 2, 8), (8,), 3)
verify_bias_add((4, 8), (8,))


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit f10944c

Please sign in to comment.