Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX partitioning error when attempting to run with sequence parallelism factor not a power of 2 #9

Open
exists-forall opened this issue Oct 24, 2023 · 0 comments

Comments

@exists-forall
Copy link

I'm trying to run Ring Attention on a machine with 6 A100 GPUs, and I'm finding that when I try to set the sequence parallelism dimension to anything other than a power of 2, the process crashes with a JAX partitioning error.

I'd be grateful for any insight into whether or not I'm doing something wrong in the way I'm invoking the training script, and for any advice on how to work around this issue.

Steps to Reproduce

Consider the following script for invoking llamabpt.train:

#########################################################
### Configuration 1 (runs successfully)               ###
#########################################################
export CUDA_VISIBLE_DEVICES="0,1,2,3"
SEQ_PAR_DIM=4
MAX_SEQ_LEN=131072
#########################################################

#########################################################
### Configuration 2 (CRASHES with partitioning error) ###
#########################################################
# export CUDA_VISIBLE_DEVICES="0,1,2"
# SEQ_PAR_DIM=3
# MAX_SEQ_LEN=98304
#########################################################

python3 -m llamabpt.train \
  --mesh_dim="1,1,1,${SEQ_PAR_DIM}" \
  --dtype=bf16 \
  --load_llama_config=1b \
  --update_llama_config="{'max_sequence_length': ${MAX_SEQ_LEN}, 'scan_attention': True, 'scan_query_chunk_size': 2048, 'scan_key_chunk_size': 4096, 'remat_attention': 'nothing_saveable', 'scan_mlp': True, 'scan_mlp_chunk_size': 2048, 'remat_mlp': 'nothing_saveable', 'remat_block': 'nothing_saveable', 'scan_layers': True, 'attention_type': 'ring_blockwise', 'param_scan_axis': 0, 'mesh_dim': '1,1,1,${SEQ_PAR_DIM}'}" \
  --total_steps=2 \
  --log_freq=1 \
  --save_model_freq=0 \
  --save_milestone_freq=1000 \
  --tokenizer.vocab_file="${TRAIN_DATA_PATH}" \
  --optimizer.type=adamw \
  --optimizer.adamw_optimizer.weight_decay=0.1 \
  --optimizer.adamw_optimizer.lr=1.5e-4 \
  --optimizer.adamw_optimizer.end_lr=1.5e-5 \
  --optimizer.adamw_optimizer.lr_warmup_steps=1 \
  --optimizer.adamw_optimizer.lr_decay_steps=10 \
  --train_dataset.type=json \
  --train_dataset.text_processor.fields=text \
  --train_dataset.json_dataset.path="${TOKENIZER_PATH}" \
  --train_dataset.json_dataset.seq_length=${MAX_SEQ_LEN} \
  --train_dataset.json_dataset.batch_size=1 \
  --train_dataset.json_dataset.tokenizer_processes=16

For Configuration 1, where the sequence parallelism dimension is 4, the training script runs as expected without errors.

However, when I uncomment Configuration 2, where the sequence parallelism dimension is 3, the training script crashes with the following error:

ValueError: One of pjit outputs with pytree key path .params['params']['lm_head']['kernel'] was given the sharding of NamedSharding(mesh={'dp': 1, 'fsdp': 1, 'tp': 1, 'sp': 3}, spec=PartitionSpec(('fsdp', 'sp'), 'tp')), which implies that the global size of its dimension 0 should be divisible by 3, but it is equal to 2048 (full shape: (2048, 32000))

The error occurs during the first call to sharded_init_fn.

I would expect Configuration 2 to run successfully, because the total sequence length (98304) is a multiple of the sequence parallelism dimension (3).

Generalizing to more sequence parallelism dimensions, I find that:

  • Setting SEQ_PAR_DIM to either 2 or 4 runs successfully.
  • Setting SEQ_PAR_DIM to either 3 or 6 crashes with a partitioning error.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant