Data imbalance handling in MXNet Gluon #17237
-
Hello, My question regarding data imbalance handling in Gluon is as follows: Suppose I'm training with 4 GPUs. For an update, my training loop samples 4 batches (one for each GPU) and runs fwd/bkwd on them. Using a Gluon Trainer, I can reduce and update gradients on all 4 GPUs. Now I'm towards the end of an epoch and I only have 2 batches left to process. I sample those 2 batches, send them off to the first two GPUs, run fwd/bkwd. At this point, 2 GPUs have non-zero gradients. If I do a Trainer.step(), how does it reduce gradients on all GPUS?
|
Beta Was this translation helpful? Give feedback.
Replies: 5 comments
-
One way is to let the Gluon Dataloader handle the last batch using [discard/ rollover] so that every GPU process the same number of samples. https://mxnet.apache.org/api/python/docs/api/gluon/data/index.html#mxnet.gluon.data.DataLoader |
Beta Was this translation helpful? Give feedback.
-
So that only works when the number of samples being samples is less than batch_size but I'm talking about a case when number of batches being sampled is less than number of GPUs. Also, we don't have an issue handling data imbalance but I'm trying to understand the internals of how MXNet does it. Today, we sample batches and if the number of sampled batches is less than number of GPUs, we just simple process batches on those GPUs and do a trainer.step() which reduces the gradients correctly and updates params. I would like to understand how MXNET handles this internally in the PS architecture. |
Beta Was this translation helpful? Give feedback.
-
Thanks for posting the question here. If your current mini-batch is small and GPU 3 & 4 does not even have 1 sample, the gradient on GPU 3 & 4 will remain the same as what they were for the previous iteration. Therefore, for this case, the allreduced_gradient will be based on refresh gradients from GPU 1 & 2, and stale gradients from GPU 3 & 4. |
Beta Was this translation helpful? Give feedback.
-
I recently updated the split sampler in gluonnlp, such that the number of samplers for each worker will always be the same (with This somewhat avoids the imbalanced data batch problem. If it is useful I can upstream the sampler to mxnet, too |
Beta Was this translation helpful? Give feedback.
I recently updated the split sampler in gluonnlp, such that the number of samplers for each worker will always be the same (with
even_size=True
). https://gluon-nlp.mxnet.io/master/api/modules/data.html?highlight=splitsampler#gluonnlp.data.SplitSamplerThis somewhat avoids the imbalanced data batch problem. If it is useful I can upstream the sampler to mxnet, too