Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1-bit Adam v2 #817

Merged
merged 55 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
a6dba72
NCCL-based 1-bit Adam + Code Refactor for Comm. Backends (#594)
awan-10 Dec 14, 2020
7840085
Merge branch 'master' into staging-1bit-nccl-v2
conglongli Dec 23, 2020
6dbdd98
Revert "Merge branch 'master' into staging-1bit-nccl-v2"
conglongli Dec 23, 2020
9712f10
Revert "Revert "Merge branch 'master' into staging-1bit-nccl-v2""
conglongli Dec 28, 2020
4d1f4f0
Merge branch 'master' into staging-1bit-nccl-v2
conglongli Dec 28, 2020
d8a23c9
comm optimization + 1-bit lamb
conglongli Dec 30, 2020
89e1936
Saving/debugging commit.
awan-10 Feb 5, 2021
a1bbf78
finalizing 1-bit lamb
conglongli Feb 14, 2021
db0ca76
finalizing 1-bit lamb
conglongli Feb 16, 2021
07deab8
add momentum mask and chkpt handling for 1-bit adam
conglongli Feb 16, 2021
625f475
Merge remote-tracking branch 'origin/staging-1bit-nccl-v2' into stagi…
awan-10 Feb 19, 2021
d55fddb
Cleanup and modify nccl test to be runnable with deepspeed launcher.
awan-10 Feb 19, 2021
5b1cacb
Merge branch 'master' into staging-1bit-nccl-v2
awan-10 Feb 22, 2021
8cbc212
Fix format.
awan-10 Feb 22, 2021
ff8c871
fix formatting again.
awan-10 Feb 22, 2021
c17041f
make test runnable without mpi4py
awan-10 Feb 22, 2021
5e01a30
Add dist.alltoall and dist.allgather instead of custom functions.
awan-10 Feb 22, 2021
97a5557
remove debug prints.
awan-10 Feb 22, 2021
e3e1e39
formatting and renaming
conglongli Mar 1, 2021
d5b9dcc
renaming
conglongli Mar 1, 2021
3d66a8a
renaming
conglongli Mar 1, 2021
b042467
add unit test, fix existing tests
conglongli Mar 2, 2021
ab3521d
Merge branch 'master' into staging-1bit-adam-v2
awan-10 Mar 2, 2021
9fa5166
skip unit test when torch < 1.8
conglongli Mar 3, 2021
65d7ec5
revert 1-bit lamb
conglongli Mar 3, 2021
8376a40
flatten momentum when dimension is more than 1
conglongli Mar 3, 2021
6a19f29
add warning message for 1-bit adam under fp32
conglongli Mar 3, 2021
819043d
improve version check
conglongli Mar 4, 2021
a6943be
add fp32 test
conglongli Mar 4, 2021
2042b29
Merge remote-tracking branch 'origin' into staging-1bit-adam-v2
conglongli Mar 4, 2021
66a8c93
1-bit adam doc
conglongli Mar 4, 2021
fb329a9
fix file name
conglongli Mar 4, 2021
0b3c1d7
doc fix
conglongli Mar 4, 2021
0bffa9b
torch 1.8 is released
conglongli Mar 5, 2021
294c2d6
doc fix
conglongli Mar 5, 2021
003981a
fix tests
conglongli Mar 8, 2021
3f42b3a
Merge branch 'master' into staging-1bit-adam-v2
conglongli Mar 8, 2021
bbd6143
Merge branch 'master' into staging-1bit-adam-v2
conglongli Mar 8, 2021
877f8d7
update news
conglongli Mar 8, 2021
f861465
Merge branch 'master' into staging-1bit-adam-v2
conglongli Mar 8, 2021
2ed029e
add doc for momentum mask
conglongli Mar 9, 2021
c6e7cf7
Merge branch 'staging-1bit-adam-v2' of github.com:microsoft/DeepSpeed…
conglongli Mar 9, 2021
3b53c90
fix checkpoing handling, add unit test
conglongli Mar 12, 2021
4240729
checkpoint handling doc
conglongli Mar 12, 2021
968a53f
doc final cleanup
conglongli Mar 12, 2021
bbec300
Merge branch 'master' into staging-1bit-adam-v2
conglongli Mar 12, 2021
e28a99e
Merge branch 'master' into staging-1bit-adam-v2
jeffra Mar 16, 2021
1221aec
bump dates
conglongli Mar 16, 2021
535b5ba
Merge branch 'staging-1bit-adam-v2' of github.com:microsoft/DeepSpeed…
conglongli Mar 16, 2021
8cfd2b7
update tests
conglongli Mar 16, 2021
38ff08a
url change
conglongli Mar 16, 2021
de03656
doc fix
conglongli Mar 16, 2021
5957bce
fix test
conglongli Mar 16, 2021
ef51ac6
doc update
conglongli Mar 16, 2021
7c08b34
Merge branch 'master' into staging-1bit-adam-v2
conglongli Mar 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)


