-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tensorflow
: add tf.train.CheckpointOptions
and other tf.train
m…
…embers. (#11327)
- Loading branch information
1 parent
547cbc7
commit 587e75f
Showing
3 changed files
with
96 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from _typeshed import Incomplete | ||
from collections.abc import Callable | ||
from typing import Any, TypeVar | ||
from typing_extensions import Self | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow.core.example.example_pb2 import Example as Example | ||
from tensorflow.core.example.feature_pb2 import ( | ||
BytesList as BytesList, | ||
Feature as Feature, | ||
Features as Features, | ||
FloatList as FloatList, | ||
Int64List as Int64List, | ||
) | ||
from tensorflow.core.protobuf.cluster_pb2 import ClusterDef as ClusterDef | ||
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef as ServerDef | ||
from tensorflow.python.trackable.base import Trackable | ||
|
||
class CheckpointOptions: | ||
experimental_io_device: None | str | ||
experimental_enable_async_checkpoint: bool | ||
# Uncomment when the stubs' TF version is updated to 2.15 | ||
# experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] | ||
enable_async: bool | ||
|
||
def __init__( | ||
self, | ||
experimental_io_device: None | str = None, | ||
experimental_enable_async_checkpoint: bool = False, | ||
# Uncomment when the stubs' TF version is updated to 2.15 | ||
# experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] = None, | ||
enable_async: bool = False, | ||
) -> None: ... | ||
|
||
_T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str]) | ||
|
||
class ClusterSpec: | ||
def __init__(self, cluster: dict[str, _T] | ClusterDef | ClusterSpec) -> None: ... | ||
def as_dict(self) -> dict[str, list[str] | tuple[str] | dict[int, str]]: ... | ||
def num_tasks(self, job_name: str) -> int: ... | ||
|
||
class _CheckpointLoadStatus: | ||
def assert_consumed(self) -> Self: ... | ||
def assert_existing_objects_matched(self) -> Self: ... | ||
def assert_nontrivial_match(self) -> Self: ... | ||
def expect_partial(self) -> Self: ... | ||
|
||
class Checkpoint: | ||
def __init__(self, root: Trackable | None = None, **kwargs: Trackable) -> None: ... | ||
def read(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... | ||
def restore(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... | ||
def save(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... | ||
# def sync(self) -> None: ... # Uncomment when the stubs' TF version is updated to 2.15 | ||
def write(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... | ||
|
||
class CheckpointManager: | ||
def __init__( | ||
self, | ||
checkpoint: Checkpoint, | ||
directory: str, | ||
max_to_keep: int, | ||
keep_checkpoint_every_n_hours: int | None = None, | ||
checkpoint_name: str = "ckpt", | ||
step_counter: tf.Variable | None = None, | ||
checkpoint_interval: int | None = None, | ||
init_fn: Callable[[], object] | None = None, | ||
) -> None: ... | ||
def _sweep(self) -> None: ... | ||
|
||
def latest_checkpoint(checkpoint_dir: str, latest_filename: str | None = None) -> str: ... | ||
def load_variable(ckpt_dir_or_file: str, name: str) -> np.ndarray[Any, Any]: ... | ||
def list_variables(ckpt_dir_or_file: str) -> list[tuple[str, list[int]]]: ... | ||
def __getattr__(name: str) -> Incomplete: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import abc | ||
from _typeshed import Incomplete | ||
from typing_extensions import Self | ||
|
||
from tensorflow.python.trackable.base import Trackable | ||
|
||
class PythonState(Trackable, metaclass=abc.ABCMeta): | ||
@abc.abstractmethod | ||
def serialize(self) -> str: ... | ||
@abc.abstractmethod | ||
def deserialize(self, string_value: str) -> Self: ... | ||
|
||
def __getattr__(name: str) -> Incomplete: ... |