-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic/variable batch size support #1051
Comments
I'm also super interested in knowing more about this. Happy to lend a hand, so to see it available faster! |
@ecly, apologies that this request somehow slipped through. I wonder what solution you ended up with and whether you are still interested in DeepSpeed support? @aniruddhakal, thanks for bumping this back to our attention. Is it okay to wait @ecly to respond in order to decide next steps? |
@tjruwase we ended up for the most part just using pure DDP in PyTorch. We did have moderate success using Fairscale which supported the variable batch sizes out of the box, but didn't find any value from doing so with our model size of ~200M parameters. The problem is very much still relevant to us, as we'd still like to further adopt DeepSpeed, but this is a blocker that makes it non-trivial for us to adopt. |
Hello,how to use dynamic batch in ddp, can you give an example. due to the dynamic batch size, the number of batches allocated by different ranks is inconsistent. This is what makes different ranks unable to communicate in DDP. Are there other solutions? |
@ecly, I am also interested in your dynamic batch in ddp. If you can share some client code with us, that would help with DeepSpeed support. Thanks! |
@tjruwase infinibatch may be a good choice for the dynamic batch in ddp. Notice, Dataset with the DistributedSampler may be better than infinibatch for validation set. |
Hey @tjruwase and @wangleiofficial As my original question is starting to be a bit old, I think I perhaps need to retest on a newer version of DeepSpeed to confirm that it's still the case. But nonetheless, I'll share a bit more details below. The idea is just that we limit batches in the number of total tokens in the batch (to provide a similar amount of signal for learning in each minibatch), rather than the number of examples in the context of training Transformers on text inputs of varying length (in our case for Machine Translation). The code for our We adapt it for DDP with only a tiny bit of code: class DistributedMaxTokensBatchSampler(DistributedSampler, MaxTokensBatchSampler):
def __init__(self, dataset: TranslationDataset, batch_max_tokens: int, **kwargs):
DistributedSampler.__init__(self, dataset)
MaxTokensBatchSampler.__init__(self, dataset, batch_max_tokens, **kwargs)
def __iter__(self) -> Iterator[List[int]]:
iterator = MaxTokensBatchSampler.__iter__(self)
return itertools.islice(iterator, self.rank, None, self.num_replicas)
def __len__(self):
return len(self.batches) // self.num_replicas This approach can effectively result in batches with different numbers of examples, but where they are similarly sized in terms of the number of tokens. It works out of the box with DDP from our experience. |
Keep tracking this issue |
@ecly @aniruddhakal @tjruwase If not too late: I started developing dynamic batch size and the corresponding LR scaling on deepspeed in PR 5237. For the time being you have the Torch and DeepSpeed implementation for FSDP and DDP parallelism. |
@bm-synth, it is not too late. I believe my colleague @conglongli already responded on #5237. Thanks! |
@bm-synth Thank you for your contribution to this important feature. Since it's not yet merged, I have a question regarding my specific use case. I'm working with datasets of fixed but different lengths:
Note: The longer sequences use a smaller batch size, but the total token count per batch remains constant across both datasets. My training setup: In each training step, the dataloader selects only one type of data (either A or B) to train across all GPUs. Given this simplified scenario, would it be appropriate to set Additional context:
Thank you in advance for your guidance on this matter. |
Hi @ShomyLiu thank you. I'll try to finish the PR asap. in your case, you dont need to have a "setup" that automatically batches samples. The The I would recommend you to not do the batching yourself (aka the "setup") and call |
@bm-synth Thank you for such a detailed explanation and reply. I appreciate your help. |
Hi @ShomyLiu .
Hope this helps |
@bm-synth Thank you for the detailed clarification. My implementation is a bit simpler. I've set up two dataloaders, each based on a DistributedSampler. For each training step, I randomly select one of these dataloaders. By setting a fixed seed, I ensure that all GPU processes choose data with the same shape. |
@ShomyLiu yes it definitely fulfils both conditions. so all good. good luck. |
@bm-synth Thanks again for your kind patience and guidance |
For the model I am training, I am relying on a custom Sampler, that returns variable batch sizes. My task at hand is translation, where I following Attention is all you need (2017) create batches based on total token count in a batch, which given the variable length input, results in batches of varying numbers of examples (examples here being one source/target text translation pair).
For regular DDP based training, this worked fine, by simply creating a distributed version of this sampler, to split the variable size batch into sub-batches based on the GPU rank. For DeepSpeed however, I am forced to provide either
train_micro_batch_size_per_gpu
ortrain_batch_size
, both my current understanding tells me are based on the number of examples in the batch.As the number of examples in my batch varies for each batch, and I just want to configure the accumulation based on batch count, rather than batch size, I'm not sure how to achieve this with DeepSpeed's configuration.
Am I misunderstanding the impact of the configuration variables, missing some other configuration, or is this not possible to achieve at the moment?
The text was updated successfully, but these errors were encountered: