Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push Jax test through #924

Open
nsmithtt opened this issue Oct 16, 2024 · 0 comments · Fixed by #1351 or #1432 · May be fixed by #1716
Open

Push Jax test through #924

nsmithtt opened this issue Oct 16, 2024 · 0 comments · Fixed by #1351 or #1432 · May be fixed by #1716
Assignees

Comments

@nsmithtt
Copy link
Contributor

Push this simple matmul Jax test through.

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
from functools import partial


P = PartitionSpec

mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y'))

n_m = 8192
n_k = 784
n_n = 8192
batch_size = 1

key = jax.random.key(0)
key, *keys = jax.random.split(key, len(layer_sizes))
k1 = keys[0]
act = jax.random.normal(k1, (n_m, n_k))
W = jax.random.normal(k1, (n_k, n_n))

@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
         out_specs=P('x', None))
def matmul_basic(a_block, b_block):
  c_partialsum = jnp.dot(a_block, b_block)
  c_block = jax.lax.psum(c_partialsum, 'y')
  return c_block

lowered_single = jax.jit(matmul_basic).lower(act, W)
print(lowered_single.as_text())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment