Skip to content

Commit

Permalink
Add 'latest' checkpoint save/load support (microsoft#569)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored Dec 2, 2020
1 parent 7a75f8b commit 845921b
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 9 deletions.
29 changes: 24 additions & 5 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,15 +1300,15 @@ def _get_ckpt_name(self, checkpoints_path, tag):

def load_checkpoint(self,
load_dir,
tag,
tag=None,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True):
r"""Load training checkpoint
"""Load training checkpoint
Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
Expand All @@ -1317,6 +1317,13 @@ def load_checkpoint(self,
client_state: State dictionary used for loading required training states in the client code.
"""

if tag is None:
latest_path = os.path.join(load_dir, 'latest')
assert os.path.isfile(latest_path), f"Unable to find latest file at {latest_path}, if trying to load latest " \
"checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint."
with open(latest_path, 'r') as fd:
tag = fd.read().strip()

load_path, client_states = self._load_checkpoint(load_dir,
tag,
load_module_strict=load_module_strict,
Expand Down Expand Up @@ -1454,18 +1461,25 @@ def _get_all_zero_checkpoints(self, load_dir, tag):
)
return zero_optimizer_sd

def save_checkpoint(self, save_dir, tag, client_state={}):
def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True):
r"""Save training checkpoint
Arguments:
save_dir: Required. Directory for saving the checkpoint
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is used if not provided.
client_state: Optional. State dictionary used for saving required training states in the client code.
save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint.
"""

# This is to make sure the checkpoint names are created without collision
# There seems to be issue creating them in parallel

# Ensure save_dir directory exists
os.makedirs(save_dir, exist_ok=True)

if tag is None:
tag = f"global_step{self.global_steps}"

if self.save_non_zero_checkpoint:
self._create_checkpoint_file(save_dir, tag, False)
self._save_checkpoint(save_dir, tag, client_state=client_state)
Expand All @@ -1474,6 +1488,11 @@ def save_checkpoint(self, save_dir, tag, client_state={}):
self._create_zero_checkpoint_files(save_dir, tag)
self._save_zero_checkpoint(save_dir, tag)

# Save latest checkpoint tag
if save_latest:
with open(os.path.join(save_dir, 'latest'), 'w') as fd:
fd.write(tag)

return True

def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
Expand Down
65 changes: 61 additions & 4 deletions tests/unit/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def checkpoint_correctness_verification(args,
fp16=True,
train_batch=False,
base_optimizers=[None,
None]):
None],
empty_tag=False):
dtype = torch.half if fp16 else torch.float32
ds_model = create_deepspeed_model(args=args,
model=models[0],
Expand All @@ -153,16 +154,16 @@ def checkpoint_correctness_verification(args,
trained_model = ds_model

save_folder = os.path.join(tmpdir, 'saved_checkpoint')
save_tag = '1'
save_tag = None if empty_tag else '1'

trained_model.save_checkpoint(save_folder, save_tag)
trained_model.save_checkpoint(save_folder, tag=save_tag)

loaded_model = create_deepspeed_model(args=args,
model=models[1],
base_optimizer=base_optimizers[1])

loaded_model.load_checkpoint(save_folder,
save_tag,
tag=save_tag,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)

Expand Down Expand Up @@ -704,3 +705,59 @@ def _test_checkpoint_zero_hybrid_optimizer_state(args,
models=models,
optimizers=optimizers,
hidden_dim=hidden_dim)


def test_checkpoint_latest(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
}
}
hidden_dim = 10
args = args_from_dict(tmpdir, config_dict)
models = [SimpleModel(hidden_dim=hidden_dim) for _ in range(2)]

@distributed_test(world_size=[1])
def _helper(args, models):
checkpoint_correctness_verification(args,
models=models,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
load_optimizer_states=True,
load_lr_scheduler_states=False,
fp16=False,
empty_tag=True)

_helper(args, models)


def test_checkpoint_missing_latest(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
}
}
hidden_dim = 10
args = args_from_dict(tmpdir, config_dict)

model = SimpleModel(hidden_dim, rank=args.local_rank)

@distributed_test(world_size=[1])
def _helper(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
with pytest.raises(AssertionError):
model.load_checkpoint(tmpdir)

_helper(args=args, model=model, hidden_dim=hidden_dim)

0 comments on commit 845921b

Please sign in to comment.