Skip to content

Commit

Permalink
use tree_map
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Jun 22, 2024
1 parent 57465ca commit 19f8232
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def body_fn(wrapped_carry, x, prefix=None):
# return early if length = unroll_steps
if length == unroll_steps:
return wrapped_carry, (PytreeTrace({}), y0s)
wrapped_carry = device_put(wrapped_carry)
wrapped_carry = tree_map(device_put, wrapped_carry)
wrapped_carry, (pytree_trace, ys) = lax.scan(
body_fn, wrapped_carry, xs_, length - unroll_steps, reverse
)
Expand Down Expand Up @@ -324,7 +324,7 @@ def body_fn(wrapped_carry, x):

return (i + 1, rng_key, carry), (PytreeTrace(trace), y)

wrapped_carry = device_put((0, rng_key, init))
wrapped_carry = tree_map(device_put, (0, rng_key, init))
last_carry, (pytree_trace, ys) = lax.scan(
body_fn, wrapped_carry, xs, length=length, reverse=reverse
)
Expand Down

0 comments on commit 19f8232

Please sign in to comment.