Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1-bit LAMB optimizer #970

Merged
merged 26 commits into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
755860a
initial commit
conglongli Mar 22, 2021
cb3760c
make build test work
conglongli Mar 22, 2021
0ec4397
unit test
conglongli Mar 22, 2021
7141b15
put scaling coeff inside state
conglongli Mar 28, 2021
830bab2
add unit test
conglongli Mar 28, 2021
e2b0987
add unit test
conglongli Mar 28, 2021
a1fa974
check DeepSpeedEngine.enable_backward_allreduce at CheckOverflow.has_…
conglongli Mar 30, 2021
5daacfd
doc
conglongli Apr 5, 2021
4f9cbe4
resolve merge conflict
conglongli Apr 17, 2021
99e102d
doc
conglongli Apr 17, 2021
691b6ff
var name change
conglongli Apr 17, 2021
642dd5a
Onebitadam + Pipeline Parallel fixes (#972)
sdtblck Apr 17, 2021
4831dca
polish pipeline parallel support
conglongli Apr 17, 2021
c0f176d
pipeline support unit test
conglongli Apr 18, 2021
e04e539
Merge branch 'master' into staging-1bit-lamb
conglongli Apr 19, 2021
ce4c7ea
Merge branch 'master' into staging-1bit-lamb
conglongli Apr 19, 2021
0a2b165
doc fix
conglongli Apr 19, 2021
c470e07
Merge branch 'master' into staging-1bit-lamb
conglongli Apr 19, 2021
6953e94
Merge branch 'master' into staging-1bit-lamb
jeffra Apr 20, 2021
4689223
handle pipeline case correctly
conglongli Apr 20, 2021
ba73f26
Merge branch 'staging-1bit-lamb' of github.com:microsoft/DeepSpeed in…
conglongli Apr 20, 2021
8f5adfe
Merge branch 'master' into staging-1bit-lamb
jeffra Apr 20, 2021
d091006
different way to check pipeline
conglongli Apr 21, 2021
6c64750
Merge branch 'staging-1bit-lamb' of github.com:microsoft/DeepSpeed in…
conglongli Apr 21, 2021
98163e1
add doc
conglongli Apr 21, 2021
08fbc5a
Merge branch 'master' into staging-1bit-lamb
conglongli Apr 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ DeepSpeed delivers extreme-scale model training for everyone, from data scientis
* Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters.
* Extremely memory efficient: With just a single GPU, ZeRO-Offload of DeepSpeed can train models with over 10B parameters, 10x bigger than the state of arts, democratizing multi-billion-parameter model training such that many deep learning scientists can explore bigger and better models.
* Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers.
* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks.
* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam/1-bit LAMB reduce communication volume by up to 5x while achieving similar convergence efficiency to Adam/LAMB, allowing for scaling to different types of GPU clusters and networks.

Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called
Expand All @@ -33,6 +33,7 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)


# News
* [2021/04/20] [1-bit LAMB: up to 4.6x less communication and 2.8x faster training, together with LAMB's convergence speed at large batch sizes](https://www.deepspeed.ai/tutorials/onebit-lamb/)
* [2021/04/19] [ZeRO-Infinity unlocks unprecedented model scale for deep learning training](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/)
* [Tutorial on how to use different stages of ZeRO](https://www.deepspeed.ai/tutorials/zero/)
* [2021/04/01] [[DeepSpeed on AzureML] Transformers and CIFAR examples are now available on AzureML GitHub](https://github.com/Azure/azureml-examples/tree/main/workflows/train/deepspeed)
Expand Down Expand Up @@ -119,7 +120,7 @@ overview](https://www.deepspeed.ai/features/) for descriptions and usage.
* Memory- and compute-efficient sparse kernels
* Support 10x longer sequences than dense
* Flexible support to different sparse structures
* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html)
* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html) and [1-bit LAMB](https://www.deepspeed.ai/tutorials/onebit-lamb/)
* Custom communication collective
* Up to 5x communication volume saving
* [Additional Memory and Bandwidth Optimizations](https://www.deepspeed.ai/features/#additional-memory-and-bandwidth-optimizations)
Expand Down Expand Up @@ -192,7 +193,7 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840).
5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888).
6. Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He. (2021) ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning. [arXiv:2104.07857](https://arxiv.org/abs/2104.07857).

7. Conglong Li, Ammar Ahmad Awan, Hanlin Tang, Samyam Rajbhandari, Yuxiong He. (2021) 1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed. [arXiv:2104.06069](https://arxiv.org/abs/2104.06069).

# Videos
1. DeepSpeed KDD 2020 Tutorial
Expand Down
20 changes: 14 additions & 6 deletions deepspeed/runtime/comm/nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@


class NcclBackend(object):
def __init__(self):
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
def __init__(self, mpu=None):
if mpu is None:
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
else:
self.mpu = mpu
self.world_group = self.mpu.get_data_parallel_group()
self.rank = dist.get_rank(group=self.world_group)
self.size = dist.get_world_size(group=self.world_group)
self.compression_backend = CupyBackend()
Expand Down Expand Up @@ -92,9 +96,11 @@ def compressed_allreduce(self,
# communication phase 1
# gather_start = time.time()
# Alltoall for sign
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed))
dist.all_to_all_single(recvbuf_sign,
torch.stack(sign_list_packed),
group=self.world_group)
# Allgather for scale
dist.all_gather(recvbuf_scale, worker_scale)
dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group)

# gather_end = time.time()

Expand Down Expand Up @@ -151,8 +157,10 @@ def compressed_allreduce(self,
]

# Communication Phase 2
dist.all_gather(recvbuf_sign_server, server_sign_packed[0])
dist.all_gather(recvbuf_scale_server, server_scale)
dist.all_gather(recvbuf_sign_server,
server_sign_packed[0],
group=self.world_group)
dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group)

