Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fine grained activation sharding. #881

Merged
merged 1 commit into from
Dec 17, 2024

Conversation

patrick-toulme
Copy link
Contributor

Hello, this is a draft PR to support fine grained activation sharding. All specs are defaulted to None, so this PR maintains all existing behavior of the Axlearn codebase.

I am requesting feedback and comments on the approach.

With this PR the user can control more closely how GSPMD partitions activations allowing more control of the collectives generated by the SPMD partitioner. For example, GSPMD will generate all-to-alls around the gather operation of the vocab table, but with this PR user can annotate partition specs to generate a reduce-scatter instead.

Thank you!

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Some questions about the API...

axlearn/common/attention.py Outdated Show resolved Hide resolved
axlearn/common/attention.py Outdated Show resolved Hide resolved
axlearn/common/attention.py Outdated Show resolved Hide resolved
x = self.norm(inputs)
x = maybe_shard(x, cfg.premlp_partition_spec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? I suppose norm usually does not change the partition spec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to change the spec post norm to not propagate the spec from norm to MLP linear1.
So for sequence parallel we would make the premlp_partition_spec = ((fsdp, data), None, None) to force an all-gather on the sequence dimension.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Maybe we can add output_partition_spec to RMSNorm.Config to be consistent with Linear.Config. This is also more flexible as RMSNorm can be used in other places. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I actually originally had this! I have revised to take this approach.

# If not None, how to partition pre attention activation values.
preattention_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition post attention activation values.
postattention_partition_spec: Optional[tuple[Optional[str]]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this different from setting

attn_layer.output_linear.param_partition_spec = (fsdp_axis_names, tp_axis_names, None)
?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I also have a high level question, should the sharding config auto-derived and optimized by XLA for most of the intermediary activations, based on weights and outputs etc?
I'd like to see a concrete example to understand why explicit sharding is needed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpenXLA ShardingPropagation often makes non optimal decisions in propagation. SPMD partitioner then partitions based on what HloSharding was propagated to each HloInstruction, so if your sharding propagation is not optimal then you will see collectives generated such as all-to-all to resolve the sharding conflicts.

An example is if you set the partition spec for premlp or preattention to be ((fsdp, data), model, None) to force the computation to be done in sequence parallel, then ShardingPropagation pass will propagate that HloSharding spec down to the matmul causing an all-to-all collective.

This is not as optimal as an all-gather of the sequence dimension before the matmul.

axlearn/common/attention.py Outdated Show resolved Hide resolved
axlearn/common/attention.py Outdated Show resolved Hide resolved
axlearn/common/attention.py Outdated Show resolved Hide resolved
# If not None, how to partition pre attention activation values.
preattention_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition post attention activation values.
postattention_partition_spec: Optional[tuple[Optional[str]]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I also have a high level question, should the sharding config auto-derived and optimized by XLA for most of the intermediary activations, based on weights and outputs etc?
I'd like to see a concrete example to understand why explicit sharding is needed here.

@@ -814,8 +821,14 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:
)

def forward(self, x: Tensor) -> Tensor:
cfg = self.config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here, can we handle this inside XLA compiler directly instead of making it explicit here? Adding this kind of fine-level sharding spec is a divergence from GSPMD programming paradigm imho.

Copy link
Contributor Author

@patrick-toulme patrick-toulme Dec 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Responded above with a concrete example. The main issues we are trying to solve here are

  1. We want Norms/Dropouts/Residuals to be computed in sequence parallel so partitioned along sequence dimension. This is to avoid redundant compute across tensor parallel workers.

  2. OpenXLA ShardingPropagation often makes non optimal decisions. By inserting these partition specs they actually show up in the Hlo as custom-call (Sharding); this guides the Sharding Propagation to make more optimal decisions, as it will prioritize propagating user provided partition specs. - https://github.com/openxla/xla/blob/main/xla/service/sharding_propagation.h

  3. Those non optimal decisions such as say propagating sequence level sharding to the left hand side of the attention QKV proj or MLP up proj will cause the SPMD partitioner to insert all-to-alls or collective-permutes or involuntary full rematerialization.

I am curious, why Axlearn has an activation level sharding spec in use today for linear2.output_partition_spec
Was this added to also address non optimal SPMD partitioning or propagation?

cfg_layer.self_attention.attention.input_linear = self_attention_input_linear_cfg
cfg_layers = [cfg_layer, cfg_layer]

cfg_layer.self_attention.prenorm_partition_spec = (fsdp_axis_names, tp_axis_names, None,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, is it the real use case you are testing nowadays?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is the real config for sequence parallel. I have confirmed with these specs and internal Neuron changes we can generate partitioned HLOs that perform Norms/Residuals/Dropouts partitioned along the sequence and matmuls partitioned along the model dimension without any unnecessary collectives, such as all-to-alls or collective-permutes or involuntary full rematerialization.

@ruomingp
Copy link
Contributor

Thanks. Please re-request review when it's ready.

@patrick-toulme patrick-toulme force-pushed the activation_sharding_pr branch 6 times, most recently from 5134c41 to 7d1d723 Compare December 16, 2024 00:02
@patrick-toulme patrick-toulme changed the title WIP - Support fine grained activation sharding. Support fine grained activation sharding. Dec 16, 2024
return x
return with_sharding_constraint(x, PartitionSpec(*partition_spec))


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-commit wanted these blank lines. I ran

pre-commit run -a

@patrick-toulme
Copy link
Contributor Author

@kelvin-zou @ruomingp

I have revised the PR based on your comments. I removed "WIP," as the PR is now production ready. I removed Neuron specific logic; I think that should be in a separate PR, and this PR should be enabling hardware agnostic activation sharding.

I ran precommit and pytest and both are passing. Please take a look! Thank you.

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Are there tests that can be added for the new logic?

axlearn/common/utils.py Outdated Show resolved Hide resolved
@patrick-toulme
Copy link
Contributor Author

@ruomingp What test do you envision? What should I assert? This is effectively just putting

with_sharding_constraint(tensor, PartitionSpec('fdsp', 'model'))

behind a config.

@patrick-toulme
Copy link
Contributor Author

@ruomingp I could add a test where I make an RMSNorm and set the config for partition specs and then assert they are the same specs, but I am not sure what benefit that adds. I do not see tests like that for other configs, as that is like a setter/getter test.

@ruomingp
Copy link
Contributor

@ruomingp I could add a test where I make an RMSNorm and set the config for partition specs and then assert they are the same specs, but I am not sure what benefit that adds. I do not see tests like that for other configs, as that is like a setter/getter test.

How about patching with_sharding_constraint to make sure that it's called with the right tensors in forward() when we set partition specs?

axlearn/common/utils.py Outdated Show resolved Hide resolved
@patrick-toulme
Copy link
Contributor Author

@ruomingp I could add a test where I make an RMSNorm and set the config for partition specs and then assert they are the same specs, but I am not sure what benefit that adds. I do not see tests like that for other configs, as that is like a setter/getter test.

How about patching with_sharding_constraint to make sure that it's called with the right tensors in forward() when we set partition specs?

Do you have an example of this?

@patrick-toulme
Copy link
Contributor Author

@ruomingp @kelvin-zou
I have added tests per @ruomingp specifications. All tests pass and pre-commit passes also.

Thank you!

@patrick-toulme
Copy link
Contributor Author

patrick-toulme commented Dec 17, 2024

@ruomingp @kelvin-zou Thank you for the approval!!
Does the PR auto merge? Or is automerge blocked on @markblee review? It says all checks have passed.

@ruomingp ruomingp added this pull request to the merge queue Dec 17, 2024
Merged via the queue into apple:main with commit 01b762e Dec 17, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants