Skip to content

Commit

Permalink
Fixed failing test for value_and_grad (#6041)
Browse files Browse the repository at this point in the history
  • Loading branch information
vedpatwardhan authored Oct 22, 2022
1 parent 33d9b35 commit 80f3cac
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions ivy/functional/backends/jax/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def value_and_grad(func):

def callback_fn(xs):
xs = ivy.nested_map(xs, lambda x: ivy.to_native(x), include_derived=True)
ret = jax.value_and_grad(grad_fn)(xs)
ret = _remove_zeros_and_nones(ret, ret)
return ret
value, grad = jax.value_and_grad(grad_fn)(xs)
grad = _remove_zeros_and_nones(grad, grad)
return ivy.to_ivy(value), ivy.to_ivy(grad)

return callback_fn

Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/torch/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def autograd_fn(x):
else ivy.to_native(ivy.zeros_like(ivy.to_ivy(x)))
)
grad = ivy.to_ivy(grad)
grad = _remove_zeros_and_nones(grads, grads)
grad = _remove_zeros_and_nones(grad, grad)
return grad

grads = ivy.nested_map(
Expand Down

0 comments on commit 80f3cac

Please sign in to comment.