cupy_server_sign_packed = None

Expand Down
2 changes: 2 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
ADAMW_OPTIMIZER = 'adamw'
LAMB_OPTIMIZER = 'lamb'
ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
ONEBIT_LAMB_OPTIMIZER = 'onebitlamb'
DEEPSPEED_OPTIMIZERS = [
ADAM_OPTIMIZER,
ADAMW_OPTIMIZER,
LAMB_OPTIMIZER,
ONEBIT_ADAM_OPTIMIZER,
ONEBIT_LAMB_OPTIMIZER,
]

# extra optimizer parameters for adam/adamw
Expand Down
15 changes: 13 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, \
ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT

from deepspeed.runtime.dataloader import DeepSpeedDataLoader
Expand Down Expand Up @@ -553,7 +553,8 @@ def _do_sanity_check(self):
assert self._is_supported_optimizer(self.optimizer_name()), \
'{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name())

if self.optimizer_name() == LAMB_OPTIMIZER:
if self.optimizer_name() == LAMB_OPTIMIZER or self.optimizer_name(
) == ONEBIT_LAMB_OPTIMIZER:
assert self.dynamic_loss_scale(), \
'DeepSpeed {} optimizer requires dynamic loss scaling'.format(self.optimizer_name())

Expand Down Expand Up @@ -694,6 +695,13 @@ def _configure_basic_optimizer(self, model_parameters):
logger.warning(
f'Currently the convergence of 1-bit Adam is only verified under FP16'
)
elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER:
from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb
optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters)
if not self.fp16_enabled():
logger.warning(
f'Currently the convergence of 1-bit Lamb is only verified under FP16'
)
else:
torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
Expand All @@ -710,6 +718,7 @@ def _configure_fp16_optimizer(self, optimizer):
timers = self.timers if self.wall_clock_breakdown() else None
optimizer = FP16_Optimizer(
optimizer,
deepspeed=self,
dynamic_loss_scale=True,
initial_dynamic_scale=initial_dynamic_scale,
dynamic_loss_args=dynamic_loss_args,
Expand All @@ -723,6 +732,7 @@ def _configure_fp16_optimizer(self, optimizer):
ranks=[0])
optimizer = FP16_Optimizer(
optimizer,
deepspeed=self,
static_loss_scale=self.loss_scale(),
mpu=self.mpu,
clip_grad=clip_grad,
Expand All @@ -732,6 +742,7 @@ def _configure_fp16_optimizer(self, optimizer):
ranks=[0])
optimizer = FP16_UnfusedOptimizer(
optimizer,
deepspeed=self,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=dynamic_loss_args,
Expand Down
5 changes: 4 additions & 1 deletion deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class FP16_Optimizer(object):
"""
def __init__(self,
init_optimizer,
deepspeed=None,
static_loss_scale=1.0,
dynamic_loss_scale=False,
initial_dynamic_scale=2**32,
Expand Down Expand Up @@ -100,7 +101,9 @@ def __init__(self,
self.mpu = mpu

self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)
self.overflow_checker = CheckOverflow(self.fp16_groups,
mpu=self.mpu,
deepspeed=deepspeed)
self.initialize_optimizer_states()

def initialize_optimizer_states(self):
Expand Down
10 changes: 7 additions & 3 deletions deepspeed/runtime/fp16/onebit/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self,
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend
self.comm_backend_handle = NcclBackend()
self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)

elif self.comm_backend_name == 'mpi':
from deepspeed.runtime.comm.mpi import MpiBackend
Expand Down Expand Up @@ -254,8 +254,10 @@ def step(self, closure=None, grads=None):

if self.adam_freeze_key is False:
if state['step'] >= self.freeze_step:
print('OnebitAdam - starting compressed communication')
self.adam_freeze_key = True
self.deepspeed.enable_backward_allreduce = False
self.deepspeed.pipeline_enable_backward_allreduce = False

return loss

Expand All @@ -277,18 +279,20 @@ def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
if self.state[self.param_groups[0]['params'][0]]['step'] < self.freeze_step:
if torch.distributed.get_rank() == 0:
print("Checkpoint loaded and 1-bit Adam warmup stage starts/continues.")
print("Checkpoint loaded and OnebitAdam warmup stage starts/continues.")
if self.adam_freeze_key is True:
self.adam_freeze_key = False
self.deepspeed.enable_backward_allreduce = True
self.deepspeed.pipeline_enable_backward_allreduce = True
else:
if torch.distributed.get_rank() == 0:
print(
"Checkpoint loaded and 1-bit Adam compression stage starts/continues."
"Checkpoint loaded and OnebitAdam compression stage starts/continues."
)
if self.adam_freeze_key is False:
self.adam_freeze_key = True
self.deepspeed.enable_backward_allreduce = False
self.deepspeed.pipeline_enable_backward_allreduce = False
# We reset the compression errors when loading checkpoints for 3 reasons:
# 1) The worker and server error at each GPU are distinct, so in current implementation
# only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors.
Expand Down
Loading