Skip to content

Commit

Permalink
Make elastic checkpointing optional (#526)
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Nov 14, 2020
1 parent fba92f1 commit 9c577ee
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 55 deletions.
54 changes: 33 additions & 21 deletions deepspeed/pt/deepspeed_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def get_optimizer_gradient_clipping(param_dict):

def get_optimizer_legacy_fusion(param_dict):
if OPTIMIZER in param_dict.keys() and \
LEGACY_FUSION in param_dict[OPTIMIZER].keys():
LEGACY_FUSION in param_dict[OPTIMIZER].keys():
return param_dict[OPTIMIZER][LEGACY_FUSION]
else:
return LEGACY_FUSION_DEFAULT
Expand Down Expand Up @@ -260,6 +260,19 @@ def get_tensorboard_job_name(param_dict):
return TENSORBOARD_JOB_NAME_DEFAULT


def print_config(config):
for arg in sorted(vars(config)):
if arg != '_param_dict':
arg_object = getattr(config, arg)
dots = '.' * (29 - len(arg))
if hasattr(arg_object, '__dict__'):
logger.info(' {} {} Begin'.format(arg, dots))
print_config(arg_object)
logger.info(' {} {} End'.format(arg, dots))
else:
logger.info(' {} {} {}'.format(arg, dots, arg_object))


'''Write deepspeed config files by modifying basic templates.
Can be used for quicly changing parameters via command line parameters.'''

Expand Down Expand Up @@ -340,7 +353,7 @@ def _initialize_params(self, param_dict):

self.optimizer_name = get_optimizer_name(param_dict)
if self.optimizer_name is not None and \
self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS:
self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS:
self.optimizer_name = self.optimizer_name.lower()

self.optimizer_params = get_optimizer_params(param_dict)
Expand Down Expand Up @@ -374,39 +387,39 @@ def _batch_assertion(self):
f'Gradient accumulation steps: {grad_acc} has to be greater than 0'

assert train_batch == micro_batch * grad_acc * self.world_size, \
(f'Check batch related parameters. train_batch_size is not equal'
' to micro_batch_per_gpu * gradient_acc_step * world_size'
f'{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}')
(f'Check batch related parameters. train_batch_size is not equal'
' to micro_batch_per_gpu * gradient_acc_step * world_size'
f'{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}')

def _set_batch_related_parameters(self):

train_batch = self.train_batch_size
micro_batch = self.train_micro_batch_size_per_gpu
grad_acc = self.gradient_accumulation_steps

#all values are provided nothing needs to be set
# all values are provided nothing needs to be set
if train_batch is not None and \
micro_batch is not None and \
grad_acc is not None:
micro_batch is not None and \
grad_acc is not None:
return

#global_accumulation_steps needs to be set
# global_accumulation_steps needs to be set
elif train_batch is not None and \
micro_batch is not None:
micro_batch is not None:
grad_acc = train_batch // micro_batch
grad_acc //= self.world_size
self.gradient_accumulation_steps = grad_acc

#micro_batch_per_gpu needs to be set
# micro_batch_per_gpu needs to be set
elif train_batch is not None and \
grad_acc is not None:
grad_acc is not None:
micro_batch = train_batch // self.world_size
micro_batch //= grad_acc
self.train_micro_batch_size_per_gpu = micro_batch

#train_batch_size needs to be set
# train_batch_size needs to be set
elif micro_batch is not None and \
grad_acc is not None:
grad_acc is not None:
train_batch_size = micro_batch * grad_acc
train_batch_size *= self.world_size
self.train_batch_size = train_batch_size
Expand All @@ -421,7 +434,7 @@ def _set_batch_related_parameters(self):
self.train_batch_size = micro_batch * self.world_size
self.gradient_accumulation_steps = 1

#either none of the three parameters are provided or just gradient_accumulation_step is provided
# either none of the three parameters are provided or just gradient_accumulation_step is provided
else:
assert False, \
'Either train_batch_size or micro_batch_per_gpu needs to be provided'
Expand All @@ -441,10 +454,7 @@ def _do_sanity_check(self):

def print(self, name):
logger.info('{}:'.format(name))
for arg in sorted(vars(self)):
if arg != '_param_dict':
dots = '.' * (29 - len(arg))
logger.info(' {} {} {}'.format(arg, dots, getattr(self, arg)))
print_config(self)

logger.info(' json = {}'.format(
json.dumps(self._param_dict,
Expand All @@ -456,9 +466,11 @@ def print(self, name):
def _do_error_check(self):
if self.zero_enabled:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(
MAX_STAGE_ZERO_OPTIMIZATION)

assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(
TRAIN_MICRO_BATCH_SIZE_PER_GPU)

assert self.gradient_accumulation_steps, 'DeepSpeedConfig: {} is not defined'.format(
GRADIENT_ACCUMULATION_STEPS)
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/pt/deepspeed_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ def zero_contiguous_gradients(self):
def zero_load_from_fp32_weights(self):
return self._config.zero_config.load_from_fp32_weights

def zero_elastic_checkpoint(self):
return self._config.zero_config.elastic_checkpoint

def allgather_size(self):
return self._config.allgather_size

Expand Down Expand Up @@ -596,6 +599,7 @@ def _configure_zero_optimizer(self, optimizer):
allgather_size=self.zero_allgather_bucket_size(),
max_elements_per_comm=self.zero_reduce_bucket_size(),
dp_process_group=self.data_parallel_group,
elastic_checkpoint=self.zero_elastic_checkpoint(),
mpu=self.mpu)
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
assert self.gradient_accumulation_steps(
Expand Down
21 changes: 14 additions & 7 deletions deepspeed/pt/deepspeed_zero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,22 @@
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT = 'elastic_checkpoint'
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT = False

ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER:
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE:
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT: ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT
}


Expand All @@ -93,6 +93,7 @@ def __init__(self, param_dict):
self.allgather_bucket_size = None
self.overlap_comm = None
self.load_from_fp32_weights = None
self.elastic_checkpoint = None

if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
Expand Down Expand Up @@ -157,7 +158,13 @@ def _initialize(self, zero_config_dict):
zero_config_dict,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)

self.load_from_fp32_weights = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT)

self.elastic_checkpoint = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT)
Loading

0 comments on commit 9c577ee

Please sign in to comment.