# News
* [2021/03/16] [1-bit Adam v2: NCCL-based implementation and more](https://www.deepspeed.ai/tutorials/onebit-adam/)
* [2021/03/08] [ZeRO-3 Offload: Scale your models to trillion parameters without code changes while leveraging both CPUs & GPUs](https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html)
* [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation)
* [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html)
Expand Down
Empty file.
290 changes: 290 additions & 0 deletions deepspeed/runtime/comm/mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''

import torch
import cupy
import time
import numpy as np
from mpi4py import MPI

from deepspeed.runtime.compression.cupy import CupyBackend


class MpiBackend(object):
def __init__(self, cuda_aware):
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
self.cuda_aware = cuda_aware
self.compression_backend = CupyBackend()

def my_igather(self, rank, size, comm, sendbuf, recbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(comm.Irecv(recbuf[idx], source=idx))
else:
recbuf[rank] = sendbuf
else:
req.append(comm.Isend(sendbuf, dest=root))
return req

def gather_cuda(self,
rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# We do in-place operations on cupy buffers so we do not return any buffers
requests = []
for idx in range(world_size):
req_sign = self.my_igather(rank,
world_size,
comm,
cupy_sign_list_packed[idx],
cupy_recvbuf_sign,
root=idx)
requests += req_sign

for idx in range(world_size):
req_scale = self.my_igather(rank,
world_size,
comm,
cupy_worker_scale,
cupy_recvbuf_scale,
root=idx)
requests += req_scale

MPI.Request.Waitall(requests)

def gather_host(self,
rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):

# In-place operations are not possible for newly created cupy arrays
# so we need to return the new buffers
numpy_recvbuf_sign = np.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)

# 1. convert from cupy to numpy
numpy_sign_list_packed = cupy_sign_list_packed

for idx in range(world_size):
numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx])

numpy_worker_scale = cupy.asnumpy(cupy_worker_scale)
numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale)

cupy.cuda.get_current_stream().synchronize()

# 2. use numpy buffers for communication
requests = []

for idx in range(world_size):
req_sign = self.my_igather(rank,
world_size,
comm,
numpy_sign_list_packed[idx],
numpy_recvbuf_sign,
root=idx)
requests += req_sign

for idx in range(world_size):
req_scale = self.my_igather(rank,
world_size,
comm,
numpy_worker_scale,
numpy_recvbuf_scale,
root=idx)
requests += req_scale

MPI.Request.Waitall(requests)

# 3. Convert back from numpy to cupy
cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign)
for idx in range(world_size):
cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx])

cupy_worker_scale = cupy.asarray(numpy_worker_scale)
cupy_recvbuf_scale = cupy.asarray(numpy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()

return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale

def allgather_cuda(self,
comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server)
comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server)

def allgather_host(self,
comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):

# 1. Convert cupy to numpy
numpy_recvbuf_sign_server = np.zeros(
[comm.Get_size(),
cupy_server_sign_packed.size],
dtype=cupy_server_sign_packed.dtype)
numpy_recvbuf_scale_server = np.zeros([comm.Get_size(),
1],
dtype=cupy_server_scale.dtype)

numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed)
numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server)
numpy_server_scale = cupy.asnumpy(cupy_server_scale)
numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()

# 2. Communicate numpy buffers
comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server)
comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server)
comm.Barrier()

# 3. Convert numpy back to cupy
cupy_server_sign_packed = cupy.asarray(numpy_server_sign_packed)
cupy_recvbuf_sign_server = cupy.asarray(numpy_recvbuf_sign_server)
cupy_server_scale = cupy.asarray(numpy_server_scale)
cupy_recvbuf_scale_server = cupy.asarray(numpy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()

return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server

def compressed_allreduce(self,
buffer_m: torch.tensor,
worker_error,
server_error,
local_rank):

all_start_time = time.time()
original_shape = buffer_m.size()
if len(original_shape) > 1:
buffer_m = torch.flatten(buffer_m)
original_size = buffer_m.numel()
worker_error_size = worker_error.numel()
cupy.cuda.Device(local_rank).use()

if original_size != worker_error_size:
empty_tensor = torch.zeros(worker_error_size - original_size,
device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor])

buffer_m.add_(worker_error)
worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
worker_error.set_(buffer_m - worker_scale *
buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()),
self.size)
cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale)

cupy_recvbuf_sign = cupy.zeros(
[self.size,
cupy_sign_list_packed[self.rank].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)

# Communication Phase 1
gather_start = time.time()
if self.cuda_aware:
self.gather_cuda(self.rank,
self.size,
self.comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
else:
_, cupy_recvbuf_sign, _, cupy_recvbuf_scale = self.gather_host(self.rank,
self.size,
self.comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
gather_end = time.time()

# cupy_sign_list_packed, cupy_worker_scale, worker_scale = None, None, None
cupy_sign_list_packed = None

compensated_server_m = self.compression_backend.cupy2torch(
(cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(
self.size,
-1)).float().add_(-0.5).mul_(2.0).mul_(
self.compression_backend.cupy2torch(cupy_recvbuf_scale).mul_(
1 / self.size)).sum(0)
compensated_server_m.add_(server_error)
server_scale = torch.norm(compensated_server_m) / np.sqrt(
compensated_server_m.numel())
server_error.set_(
compensated_server_m - server_scale *
compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

cupy_server_scale = self.compression_backend.torch2cupy(server_scale)

cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
self.compression_backend.torch2cupy(
compensated_server_m.sign_().add_(1).bool()),
1)
compensated_server_m = None

cupy_recvbuf_sign_server = cupy.zeros(
[self.size,
cupy_server_sign_packed[0].size],
dtype=cupy_recvbuf_sign.dtype)
cupy_recvbuf_scale_server = cupy.zeros([self.size,
1],
dtype=cupy_recvbuf_scale.dtype)
# cupy_recvbuf_sign, cupy_recvbuf_scale = None, None
cupy_recvbuf_sign = None

# Communication Phase 2
if self.cuda_aware:
self.allgather_cuda(self.comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)
else:
_, cupy_recvbuf_sign_server, _, cupy_recvbuf_scale_server = self.allgather_host(self.comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)

# cupy_server_sign_packed, cupy_server_scale, server_scale = None, None, None
cupy_server_sign_packed = None

buffer_m.data.copy_(
self.compression_backend.cupy2torch(
(cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape(
self.size,
-1)).float().add_(-0.5).mul_(2.0).mul_(
self.compression_backend.cupy2torch(
cupy_recvbuf_scale_server)).flatten().data)
if original_size != worker_error_size:
buffer_m = buffer_m[0:original_size]
if len(original_shape) > 1:
buffer_m = buffer_m.reshape(original_shape)

# cupy_recvbuf_sign_server, cupy_recvbuf_scale_server = None, None

return buffer_m
Loading