Skip to content

Commit

Permalink
tensorflow: add tf.train.CheckpointOptions and other tf.train m…
Browse files Browse the repository at this point in the history
…embers. (#11327)
  • Loading branch information
hoel-bagard authored Feb 1, 2024
1 parent 547cbc7 commit 587e75f
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 0 deletions.
9 changes: 9 additions & 0 deletions stubs/tensorflow/@tests/stubtest_allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,12 @@ tensorflow.io.SparseFeature.__new__

# Metaclass inconsistency. The runtime metaclass is defined from c++ extension and is undocumented.
tensorflow.io.TFRecordWriter

# stubtest does not pass for protobuf generated stubs.
tensorflow.train.Example.*
tensorflow.train.BytesList.*
tensorflow.train.Feature.*
tensorflow.train.FloatList.*
tensorflow.train.Int64List.*
tensorflow.train.ClusterDef.*
tensorflow.train.ServerDef.*
74 changes: 74 additions & 0 deletions stubs/tensorflow/tensorflow/train/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from _typeshed import Incomplete
from collections.abc import Callable
from typing import Any, TypeVar
from typing_extensions import Self

import numpy as np
import tensorflow as tf
from tensorflow.core.example.example_pb2 import Example as Example
from tensorflow.core.example.feature_pb2 import (
BytesList as BytesList,
Feature as Feature,
Features as Features,
FloatList as FloatList,
Int64List as Int64List,
)
from tensorflow.core.protobuf.cluster_pb2 import ClusterDef as ClusterDef
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef as ServerDef
from tensorflow.python.trackable.base import Trackable

class CheckpointOptions:
experimental_io_device: None | str
experimental_enable_async_checkpoint: bool
# Uncomment when the stubs' TF version is updated to 2.15
# experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]]
enable_async: bool

def __init__(
self,
experimental_io_device: None | str = None,
experimental_enable_async_checkpoint: bool = False,
# Uncomment when the stubs' TF version is updated to 2.15
# experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] = None,
enable_async: bool = False,
) -> None: ...

_T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str])

class ClusterSpec:
def __init__(self, cluster: dict[str, _T] | ClusterDef | ClusterSpec) -> None: ...
def as_dict(self) -> dict[str, list[str] | tuple[str] | dict[int, str]]: ...
def num_tasks(self, job_name: str) -> int: ...

class _CheckpointLoadStatus:
def assert_consumed(self) -> Self: ...
def assert_existing_objects_matched(self) -> Self: ...
def assert_nontrivial_match(self) -> Self: ...
def expect_partial(self) -> Self: ...

class Checkpoint:
def __init__(self, root: Trackable | None = None, **kwargs: Trackable) -> None: ...
def read(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ...
def restore(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ...
def save(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ...
# def sync(self) -> None: ... # Uncomment when the stubs' TF version is updated to 2.15
def write(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ...

class CheckpointManager:
def __init__(
self,
checkpoint: Checkpoint,
directory: str,
max_to_keep: int,
keep_checkpoint_every_n_hours: int | None = None,
checkpoint_name: str = "ckpt",
step_counter: tf.Variable | None = None,
checkpoint_interval: int | None = None,
init_fn: Callable[[], object] | None = None,
) -> None: ...
def _sweep(self) -> None: ...

def latest_checkpoint(checkpoint_dir: str, latest_filename: str | None = None) -> str: ...
def load_variable(ckpt_dir_or_file: str, name: str) -> np.ndarray[Any, Any]: ...
def list_variables(ckpt_dir_or_file: str) -> list[tuple[str, list[int]]]: ...
def __getattr__(name: str) -> Incomplete: ...
13 changes: 13 additions & 0 deletions stubs/tensorflow/tensorflow/train/experimental.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import abc
from _typeshed import Incomplete
from typing_extensions import Self

from tensorflow.python.trackable.base import Trackable

class PythonState(Trackable, metaclass=abc.ABCMeta):
@abc.abstractmethod
def serialize(self) -> str: ...
@abc.abstractmethod
def deserialize(self, string_value: str) -> Self: ...

def __getattr__(name: str) -> Incomplete: ...

0 comments on commit 587e75f

Please sign in to comment.