We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Push this simple matmul Jax test through.
import jax import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map from functools import partial P = PartitionSpec mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y')) n_m = 8192 n_k = 784 n_n = 8192 batch_size = 1 key = jax.random.key(0) key, *keys = jax.random.split(key, len(layer_sizes)) k1 = keys[0] act = jax.random.normal(k1, (n_m, n_k)) W = jax.random.normal(k1, (n_k, n_n)) @partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), out_specs=P('x', None)) def matmul_basic(a_block, b_block): c_partialsum = jnp.dot(a_block, b_block) c_block = jax.lax.psum(c_partialsum, 'y') return c_block lowered_single = jax.jit(matmul_basic).lower(act, W) print(lowered_single.as_text())
The text was updated successfully, but these errors were encountered:
nsmithtt
wooseokTT
Successfully merging a pull request may close this issue.
Push this simple matmul Jax test through.
The text was updated successfully, but these errors were encountered: