-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Autotp training #6922
base: master
Are you sure you want to change the base?
Autotp training #6922
Conversation
…-precision version before the rebase, but the grad norm differs (display issue)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @inkcherry , @delock
Sorry for the delay. I just left some comments. Thanks
if is_inference_mode: | ||
dist.inference_all_reduce(input, group=group) | ||
else: | ||
dist.all_reduce(input.contiguous(), group=group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there any reason for input.contiguous()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that adding this makes it safer, potentially helping to avoid discontinuity introduced by transpose/permute
.
FYI: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/mappings.py#L23
I am not very clear on the implementation detail of inference_all_reduce
, so I have kept the original dist.inference_all_reduce
code path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it is already contiguous, .contiguous()
will not launch additional memory copy kernel. Or it won't hurt performance.
@staticmethod | ||
def symbolic(graph, input): | ||
"""Symbolic function for tracing.""" | ||
return dist.all_reduce(input.contiguous(), dist.get_tensor_model_parallel_group()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar here, is this contiguous()
necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is consistent with the previous situation.
|
||
@pytest.mark.parametrize("layer_type", ["linear", "linearallreduce"]) | ||
def test(self, layer_type): | ||
tp_size = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we parametrize and test tp_size of both 2 and 4?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reminder, added
reuse_dist_env = True | ||
|
||
def test_save_original_weight(self): | ||
tp_size = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, could we parameterize both tp_size 2 and 4?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reminder, added
return | ||
|
||
if data_parallel_size is None: | ||
data_parallel_size = dist.get_world_size() // tensor_model_parallel_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to consider pipeline_parallel_size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, this feature does not support the pipeline and pipeline-related logic will not reach this part. Perhaps we can consider adding pipeline support in the future.
self.tp_config = TPConfig() | ||
self.tp_config.tp_size = tp_size | ||
if tp_size <= 1: | ||
self.tp_config.enabled = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see anywhere this flag is used (i.e. there seems no design/code if enabled flag == False)? is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing that out. It's not necessary, I was referring to the inference config. I have removed it now.
Returns: | ||
OrderedDict: The consolidated state dictionary if the current process rank is 0, otherwise None. | ||
""" | ||
#TODO: If we use both Zero3 and tensor parallel simultaneously |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also don't see why need to gather weights/params in TP training/inference. If it is only used for re-collecting weights for single point checkpoint write, then you can use our universal checkpoint feature to convert model parallel strategy after training.
tp_size: int = 1 | ||
""" Number of devices to split the model across using tensor parallelism. """ | ||
|
||
tp_grain_size: int = 64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this argument I also did not see any use case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable is used in the autoTP parser to set tile boundaries to accelerate GEMM.
set_tp_grain_size(config.tensor_parallel.tp_grain_size) |
it has not been activated in training yet, as it requires support for uneven gather. I have added clearer comments for better understanding.
class Yuan_LinearAllreduce(LinearAllreduce): | ||
|
||
#Yuan2 | ||
@torch.no_grad() | ||
def partition(self, params_list): | ||
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index, | ||
self.tp_world_size, False) | ||
params_list[0].data = weight | ||
if bias is not None: | ||
params_list[1].data = bias | ||
|
||
|
||
class Yuan_LinearLayer(LinearLayer): | ||
#Yuan2 | ||
@torch.no_grad() | ||
def partition(self, params_list): | ||
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index, | ||
self.tp_world_size, True) | ||
params_list[0].data = move(weight, get_accelerator().current_device_name()).detach() | ||
if bias is not None: | ||
params_list[1].data = move(bias, get_accelerator().current_device_name()).detach() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to make an abstraction of partition
method with arguments passed-in for different models? if doing this, we can avoid create 2 new classes (e.g., Yuan_linear & Yuan_linear+allreduce) for every new model structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, they currently only have one method. every specific shard logic should have a corresponding reverse gather logic. The current shard method hasn’t implemented the corresponding gather. I think using a class might help reserve a potential placeholder and make the code more consistent.
return new_obj | ||
|
||
|
||
class GatherReplacedLayerParams: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need gather TP params during training or inference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, the reason are integrated into the comments above.
@tjruwase @GuanhuaWang Thank you for your review. I’ve added modifications or explanations. Could you take another look? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @inkcherry, thanks for contributing. Just a heads-up, all the all_reduce call in domino is supposed to be asynchronous, and current LinearAllreduce
and LinearLayer
need to be updated to work with Domino.
For example, in the LinearAllreduce
, we'd like to get the handle from asynchronous all reduce, and synchronize it later to overlap computation.
The Domino work is still in progress, and it's not finalized yet. So, you don't need to worry about the compatibility with Domino at this point. But one thing you can easily support is the async TP, similar to Megatron here. Maybe it can be your next PR.
Thanks for your help!
FYI @tjruwase @GuanhuaWang @delock @skyshine102 context: #5445
changes/support
gather_16bit_weights_on_model_save=True
in ds config).HF trainer dependency:
transformer: https://github.com/inkcherry/transformers/tree/ds_tp
accelerate: https://github.com/inkcherry/accelerate/tree/ds_tp
I could send them once ds support these api.
Usage:
Users do not need to modify the client code, they only need to configure the settings in the config file to achieve the desired functionality.
Below is an example of code for fine-tuning a LLaMA 2 model (SFT). It supports Zero3/FSDP training and enables TP training by simply adjusting the configuration
https://github.com/inkcherry/stanford_alpaca/commits/tp_demo_1127/
This branch contains three commits, with the last two commits added for quick experiments and logging purposes.
results
loss curve(gbs=16):
zero3(baseline)
tp(this)
zero1 with zero1+tp(zero compatible)
performance(For your reference only.):
zero3(not enabled any acceleration.) : 18GB 2.3s/it
zero1:38GB 1.30s/it
zero1+tp: 24GB 1.66s/it
extension:
I think async-TP/domino .etc. can be implemented by inheriting a class and overriding the fwd/bwd methods. The logic for gather/partition can be reused to achieve this.(please correct me if I am wrong)
Complex sharding can also be achieved through independent partitioning and gathering. Partitioning is mandatory, while gathering is required for training.
TODO:
embedding vocab parallel
Currently, the parallelism for embeddings is primarily based on hidden_dim parallel combined with allreduce. This approach takes advantage of efficient reduction kernels. and it is not forced to use.
In training, however, the more common method is vocab parallelism. Enabling by default can save a certain amount of GPU memory.
thanks for @delock guidance.
I also verified inference with cpu-inference workloads(Optimized Model List in https://github.com/intel/intel-extension-for-pytorch/tree/main).
many thanks for @xuguangxin @ikurtchen @rogerxfeng8 ,@Yejing-Lai ,@ys950902 .etc. Help review and address matters related to inference.