-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compiler crash when sharding model weights #1030
Comments
Good catch, we’ll get back to you soon
From: Corentin Godeau ***@***.***>
Reply-To: aws-neuron/aws-neuron-sdk ***@***.***>
Date: Saturday, November 9, 2024 at 8:22 AM
To: aws-neuron/aws-neuron-sdk ***@***.***>
Cc: Subscribed ***@***.***>
Subject: [aws-neuron/aws-neuron-sdk] Compiler crash when sharding model weights (Issue #1030)
Hi !
I was playing with JAX on Neuron recently and came across a bug that is quite annoying.
When trying to shard a very simple MLP layer, depending on the axis you choose, the compilation fails.
Here is a snippet of code that demonstrates the issue:
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
mesh = jax.sharding.Mesh(jax.devices(), ('x'))
BATCH_SIZE = 1
SIZE = 4096
HIDDEN_SIZE = 8192
DTYPE = jnp.bfloat16
SHARD_ON_HIDDEN = True
def mlp(x: jax.Array, gate_up_proj: jax.Array, down_proj: jax.Array) -> jax.Array:
# x: (B, D)
# gate_up_proj: (D, 2 * H)
# down_proj: (H, D)
hidden = jax.lax.dot_general(x, gate_up_proj, (([1], [0]), ([],[])))
# hidden: (B, 2 * H)
x1, x2 = jnp.split(hidden, 2, 1)
hidden = jax.nn.gelu(x1) * x2
# hidden: (B, H)
hidden = jax.lax.dot_general(hidden, down_proj, (([1], [1]), ([],[])))
return hidden
if SHARD_ON_HIDDEN:
weight_sharding = jax.sharding.NamedSharding(mesh, P(None, 'x'))
else:
weight_sharding = jax.sharding.NamedSharding(mesh, P('x', None))
x = jax.ShapeDtypeStruct(shape=(BATCH_SIZE, SIZE), dtype=DTYPE, sharding=jax.sharding.NamedSharding(mesh, P(None, None)))
gate_up_proj = jax.ShapeDtypeStruct((SIZE, 2 * HIDDEN_SIZE), dtype=DTYPE, sharding=weight_sharding)
down_proj = jax.ShapeDtypeStruct((SIZE, HIDDEN_SIZE), dtype=DTYPE, sharding=weight_sharding)
lowered = jax.jit(mlp).lower(x, gate_up_proj, down_proj)
print(lowered.as_text())
compiled = lowered.compile()
print(compiled.as_text())
If SHARD_ON_HIDDEN is set to False everything works fine but when it's set to True (which is something you want for optimal performance), it crashes when using more than 2 Neuron cores (probably because it then introduces interconnect transfers).
Here is the error:
2024-Nov-09 16:14:06.605467 66954:67403 ERROR ENC:enc_parse_replica_groups [nec_dev 2] replica groups (0/1) does not have myself 2
2024-Nov-09 16:14:06.605522 66954:67403 ERROR TDRV:build_enc_source_target_pairs Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.605546 66954:67403 ERROR TDRV:instr_col_translate_ptc2 Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.605670 66954:67403 ERROR TDRV:ib_create_one_block failed to translate instructions
2024-Nov-09 16:14:06.605695 66954:67403 ERROR TDRV:ib_create_eib_v2 Failed to create instruction blocks
2024-Nov-09 16:14:06.605715 66954:67403 ERROR TDRV:sequencer_v2_setup_instr_one_eng Failed to allocate eib
2024-Nov-09 16:14:06.605735 66954:67403 ERROR TDRV:kbl_model_add create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.605981 66954:67403 ERROR NMGR:dlr_kelf_stage Failed to load subgraph
2024-Nov-09 16:14:06.606000 66954:67403 ERROR NMGR:kmgr_load_nn_internal_v2 Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.606012 66954:67403 ERROR NMGR:kmgr_load_nn_post_metrics Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
2024-Nov-09 16:14:06.606025 66954:67403 ERROR NRT:nrt_infodump Neuron runtime information - please include in any support request:
2024-Nov-09 16:14:06.606040 66954:67403 ERROR NRT:nrt_infodump ------------->8------------[ cut here ]------------>8-------------
2024-Nov-09 16:14:06.606062 66954:67403 ERROR NRT:nrt_infodump NRT version: 2.22.14.0 (6e27b8d5b22dea0e0b8375517f4d8a009b6de5a8)
2024-Nov-09 16:14:06.606084 66954:67403 ERROR NRT:nrt_infodump Embedded FW version: 1.12.2.0 (f152b70c827a52701d6b9ee74ec7ff7a15971f7d)
2024-Nov-09 16:14:06.606112 66954:67403 ERROR NRT:nrt_infodump CCOM version: 2.22.26.0- (compat 48)
2024-Nov-09 16:14:06.606134 66954:67403 ERROR NRT:nrt_infodump Instance ID: i-0b19e4a1cf3fd70d9
2024-Nov-09 16:14:06.606156 66954:67403 ERROR NRT:nrt_infodump Cluster ID: N/A
2024-Nov-09 16:14:06.606178 66954:67403 ERROR NRT:nrt_infodump Kernel: Linux 6.8.0-1015-aws #16~22.04.1-Ubuntu SMP Mon Aug 19 19:38:17 UTC 2024
2024-Nov-09 16:14:06.606200 66954:67403 ERROR NRT:nrt_infodump Nodename: ip-172-31-42-39
2024-Nov-09 16:14:06.606254 66954:67403 ERROR NRT:nrt_infodump Driver version: 2.18.12.0
2024-Nov-09 16:14:06.606276 66954:67403 ERROR NRT:nrt_infodump Failure: NRT_RESOURCE in nrt_load()
2024-Nov-09 16:14:06.606298 66954:67403 ERROR NRT:nrt_infodump Visible cores: 0, 1, 2, 3
2024-Nov-09 16:14:06.606318 66954:67403 ERROR NRT:nrt_infodump Environment:
2024-Nov-09 16:14:06.606341 66954:67403 ERROR NRT:nrt_infodump NEURON_CC_FLAGS=--model-type=transformer --auto-cast=none
2024-Nov-09 16:14:06.606362 66954:67403 ERROR NRT:nrt_infodump NEURON_RT_NUM_CORES=4
2024-Nov-09 16:14:06.606382 66954:67403 ERROR NRT:nrt_infodump NEURON_RT_ROOT_COMM_ID=localhost:49255
2024-Nov-09 16:14:06.606401 66954:67403 ERROR NRT:nrt_infodump -------------8<-----------[ cut to here ]-----------8<------------
2024-Nov-09 16:14:06.602248 66954:67406 ERROR ENC:enc_parse_replica_groups [nec_dev 1] replica groups (0/1) does not have myself 1
2024-Nov-09 16:14:06.603586 66954:67404 ERROR ENC:enc_parse_replica_groups [nec_dev 3] replica groups (0/1) does not have myself 3
2024-Nov-09 16:14:06.614656 66954:67406 ERROR TDRV:build_enc_source_target_pairs Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.626936 66954:67404 ERROR TDRV:build_enc_source_target_pairs Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.639079 66954:67406 ERROR TDRV:instr_col_translate_ptc2 Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.649859 66954:67404 ERROR TDRV:instr_col_translate_ptc2 Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.662122 66954:67406 ERROR TDRV:ib_create_one_block failed to translate instructions
2024-Nov-09 16:14:06.672477 66954:67404 ERROR TDRV:ib_create_one_block failed to translate instructions
2024-Nov-09 16:14:06.682738 66954:67406 ERROR TDRV:ib_create_eib_v2 Failed to create instruction blocks
2024-Nov-09 16:14:06.692175 66954:67404 ERROR TDRV:ib_create_eib_v2 Failed to create instruction blocks
2024-Nov-09 16:14:06.702800 66954:67406 ERROR TDRV:sequencer_v2_setup_instr_one_eng Failed to allocate eib
2024-Nov-09 16:14:06.712256 66954:67404 ERROR TDRV:sequencer_v2_setup_instr_one_eng Failed to allocate eib
2024-Nov-09 16:14:06.723884 66954:67406 ERROR TDRV:kbl_model_add create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.751248 66954:67405 ERROR ENC:enc_parse_replica_groups [nec_dev 0] replica groups (0/1) does not have myself 0
2024-Nov-09 16:14:06.751330 66954:67405 ERROR TDRV:build_enc_source_target_pairs Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.751345 66954:67405 ERROR TDRV:instr_col_translate_ptc2 Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.751929 66954:67405 ERROR TDRV:ib_create_one_block failed to translate instructions
2024-Nov-09 16:14:06.751957 66954:67405 ERROR TDRV:ib_create_eib_v2 Failed to create instruction blocks
2024-Nov-09 16:14:06.751974 66954:67405 ERROR TDRV:sequencer_v2_setup_instr_one_eng Failed to allocate eib
2024-Nov-09 16:14:06.751990 66954:67405 ERROR TDRV:kbl_model_add create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.753195 66954:67405 ERROR NMGR:dlr_kelf_stage Failed to load subgraph
2024-Nov-09 16:14:06.753230 66954:67405 ERROR NMGR:kmgr_load_nn_internal_v2 Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.753248 66954:67405 ERROR NMGR:kmgr_load_nn_post_metrics Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
2024-Nov-09 16:14:06.735901 66954:67404 ERROR TDRV:kbl_model_add create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.749448 66954:67406 ERROR NMGR:dlr_kelf_stage Failed to load subgraph
2024-Nov-09 16:14:06.762705 66954:67404 ERROR NMGR:dlr_kelf_stage Failed to load subgraph
2024-Nov-09 16:14:06.775586 66954:67406 ERROR NMGR:kmgr_load_nn_internal_v2 Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.789108 66954:67404 ERROR NMGR:kmgr_load_nn_internal_v2 Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.799686 66954:67406 ERROR NMGR:kmgr_load_nn_post_metrics Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
2024-Nov-09 16:14:06.810033 66954:67404 ERROR NMGR:kmgr_load_nn_post_metrics Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
Environment
* Python 3.10
* Packages:
* neuronx-cc==2.15.141.0+d3cfc8ca
* libneuronxla==2.0.4986.0
* jaxlib==0.4.31
* jax-neuronx==0.1.1
* jax==0.4.31
* inf2.48xlarge instance
Thanks for the help !
—
Reply to this email directly, view it on GitHub<#1030>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/AFTRWCOMYDUQSVYAGLEL2GDZ7YZDHAVCNFSM6AAAAABRPIM2U6VHI2DSMVQWIX3LMV43ASLTON2WKOZSGY2DMMRWHE4TIMA>.
You are receiving this because you are subscribed to this thread.Message ID: ***@***.***>
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi !
I was playing with JAX on Neuron recently and came across a bug that is quite annoying.
When trying to shard a very simple MLP layer, depending on the axis you choose, the compilation fails.
Here is a snippet of code that demonstrates the issue:
If
SHARD_ON_HIDDEN
is set toFalse
everything works fine but when it's set toTrue
(which is something you want for optimal performance), it crashes when using more than 2 Neuron cores (probably because it then introduces interconnect transfers).Here is the error:
Environment
neuronx-cc==2.15.141.0+d3cfc8ca
libneuronxla==2.0.4986.0
jaxlib==0.4.31
jax-neuronx==0.1.1
jax==0.4.31
inf2.48xlarge
instanceThanks for the help !
The text was updated successfully, but these errors were encountered: