diff --git a/azure-pipelines.yml b/azure-pipelines.yml index d57b335ff818..244b55068962 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -12,7 +12,7 @@ jobs: cuda.version: '10.0' pytorch.version: '1.2' torchvision.version: '0.4.0' - runmodeltests: true + runmodeltests: false #PyTorch15-CUDA101: # python.version: '3.7' # cuda.version: '10.1' @@ -40,6 +40,7 @@ jobs: conda install -q --yes pip conda install -q --yes gxx_linux-64 if [[ $(cuda.version) != "10.2" ]]; then conda install --yes -c conda-forge cudatoolkit-dev=$(cuda.version) ; fi + echo "PATH=$PATH, LD_LIBRARY_PATH=$LD_LIBRARY_PATH" displayName: 'Setup environment python=$(python.version) pytorch=$(pytorch.version) cuda=$(cuda.version)' # Manually install torch/torchvision first to enforce versioning. diff --git a/basic_install_test.py b/basic_install_test.py index 7207fe0319c6..dfc0326e13dd 100644 --- a/basic_install_test.py +++ b/basic_install_test.py @@ -1,15 +1,19 @@ import torch +import warnings import importlib try: - import deepspeed as ds + import deepspeed print("deepspeed successfully imported") except ImportError as err: raise err +print(f"torch install path: {torch.__path__}") print(f"torch version: {torch.__version__}") - -print(f"deepspeed info: {ds.__version__}, {ds.__git_hash__}, {ds.__git_branch__}") +print(f"deepspeed install path: {deepspeed.__path__}") +print( + f"deepspeed info: {deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}" +) try: apex_C = importlib.import_module('apex_C') @@ -17,12 +21,6 @@ except Exception as err: raise err -try: - fused_lamb = importlib.import_module('deepspeed.ops.lamb.fused_lamb_cuda') - print('deepspeed fused lamb kernels successfully installed') -except Exception as err: - raise err - try: from apex.optimizers import FP16_Optimizer print("using old-style apex") @@ -30,8 +28,19 @@ print("using new-style apex") try: - ds_transformer = importlib.import_module( - 'deepspeed.ops.transformer.transformer_cuda') - print('deepspeed transformer kernels successfully installed') + importlib.import_module('deepspeed.ops.lamb.fused_lamb_cuda') + print('deepspeed lamb successfully installed.') except Exception as err: - raise err + warnings.warn("deepspeed lamb is NOT installed.") + +try: + importlib.import_module('deepspeed.ops.transformer.transformer_cuda') + print('deepspeed transformer kernels successfully installed.') +except Exception as err: + warnings.warn('deepspeed transformer kernels are NOT installed.') + +try: + importlib.import_module('deepspeed.ops.sparse_attention.cpp_utils') + print('deepspeed sparse attention successfully installed.') +except ImportError: + warnings.warn('deepspeed sparse attention is NOT installed.') diff --git a/deepspeed/launcher/constants.py b/deepspeed/launcher/constants.py new file mode 100644 index 000000000000..f384d58b2c52 --- /dev/null +++ b/deepspeed/launcher/constants.py @@ -0,0 +1,14 @@ +# Copyright 2020 The Microsoft DeepSpeed Team + +############################################# +# Torch distributed constants +############################################# +TORCH_DISTRIBUTED_DEFAULT_PORT = 29500 + +PDSH_LAUNCHER = 'pdsh' +PDSH_MAX_FAN_OUT = 1024 + +OPENMPI_LAUNCHER = 'openmpi' + +MVAPICH_LAUNCHER = 'mvapich' +MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile' diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py index b59a13d33bb6..205aee2d6ac4 100755 --- a/deepspeed/launcher/launch.py +++ b/deepspeed/launcher/launch.py @@ -1,5 +1,11 @@ +# Copyright 2020 The Microsoft DeepSpeed Team """ -Copyright 2020 The Microsoft DeepSpeed Team: deepspeed@microsoft.com +DeepSpeed launcher, this is similar to torch.distributed.launch but supports +additional features such as abitrary gpu exclusion. + +deepspeed.launcher.launch is intended to be run on a single worker node and +will spawn several worker sub-processes depending on how many devices/ranks +are on the worker. """ import sys @@ -10,7 +16,8 @@ from collections import defaultdict from argparse import ArgumentParser, REMAINDER -from deepspeed.utils import logger +from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT +from ..utils import logger def parse_args(): @@ -32,7 +39,7 @@ def parse_args(): " single node multi-proc training, the" " --master_addr can simply be 127.0.0.1") parser.add_argument("--master_port", - default=29500, + default=TORCH_DISTRIBUTED_DEFAULT_PORT, type=int, help="Master node (rank 0)'s free port that needs to " "be used for communication during distributed " diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py new file mode 100644 index 000000000000..a45cc6a56505 --- /dev/null +++ b/deepspeed/launcher/multinode_runner.py @@ -0,0 +1,189 @@ +import os +import sys +import shutil +import subprocess +import warnings +from abc import ABC, abstractmethod + +from ..utils import logger +from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE + + +class MultiNodeRunner(ABC): + def __init__(self, args, world_info_base64): + self.args = args + self.user_arguments = self.parse_user_args() + self.user_script = args.user_script + self.world_info_base64 = world_info_base64 + self.exports = {} + + @abstractmethod + def backend_exists(self): + pass + + @abstractmethod + def get_cmd(self, environment, active_resources): + pass + + def add_export(self, key, var): + self.exports[key.strip()] = var.strip() + + def parse_user_args(self): + return self.args.user_args + + +class PDSHRunner(MultiNodeRunner): + def __init__(self, args, world_info_base64): + super().__init__(args, world_info_base64) + + def backend_exists(self): + return shutil.which('pdsh') + + def parse_user_args(self): + return list( + map(lambda x: x if x.startswith("-") else "'{}'".format(x), + self.args.user_args)) + + def get_cmd(self, environment, active_resources): + environment['PDSH_RCMD_TYPE'] = 'ssh' + + active_workers = ",".join(active_resources.keys()) + logger.info("Running on the following workers: %s" % active_workers) + + # PDSH flags for max node fan out and specific hosts to launch on + # See https://linux.die.net/man/1/pdsh for flag details + pdsh_cmd_args = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers] + + exports = "" + for key, val in self.exports.items(): + exports += "export {}={}; ".format(key, val) + + deepspeed_launch = [ + exports, + "cd {};".format(os.path.abspath('.')), + sys.executable, + "-u", + "-m", + "deepspeed.launcher.launch", + '--world_info={}'.format(self.world_info_base64), + "--node_rank=%n", + "--master_addr={}".format(self.args.master_addr), + "--master_port={}".format(self.args.master_port) + ] + + return pdsh_cmd_args + deepspeed_launch + [self.user_script + ] + self.user_arguments + + +class OpenMPIRunner(MultiNodeRunner): + def __init__(self, args, world_info_base64, resource_pool): + super().__init__(args, world_info_base64) + self.resource_pool = resource_pool + self.add_export('UCX_TLS', 'tcp') + + def backend_exists(self): + #TODO: if IB is available we should suggestion mvapich + return shutil.which('ompi_info') + + def get_cmd(self, environment, active_resources): + #TODO: Allow for include/exclude at node-level but not gpu-level + assert self.args.include == "" and self.args.exclude == "", 'openmpi backend does not support worker include/exclusion' + assert self.args.num_nodes == -1 and self.args.num_gpus == -1, 'openmpi backend does not support limiting num nodes/gpus' + total_process_count = sum(self.resource_pool.values()) + + mpirun_cmd = [ + 'mpirun', + '-n', + f'{total_process_count}', + '-hostfile', + f'{self.args.hostfile}', + '--mca', + 'btl', + '^openib', + '--mca', + 'btl_tcp_if_include', + 'eth0', + ] + + export_cmd = [] + for k, v in self.exports.items(): + export_cmd += ['-x', f'{k}={v}'] + + python_exec = [sys.executable, "-u"] + + return mpirun_cmd + export_cmd + python_exec + [self.user_script + ] + self.user_arguments + + +class MVAPICHRunner(MultiNodeRunner): + def __init__(self, args, world_info_base64, resource_pool): + super().__init__(args, world_info_base64) + self.resource_pool = resource_pool + + # Disable the CMA kernel module, not available on Ubuntu systems + self.add_export('MV2_SMP_USE_CMA', '0') + + # If we fail this will output more verbose logging + self.add_export('MV2_DEBUG_SHOW_BACKTRACE', '1') + + # Enabled cuda-aware communication + self.add_export('MV2_USE_CUDA', '1') + + # Support deep learning frameworks: http://hidl.cse.ohio-state.edu/userguide/horovod/ + self.add_export('MV2_SUPPORT_DL', '1') + + # Support MPI_THREAD_MULTIPLE + self.add_export('MV2_ENABLE_AFFINITY', '0') + + # Performance tuning flags for allgather + self.add_export('MV2_INTER_ALLGATHER_TUNING', '5') + self.add_export('MV2_CUDA_USE_NAIVE', '0') + + def backend_exists(self): + #TODO: if IB is available we should suggestion mvapich + mpiname_exists = shutil.which('mpiname') + exists = False + if not mpiname_exists: + warnings.warn("mpiname does not exist, mvapich is not installed properly") + else: + results = subprocess.check_output('mpiname', shell=True) + mpiname_results = results.decode('utf-8').strip() + if "MVAPICH2-GDR" in mpiname_results: + exists = True + else: + warnings.warn( + f"Expected MVAPICH2-GDR as return for mpiname but received {mpiname_results}" + ) + return exists + + def get_cmd(self, environment, active_resources): + #TODO: Allow for include/exclude at node-level but not gpu-level + assert self.args.include == "" and self.args.exclude == "", 'mvapich backend does not support worker include/exclusion' + assert self.args.num_nodes == -1 and self.args.num_gpus == -1, 'mvapich backend does not support limiting num nodes/gpus' + devices_per_node = self.resource_pool.values() + total_process_count = sum(devices_per_node) + process_per_node = list(devices_per_node)[0] + assert all([n == process_per_node for n in devices_per_node]), "mvapich requires same number of devices per node" + + with open(MVAPICH_TMP_HOSTFILE, 'w') as fd: + for host in self.resource_pool.keys(): + fd.write(f'{host}\n') + + mpirun_cmd = [ + 'mpirun', + '-np', + f'{total_process_count}', + '-ppn', + f'{process_per_node}', + '--hostfile', + f'{MVAPICH_TMP_HOSTFILE}', + ] + + export_cmd = [] + for k, v in self.exports.items(): + export_cmd += ['-env', f'{k}={v}'] + + python_exec = [sys.executable, "-u"] + + return mpirun_cmd + export_cmd + python_exec + [self.user_script + ] + self.user_arguments diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index 565083aa7feb..9479bb63758c 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -1,5 +1,9 @@ +# Copyright 2020 The Microsoft DeepSpeed Team """ -Copyright 2020 The Microsoft DeepSpeed Team +DeepSpeed runner is the main front-end to launching multi-worker +training jobs with DeepSpeed. By default this uses pdsh to parallel +ssh into multiple worker nodes and launch all the neccisary processes +per rank for training. """ import os @@ -14,11 +18,13 @@ import torch.cuda -from deepspeed.runtime.constants import TORCH_DISTRIBUTED_DEFAULT_PORT -from deepspeed.utils import logger +from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner +from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT, \ + PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER +from ..utils import logger DLTS_HOSTFILE = "/job/hostfile" -EXPORT_ENVS = ["NCCL", "PYTHON"] +EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", 'UCX'] DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env" DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.'] PDSH_MAX_FAN_OUT = 1024 @@ -62,12 +68,20 @@ def parse_args(args=None): resources except slot 0 on worker-1. ''') - parser.add_argument("--num_nodes", type=int, default=-1, help="") + parser.add_argument("--num_nodes", + type=int, + default=-1, + help="Total number of worker nodes to run on, this will use " + "the top N hosts from the given hostfile.") - parser.add_argument("--num_gpus", type=int, default=-1, help="") + parser.add_argument("--num_gpus", + type=int, + default=-1, + help="Max number of GPUs to use on each node, will use " + "[0:N) GPU ids on each node.") parser.add_argument("--master_port", - default=int(TORCH_DISTRIBUTED_DEFAULT_PORT), + default=TORCH_DISTRIBUTED_DEFAULT_PORT, type=int, help="(optional) Port used by PyTorch distributed for " "communication during training.") @@ -78,6 +92,18 @@ def parse_args(args=None): help="(optional) IP address of node 0, will be " "inferred via 'hostname -I' if not specified.") + parser.add_argument("--launcher", + default=PDSH_LAUNCHER, + type=str, + help="(optional) choose launcher backend for multi-node" + "training. Options currently include PDSH, OpenMPI, MVAPICH.") + + parser.add_argument("--launcher_args", + default="", + type=str, + help="(optional) pass launcher specific arguments as a " + "single quoted argument.") + parser.add_argument("user_script", type=str, help="User script to launch, followed by any required " @@ -292,17 +318,18 @@ def main(args=None): ] cmd = deepspeed_launch + [args.user_script] + args.user_args else: - env['PDSH_RCMD_TYPE'] = 'ssh' - - active_workers = ",".join(active_resources.keys()) - logger.info("Running on the following workers: %s" % active_workers) - - # PDSH flags for max node fan out and specific hosts to launch on - # See https://linux.die.net/man/1/pdsh for flag details - pdsh_cmd_args = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers] + args.launcher = args.launcher.lower() + if args.launcher == PDSH_LAUNCHER: + runner = PDSHRunner(args, world_info_base64) + elif args.launcher == OPENMPI_LAUNCHER: + runner = OpenMPIRunner(args, world_info_base64, resource_pool) + elif args.launcher == MVAPICH_LAUNCHER: + runner = MVAPICHRunner(args, world_info_base64, resource_pool) + else: + raise NotImplementedError(f"Unknown launcher {args.launcher}") - num_nodes = len(active_resources.keys()) - num_gpus_per_node = None + if not runner.backend_exists(): + raise RuntimeError(f"launcher '{args.launcher}' not installed.") curr_path = os.path.abspath('.') if 'PYTHONPATH' in env: @@ -312,33 +339,20 @@ def main(args=None): exports = "" for var in env.keys(): - if any(map(lambda name: var.startswith(name), EXPORT_ENVS)): - exports += "export {}={}; ".format(var, env[var]) + if any([var.startswith(name) for name in EXPORT_ENVS]): + runner.add_export(var, env[var]) for environ_path in DEEPSPEED_ENVIRONMENT_PATHS: environ_file = os.path.join(environ_path, DEEPSPEED_ENVIRONMENT_NAME) if os.path.isfile(environ_file): with open(environ_file, 'r') as fd: for var in fd.readlines(): - exports += "export {}; ".format(var.strip()) + key, val = var.split('=') + runner.add_export(key, val) - deepspeed_launch = [ - exports, - "cd {};".format(curr_path), - sys.executable, - "-u", - "-m", - "deepspeed.launcher.launch", - '--world_info={}'.format(world_info_base64), - "--node_rank=%n", - "--master_addr={}".format(args.master_addr), - "--master_port={}".format(args.master_port) - ] - user_args = list( - map(lambda x: x if x.startswith("-") else "'{}'".format(x), - args.user_args)) - cmd = pdsh_cmd_args + deepspeed_launch + [args.user_script] + user_args - logger.info("cmd={}".format(cmd)) + cmd = runner.get_cmd(env, active_resources) + + logger.info("cmd = {}".format(' '.join(cmd))) result = subprocess.Popen(cmd, env=env) result.wait() diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 26c10f3b0e35..754e780a2eea 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -14,9 +14,10 @@ from deepspeed.utils import logger TENSOR_CORE_ALIGN_SIZE = 8 +ONEBIT_ADAM_OPTIMIZER = 'onebitadam' ADAM_OPTIMIZER = 'adam' LAMB_OPTIMIZER = 'lamb' -DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER] +DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER] def get_amp_enabled(param_dict): diff --git a/deepspeed/runtime/custom_collectives.py b/deepspeed/runtime/custom_collectives.py new file mode 100644 index 000000000000..cb77edcaf60d --- /dev/null +++ b/deepspeed/runtime/custom_collectives.py @@ -0,0 +1,154 @@ +''' +Copyright 2019 The Microsoft DeepSpeed Team +''' + +from mpi4py import MPI +import numpy as np +import cupy + + +def my_igather(rank, size, comm, sendbuf, recbuf, root): + req = [] + if rank == root: + for idx in range(size): + if idx != rank: + req.append(comm.Irecv(recbuf[idx], source=idx)) + else: + recbuf[rank] = sendbuf + else: + req.append(comm.Isend(sendbuf, dest=root)) + return req + + +def gather_cuda(rank, + world_size, + comm, + cupy_sign_list_packed, + cupy_recvbuf_sign, + cupy_worker_scale, + cupy_recvbuf_scale): + # We do in-place operations on cupy buffers so we do not return any buffers + requests = [] + for idx in range(world_size): + req_sign = my_igather(rank, + world_size, + comm, + cupy_sign_list_packed[idx], + cupy_recvbuf_sign, + root=idx) + requests += req_sign + + for idx in range(world_size): + req_scale = my_igather(rank, + world_size, + comm, + cupy_worker_scale, + cupy_recvbuf_scale, + root=idx) + requests += req_scale + + MPI.Request.Waitall(requests) + + +def gather_host(rank, + world_size, + comm, + cupy_sign_list_packed, + cupy_recvbuf_sign, + cupy_worker_scale, + cupy_recvbuf_scale): + # In-place operations are not possible for newly created cupy arrays + # so we need to return the new buffers + numpy_recvbuf_sign = np.zeros([world_size, + cupy_sign_list_packed[rank].size], + dtype=cupy_sign_list_packed[0].dtype) + numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype) + + # 1. convert from cupy to numpy + numpy_sign_list_packed = cupy_sign_list_packed + + for idx in range(world_size): + numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx]) + + numpy_worker_scale = cupy.asnumpy(cupy_worker_scale) + numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale) + + cupy.cuda.get_current_stream().synchronize() + + # 2. use numpy buffers for communication + requests = [] + + for idx in range(world_size): + req_sign = my_igather(rank, + world_size, + comm, + numpy_sign_list_packed[idx], + numpy_recvbuf_sign, + root=idx) + requests += req_sign + + for idx in range(world_size): + req_scale = my_igather(rank, + world_size, + comm, + numpy_worker_scale, + numpy_recvbuf_scale, + root=idx) + requests += req_scale + + MPI.Request.Waitall(requests) + + # 3. Convert back from numpy to cupy + cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign) + for idx in range(world_size): + cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx]) + + cupy_worker_scale = cupy.asarray(numpy_worker_scale) + cupy_recvbuf_scale = cupy.asarray(numpy_recvbuf_scale) + cupy.cuda.get_current_stream().synchronize() + + return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale + + +def allgather_cuda(comm, + cupy_server_sign_packed, + cupy_recvbuf_sign_server, + cupy_server_scale, + cupy_recvbuf_scale_server): + comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server) + comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server) + + +def allgather_host(comm, + cupy_server_sign_packed, + cupy_recvbuf_sign_server, + cupy_server_scale, + cupy_recvbuf_scale_server): + + # 1. Convert cupy to numpy + numpy_recvbuf_sign_server = np.zeros([comm.Get_size(), + cupy_server_sign_packed.size], + dtype=cupy_server_sign_packed.dtype) + numpy_recvbuf_scale_server = np.zeros([comm.Get_size(), + 1], + dtype=cupy_server_scale.dtype) + + numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed) + numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server) + numpy_server_scale = cupy.asnumpy(cupy_server_scale) + numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server) + cupy.cuda.get_current_stream().synchronize() + + # 2. Communicate numpy buffers + comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server) + comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server) + comm.Barrier() + + # 3. Convert numpy back to cupy + cupy_server_sign_packed = cupy.asarray(numpy_server_sign_packed) + cupy_recvbuf_sign_server = cupy.asarray(numpy_recvbuf_sign_server) + cupy_server_scale = cupy.asarray(numpy_server_scale) + cupy_recvbuf_scale_server = cupy.asarray(numpy_recvbuf_scale_server) + cupy.cuda.get_current_stream().synchronize() + + return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 0a02bf6dd7d0..b7decf991b58 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -18,7 +18,8 @@ from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.config import DeepSpeedConfig, \ - ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS + ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_OPTIMIZERS + from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.constants import \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ @@ -27,8 +28,6 @@ from deepspeed.runtime.csr_tensor import CSRTensor import deepspeed.runtime.lr_schedules as lr_schedules -from deepspeed.ops.lamb import FusedLamb - from deepspeed.utils import logger from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer @@ -122,6 +121,7 @@ def __init__(self, self.config_params = config_params self.loaded_checkpoint_mp_world_size = None self.loaded_checkpoint_dp_world_size = None + self.enable_backward_allreduce = True if dist_init_required is None: dist_init_required = not dist.is_initialized() @@ -527,6 +527,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters): def _configure_basic_optimizer(self, model_parameters): optimizer_parameters = self.optimizer_params() + # print(optimizer_parameters.keys()) if 'max_grad_norm' in optimizer_parameters.keys(): raise ValueError( "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" @@ -535,7 +536,11 @@ def _configure_basic_optimizer(self, model_parameters): from apex.optimizers.fused_adam import FusedAdam optimizer = FusedAdam(model_parameters, **optimizer_parameters) elif self.optimizer_name() == LAMB_OPTIMIZER: + from deepspeed.ops.lamb import FusedLamb optimizer = FusedLamb(model_parameters, **optimizer_parameters) + elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER: + from deepspeed.runtime.fp16.onebit_adam import OnebitAdam + optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters) else: torch_optimizer = getattr(torch.optim, self.optimizer_name()) optimizer = torch_optimizer(model_parameters, **optimizer_parameters) @@ -545,7 +550,8 @@ def _configure_fp16_optimizer(self, optimizer): initial_dynamic_scale = self.initial_dynamic_scale() dynamic_loss_args = self.dynamic_loss_scale_args() clip_grad = self.gradient_clipping() - if self.optimizer_name() == ADAM_OPTIMIZER: + if self.optimizer_name() == ADAM_OPTIMIZER or self.optimizer_name( + ) == ONEBIT_ADAM_OPTIMIZER: if self.dynamic_loss_scale(): logger.info('Creating fp16 optimizer with dynamic loss scale') timers = self.timers if self.wall_clock_breakdown() else None @@ -734,7 +740,7 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): else: self.buffered_allreduce_fallback(elements_per_buffer=bucket_size) - def backward(self, loss, allreduce_gradients=True): + def backward(self, loss, allreduce_gradients=True, release_loss=False): r"""Execute backward pass on the loss Arguments: @@ -796,7 +802,7 @@ def backward(self, loss, allreduce_gradients=True): self.timers('backward_allreduce_microstep').start() self.timers('backward_allreduce').start() - if allreduce_gradients: + if allreduce_gradients and self.enable_backward_allreduce: self.allreduce_gradients() if self.wall_clock_breakdown(): @@ -805,6 +811,10 @@ def backward(self, loss, allreduce_gradients=True): self.timers('backward').stop() self.timers('backward_microstep').stop() + if release_loss: + # loss.data = None + pass + return loss def is_gradient_accumulation_boundary(self): diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 3c5f66b21fb5..98cb6b1d1402 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -101,6 +101,20 @@ def __init__(self, self.overflow = False self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu) + self.initialize_optimizer_states() + + def initialize_optimizer_states(self): + for i, group in enumerate(self.fp16_groups): + self.fp32_groups_flat[i].grad = torch.zeros( + self.fp32_groups_flat[i].size(), + device=self.fp32_groups_flat[i].device) + + self.optimizer.step() + + for i, group in enumerate(self.fp16_groups): + self.fp32_groups_flat[i].grad = None + + return def zero_grad(self, set_grads_to_None=True): """ @@ -204,6 +218,9 @@ def step(self, closure=None): if p.grad is None else p.grad.to(data_type) for p in group ])) + for p in group: + p.grad = None + self.fp32_groups_flat[i].grad = grads_groups_flat[i] self.start_timers([COMPUTE_NORM]) @@ -223,6 +240,7 @@ def step(self, closure=None): "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) self.log_timers(OVERFLOW_TIMERS) + grads_groups_flat = None return self.overflow self.start_timers([UNSCALE_AND_CLIP]) diff --git a/deepspeed/runtime/fp16/onebit_adam.py b/deepspeed/runtime/fp16/onebit_adam.py new file mode 100644 index 000000000000..c6566c28777b --- /dev/null +++ b/deepspeed/runtime/fp16/onebit_adam.py @@ -0,0 +1,374 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' +import types +import torch +import importlib +import numpy as np +import time +import cupy +from torch.utils.dlpack import to_dlpack +from torch.utils.dlpack import from_dlpack +from deepspeed.utils.logging import logger + +from mpi4py import MPI +from deepspeed.runtime.custom_collectives import gather_cuda, gather_host, allgather_cuda, allgather_host + + +class OnebitAdam(torch.optim.Optimizer): + """Implements the 1-bit Adam algorithm. Currently GPU-only. + For usage example please see, TODO DeepSpeed Tutorial + It has been proposed in APMSqueeze (https://arxiv.org/abs/2008.11343) + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + freeze_step (int, optional): Number of steps for warmup (uncompressed) + stage before we start using compressed communication. (default 100000) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0) + min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED in 1-bit Adam! + eps_inside_sqrt (boolean, optional): in the 'update parameters' step, + adds eps to the bias-corrected second moment estimate before + evaluating square root instead of adding it to the square root of + second moment estimate as in the original paper. (default: False) + cuda_aware (boolean, required): Set True if the underlying MPI implementation + supports CUDA-Aware communication. (default: False) + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + def __init__(self, + params, + deepspeed=None, + lr=1e-3, + freeze_step=100000, + bias_correction=True, + betas=(0.9, + 0.999), + eps=1e-8, + eps_inside_sqrt=False, + weight_decay=0., + max_grad_norm=0., + amsgrad=False, + cuda_aware=False): + + if amsgrad: + raise RuntimeError('1-bit Adam does not support the AMSGrad variant.') + defaults = dict(lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm) + + super(OnebitAdam, self).__init__(params, defaults) + from mpi4py import MPI + self.eps_mode = 0 if eps_inside_sqrt else 1 + + self.comm = MPI.COMM_WORLD + self.rank = self.comm.Get_rank() + self.size = self.comm.Get_size() + self.comm_time = 0.0 + self.step_time = 0.0 + self.ave_step = 1 + self.bk_time = 0.0 + self.divider = int(self.size * 8 / np.gcd(self.size, 8)) + self.deepspeed = deepspeed + self.adam_freeze_key = False + self.initialize = False + self.freeze_step = freeze_step + self.cuda_aware = cuda_aware + + def torch2cupy(self, tensor): + return cupy.fromDlpack(to_dlpack(tensor)) + + def cupy2torch(self, cupy_tensor): + return from_dlpack(cupy_tensor.toDlpack()) + + def compress_by_chunk(self, cupy_bool_tensor, num_chunks): + packed_sign = cupy.packbits(cupy_bool_tensor) + sign_list_packed = cupy.split(packed_sign, num_chunks) + cupy.cuda.get_current_stream().synchronize() + return sign_list_packed + + def Compressed_Allreduce(self, + buffer_m: torch.tensor, + worker_error, + server_error, + rank, + world_size, + comm, + local_rank): + + all_start_time = time.time() + original_size = buffer_m.numel() + cupy.cuda.Device(local_rank).use() + + if torch.numel(buffer_m) != torch.numel(worker_error): + empty_tensor = torch.zeros(torch.numel(worker_error) - torch.numel(buffer_m), + device=buffer_m.device) + buffer_m = torch.cat([buffer_m, empty_tensor]) + + buffer_m.add_(worker_error) + worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m)) + sign_buffer_m = buffer_m.sign().add_(1).bool() + sign_buffer_m = sign_buffer_m.float() + sign_buffer_m.add_(-0.5).mul_(2.0) + worker_error.set_((buffer_m - worker_scale * sign_buffer_m)) + sign_buffer_m = None + + compensated_buffer_m = buffer_m + compensated_buffer_m.sign_() + compensated_buffer_m = compensated_buffer_m.add_(1).bool() + cupy_worker_scale = self.torch2cupy(worker_scale) + cupy_compensated_buffer_m = self.torch2cupy(compensated_buffer_m) + compensated_buffer_m = None + + cupy_sign_list_packed = self.compress_by_chunk(cupy_compensated_buffer_m, + world_size) + cupy_compensated_buffer_m = None + + cupy_recvbuf_sign = cupy.zeros([world_size, + cupy_sign_list_packed[rank].size], + dtype=cupy_sign_list_packed[0].dtype) + cupy_recvbuf_scale = cupy.zeros([world_size, 1], dtype=cupy_worker_scale.dtype) + + # Communication Phase 1 + gather_start = time.time() + if self.cuda_aware: + gather_cuda(rank, + world_size, + comm, + cupy_sign_list_packed, + cupy_recvbuf_sign, + cupy_worker_scale, + cupy_recvbuf_scale) + else: + cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale = gather_host(rank, + world_size, + comm, + cupy_sign_list_packed, + cupy_recvbuf_sign, + cupy_worker_scale, + cupy_recvbuf_scale) + gather_end = time.time() + + cupy_unpacked_sign = (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape( + world_size, + -1) + cupy_recvbuf_sign = None + unpacked_sign = self.cupy2torch(cupy_unpacked_sign).float() + cupy_unpacked_sign = None + unpacked_sign = unpacked_sign.add_(-0.5).mul_(2.0) + worker_scale = self.cupy2torch(cupy_recvbuf_scale).mul_(1 / world_size) + compensated_server_m = unpacked_sign.mul_(worker_scale).sum(0) + unpacked_sign = None + + compensated_server_m.add_(server_error) + server_scale = torch.norm(compensated_server_m) / np.sqrt( + compensated_server_m.numel()) + sign_server_m = compensated_server_m.sign().add_(1).bool() + sign_server_m = sign_server_m.float() + sign_server_m.add_(-0.5).mul_(2.0) + server_error.set_(compensated_server_m - server_scale * sign_server_m) + sign_server_m = None + + compensated_server_m.sign_() + compensated_server_m = compensated_server_m.add_(1).bool() + cupy_server_scale = self.torch2cupy(server_scale) + cupy_compensated_server_m = self.torch2cupy(compensated_server_m) + compensated_server_m = None + + cupy_server_sign_packed = self.compress_by_chunk(cupy_compensated_server_m, 1) + + cupy_recvbuf_sign_server = cupy.zeros( + [world_size, + cupy_server_sign_packed[0].size], + dtype=cupy_sign_list_packed[0].dtype) + cupy_recvbuf_scale_server = cupy.zeros([world_size, + 1], + dtype=cupy_worker_scale.dtype) + + # Communication Phase 2 + if self.cuda_aware: + allgather_cuda(comm, + cupy_server_sign_packed[0], + cupy_recvbuf_sign_server, + cupy_server_scale, + cupy_recvbuf_scale_server) + else: + cupy_server_sign_packed[0], cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server = allgather_host(comm, + cupy_server_sign_packed[0], + cupy_recvbuf_sign_server, + cupy_server_scale, + cupy_recvbuf_scale_server) + + cupy_server_unpacked_sign = (cupy.unpackbits( + cupy_recvbuf_sign_server.flatten())).reshape(world_size, + -1) + cupy_recvbuf_sign_server = None + + server_unpacked_sign = self.cupy2torch(cupy_server_unpacked_sign) + cupy_server_unpacked_sign = None + + server_unpacked_sign = server_unpacked_sign.float().add_(-0.5).mul_(2.0) + server_scale = self.cupy2torch(cupy_recvbuf_scale_server) + buffer_m = server_unpacked_sign.mul_(server_scale).flatten()[0:original_size] + + return buffer_m + + def step(self, closure=None, grads=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + grads (list of tensors, optional): weight gradient to use for the + optimizer update. If gradients have type torch.half, parameters + are expected to be in type torch.float. (default: None) + output params (list of tensors, optional): A reduced recision copy + of the updated weights written out in addition to the regular + updated weights. Have to be of same type as gradients. (default: None) + scale (float, optional): factor to divide gradient tensor values + by before applying to weights. (default: 1) + """ + loss = None + if closure is not None: + loss = closure() + + gather_time = 0 + allgather_time = 0 + all_time = 0 + + if self.adam_freeze_key is False: + v_diff_buffer = 0.0 + + if grads is None: + grads_group = [None] * len(self.param_groups) + # backward compatibility + # assuming a list/generator of parameter means single group + elif isinstance(grads, types.GeneratorType): + grads_group = [grads] + elif type(grads[0]) != list: + grads_group = [grads] + else: + grads_group = grads + + for group, grads_this_group in zip(self.param_groups, grads_group): + if grads_this_group is None: + grads_this_group = [None] * len(group['params']) + + bias_correction = 1 if group['bias_correction'] else 0 + + for p, grad in zip(group['params'], grads_this_group): + if p.grad is None and grad is None: + continue + if grad is None: + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + 'FusedAdam does not support sparse gradients, please consider SparseAdam instead' + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + state['tensor_size'] = torch.numel(p.data) + state['corrected_tensor_size'] = state['tensor_size'] + + if state['tensor_size'] % (self.size * self.divider) != 0: + state['corrected_tensor_size'] += ((self.size * self.divider) - + (state['tensor_size'] % + (self.size * self.divider))) + state['server_chunk_size'] = state[ + 'corrected_tensor_size'] // self.size + + if not self.initialize or (self.adam_freeze_key + and 'worker_error' not in state.keys()): + torch.cuda.empty_cache() + state['worker_error'] = torch.zeros(state['corrected_tensor_size'], + device=p.device) + state['server_error'] = torch.zeros(state['server_chunk_size'], + device=p.device) + torch.cuda.empty_cache() + self.adam_freeze_key = True + if not self.initialize and torch.distributed.get_rank() == 0: + print("Cupy Buffers Initialized Successfully.") + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + if self.adam_freeze_key is False: + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + grad = None + if self.initialize: + update = exp_avg / (exp_avg_sq.sqrt() + group['eps']) + + else: + if 'non_freeze' in group.keys() and group['non_freeze'] is True: + dist.all_reduce(grad) + grad.mul_(1 / dist.get_world_size()) + exp_avg.mul_(beta1).add(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + grad = None + else: + if self.initialize is True: + exp_avg.mul_(beta1).add_(1 - beta1, grad) + grad = None + + if self.size > 1: + exp_avg.set_( + self.Compressed_Allreduce(exp_avg, + state['worker_error'], + state['server_error'], + self.rank, + self.size, + self.comm, + self.deepspeed.local_rank)) + if self.initialize: + update = exp_avg / (exp_avg_sq.sqrt() + group['eps']) + + if self.initialize: + if group['weight_decay'] > 0.0: + update += group['weight_decay'] * p.data + with torch.no_grad(): + p.add_(-group['lr'] * update) + + if not self.initialize: + print('Pop out errors', flush=True) + state.pop('worker_error') + state.pop('server_error') + + if not self.initialize: + self.adam_freeze_key = False + self.initialize = True + print( + f"Finished the initialization step at rant {torch.distributed.get_rank()}" + ) + return loss + + if self.adam_freeze_key is False: + if state['step'] >= self.freeze_step: + self.adam_freeze_key = True + self.deepspeed.enable_backward_allreduce = False + + return loss diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index d9d0781aa4fe..b424487717e3 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -8,6 +8,7 @@ import torch from torch._six import inf +import torch.distributed as dist from deepspeed.utils import logger @@ -23,7 +24,8 @@ def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False): for param in group: self.params.append(param) - def check_using_norm(self, norm_group): + def check_using_norm(self, norm_group, reduce_overflow=True): + #TODO: I don't think reduce_overflow is needed if mpu is None overflow = -1 in norm_group if self.mpu is not None: @@ -32,6 +34,11 @@ def check_using_norm(self, norm_group): op=torch.distributed.ReduceOp.MAX, group=self.mpu.get_model_parallel_group()) overflow = overflow_gpu[0].item() + elif reduce_overflow: + cuda_overflow = torch.cuda.FloatTensor([overflow]) + dist.all_reduce(cuda_overflow, op=torch.distributed.ReduceOp.MAX) + dist.barrier() + overflow = cuda_overflow[0].item() return bool(overflow) diff --git a/docs/_tutorials/getting-started.md b/docs/_tutorials/getting-started.md index b8ecc8027fb9..c62eef569a1d 100644 --- a/docs/_tutorials/getting-started.md +++ b/docs/_tutorials/getting-started.md @@ -9,8 +9,7 @@ date: 2020-05-15 * Please see our [Azure tutorial](/tutorials/azure/) to get started with DeepSpeed on Azure! * If you're not on Azure, we recommend using our docker image via `docker pull deepspeed/deepspeed:latest` which contains a pre-installed version of DeepSpeed and all the necessary dependencies. -* If you want to install DeepSpeed manually, we provide an install script -* `install.sh` to help install on a local machine or across an entire cluster. +* If you want to install DeepSpeed manually, we provide an install script `install.sh` to help install on a local machine or across an entire cluster. ## Writing DeepSpeed Models DeepSpeed model training is accomplished using the DeepSpeed engine. The engine diff --git a/docs/_tutorials/onebit-adam.md b/docs/_tutorials/onebit-adam.md new file mode 100644 index 000000000000..c36ffa614233 --- /dev/null +++ b/docs/_tutorials/onebit-adam.md @@ -0,0 +1,234 @@ +--- +title: "1-bit Adam: Up to 5x less communication volume and up to 2x faster training" +--- + +In this tutorial, we are going to introduce the 1-bit Adam optimizer in DeepSpeed. 1-bit Adam can improve model training speed on communication-constrained clusters, especially for communication-intensive large models by reducing the overall communication volume by up to 5x. + +To illustrate the benefits and usage of 1-bit Adam optimizer in DeepSpeed, we use the following two training tasks as examples: + +1. BingBertSQuAD Fine-tuning +2. BERT Pre-training + +For more details on these tasks, please refer to the tutorial posts on [BingBertSQuAD Fine-tuning](https://www.deepspeed.ai/tutorials/bert-finetuning/) and [BERT Pre-training](https://www.deepspeed.ai/tutorials/bert-pretraining/). + +## Overview + +If you don't already have a copy of the DeepSpeed repository, please clone in +now and checkout the DeepSpeedExamples submodule that contains the BingBertSQuAD and BERT Pre-training examples. + +```shell +git clone https://github.com/microsoft/DeepSpeed +cd DeepSpeed +git submodule update --init --recursive +cd DeepSpeedExamples/ +``` +## Pre-requisites for 1-bit Adam + +1-bit Adam uses advanced communication schemes that are not yet supported by PyTorch distributed and NCCL. We rely on Message Passing Interface (MPI) for these advanced communication primitives. + +We package the necessary dependencies in the DeepSpeed docker images. However, if you are using a different build system, please install MPI and mpi4py on your system. We have tested CUDA-Aware MPI communication using the [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) library. However, any CUDA-Aware communication library including [OpenMPI](https://www.open-mpi.org/) should work fine with these examples. + +An example launch command for 1-bit Adam using the `deepspeed` launcher is as follows: + +```shell +deepspeed --launcher=[mvapich|openmpi] script.py +``` + +Alternatively, the standard mpirun launcher can also be used as follows: + +```shell +mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flags] bash [training_script.sh] +``` + +### Configuration +The 1-bit Adam feature can be used by setting the optimizer configuration options as follows. An example json config file is shown below. + +```json +{ + "train_batch_size": 4096, + "train_micro_batch_size_per_gpu": 64, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 2e-4, + "freeze_step": 400, + "cuda_aware": true + } + }, + "fp16": { + "enabled": true, + } +} +``` +Please note two new parameters `freeze_step` and `cuda_aware` that have been added to support the 1-bit Adam feature. + +`cuda_aware` is used to indicate that the underlying MPI library support CUDA-Aware communication. +This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication. + +`freeze_step` is the number of warm up steps before 1-bit compression gets applied to the communication. In order to determine the number of warm up steps, one strategy is to set 15-25% of the total training steps for a given model. If it provides the desired outcome, one can try to extract more performance by reducing the steps systematically. In future, we plan to introduce a threshold that can automatically search and decide for the number of warm up steps for different models. The examples below have been tuned for the number of warm up steps. The `freeze_step` parameter has already been set to the best number we found in the corresponding run scripts. + +## 1. BingBertSQuAD fine-tuning with 1-bit Adam + +* Download the SQuAD dataset: + * Training set: [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json) + * Validation set: [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json) +* Download the HuggingFace checkpoint and config files: + * [bert-large-uncased-whole-word-masking](https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin) + * [bert json config](https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json) + +You can also use a pre-trained BERT model checkpoint from either DeepSpeed, [HuggingFace](https://github.com/huggingface/transformers), or [TensorFlow](https://github.com/google-research/bert#pre-trained-models) to run the fine-tuning. + +### 1.1 Running BingBertSQuAD with DeepSpeed and 1-bit Adam + +The main part of training is done in `nvidia_run_squad_deepspeed.py`, which has +already been modified to use DeepSpeed. The `run_squad_deepspeed.sh` script +helps to invoke training and setup several different hyperparameters relevant +to the training process. + +- **DeepSpeed-enabled:** Start training with DeepSpeed by providing the following 4 arguments to this script: + +```shell +bash run_squad_deepspeed.sh ` +``` + +The first argument is the number of GPUs to train with, second argument is the path to the pre-training checkpoint, third is the path to training and validation sets (e.g., train-v1.1.json), and fourth is path to an output folder where the results will be saved. This script will invoke `nvidia_run_squad_deepspeed.py`. + +- **DeepSpeed with 1-bit Adam enabled:** In order to run with 1-bit Adam feature enabled, the same script (`nvidia_run_squad_deepspeed.py`) can be used but there are two options for launching this properly: 1) Launch using deepspeed launcher and 2) Launch with mpirun. + +To enable the 1-bit compressed training, 1-bit Adam uses an MPI library (E.g. MVAPICH2-GDR, OpenMPI, etc.) as the communication backend, which means that we can use `mpirun` to launchg the training job. However, our user-friendly launcher called `deepspeed` has been enhanced to launch MPI jobs as well. + +### Launch with deepspeed + +The following helper script in the DeepSpeedExamples/BingBertSQuAD will launch the training without the need for setting any `mpirun` parameters. + +```shell +bash run_squad_deepspeed_onebitadam.sh +``` + +### Launch with mpirun + +Alternatively, we show how the standard `mpirun` launcher can be used for launching the fine-tuning job. + +```shell +mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flags] bash run_squad_deepspeed_onebitadam.sh +``` +For example, in order to use 32 GPUs (4GPUs/node, 8 nodes in total), with the support of InfiniBand, you can use the `mpirun` launcher packaged with the MVAPICH2 library. Please run the folowing command: + +```shell +mpirun -np 32 -ppn 4 -hostfile hosts -env MV2_USE_CUDA=1 -env MV2_SUPPORT_DL=1 -env MV2_ENABLE_AFFINITY=0 -env MV2_SMP_USE_CMA=0 bash run_squad_deepspeed_onebitadam.sh +``` + +### 1.2 Configuration for BingBertSQuAD with DeepSpeed and 1-bit Adam enabled + +The `deepspeed_bsz96_onebit_config.json` file gives the user the ability to specify DeepSpeed +options in terms of batch size, micro batch size, optimizer, learning rate, and other parameters. +When running the `nvidia_run_squad_deepspeed.py`, in addition to the +`--deepspeed` flag to enable DeepSpeed, the appropriate DeepSpeed configuration +file must be specified using `--deepspeed_config deepspeed_bsz96_config.json`. + +Table 1 shows the fine-tuning configuration we used in our experiments. + +| Parameters | Value | +| ------------------------------ | ---------------------| +| Total batch size | 96 | +| Train micro batch size per GPU | 3 | +| Optimizer | **OnebitAdam** | +| Learning rate | 3e-5 | +| Sequence-length | 384 | +| Weight-decay | 0.0 | +| Epoch count | 2 | +| **freeze_step** | 400 | +| **cuda_aware** | True | + +Table 1. Fine-tuning configuration + +### 1.3 Results for BingBertSQuAD Fine-tuning + +The results are summarized in the table below. The total batch size is set to 96 and training is conducted +on 32 GPUs for 2 epochs. A set of parameters (seeds and learning rates) were tried and the best ones were selected. +We fixed the learning rate to 3e-5. The table below shows the F1 and the EM scores we achieved that are on-par or better than the [HuggingFace results](https://github.com/huggingface/transformers/tree/master/examples/question-answering). + +| Case | Model | Precision | EM | F1 | +| ----------- | ------------------------------------- | --------- | ----- | ----- | +| HuggingFace | [Bert-large-uncased-whole-word-masking](https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin) | FP16 | 87.26 | 93.32 | + +**Note:** For more details about loading checkpoint, argument parsing, initialization, forward pass, backward pass, weight update and evaluation, please refer to the [BingBertSQuAD Fine-tuning](https://www.deepspeed.ai/tutorials/bert-finetuning/) tutorial. + + +## 2. BERT Pre-training with 1-bit Adam +For data downloading and pre-processing, please refer to [BERT Pre-training](https://www.deepspeed.ai/tutorials/bert-pretraining/) posts +for more details. + +### 2.1 Running Pre-training with DeepSpeed and 1-bit Adam + +The main part of training is done in `deepspeed_train.py`, which has +already been modified to use DeepSpeed. The `ds_train_bert_onebitadam_bsz4k_seq128.sh` and `ds_train_bert_bsz64k_seq128.sh` are the + shell scripts that +help to invoke training and setup several different hyperparameters relevant +to the training process. + +- **DeepSpeed-enabled:** Start training with DeepSpeed by running the command below: + +```shell +bash ds_train_bert_bsz64k_seq128.sh +``` + +- **DeepSpeed with 1-bit Adam enabled:** In order to run with 1-bit Adam feature enabled, the same script (`deepspeed_train.py`) can be used but there are two options for launching this properly: + +### Launch with deepspeed + +As discussed for BingBertSQuAD fine-tuning, we can simply use the `deepspeed` launcher to launch our BERT pre-training jobs as follows. + +```shell +bash ds_train_bert_onebitadam_bsz4k_seq128.sh +``` + +### Launch with mpirun + +Alternatively, use the following command to launch using `mpirun`. + +```shell +mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flags] bash ds_train_bert_onebitadam_bsz4k_seq128.sh +``` + +For example, in order to use 32 GPUs (4GPUs/node, 8 nodes in total), with the support of InfiniBand, you can use MVAPICH2 as the launcher and run the following command: +```shell +mpirun -np 32 -ppn 4 -hostfile hosts -env MV2_USE_CUDA=1 -env MV2_SUPPORT_DL=1 -env MV2_ENABLE_AFFINITY=0 -env MV2_SMP_USE_CMA=0 bash ds_train_bert_onebitadam_bsz4k_seq128.sh +``` + +### 2.2 Configuration for BingBertSQuAD with DeepSpeed and 1-bit Adam enabled + +The `deepspeed_bsz4k_onebit_config_seq128.json` file gives the user the ability to specify DeepSpeed +options in terms of batch size, micro batch size, optimizer, learning rate, and other parameters. + +Below is the DeepSpeed configuration file for running BERT-large pre-training with sequence length of 128. +```json +{ + "train_batch_size": 4096, + "train_micro_batch_size_per_gpu": 64, + "steps_per_print": 1000, + "optimizer": { + "type": "Adam", + "params": { + "lr": 2e-4, + "max_grad_norm": 1.0, + "weight_decay": 0.01, + "bias_correction": false, + "freeze_step": 23000, + "cuda_aware": true + } + }, + "fp16": { + "enabled": true, + "loss_scale": 0, + "initial_scale_power": 16 + } +} +``` +Notice that for BERT-base training (sequence length 128), the suggested freeze_step is 16000. For the rest of the pre-training using sequence 512, we suggest to use a freeze_step of 1500. + +### 2.3 Results for BERT pre-training + +Using 1-bit Adam, we are able to achieve significantly higher througput compared to the original Adam optimizer. We note that increase training speed during the compressed stage enables overall training speedup of up to 3.5x on Ethernet based systems where communication bandwidth is significantly limited. However, we are able to achieve up to 1.7x overall speedup even for the 40 Gigabit InfiniBand QDR based system. Furthermore, it is important to highlight that we are able to achieve feasible BERT pre-training using 1-bit Adam on a significantly smaller batch size of 4k compared to 32k and 64k for the LAMB optimizer. + +Graphs to be added from the blog post ... diff --git a/install.sh b/install.sh index 3dae80b4033b..b1417596263e 100755 --- a/install.sh +++ b/install.sh @@ -239,5 +239,5 @@ else pdsh -w $hosts "python $tmp_wheel_path/basic_install_test.py" echo "Installation is successful" fi - pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl $tmp_wheel_path/basic_install_test.py $tmp_wheel_path/requirements.txt; rmdir $tmp_wheel_path; fi" + pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl $tmp_wheel_path/basic_install_test.py; rmdir $tmp_wheel_path; fi" fi diff --git a/requirements/requirements-1bit-adam.txt b/requirements/requirements-1bit-adam.txt new file mode 100644 index 000000000000..66c5ba0468f8 --- /dev/null +++ b/requirements/requirements-1bit-adam.txt @@ -0,0 +1 @@ +mpi4py diff --git a/setup.py b/setup.py index 3d57e8791da8..36e6fc0dfa05 100755 --- a/setup.py +++ b/setup.py @@ -27,6 +27,11 @@ def fetch_requirements(path): dev_requires = fetch_requirements('requirements/requirements-dev.txt') sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt') +onebit_adam_requires = fetch_requirements('requirements/requirements-1bit-adam.txt') +if torch.cuda.is_available(): + onebit_adam_requires.append(f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}") +install_requires += onebit_adam_requires + # Build environment variables for custom builds DS_BUILD_LAMB_MASK = 1 DS_BUILD_TRANSFORMER_MASK = 10 @@ -227,7 +232,7 @@ def command_exists(cmd): description='DeepSpeed library', author='DeepSpeed Team', author_email='deepspeed@microsoft.com', - url='http://aka.ms/deepspeed', + url='http://deepspeed.ai', install_requires=install_requires, packages=find_packages(exclude=["docker", "third_party", diff --git a/tests/onebitadam/test_com_reduce_cuda.py b/tests/onebitadam/test_com_reduce_cuda.py new file mode 100644 index 000000000000..a5a87ce67232 --- /dev/null +++ b/tests/onebitadam/test_com_reduce_cuda.py @@ -0,0 +1,86 @@ +from mpi4py import MPI +import time +import torch +import torch.distributed as dist +import numpy as np +import deepspeed +from deepspeed.runtime.fp16.onebit_adam import OnebitAdam + +comm = MPI.COMM_WORLD +size = comm.Get_size() +rank = comm.Get_rank() + +#TODO: Detect the hostname we are running on automatically +torch.distributed.init_process_group(backend='nccl', + init_method='tcp://worker-1:2245', + world_size=size, + rank=rank) + +dummy_model = [torch.nn.Parameter(torch.ones(10))] + +# Set cuda_aware to True to use CUDA buffers for communication +dummy_optim = OnebitAdam(dummy_model, cuda_aware=True) + +device = torch.device('cuda', rank % torch.cuda.device_count()) + + +def torch_sim(a): + a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) + scale = a.norm() / np.sqrt(a.numel()) + a_compressed = scale * a_sign + a_sign = None + worker_error = a - a_compressed + dist.all_reduce(a_compressed) + a_compressed.mul_(1 / dist.get_world_size()) + a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) + a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) + server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list] + a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) + a_server_compressed = torch.cat( + [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) + rank = dist.get_rank() + server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] + torch.cuda.synchronize() + torch.distributed.barrier() + return a_server_compressed, worker_error, server_error + + +tensor_size = 100 * 2**20 +server_size = int(tensor_size / size) +if tensor_size % (8 * size) != 0: + right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) +else: + right_tensor_size = tensor_size +right_server_size = right_tensor_size // size +# Adding bias to the initialization of the gradient we are communicating +# In order to get rid of the case where some elements in the gradient are too small +a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank +worker_error = torch.zeros(right_tensor_size, device=device) +server_error = torch.zeros(right_server_size, device=device) +a_torch, worker_error_torch, server_error_torch = torch_sim(a) +torch.cuda.empty_cache() +local_rank = rank % torch.cuda.device_count() +a_after = dummy_optim.Compressed_Allreduce(a, + worker_error, + server_error, + rank, + size, + comm, + local_rank) +threshold = 1e-6 +magnitude_threshold = 1e-6 +diff_mask = (a_after - a_torch) > threshold +diff_server_mask = torch.chunk(diff_mask, size)[rank] +mpi_server = torch.chunk(a_after, size)[rank] + server_error +torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch + +# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic +# The test would skip those numbers that are too small in compensated_server_m +if torch.sum(diff_server_mask) == 0: + print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank)) +else: + check_mag_mask = mpi_server[diff_mask] > magnitude_threshold + if torch.sum(check_mag_mask) == 0: + print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank)) + else: + print('Fails at {} of positions'.format(torch.sum(check_mag_mask))) diff --git a/tests/onebitadam/test_com_reduce_host.py b/tests/onebitadam/test_com_reduce_host.py new file mode 100644 index 000000000000..1507abc44f24 --- /dev/null +++ b/tests/onebitadam/test_com_reduce_host.py @@ -0,0 +1,86 @@ +from mpi4py import MPI +import time +import torch +import torch.distributed as dist +import numpy as np +import deepspeed +from deepspeed.runtime.fp16.onebit_adam import OnebitAdam + +comm = MPI.COMM_WORLD +size = comm.Get_size() +rank = comm.Get_rank() + +#TODO: Detect the hostname we are running on automatically +torch.distributed.init_process_group(backend='nccl', + init_method='tcp://worker-1:2245', + world_size=size, + rank=rank) + +dummy_model = [torch.nn.Parameter(torch.ones(10))] + +# Set cuda_aware to False to use host buffers for communication +dummy_optim = OnebitAdam(dummy_model, cuda_aware=False) + +device = torch.device('cuda', rank % torch.cuda.device_count()) + + +def torch_sim(a): + a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) + scale = a.norm() / np.sqrt(a.numel()) + a_compressed = scale * a_sign + a_sign = None + worker_error = a - a_compressed + dist.all_reduce(a_compressed) + a_compressed.mul_(1 / dist.get_world_size()) + a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) + a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) + server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list] + a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) + a_server_compressed = torch.cat( + [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) + rank = dist.get_rank() + server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] + torch.cuda.synchronize() + torch.distributed.barrier() + return a_server_compressed, worker_error, server_error + + +tensor_size = 100 * 2**20 +server_size = int(tensor_size / size) +if tensor_size % (8 * size) != 0: + right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) +else: + right_tensor_size = tensor_size +right_server_size = right_tensor_size // size +# Adding bias to the initialization of the gradient we are communicating +# In order to get rid of the case where some elements in the gradient are too small +a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank +worker_error = torch.zeros(right_tensor_size, device=device) +server_error = torch.zeros(right_server_size, device=device) +a_torch, worker_error_torch, server_error_torch = torch_sim(a) +torch.cuda.empty_cache() +local_rank = rank % torch.cuda.device_count() +a_after = dummy_optim.Compressed_Allreduce(a, + worker_error, + server_error, + rank, + size, + comm, + local_rank) +threshold = 1e-6 +magnitude_threshold = 1e-6 +diff_mask = (a_after - a_torch) > threshold +diff_server_mask = torch.chunk(diff_mask, size)[rank] +mpi_server = torch.chunk(a_after, size)[rank] + server_error +torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch + +# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic +# The test would skip those numbers that are too small in compensated_server_m +if torch.sum(diff_server_mask) == 0: + print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank)) +else: + check_mag_mask = mpi_server[diff_mask] > magnitude_threshold + if torch.sum(check_mag_mask) == 0: + print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank)) + else: + print('Fails at {} of positions'.format(torch.sum(check_mag_mask))) diff --git a/tests/onebitadam/test_server_error.py b/tests/onebitadam/test_server_error.py new file mode 100644 index 000000000000..075145f84915 --- /dev/null +++ b/tests/onebitadam/test_server_error.py @@ -0,0 +1,87 @@ +from mpi4py import MPI +import time +import torch +import torch.distributed as dist +import numpy as np +import deepspeed +from deepspeed.runtime.fp16.onebit_adam import OnebitAdam + +comm = MPI.COMM_WORLD +size = comm.Get_size() +rank = comm.Get_rank() + +torch.distributed.init_process_group(backend='nccl', + init_method='tcp://worker-0:2245', + world_size=size, + rank=rank) + +dummy_model = [torch.nn.Parameter(torch.ones(10))] +dummy_optim = OnebitAdam(dummy_model, cuda_aware=False) + +device = torch.device('cuda', rank % torch.cuda.device_count()) + + +def torch_sim(a): + a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) + scale = a.norm() / np.sqrt(a.numel()) + a_compressed = scale * a_sign + a_sign = None + worker_error = a - a_compressed + dist.all_reduce(a_compressed) + a_compressed.mul_(1 / dist.get_world_size()) + a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) + a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) + server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list] + a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) + a_server_compressed = torch.cat( + [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) + rank = dist.get_rank() + server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] + torch.cuda.synchronize() + torch.distributed.barrier() + return a_server_compressed, worker_error, server_error + + +# Input Tensor size +tensor_size = 100 * 2**20 + +server_size = int(tensor_size / size) +if tensor_size % (8 * size) != 0: + right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) +else: + right_tensor_size = tensor_size + +right_server_size = right_tensor_size // size + +# The -0.5 is required for avoiding sign flips/errors +a = torch.rand(tensor_size, device=device) - 0.5 + +worker_error = torch.zeros(right_tensor_size, device=device) +server_error = torch.zeros(right_server_size, device=device) +a_torch, worker_error_torch, server_error_torch = torch_sim(a) +torch.cuda.empty_cache() +local_rank = rank % torch.cuda.device_count() + +# Test the 1-bit Adam optimizer +a_after = dummy_optim.Compressed_Allreduce(a, + worker_error, + server_error, + rank, + size, + comm, + local_rank) + +# If the error is below the threshold, it is acceptable for training +threshold = 1e-6 + +diff_pos = ((a_after - a_torch) > threshold) + +if rank == 0: + before_diff = torch.chunk(a_after - a_torch, + size)[rank] + server_error - server_error_torch + if torch.norm(before_diff) / torch.norm(torch.chunk(a_after, + size)[rank]) < threshold: + print('Successfully passed the test') + else: + print('The difference for the tensor before allgather is {}'.format( + torch.norm(before_diff)))