Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic/variable batch size support #1051

Open
ecly opened this issue May 6, 2021 · 17 comments
Open

Dynamic/variable batch size support #1051

ecly opened this issue May 6, 2021 · 17 comments

Comments

@ecly
Copy link

ecly commented May 6, 2021

For the model I am training, I am relying on a custom Sampler, that returns variable batch sizes. My task at hand is translation, where I following Attention is all you need (2017) create batches based on total token count in a batch, which given the variable length input, results in batches of varying numbers of examples (examples here being one source/target text translation pair).

For regular DDP based training, this worked fine, by simply creating a distributed version of this sampler, to split the variable size batch into sub-batches based on the GPU rank. For DeepSpeed however, I am forced to provide either train_micro_batch_size_per_gpu or train_batch_size, both my current understanding tells me are based on the number of examples in the batch.

As the number of examples in my batch varies for each batch, and I just want to configure the accumulation based on batch count, rather than batch size, I'm not sure how to achieve this with DeepSpeed's configuration.

Am I misunderstanding the impact of the configuration variables, missing some other configuration, or is this not possible to achieve at the moment?

@aniruddhakal
Copy link

aniruddhakal commented Apr 15, 2022

I'm also super interested in knowing more about this. Happy to lend a hand, so to see it available faster!

@tjruwase
Copy link
Contributor

@ecly, apologies that this request somehow slipped through. I wonder what solution you ended up with and whether you are still interested in DeepSpeed support?

@aniruddhakal, thanks for bumping this back to our attention. Is it okay to wait @ecly to respond in order to decide next steps?

@ecly
Copy link
Author

ecly commented Apr 18, 2022

@tjruwase we ended up for the most part just using pure DDP in PyTorch. We did have moderate success using Fairscale which supported the variable batch sizes out of the box, but didn't find any value from doing so with our model size of ~200M parameters. The problem is very much still relevant to us, as we'd still like to further adopt DeepSpeed, but this is a blocker that makes it non-trivial for us to adopt.

@wangleiofficial
Copy link

@tjruwase we ended up for the most part just using pure DDP in PyTorch. We did have moderate success using Fairscale which supported the variable batch sizes out of the box, but didn't find any value from doing so with our model size of ~200M parameters. The problem is very much still relevant to us, as we'd still like to further adopt DeepSpeed, but this is a blocker that makes it non-trivial for us to adopt.

Hello,how to use dynamic batch in ddp, can you give an example. due to the dynamic batch size, the number of batches allocated by different ranks is inconsistent. This is what makes different ranks unable to communicate in DDP. Are there other solutions?

@tjruwase
Copy link
Contributor

tjruwase commented Jun 9, 2022

@ecly, I am also interested in your dynamic batch in ddp. If you can share some client code with us, that would help with DeepSpeed support. Thanks!

@wangleiofficial
Copy link

@tjruwase infinibatch may be a good choice for the dynamic batch in ddp. Notice, Dataset with the DistributedSampler may be better than infinibatch for validation set.

@ecly
Copy link
Author

ecly commented Jun 13, 2022

Hey @tjruwase and @wangleiofficial

As my original question is starting to be a bit old, I think I perhaps need to retest on a newer version of DeepSpeed to confirm that it's still the case. But nonetheless, I'll share a bit more details below.

The idea is just that we limit batches in the number of total tokens in the batch (to provide a similar amount of signal for learning in each minibatch), rather than the number of examples in the context of training Transformers on text inputs of varying length (in our case for Machine Translation). The code for our MaxTokensBatchSampler that we use as a batch_sampler with the PyTorch DataLoader, is similar in nature to the one used in fairseq: https://github.com/facebookresearch/fairseq/blob/b5a039c292facba9c73f59ff34621ec131d82341/fairseq/data/data_utils.py#L282

We adapt it for DDP with only a tiny bit of code:

class DistributedMaxTokensBatchSampler(DistributedSampler, MaxTokensBatchSampler):

    def __init__(self, dataset: TranslationDataset, batch_max_tokens: int, **kwargs):
        DistributedSampler.__init__(self, dataset)
        MaxTokensBatchSampler.__init__(self, dataset, batch_max_tokens, **kwargs)

    def __iter__(self) -> Iterator[List[int]]:
        iterator = MaxTokensBatchSampler.__iter__(self)
        return itertools.islice(iterator, self.rank, None, self.num_replicas)

    def __len__(self):
        return len(self.batches) // self.num_replicas

This approach can effectively result in batches with different numbers of examples, but where they are similarly sized in terms of the number of tokens. It works out of the box with DDP from our experience.

@HsunGong
Copy link

HsunGong commented Mar 1, 2023

Keep tracking this issue

@bm-synth
Copy link
Contributor

@ecly @aniruddhakal @tjruwase If not too late: I started developing dynamic batch size and the corresponding LR scaling on deepspeed in PR 5237. For the time being you have the Torch and DeepSpeed implementation for FSDP and DDP parallelism.

