-
-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax 0.4.27 parallelism errors #412
Comments
Possible this is an error in equinox, but I wasn't able to exactly replicate it without diffrax, e.g. works fine import os
import multiprocessing
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
multiprocessing.cpu_count()
)
os.environ['EQX_ON_ERROR'] = 'nan'
import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import equinox as eqx
import equinox.internal as eqxi
import functools as ft
def f(t, y, theta):
return jnp.abs(jnp.sin(t)) + theta * y
_inner_loop = jax.named_call(eqxi.while_loop, name="inner-loop")
_outer_loop = jax.named_call(eqxi.while_loop, name="outer-loop")
def solve(init, key):
def inner_loop_cond(state):
t, y, _ = state
return y.squeeze() < 10
def inner_loop_body(state):
t, y, theta = state
dy = f(t, y, theta)
return (t + 0.1, y + 0.1 * dy, theta)
def outer_loop_cond(state):
_, _, _, count = state
return count < 5
def outer_loop_body(state):
t, y, theta, count = state
y = jax.random.uniform(jax.random.PRNGKey(count), shape=(1,))
new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
return (new_t, new_y, theta, count + 1)
inner_while_loop = ft.partial(_inner_loop, kind="lax")
outer_while_loop = ft.partial(_outer_loop, kind="lax")
theta = 5.0
t_initial = 0.0
y_initial = init
count_initial = jax.random.randint(key, minval=-2, maxval=2, shape=())
final_state = outer_while_loop(outer_loop_cond, outer_loop_body, (t_initial, y_initial, theta, count_initial))
return final_state[1]
batch_size = 30
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()
inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])
x, y = eqx.filter_shard((inits, keys), sharding)
fn = eqx.filter_jit(eqx.filter_vmap(solve))
pmap_fn = eqx.filter_pmap(fn)
print('shard')
_ = fn(x, y).block_until_ready()
print('pmap')
_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready() |
Issue that the error code references in stack trace: jax-ml/jax#13554 |
All of the above code works 100% fine in 0.4.26 btw (well the sharding is still slower, but that's a different issue) |
Thanks for the report! Looks like an upstream JAX bug. I've opened jax-ml/jax#21116. |
Great, closing! 0.4.28 fixed |
The latest version of jax seems to break things when you parallelize, both
pmap
and sharding have the same error. Here is MVC:just took the code from #407
The text was updated successfully, but these errors were encountered: