-
Notifications
You must be signed in to change notification settings - Fork 279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support fine grained activation sharding. #881
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Some questions about the API...
axlearn/common/attention.py
Outdated
x = self.norm(inputs) | ||
x = maybe_shard(x, cfg.premlp_partition_spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this? I suppose norm
usually does not change the partition spec?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to change the spec post norm to not propagate the spec from norm to MLP linear1.
So for sequence parallel we would make the premlp_partition_spec = ((fsdp, data), None, None) to force an all-gather on the sequence dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. Maybe we can add output_partition_spec
to RMSNorm.Config to be consistent with Linear.Config. This is also more flexible as RMSNorm can be used in other places. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I actually originally had this! I have revised to take this approach.
axlearn/common/attention.py
Outdated
# If not None, how to partition pre attention activation values. | ||
preattention_partition_spec: Optional[tuple[Optional[str]]] = None | ||
# If not None, how to partition post attention activation values. | ||
postattention_partition_spec: Optional[tuple[Optional[str]]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this different from setting
axlearn/axlearn/common/attention.py
Line 3319 in 2134a25
attn_layer.output_linear.param_partition_spec = (fsdp_axis_names, tp_axis_names, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, I also have a high level question, should the sharding config auto-derived and optimized by XLA for most of the intermediary activations, based on weights and outputs etc?
I'd like to see a concrete example to understand why explicit sharding is needed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OpenXLA ShardingPropagation often makes non optimal decisions in propagation. SPMD partitioner then partitions based on what HloSharding was propagated to each HloInstruction, so if your sharding propagation is not optimal then you will see collectives generated such as all-to-all to resolve the sharding conflicts.
An example is if you set the partition spec for premlp or preattention to be ((fsdp, data), model, None) to force the computation to be done in sequence parallel, then ShardingPropagation pass will propagate that HloSharding spec down to the matmul causing an all-to-all collective.
This is not as optimal as an all-gather of the sequence dimension before the matmul.
axlearn/common/attention.py
Outdated
# If not None, how to partition pre attention activation values. | ||
preattention_partition_spec: Optional[tuple[Optional[str]]] = None | ||
# If not None, how to partition post attention activation values. | ||
postattention_partition_spec: Optional[tuple[Optional[str]]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, I also have a high level question, should the sharding config auto-derived and optimized by XLA for most of the intermediary activations, based on weights and outputs etc?
I'd like to see a concrete example to understand why explicit sharding is needed here.
@@ -814,8 +821,14 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: | |||
) | |||
|
|||
def forward(self, x: Tensor) -> Tensor: | |||
cfg = self.config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question here, can we handle this inside XLA compiler directly instead of making it explicit here? Adding this kind of fine-level sharding spec is a divergence from GSPMD programming paradigm imho.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Responded above with a concrete example. The main issues we are trying to solve here are
-
We want Norms/Dropouts/Residuals to be computed in sequence parallel so partitioned along sequence dimension. This is to avoid redundant compute across tensor parallel workers.
-
OpenXLA ShardingPropagation often makes non optimal decisions. By inserting these partition specs they actually show up in the Hlo as
custom-call (Sharding)
; this guides the Sharding Propagation to make more optimal decisions, as it will prioritize propagating user provided partition specs. - https://github.com/openxla/xla/blob/main/xla/service/sharding_propagation.h -
Those non optimal decisions such as say propagating sequence level sharding to the left hand side of the attention QKV proj or MLP up proj will cause the SPMD partitioner to insert all-to-alls or collective-permutes or involuntary full rematerialization.
I am curious, why Axlearn has an activation level sharding spec in use today for linear2.output_partition_spec
Was this added to also address non optimal SPMD partitioning or propagation?
axlearn/common/attention_test.py
Outdated
cfg_layer.self_attention.attention.input_linear = self_attention_input_linear_cfg | ||
cfg_layers = [cfg_layer, cfg_layer] | ||
|
||
cfg_layer.self_attention.prenorm_partition_spec = (fsdp_axis_names, tp_axis_names, None,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting, is it the real use case you are testing nowadays?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is the real config for sequence parallel. I have confirmed with these specs and internal Neuron changes we can generate partitioned HLOs that perform Norms/Residuals/Dropouts partitioned along the sequence and matmuls partitioned along the model dimension without any unnecessary collectives, such as all-to-alls or collective-permutes or involuntary full rematerialization.
Thanks. Please re-request review when it's ready. |
5134c41
to
7d1d723
Compare
7d1d723
to
a6acf7c
Compare
return x | ||
return with_sharding_constraint(x, PartitionSpec(*partition_spec)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pre-commit wanted these blank lines. I ran
pre-commit run -a
I have revised the PR based on your comments. I removed "WIP," as the PR is now production ready. I removed Neuron specific logic; I think that should be in a separate PR, and this PR should be enabling hardware agnostic activation sharding. I ran precommit and pytest and both are passing. Please take a look! Thank you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Are there tests that can be added for the new logic?
@ruomingp What test do you envision? What should I assert? This is effectively just putting with_sharding_constraint(tensor, PartitionSpec('fdsp', 'model')) behind a config. |
@ruomingp I could add a test where I make an RMSNorm and set the config for partition specs and then assert they are the same specs, but I am not sure what benefit that adds. I do not see tests like that for other configs, as that is like a setter/getter test. |
a6acf7c
to
58015ff
Compare
How about patching |
Do you have an example of this? |
58015ff
to
93c3126
Compare
@ruomingp @kelvin-zou Thank you! |
93c3126
to
4eec4d1
Compare
@ruomingp @kelvin-zou Thank you for the approval!! |
Hello, this is a draft PR to support fine grained activation sharding. All specs are defaulted to None, so this PR maintains all existing behavior of the Axlearn codebase.
I am requesting feedback and comments on the approach.
With this PR the user can control more closely how GSPMD partitions activations allowing more control of the collectives generated by the SPMD partitioner. For example, GSPMD will generate all-to-alls around the gather operation of the vocab table, but with this PR user can annotate partition specs to generate a reduce-scatter instead.
Thank you!