From fb0dc00f88996dc0c220bacbea5e2c66a7a4542d Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 18 Feb 2022 13:21:09 +0500 Subject: [PATCH 01/29] bf16 updates --- .gitignore | 1 + deepspeed/runtime/bf16_optimizer.py | 160 ++++++++++++++++++++++++++++ deepspeed/runtime/config.py | 1 - deepspeed/runtime/engine.py | 109 +++++++++++++------ deepspeed/runtime/pipe/engine.py | 31 +++++- deepspeed/runtime/utils.py | 19 ++++ 6 files changed, 284 insertions(+), 37 deletions(-) create mode 100644 deepspeed/runtime/bf16_optimizer.py diff --git a/.gitignore b/.gitignore index 84340857f802..ab364ad8a7e7 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ *.swp *.log deepspeed/git_version_info_installed.py +__pycache__ # Build + installation data build/ diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py new file mode 100644 index 000000000000..db7e3558b027 --- /dev/null +++ b/deepspeed/runtime/bf16_optimizer.py @@ -0,0 +1,160 @@ +import torch +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from deepspeed.runtime.utils import get_grad_norm, clip_gradients + + +class BF16_Optimizer: + def __init__(self, + init_optimizer, + mpu=None, + clip_grad=0.0, + norm_type=2, + timers=None): + super().__init__() + self.timers = timers + self.optimizer = init_optimizer + self.clip_grad = clip_grad + self.norm_type = norm_type + self.mpu = mpu + + # Build BF16/FP32 groups + self.bf16_groups = [] + self.fp32_groups = [] + for i, param_group in enumerate(self.optimizer.param_groups): + # grab the original list + self.bf16_groups.append(param_group['params']) + + fp32_group = [p.clone().float().detach() for p in param_group['params']] + for p in fp32_group: + p.requires_grad = True + + # Ensure model parallel attributes are carried over + for lp, hp in zip(param_group['params'], fp32_group): + if hasattr(lp, 'model_parallel'): + hp.model_parallel = lp.model_parallel + if hasattr(lp, '_pipe_replicated'): + hp._pipe_replicated = lp._pipe_replicated + + self.fp32_groups.append(fp32_group) + param_group['params'] = self.fp32_groups[i] + + self.initialize_optimizer_states() + + def initialize_optimizer_states(self): + """Take an optimizer step with zero-valued gradients to allocate internal + optimizer state. + + This helps prevent memory fragmentation by allocating optimizer state at the + beginning of training instead of after activations have been allocated. + """ + for group in self.fp32_groups: + for param in group: + param.grad = torch.zeros(param.size(), + device=param.device, + dtype=param.dtype) + + self.optimizer.step() + self.clear_hp_grads() + + @torch.no_grad() + def step(self, closure=None): + if closure is not None: + raise NotImplementedError(f'{self.__class__} does not support closure.') + + params = self.get_fp32_params(filter_nograd=True) + all_groups_norm = get_grad_norm(parameters=params, + mpu=self.mpu, + norm_type=self.norm_type) + self._global_grad_norm = all_groups_norm + + assert all_groups_norm > 0. + if self.clip_grad > 0.: + clip_gradients(parameters=params, + max_norm=self.clip_grad, + mpu=self.mpu, + global_grad_norm=all_groups_norm) + + self.optimizer.step() + + self.clear_hp_grads() + self.update_lp_params() + + def get_fp32_params(self, filter_nograd=False): + params = [] + for group in self.fp32_groups: + for param in group: + if filter_nograd and param.grad is not None: + params.append(param) + return params + + def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs): + """Perform a backward pass and copy the low-precision gradients to the + high-precision copy. + + We copy/accumulate to the high-precision grads now to prevent accumulating in the + bf16 grads after successive backward() calls (i.e., grad accumulation steps > 1) + + The low-precision grads are deallocated during this procedure. + """ + self.clear_lp_grads() + loss.backward(**bwd_kwargs) + + if update_hp_grads: + self.update_hp_grads(clear_lp_grads=clear_lp_grads) + + @torch.no_grad() + def update_hp_grads(self, clear_lp_grads=False): + for i, group in enumerate(self.bf16_groups): + for lp, hp in zip(group, self.fp32_groups[i]): + if lp.grad is None: + continue + + data_type = hp.dtype + + if hp.grad is None: + hp.grad = lp.grad.to(data_type) + # give the model parameter access to the hp grad as well + lp._hp_grad = hp.grad + else: + hp.grad.data.add_(lp.grad.data.to(data_type)) + + # clear gradients + if clear_lp_grads: + lp.grad = None + + def update_lp_params(self): + for i, group in enumerate(self.bf16_groups): + for lp, hp in zip(group, self.fp32_groups[i]): + lp.data.copy_(hp.data.to(lp.dtype)) + + def clear_hp_grads(self): + for group in self.fp32_groups: + for param in group: + param.grad = None + + def clear_lp_grads(self): + for group in self.bf16_groups: + for param in group: + param.grad = None + + def state_dict(self): + state_dict = {} + state_dict['optimizer_state_dict'] = self.optimizer.state_dict() + state_dict['fp32_groups'] = self.fp32_groups + state_dict['clip_grad'] = self.clip_grad + return state_dict + + def load_state_dict(self, state_dict, load_optimizer_states=True): + if load_optimizer_states: + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + self.clip_grad = state_dict['clip_grad'] + + for i in range(len(self.fp32_groups)): + for current, saved in zip(self.fp32_groups[i], state_dict['fp32_groups'][i]): + current.data.copy_(saved.data) + + @property + def param_groups(self): + """Forward the wrapped optimizer's parameters.""" + return self.optimizer.param_groups \ No newline at end of file diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 009f17557489..b4f4749ebf10 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -899,7 +899,6 @@ def _initialize_params(self, param_dict): self.fp16_enabled = get_fp16_enabled(param_dict) self.bfloat16_enabled = get_bfloat16_enabled(param_dict) assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' - assert not (self.bfloat16_enabled and (self.zero_optimization_stage not in {1, 2, 3})), f'bfloat16 mode is only enabled for Zero 1,2,3 currently. got {self.zero_optimization_stage}' self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled( param_dict) self.amp_enabled = get_amp_enabled(param_dict) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e496589fba0d..44dcb20d1dc3 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -32,6 +32,7 @@ ) from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer +from deepspeed.runtime.bf16_optimizer import BF16_Optimizer from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ @@ -1002,6 +1003,8 @@ def _configure_distributed_model(self, model): hasattr(param, 'ds_id') for param in self.module.parameters()): self.__check_params(self.module, torch.bfloat16) + if self.zero_optimization_stage() == 0 and not self.pipeline_parallelism: + raise NotImplementedError("BF16 support is not yet implemented when not running ZeRO") self.module.bfloat16() else: self.__check_params(self.module, torch.float) @@ -1155,6 +1158,8 @@ def _configure_optimizer(self, client_optimizer, model_parameters): # TODO: maybe need to broadcast experts differently? elif self.fp16_enabled(): self.optimizer = self._configure_fp16_optimizer(basic_optimizer) + elif self.bfloat16_enabled(): + self.optimizer = self._configure_bf16_optimizer(basic_optimizer) else: self.optimizer = basic_optimizer log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()), @@ -1324,6 +1329,25 @@ def _configure_fp16_optimizer(self, optimizer): return optimizer + def _configure_bf16_optimizer(self, optimizer): + clip_grad = self.gradient_clipping() + if APEX_INSTALLED: + fused_opts = (apex.optimizers.FusedAdam, FusedAdam) + else: + fused_opts = FusedAdam + if isinstance(optimizer, fused_opts): + if self.global_rank == 0: + logger.info('Creating unfused BF16 optimizer') + timers = self.timers if self.wall_clock_breakdown() else None + optimizer = BF16_Optimizer(optimizer, + mpu=self.mpu, + clip_grad=clip_grad, + timers=timers) + else: + raise NotImplementedError('BF16 requires a fused optimizer for now.') + + return optimizer + def _configure_zero_optimizer(self, optimizer): zero_stage = self.zero_optimization_stage() log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage), ranks=[0]) @@ -1731,8 +1755,20 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False): self._start_timers(self.engine_timers.backward_reduce_timers) - if self.enable_backward_allreduce: - self.allreduce_gradients() + if allreduce_gradients and self.enable_backward_allreduce: + if self.bfloat16_enabled(): + # Make our own list of gradients from the optimizer's FP32 grads + grads = [] + for param_group in self.optimizer.fp32_groups: + for param in param_group: + assert param.grad is not None + assert param.grad.dtype == torch.float32 + grads.append(param.grad.data) + print(f'rank={self.global_rank} {len(grads)=}') + self.buffered_allreduce_fallback(grads=grads) + else: + # Traditional code path that allreduces the module parameter grads + self.allreduce_gradients() self._stop_timers(self.engine_timers.backward_reduce_timers) @@ -1799,8 +1835,8 @@ def clip_fp32_gradients(self): def _take_model_step(self, lr_kwargs, block_eigenvalue={}): if self.gradient_clipping() > 0.0: - if not (self.fp16_enabled() or self.amp_enabled() - or self.zero_optimization()): + if not (self.fp16_enabled() or self.bfloat16_enabled() + or self.amp_enabled() or self.zero_optimization()): self.clip_fp32_gradients() elif self.amp_enabled(): # AMP's recommended way of doing clipping @@ -1832,7 +1868,7 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): if (not self.zero_optimization() and not self.fp16_enabled() and not self.amp_enabled()): self.zero_grad() - else: + elif not self.bfloat16_enabled(): self.optimizer.zero_grad() report_progress = self.global_rank == 0 if self.global_rank else True @@ -2126,42 +2162,45 @@ def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000): self.allreduce_and_copy(small_bucket, dp_group) def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000): - grads = [] - expert_grads = {} - if self.has_moe_layers: - for key in self.expert_data_parallel_group.keys(): - expert_grads[key] = [] + if grads is not None: + assert not self.has_moe_layers, "attempting to reduce grads in unsupported way w.r.t. MoE" + else: + grads = [] + expert_grads = {} + if self.has_moe_layers: + for key in self.expert_data_parallel_group.keys(): + expert_grads[key] = [] - for param_name, param in self.module.named_parameters(): - if hasattr(param, 'allreduce') and not param.allreduce: - is_moe_param = True - else: - is_moe_param = False - if param.grad is None: - # In cases where there is an imbalance of empty grads across - # ranks we must create empty grads, this will ensure that every - # rank is reducing the same size. In some cases it may make - # sense in the future to support the ability to average not - # w.r.t. world size but with a different value. - param.grad = torch.zeros(param.size(), - dtype=param.dtype, - device=param.device) - if is_moe_param: - expert_grads[param.group_name].append(param.grad.data) + for param_name, param in self.module.named_parameters(): + if hasattr(param, 'allreduce') and not param.allreduce: + is_moe_param = True else: - grads.append(param.grad.data) - else: - grad_data = param.grad.data - if param_name in self.sparse_tensor_module_names or grad_data.is_sparse: + is_moe_param = False + if param.grad is None: + # In cases where there is an imbalance of empty grads across + # ranks we must create empty grads, this will ensure that every + # rank is reducing the same size. In some cases it may make + # sense in the future to support the ability to average not + # w.r.t. world size but with a different value. + param.grad = torch.zeros(param.size(), + dtype=param.dtype, + device=param.device) if is_moe_param: - expert_grads[param.group_name].append(SparseTensor(grad_data)) + expert_grads[param.group_name].append(param.grad.data) else: - grads.append(SparseTensor(grad_data)) + grads.append(param.grad.data) else: - if is_moe_param: - expert_grads[param.group_name].append(grad_data) + grad_data = param.grad.data + if param_name in self.sparse_tensor_module_names or grad_data.is_sparse: + if is_moe_param: + expert_grads[param.group_name].append(SparseTensor(grad_data)) + else: + grads.append(SparseTensor(grad_data)) else: - grads.append(grad_data) + if is_moe_param: + expert_grads[param.group_name].append(grad_data) + else: + grads.append(grad_data) split_buckets = split_half_float_double_sparse(grads) for _, bucket_tuple in enumerate(split_buckets): diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 258a175ff76f..49562b6cc543 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -246,9 +246,30 @@ def _exec_reduce_tied_grads(self): def _exec_reduce_grads(self): self._force_grad_boundary = True if self.pipeline_enable_backward_allreduce: - self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE) + if self.bfloat16_enabled(): + if self.zero_optimization_stage() == 0: + self._bf16_reduce_grads() + else: + assert self.zero_optimization_stage() == 1, "only bf16 + z1 are supported" + raise NotImplementedError() + else: + self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE) self._force_grad_boundary = False + def _bf16_reduce_grads(self): + # Make our own list of gradients from the optimizer's FP32 grads + grads = [] + for param_group in self.optimizer.fp32_groups: + for param in param_group: + if param.grad is None: + continue + assert param.grad is not None + assert param.grad.dtype == torch.float32 + grads.append(param.grad.data) + self.buffered_allreduce_fallback( + grads=grads, + elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE) + def _reserve_pipe_buffers(self, num_buffers): """Ensure that each pipeline buffer has at least ``num_buffers`` slots. @@ -726,6 +747,10 @@ def _exec_backward_pass(self, buffer_id): part_grad = None #print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') + if self.bfloat16_enabled() and not self.is_last_stage(): + # manually call because we don't call optimizer.backward() + self.optimizer.clear_lp_grads() + # This handles either a single tensor or tuple of tensors. if isinstance(outputs, tuple): out_tensors = [t for t in outputs if t.is_floating_point()] @@ -734,6 +759,10 @@ def _exec_backward_pass(self, buffer_id): else: torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, )) + if self.bfloat16_enabled() and not self.is_last_stage(): + # manually call because we don't call optimizer.backward() + self.optimizer.update_hp_grads(clear_lp_grads=False) + # Free up the memory from the output of forward() self.pipe_buffers['output_tensors'][buffer_id] = None self.pipe_buffers['outputs'][buffer_id] = None diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 2decd12ffc9a..2193d9ced3a6 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -867,3 +867,22 @@ def get_only_unique_item(items): unique_item, = item_set return unique_item + + +def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, eps=1e-6): + """Clip the gradient of a list of parameters. + Args: + parameters: List of parameters whose .grad will be clipped. + global_grad_norm (float, optional): Precomputed gradient norm. Defaults to None. + mpu (optional): model parallelism unit. Defaults to None. + eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6 + Returns: + float: the global gradient norm + """ + if global_grad_norm is None: + global_grad_norm = get_grad_norm(parameters, mpu=mpu) + clip_coef = max_norm / (global_grad_norm + eps) + if clip_coef < 1: + for p in parameters: + p.grad.detach().mul_(clip_coef) + return global_grad_norm \ No newline at end of file From 6eb4f1fa98e9dcd788f18c30235fbf886949d5bc Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Mon, 21 Feb 2022 23:12:47 +0500 Subject: [PATCH 02/29] Got bf16 working --- deepspeed/runtime/bf16_optimizer.py | 44 ++++++++++++++++++-- deepspeed/runtime/engine.py | 45 ++++++-------------- deepspeed/runtime/fp16/fused_optimizer.py | 50 +++++++++++++++-------- deepspeed/runtime/pipe/engine.py | 5 +-- deepspeed/runtime/utils.py | 18 +++++++- deepspeed/runtime/zero/stage_1_and_2.py | 21 ++++------ 6 files changed, 113 insertions(+), 70 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index db7e3558b027..8b8464a27d18 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -1,7 +1,8 @@ import torch -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +import torch.distributed as dist +from deepspeed.ops.op_builder import UtilsBuilder -from deepspeed.runtime.utils import get_grad_norm, clip_gradients +from deepspeed.runtime.utils import (get_grad_norm, clip_gradients, align_dense_tensors) class BF16_Optimizer: @@ -10,6 +11,7 @@ def __init__(self, mpu=None, clip_grad=0.0, norm_type=2, + dp_process_group=None, timers=None): super().__init__() self.timers = timers @@ -17,13 +19,41 @@ def __init__(self, self.clip_grad = clip_grad self.norm_type = norm_type self.mpu = mpu + self.dp_process_group = dp_process_group + + self.real_dp_process_group = [ + dp_process_group for i in range(len(self.optimizer.param_groups)) + ] + + # Load pre-built or JIT compile (un)flatten ops + util_ops = UtilsBuilder().load() + self.flatten = util_ops.flatten + self.unflatten = util_ops.unflatten + + #align nccl all-gather send buffers to 4-bye boundary + self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2 # Build BF16/FP32 groups self.bf16_groups = [] + self.bf16_groups_flat = [] self.fp32_groups = [] + self.single_partition_of_fp32_groups = [] + data_parallel_world_size = dist.get_world_size(group=self.dp_process_group) + for i, param_group in enumerate(self.optimizer.param_groups): # grab the original list self.bf16_groups.append(param_group['params']) + self.bf16_groups_flat.append( + self._flatten_dense_tensors_aligned( + self.bf16_groups[i], + self.nccl_start_alignment_factor * + dist.get_world_size(group=self.real_dp_process_group[i]))) + + # self.single_partition_of_fp32_groups.append(self.bf16_groups_flat[i].clone().float().detach()) + # self.single_partition_of_fp32_groups[i].requires_grad = True + + # param_group['params'] = [self.single_partition_of_fp32_groups[i]] + # self._unflatten_dense_tensors(self.bf16_groups[i], self.bf16_groups_flat[i]) fp32_group = [p.clone().float().detach() for p in param_group['params']] for p in fp32_group: @@ -57,6 +87,14 @@ def initialize_optimizer_states(self): self.optimizer.step() self.clear_hp_grads() + def _unflatten_dense_tensors(self, tensor_list, flattened_tensors): + updated_params = self.unflatten(flattened_tensors, tensor_list) + for p, q in zip(tensor_list, updated_params): + p.data = q.data + + def _flatten_dense_tensors_aligned(self, tensor_list, alignment): + return self.flatten(align_dense_tensors(tensor_list, alignment)) + @torch.no_grad() def step(self, closure=None): if closure is not None: @@ -157,4 +195,4 @@ def load_state_dict(self, state_dict, load_optimizer_states=True): @property def param_groups(self): """Forward the wrapped optimizer's parameters.""" - return self.optimizer.param_groups \ No newline at end of file + return self.optimizer.param_groups diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 44dcb20d1dc3..b7060fbc334b 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -437,19 +437,6 @@ def set_train_batch_size(self, train_batch_size): self._config.train_batch_size = train_batch_size self._config.gradient_accumulation_steps = new_gas - def get_global_grad_norm(self) -> float: - """Return the 2-norm of all gradients. If there is model parallelism, - the norm will be global. - - The computed norm will be cached and reused until the next step() pass. - .. note:: - In the presence of model parallelism, this is a collective call - and acts as a barrier among ``mpu.get_model_parallel_group()``. - Returns: - float: norm - """ - return self._global_grad_norm - def checkpoint_tag_validation_enabled(self): return self._config.checkpoint_tag_validation_enabled @@ -1004,7 +991,8 @@ def _configure_distributed_model(self, model): 'ds_id') for param in self.module.parameters()): self.__check_params(self.module, torch.bfloat16) if self.zero_optimization_stage() == 0 and not self.pipeline_parallelism: - raise NotImplementedError("BF16 support is not yet implemented when not running ZeRO") + raise NotImplementedError( + "BF16 support is not yet implemented when not running ZeRO") self.module.bfloat16() else: self.__check_params(self.module, torch.float) @@ -1342,6 +1330,7 @@ def _configure_bf16_optimizer(self, optimizer): optimizer = BF16_Optimizer(optimizer, mpu=self.mpu, clip_grad=clip_grad, + dp_process_group=self.data_parallel_group, timers=timers) else: raise NotImplementedError('BF16 requires a fused optimizer for now.') @@ -1745,6 +1734,8 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False): self.optimizer.backward(loss, create_graph=True, retain_graph=True) else: self.optimizer.backward(loss) + elif self.bfloat16_enabled(): + self.optimizer.backward(loss) else: if self.eigenvalue_enabled(): loss.backward(create_graph=True, retain_graph=True) @@ -1756,19 +1747,8 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False): self._start_timers(self.engine_timers.backward_reduce_timers) if allreduce_gradients and self.enable_backward_allreduce: - if self.bfloat16_enabled(): - # Make our own list of gradients from the optimizer's FP32 grads - grads = [] - for param_group in self.optimizer.fp32_groups: - for param in param_group: - assert param.grad is not None - assert param.grad.dtype == torch.float32 - grads.append(param.grad.data) - print(f'rank={self.global_rank} {len(grads)=}') - self.buffered_allreduce_fallback(grads=grads) - else: - # Traditional code path that allreduces the module parameter grads - self.allreduce_gradients() + # Traditional code path that allreduces the module parameter grads + self.allreduce_gradients() self._stop_timers(self.engine_timers.backward_reduce_timers) @@ -1835,8 +1815,8 @@ def clip_fp32_gradients(self): def _take_model_step(self, lr_kwargs, block_eigenvalue={}): if self.gradient_clipping() > 0.0: - if not (self.fp16_enabled() or self.bfloat16_enabled() - or self.amp_enabled() or self.zero_optimization()): + if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() + or self.zero_optimization()): self.clip_fp32_gradients() elif self.amp_enabled(): # AMP's recommended way of doing clipping @@ -2183,8 +2163,8 @@ def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000) # sense in the future to support the ability to average not # w.r.t. world size but with a different value. param.grad = torch.zeros(param.size(), - dtype=param.dtype, - device=param.device) + dtype=param.dtype, + device=param.device) if is_moe_param: expert_grads[param.group_name].append(param.grad.data) else: @@ -2193,7 +2173,8 @@ def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000) grad_data = param.grad.data if param_name in self.sparse_tensor_module_names or grad_data.is_sparse: if is_moe_param: - expert_grads[param.group_name].append(SparseTensor(grad_data)) + expert_grads[param.group_name].append( + SparseTensor(grad_data)) else: grads.append(SparseTensor(grad_data)) else: diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 86ffc5ab92c0..fdba94c55af7 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -32,11 +32,13 @@ def __init__(self, mpu=None, clip_grad=0.0, fused_adam_legacy=False, + has_moe_layers=False, timers=None): self.fused_adam_legacy = fused_adam_legacy self.timers = timers self.deepspeed = deepspeed + self.has_moe_layers = has_moe_layers self.using_pipeline = self.deepspeed.pipeline_parallelism if not torch.cuda.is_available: raise SystemError("Cannot use fp16 without CUDA.") @@ -167,11 +169,15 @@ def step_fused_adam(self, closure=None): self.cur_scale)) return self.overflow - self._global_grad_norm = get_global_norm(norm_list=norm_groups) + scaled_grad_norm = get_global_norm(norm_list=norm_groups) combined_scale = self.unscale_and_clip_grads(grads_groups_flat, - self._global_grad_norm, + scaled_grad_norm, apply_scale=False) + + # Stash unscaled gradient norm + self._global_grad_norm = scaled_global_grad_norm / self.cur_scale + # norm is in fact norm*cur_scale self.optimizer.step(grads=[[g] for g in grads_groups_flat], output_params=[[p] for p in self.fp16_groups_flat], @@ -258,26 +264,19 @@ def step(self, closure=None): self.start_timers([COMPUTE_NORM]) all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) - #all_groups_norm_old = all_groups_norm - # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce - if self.using_pipeline: - pg = self.deepspeed.mpu.get_data_parallel_group() - else: - pg = groups.get_data_parallel_group() - scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg)) - scaled_norm_tensor = torch.tensor(scaled_norm, - device=self.fp32_groups_flat[i].device, - dtype=torch.float) - dist.all_reduce(scaled_norm_tensor, group=pg) - all_groups_norm = scaled_norm_tensor.item() - #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {torch.distributed.get_rank()}") self.stop_timers([COMPUTE_NORM]) - self._global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) + if self.has_moe_layers: + scaled_global_grad_norm = self._get_norm_with_moe_layers(all_groups_norm) + else: + scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) + + # Stash unscaled gradient norm + self._global_grad_norm = scaled_global_grad_norm / self.cur_scale self.start_timers([UNSCALE_AND_CLIP]) - self.unscale_and_clip_grads(grads_groups_flat, self._global_grad_norm) + self.unscale_and_clip_grads(grads_groups_flat, scaled_global_grad_norm) self.stop_timers([UNSCALE_AND_CLIP]) self.start_timers([BASIC_STEP]) @@ -302,6 +301,23 @@ def step(self, closure=None): return self.overflow + def _get_norm_with_moe_layers(self, all_groups_norm): + total_norm = get_global_norm(norm_list=[all_groups_norm]) + #all_groups_norm_old = all_groups_norm + # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce + if self.using_pipeline: + pg = self.deepspeed.mpu.get_data_parallel_group() + else: + pg = groups.get_data_parallel_group() + scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg)) + scaled_norm_tensor = torch.tensor(scaled_norm, + device=self.fp32_groups_flat[i].device, + dtype=torch.float) + dist.all_reduce(scaled_norm_tensor, group=pg) + all_groups_norm = scaled_norm_tensor.item() + #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {torch.distributed.get_rank()}") + return all_groups_norm + def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True): # compute combined scale factor for this group combined_scale = self.cur_scale diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 49562b6cc543..fe10064ddc25 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -266,9 +266,8 @@ def _bf16_reduce_grads(self): assert param.grad is not None assert param.grad.dtype == torch.float32 grads.append(param.grad.data) - self.buffered_allreduce_fallback( - grads=grads, - elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE) + self.buffered_allreduce_fallback(grads=grads, + elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE) def _reserve_pipe_buffers(self, num_buffers): """Ensure that each pipeline buffer has at least ``num_buffers`` slots. diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 2193d9ced3a6..ea0ea9eb2098 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -885,4 +885,20 @@ def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, ep if clip_coef < 1: for p in parameters: p.grad.detach().mul_(clip_coef) - return global_grad_norm \ No newline at end of file + return global_grad_norm + + +def align_dense_tensors(tensor_list, alignment): + num_elements = sum(t.numel() for t in tensor_list) + remaining = num_elements % alignment + + if remaining: + elements_to_add = alignment - remaining + pad_tensor = torch.zeros(elements_to_add, + device=tensor_list[0].device, + dtype=tensor_list[0].dtype) + padded_tensor_list = tensor_list + [pad_tensor] + else: + padded_tensor_list = tensor_list + + return padded_tensor_list diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 59020684842b..6b7dd5d594ba 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -9,7 +9,12 @@ from packaging import version as pkg_version from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, get_global_norm, see_memory_usage, is_model_parallel_parameter +from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, + get_global_norm, + see_memory_usage, + is_model_parallel_parameter, + align_dense_tensors) + from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS from deepspeed.runtime.zero.offload_constants import OFFLOAD_CPU_DEVICE, OFFLOAD_OPTIMIZER from deepspeed.ops.adam import DeepSpeedCPUAdam @@ -791,19 +796,7 @@ def report_ipg_memory_usage(self, tag, param_elems): # create a flat tensor aligned at the alignment boundary def flatten_dense_tensors_aligned(self, tensor_list, alignment): - num_elements = sum(t.numel() for t in tensor_list) - remaining = num_elements % alignment - - if remaining: - elements_to_add = alignment - remaining - pad_tensor = torch.zeros(elements_to_add, - device=tensor_list[0].device, - dtype=tensor_list[0].dtype) - padded_tensor_list = tensor_list + [pad_tensor] - else: - padded_tensor_list = tensor_list - - return self.flatten(padded_tensor_list) + return self.flatten(align_dense_tensors(tensor_list, alignment)) ############### Independent Partition Gradient ######################## def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): From a3d3576e0855cd760d6e6c7a574aa11d0edad5fc Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 24 Feb 2022 00:07:00 +0500 Subject: [PATCH 03/29] fp32 reduction; flattened tensors --- deepspeed/runtime/bf16_optimizer.py | 145 ++++++++++++++++-------- deepspeed/runtime/constants.py | 5 + deepspeed/runtime/engine.py | 4 + deepspeed/runtime/utils.py | 78 ++++++++++++- deepspeed/runtime/zero/stage_1_and_2.py | 4 +- 5 files changed, 184 insertions(+), 52 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 8b8464a27d18..e6139946a4b2 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -1,8 +1,13 @@ import torch import torch.distributed as dist +from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.ops.op_builder import UtilsBuilder -from deepspeed.runtime.utils import (get_grad_norm, clip_gradients, align_dense_tensors) +from deepspeed.runtime.utils import (get_global_norm_of_tensors, + clip_tensors_by_global_norm, + get_grad_norm, + clip_gradients, + align_dense_tensors) class BF16_Optimizer: @@ -20,10 +25,7 @@ def __init__(self, self.norm_type = norm_type self.mpu = mpu self.dp_process_group = dp_process_group - - self.real_dp_process_group = [ - dp_process_group for i in range(len(self.optimizer.param_groups)) - ] + self.dp_rank = dist.get_rank(group=self.dp_process_group) # Load pre-built or JIT compile (un)flatten ops util_ops = UtilsBuilder().load() @@ -36,40 +38,67 @@ def __init__(self, # Build BF16/FP32 groups self.bf16_groups = [] self.bf16_groups_flat = [] + # TODO: Need to only track fp32 params of this partition self.fp32_groups = [] + self.fp32_groups_flat = [] self.single_partition_of_fp32_groups = [] - data_parallel_world_size = dist.get_world_size(group=self.dp_process_group) + self.fp32_groups_gradients = [] + self.fp32_groups_gradients_flat = [] + + dp_world_size = dist.get_world_size(group=self.dp_process_group) for i, param_group in enumerate(self.optimizer.param_groups): # grab the original list self.bf16_groups.append(param_group['params']) + + # create flat bf16 params self.bf16_groups_flat.append( self._flatten_dense_tensors_aligned( self.bf16_groups[i], - self.nccl_start_alignment_factor * - dist.get_world_size(group=self.real_dp_process_group[i]))) - - # self.single_partition_of_fp32_groups.append(self.bf16_groups_flat[i].clone().float().detach()) - # self.single_partition_of_fp32_groups[i].requires_grad = True - - # param_group['params'] = [self.single_partition_of_fp32_groups[i]] - # self._unflatten_dense_tensors(self.bf16_groups[i], self.bf16_groups_flat[i]) - - fp32_group = [p.clone().float().detach() for p in param_group['params']] - for p in fp32_group: - p.requires_grad = True - - # Ensure model parallel attributes are carried over - for lp, hp in zip(param_group['params'], fp32_group): - if hasattr(lp, 'model_parallel'): - hp.model_parallel = lp.model_parallel - if hasattr(lp, '_pipe_replicated'): - hp._pipe_replicated = lp._pipe_replicated - - self.fp32_groups.append(fp32_group) - param_group['params'] = self.fp32_groups[i] + self.nccl_start_alignment_factor * dp_world_size)) + + # Make bf16 params point to flat tensor storage + self._update_storage_to_flattened_tensor( + tensor_list=self.bf16_groups[i], + flat_tensor=self.bf16_groups_flat[i]) + + # create flat fp32 params + self.fp32_groups_flat.append( + self.bf16_groups_flat[i].clone().float().detach()) + self.fp32_groups_flat[i].requires_grad = True + + num_elem_list = [t.numel() for t in self.bf16_groups[i]] + + # create fp32 params using flat tensor storage + fp32_group_params = self._split_flat_tensor( + flat_tensor=self.fp32_groups_flat[i], + num_elem_list=num_elem_list) + self._propagate_attributes(src_tensor_list=self.bf16_groups[i], + dst_tensor_list=fp32_group_params) + self.fp32_groups.append(fp32_group_params) + + # create fp32 gradients + self.fp32_groups_gradients_flat.append( + torch.zeros_like(self.fp32_groups_flat[i])) + fp32_gradients = self._split_flat_tensor( + flat_tensor=self.fp32_groups_gradients_flat[i], + num_elem_list=num_elem_list) + self.fp32_groups_gradients.append(fp32_gradients) + + # create fp32 partition from flat tensor storage + assert self.fp32_groups_flat[i].numel() % dp_world_size == 0, \ + f'group {i} flat tensor size {self.fp32_groups_flat[i].numel()} not divisible by DP world size {dp_world_size}' + + partition_size = self.fp32_groups_flat[i].numel() // dp_world_size + self.single_partition_of_fp32_groups.append( + torch.narrow(self.fp32_groups_flat[i], + 0, + self.dp_rank * partition_size, + partition_size)) + param_group['params'] = [self.single_partition_of_fp32_groups[i]] self.initialize_optimizer_states() + self._init_hp_grads() def initialize_optimizer_states(self): """Take an optimizer step with zero-valued gradients to allocate internal @@ -78,17 +107,32 @@ def initialize_optimizer_states(self): This helps prevent memory fragmentation by allocating optimizer state at the beginning of training instead of after activations have been allocated. """ - for group in self.fp32_groups: - for param in group: - param.grad = torch.zeros(param.size(), - device=param.device, - dtype=param.dtype) + for i, single_partition in enumerate(self.single_partition_of_fp32_groups): + single_partition.grad = self.fp32_groups_gradients_flat[i] self.optimizer.step() self.clear_hp_grads() - def _unflatten_dense_tensors(self, tensor_list, flattened_tensors): - updated_params = self.unflatten(flattened_tensors, tensor_list) + def _propagate_attributes(self, src_tensor_list, dst_tensor_list): + for src_tensor, dst_tensor in zip(src_tensor_list, dst_tensor_list): + if hasattr(src_tensor, 'model_parallel'): + dst_tensor.model_parallel = src_tensor.model_parallel + if hasattr(src_tensor, PIPE_REPLICATED): + dst_tensor.ds_pipe_replicated = src_tensor.ds_pipe_replicated + + def _split_flat_tensor(self, flat_tensor, num_elem_list): + assert sum(num_elem_list) <= flat_tensor.numel() + tensor_list = [] + offset = 0 + for num_elem in num_elem_list: + dense_tensor = torch.narrow(flat_tensor, 0, offset, num_elem) + tensor_list.append(dense_tensor) + offset += num_elem + + return tensor_list + + def _update_storage_to_flattened_tensor(self, tensor_list, flat_tensor): + updated_params = self.unflatten(flat_tensor, tensor_list) for p, q in zip(tensor_list, updated_params): p.data = q.data @@ -144,32 +188,39 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg @torch.no_grad() def update_hp_grads(self, clear_lp_grads=False): for i, group in enumerate(self.bf16_groups): - for lp, hp in zip(group, self.fp32_groups[i]): + for j, (lp, hp) in enumerate(zip(group, self.fp32_groups[i])): if lp.grad is None: continue - data_type = hp.dtype + assert hp.grad is not None, \ + f'high precision param has no gradient, param_id = {id(hp)} group_info = [{i}][{j}]' - if hp.grad is None: - hp.grad = lp.grad.to(data_type) - # give the model parameter access to the hp grad as well - lp._hp_grad = hp.grad - else: - hp.grad.data.add_(lp.grad.data.to(data_type)) + hp.grad.data.add_(lp.grad.data.to(hp.dtype).view(hp.shape)) + lp._hp_grad = hp.grad # clear gradients if clear_lp_grads: lp.grad = None + @torch.no_grad() + def get_grads_for_reduction(self): + return self.fp32_groups_gradients_flat + + @torch.no_grad() def update_lp_params(self): for i, group in enumerate(self.bf16_groups): for lp, hp in zip(group, self.fp32_groups[i]): - lp.data.copy_(hp.data.to(lp.dtype)) + lp.data.copy_(hp.data.to(lp.dtype).view(lp.shape)) + + @torch.no_grad() + def _init_hp_grads(self): + for i, group in enumerate(self.bf16_groups): + for j, (lp, hp) in enumerate(zip(group, self.fp32_groups[i])): + hp.grad = self.fp32_groups_gradients[i][j] def clear_hp_grads(self): - for group in self.fp32_groups: - for param in group: - param.grad = None + for flat_gradients in self.fp32_groups_gradients_flat: + flat_gradients.zero_() def clear_lp_grads(self): for group in self.bf16_groups: diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 2d16f39433c3..ee2e51c6109f 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -446,3 +446,8 @@ class ValidationMode: ''' DATALOADER_DROP_LAST = "dataloader_drop_last" DATALOADER_DROP_LAST_DEFAULT = False + +######################################### +# PIPELINE PARALLELISM +######################################### +PIPE_REPLICATED = 'ds_pipe_replicated' diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index b7060fbc334b..4d16a31d8163 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1674,6 +1674,10 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES: self.optimizer.reduce_gradients( pipeline_parallel=self.pipeline_parallelism) + elif self.bfloat16_enabled() and self.pipeline_parallelism: + self.buffered_allreduce_fallback( + grads=self.optimizer.get_grads_for_reduction(), + elements_per_buffer=bucket_size) else: self.buffered_allreduce_fallback(elements_per_buffer=bucket_size) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index ea0ea9eb2098..4aaf5b59e8de 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -6,6 +6,7 @@ Helper functions and classes from multiple sources. ''' +from collections.abc import Iterable from deepspeed.moe.utils import is_moe_param, split_params_into_shared_and_expert_params import os import psutil @@ -19,6 +20,7 @@ import torch.distributed as dist from deepspeed.utils import groups, logger +from deepspeed.runtime.constants import PIPE_REPLICATED from numpy import prod # pt-1.9 deprecations @@ -423,7 +425,7 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) for p in parameters: # Pipeline parallelism may replicate parameters. Avoid multi-counting. - if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: continue # Filter to avoid over-counting replicated tensors from tensor @@ -469,7 +471,7 @@ def get_grad_zeros(parameters, mpu=None): tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) for p in parameters: # Pipeline parallelism may replicate parameters. Avoid multi-counting. - if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: continue # Filter to avoid over-counting replicated tensors from tensor @@ -526,7 +528,7 @@ def get_weight_norm(parameters, norm_type=2, mpu=None): tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) for p in parameters: # Pipeline parallelism may replicate parameters. Avoid multi-counting. - if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: continue # Filter to avoid over-counting replicated tensors from tensor @@ -888,6 +890,76 @@ def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, ep return global_grad_norm +def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): + """Get norm of an iterable of tensors. + + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Taken from Nvidia Megatron. + + Arguments: + input_tensors (Iterable[Tensor]): an iterable of Tensors will have norm computed + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the tensors (viewed as a single vector). + """ + + assert isinstance(input_tensors, Iterable), f'expected Iterable type not {type(input_tensors)}' + assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors' + + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(t.data.abs().max() for t in input_tensors) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + if mpu is not None: + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.MAX, + group=mpu.get_model_parallel_group()) + total_norm = total_norm_cuda[0].item() + else: + total_norm = sum( + [t.data.float().norm(norm_type).item()**norm_type for t in input_tensors]) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + if mpu is not None: + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_model_parallel_group()) + total_norm = total_norm_cuda[0].item()**(1. / norm_type) + + if total_norm == float( + 'inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + +def clip_tensors_by_global_norm(input_tensors, + max_norm=1.0, + global_norm=None, + mpu=None, + eps=1e-6): + """Clip list of tensors by global norm. + Args: + input_tensors: List of tensors to be clipped + global_grad_norm (float, optional): Precomputed norm. Defaults to None. + mpu (optional): model parallelism unit. Defaults to None. + eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6 + Returns: + float: the global norm + """ + if global_norm is None: + global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu) + + clip_coef = max_norm / (global_grad_norm + eps) + + if clip_coef < 1: + for t in input_tensors: + t.detach().mul_(clip_coef) + + return global_norm + + def align_dense_tensors(tensor_list, alignment): num_elements = sum(t.numel() for t in tensor_list) remaining = num_elements % alignment diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 6b7dd5d594ba..b48ab6281d54 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1107,7 +1107,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): norm_type = 2.0 for p in params: # Pipeline parallelism may replicate parameters. Avoid multi-counting. - if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: continue if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): @@ -1519,7 +1519,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): # logger.info(f"Total Norm beginning {total_norm}") for g, p in zip(gradients, params): # Pipeline parallelism may replicate parameters. Avoid multi-counting. - if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated: + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: continue if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_norm = g.data.double().norm(2) From 6f5ebc37af0331fa743877555f917b544643ef4f Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 24 Feb 2022 07:45:10 +0500 Subject: [PATCH 04/29] bf16+zero_stage_1 first cut --- deepspeed/runtime/bf16_optimizer.py | 111 +++++++++++++++++------- deepspeed/runtime/engine.py | 12 +-- deepspeed/runtime/utils.py | 41 +++++++++ deepspeed/runtime/zero/stage_1_and_2.py | 77 ++++++++-------- deepspeed/runtime/zero/utils.py | 8 +- 5 files changed, 173 insertions(+), 76 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index e6139946a4b2..5aa41ac1eb20 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -7,7 +7,9 @@ clip_tensors_by_global_norm, get_grad_norm, clip_gradients, - align_dense_tensors) + align_dense_tensors, + all_gather_dp_groups, + see_memory_usage) class BF16_Optimizer: @@ -16,16 +18,22 @@ def __init__(self, mpu=None, clip_grad=0.0, norm_type=2, + allgather_bucket_size=5000000000, dp_process_group=None, timers=None): super().__init__() + see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers self.optimizer = init_optimizer self.clip_grad = clip_grad self.norm_type = norm_type self.mpu = mpu + self.allgather_bucket_size = int(allgather_bucket_size) self.dp_process_group = dp_process_group self.dp_rank = dist.get_rank(group=self.dp_process_group) + self.real_dp_process_group = [ + dp_process_group for i in range(len(self.optimizer.param_groups)) + ] # Load pre-built or JIT compile (un)flatten ops util_ops = UtilsBuilder().load() @@ -38,16 +46,26 @@ def __init__(self, # Build BF16/FP32 groups self.bf16_groups = [] self.bf16_groups_flat = [] + self.bf16_partitioned_groups = [] + # TODO: Need to only track fp32 params of this partition self.fp32_groups = [] self.fp32_groups_flat = [] - self.single_partition_of_fp32_groups = [] + self.fp32_groups_flat_partition = [] + + # Maintain different fp32 gradients views for convenience self.fp32_groups_gradients = [] self.fp32_groups_gradients_flat = [] + self.fp32_groups_actual_gradients_flat = [] + self.fp32_groups_gradient_flat_partition = [] dp_world_size = dist.get_world_size(group=self.dp_process_group) for i, param_group in enumerate(self.optimizer.param_groups): + see_memory_usage(f'before initializing group {i}', force=True) + + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + # grab the original list self.bf16_groups.append(param_group['params']) @@ -57,6 +75,16 @@ def __init__(self, self.bf16_groups[i], self.nccl_start_alignment_factor * dp_world_size)) + # divide flat weights into equal sized partitions + partition_size = self.bf16_groups_flat[i].numel() // dp_world_size + bf16_dp_partitions = [ + self.bf16_groups_flat[i].narrow(0, + dp_index * partition_size, + partition_size) + for dp_index in range(dp_world_size) + ] + self.bf16_partitioned_groups.append(bf16_dp_partitions) + # Make bf16 params point to flat tensor storage self._update_storage_to_flattened_tensor( tensor_list=self.bf16_groups[i], @@ -80,25 +108,46 @@ def __init__(self, # create fp32 gradients self.fp32_groups_gradients_flat.append( torch.zeros_like(self.fp32_groups_flat[i])) + fp32_gradients = self._split_flat_tensor( flat_tensor=self.fp32_groups_gradients_flat[i], num_elem_list=num_elem_list) + self.fp32_groups_gradients.append(fp32_gradients) + # flat tensor corresponding to actual fp32 gradients + length_without_padding = sum(num_elem_list) + self.fp32_groups_actual_gradients_flat.append( + torch.narrow(self.fp32_groups_gradients_flat[i], + 0, + 0, + length_without_padding)) + + # flat tensor corresponding to gradient partition + self.fp32_groups_gradient_flat_partition.append( + torch.narrow(self.fp32_groups_gradients_flat[i], + 0, + partition_id * partition_size, + partition_size)) + # create fp32 partition from flat tensor storage assert self.fp32_groups_flat[i].numel() % dp_world_size == 0, \ f'group {i} flat tensor size {self.fp32_groups_flat[i].numel()} not divisible by DP world size {dp_world_size}' - partition_size = self.fp32_groups_flat[i].numel() // dp_world_size - self.single_partition_of_fp32_groups.append( + self.fp32_groups_flat_partition.append( torch.narrow(self.fp32_groups_flat[i], 0, self.dp_rank * partition_size, partition_size)) - param_group['params'] = [self.single_partition_of_fp32_groups[i]] + param_group['params'] = [self.fp32_groups_flat_partition[i]] + see_memory_usage(f'after initializing group {i}', force=True) + + see_memory_usage('before initialize_optimizer', force=True) self.initialize_optimizer_states() - self._init_hp_grads() + see_memory_usage('end initialize_optimizer', force=True) + + see_memory_usage('end bf16_optimizer', force=True) def initialize_optimizer_states(self): """Take an optimizer step with zero-valued gradients to allocate internal @@ -107,8 +156,8 @@ def initialize_optimizer_states(self): This helps prevent memory fragmentation by allocating optimizer state at the beginning of training instead of after activations have been allocated. """ - for i, single_partition in enumerate(self.single_partition_of_fp32_groups): - single_partition.grad = self.fp32_groups_gradients_flat[i] + for param_partition, grad_partition in zip(self.fp32_groups_flat_partition, self.fp32_groups_gradient_flat_partition): + param_partition.grad = grad_partition self.optimizer.step() self.clear_hp_grads() @@ -144,10 +193,10 @@ def step(self, closure=None): if closure is not None: raise NotImplementedError(f'{self.__class__} does not support closure.') - params = self.get_fp32_params(filter_nograd=True) - all_groups_norm = get_grad_norm(parameters=params, - mpu=self.mpu, - norm_type=self.norm_type) + all_groups_norm = get_global_norm_of_tensors( + input_tensors=self.get_grads_for_norm(), + mpu=self.mpu, + norm_type=self.norm_type) self._global_grad_norm = all_groups_norm assert all_groups_norm > 0. @@ -159,16 +208,14 @@ def step(self, closure=None): self.optimizer.step() - self.clear_hp_grads() self.update_lp_params() - def get_fp32_params(self, filter_nograd=False): - params = [] - for group in self.fp32_groups: - for param in group: - if filter_nograd and param.grad is not None: - params.append(param) - return params + all_gather_dp_groups(partitioned_param_groups=self.bf16_partitioned_groups, + dp_process_group=self.real_dp_process_group, + start_alignment_factor=self.nccl_start_alignment_factor, + allgather_bucket_size=self.allgather_bucket_size) + + self.clear_hp_grads() def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs): """Perform a backward pass and copy the low-precision gradients to the @@ -188,15 +235,16 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg @torch.no_grad() def update_hp_grads(self, clear_lp_grads=False): for i, group in enumerate(self.bf16_groups): - for j, (lp, hp) in enumerate(zip(group, self.fp32_groups[i])): + for j, lp in enumerate(group): if lp.grad is None: continue - assert hp.grad is not None, \ - f'high precision param has no gradient, param_id = {id(hp)} group_info = [{i}][{j}]' + hp_grad = self.fp32_groups_gradients[i][j] + assert hp_grad is not None, \ + f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]' - hp.grad.data.add_(lp.grad.data.to(hp.dtype).view(hp.shape)) - lp._hp_grad = hp.grad + hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) + lp._hp_grad = hp_grad # clear gradients if clear_lp_grads: @@ -207,16 +255,15 @@ def get_grads_for_reduction(self): return self.fp32_groups_gradients_flat @torch.no_grad() - def update_lp_params(self): - for i, group in enumerate(self.bf16_groups): - for lp, hp in zip(group, self.fp32_groups[i]): - lp.data.copy_(hp.data.to(lp.dtype).view(lp.shape)) + def get_grads_for_norm(self): + return self.fp32_groups_actual_gradients_flat @torch.no_grad() - def _init_hp_grads(self): + def update_lp_params(self): for i, group in enumerate(self.bf16_groups): - for j, (lp, hp) in enumerate(zip(group, self.fp32_groups[i])): - hp.grad = self.fp32_groups_gradients[i][j] + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + for bf16_partitions, fp32_partition in zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition): + bf16_partitions[partition_id].data.copy_(fp32_partition.data) def clear_hp_grads(self): for flat_gradients in self.fp32_groups_gradients_flat: diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 4d16a31d8163..e1fc8342e792 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1327,11 +1327,13 @@ def _configure_bf16_optimizer(self, optimizer): if self.global_rank == 0: logger.info('Creating unfused BF16 optimizer') timers = self.timers if self.wall_clock_breakdown() else None - optimizer = BF16_Optimizer(optimizer, - mpu=self.mpu, - clip_grad=clip_grad, - dp_process_group=self.data_parallel_group, - timers=timers) + optimizer = BF16_Optimizer( + optimizer, + mpu=self.mpu, + clip_grad=clip_grad, + allgather_bucket_size=self.zero_allgather_bucket_size(), + dp_process_group=self.data_parallel_group, + timers=timers) else: raise NotImplementedError('BF16 requires a fused optimizer for now.') diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 4aaf5b59e8de..ebbf5b72a4ec 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -974,3 +974,44 @@ def align_dense_tensors(tensor_list, alignment): padded_tensor_list = tensor_list return padded_tensor_list + + +def all_gather_dp_groups(partitioned_param_groups, + dp_process_group, + start_alignment_factor, + allgather_bucket_size): + for group_id, partitioned_params in enumerate(partitioned_param_groups): + # Sequential AllGather Best of both worlds + partition_id = dist.get_rank(group=dp_process_group[group_id]) + dp_world_size = dist.get_world_size(group=dp_process_group[group_id]) + + num_shards = max( + 1, + partitioned_params[partition_id].numel() * dp_world_size // + allgather_bucket_size) + + shard_size = partitioned_params[partition_id].numel() // num_shards + + # Enforce nccl/rccl alignment of start location of each shard + shard_size = shard_size - (shard_size % start_alignment_factor) + + num_elements = shard_size + + assert shard_size * num_shards <= partitioned_params[partition_id].numel() + + for shard_id in range(num_shards): + + if shard_id == (num_shards - 1): + num_elements = partitioned_params[partition_id].numel( + ) - shard_id * shard_size + + shard_list = [] + for dp_id in range(dp_world_size): + curr_shard = partitioned_params[dp_id].narrow(0, + shard_id * shard_size, + num_elements).detach() + shard_list.append(curr_shard) + + dist.all_gather(shard_list, + shard_list[partition_id], + dp_process_group[group_id]) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index b48ab6281d54..b58d3ed32ad6 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -13,7 +13,8 @@ get_global_norm, see_memory_usage, is_model_parallel_parameter, - align_dense_tensors) + align_dense_tensors, + all_gather_dp_groups) from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS from deepspeed.runtime.zero.offload_constants import OFFLOAD_CPU_DEVICE, OFFLOAD_OPTIMIZER @@ -1740,41 +1741,47 @@ def step(self, closure=None): self.start_timers([OPTIMIZER_ALLGATHER]) # gather the updated weights from everyone - for group_id, partitioned_params in enumerate(self.parallel_partitioned_bit16_groups): + all_gather_dp_groups( + partitioned_param_groups=self.parallel_partitioned_bit16_groups, + dp_process_group=self.real_dp_process_group, + start_alignment_factor=self.nccl_start_alignment_factor, + allgather_bucket_size=self.allgather_bucket_size) - # Sequential AllGather Best of both worlds - dp_world_size = dist.get_world_size( - group=self.real_dp_process_group[group_id]) - num_shards = max( - 1, - partitioned_params[partition_id].numel() * dp_world_size // - self.allgather_bucket_size) - - shard_size = partitioned_params[partition_id].numel() // num_shards - - # Enforce nccl/rccl alignment of start location of each shard - shard_size = shard_size - (shard_size % self.nccl_start_alignment_factor) - - num_elements = shard_size - - assert shard_size * num_shards <= partitioned_params[partition_id].numel() - - for shard_id in range(num_shards): - - if shard_id == (num_shards - 1): - num_elements = partitioned_params[partition_id].numel( - ) - shard_id * shard_size - - shard_list = [] - for dp_id in range(dp_world_size): - curr_shard = partitioned_params[dp_id].narrow( - 0, - shard_id * shard_size, - num_elements).detach() - shard_list.append(curr_shard) - dist.all_gather(shard_list, - shard_list[partition_id], - group=self.real_dp_process_group[group_id]) + # for group_id, partitioned_params in enumerate(self.parallel_partitioned_bit16_groups): + # + # # Sequential AllGather Best of both worlds + # dp_world_size = dist.get_world_size( + # group=self.real_dp_process_group[group_id]) + # num_shards = max( + # 1, + # partitioned_params[partition_id].numel() * dp_world_size // + # self.allgather_bucket_size) + # + # shard_size = partitioned_params[partition_id].numel() // num_shards + # + # # Enforce nccl/rccl alignment of start location of each shard + # shard_size = shard_size - (shard_size % self.nccl_start_alignment_factor) + # + # num_elements = shard_size + # + # assert shard_size * num_shards <= partitioned_params[partition_id].numel() + # + # for shard_id in range(num_shards): + # + # if shard_id == (num_shards - 1): + # num_elements = partitioned_params[partition_id].numel( + # ) - shard_id * shard_size + # + # shard_list = [] + # for dp_id in range(dp_world_size): + # curr_shard = partitioned_params[dp_id].narrow( + # 0, + # shard_id * shard_size, + # num_elements).detach() + # shard_list.append(curr_shard) + # dist.all_gather(shard_list, + # shard_list[partition_id], + # group=self.real_dp_process_group[group_id]) self.stop_timers([OPTIMIZER_ALLGATHER]) # TODO: we probably don't need this? just to be safe diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 0017213a9941..fff8ebd31ef3 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -27,6 +27,10 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None): return my_group +class ZeRORuntimeException(Exception): + pass + + ZERO_SUPPORTED_OPTIMIZERS = [ torch.optim.Adam, torch.optim.AdamW, @@ -81,7 +85,3 @@ def assert_ints_same_as_other_ranks(ints: List[int]) -> None: if ints != rank0_ints: raise RuntimeError(f"disagreement between rank0 and rank{dist.get_rank()}: " f"rank0: {rank0_ints}, rank{dist.get_rank()}: {ints}") - - -class ZeRORuntimeException(Exception): - pass From 819abe2a7d8ba72c950da11890f565dc3a05385e Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 24 Feb 2022 23:21:11 +0500 Subject: [PATCH 05/29] finish zero_stage 1 sharding --- deepspeed/runtime/bf16_optimizer.py | 87 ++++++++++++++--------------- deepspeed/runtime/pipe/engine.py | 16 +++--- deepspeed/runtime/utils.py | 12 ++-- 3 files changed, 57 insertions(+), 58 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 5aa41ac1eb20..42f4e58d8a05 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -9,6 +9,8 @@ clip_gradients, align_dense_tensors, all_gather_dp_groups, + bwc_tensor_model_parallel_rank, + is_model_parallel_parameter, see_memory_usage) @@ -48,9 +50,6 @@ def __init__(self, self.bf16_groups_flat = [] self.bf16_partitioned_groups = [] - # TODO: Need to only track fp32 params of this partition - self.fp32_groups = [] - self.fp32_groups_flat = [] self.fp32_groups_flat_partition = [] # Maintain different fp32 gradients views for convenience @@ -58,6 +57,7 @@ def __init__(self, self.fp32_groups_gradients_flat = [] self.fp32_groups_actual_gradients_flat = [] self.fp32_groups_gradient_flat_partition = [] + self.fp32_groups_has_gradients = [] dp_world_size = dist.get_world_size(group=self.dp_process_group) @@ -75,6 +75,11 @@ def __init__(self, self.bf16_groups[i], self.nccl_start_alignment_factor * dp_world_size)) + # Make bf16 params point to flat tensor storage + self._update_storage_to_flattened_tensor( + tensor_list=self.bf16_groups[i], + flat_tensor=self.bf16_groups_flat[i]) + # divide flat weights into equal sized partitions partition_size = self.bf16_groups_flat[i].numel() // dp_world_size bf16_dp_partitions = [ @@ -85,37 +90,25 @@ def __init__(self, ] self.bf16_partitioned_groups.append(bf16_dp_partitions) - # Make bf16 params point to flat tensor storage - self._update_storage_to_flattened_tensor( - tensor_list=self.bf16_groups[i], - flat_tensor=self.bf16_groups_flat[i]) - - # create flat fp32 params - self.fp32_groups_flat.append( - self.bf16_groups_flat[i].clone().float().detach()) - self.fp32_groups_flat[i].requires_grad = True + # create fp32 params partition + self.fp32_groups_flat_partition.append( + bf16_dp_partitions[partition_id].clone().float().detach()) + self.fp32_groups_flat_partition[i].requires_grad = True num_elem_list = [t.numel() for t in self.bf16_groups[i]] - # create fp32 params using flat tensor storage - fp32_group_params = self._split_flat_tensor( - flat_tensor=self.fp32_groups_flat[i], - num_elem_list=num_elem_list) - self._propagate_attributes(src_tensor_list=self.bf16_groups[i], - dst_tensor_list=fp32_group_params) - self.fp32_groups.append(fp32_group_params) - # create fp32 gradients self.fp32_groups_gradients_flat.append( - torch.zeros_like(self.fp32_groups_flat[i])) + torch.zeros_like(self.bf16_groups_flat[i], + dtype=torch.float32)) + # track individual fp32 gradients for entire model fp32_gradients = self._split_flat_tensor( flat_tensor=self.fp32_groups_gradients_flat[i], num_elem_list=num_elem_list) - self.fp32_groups_gradients.append(fp32_gradients) - # flat tensor corresponding to actual fp32 gradients + # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding) length_without_padding = sum(num_elem_list) self.fp32_groups_actual_gradients_flat.append( torch.narrow(self.fp32_groups_gradients_flat[i], @@ -130,17 +123,12 @@ def __init__(self, partition_id * partition_size, partition_size)) - # create fp32 partition from flat tensor storage - assert self.fp32_groups_flat[i].numel() % dp_world_size == 0, \ - f'group {i} flat tensor size {self.fp32_groups_flat[i].numel()} not divisible by DP world size {dp_world_size}' - - self.fp32_groups_flat_partition.append( - torch.narrow(self.fp32_groups_flat[i], - 0, - self.dp_rank * partition_size, - partition_size)) + # track fp32 gradient updates + self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i])) + # update optimizer param groups to reference fp32 params partition param_group['params'] = [self.fp32_groups_flat_partition[i]] + see_memory_usage(f'after initializing group {i}', force=True) see_memory_usage('before initialize_optimizer', force=True) @@ -160,14 +148,8 @@ def initialize_optimizer_states(self): param_partition.grad = grad_partition self.optimizer.step() - self.clear_hp_grads() - def _propagate_attributes(self, src_tensor_list, dst_tensor_list): - for src_tensor, dst_tensor in zip(src_tensor_list, dst_tensor_list): - if hasattr(src_tensor, 'model_parallel'): - dst_tensor.model_parallel = src_tensor.model_parallel - if hasattr(src_tensor, PIPE_REPLICATED): - dst_tensor.ds_pipe_replicated = src_tensor.ds_pipe_replicated + self.clear_hp_grads() def _split_flat_tensor(self, flat_tensor, num_elem_list): assert sum(num_elem_list) <= flat_tensor.numel() @@ -245,6 +227,7 @@ def update_hp_grads(self, clear_lp_grads=False): hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) lp._hp_grad = hp_grad + self.fp32_groups_has_gradients[i][j] = True # clear gradients if clear_lp_grads: @@ -256,7 +239,23 @@ def get_grads_for_reduction(self): @torch.no_grad() def get_grads_for_norm(self): - return self.fp32_groups_actual_gradients_flat + grads = [] + tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) + for i, group in enumerate(self.bf16_groups): + for j, lp in enumerate(group): + if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated: + continue + + if (tensor_mp_rank > 0) and not is_model_parallel_parameter(lp): + continue + + if not self.fp32_groups_has_gradients[i][j]: + continue + + # TODO: Only include gradients in this partition + grads.append(self.fp32_groups_gradients[i][j]) + + return grads @torch.no_grad() def update_lp_params(self): @@ -269,6 +268,9 @@ def clear_hp_grads(self): for flat_gradients in self.fp32_groups_gradients_flat: flat_gradients.zero_() + for group in self.fp32_groups_has_gradients: + group = [False] * len(group) + def clear_lp_grads(self): for group in self.bf16_groups: for param in group: @@ -277,7 +279,6 @@ def clear_lp_grads(self): def state_dict(self): state_dict = {} state_dict['optimizer_state_dict'] = self.optimizer.state_dict() - state_dict['fp32_groups'] = self.fp32_groups state_dict['clip_grad'] = self.clip_grad return state_dict @@ -286,10 +287,6 @@ def load_state_dict(self, state_dict, load_optimizer_states=True): self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) self.clip_grad = state_dict['clip_grad'] - for i in range(len(self.fp32_groups)): - for current, saved in zip(self.fp32_groups[i], state_dict['fp32_groups'][i]): - current.data.copy_(saved.data) - @property def param_groups(self): """Forward the wrapped optimizer's parameters.""" diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index fe10064ddc25..60bbb62335d3 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -259,14 +259,14 @@ def _exec_reduce_grads(self): def _bf16_reduce_grads(self): # Make our own list of gradients from the optimizer's FP32 grads grads = [] - for param_group in self.optimizer.fp32_groups: - for param in param_group: - if param.grad is None: - continue - assert param.grad is not None - assert param.grad.dtype == torch.float32 - grads.append(param.grad.data) - self.buffered_allreduce_fallback(grads=grads, + # for param_group in self.optimizer.fp32_groups: + # for param in param_group: + # if param.grad is None: + # continue + # assert param.grad is not None + # assert param.grad.dtype == torch.float32 + # grads.append(param.grad.data) + self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(), elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE) def _reserve_pipe_buffers(self, num_buffers): diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index ebbf5b72a4ec..7642301e289a 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -72,7 +72,13 @@ def set_random_seed(seed): def is_model_parallel_parameter(p) -> bool: - return hasattr(p, 'model_parallel') and p.model_parallel + if hasattr(p, 'model_parallel') and p.model_parallel: + return True + + if hasattr(p, 'tensor_model_parallel') and p.tensor_model_parallel: + return True + + return False def bwc_tensor_model_parallel_rank(mpu=None): @@ -554,10 +560,6 @@ def get_weight_norm(parameters, norm_type=2, mpu=None): return total_norm -def is_model_parallel_parameter(p): - return hasattr(p, 'model_parallel') and p.model_parallel - - def prefix_sum_inc(weights): """ Compute an inclusive prefix sum. From e48035b794af6e3bacaff817d686d770ac449747 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Sun, 27 Feb 2022 09:04:24 +0500 Subject: [PATCH 06/29] Matching fp16 with debugging codes --- deepspeed/runtime/bf16_optimizer.py | 46 ++++++++++++++++++++++- deepspeed/runtime/engine.py | 7 ++-- deepspeed/runtime/fp16/fused_optimizer.py | 38 ++++++++++++++++++- deepspeed/runtime/pipe/engine.py | 4 +- 4 files changed, 87 insertions(+), 8 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 42f4e58d8a05..358652b1c84d 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -12,8 +12,8 @@ bwc_tensor_model_parallel_rank, is_model_parallel_parameter, see_memory_usage) - - +import os +DUMP_FILE = os.environ.get('BIT16_DUMP_FILE', os.path.join('/tmp', 'bf16_debug.txt')) class BF16_Optimizer: def __init__(self, init_optimizer, @@ -58,6 +58,9 @@ def __init__(self, self.fp32_groups_actual_gradients_flat = [] self.fp32_groups_gradient_flat_partition = [] self.fp32_groups_has_gradients = [] + self.step_count = 0 + self.backward_count = 0 + self.fp = open(DUMP_FILE, 'w') if dist.get_rank() == 0 else None dp_world_size = dist.get_world_size(group=self.dp_process_group) @@ -137,6 +140,7 @@ def __init__(self, see_memory_usage('end bf16_optimizer', force=True) + def initialize_optimizer_states(self): """Take an optimizer step with zero-valued gradients to allocate internal optimizer state. @@ -170,6 +174,22 @@ def _update_storage_to_flattened_tensor(self, tensor_list, flat_tensor): def _flatten_dense_tensors_aligned(self, tensor_list, alignment): return self.flatten(align_dense_tensors(tensor_list, alignment)) + + def _dump_tensors(self, tag, tensor_list, print_all=False): + if self.fp is None: + return + self.fp.write(f"rank {dist.get_rank()} - dump {tag} \n") + for i, t in enumerate(tensor_list): + if torch.is_tensor(t): + if print_all: + value = t + else: + value = t[:1] if t.numel() > 1 else t + else: + value = t + self.fp.write(f"rank {dist.get_rank()} - {i} = {value} \n") + + @torch.no_grad() def step(self, closure=None): if closure is not None: @@ -188,6 +208,12 @@ def step(self, closure=None): mpu=self.mpu, global_grad_norm=all_groups_norm) + + #self._dump_tensors(f'hp grads before step {self.step_count}', [self.fp32_groups_gradients[0][0]]) + #self._dump_tensors(f'hp norm before step {self.step_count}', [self._global_grad_norm]) + #self._dump_tensors(f'lp weights before step {self.step_count}', [self.bf16_groups[0][0][0]]) + + self.optimizer.step() self.update_lp_params() @@ -197,7 +223,10 @@ def step(self, closure=None): start_alignment_factor=self.nccl_start_alignment_factor, allgather_bucket_size=self.allgather_bucket_size) + #self._dump_tensors(f'lp weights after step {self.step_count}', [self.bf16_groups[0][0][0]]) + self.clear_hp_grads() + self.step_count += 1 def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs): """Perform a backward pass and copy the low-precision gradients to the @@ -210,18 +239,28 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg """ self.clear_lp_grads() loss.backward(**bwd_kwargs) + self._dump_tensors(f'loss in backward {self.backward_count}', [loss]) + self._dump_tensors(f'lp grads in backward {self.backward_count}', [self.bf16_groups[0][0].grad[0]]) if update_hp_grads: + #self._dump_tensors(f'hp grads before update {self.backward_count}', [self.fp32_groups_gradients[0][0]]) self.update_hp_grads(clear_lp_grads=clear_lp_grads) + #self._dump_tensors(f'hp grads after update {self.backward_count}', [self.fp32_groups_gradients[0][0]]) + + self.backward_count += 1 @torch.no_grad() def update_hp_grads(self, clear_lp_grads=False): + lp_dtypes = [] + hp_dtypes = [] for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): if lp.grad is None: continue hp_grad = self.fp32_groups_gradients[i][j] + hp_dtypes.append(hp_grad.dtype) + lp_dtypes.append(lp.grad.dtype) assert hp_grad is not None, \ f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]' @@ -233,6 +272,9 @@ def update_hp_grads(self, clear_lp_grads=False): if clear_lp_grads: lp.grad = None +# print(f'dtype in accum in step {self.step_count} lp = {lp_dtypes}') +# print(f'dtype in accum in step {self.step_count} hp = {hp_dtypes}') + @torch.no_grad() def get_grads_for_reduction(self): return self.fp32_groups_gradients_flat diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e1fc8342e792..6fce09f3827b 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1663,6 +1663,9 @@ def print_forward_breakdown(self, fwd_time): @instrument_w_nvtx def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): + assert not (self.bfloat16_enabled() and self.pipeline_parallelism), \ + f'allreduce_gradients() is not valid when bfloat+pipeline_parallelism is enabled' + # Pass (PP) gas boundary flag to optimizer (required for zero) self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary( ) @@ -1676,10 +1679,6 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES: self.optimizer.reduce_gradients( pipeline_parallel=self.pipeline_parallelism) - elif self.bfloat16_enabled() and self.pipeline_parallelism: - self.buffered_allreduce_fallback( - grads=self.optimizer.get_grads_for_reduction(), - elements_per_buffer=bucket_size) else: self.buffered_allreduce_fallback(elements_per_buffer=bucket_size) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index fdba94c55af7..f8be5e4de264 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -14,6 +14,8 @@ from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT import torch.distributed as dist +import os +DUMP_FILE = os.environ.get('BIT16_DUMP_FILE', os.path.join('/tmp', 'fp16_debug.txt')) class FP16_Optimizer(object): """ @@ -95,6 +97,9 @@ def __init__(self, self.clip_grad = clip_grad self.norm_type = 2 + self.step_count = 0 + self.backward_count = 0 + self.fp = open(DUMP_FILE, 'w') if dist.get_rank() == 0 else None TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) @@ -261,6 +266,8 @@ def step(self, closure=None): self.fp32_groups_flat[i].grad = grads_groups_flat[i] + #self._dump_tensors(f'hp grads before step {self.step_count}', [self.fp32_groups_flat[0].grad[0]]) + self.start_timers([COMPUTE_NORM]) all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) @@ -279,6 +286,9 @@ def step(self, closure=None): self.unscale_and_clip_grads(grads_groups_flat, scaled_global_grad_norm) self.stop_timers([UNSCALE_AND_CLIP]) + #self._dump_tensors(f'hp norm before step {self.step_count}', [self._global_grad_norm]) + #self._dump_tensors(f'lp weights before step {self.step_count}', [self.fp16_groups[0][0][0]]) + self.start_timers([BASIC_STEP]) self.optimizer.step() self.stop_timers([BASIC_STEP]) @@ -297,8 +307,12 @@ def step(self, closure=None): self.stop_timers([UPDATE_FP16]) + #self._dump_tensors(f'lp weights after step {self.step_count}', [self.fp16_groups[0][0][0]]) + self.log_timers(STEP_TIMERS) + self.step_count += 1 + return self.overflow def _get_norm_with_moe_layers(self, all_groups_norm): @@ -333,6 +347,22 @@ def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True) return combined_scale + + def _dump_tensors(self, tag, tensor_list, print_all=False): + if self.fp is None: + return + self.fp.write(f"rank {dist.get_rank()} - dump {tag} \n") + for i, t in enumerate(tensor_list): + if torch.is_tensor(t): + if print_all: + value = t + else: + value = t[:1] if t.numel() > 1 else t + else: + value = t + self.fp.write(f"rank {dist.get_rank()} - {i} = {value} \n") + + def backward(self, loss, create_graph=False, retain_graph=False): """ :attr:`backward` performs the following steps: @@ -340,11 +370,17 @@ def backward(self, loss, create_graph=False, retain_graph=False): 1. fp32_loss = loss.float() 2. scaled_loss = fp32_loss*loss_scale 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves - """ + """ + scaled_loss = (loss.float()) * self.cur_scale + self._dump_tensors(f'scaled_loss in backward {self.backward_count}', [scaled_loss]) scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph) + self._dump_tensors(f'lp grads after backward {self.backward_count}', [self.fp16_groups[0][0].grad[0]]) + + self.backward_count += 1 + def _update_scale(self, skip): if self.dynamic_loss_scale: prev_scale = self.cur_scale diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 60bbb62335d3..e920d73d759a 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -648,7 +648,9 @@ def _exec_forward_pass(self, buffer_id): # tensor changes across batches self._zero_grads(inputs) + self.optimizer._dump_tensors(f'inputs', [t[:][:10] for t in inputs[:2]], print_all=True) outputs = super().forward(inputs) + self.optimizer._dump_tensors(f'outputs', [outputs], print_all=True) # Partition the outputs if we are not the last stage if self.is_pipe_partitioned and not self.is_last_stage(): @@ -680,7 +682,7 @@ def _exec_forward_pass(self, buffer_id): if self.is_last_stage(): if self._compute_loss and self.loss_model is not None: labels = self.pipe_buffers['labels'][buffer_id] - self.loss = self.loss_model(outputs, labels) + self.loss = self.loss_model(outputs, labels, self) else: # Some models just return loss from forward() self.loss = outputs From 8245053977592a0a9c437a489d62a977376ff8bf Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Sun, 27 Feb 2022 09:38:23 +0500 Subject: [PATCH 07/29] Matching loss with fp16 --- deepspeed/runtime/bf16_optimizer.py | 46 ++--------------------- deepspeed/runtime/fp16/fused_optimizer.py | 35 +---------------- deepspeed/runtime/pipe/engine.py | 4 +- 3 files changed, 5 insertions(+), 80 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 358652b1c84d..fea33a5f4f3a 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -12,8 +12,8 @@ bwc_tensor_model_parallel_rank, is_model_parallel_parameter, see_memory_usage) -import os -DUMP_FILE = os.environ.get('BIT16_DUMP_FILE', os.path.join('/tmp', 'bf16_debug.txt')) + + class BF16_Optimizer: def __init__(self, init_optimizer, @@ -59,8 +59,6 @@ def __init__(self, self.fp32_groups_gradient_flat_partition = [] self.fp32_groups_has_gradients = [] self.step_count = 0 - self.backward_count = 0 - self.fp = open(DUMP_FILE, 'w') if dist.get_rank() == 0 else None dp_world_size = dist.get_world_size(group=self.dp_process_group) @@ -140,7 +138,6 @@ def __init__(self, see_memory_usage('end bf16_optimizer', force=True) - def initialize_optimizer_states(self): """Take an optimizer step with zero-valued gradients to allocate internal optimizer state. @@ -174,22 +171,6 @@ def _update_storage_to_flattened_tensor(self, tensor_list, flat_tensor): def _flatten_dense_tensors_aligned(self, tensor_list, alignment): return self.flatten(align_dense_tensors(tensor_list, alignment)) - - def _dump_tensors(self, tag, tensor_list, print_all=False): - if self.fp is None: - return - self.fp.write(f"rank {dist.get_rank()} - dump {tag} \n") - for i, t in enumerate(tensor_list): - if torch.is_tensor(t): - if print_all: - value = t - else: - value = t[:1] if t.numel() > 1 else t - else: - value = t - self.fp.write(f"rank {dist.get_rank()} - {i} = {value} \n") - - @torch.no_grad() def step(self, closure=None): if closure is not None: @@ -208,12 +189,6 @@ def step(self, closure=None): mpu=self.mpu, global_grad_norm=all_groups_norm) - - #self._dump_tensors(f'hp grads before step {self.step_count}', [self.fp32_groups_gradients[0][0]]) - #self._dump_tensors(f'hp norm before step {self.step_count}', [self._global_grad_norm]) - #self._dump_tensors(f'lp weights before step {self.step_count}', [self.bf16_groups[0][0][0]]) - - self.optimizer.step() self.update_lp_params() @@ -223,8 +198,6 @@ def step(self, closure=None): start_alignment_factor=self.nccl_start_alignment_factor, allgather_bucket_size=self.allgather_bucket_size) - #self._dump_tensors(f'lp weights after step {self.step_count}', [self.bf16_groups[0][0][0]]) - self.clear_hp_grads() self.step_count += 1 @@ -239,28 +212,18 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg """ self.clear_lp_grads() loss.backward(**bwd_kwargs) - self._dump_tensors(f'loss in backward {self.backward_count}', [loss]) - self._dump_tensors(f'lp grads in backward {self.backward_count}', [self.bf16_groups[0][0].grad[0]]) if update_hp_grads: - #self._dump_tensors(f'hp grads before update {self.backward_count}', [self.fp32_groups_gradients[0][0]]) self.update_hp_grads(clear_lp_grads=clear_lp_grads) - #self._dump_tensors(f'hp grads after update {self.backward_count}', [self.fp32_groups_gradients[0][0]]) - - self.backward_count += 1 @torch.no_grad() def update_hp_grads(self, clear_lp_grads=False): - lp_dtypes = [] - hp_dtypes = [] for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): if lp.grad is None: continue hp_grad = self.fp32_groups_gradients[i][j] - hp_dtypes.append(hp_grad.dtype) - lp_dtypes.append(lp.grad.dtype) assert hp_grad is not None, \ f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]' @@ -272,9 +235,6 @@ def update_hp_grads(self, clear_lp_grads=False): if clear_lp_grads: lp.grad = None -# print(f'dtype in accum in step {self.step_count} lp = {lp_dtypes}') -# print(f'dtype in accum in step {self.step_count} hp = {hp_dtypes}') - @torch.no_grad() def get_grads_for_reduction(self): return self.fp32_groups_gradients_flat @@ -294,7 +254,6 @@ def get_grads_for_norm(self): if not self.fp32_groups_has_gradients[i][j]: continue - # TODO: Only include gradients in this partition grads.append(self.fp32_groups_gradients[i][j]) return grads @@ -319,6 +278,7 @@ def clear_lp_grads(self): param.grad = None def state_dict(self): + # TODO capture all training state for checkpointing state_dict = {} state_dict['optimizer_state_dict'] = self.optimizer.state_dict() state_dict['clip_grad'] = self.clip_grad diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index f8be5e4de264..7944a4b8b925 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -14,8 +14,6 @@ from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT import torch.distributed as dist -import os -DUMP_FILE = os.environ.get('BIT16_DUMP_FILE', os.path.join('/tmp', 'fp16_debug.txt')) class FP16_Optimizer(object): """ @@ -98,8 +96,6 @@ def __init__(self, self.clip_grad = clip_grad self.norm_type = 2 self.step_count = 0 - self.backward_count = 0 - self.fp = open(DUMP_FILE, 'w') if dist.get_rank() == 0 else None TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) @@ -266,8 +262,6 @@ def step(self, closure=None): self.fp32_groups_flat[i].grad = grads_groups_flat[i] - #self._dump_tensors(f'hp grads before step {self.step_count}', [self.fp32_groups_flat[0].grad[0]]) - self.start_timers([COMPUTE_NORM]) all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) @@ -286,9 +280,6 @@ def step(self, closure=None): self.unscale_and_clip_grads(grads_groups_flat, scaled_global_grad_norm) self.stop_timers([UNSCALE_AND_CLIP]) - #self._dump_tensors(f'hp norm before step {self.step_count}', [self._global_grad_norm]) - #self._dump_tensors(f'lp weights before step {self.step_count}', [self.fp16_groups[0][0][0]]) - self.start_timers([BASIC_STEP]) self.optimizer.step() self.stop_timers([BASIC_STEP]) @@ -307,8 +298,6 @@ def step(self, closure=None): self.stop_timers([UPDATE_FP16]) - #self._dump_tensors(f'lp weights after step {self.step_count}', [self.fp16_groups[0][0][0]]) - self.log_timers(STEP_TIMERS) self.step_count += 1 @@ -347,22 +336,6 @@ def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True) return combined_scale - - def _dump_tensors(self, tag, tensor_list, print_all=False): - if self.fp is None: - return - self.fp.write(f"rank {dist.get_rank()} - dump {tag} \n") - for i, t in enumerate(tensor_list): - if torch.is_tensor(t): - if print_all: - value = t - else: - value = t[:1] if t.numel() > 1 else t - else: - value = t - self.fp.write(f"rank {dist.get_rank()} - {i} = {value} \n") - - def backward(self, loss, create_graph=False, retain_graph=False): """ :attr:`backward` performs the following steps: @@ -370,17 +343,11 @@ def backward(self, loss, create_graph=False, retain_graph=False): 1. fp32_loss = loss.float() 2. scaled_loss = fp32_loss*loss_scale 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves - """ + """ scaled_loss = (loss.float()) * self.cur_scale - self._dump_tensors(f'scaled_loss in backward {self.backward_count}', [scaled_loss]) - scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph) - self._dump_tensors(f'lp grads after backward {self.backward_count}', [self.fp16_groups[0][0].grad[0]]) - - self.backward_count += 1 - def _update_scale(self, skip): if self.dynamic_loss_scale: prev_scale = self.cur_scale diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index e920d73d759a..60bbb62335d3 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -648,9 +648,7 @@ def _exec_forward_pass(self, buffer_id): # tensor changes across batches self._zero_grads(inputs) - self.optimizer._dump_tensors(f'inputs', [t[:][:10] for t in inputs[:2]], print_all=True) outputs = super().forward(inputs) - self.optimizer._dump_tensors(f'outputs', [outputs], print_all=True) # Partition the outputs if we are not the last stage if self.is_pipe_partitioned and not self.is_last_stage(): @@ -682,7 +680,7 @@ def _exec_forward_pass(self, buffer_id): if self.is_last_stage(): if self._compute_loss and self.loss_model is not None: labels = self.pipe_buffers['labels'][buffer_id] - self.loss = self.loss_model(outputs, labels, self) + self.loss = self.loss_model(outputs, labels) else: # Some models just return loss from forward() self.loss = outputs From 15293139f7c25a9a7ddbcefcd8bb1d84f4052b2e Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Mon, 28 Feb 2022 20:01:06 +0000 Subject: [PATCH 08/29] Fix gradient clipping --- deepspeed/runtime/bf16_optimizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index fea33a5f4f3a..f25f1ae780e2 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -184,10 +184,10 @@ def step(self, closure=None): assert all_groups_norm > 0. if self.clip_grad > 0.: - clip_gradients(parameters=params, - max_norm=self.clip_grad, - mpu=self.mpu, - global_grad_norm=all_groups_norm) + clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(), + max_norm=self.clip_grad, + mpu=self.mpu, + global_grad_norm=all_groups_norm) self.optimizer.step() From 27e5b9564912f2e615f2927ea267d4e6aa1c7cd0 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 1 Mar 2022 11:04:28 +0500 Subject: [PATCH 09/29] bf16 gradient clipping fix bf16 checkpoint save/load --- deepspeed/checkpoint/constants.py | 2 + deepspeed/runtime/bf16_optimizer.py | 71 +++++++++++++++++++++---- deepspeed/runtime/engine.py | 24 +++++---- deepspeed/runtime/utils.py | 4 +- deepspeed/runtime/zero/stage_1_and_2.py | 37 +------------ 5 files changed, 81 insertions(+), 57 deletions(-) diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index 0162bf6f27d3..3905afa6fe97 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -11,9 +11,11 @@ BASE_OPTIMIZER_STATE = 'base_optimizer_state' SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups" +GROUPS_PADDING = 'groups_padding' PARTITION_COUNT = 'partition_count' ZERO_STAGE = 'zero_stage' +CLIP_GRAD = 'clip_gradient' ######################################### # Module checkpoint keys diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index f25f1ae780e2..696fd3929ebb 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -2,7 +2,9 @@ import torch.distributed as dist from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.ops.op_builder import UtilsBuilder +from packaging import version as pkg_version +from deepspeed.git_version_info import version from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, get_grad_norm, @@ -13,6 +15,13 @@ is_model_parallel_parameter, see_memory_usage) +from deepspeed.checkpoint.constants import (DS_VERSION, + PARTITION_COUNT, + BASE_OPTIMIZER_STATE, + SINGLE_PARTITION_OF_FP32_GROUPS, + CLIP_GRAD, + GROUPS_PADDING) + class BF16_Optimizer: def __init__(self, @@ -36,6 +45,10 @@ def __init__(self, self.real_dp_process_group = [ dp_process_group for i in range(len(self.optimizer.param_groups)) ] + dp_world_size = dist.get_world_size(group=self.dp_process_group) + self.partition_count = [ + dp_world_size for i in range(len(self.optimizer.param_groups)) + ] # Load pre-built or JIT compile (un)flatten ops util_ops = UtilsBuilder().load() @@ -58,9 +71,9 @@ def __init__(self, self.fp32_groups_actual_gradients_flat = [] self.fp32_groups_gradient_flat_partition = [] self.fp32_groups_has_gradients = [] - self.step_count = 0 - dp_world_size = dist.get_world_size(group=self.dp_process_group) + self.step_count = 0 + self.groups_padding = [] for i, param_group in enumerate(self.optimizer.param_groups): see_memory_usage(f'before initializing group {i}', force=True) @@ -127,6 +140,15 @@ def __init__(self, # track fp32 gradient updates self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i])) + # Record padding required for alignment + if partition_id == dist.get_world_size( + group=self.real_dp_process_group[i]) - 1: + padding = self.bf16_groups_flat[i].numel() - length_without_padding + else: + padding = 0 + + self.groups_padding.append(padding) + # update optimizer param groups to reference fp32 params partition param_group['params'] = [self.fp32_groups_flat_partition[i]] @@ -186,8 +208,8 @@ def step(self, closure=None): if self.clip_grad > 0.: clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(), max_norm=self.clip_grad, - mpu=self.mpu, - global_grad_norm=all_groups_norm) + global_norm=all_groups_norm, + mpu=self.mpu) self.optimizer.step() @@ -278,18 +300,47 @@ def clear_lp_grads(self): param.grad = None def state_dict(self): - # TODO capture all training state for checkpointing state_dict = {} - state_dict['optimizer_state_dict'] = self.optimizer.state_dict() - state_dict['clip_grad'] = self.clip_grad + state_dict[CLIP_GRAD] = self.clip_grad + state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict() + state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = self.fp32_groups_flat_partition + state_dict[GROUPS_PADDING] = self.groups_padding + state_dict[PARTITION_COUNT] = self.partition_count + state_dict[DS_VERSION] = version + return state_dict - def load_state_dict(self, state_dict, load_optimizer_states=True): + def load_state_dict(self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False): + dp_rank = dist.get_rank(group=self.dp_process_group) + current_rank_sd = state_dict_list[dp_rank] + + ckpt_version = current_rank_sd.get(DS_VERSION, False) + assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed" + ckpt_version = pkg_version.parse(ckpt_version) + + self.clip_grad = current_rank_sd[CLIP_GRAD] + if load_optimizer_states: - self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) - self.clip_grad = state_dict['clip_grad'] + self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE]) + + if load_from_fp32_weights: + for current, saved in zip(self.fp32_groups_flat_partition, current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]): + src_tensor = _get_padded_tensor(saved, current.numel()) + current.data.copy_(src_tensor.data) @property def param_groups(self): """Forward the wrapped optimizer's parameters.""" return self.optimizer.param_groups + + +def _get_padded_tensor(src_tensor, size): + if src_tensor.numel() >= size: + return src_tensor + padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device) + slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel()) + slice_tensor.data.copy_(src_tensor.data) + return padded_tensor diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 6fce09f3827b..a2ebe7b7b681 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -807,7 +807,7 @@ def _configure_checkpointing(self, dist_init_required): self.save_non_zero_checkpoint = ( dp_rank == 0) or self.zero_optimization_partition_weights() - if self.zero_optimization(): + if self.zero_optimization() or self.bfloat16_enabled(): param_rank = torch.distributed.get_rank( group=self.optimizer.dp_process_group) @@ -2370,7 +2370,8 @@ def load_module_state_dict(self, state_dict, strict=True): self.module.load_state_dict(state_dict, strict=strict) def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank): - filename = "zero_pp_rank_{}".format(dp_rank) + filename = "bf16_zero_pp_rank_{}".format( + dp_rank) if self.bfloat16_enabled() else "zero_pp_rank_{}".format(dp_rank) zero_ckpt_name = os.path.join( checkpoints_path, str(tag), @@ -2495,7 +2496,8 @@ def load_checkpoint(self, load_lr_scheduler_states=load_lr_scheduler_states, load_module_only=load_module_only) - if self.zero_optimization() and load_path is not None: + load_zero_checkpoint = self.zero_optimization() or self.bfloat16_enabled() + if load_zero_checkpoint and load_path is not None: success = self._load_zero_checkpoint( load_dir, tag, @@ -2567,8 +2569,9 @@ def _load_checkpoint(self, else: optim_checkpoint = checkpoint - if load_optimizer_states and self.optimizer is not None and not self.zero_optimization( - ): + has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled( + ) + if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state: if self.fp16_enabled(): self.optimizer.load_state_dict( optim_checkpoint['optimizer'], @@ -2964,13 +2967,13 @@ def _save_checkpoint(self, save_dir, tag, client_state={}): # module_state_dict() and uses this path to save the model. module_state_dict() # then instead just returns None. self._curr_ckpt_path = os.path.join(save_dir, tag) - + zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() state = dict(module=self.module_state_dict(), buffer_names=self._get_buffer_names(), optimizer=self.optimizer.state_dict() - if self.optimizer and not self.zero_optimization() else None, + if self.optimizer and not zero_optimizer_state else None, param_shapes=self._get_zero_param_shapes() - if self.optimizer and self.zero_optimization() else None, + if self.optimizer and zero_optimizer_state else None, lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, sparse_tensor_module_names=self.sparse_tensor_module_names, @@ -3028,6 +3031,8 @@ def _get_zero_param_shapes(self): # if we don't use it, we get parameters ordered incorrectly if hasattr(self.optimizer, "round_robin_bit16_groups"): bit16_groups = self.optimizer.round_robin_bit16_groups + elif self.bfloat16_enabled() and not self.zero_optimization(): + bit16_groups = self.optimizer.bf16_groups else: bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage( ) == 2 else self.optimizer.fp16_groups @@ -3068,7 +3073,8 @@ def _save_zero_checkpoint(self, save_path, tag): torch.save(zero_sd, zero_checkpoint_name) if self.global_rank == 0: self._copy_recovery_script(save_path) - logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name)) + ckpt_type = 'zero' if self.zero_optimization() else 'bfl6_zero' + logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}') def _zero3_consolidated_16bit_state_dict(self): """ diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 7642301e289a..3bb7559f357c 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -944,7 +944,7 @@ def clip_tensors_by_global_norm(input_tensors, """Clip list of tensors by global norm. Args: input_tensors: List of tensors to be clipped - global_grad_norm (float, optional): Precomputed norm. Defaults to None. + global_norm (float, optional): Precomputed norm. Defaults to None. mpu (optional): model parallelism unit. Defaults to None. eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6 Returns: @@ -953,7 +953,7 @@ def clip_tensors_by_global_norm(input_tensors, if global_norm is None: global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu) - clip_coef = max_norm / (global_grad_norm + eps) + clip_coef = max_norm / (global_norm + eps) if clip_coef < 1: for t in input_tensors: diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index b58d3ed32ad6..15ee39eb5cb4 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -23,7 +23,7 @@ from deepspeed.utils import logger from deepspeed.moe.utils import is_moe_param from deepspeed.git_version_info import version - +from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, SINGLE_PARTITION_OF_FP32_GROUPS, @@ -1747,41 +1747,6 @@ def step(self, closure=None): start_alignment_factor=self.nccl_start_alignment_factor, allgather_bucket_size=self.allgather_bucket_size) - # for group_id, partitioned_params in enumerate(self.parallel_partitioned_bit16_groups): - # - # # Sequential AllGather Best of both worlds - # dp_world_size = dist.get_world_size( - # group=self.real_dp_process_group[group_id]) - # num_shards = max( - # 1, - # partitioned_params[partition_id].numel() * dp_world_size // - # self.allgather_bucket_size) - # - # shard_size = partitioned_params[partition_id].numel() // num_shards - # - # # Enforce nccl/rccl alignment of start location of each shard - # shard_size = shard_size - (shard_size % self.nccl_start_alignment_factor) - # - # num_elements = shard_size - # - # assert shard_size * num_shards <= partitioned_params[partition_id].numel() - # - # for shard_id in range(num_shards): - # - # if shard_id == (num_shards - 1): - # num_elements = partitioned_params[partition_id].numel( - # ) - shard_id * shard_size - # - # shard_list = [] - # for dp_id in range(dp_world_size): - # curr_shard = partitioned_params[dp_id].narrow( - # 0, - # shard_id * shard_size, - # num_elements).detach() - # shard_list.append(curr_shard) - # dist.all_gather(shard_list, - # shard_list[partition_id], - # group=self.real_dp_process_group[group_id]) self.stop_timers([OPTIMIZER_ALLGATHER]) # TODO: we probably don't need this? just to be safe From f4977024900f115a044cf82811388e1e024afcca Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 2 Mar 2022 03:55:11 +0500 Subject: [PATCH 10/29] Unscale grad norm --- deepspeed/runtime/zero/stage_1_and_2.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 15ee39eb5cb4..6f518f2f77e6 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1707,8 +1707,13 @@ def step(self, closure=None): if self.has_moe_layers: self._average_expert_grad_norms(norm_groups) - self._global_grad_norm = get_global_norm(norm_list=norm_groups) - self.unscale_and_clip_grads(single_partition_grad_groups, self._global_grad_norm) + scaled_global_grad_norm = get_global_norm(norm_list=norm_groups) + self.unscale_and_clip_grads(single_partition_grad_groups, + scaled_global_grad_norm) + + # Stash unscaled gradient norm + self._global_grad_norm = scaled_global_grad_norm / self.loss_scale + self.stop_timers([OPTIMIZER_GRADIENTS]) self.start_timers([OPTIMIZER_STEP]) From 0ad7c7d328fddfffeee4020112c84cd4ba367e73 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 2 Mar 2022 05:14:49 +0500 Subject: [PATCH 11/29] Fix grad norm scaling --- deepspeed/runtime/zero/stage3.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index a118dad15a0e..fef2a9daa25a 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2614,7 +2614,10 @@ def step(self, closure=None): return norm_groups = self._get_norm_groups() - self._global_grad_norm = get_global_norm(norm_list=norm_groups) + scaled_global_grad_norm = get_global_norm(norm_list=norm_groups) + + # Stash unscaled gradient norm + self._global_grad_norm = scaled_global_grad_norm / self.loss_scale timer_names = set() @@ -2628,7 +2631,7 @@ def step(self, closure=None): self._prepare_sub_group(sub_group_id, timer_names) #scale the fp32 gradients - self.unscale_and_clip_grads(sub_group_id, self._global_grad_norm) + self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm) #apply the optimizer step on the sub group and copy fp32 parameters to fp16 self._optimizer_step(sub_group_id) From b81d862feac4a176984c92cec9bf8a7dba9df26c Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Fri, 4 Mar 2022 07:52:32 +0500 Subject: [PATCH 12/29] Enable loading fp16_zero_1 into bf16_zero_1 engine and vice versa --- deepspeed/checkpoint/constants.py | 2 +- deepspeed/runtime/bf16_optimizer.py | 10 ++++ deepspeed/runtime/engine.py | 75 ++++++++++++++++--------- deepspeed/runtime/zero/stage_1_and_2.py | 10 +++- 4 files changed, 67 insertions(+), 30 deletions(-) diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index 3905afa6fe97..f45777025db7 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -15,7 +15,7 @@ PARTITION_COUNT = 'partition_count' ZERO_STAGE = 'zero_stage' -CLIP_GRAD = 'clip_gradient' +CLIP_GRAD = 'clip_grad' ######################################### # Module checkpoint keys diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 696fd3929ebb..d8ce41305a47 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -310,6 +310,16 @@ def state_dict(self): return state_dict + # Restore base optimizer fp32 weights bfloat16 weights + def _restore_from_bit16_weights(self): + for i, group in enumerate(self.bf16_groups): + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + for bf16_partitions, fp32_partition in zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition): + fp32_partition.data.copy_(bf16_partitions[partition_id].data) + + def refresh_fp32_params(self): + self._restore_from_bit16_weights() + def load_state_dict(self, state_dict_list, load_optimizer_states=True, diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index a2ebe7b7b681..c5c07066a64c 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -41,7 +41,7 @@ from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.constants import \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ - PLD_THETA, PLD_GAMMA + PLD_THETA, PLD_GAMMA, BFLOAT16, FP16 from deepspeed.runtime.zero.constants import \ ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS, ZERO_OPTIMIZATION_WEIGHTS from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT @@ -2369,20 +2369,32 @@ def load_moe_state_dict(checkpoint_path, def load_module_state_dict(self, state_dict, strict=True): self.module.load_state_dict(state_dict, strict=strict) - def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank): - filename = "bf16_zero_pp_rank_{}".format( - dp_rank) if self.bfloat16_enabled() else "zero_pp_rank_{}".format(dp_rank) + def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode): + return f'bf16_zero_pp_rank_{dp_rank}' if bf16_mode else f'zero_pp_rank_{dp_rank}' + + def _get_rank_zero_ckpt_name(self, + checkpoints_path, + tag, + mp_rank, + dp_rank, + bf16_mode): + file_prefix = self._get_zero_ckpt_prefix(dp_rank, bf16_mode=bf16_mode) zero_ckpt_name = os.path.join( checkpoints_path, str(tag), - filename + "_mp_rank_{:02d}".format(mp_rank) + "_optim_states.pt", + file_prefix + "_mp_rank_{:02d}".format(mp_rank) + "_optim_states.pt", ) return zero_ckpt_name def _get_zero_ckpt_name(self, checkpoints_path, tag): mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() pp_rank = torch.distributed.get_rank(group=self.optimizer.dp_process_group) - return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank) + bf16_mode = self.bfloat16_enabled() + return self._get_rank_zero_ckpt_name(checkpoints_path, + tag, + mp_rank, + pp_rank, + bf16_mode) def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None): if mp_placeholder is not None: @@ -2670,41 +2682,31 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): ) return True - def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size): + def _get_mp_rank_zero_checkpoint_names(self, + load_dir, + tag, + mp_rank, + dp_world_size, + bf16_mode): zero_ckpt_names = [] for dp_rank in range(dp_world_size): ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir, tag=tag, mp_rank=mp_rank, - dp_rank=dp_rank) + dp_rank=dp_rank, + bf16_mode=bf16_mode) zero_ckpt_names.append(ckpt_name) return zero_ckpt_names - def _get_all_zero_checkpoint_names(self, - load_dir, - tag, - mp_world_size, - dp_world_size): - zero_ckpt_names = [] - for mp_rank in range(mp_world_size): - mp_rank_ckpt_names = self._get_mp_rank_zero_checkpoint_names( - load_dir=load_dir, - tag=tag, - mp_rank=mp_rank, - dp_world_size=dp_world_size) - zero_ckpt_names += mp_rank_ckpt_names - - return zero_ckpt_names - - def _get_all_zero_checkpoints(self, load_dir, tag): + def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode): mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names( load_dir=load_dir, tag=tag, mp_rank=mp_rank, dp_world_size=self.loaded_checkpoint_dp_world_size, - ) + bf16_mode=bf16_mode) invalid_zero_ckpt_paths = [] for i, ckpt_name in enumerate(zero_ckpt_names): if not os.path.exists(ckpt_name): @@ -2723,6 +2725,9 @@ def _get_all_zero_checkpoints(self, load_dir, tag): ) return None + return zero_ckpt_names + + def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names): zero_sd_list = [] for i, ckpt_name in enumerate(zero_ckpt_names): _state = None @@ -2740,6 +2745,24 @@ def _get_all_zero_checkpoints(self, load_dir, tag): ) return zero_optimizer_sd + def _get_all_zero_checkpoints(self, load_dir, tag): + for bf16_mode in [self.bfloat16_enabled(), not self.bfloat16_enabled()]: + zero_ckpt_names = self._get_all_zero_checkpoint_names( + load_dir, + tag, + bf16_mode) + if zero_ckpt_names is not None: + # Warn if loading checkpoint of different bit16 type + if bf16_mode is not self.bfloat16_enabled(): + checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16 + engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16 + logger.warn( + f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine' + ) + return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names) + + return None + def _checkpoint_tag_validation(self, tag): if self.checkpoint_tag_validation_enabled(): s_hash = hashlib.sha1(tag.encode()) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 6f518f2f77e6..9a35b1abd2fb 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -28,6 +28,7 @@ PARTITION_COUNT, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, + CLIP_GRAD, ZERO_STAGE) # Toggle this to true to enable correctness test @@ -1971,6 +1972,7 @@ def state_dict(self): state_dict['loss_scaler'] = self.loss_scaler state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['overflow'] = self.overflow + state_dict[CLIP_GRAD] = self.clip_grad if self.elastic_checkpoint: state_dict[BASE_OPTIMIZER_STATE] = self._get_base_optimizer_state() @@ -2125,9 +2127,11 @@ def load_state_dict(self, # I think it should actually be ok to reload the optimizer before the model. dp_rank = dist.get_rank(group=self.dp_process_group) current_rank_sd = state_dict_list[dp_rank] - self.loss_scaler = current_rank_sd['loss_scaler'] - self.dynamic_loss_scale = current_rank_sd['dynamic_loss_scale'] - self.overflow = current_rank_sd['overflow'] + self.loss_scaler = current_rank_sd.get('loss_scaler', self.loss_scaler) + self.dynamic_loss_scale = current_rank_sd.get('dynamic_loss_scale', + self.dynamic_loss_scale) + self.overflow = current_rank_sd.get('overflow', self.overflow) + self.clip_grad = current_rank_sd[CLIP_GRAD] ckpt_version = current_rank_sd.get(DS_VERSION, False) assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed" From 35ea3808faec4c40a1c94d3acecb30afc9c700c6 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Sat, 5 Mar 2022 01:03:08 +0500 Subject: [PATCH 13/29] Fix clip_grad key error --- deepspeed/runtime/bf16_optimizer.py | 2 +- deepspeed/runtime/fp16/fused_optimizer.py | 6 +++--- deepspeed/runtime/zero/stage_1_and_2.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index d8ce41305a47..c6d082e06676 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -331,7 +331,7 @@ def load_state_dict(self, assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed" ckpt_version = pkg_version.parse(ckpt_version) - self.clip_grad = current_rank_sd[CLIP_GRAD] + self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad) if load_optimizer_states: self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE]) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 7944a4b8b925..44ce6a1a7d93 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -11,7 +11,7 @@ from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE from deepspeed.utils import groups, logger, log_dist -from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT +from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD import torch.distributed as dist @@ -418,7 +418,7 @@ def state_dict(self): state_dict['scale_window'] = self.scale_window state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict() state_dict['fp32_groups_flat'] = self.fp32_groups_flat - state_dict['clip_grad'] = self.clip_grad + state_dict[CLIP_GRAD] = self.clip_grad return state_dict # Refresh fp32 master params from fp16 copies @@ -452,7 +452,7 @@ def load_state_dict(self, state_dict, load_optimizer_states=True): self.scale_window = state_dict['scale_window'] if load_optimizer_states: self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT]) - self.clip_grad = state_dict['clip_grad'] + self.clip_grad = state_dict[CLIP_GRAD] # At this point, the optimizer's references to the model's fp32 parameters are up to date. # The optimizer's hyperparameters and internal buffers are also up to date. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 9a35b1abd2fb..b48e32eb8391 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -2131,7 +2131,7 @@ def load_state_dict(self, self.dynamic_loss_scale = current_rank_sd.get('dynamic_loss_scale', self.dynamic_loss_scale) self.overflow = current_rank_sd.get('overflow', self.overflow) - self.clip_grad = current_rank_sd[CLIP_GRAD] + self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad) ckpt_version = current_rank_sd.get(DS_VERSION, False) assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed" From 37011a92bad42b07c2cb742751873ef7073d84b8 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Sat, 5 Mar 2022 22:06:00 +0500 Subject: [PATCH 14/29] Reduce tied weight gradients --- deepspeed/runtime/engine.py | 2 +- deepspeed/runtime/pipe/engine.py | 15 +++++++-------- deepspeed/runtime/pipe/module.py | 7 +++++++ 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index c5c07066a64c..ce09a71aa809 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3096,7 +3096,7 @@ def _save_zero_checkpoint(self, save_path, tag): torch.save(zero_sd, zero_checkpoint_name) if self.global_rank == 0: self._copy_recovery_script(save_path) - ckpt_type = 'zero' if self.zero_optimization() else 'bfl6_zero' + ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero' logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}') def _zero3_consolidated_16bit_state_dict(self): diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 60bbb62335d3..f325e56701dd 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -241,7 +241,13 @@ def _exec_reduce_tied_grads(self): # (see https://github.com/EleutherAI/gpt-neox/issues/62#issuecomment-761471944) if self.zero_optimization_partition_gradients(): self.optimizer.overlapping_partition_gradients_reduce_epilogue() - self.module.allreduce_tied_weight_gradients() + + if self.bfloat16_enabled(): + weight_group_list = self.module.get_tied_weights_and_groups() + for weight, group in weight_group_list: + dist.all_reduce(weight._hp_grad, group=group) + else: + self.module.allreduce_tied_weight_gradients() def _exec_reduce_grads(self): self._force_grad_boundary = True @@ -259,13 +265,6 @@ def _exec_reduce_grads(self): def _bf16_reduce_grads(self): # Make our own list of gradients from the optimizer's FP32 grads grads = [] - # for param_group in self.optimizer.fp32_groups: - # for param in param_group: - # if param.grad is None: - # continue - # assert param.grad is not None - # assert param.grad.dtype == torch.float32 - # grads.append(param.grad.data) self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(), elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index c1b82028673d..3efbc62e2b8d 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -419,6 +419,13 @@ def allreduce_tied_weight_gradients(self): weight = getattr(self.tied_modules[key], comm['weight_attr']) dist.all_reduce(weight.grad, group=comm['group']) + def get_tied_weights_and_groups(self): + weight_group_list = [] + for key, comm in self.tied_comms.items(): + weight = getattr(self.tied_modules[key], comm['weight_attr']) + weight_group_list.append((weight, comm['group'])) + return weight_group_list + def _synchronize_tied_weights(self): for key, comm in self.tied_comms.items(): dist.broadcast( From 61d51fd62141ddb51b629b785af256fac407e048 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Sun, 6 Mar 2022 01:48:32 +0500 Subject: [PATCH 15/29] Fix grad norm for moe --- deepspeed/runtime/engine.py | 2 ++ deepspeed/runtime/fp16/fused_optimizer.py | 10 +++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 468cbb26736e..959b921f68d1 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1235,6 +1235,7 @@ def _configure_fp16_optimizer(self, optimizer): clip_grad=clip_grad, fused_adam_legacy=self.optimizer_legacy_fusion(), timers=timers, + has_moe_layers=self.has_moe_layers, ) else: log_dist( @@ -1249,6 +1250,7 @@ def _configure_fp16_optimizer(self, optimizer): mpu=self.mpu, clip_grad=clip_grad, fused_adam_legacy=self.optimizer_legacy_fusion(), + has_moe_layers=self.has_moe_layers, ) else: log_dist("Creating fp16 unfused optimizer with dynamic loss scale", diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 7ad09b2c4643..dc52552aebba 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -269,9 +269,9 @@ def step(self, closure=None): self.stop_timers([COMPUTE_NORM]) if self.has_moe_layers: - scaled_global_grad_norm = self._get_norm_with_moe_layers(all_groups_norm) - else: - scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) + all_groups_norm = self._get_norm_with_moe_layers(all_groups_norm) + + scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.cur_scale @@ -310,10 +310,10 @@ def _get_norm_with_moe_layers(self, all_groups_norm): if self.using_pipeline: pg = self.deepspeed.mpu.get_data_parallel_group() else: - pg = groups.get_data_parallel_group() + pg = groups._get_data_parallel_group() scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg)) scaled_norm_tensor = torch.tensor(scaled_norm, - device=self.fp32_groups_flat[i].device, + device=self.fp32_groups_flat[0].device, dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=pg) all_groups_norm = scaled_norm_tensor.item() From de3616caba6a9634157b6d985d7b5bd30ef9184c Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 9 Mar 2022 03:58:07 +0500 Subject: [PATCH 16/29] Reduce specified gradients --- deepspeed/runtime/engine.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 151261340f69..fb59b9b0fbb8 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2137,7 +2137,11 @@ def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): numel_per_bucket=elements_per_buffer) def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000): - non_expert_grads, expert_grads = self._get_gradients_for_reduction() + if grads is None: + non_expert_grads, expert_grads = self._get_gradients_for_reduction() + else: + assert not self.has_moe_layers, "attempting to reduce grads in unsupported way w.r.t. MoE" + non_expert_grads = grads self._reduce_non_expert_gradients(non_expert_grads, elements_per_buffer) From ab61edb02a137d91b61bd416b4e8d3eb287b0eba Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 9 Mar 2022 18:55:16 +0500 Subject: [PATCH 17/29] Use O(n) instead of O(n^2) --- deepspeed/runtime/bf16_optimizer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index c6d082e06676..9e8351323580 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -282,10 +282,9 @@ def get_grads_for_norm(self): @torch.no_grad() def update_lp_params(self): - for i, group in enumerate(self.bf16_groups): + for i, (bf16_partitions, fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) - for bf16_partitions, fp32_partition in zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition): - bf16_partitions[partition_id].data.copy_(fp32_partition.data) + bf16_partitions[partition_id].data.copy_(fp32_partition.data) def clear_hp_grads(self): for flat_gradients in self.fp32_groups_gradients_flat: From b7d64fd78610a0d6b8464f63bf87c7300453e007 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 9 Mar 2022 19:55:01 +0500 Subject: [PATCH 18/29] Remove optimizer restriction for bf16 --- deepspeed/runtime/engine.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5783022bcede..39240a3b6a79 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1270,23 +1270,17 @@ def _configure_fp16_optimizer(self, optimizer): def _configure_bf16_optimizer(self, optimizer): clip_grad = self.gradient_clipping() - if APEX_INSTALLED: - fused_opts = (apex.optimizers.FusedAdam, FusedAdam) - else: - fused_opts = FusedAdam - if isinstance(optimizer, fused_opts): - if self.global_rank == 0: - logger.info('Creating unfused BF16 optimizer') - timers = self.timers if self.wall_clock_breakdown() else None - optimizer = BF16_Optimizer( - optimizer, - mpu=self.mpu, - clip_grad=clip_grad, - allgather_bucket_size=self.zero_allgather_bucket_size(), - dp_process_group=self.data_parallel_group, - timers=timers) - else: - raise NotImplementedError('BF16 requires a fused optimizer for now.') + + if self.global_rank == 0: + logger.info('Creating unfused BF16 optimizer') + timers = self.timers if self.wall_clock_breakdown() else None + optimizer = BF16_Optimizer( + optimizer, + mpu=self.mpu, + clip_grad=clip_grad, + allgather_bucket_size=self.zero_allgather_bucket_size(), + dp_process_group=self.data_parallel_group, + timers=timers) return optimizer From 19198688a3e66f5b5a1ab3017ac742b2a0038175 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Fri, 11 Mar 2022 01:49:03 +0500 Subject: [PATCH 19/29] Link bf16 & fp32 params --- deepspeed/runtime/bf16_optimizer.py | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 9e8351323580..094859bd79fa 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -109,6 +109,12 @@ def __init__(self, bf16_dp_partitions[partition_id].clone().float().detach()) self.fp32_groups_flat_partition[i].requires_grad = True + # Link bf16 and fp32 params in partition + self._link_hp_params(self.bf16_groups[i], + self.fp32_groups_flat_partition[i], + partition_id * partition_size, + partition_size) + num_elem_list = [t.numel() for t in self.bf16_groups[i]] # create fp32 gradients @@ -160,6 +166,32 @@ def __init__(self, see_memory_usage('end bf16_optimizer', force=True) + def _link_hp_params(self, + lp_param_list, + flat_hp_partition, + partition_start, + partition_size): + current_offset = 0 + partition_end = partition_start + partition_size + for lp_param in lp_param_list: + # lp_param does not overlap with partition if either is true + # 1) current_offset >= partition_end, i.e belongs to later partition + # 2) current_offset + lp_param.numel() < partition_start, i.e., belongs to earlier partition + tensor_end = current_offset + lp_param.numel() + if (current_offset >= partition_end) or (tensor_end < partition_start): + lp_param._hp_param = None + continue + + narrow_offset = max(current_offset, partition_start) - partition_start + narrow_elem = min(tensor_end, + partition_end) - max(current_offset, + partition_start) + + hp_param = flat_hp_partition.narrow(0, narrow_offset, narrow_elem) + + lp_param._hp_param = hp_param + current_offset += lp_param.numel() + def initialize_optimizer_states(self): """Take an optimizer step with zero-valued gradients to allocate internal optimizer state. From 77b649d160c1cd86f33415e2a7deab50c45fba16 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Fri, 11 Mar 2022 09:20:43 +0500 Subject: [PATCH 20/29] Clip gradients of last stage tied weights --- deepspeed/runtime/bf16_optimizer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 094859bd79fa..08daa1ffd7ab 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -238,10 +238,11 @@ def step(self, closure=None): assert all_groups_norm > 0. if self.clip_grad > 0.: - clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(), - max_norm=self.clip_grad, - global_norm=all_groups_norm, - mpu=self.mpu) + clip_tensors_by_global_norm( + input_tensors=self.get_grads_for_norm(for_clipping=True), + max_norm=self.clip_grad, + global_norm=all_groups_norm, + mpu=self.mpu) self.optimizer.step() @@ -294,13 +295,14 @@ def get_grads_for_reduction(self): return self.fp32_groups_gradients_flat @torch.no_grad() - def get_grads_for_norm(self): + def get_grads_for_norm(self, for_clipping=False): grads = [] tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): - if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated: - continue + if not for_clipping: + if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated: + continue if (tensor_mp_rank > 0) and not is_model_parallel_parameter(lp): continue From 2aa612a6e823fe12a2fa1f224299a78cb3c47a17 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 17 Mar 2022 18:01:00 +0500 Subject: [PATCH 21/29] Simplify tied weights reduction logic --- deepspeed/runtime/engine.py | 3 ++- deepspeed/runtime/pipe/engine.py | 10 ++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index cbfacb140a6a..71d72bd39530 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -971,7 +971,8 @@ def _configure_distributed_model(self, model): self.__check_params(self.module, torch.bfloat16) if self.zero_optimization_stage() == 0 and not self.pipeline_parallelism: raise NotImplementedError( - "BF16 support is not yet implemented when not running ZeRO") + "When not running ZeRO, BF16 training support is only supported for Pipeline parallelism" + ) self.module.bfloat16() else: self.__check_params(self.module, torch.float) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index f325e56701dd..2a96ef897d01 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -242,12 +242,10 @@ def _exec_reduce_tied_grads(self): if self.zero_optimization_partition_gradients(): self.optimizer.overlapping_partition_gradients_reduce_epilogue() - if self.bfloat16_enabled(): - weight_group_list = self.module.get_tied_weights_and_groups() - for weight, group in weight_group_list: - dist.all_reduce(weight._hp_grad, group=group) - else: - self.module.allreduce_tied_weight_gradients() + weight_group_list = self.module.get_tied_weights_and_groups() + for weight, group in weight_group_list: + grad = weight._hp_grad if self.bfloat16_enabled() else weight.grad + dist.all_reduce(grad, group=group) def _exec_reduce_grads(self): self._force_grad_boundary = True From e24814a10de04ce280efe2adb027b023e3336493 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 24 Mar 2022 21:17:17 +0000 Subject: [PATCH 22/29] Also clip all tp rank parameters --- deepspeed/runtime/bf16_optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 08daa1ffd7ab..861f6adaa408 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -304,8 +304,8 @@ def get_grads_for_norm(self, for_clipping=False): if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated: continue - if (tensor_mp_rank > 0) and not is_model_parallel_parameter(lp): - continue + if not (tensor_mp_rank == 0 or is_model_parallel_parameter(p)): + continue if not self.fp32_groups_has_gradients[i][j]: continue From 20697bc4a949b0bd376a5e2155e62c8bd2074904 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Sat, 26 Mar 2022 02:02:48 +0500 Subject: [PATCH 23/29] lp to hp mapping --- deepspeed/runtime/bf16_optimizer.py | 221 +++++++++++++++++++++++++--- 1 file changed, 201 insertions(+), 20 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 861f6adaa408..38d2dedd487e 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -110,6 +110,7 @@ def __init__(self, self.fp32_groups_flat_partition[i].requires_grad = True # Link bf16 and fp32 params in partition + # TODO: Make this configurable self._link_hp_params(self.bf16_groups[i], self.fp32_groups_flat_partition[i], partition_id * partition_size, @@ -166,31 +167,128 @@ def __init__(self, see_memory_usage('end bf16_optimizer', force=True) + def _init_lp_to_hp_mapping(self, lp_param_list, partition_start, partition_size): + current_offset = 0 + param_and_offset_list = [] + partition_end = partition_start + partition_size + for lp_param in lp_param_list: + lp_param._hp_mapping = None + # lp_param overlaps with partition if both are true + # 1) current_offset < partition_end, + # 2) current_offset + lp_param.numel() >= partition_start + lp_param_end = current_offset + lp_param.numel() + if current_offset < partition_end and lp_param_end >= partition_start: + param_and_offset_list.append((lp_param, current_offset)) + current_offset += lp_param.numel() + + return param_and_offset_list, + + +# def _link_hp_params_0(self, +# lp_param_list, +# flat_hp_partition, +# partition_start, +# partition_size): +# +# from dataclasses import dataclass +# @dataclass +# class fragment_address: +# numel: int +# start: int +# end: int +# +# current_offset = 0 +# partition_end = partition_start + partition_size +# for lp_param in lp_param_list: +# # lp_param does not overlap with partition if either is true +# # 1) current_offset >= partition_end, i.e belongs to later partition +# # 2) current_offset + lp_param.numel() < partition_start, i.e., belongs to earlier partition +# lp_param_end = current_offset + lp_param.numel() +# if (current_offset >= partition_end) or (lp_param_end < partition_start): +# lp_param._hp_param = None +# continue +# +# lp_fragment_start = max(current_offset, partition_start) +# lp_fragment_end = min(lp_param_end, partition_end) +# lp_fragment_elem = lp_fragment_end - lp_fragment_start +# +# hp_fragment_start = lp_fragment_start - partition_start +# hp_frag_address = fragment_address( +# numel=lp_fragment_elem, +# start=hp_fragment_start, +# end=lp_fragment_end +# ) +# lp_param._hp_frag_address = hp_frag_address +# lp_param._hp_fragment = flat_hp_partition.narrow(0, hp_narrow_offset, lp_fragment_elem) +# +# if current_offset >= partition_start: +# lp_param._lp_frag_start = 0 +# else: +# lp_param._hp_copy_offset = partition_start - current_offset +# +# lp_param._lp_frag_address = fragment_address( +# numel=lp_fragment_elem, +# start= +# ) +# current_offset += lp_param.numel() + def _link_hp_params(self, lp_param_list, flat_hp_partition, partition_start, partition_size): - current_offset = 0 - partition_end = partition_start + partition_size - for lp_param in lp_param_list: - # lp_param does not overlap with partition if either is true - # 1) current_offset >= partition_end, i.e belongs to later partition - # 2) current_offset + lp_param.numel() < partition_start, i.e., belongs to earlier partition - tensor_end = current_offset + lp_param.numel() - if (current_offset >= partition_end) or (tensor_end < partition_start): - lp_param._hp_param = None - continue - - narrow_offset = max(current_offset, partition_start) - partition_start - narrow_elem = min(tensor_end, - partition_end) - max(current_offset, - partition_start) - - hp_param = flat_hp_partition.narrow(0, narrow_offset, narrow_elem) - - lp_param._hp_param = hp_param - current_offset += lp_param.numel() + + local_lp_param_and_offset = self._init_lp_to_hp_mapping( + lp_param_list, + partition_start, + partition_size) + + from dataclasses import dataclass + + @dataclass + class fragment_address: + numel: int + start: int + + @dataclass + class tensor_fragment: + lp_fragment: torch.Tensor + lp_fragment_address: fragment_address + hp_fragment: torch.Tensor + hp_fragment_address: fragment_address + + def update_hp(self): + self.hp_fragment.data.copy_(self.lp_fragment.data) + + def update_lp(self): + self.lp_fragment.data.copy_(self.hp_fragment.data) + + hp_end = partition_start + partition_size + for lp_param, lp_start in local_lp_param_and_offset: + lp_end = lp_param.numel() + lp_start + hp_start = partition_start + + fragment_start = max(lp_start, hp_start) + fragment_end = min(lp_end, hp_end) + assert fragment_start < fragment_end, \ + f'fragment start {fragment_start} should be < fragment_end {fragment_end}' + + fragment_numel = fragment_end - fragment_start + hp_frag_address = fragment_address(start=fragment_start - hp_start, + numel=fragment_numel) + hp_fragment_tensor = flat_hp_partition.narrow(0, + hp_frag_address.start, + hp_frag_address.numel) + + lp_frag_address = fragment_address(start=fragment_start - lp_start, + numel=fragment_numel) + lp_fragment_tensor = lp_param.flatten().narrow(0, + lp_frag_address.start, + lp_frag_address.numel) + lp_param._hp_mapping = tensor_fragment(lp_fragment=lp_fragment_tensor, + lp_fragment_address=lp_frag_address, + hp_fragment=hp_fragment_tensor, + hp_fragment_address=hp_frag_address) def initialize_optimizer_states(self): """Take an optimizer step with zero-valued gradients to allocate internal @@ -387,3 +485,86 @@ def _get_padded_tensor(src_tensor, size): slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel()) slice_tensor.data.copy_(src_tensor.data) return padded_tensor + + +''' +Logic for lp_param to hp_param mapping + +lp lp0 lp1 lp2 lp3 lp4 <------- indices/names +lp [ ][ ][ ][ ][ ] <-------- tensors +flat_lp [ ] <-------- flat lp params +flat_hp [ ] <------------------ flat hp partition on current rank +full_hp [ ] <------- full flat hp params + + +lp2 + full numel = 16 + lp_frag + numel = 12 + frag_start = 3 + frag_end = 15 + hp_frag + numel = 12 + frag_start = 0 + frag_end = 11 + + hp_frag.copy_(lp_frag) + + +lp3: + full numel = 4 + lp_frag + numel = 4 + start = 0 + end = 3 + hp_frag + numel = 4 + start = 12 + end = 15 + + +lp4: + full numel = 12 + lp_frag + numel = 4 + start = 0 + end = 3 + hp_frag + numel = 4 + start = 16 + end = 19 + + + +Visual depiction of above +lp { } +flat_lp [ ] +flat_hp ( ) + + +flat_lp [ { ( } ) ] + lx hx ly hy + ly-hx + + +lp { } +flat_lp [ ] +flat_hp ( ) + + +flat_lp [ ( { ) } ] + hx lx hy ly + hy-lx + +lp { } +flat_lp [ ] +flat_hp ( ) + + +flat_lp [ ( { } ) ] + hx lx ly hy + ly-lx + +lp -> (lx, hy) +flat_hp -> (hx, hy) +''' From 4e8f7fff9a1575f16191e16d2d161af4e6b52b51 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 29 Mar 2022 00:12:33 +0500 Subject: [PATCH 24/29] Link lp/hp/optim state; Refresh links after checkpoint load --- deepspeed/runtime/bf16_optimizer.py | 54 ++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 38d2dedd487e..0e5d77bedd42 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -109,13 +109,6 @@ def __init__(self, bf16_dp_partitions[partition_id].clone().float().detach()) self.fp32_groups_flat_partition[i].requires_grad = True - # Link bf16 and fp32 params in partition - # TODO: Make this configurable - self._link_hp_params(self.bf16_groups[i], - self.fp32_groups_flat_partition[i], - partition_id * partition_size, - partition_size) - num_elem_list = [t.numel() for t in self.bf16_groups[i]] # create fp32 gradients @@ -165,8 +158,23 @@ def __init__(self, self.initialize_optimizer_states() see_memory_usage('end initialize_optimizer', force=True) + # Need optimizer states initialized before linking lp to optimizer state + self._link_all_hp_params() + see_memory_usage('end bf16_optimizer', force=True) + def _link_all_hp_params(self): + dp_world_size = dist.get_world_size(group=self.dp_process_group) + for i, param_group in enumerate(self.optimizer.param_groups): + # Link bf16 and fp32 params in partition + # TODO: Make this configurable + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + partition_size = self.bf16_groups_flat[i].numel() // dp_world_size + self._link_hp_params(self.bf16_groups[i], + self.fp32_groups_flat_partition[i], + partition_id * partition_size, + partition_size) + def _init_lp_to_hp_mapping(self, lp_param_list, partition_start, partition_size): current_offset = 0 param_and_offset_list = [] @@ -177,11 +185,11 @@ def _init_lp_to_hp_mapping(self, lp_param_list, partition_start, partition_size) # 1) current_offset < partition_end, # 2) current_offset + lp_param.numel() >= partition_start lp_param_end = current_offset + lp_param.numel() - if current_offset < partition_end and lp_param_end >= partition_start: + if current_offset < partition_end and lp_param_end > partition_start: param_and_offset_list.append((lp_param, current_offset)) current_offset += lp_param.numel() - return param_and_offset_list, + return param_and_offset_list # def _link_hp_params_0(self, @@ -237,7 +245,6 @@ def _link_hp_params(self, flat_hp_partition, partition_start, partition_size): - local_lp_param_and_offset = self._init_lp_to_hp_mapping( lp_param_list, partition_start, @@ -256,6 +263,7 @@ class tensor_fragment: lp_fragment_address: fragment_address hp_fragment: torch.Tensor hp_fragment_address: fragment_address + optim_fragment: {} def update_hp(self): self.hp_fragment.data.copy_(self.lp_fragment.data) @@ -263,6 +271,12 @@ def update_hp(self): def update_lp(self): self.lp_fragment.data.copy_(self.hp_fragment.data) + def get_optim_state(self, key): + if key in self.optim_fragment: + return self.optim_fragment[key] + else: + raise ValueError(f'{key} not found in optimizer state fragment') + hp_end = partition_start + partition_size for lp_param, lp_start in local_lp_param_and_offset: lp_end = lp_param.numel() + lp_start @@ -270,6 +284,9 @@ def update_lp(self): fragment_start = max(lp_start, hp_start) fragment_end = min(lp_end, hp_end) + print( + f'{self.dp_rank=} {lp_start=} {lp_end-lp_start=} {hp_start=} {hp_end-hp_start=} {fragment_start=} {fragment_end-fragment_start=}' + ) assert fragment_start < fragment_end, \ f'fragment start {fragment_start} should be < fragment_end {fragment_end}' @@ -280,15 +297,26 @@ def update_lp(self): hp_frag_address.start, hp_frag_address.numel) + optim_fragment = { + key: value.narrow(0, + hp_frag_address.start, + hp_frag_address.numel) + for key, + value in self.optimizer.state[flat_hp_partition].items() + if torch.is_tensor(value) + } + lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel) lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel) + lp_param._hp_mapping = tensor_fragment(lp_fragment=lp_fragment_tensor, lp_fragment_address=lp_frag_address, hp_fragment=hp_fragment_tensor, - hp_fragment_address=hp_frag_address) + hp_fragment_address=hp_frag_address, + optim_fragment=optim_fragment) def initialize_optimizer_states(self): """Take an optimizer step with zero-valued gradients to allocate internal @@ -402,7 +430,7 @@ def get_grads_for_norm(self, for_clipping=False): if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated: continue - if not (tensor_mp_rank == 0 or is_model_parallel_parameter(p)): + if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)): continue if not self.fp32_groups_has_gradients[i][j]: @@ -472,6 +500,8 @@ def load_state_dict(self, src_tensor = _get_padded_tensor(saved, current.numel()) current.data.copy_(src_tensor.data) + self._link_all_hp_params() + @property def param_groups(self): """Forward the wrapped optimizer's parameters.""" From 5481b8648042fe6cbef0c9e800b6d15917ad0acb Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 29 Mar 2022 00:27:48 +0500 Subject: [PATCH 25/29] Remove debug print --- deepspeed/runtime/bf16_optimizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 0e5d77bedd42..50fcf296cf69 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -284,9 +284,9 @@ def get_optim_state(self, key): fragment_start = max(lp_start, hp_start) fragment_end = min(lp_end, hp_end) - print( - f'{self.dp_rank=} {lp_start=} {lp_end-lp_start=} {hp_start=} {hp_end-hp_start=} {fragment_start=} {fragment_end-fragment_start=}' - ) +# print( +# f'{self.dp_rank=} {lp_start=} {lp_end-lp_start=} {hp_start=} {hp_end-hp_start=} {fragment_start=} {fragment_end-fragment_start=}' +# ) assert fragment_start < fragment_end, \ f'fragment start {fragment_start} should be < fragment_end {fragment_end}' From d911e672248c99c82993a331b79c635e8ea7cfc5 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 29 Mar 2022 00:29:09 +0500 Subject: [PATCH 26/29] Remove debug print --- deepspeed/runtime/bf16_optimizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 50fcf296cf69..f7831a28c97f 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -284,9 +284,9 @@ def get_optim_state(self, key): fragment_start = max(lp_start, hp_start) fragment_end = min(lp_end, hp_end) -# print( -# f'{self.dp_rank=} {lp_start=} {lp_end-lp_start=} {hp_start=} {hp_end-hp_start=} {fragment_start=} {fragment_end-fragment_start=}' -# ) + # print( + # f'{self.dp_rank=} {lp_start=} {lp_end-lp_start=} {hp_start=} {hp_end-hp_start=} {fragment_start=} {fragment_end-fragment_start=}' + # ) assert fragment_start < fragment_end, \ f'fragment start {fragment_start} should be < fragment_end {fragment_end}' From 144f652768a980b7a4f1e1104ae7872189742c99 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 30 Mar 2022 05:12:08 +0500 Subject: [PATCH 27/29] Simplify zero_grad logic --- deepspeed/runtime/engine.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 71d72bd39530..22dac70f3eae 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1804,11 +1804,16 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): ) # zero grad in basic optimizer could be unreliable and may not exhibit # the behaviour that we want - if (not self.zero_optimization() and not self.fp16_enabled() - and not self.amp_enabled()): - self.zero_grad() - elif not self.bfloat16_enabled(): + if self.bfloat16_enabled(): + # TODO: Temporary until bf16_optimizer and zero_optimizer are integrated + if self.zero_optimization(): + self.optimizer.zero_grad() + else: + pass + elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled(): self.optimizer.zero_grad() + else: + self.zero_grad() report_progress = self.global_rank == 0 if self.global_rank else True From bb70816fc9bdda9b77b69ff361f511dd1097b3ae Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 30 Mar 2022 06:37:21 +0500 Subject: [PATCH 28/29] fp32 accessors --- deepspeed/runtime/bf16_optimizer.py | 144 ++++++++++++---------------- 1 file changed, 63 insertions(+), 81 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index f7831a28c97f..0bd173aceb66 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -22,6 +22,55 @@ CLIP_GRAD, GROUPS_PADDING) +import types + +from dataclasses import dataclass + + +@dataclass +class fragment_address: + numel: int + start: int + + +@dataclass +class tensor_fragment: + lp_fragment: torch.Tensor + lp_fragment_address: fragment_address + hp_fragment: torch.Tensor + hp_fragment_address: fragment_address + optim_fragment: {} + + def update_hp(self): + self.hp_fragment.data.copy_(self.lp_fragment.data) + + def update_lp(self): + self.lp_fragment.data.copy_(self.hp_fragment.data) + + def get_optim_state_fragment(self, key): + if key in self.optim_fragment: + return self.optim_fragment[key] + else: + raise ValueError(f'{key} not found in optimizer state fragment') + + +def get_full_hp_param(self, optim_state_key=None): + reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() + if self._hp_mapping is not None: + lp_frag_address = self._hp_mapping.lp_fragment_address + reduce_fragment = torch.narrow(reduce_buffer, + 0, + lp_frag_address.start, + lp_frag_address.numel) + if optim_state_key is None: + hp_fragment = self._hp_mapping.hp_fragment + else: + hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key) + + reduce_fragment.data.copy_(hp_fragment.data) + torch.distributed.all_reduce(reduce_buffer, group=self._dp_group) + return reduce_buffer.reshape_as(self) + class BF16_Optimizer: def __init__(self, @@ -173,14 +222,21 @@ def _link_all_hp_params(self): self._link_hp_params(self.bf16_groups[i], self.fp32_groups_flat_partition[i], partition_id * partition_size, - partition_size) - - def _init_lp_to_hp_mapping(self, lp_param_list, partition_start, partition_size): + partition_size, + self.real_dp_process_group[i]) + + def _init_lp_to_hp_mapping(self, + lp_param_list, + partition_start, + partition_size, + dp_group): current_offset = 0 param_and_offset_list = [] partition_end = partition_start + partition_size for lp_param in lp_param_list: lp_param._hp_mapping = None + lp_param._dp_group = dp_group + lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param) # lp_param overlaps with partition if both are true # 1) current_offset < partition_end, # 2) current_offset + lp_param.numel() >= partition_start @@ -191,91 +247,17 @@ def _init_lp_to_hp_mapping(self, lp_param_list, partition_start, partition_size) return param_and_offset_list - -# def _link_hp_params_0(self, -# lp_param_list, -# flat_hp_partition, -# partition_start, -# partition_size): -# -# from dataclasses import dataclass -# @dataclass -# class fragment_address: -# numel: int -# start: int -# end: int -# -# current_offset = 0 -# partition_end = partition_start + partition_size -# for lp_param in lp_param_list: -# # lp_param does not overlap with partition if either is true -# # 1) current_offset >= partition_end, i.e belongs to later partition -# # 2) current_offset + lp_param.numel() < partition_start, i.e., belongs to earlier partition -# lp_param_end = current_offset + lp_param.numel() -# if (current_offset >= partition_end) or (lp_param_end < partition_start): -# lp_param._hp_param = None -# continue -# -# lp_fragment_start = max(current_offset, partition_start) -# lp_fragment_end = min(lp_param_end, partition_end) -# lp_fragment_elem = lp_fragment_end - lp_fragment_start -# -# hp_fragment_start = lp_fragment_start - partition_start -# hp_frag_address = fragment_address( -# numel=lp_fragment_elem, -# start=hp_fragment_start, -# end=lp_fragment_end -# ) -# lp_param._hp_frag_address = hp_frag_address -# lp_param._hp_fragment = flat_hp_partition.narrow(0, hp_narrow_offset, lp_fragment_elem) -# -# if current_offset >= partition_start: -# lp_param._lp_frag_start = 0 -# else: -# lp_param._hp_copy_offset = partition_start - current_offset -# -# lp_param._lp_frag_address = fragment_address( -# numel=lp_fragment_elem, -# start= -# ) -# current_offset += lp_param.numel() - def _link_hp_params(self, lp_param_list, flat_hp_partition, partition_start, - partition_size): + partition_size, + dp_group): local_lp_param_and_offset = self._init_lp_to_hp_mapping( lp_param_list, partition_start, - partition_size) - - from dataclasses import dataclass - - @dataclass - class fragment_address: - numel: int - start: int - - @dataclass - class tensor_fragment: - lp_fragment: torch.Tensor - lp_fragment_address: fragment_address - hp_fragment: torch.Tensor - hp_fragment_address: fragment_address - optim_fragment: {} - - def update_hp(self): - self.hp_fragment.data.copy_(self.lp_fragment.data) - - def update_lp(self): - self.lp_fragment.data.copy_(self.hp_fragment.data) - - def get_optim_state(self, key): - if key in self.optim_fragment: - return self.optim_fragment[key] - else: - raise ValueError(f'{key} not found in optimizer state fragment') + partition_size, + dp_group) hp_end = partition_start + partition_size for lp_param, lp_start in local_lp_param_and_offset: From fa4ff11dd83948812128495af0b1d39d1b14cfc1 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Fri, 1 Apr 2022 00:44:25 +0500 Subject: [PATCH 29/29] Fix update bug --- deepspeed/runtime/bf16_optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 0bd173aceb66..43c23f0b77f1 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -432,8 +432,8 @@ def clear_hp_grads(self): for flat_gradients in self.fp32_groups_gradients_flat: flat_gradients.zero_() - for group in self.fp32_groups_has_gradients: - group = [False] * len(group) + for i, group in enumerate(self.fp32_groups_gradients): + self.fp32_groups_has_gradients[i] = [False] * len(group) def clear_lp_grads(self): for group in self.bf16_groups: