diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index faae84979..8985512b3 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -522,7 +522,8 @@ def restore_checkpoint( *, local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, - ) -> Dict[str, Any]: + load_trainer_state: bool = True, + ) -> Optional[Dict[str, Any]]: """ Restores a checkpoint to the model and optimizer. Returns the remaining trainer state. """ @@ -678,7 +679,8 @@ def restore_checkpoint( *, local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, - ) -> Dict[str, Any]: + load_trainer_state: bool = True, + ) -> Optional[Dict[str, Any]]: with FSDP.state_dict_type( fsdp_model, state_dict_type=StateDictType.FULL_STATE_DICT, @@ -751,11 +753,13 @@ def restore_checkpoint( del optim_state_dict_to_load # Load other state. - try: - trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache) - except FileNotFoundError: - # for backwards compatibility - trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache) + trainer_state = None + if load_trainer_state: + try: + trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache) + except FileNotFoundError: + # for backwards compatibility + trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache) barrier() return trainer_state @@ -872,7 +876,8 @@ def restore_checkpoint( *, local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, - ) -> Dict[str, Any]: + load_trainer_state: bool = True, + ) -> Optional[Dict[str, Any]]: # Load model and optimizer state in place. log.info("Loading model and optimizer state...") load_fsdp_model_and_optim_state( @@ -885,14 +890,16 @@ def restore_checkpoint( # Load trainer state dict. log.info("Loading trainer state...") - try: - trainer_state = load_state_dict( - load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache - ) - except FileNotFoundError: - # Fall back to rank 0 train state. - # This can happen when we're restoring a checkpoint with a different world size. - trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache) + trainer_state = None + if load_trainer_state: + try: + trainer_state = load_state_dict( + load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache + ) + except FileNotFoundError: + # Fall back to rank 0 train state. + # This can happen when we're restoring a checkpoint with a different world size. + trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache) barrier() return trainer_state @@ -949,6 +956,7 @@ def restore_checkpoint( *, local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, + load_trainer_state: bool = True, ) -> Dict[str, Any]: with FSDP.state_dict_type( fsdp_model, @@ -1562,7 +1570,8 @@ def restore_checkpoint( *, local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, - ) -> Dict[str, Any]: + load_trainer_state: bool = True, + ) -> Optional[Dict[str, Any]]: # Load metadata and make sure checkpoint is compatible. metadata = self._load_metadata(load_path, local_cache=local_cache) assert metadata.world_size == get_world_size() @@ -1599,7 +1608,9 @@ def restore_checkpoint( # Load local trainer state. log.info("Loading local trainer state...") - trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache) + trainer_state = None + if load_trainer_state: + trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache) barrier() return trainer_state @@ -1868,7 +1879,8 @@ def restore_checkpoint( *, local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, - ) -> Dict[str, Any]: + load_trainer_state: bool = True, + ) -> Optional[Dict[str, Any]]: from olmo_core.distributed.checkpoint import ( # type: ignore load_model_and_optim_state, ) @@ -1877,14 +1889,16 @@ def restore_checkpoint( load_model_and_optim_state(load_path, fsdp_model, optim if load_optimizer_state else None) log.info("Loading trainer state...") - try: - trainer_state = load_state_dict( - load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache - ) - except FileNotFoundError: - # Fall back to rank 0 train state. - # This can happen when we're restoring a checkpoint with a different world size. - trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache) + trainer_state = None + if load_trainer_state: + try: + trainer_state = load_state_dict( + load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache + ) + except FileNotFoundError: + # Fall back to rank 0 train state. + # This can happen when we're restoring a checkpoint with a different world size. + trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache) barrier() return trainer_state diff --git a/olmo/train.py b/olmo/train.py index 94cf56f7a..bc241395b 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -504,6 +504,7 @@ def restore_sharded_checkpoint( self.optim, local_cache=local_cache, load_optimizer_state=load_optimizer_state, + load_trainer_state=load_trainer_state, ) if load_trainer_state: self.load_trainer_state_dict(trainer_state)