From a918327818686e2a642f5e80d74a4ea8503b64bc Mon Sep 17 00:00:00 2001 From: Adam Cogdell Date: Mon, 18 Nov 2024 20:56:46 -0800 Subject: [PATCH] Use new metadata in `CheckpointManager`. PiperOrigin-RevId: 697859685 --- CHANGELOG.md | 2 + .../checkpoint/abstract_checkpoint_manager.py | 19 ++- .../orbax/checkpoint/checkpoint_manager.py | 108 +++++++++++++----- .../emergency/checkpoint_manager.py | 4 +- .../replicator_checkpoint_manager.py | 7 +- 5 files changed, 103 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 73ea6e4c..508dac44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 packages, namely `orbax-checkpoint` and `orbax-export`. Imports are unchanged, and still of the form `import orbax.checkpoint` or `import orbax.export`. - Finer scoped jax.monitoring calls on the save path. +- `CheckpointManager.metadata()` now accepts a `step` parameter. If provided, it will return `StepMetadata`, and will otherwise return `RootMetadata`. +- `CompositeCheckpointHandler.metadata()` now returns `StepMetadata`. ## [0.1.7] - 2022-03-29 diff --git a/checkpoint/orbax/checkpoint/abstract_checkpoint_manager.py b/checkpoint/orbax/checkpoint/abstract_checkpoint_manager.py index e493f30c..07442bc3 100644 --- a/checkpoint/orbax/checkpoint/abstract_checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/abstract_checkpoint_manager.py @@ -19,6 +19,7 @@ from etils import epath from orbax.checkpoint import args as args_lib +from orbax.checkpoint._src.metadata import checkpoint PyTree = Any SaveParams = Mapping[str, Any] @@ -290,8 +291,22 @@ def item_metadata( """ @abc.abstractmethod - def metadata(self) -> Mapping[str, Any]: - """Returns CheckpointManager level metadata if present, empty otherwise.""" + def metadata( + self, step: int | None = None, + ) -> checkpoint.StepMetadata | checkpoint.RootMetadata: + """Returns `StepMetadata` for the specified step, or `RootMetadata` all. + + If step is specified, only return `StepMetadata` for that step. + Otherwise, restore `RootMetadata`. + + Args: + step: Step for which to retrieve `StepMetadata`. If None, returns + `RootMetadata`. + + Returns: + Metadata for the specified step (`StepMetadata`), or all steps + (`RootMetadata`). + """ @abc.abstractmethod def metrics(self, step: int) -> Optional[PyTree]: diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index 606435fe..4f988a27 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -22,7 +22,7 @@ import threading import time import typing -from typing import Any, Callable, Container, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Container, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union, overload from absl import logging from etils import epath @@ -44,6 +44,7 @@ from orbax.checkpoint._src.handlers import proto_checkpoint_handler from orbax.checkpoint._src.metadata import checkpoint from orbax.checkpoint._src.metadata import root_metadata_serialization +from orbax.checkpoint._src.metadata import step_metadata_serialization from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import atomicity_types from orbax.checkpoint._src.path import deleter @@ -64,6 +65,9 @@ AbstractCheckpointManager = ( abstract_checkpoint_manager.AbstractCheckpointManager ) +StepMetadata = checkpoint.StepMetadata +RootMetadata = checkpoint.RootMetadata +ItemMetadata = checkpoint.CompositeItemMetadata | checkpoint.SingleItemMetadata AsyncCheckpointer = async_checkpointer.AsyncCheckpointer Checkpointer = checkpointer_lib.Checkpointer JsonCheckpointHandler = json_checkpoint_handler.JsonCheckpointHandler @@ -709,11 +713,14 @@ def __init__( self._metadata_dir = self.directory / METADATA_ITEM_NAME if self._options.read_only and not self._metadata_dir.exists(): - self._metadata = {} if metadata is None else metadata + custom_metadata = {} if metadata is None else dict(metadata) else: - self._metadata = None + custom_metadata = None + self._root_metadata = RootMetadata( + custom=custom_metadata, + ) - self._maybe_save_metadata(metadata) + self._maybe_save_root_metadata(metadata) # TODO: b/359854428 - Move Finalize biz logic to a separate class/module. self._finalize_thread_lock = threading.Lock() @@ -1194,6 +1201,13 @@ def save( args_dict['metrics'] = args_lib.JsonSave(metrics) args = args_lib.Composite(**args_dict) + step_metadata = StepMetadata( + metrics=dict(metrics) if metrics is not None else None, + performance_metrics=step_stats, + custom=self._root_metadata.custom, + ) + logging.info('step_metadata: %s', step_metadata) + save_directory = self._get_write_step_directory(step, self.directory) # If a folder for the step to save exists and is not finalized, remove the # existing folder. @@ -1230,7 +1244,9 @@ def save( '[process=%s] Saving checkpoint at step %d', process_index, step ) step_stats.checkpointer_blocking_start_time = time.time() - self._checkpointer.save(save_directory, args=args) + self._checkpointer.save( + save_directory, args=args, step_metadata=step_metadata + ) step_stats.checkpointer_blocking_duration_secs = ( time.time() - step_stats.checkpointer_blocking_start_time ) @@ -1298,7 +1314,9 @@ def save( self._logger.log_entry(dataclasses.asdict(step_stats)) return True - def _maybe_get_default_item(self, composite_result: args_lib.Composite): + def _maybe_get_default_item( + self, composite_result: args_lib.Composite + ) -> Union[Any, args_lib.Composite]: if self._default_item: if DEFAULT_ITEM_NAME not in composite_result: raise ValueError( @@ -1379,7 +1397,9 @@ def restore( return self._maybe_get_default_item(restored) - def item_metadata(self, step: int) -> Union[Any, args_lib.Composite]: + def item_metadata( + self, step: int + ) -> Union[Any, args_lib.Composite, ItemMetadata]: """Retrieves metadata for all known items. Note that metadata will only be returned for items that can actually be @@ -1394,18 +1414,14 @@ def item_metadata(self, step: int) -> Union[Any, args_lib.Composite]: Either metadata for the item itself, if in default-item mode, or a Composite of metadata for each item. """ - assert isinstance(self._checkpointer.handler, CompositeCheckpointHandler) read_step_directory = self._get_read_step_directory(step, self.directory) - - result = self._checkpointer.metadata(read_step_directory) - if isinstance(result, checkpoint.StepMetadata): - result = result.item_metadata if self._default_item is None: self._default_item = _determine_default_item_mode_from_directory( read_step_directory ) - return self._maybe_get_default_item(result) + return self._maybe_get_default_item(self.metadata(step).item_metadata) + # TODO(b/370812224): Deprecate in favor of StepMetadata.metrics def metrics(self, step: int) -> Optional[PyTree]: if self._track_best: try: @@ -1504,21 +1520,18 @@ def _metadata_file_path(self, legacy: bool = False) -> epath.Path: self._metadata_dir, legacy=legacy ) - def _maybe_save_metadata(self, metadata: Mapping[str, Any]): + def _maybe_save_root_metadata(self, custom_metadata: Mapping[str, Any]): """Saves CheckpointManager level metadata, skips if already present.""" if self._options.save_root_metadata: - logging.info('Saving root metadata') - if (metadata is not None and + if (custom_metadata is not None and not self._options.read_only and utils.is_primary_host(self._multiprocessing_options.primary_host)): - logging.info('Creating metadata directory') self._metadata_dir.mkdir(parents=True, exist_ok=True) file_path = self._metadata_file_path() if not file_path.exists(): # May have been created by a previous run. - logging.info('Writing root metadata') - metadata_to_save = checkpoint.RootMetadata( - custom=dict(metadata), - ) + metadata_to_save = self._root_metadata + if custom_metadata is not None: + metadata_to_save.custom = dict(custom_metadata) self._blocking_metadata_store.write( file_path, serialize_root_metadata(metadata_to_save) ) @@ -1531,9 +1544,28 @@ def _maybe_save_metadata(self, metadata: Mapping[str, Any]): processes=self._multiprocessing_options.active_processes, ) - def metadata(self) -> Mapping[str, Any]: - """See superclass documentation.""" - if self._metadata is None: + def _get_step_metadata(self, step: int) -> StepMetadata: + infos = [info for info in self._checkpoints if info.step == step] + if not infos or len(infos) > 1: + metrics = None + else: + metrics = infos[0].metrics + + step_metadata = self._checkpointer.metadata( + self._get_read_step_directory(step, self.directory), + ) + if metrics is not None: + validated_metrics = step_metadata_serialization.deserialize( + {}, metrics=dict(metrics) + ).metrics + step_metadata = dataclasses.replace( + step_metadata, + metrics=validated_metrics, + ) + return step_metadata + + def _get_root_metadata(self) -> RootMetadata: + if self._root_metadata.custom is None: if self._metadata_dir.exists(): file_path = self._metadata_file_path() if not file_path.exists(): @@ -1542,14 +1574,28 @@ def metadata(self) -> Mapping[str, Any]: self._metadata_dir) file_path = self._metadata_file_path(legacy=True) serialized_metadata = self._blocking_metadata_store.read(file_path) - self._metadata = deserialize_root_metadata(serialized_metadata).custom - if self._metadata is None: - raise FileNotFoundError( - f'Failed to read metadata from {file_path}.' - ) + if serialized_metadata is None: + raise IOError(f'Failed to read metadata from {file_path}') + self._root_metadata = root_metadata_serialization.deserialize( + serialized_metadata + ) else: - self._metadata = {} - return self._metadata + self._root_metadata.custom = {} + return self._root_metadata + + @overload + def metadata(self, step: None = None) -> RootMetadata: + ... + + @overload + def metadata(self, step: int) -> StepMetadata: + ... + + def metadata(self, step: int | None = None) -> RootMetadata | StepMetadata: + """See superclass documentation.""" + if step is not None: + return self._get_step_metadata(step) + return self._get_root_metadata() def _sort_checkpoints_by_metrics( self, checkpoints: List[CheckpointInfo] diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py index 697cbbd2..7bbc1d4e 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py @@ -67,6 +67,8 @@ get_present_and_missing_chunks = ( local_checkpoint_data_debugging.get_present_and_missing_chunks ) +RootMetadata = checkpoint_manager.RootMetadata +StepMetadata = checkpoint_manager.StepMetadata _PRIMARY_REPLICA_ID = 0 _SECONDARY_REPLICA_ID = 1 @@ -1303,7 +1305,7 @@ def item_metadata(self, step: int) -> Any: 'Item metadata not yet implemented for emergency.CheckpointManager.' ) - def metadata(self) -> dict[str, Any]: + def metadata(self, step: int | None = None) -> RootMetadata | StepMetadata: """Returns CheckpointManager level metadata if present, empty otherwise.""" raise NotImplementedError( 'Metadata not yet implemented for emergency.CheckpointManager.' diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager.py b/checkpoint/orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager.py index f91d77b6..9ddae2c0 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager.py @@ -18,7 +18,6 @@ subject to change without notice. """ -from collections.abc import Mapping import dataclasses from typing import Any, Callable, Iterable, Sequence from absl import logging @@ -42,6 +41,8 @@ handler_registration.DefaultCheckpointHandlerRegistry ) PyTreeCheckpointHandler = pytree_checkpoint_handler.PyTreeCheckpointHandler +RootMetadata = checkpoint_manager.RootMetadata +StepMetadata = checkpoint_manager.StepMetadata _UNNAMED_ITEM_NAME = 'state' @@ -319,8 +320,8 @@ def restore( def item_metadata(self, step: int) -> Any: return self._impl.item_metadata(step) - def metadata(self) -> Mapping[str, Any]: - return self._impl.metadata() + def metadata(self, step: int | None = None) -> RootMetadata | StepMetadata: + return self._impl.metadata(step) def metrics(self, step: int) -> PyTree | None: raise NotImplementedError()