diff --git a/deepspeed/pt/deepspeed_config.py b/deepspeed/pt/deepspeed_config.py index f618124d74c7..b63e96a6d925 100755 --- a/deepspeed/pt/deepspeed_config.py +++ b/deepspeed/pt/deepspeed_config.py @@ -185,7 +185,7 @@ def get_optimizer_gradient_clipping(param_dict): def get_optimizer_legacy_fusion(param_dict): if OPTIMIZER in param_dict.keys() and \ - LEGACY_FUSION in param_dict[OPTIMIZER].keys(): + LEGACY_FUSION in param_dict[OPTIMIZER].keys(): return param_dict[OPTIMIZER][LEGACY_FUSION] else: return LEGACY_FUSION_DEFAULT @@ -260,6 +260,19 @@ def get_tensorboard_job_name(param_dict): return TENSORBOARD_JOB_NAME_DEFAULT +def print_config(config): + for arg in sorted(vars(config)): + if arg != '_param_dict': + arg_object = getattr(config, arg) + dots = '.' * (29 - len(arg)) + if hasattr(arg_object, '__dict__'): + logger.info(' {} {} Begin'.format(arg, dots)) + print_config(arg_object) + logger.info(' {} {} End'.format(arg, dots)) + else: + logger.info(' {} {} {}'.format(arg, dots, arg_object)) + + '''Write deepspeed config files by modifying basic templates. Can be used for quicly changing parameters via command line parameters.''' @@ -340,7 +353,7 @@ def _initialize_params(self, param_dict): self.optimizer_name = get_optimizer_name(param_dict) if self.optimizer_name is not None and \ - self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS: + self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS: self.optimizer_name = self.optimizer_name.lower() self.optimizer_params = get_optimizer_params(param_dict) @@ -374,9 +387,9 @@ def _batch_assertion(self): f'Gradient accumulation steps: {grad_acc} has to be greater than 0' assert train_batch == micro_batch * grad_acc * self.world_size, \ - (f'Check batch related parameters. train_batch_size is not equal' - ' to micro_batch_per_gpu * gradient_acc_step * world_size' - f'{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}') + (f'Check batch related parameters. train_batch_size is not equal' + ' to micro_batch_per_gpu * gradient_acc_step * world_size' + f'{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}') def _set_batch_related_parameters(self): @@ -384,29 +397,29 @@ def _set_batch_related_parameters(self): micro_batch = self.train_micro_batch_size_per_gpu grad_acc = self.gradient_accumulation_steps - #all values are provided nothing needs to be set + # all values are provided nothing needs to be set if train_batch is not None and \ - micro_batch is not None and \ - grad_acc is not None: + micro_batch is not None and \ + grad_acc is not None: return - #global_accumulation_steps needs to be set + # global_accumulation_steps needs to be set elif train_batch is not None and \ - micro_batch is not None: + micro_batch is not None: grad_acc = train_batch // micro_batch grad_acc //= self.world_size self.gradient_accumulation_steps = grad_acc - #micro_batch_per_gpu needs to be set + # micro_batch_per_gpu needs to be set elif train_batch is not None and \ - grad_acc is not None: + grad_acc is not None: micro_batch = train_batch // self.world_size micro_batch //= grad_acc self.train_micro_batch_size_per_gpu = micro_batch - #train_batch_size needs to be set + # train_batch_size needs to be set elif micro_batch is not None and \ - grad_acc is not None: + grad_acc is not None: train_batch_size = micro_batch * grad_acc train_batch_size *= self.world_size self.train_batch_size = train_batch_size @@ -421,7 +434,7 @@ def _set_batch_related_parameters(self): self.train_batch_size = micro_batch * self.world_size self.gradient_accumulation_steps = 1 - #either none of the three parameters are provided or just gradient_accumulation_step is provided + # either none of the three parameters are provided or just gradient_accumulation_step is provided else: assert False, \ 'Either train_batch_size or micro_batch_per_gpu needs to be provided' @@ -441,10 +454,7 @@ def _do_sanity_check(self): def print(self, name): logger.info('{}:'.format(name)) - for arg in sorted(vars(self)): - if arg != '_param_dict': - dots = '.' * (29 - len(arg)) - logger.info(' {} {} {}'.format(arg, dots, getattr(self, arg))) + print_config(self) logger.info(' json = {}'.format( json.dumps(self._param_dict, @@ -456,9 +466,11 @@ def print(self, name): def _do_error_check(self): if self.zero_enabled: assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled" - assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION) + assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( + MAX_STAGE_ZERO_OPTIMIZATION) - assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU) + assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format( + TRAIN_MICRO_BATCH_SIZE_PER_GPU) assert self.gradient_accumulation_steps, 'DeepSpeedConfig: {} is not defined'.format( GRADIENT_ACCUMULATION_STEPS) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index c6e7623b1792..7396d130fa83 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -311,6 +311,9 @@ def zero_contiguous_gradients(self): def zero_load_from_fp32_weights(self): return self._config.zero_config.load_from_fp32_weights + def zero_elastic_checkpoint(self): + return self._config.zero_config.elastic_checkpoint + def allgather_size(self): return self._config.allgather_size @@ -596,6 +599,7 @@ def _configure_zero_optimizer(self, optimizer): allgather_size=self.zero_allgather_bucket_size(), max_elements_per_comm=self.zero_reduce_bucket_size(), dp_process_group=self.data_parallel_group, + elastic_checkpoint=self.zero_elastic_checkpoint(), mpu=self.mpu) elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS: assert self.gradient_accumulation_steps( diff --git a/deepspeed/pt/deepspeed_zero_config.py b/deepspeed/pt/deepspeed_zero_config.py index 4f654d3b8c30..c377b6550f28 100755 --- a/deepspeed/pt/deepspeed_zero_config.py +++ b/deepspeed/pt/deepspeed_zero_config.py @@ -62,22 +62,22 @@ ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size' ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights' ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True +ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT = 'elastic_checkpoint' +ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT = False ZERO_OPTIMIZATION_DEFAULT = { - ZERO_OPTIMIZATION_STAGE: - ZERO_OPTIMIZATION_STAGE_DEFAULT, + ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT, ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS: ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT, - ZERO_OPTIMIZATION_REDUCE_SCATTER: - ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT, - ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: - ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT, + ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT, + ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS: ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE: ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS: - ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT + ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT, + ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT: ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT } @@ -93,6 +93,7 @@ def __init__(self, param_dict): self.allgather_bucket_size = None self.overlap_comm = None self.load_from_fp32_weights = None + self.elastic_checkpoint = None if ZERO_OPTIMIZATION in param_dict.keys(): zero_config_dict = param_dict[ZERO_OPTIMIZATION] @@ -157,7 +158,13 @@ def _initialize(self, zero_config_dict): zero_config_dict, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT) + self.load_from_fp32_weights = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT) + + self.elastic_checkpoint = get_scalar_param( + zero_config_dict, + ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT, + ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT) diff --git a/deepspeed/pt/zero_optimizer_stage1.py b/deepspeed/pt/zero_optimizer_stage1.py index 9a2ab8763e0f..fe8a19ebbf97 100755 --- a/deepspeed/pt/zero_optimizer_stage1.py +++ b/deepspeed/pt/zero_optimizer_stage1.py @@ -19,13 +19,25 @@ def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_s return min(sub_partition_size, sub_partition_high_limit - flattened_lean_size) -def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_count): +def get_group_alignment_padding(tensor_list, + sub_partition_size, + sub_partition_count, + group_index, + dp_process_group): group_paddings = [] flattened_size = sum([tensor.numel() for tensor in tensor_list]) for i in range(sub_partition_count): padding = get_alignment_padding(flattened_size, i, sub_partition_size) group_paddings.append(padding) + if not dist.is_initialized() or dist.get_rank(group=dp_process_group) == 0: + logger.info(f"****Group Padding information {group_index}*****") + logger.info(f"tensor_size = {flattened_size}") + logger.info(f"sub_partition_size = {sub_partition_size}") + logger.info(f"sub_partition_count = {sub_partition_count}") + for i, padding in enumerate(group_paddings): + logger.info(f"padding[{i}] = {padding}") + return group_paddings @@ -124,7 +136,8 @@ def __init__(self, all_gather_partitions=True, allgather_size=500000000, clip_grad=0.0, - max_elements_per_comm=5e8): + max_elements_per_comm=5e8, + elastic_checkpoint=False): if dp_process_group is not None and partition_size is not None: raise ValueError("Cannot specify both dp_process_group " @@ -147,6 +160,9 @@ def __init__(self, self.max_elements_per_comm = max_elements_per_comm logger.info("max_elements_per_comm={}".format(max_elements_per_comm)) + self.elastic_checkpoint = elastic_checkpoint + logger.info(f'ZeRO Elastic Checkpointing: {elastic_checkpoint}') + # param flattened by groups self.fp16_groups = [] self.fp16_groups_flat = [] @@ -243,7 +259,9 @@ def __init__(self, sub_partition_paddings = get_group_alignment_padding( tensor_list=self.fp16_groups[i], sub_partition_size=sub_partition_size, - sub_partition_count=num_comm_intervals * self.partition_count) + sub_partition_count=num_comm_intervals * self.partition_count, + group_index=i, + dp_process_group=dp_process_group) self.group_paddings.append(sub_partition_paddings) # modify optimizer of have flat master weight @@ -315,7 +333,8 @@ def get_data_parallel_sub_partitions(tensor, # Ensure partition alignment was done correctly num_sub_partitions = int(total_num_elements // sub_partition_size) - assert total_num_elements % sub_partition_size == 0, "{} % {} != 0".format(total_num_elements, sub_partition_size) + assert total_num_elements % sub_partition_size == 0, "{} % {} != 0".format( + total_num_elements, sub_partition_size) # Ensure comm interval alignment was done correctly. num_comm_intervals = int(num_sub_partitions // world_size) @@ -379,14 +398,14 @@ def get_all_sub_partition_info(tensor_list, prev_comm_idx = 0 for iii, tensor in enumerate(tensor_list): tensor_size = tensor.numel() - #if local_rank == 0: + # if local_rank == 0: # # logger.info("rank={}, current_index={}, tensor_size={}, tensor-idx={}".format(rank, # current_index, tensor_size, iii)) results_list = _range_check(current_index, all_element_intervals[rank], tensor_size) for contained, offset, comm_idx in results_list: - #if local_rank == 0: + # if local_rank == 0: # logger.info("rank={}, contained={}, offset={}, comm_idx={}".format(rank, contained, # offset, comm_idx)) if contained: @@ -441,7 +460,7 @@ def get_flat_sub_partitions(comm_tensor_list, num_elements = tensor.numel() tensor_offset = 0 - #we need to offset to get to the right element + # we need to offset to get to the right element if i == 0 and param_offsets[i] > 0: tensor_offset = param_offsets[i] num_elements = num_elements - tensor_offset @@ -451,8 +470,8 @@ def get_flat_sub_partitions(comm_tensor_list, if num_elements > (sub_partition_size - current_size): num_elements = sub_partition_size - current_size - #we need a narrow view of the tensor based on the tensor offset and number of elements that - #we need from this tensor + # we need a narrow view of the tensor based on the tensor offset and number of elements that + # we need from this tensor if tensor_offset > 0 or num_elements < tensor.numel(): flat_tensor_list.append(tensor.contiguous().view(-1).narrow( 0, @@ -462,12 +481,12 @@ def get_flat_sub_partitions(comm_tensor_list, flat_tensor_list.append(tensor.to(dtype)) my_params.append(param) - #remember offset into partition and #elems for this tensor + # remember offset into partition and #elems for this tensor my_offsets.append((current_size, num_elements)) current_size = current_size + num_elements - #this means its the last partition and does not align with the dp boundary. We need to pad before flattening + # this means its the last partition and does not align with the dp boundary. We need to pad before flattening if current_size < sub_partition_size: my_offsets.append((None, None)) my_params.append(None) @@ -482,7 +501,7 @@ def get_flat_sub_partitions(comm_tensor_list, torch.zeros(int(sub_partition_size - current_size), dtype=dtype, device=tensor_list[0].device)) - partition_params.append(my_params) #flat_tensor_list) + partition_params.append(my_params) # flat_tensor_list) final_param_offsets.append(my_offsets) assert len(flat_tensor_list) == len(my_offsets), "{} {}".format(len(flat_tensor_list), len(my_offsets)) flat_sub_partitions.append(_flatten_dense_tensors(flat_tensor_list)) @@ -500,7 +519,8 @@ def get_flat_sub_partitions(comm_tensor_list, if return_partition_params: assert len(flat_sub_partitions) == len(partition_params) - assert len(partition_params) == len(final_param_offsets), "{} {}".format(len(partition_params), len(final_param_offsets)) + assert len(partition_params) == len(final_param_offsets), "{} {}".format( + len(partition_params), len(final_param_offsets)) return flat_sub_partitions, partition_params, final_param_offsets return flat_sub_partitions @@ -601,8 +621,8 @@ def step(self, closure=None): for i, group in enumerate(self.fp16_groups): norm_groups.append(get_grad_norm(group, mpu=self.mpu)) - #RS: update free grads w.r.t. sub partitions - #free gradients for all the parameters that are not updated by this process + # RS: update free grads w.r.t. sub partitions + # free gradients for all the parameters that are not updated by this process self.free_grad_in_param_list(self.params_not_local[i]) # create flat gradient partitions for parameters updated by this process @@ -619,20 +639,20 @@ def step(self, closure=None): for idx, sub_partition_param in enumerate(self.local_sub_partitions_of_fp32_groups[i]): sub_partition_param.grad = local_grad_sub_partitions[idx] - #RS: update free grads for sub-partitions - #release all the gradient since we have already created a necessary copy in dp_grad_partition + # RS: update free grads for sub-partitions + # release all the gradient since we have already created a necessary copy in dp_grad_partition self.free_grad_in_param_list( self.params_in_rank_sub_partitions[i][partition_id]) local_sub_partitions_grad_groups.append(local_grad_sub_partitions) - #RS: update unscale/clip with sub partitions + # RS: update unscale/clip with sub partitions self.unscale_and_clip_grads(local_sub_partitions_grad_groups, norm_groups) self.optimizer.step() - #RS: clear our sub partition grads - #get rid of the fp32 gradients. Not needed anymore + # RS: clear our sub partition grads + # get rid of the fp32 gradients. Not needed anymore for group in self.local_sub_partitions_of_fp32_groups: for idx, sub_partition_param in enumerate(group): sub_partition_param.grad = None @@ -645,8 +665,8 @@ def step(self, closure=None): local_sub_partition_param_fp16.data.copy_( local_sub_partition_param_fp32.data) - #RS: all_gather/broadcast sub-partitions in separate comm calls - #gather the updated weights from everyone + # RS: all_gather/broadcast sub-partitions in separate comm calls + # gather the updated weights from everyone for fp16_all_sub_partitions in self.parallel_comm_sub_partitioned_fp16_groups: for comm_id, sub_partitions in enumerate(fp16_all_sub_partitions): dist.all_gather(sub_partitions, @@ -767,7 +787,10 @@ def _get_base_optimizer_state(self): return optimizer_groups_state - def state_dict(self): + def _rigid_state_dict(self): + """ + Returns a dict that can be loaded for continued training with same DP degree + """ """ Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict @@ -782,6 +805,19 @@ def state_dict(self): state_dict['loss_scaler'] = self.loss_scaler state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['overflow'] = self.overflow + state_dict['base_optimizer_state'] = self.optimizer.state_dict() + state_dict[ + 'local_sub_partitions_of_fp32_groups'] = self.local_sub_partitions_of_fp32_groups + return state_dict + + def _elastic_state_dict(self): + """ + Returns a dict that can be loaded for elastic training with different DP degree + """ + state_dict = {} + state_dict['loss_scaler'] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow state_dict['base_optimizer_state'] = self._get_base_optimizer_state() state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES @@ -795,6 +831,22 @@ def state_dict(self): return state_dict + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + if self.elastic_checkpoint: + return self._elastic_state_dict() + + return self._rigid_state_dict() + def _retrieve_group_sub_partition_weights(self, all_partition_fp32_weights): partition_id = dist.get_rank(group=self.dp_process_group) @@ -912,10 +964,23 @@ def _restore_from_fp16_weights(self): def refresh_fp32_params(self): self._restore_from_fp16_weights() - def load_state_dict(self, - state_dict_list, - load_optimizer_states=True, - load_from_fp32_weights=False): + def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): + + # I think it should actually be ok to reload the optimizer before the model. + self.loss_scaler = state_dict['loss_scaler'] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + if load_optimizer_states: + self.optimizer.load_state_dict(state_dict['base_optimizer_state']) + + for curr_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, state_dict['local_sub_partitions_of_fp32_groups']): + for curr_param, saved_param in zip(curr_group, saved_group): + curr_param.data.copy_(saved_param.data) + + def _elastic_load_state_dict(self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False): """ Loads a state_dict created by an earlier call to state_dict(). If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, @@ -943,3 +1008,31 @@ def load_state_dict(self, self._restore_from_fp32_weights(state_dict_list) else: self._restore_from_fp16_weights() + + def load_state_dict(self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False): + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + Example:: + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + if self.elastic_checkpoint: + self._elastic_load_state_dict(state_dict_list, + load_optimizer_states, + load_from_fp32_weights) + else: + self._rigid_load_state_dict( + state_dict_list[dist.get_rank(group=self.dp_process_group)], + load_optimizer_states)