diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 4bd2143a0..6b657b494 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -224,7 +224,7 @@ def body_fn(wrapped_carry, x, prefix=None): # return early if length = unroll_steps if length == unroll_steps: return wrapped_carry, (PytreeTrace({}), y0s) - wrapped_carry = device_put(wrapped_carry) + wrapped_carry = tree_map(device_put, wrapped_carry) wrapped_carry, (pytree_trace, ys) = lax.scan( body_fn, wrapped_carry, xs_, length - unroll_steps, reverse ) @@ -324,7 +324,7 @@ def body_fn(wrapped_carry, x): return (i + 1, rng_key, carry), (PytreeTrace(trace), y) - wrapped_carry = device_put((0, rng_key, init)) + wrapped_carry = tree_map(device_put, (0, rng_key, init)) last_carry, (pytree_trace, ys) = lax.scan( body_fn, wrapped_carry, xs, length=length, reverse=reverse )