diff --git a/stubs/tensorflow/@tests/stubtest_allowlist.txt b/stubs/tensorflow/@tests/stubtest_allowlist.txt index 5700e283375f..54f489ac7baa 100644 --- a/stubs/tensorflow/@tests/stubtest_allowlist.txt +++ b/stubs/tensorflow/@tests/stubtest_allowlist.txt @@ -85,3 +85,12 @@ tensorflow.io.SparseFeature.__new__ # Metaclass inconsistency. The runtime metaclass is defined from c++ extension and is undocumented. tensorflow.io.TFRecordWriter + +# stubtest does not pass for protobuf generated stubs. +tensorflow.train.Example.* +tensorflow.train.BytesList.* +tensorflow.train.Feature.* +tensorflow.train.FloatList.* +tensorflow.train.Int64List.* +tensorflow.train.ClusterDef.* +tensorflow.train.ServerDef.* diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi new file mode 100644 index 000000000000..1f42cd14a1a6 --- /dev/null +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -0,0 +1,74 @@ +from _typeshed import Incomplete +from collections.abc import Callable +from typing import Any, TypeVar +from typing_extensions import Self + +import numpy as np +import tensorflow as tf +from tensorflow.core.example.example_pb2 import Example as Example +from tensorflow.core.example.feature_pb2 import ( + BytesList as BytesList, + Feature as Feature, + Features as Features, + FloatList as FloatList, + Int64List as Int64List, +) +from tensorflow.core.protobuf.cluster_pb2 import ClusterDef as ClusterDef +from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef as ServerDef +from tensorflow.python.trackable.base import Trackable + +class CheckpointOptions: + experimental_io_device: None | str + experimental_enable_async_checkpoint: bool + # Uncomment when the stubs' TF version is updated to 2.15 + # experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] + enable_async: bool + + def __init__( + self, + experimental_io_device: None | str = None, + experimental_enable_async_checkpoint: bool = False, + # Uncomment when the stubs' TF version is updated to 2.15 + # experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] = None, + enable_async: bool = False, + ) -> None: ... + +_T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str]) + +class ClusterSpec: + def __init__(self, cluster: dict[str, _T] | ClusterDef | ClusterSpec) -> None: ... + def as_dict(self) -> dict[str, list[str] | tuple[str] | dict[int, str]]: ... + def num_tasks(self, job_name: str) -> int: ... + +class _CheckpointLoadStatus: + def assert_consumed(self) -> Self: ... + def assert_existing_objects_matched(self) -> Self: ... + def assert_nontrivial_match(self) -> Self: ... + def expect_partial(self) -> Self: ... + +class Checkpoint: + def __init__(self, root: Trackable | None = None, **kwargs: Trackable) -> None: ... + def read(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... + def restore(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... + def save(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... + # def sync(self) -> None: ... # Uncomment when the stubs' TF version is updated to 2.15 + def write(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... + +class CheckpointManager: + def __init__( + self, + checkpoint: Checkpoint, + directory: str, + max_to_keep: int, + keep_checkpoint_every_n_hours: int | None = None, + checkpoint_name: str = "ckpt", + step_counter: tf.Variable | None = None, + checkpoint_interval: int | None = None, + init_fn: Callable[[], object] | None = None, + ) -> None: ... + def _sweep(self) -> None: ... + +def latest_checkpoint(checkpoint_dir: str, latest_filename: str | None = None) -> str: ... +def load_variable(ckpt_dir_or_file: str, name: str) -> np.ndarray[Any, Any]: ... +def list_variables(ckpt_dir_or_file: str) -> list[tuple[str, list[int]]]: ... +def __getattr__(name: str) -> Incomplete: ... diff --git a/stubs/tensorflow/tensorflow/train/experimental.pyi b/stubs/tensorflow/tensorflow/train/experimental.pyi new file mode 100644 index 000000000000..c7f3c0ded5ec --- /dev/null +++ b/stubs/tensorflow/tensorflow/train/experimental.pyi @@ -0,0 +1,13 @@ +import abc +from _typeshed import Incomplete +from typing_extensions import Self + +from tensorflow.python.trackable.base import Trackable + +class PythonState(Trackable, metaclass=abc.ABCMeta): + @abc.abstractmethod + def serialize(self) -> str: ... + @abc.abstractmethod + def deserialize(self, string_value: str) -> Self: ... + +def __getattr__(name: str) -> Incomplete: ...