Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax 0.4.27 parallelism errors #412

Closed
lockwo opened this issue May 7, 2024 · 5 comments
Closed

Jax 0.4.27 parallelism errors #412

lockwo opened this issue May 7, 2024 · 5 comments

Comments

@lockwo
Copy link
Contributor

lockwo commented May 7, 2024

The latest version of jax seems to break things when you parallelize, both pmap and sharding have the same error. Here is MVC:

import os
import multiprocessing

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)
os.environ['EQX_ON_ERROR'] = 'nan'
import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import equinox as eqx
from diffrax import *

def f(t, y, args):
    return jnp.sin(t) + args["theta"] * y

t0 = 0.
t1 = 0.04
dt0 = 0.02
diffusion_shape = jax.ShapeDtypeStruct((1,), "float32")
solver, cont = Heun(), PIDController(1e-3, 1e-6)
ts = jnp.linspace(t0, t1, 100)

def solve(init, key, args):
    vf = ODETerm(f)
    terms = vf
    ts = jnp.linspace(t0, t1, 100)
    saving = SaveAt(ts=ts)
    sol = diffeqsolve(
        terms,
        solver,
        y0=init,
        t0=t0,
        t1=t1,
        dt0=dt0,
        args=args,
        saveat=saving,
        stepsize_controller=cont,
    )
    return sol.ys

batch_size = 30
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
args = {"theta": 0.1}

num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()

inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])

args_shard = eqx.filter_shard(args, replicated)
x, y = eqx.filter_shard((inits, keys), sharding)
fn = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(0, 0, None)))
print('shard')
_ = fn(x, y, args_shard).block_until_ready()
print('pmap')
_ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-14-d7a89459f8a5>](https://localhost:8080/#) in <cell line: 59>()
     57 fn = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(0, 0, None)))
     58 print('shard')
---> 59 _ = fn(x, y, args_shard).block_until_ready()
     60 print('pmap')
     61 _ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready()

    [... skipping hidden 20 frame]

11 frames
[<ipython-input-14-d7a89459f8a5>](https://localhost:8080/#) in solve(init, key, args)
     27     ts = jnp.linspace(t0, t1, 100)
     28     saving = SaveAt(ts=ts)
---> 29     sol = diffeqsolve(
     30         terms,
     31         solver,

    [... skipping hidden 15 frame]

[/usr/local/lib/python3.10/dist-packages/diffrax/_integrate.py](https://localhost:8080/#) in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, discrete_terminating_event, max_steps, throw, solver_state, controller_state, made_jump)
    914     #
    915 
--> 916     final_state, aux_stats = adjoint.loop(
    917         args=args,
    918         terms=terms,

    [... skipping hidden 1 frame]

[/usr/local/lib/python3.10/dist-packages/diffrax/_adjoint.py](https://localhost:8080/#) in loop(***failed resolving arguments***)
    286             )
    287             msg = None
--> 288         final_state = self._loop(
    289             terms=terms,
    290             saveat=saveat,

[/usr/local/lib/python3.10/dist-packages/diffrax/_integrate.py](https://localhost:8080/#) in loop(solver, stepsize_controller, discrete_terminating_event, saveat, t0, t1, dt0, max_steps, terms, args, init_state, inner_while_loop, outer_while_loop)
    437     static_made_jump = init_state.made_jump
    438     static_result = init_state.result
--> 439     _, traced_jump, traced_result = eqx.filter_eval_shape(body_fun_aux, init_state)
    440     if traced_jump:
    441         static_made_jump = None

    [... skipping hidden 14 frame]

[/usr/local/lib/python3.10/dist-packages/diffrax/_integrate.py](https://localhost:8080/#) in body_fun_aux(state)
    238         #
    239 
--> 240         (y, y_error, dense_info, solver_state, solver_result) = solver.step(
    241             terms,
    242             state.tprev,

    [... skipping hidden 1 frame]

[/usr/local/lib/python3.10/dist-packages/diffrax/_solver/runge_kutta.py](https://localhost:8080/#) in step(***failed resolving arguments***)
   1147         #     "triangular computations" (every stage depends on all previous stages)
   1148         #     without spurious copies.
-> 1149         final_val = eqxi.while_loop(
   1150             cond_stage,
   1151             rk_stage,

[/usr/local/lib/python3.10/dist-packages/equinox/internal/_loop/loop.py](https://localhost:8080/#) in while_loop(***failed resolving arguments***)
    105     elif kind == "checkpointed":
    106         del kind, base
--> 107         return checkpointed_while_loop(
    108             cond_fun,
    109             body_fun,

[/usr/local/lib/python3.10/dist-packages/equinox/internal/_loop/checkpointed.py](https://localhost:8080/#) in checkpointed_while_loop(***failed resolving arguments***)
    247     body_fun_ = filter_closure_convert(body_fun_, init_val_)
    248     vjp_arg = (init_val_, body_fun_)
--> 249     final_val_ = _checkpointed_while_loop(
    250         vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
    251     )

    [... skipping hidden 8 frame]

[/usr/local/lib/python3.10/dist-packages/equinox/internal/_loop/checkpointed.py](https://localhost:8080/#) in _checkpointed_while_loop(***failed resolving arguments***)
    268     _body_fun = lambda x: body_fun(x)  # hashable wrapper; JAX issue #13554
    269     while_loop = jax.named_call(lax.while_loop, name="checkpointed-no-vjp")
--> 270     return while_loop(cond_fun, _body_fun, init_val)
    271 
    272 

[/usr/lib/python3.10/contextlib.py](https://localhost:8080/#) in inner(*args, **kwds)
     77         def inner(*args, **kwds):
     78             with self._recreate_cm():
---> 79                 return func(*args, **kwds)
     80         return inner
     81 

    [... skipping hidden 9 frame]

[/usr/local/lib/python3.10/dist-packages/equinox/internal/_loop/checkpointed.py](https://localhost:8080/#) in <lambda>(x)
    266     del checkpoints, buffers, max_steps
    267     init_val, body_fun = vjp_arg
--> 268     _body_fun = lambda x: body_fun(x)  # hashable wrapper; JAX issue #13554
    269     while_loop = jax.named_call(lax.while_loop, name="checkpointed-no-vjp")
    270     return while_loop(cond_fun, _body_fun, init_val)

    [... skipping hidden 1 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in eval_jaxpr(jaxpr, consts, propagate_source_info, *args)
    447   env: dict[Var, Any] = {}
    448   map(write, jaxpr.constvars, consts)
--> 449   map(write, jaxpr.invars, args)
    450   lu = last_used(jaxpr)
    451   for eqn in jaxpr.eqns:

ValueError: safe_map() argument 2 is shorter than argument 1

just took the code from #407

@lockwo
Copy link
Contributor Author

lockwo commented May 7, 2024

Possible this is an error in equinox, but I wasn't able to exactly replicate it without diffrax, e.g. works fine

import os
import multiprocessing

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)
os.environ['EQX_ON_ERROR'] = 'nan'
import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import equinox as eqx
import equinox.internal as eqxi
import functools as ft 

def f(t, y, theta):
    return jnp.abs(jnp.sin(t)) + theta * y

_inner_loop = jax.named_call(eqxi.while_loop, name="inner-loop")
_outer_loop = jax.named_call(eqxi.while_loop, name="outer-loop")

def solve(init, key):
    def inner_loop_cond(state):
        t, y, _ = state
        return y.squeeze() < 10

    def inner_loop_body(state):
        t, y, theta = state
        dy = f(t, y, theta)
        return (t + 0.1, y + 0.1 * dy, theta)
    
    def outer_loop_cond(state):
        _, _, _, count = state
        return count < 5
    
    def outer_loop_body(state):
        t, y, theta, count = state
        y = jax.random.uniform(jax.random.PRNGKey(count), shape=(1,))
        new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
        return (new_t, new_y, theta, count + 1)

    inner_while_loop = ft.partial(_inner_loop, kind="lax")
    outer_while_loop = ft.partial(_outer_loop, kind="lax")
    theta = 5.0
    t_initial = 0.0
    y_initial = init
    count_initial = jax.random.randint(key, minval=-2, maxval=2, shape=())
    final_state = outer_while_loop(outer_loop_cond, outer_loop_body, (t_initial, y_initial, theta, count_initial))
    return final_state[1]


batch_size = 30
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)

num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()

inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])

x, y = eqx.filter_shard((inits, keys), sharding)
fn = eqx.filter_jit(eqx.filter_vmap(solve))
pmap_fn = eqx.filter_pmap(fn)
print('shard')
_ = fn(x, y).block_until_ready()
print('pmap')
_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready()

@lockwo
Copy link
Contributor Author

lockwo commented May 7, 2024

Issue that the error code references in stack trace: jax-ml/jax#13554

@lockwo
Copy link
Contributor Author

lockwo commented May 7, 2024

All of the above code works 100% fine in 0.4.26 btw (well the sharding is still slower, but that's a different issue)

@lockwo lockwo changed the title Jax 0.4.27 errors Jax 0.4.27 parallelism errors May 7, 2024
@patrick-kidger
Copy link
Owner

Thanks for the report! Looks like an upstream JAX bug. I've opened jax-ml/jax#21116.

@lockwo
Copy link
Contributor Author

lockwo commented May 10, 2024

Great, closing! 0.4.28 fixed

@lockwo lockwo closed this as completed May 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants