Skip to content

Commit

Permalink
Onebitadam + Pipeline Parallel fixes (#972)
Browse files Browse the repository at this point in the history
* Pipe Engine Changes

* OneBitAdam changes

* NCCL backend changes
  • Loading branch information
sdtblck authored Apr 17, 2021
1 parent 691b6ff commit 642dd5a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
16 changes: 10 additions & 6 deletions deepspeed/runtime/comm/nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@


class NcclBackend(object):
def __init__(self):
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
def __init__(self, mpu=None):
if mpu is None:
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
else:
self.mpu = mpu
self.world_group = self.mpu.get_data_parallel_group()
self.rank = dist.get_rank(group=self.world_group)
self.size = dist.get_world_size(group=self.world_group)
self.compression_backend = CupyBackend()
Expand Down Expand Up @@ -92,9 +96,9 @@ def compressed_allreduce(self,
# communication phase 1
# gather_start = time.time()
# Alltoall for sign
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed))
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group)
# Allgather for scale
dist.all_gather(recvbuf_scale, worker_scale)
dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group)

# gather_end = time.time()

Expand Down Expand Up @@ -151,8 +155,8 @@ def compressed_allreduce(self,
]

# Communication Phase 2
dist.all_gather(recvbuf_sign_server, server_sign_packed[0])
dist.all_gather(recvbuf_scale_server, server_scale)
dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group)
dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group)

cupy_server_sign_packed = None

Expand Down
5 changes: 4 additions & 1 deletion deepspeed/runtime/fp16/onebit/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self,
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend
self.comm_backend_handle = NcclBackend()
self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)

elif self.comm_backend_name == 'mpi':
from deepspeed.runtime.comm.mpi import MpiBackend
Expand Down Expand Up @@ -254,8 +254,10 @@ def step(self, closure=None, grads=None):

if self.adam_freeze_key is False:
if state['step'] >= self.freeze_step:
print('OneBitAdam - starting compressed communication')
self.adam_freeze_key = True
self.deepspeed.enable_backward_allreduce = False
self.deepspeed.pipeline_enable_backward_allreduce = False

return loss

Expand All @@ -281,6 +283,7 @@ def load_state_dict(self, state_dict):
if self.adam_freeze_key is True:
self.adam_freeze_key = False
self.deepspeed.enable_backward_allreduce = True
self.deepspeed.pipeline_enable_backward_allreduce = True
else:
if torch.distributed.get_rank() == 0:
print(
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def __init__(self, *super_args, **super_kwargs):

# We schedule the all-reduces, so disable it in super().backward()
self.enable_backward_allreduce = False

# used to disable the pipeline all-reduce when used with 1-bit adam
self.pipeline_enable_backward_allreduce = True

assert not self.elasticity_enabled(), "Elasticity is not currently supported" \
" with pipeline parallelism."

Expand Down Expand Up @@ -220,7 +224,7 @@ def _exec_reduce_tied_grads(self):

def _exec_reduce_grads(self):
self._force_grad_boundary = True
if self.is_data_parallel:
if self.is_data_parallel and self.pipeline_enable_backward_allreduce:
self.buffered_allreduce_fallback(
elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
self._force_grad_boundary = False
Expand Down

0 comments on commit 642dd5a

Please sign in to comment.