@tjruwase
Copy link
Contributor

@bm-synth, it is not too late. I believe my colleague @conglongli already responded on #5237. Thanks!

@ShomyLiu
Copy link

@bm-synth Thank you for your contribution to this important feature. Since it's not yet merged, I have a question regarding my specific use case.

I'm working with datasets of fixed but different lengths:

  1. Dataset A: sequence length = 2048, batch size = 8
  2. Dataset B: sequence length = 4096, batch size = 4

Note: The longer sequences use a smaller batch size, but the total token count per batch remains constant across both datasets.

My training setup: In each training step, the dataloader selects only one type of data (either A or B) to train across all GPUs.

Given this simplified scenario, would it be appropriate to set micro_batch_size=8 (i.e., the larger of the two batch sizes) in the DeepSpeed configuration?

Additional context:

  • This setup ensures that the total number of tokens processed in each batch is consistent, regardless of which dataset is chosen.
  • I'm particularly interested in understanding if this configuration might cause any issues or if there's a more optimal way to handle varying sequence lengths in this context.

Thank you in advance for your guidance on this matter.

@bm-synth
Copy link
Contributor

Hi @ShomyLiu thank you. I'll try to finish the PR asap.

in your case, you dont need to have a "setup" that automatically batches samples. The dataloader_for_variable_batch_size will do it for you (just set max_tokens_per_batch=16384 in the call to batch_by_size. If you already have that, fine. call lr_scheduler_for_variable_batch_size to get your adaptive LR scheduler, and that will update the LR at every step, taking into account the dataloader of your setup.

The micro_batch_size is just the reference batch size to which your LR in the config refers to. As an example: in your case, if you pick lr_scaling_method=linear, micro_batch_size=8 and lr=1e-3 (ie appropriate for dataset A), then when you use dataset B, because your micro_batch_size is 4 (not 8), you LR will be changed internally to lr=2e-3. I.e. the micro_batch_size and lr are just the reference pair, and the LR will change on batch based on the new micro batch size.

I would recommend you to not do the batching yourself (aka the "setup") and call dataloader_for_variable_batch_size instead to generate a dataloader, to avoid issues.

@ShomyLiu
Copy link

@bm-synth Thank you for such a detailed explanation and reply. I appreciate your help.
I think I understand the main point for the variable LR.
If I decide to use a constant learning rate and my own custom data loader, it seems that as long as the number of gradient accumulations is fixed, and the variable micro batches are used directly.
From reviewing the source code of Deepspeed, the variable micro batches could affect the only gradient accumulation boundaries in order to initiate gradient reduction.
So if the the number of gradient accumulations is fixed, there are no such issues above?
Is my understanding correct?
Thanks again for your thorough response. Your explanation has been very helpful in clarifying these concepts for me.

@bm-synth
Copy link
Contributor

bm-synth commented Aug 30, 2024

Hi @ShomyLiu .
If you use your own data loader, and no adaptive learning rate, you already have everything you need and dont need to use any of this PR's code. There are two guarantees you need to enforce:

  • you need to guarantee that the micro-batches across all dataloaders (in a distributed run) yield samples of the same shape, on each iteration. You can use a DistributedSampler to do this: the one that torch provides will allocate samples to processes in an interleaved fashion. Example, if you have 3 GPUs with a data loader process each (and two possible batch sizes SA and SB for datasets A and B), your DataLoader must generate 3 consecutive batches of the same shape, so that each will be picked by a different process. E.g. yielding SA, SA, SA, SB, SB, SB, SA, SA, SA, SA, SA, SA are valid dataloder samples, but SA, SB, SA, ... is already wrong because dataloader 1 and 3 would pick a mini batch of size SA and dataloader 2 would pick a mini batches of size SB, in the same iteration.
  • related to your question, for gradient accumulation you also need to make sure that all micro-batches within a mini-batch are of the same shape. In your case, if you have as an example 2 grad accumulation steps, then your data loader should now yield 32 (data_loaderacc_steps) consecutive samples of same shape, e.g. SA, SA, SA, SA, SA, SA, SB, SB, SB, SB, SB, SB is correct but SA, SA, SA, SB, SB, SB is already wrong because the second grad accumulation step would pick samples of size SB instead of the size SA of the first.

Hope this helps

@ShomyLiu
Copy link

@bm-synth Thank you for the detailed clarification. My implementation is a bit simpler. I've set up two dataloaders, each based on a DistributedSampler. For each training step, I randomly select one of these dataloaders. By setting a fixed seed, I ensure that all GPU processes choose data with the same shape.
This approach should meet the requirements you mentioned:
Does this sound like a valid approach? It seems to address the key points you raised while being relatively straightforward to implement.
Thank you again for your guidance.

@bm-synth
Copy link
Contributor

bm-synth commented Aug 30, 2024

@ShomyLiu yes it definitely fulfils both conditions. so all good. good luck.

@ShomyLiu
Copy link

@bm-synth Thanks again for your kind patience and guidance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants