diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 24beaa775ffe..42e356b3635c 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1300,15 +1300,15 @@ def _get_ckpt_name(self, checkpoints_path, tag): def load_checkpoint(self, load_dir, - tag, + tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True): - r"""Load training checkpoint + """Load training checkpoint Arguments: load_dir: Required. Directory to load the checkpoint from - tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step. + tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match. load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint. @@ -1317,6 +1317,13 @@ def load_checkpoint(self, client_state: State dictionary used for loading required training states in the client code. """ + if tag is None: + latest_path = os.path.join(load_dir, 'latest') + assert os.path.isfile(latest_path), f"Unable to find latest file at {latest_path}, if trying to load latest " \ + "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint." + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + load_path, client_states = self._load_checkpoint(load_dir, tag, load_module_strict=load_module_strict, @@ -1454,18 +1461,25 @@ def _get_all_zero_checkpoints(self, load_dir, tag): ) return zero_optimizer_sd - def save_checkpoint(self, save_dir, tag, client_state={}): + def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True): r"""Save training checkpoint Arguments: save_dir: Required. Directory for saving the checkpoint - tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step. + tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is used if not provided. client_state: Optional. State dictionary used for saving required training states in the client code. + save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint. """ # This is to make sure the checkpoint names are created without collision # There seems to be issue creating them in parallel + # Ensure save_dir directory exists + os.makedirs(save_dir, exist_ok=True) + + if tag is None: + tag = f"global_step{self.global_steps}" + if self.save_non_zero_checkpoint: self._create_checkpoint_file(save_dir, tag, False) self._save_checkpoint(save_dir, tag, client_state=client_state) @@ -1474,6 +1488,11 @@ def save_checkpoint(self, save_dir, tag, client_state={}): self._create_zero_checkpoint_files(save_dir, tag) self._save_zero_checkpoint(save_dir, tag) + # Save latest checkpoint tag + if save_latest: + with open(os.path.join(save_dir, 'latest'), 'w') as fd: + fd.write(tag) + return True def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index 097a581d96d9..1fbcacfa2aa4 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -128,7 +128,8 @@ def checkpoint_correctness_verification(args, fp16=True, train_batch=False, base_optimizers=[None, - None]): + None], + empty_tag=False): dtype = torch.half if fp16 else torch.float32 ds_model = create_deepspeed_model(args=args, model=models[0], @@ -153,16 +154,16 @@ def checkpoint_correctness_verification(args, trained_model = ds_model save_folder = os.path.join(tmpdir, 'saved_checkpoint') - save_tag = '1' + save_tag = None if empty_tag else '1' - trained_model.save_checkpoint(save_folder, save_tag) + trained_model.save_checkpoint(save_folder, tag=save_tag) loaded_model = create_deepspeed_model(args=args, model=models[1], base_optimizer=base_optimizers[1]) loaded_model.load_checkpoint(save_folder, - save_tag, + tag=save_tag, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_lr_scheduler_states) @@ -704,3 +705,59 @@ def _test_checkpoint_zero_hybrid_optimizer_state(args, models=models, optimizers=optimizers, hidden_dim=hidden_dim) + + +def test_checkpoint_latest(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + } + } + hidden_dim = 10 + args = args_from_dict(tmpdir, config_dict) + models = [SimpleModel(hidden_dim=hidden_dim) for _ in range(2)] + + @distributed_test(world_size=[1]) + def _helper(args, models): + checkpoint_correctness_verification(args, + models=models, + hidden_dim=hidden_dim, + tmpdir=tmpdir, + load_optimizer_states=True, + load_lr_scheduler_states=False, + fp16=False, + empty_tag=True) + + _helper(args, models) + + +def test_checkpoint_missing_latest(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + } + } + hidden_dim = 10 + args = args_from_dict(tmpdir, config_dict) + + model = SimpleModel(hidden_dim, rank=args.local_rank) + + @distributed_test(world_size=[1]) + def _helper(args, model, hidden_dim): + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + with pytest.raises(AssertionError): + model.load_checkpoint(tmpdir) + + _helper(args=args, model=model, hidden_dim=hidden_dim)