From 16d48df8eefdf1118f0123c027854c8e25e823f0 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sat, 27 Jan 2024 21:55:04 +0900 Subject: [PATCH 01/21] add CheckpointOptions --- .../tensorflow/tensorflow/train/__init__.pyi | 26 ++++++++++++++++ .../tensorflow/train/experimental.pyi | 30 +++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 stubs/tensorflow/tensorflow/train/__init__.pyi create mode 100644 stubs/tensorflow/tensorflow/train/experimental.pyi diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi new file mode 100644 index 000000000000..d1e72b97f022 --- /dev/null +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -0,0 +1,26 @@ +from _typeshed import Incomplete +from collections.abc import Callable +from typing import Any, Self + +import tensorflow as tf + +class CheckpointOptions: + experimental_io_device: None | str + experimental_enable_async_checkpoint: bool + experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] + enable_async: bool + experimental_skip_slot_variables: bool + experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = None, + + def __init__( + self, + experimental_io_device: None | str = None, + experimental_enable_async_checkpoint: bool = False, + experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] = None, + enable_async: bool = False, + experimental_skip_slot_variables: bool = False, + experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = None, + ) -> None: ... + def __copy__(self) -> Self: ... + +def __getattr__(name: str) -> Incomplete: ... diff --git a/stubs/tensorflow/tensorflow/train/experimental.pyi b/stubs/tensorflow/tensorflow/train/experimental.pyi new file mode 100644 index 000000000000..def97c1018ac --- /dev/null +++ b/stubs/tensorflow/tensorflow/train/experimental.pyi @@ -0,0 +1,30 @@ +import abc +import dataclasses +from _typeshed import Incomplete +from typing import Sequence + +from tensorflow import Tensor, TensorShape +from tensorflow.dtypes import DType +from tensorflow.python.trackable.base import Trackable + +@dataclasses.dataclass(frozen=True) +class ShardableTensor: + _tensor_save_spec: Incomplete # saveable_object.SaveSpec + tensor: Tensor + dtype: DType + device: Incomplete # device_lib.DeviceSpec + name: str + shape: TensorShape + slice_spec: Incomplete # variables.Variable.SaveSliceInfo + checkpoint_key: str + trackable: Trackable + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + +class ShardingCallback(abc.ABC): + def description(self) -> str: ... + @abc.abstractmethod + def __call__(self, shardable_tensors: Sequence[ShardableTensor]) -> Sequence[Incomplete]: ... # Sequence[TensorSliceDict] + def __hash__(self) -> int: ... + +def __getattr__(name: str) -> Incomplete: ... From 31a3c6f25202fe6f350e887b032a04890722f005 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sat, 27 Jan 2024 22:11:19 +0900 Subject: [PATCH 02/21] add other tf.train members Add members already in https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/train/__init__.pyi --- .../tensorflow/tensorflow/train/__init__.pyi | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index d1e72b97f022..bae2c68ddb2c 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -1,8 +1,17 @@ from _typeshed import Incomplete +from typing import Any, Callable, TypeVar +from typing_extensions import Self + +from google.protobuf.message import Message + +import numpy as np + +import tensorflow as tf from collections.abc import Callable from typing import Any, Self import tensorflow as tf +from tensorflow.python.trackable.base import Trackable class CheckpointOptions: experimental_io_device: None | str @@ -23,4 +32,70 @@ class CheckpointOptions: ) -> None: ... def __copy__(self) -> Self: ... +class Example(Message): + features: Features + +class Features(Message): + feature: dict[str, Feature] + +class Feature(Message): + float_list: FloatList + int64_list: Int64List + bytes_list: BytesList + +class FloatList(Message): + value: list[float] + +class Int64List(Message): + value: list[int] + +class BytesList(Message): + value: list[bytes] + +class ServerDef(Message): ... +class ClusterDef(Message): ... + +_T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str]) + +class ClusterSpec: + def __init__(self, cluster: dict[str, _T] | ClusterDef | ClusterSpec) -> None: ... + def as_dict(self) -> dict[str, list[str] | tuple[str] | dict[int, str]]: ... + def num_tasks(self, job_name: str) -> int: ... + +class CheckpointReader: + def get_variable_to_shape_map(self) -> dict[str, list[int]]: ... + def get_variable_to_dtype_map(self) -> dict[str, tf.DType]: ... + def get_tensor(self, name: str) -> np.ndarray[Any, Any] | Any: ... + +class _CheckpointLoadStatus: + def assert_consumed(self) -> Self: ... + def assert_existing_objects_matched(self) -> Self: ... + def assert_nontrivial_match(self) -> Self: ... + def expect_partial(self) -> Self: ... + +class Checkpoint: + def __init__(self, root: Trackable | None = None, **kwargs: Trackable): ... + def write(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... + def read(self, file_prefix: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... + def restore(self, file_prefix: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... + +class CheckpointManager: + def __init__( + self, + checkpoint: Checkpoint, + directory: str, + max_to_keep: int, + keep_checkpoint_every_n_hours: int | None = None, + checkpoint_name: str = "ckpt", + step_counter: tf.Variable | None = None, + checkpoint_interval: int | None = None, + init_fn: Callable[[], object] | None = None, + ): ... + def _sweep(self) -> None: ... + +def latest_checkpoint(checkpoint_dir: str, latest_filename: str | None = None) -> str: ... +def load_checkpoint(ckpt_dir_or_file: str) -> CheckpointReader: ... +def load_variable(ckpt_dir_or_file: str, name: str) -> np.ndarray[Any, Any]: ... +def list_variables(ckpt_dir_or_file: str) -> list[tuple[str, list[int]]]: ... + def __getattr__(name: str) -> Incomplete: ... From abaad59c3336f8e1b0301362573af6c8a99e2e63 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 27 Jan 2024 13:13:29 +0000 Subject: [PATCH 03/21] [pre-commit.ci] auto fixes from pre-commit.com hooks --- .../tensorflow/tensorflow/train/__init__.pyi | 14 ++++-------- .../tensorflow/train/experimental.pyi | 22 +++++++++---------- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index bae2c68ddb2c..820a0b118027 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -1,16 +1,11 @@ from _typeshed import Incomplete -from typing import Any, Callable, TypeVar +from collections.abc import Callable +from typing import Any, Callable, Self, TypeVar from typing_extensions import Self -from google.protobuf.message import Message - import numpy as np - -import tensorflow as tf -from collections.abc import Callable -from typing import Any, Self - import tensorflow as tf +from google.protobuf.message import Message from tensorflow.python.trackable.base import Trackable class CheckpointOptions: @@ -19,7 +14,7 @@ class CheckpointOptions: experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] enable_async: bool experimental_skip_slot_variables: bool - experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = None, + experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = (None,) def __init__( self, @@ -97,5 +92,4 @@ def latest_checkpoint(checkpoint_dir: str, latest_filename: str | None = None) - def load_checkpoint(ckpt_dir_or_file: str) -> CheckpointReader: ... def load_variable(ckpt_dir_or_file: str, name: str) -> np.ndarray[Any, Any]: ... def list_variables(ckpt_dir_or_file: str) -> list[tuple[str, list[int]]]: ... - def __getattr__(name: str) -> Incomplete: ... diff --git a/stubs/tensorflow/tensorflow/train/experimental.pyi b/stubs/tensorflow/tensorflow/train/experimental.pyi index def97c1018ac..371daf819910 100644 --- a/stubs/tensorflow/tensorflow/train/experimental.pyi +++ b/stubs/tensorflow/tensorflow/train/experimental.pyi @@ -9,17 +9,17 @@ from tensorflow.python.trackable.base import Trackable @dataclasses.dataclass(frozen=True) class ShardableTensor: - _tensor_save_spec: Incomplete # saveable_object.SaveSpec - tensor: Tensor - dtype: DType - device: Incomplete # device_lib.DeviceSpec - name: str - shape: TensorShape - slice_spec: Incomplete # variables.Variable.SaveSliceInfo - checkpoint_key: str - trackable: Trackable - def __hash__(self) -> int: ... - def __repr__(self) -> str: ... + _tensor_save_spec: Incomplete # saveable_object.SaveSpec + tensor: Tensor + dtype: DType + device: Incomplete # device_lib.DeviceSpec + name: str + shape: TensorShape + slice_spec: Incomplete # variables.Variable.SaveSliceInfo + checkpoint_key: str + trackable: Trackable + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... class ShardingCallback(abc.ABC): def description(self) -> str: ... From 2a97826bc4fbb3b03c2396b877a7d84e9067dd9e Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sat, 27 Jan 2024 22:17:04 +0900 Subject: [PATCH 04/21] fix CI errors --- stubs/tensorflow/tensorflow/train/__init__.pyi | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index 820a0b118027..98dca7fb28cd 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -1,6 +1,6 @@ from _typeshed import Incomplete from collections.abc import Callable -from typing import Any, Callable, Self, TypeVar +from typing import Any, Callable, TypeVar from typing_extensions import Self import numpy as np @@ -14,7 +14,7 @@ class CheckpointOptions: experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] enable_async: bool experimental_skip_slot_variables: bool - experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = (None,) + experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = None def __init__( self, From 6fed7b9790a931d6ff49662b1193982b8bf9a5b9 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sat, 27 Jan 2024 22:31:50 +0900 Subject: [PATCH 05/21] fix CI errors --- .../tensorflow/tensorflow/train/__init__.pyi | 43 ++++++++++--------- .../tensorflow/train/experimental.pyi | 43 ++++++++----------- 2 files changed, 41 insertions(+), 45 deletions(-) diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index 98dca7fb28cd..c2b522c869ec 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -1,11 +1,11 @@ from _typeshed import Incomplete from collections.abc import Callable -from typing import Any, Callable, TypeVar +from typing import Any, Callable, TypeVar, final from typing_extensions import Self import numpy as np import tensorflow as tf -from google.protobuf.message import Message +from google.protobuf.message import MessageMeta from tensorflow.python.trackable.base import Trackable class CheckpointOptions: @@ -13,8 +13,8 @@ class CheckpointOptions: experimental_enable_async_checkpoint: bool experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] enable_async: bool - experimental_skip_slot_variables: bool - experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = None + # experimental_skip_slot_variables: bool + # experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = None def __init__( self, @@ -22,33 +22,42 @@ class CheckpointOptions: experimental_enable_async_checkpoint: bool = False, experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] = None, enable_async: bool = False, - experimental_skip_slot_variables: bool = False, - experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = None, + # experimental_skip_slot_variables: bool = False, + # experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = None, ) -> None: ... def __copy__(self) -> Self: ... -class Example(Message): +@final +class Example(MessageMeta): features: Features -class Features(Message): +@final +class Features(MessageMeta): feature: dict[str, Feature] -class Feature(Message): +@final +class Feature(MessageMeta): float_list: FloatList int64_list: Int64List bytes_list: BytesList -class FloatList(Message): +@final +class FloatList(MessageMeta): value: list[float] -class Int64List(Message): +@final +class Int64List(MessageMeta): value: list[int] -class BytesList(Message): +@final +class BytesList(MessageMeta): value: list[bytes] -class ServerDef(Message): ... -class ClusterDef(Message): ... +@final +class ServerDef(MessageMeta): ... + +@final +class ClusterDef(MessageMeta): ... _T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str]) @@ -57,11 +66,6 @@ class ClusterSpec: def as_dict(self) -> dict[str, list[str] | tuple[str] | dict[int, str]]: ... def num_tasks(self, job_name: str) -> int: ... -class CheckpointReader: - def get_variable_to_shape_map(self) -> dict[str, list[int]]: ... - def get_variable_to_dtype_map(self) -> dict[str, tf.DType]: ... - def get_tensor(self, name: str) -> np.ndarray[Any, Any] | Any: ... - class _CheckpointLoadStatus: def assert_consumed(self) -> Self: ... def assert_existing_objects_matched(self) -> Self: ... @@ -89,7 +93,6 @@ class CheckpointManager: def _sweep(self) -> None: ... def latest_checkpoint(checkpoint_dir: str, latest_filename: str | None = None) -> str: ... -def load_checkpoint(ckpt_dir_or_file: str) -> CheckpointReader: ... def load_variable(ckpt_dir_or_file: str, name: str) -> np.ndarray[Any, Any]: ... def list_variables(ckpt_dir_or_file: str) -> list[tuple[str, list[int]]]: ... def __getattr__(name: str) -> Incomplete: ... diff --git a/stubs/tensorflow/tensorflow/train/experimental.pyi b/stubs/tensorflow/tensorflow/train/experimental.pyi index 371daf819910..827d7eb5ec6f 100644 --- a/stubs/tensorflow/tensorflow/train/experimental.pyi +++ b/stubs/tensorflow/tensorflow/train/experimental.pyi @@ -1,30 +1,23 @@ -import abc -import dataclasses from _typeshed import Incomplete -from typing import Sequence -from tensorflow import Tensor, TensorShape -from tensorflow.dtypes import DType -from tensorflow.python.trackable.base import Trackable +# @dataclasses.dataclass(frozen=True) +# class ShardableTensor: +# _tensor_save_spec: Incomplete # saveable_object.SaveSpec +# tensor: Tensor +# dtype: DType +# device: Incomplete # device_lib.DeviceSpec +# name: str +# shape: TensorShape +# slice_spec: Incomplete # variables.Variable.SaveSliceInfo +# checkpoint_key: str +# trackable: Trackable +# def __hash__(self) -> int: ... +# def __repr__(self) -> str: ... -@dataclasses.dataclass(frozen=True) -class ShardableTensor: - _tensor_save_spec: Incomplete # saveable_object.SaveSpec - tensor: Tensor - dtype: DType - device: Incomplete # device_lib.DeviceSpec - name: str - shape: TensorShape - slice_spec: Incomplete # variables.Variable.SaveSliceInfo - checkpoint_key: str - trackable: Trackable - def __hash__(self) -> int: ... - def __repr__(self) -> str: ... - -class ShardingCallback(abc.ABC): - def description(self) -> str: ... - @abc.abstractmethod - def __call__(self, shardable_tensors: Sequence[ShardableTensor]) -> Sequence[Incomplete]: ... # Sequence[TensorSliceDict] - def __hash__(self) -> int: ... +# class ShardingCallback(abc.ABC): +# def description(self) -> str: ... +# @abc.abstractmethod +# def __call__(self, shardable_tensors: Sequence[ShardableTensor]) -> Sequence[Incomplete]: ... # Sequence[TensorSliceDict] +# def __hash__(self) -> int: ... def __getattr__(name: str) -> Incomplete: ... From bbaa5477556fe7ba4f3e9918bb5b5f48dbb507f0 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sat, 27 Jan 2024 22:41:13 +0900 Subject: [PATCH 06/21] MessageMeta -> Message --- stubs/tensorflow/tensorflow/train/__init__.pyi | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index c2b522c869ec..547b6e9b8d07 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -5,7 +5,7 @@ from typing_extensions import Self import numpy as np import tensorflow as tf -from google.protobuf.message import MessageMeta +from google.protobuf.message import Message from tensorflow.python.trackable.base import Trackable class CheckpointOptions: @@ -28,36 +28,36 @@ class CheckpointOptions: def __copy__(self) -> Self: ... @final -class Example(MessageMeta): +class Example(Message): features: Features @final -class Features(MessageMeta): +class Features(Message): feature: dict[str, Feature] @final -class Feature(MessageMeta): +class Feature(Message): float_list: FloatList int64_list: Int64List bytes_list: BytesList @final -class FloatList(MessageMeta): +class FloatList(Message): value: list[float] @final -class Int64List(MessageMeta): +class Int64List(Message): value: list[int] @final -class BytesList(MessageMeta): +class BytesList(Message): value: list[bytes] @final -class ServerDef(MessageMeta): ... +class ServerDef(Message): ... @final -class ClusterDef(MessageMeta): ... +class ClusterDef(Message): ... _T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str]) From 789b1150c06d44d1ee063880bf7523c1601ad098 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sat, 27 Jan 2024 22:52:01 +0900 Subject: [PATCH 07/21] fix CI errors --- stubs/tensorflow/tensorflow/train/__init__.pyi | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index 547b6e9b8d07..42aa9faa7570 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -1,6 +1,6 @@ from _typeshed import Incomplete from collections.abc import Callable -from typing import Any, Callable, TypeVar, final +from typing import Any, TypeVar, final from typing_extensions import Self import numpy as np @@ -73,7 +73,7 @@ class _CheckpointLoadStatus: def expect_partial(self) -> Self: ... class Checkpoint: - def __init__(self, root: Trackable | None = None, **kwargs: Trackable): ... + def __init__(self, root: Trackable | None = None, **kwargs: Trackable) -> None: ... def write(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... def read(self, file_prefix: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... def restore(self, file_prefix: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... @@ -89,7 +89,7 @@ class CheckpointManager: step_counter: tf.Variable | None = None, checkpoint_interval: int | None = None, init_fn: Callable[[], object] | None = None, - ): ... + ) -> None: ... def _sweep(self) -> None: ... def latest_checkpoint(checkpoint_dir: str, latest_filename: str | None = None) -> str: ... From 9d381e438cfb11214b1c526c4795e9248f8c5b42 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sat, 27 Jan 2024 22:57:28 +0900 Subject: [PATCH 08/21] Message -> MessageMeta --- stubs/tensorflow/tensorflow/train/__init__.pyi | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index 42aa9faa7570..f2f8cfc2a9e9 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -5,7 +5,7 @@ from typing_extensions import Self import numpy as np import tensorflow as tf -from google.protobuf.message import Message +from google._upb._message import MessageMeta from tensorflow.python.trackable.base import Trackable class CheckpointOptions: @@ -28,36 +28,36 @@ class CheckpointOptions: def __copy__(self) -> Self: ... @final -class Example(Message): +class Example(MessageMeta): features: Features @final -class Features(Message): +class Features(MessageMeta): feature: dict[str, Feature] @final -class Feature(Message): +class Feature(MessageMeta): float_list: FloatList int64_list: Int64List bytes_list: BytesList @final -class FloatList(Message): +class FloatList(MessageMeta): value: list[float] @final -class Int64List(Message): +class Int64List(MessageMeta): value: list[int] @final -class BytesList(Message): +class BytesList(MessageMeta): value: list[bytes] @final -class ServerDef(Message): ... +class ServerDef(MessageMeta): ... @final -class ClusterDef(Message): ... +class ClusterDef(MessageMeta): ... _T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str]) From 56e6bcf64941c6c95f3042ae2133ab5b3f9acf4e Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sat, 27 Jan 2024 23:37:44 +0900 Subject: [PATCH 09/21] MessageMeta -> GeneratedProtocolMessageType --- stubs/tensorflow/tensorflow/train/__init__.pyi | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index f2f8cfc2a9e9..96210058dffc 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -5,7 +5,7 @@ from typing_extensions import Self import numpy as np import tensorflow as tf -from google._upb._message import MessageMeta +from google.protobuf.pyext.cpp_message import GeneratedProtocolMessageType from tensorflow.python.trackable.base import Trackable class CheckpointOptions: @@ -28,36 +28,36 @@ class CheckpointOptions: def __copy__(self) -> Self: ... @final -class Example(MessageMeta): +class Example(GeneratedProtocolMessageType): features: Features @final -class Features(MessageMeta): +class Features(GeneratedProtocolMessageType): feature: dict[str, Feature] @final -class Feature(MessageMeta): +class Feature(GeneratedProtocolMessageType): float_list: FloatList int64_list: Int64List bytes_list: BytesList @final -class FloatList(MessageMeta): +class FloatList(GeneratedProtocolMessageType): value: list[float] @final -class Int64List(MessageMeta): +class Int64List(GeneratedProtocolMessageType): value: list[int] @final -class BytesList(MessageMeta): +class BytesList(GeneratedProtocolMessageType): value: list[bytes] @final -class ServerDef(MessageMeta): ... +class ServerDef(GeneratedProtocolMessageType): ... @final -class ClusterDef(MessageMeta): ... +class ClusterDef(GeneratedProtocolMessageType): ... _T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str]) From 0f7548cdc1d5286b6ce63fff6a4a8a2145d86981 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sun, 28 Jan 2024 00:05:58 +0900 Subject: [PATCH 10/21] Add protobuff message type (incomplete). --- stubs/protobuf/google/protobuf/pyext/cpp_message.pyi | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 stubs/protobuf/google/protobuf/pyext/cpp_message.pyi diff --git a/stubs/protobuf/google/protobuf/pyext/cpp_message.pyi b/stubs/protobuf/google/protobuf/pyext/cpp_message.pyi new file mode 100644 index 000000000000..c85aa53a20b0 --- /dev/null +++ b/stubs/protobuf/google/protobuf/pyext/cpp_message.pyi @@ -0,0 +1,6 @@ +from _typeshed import Incomplete + +class GeneratedProtocolMessageType(Incomplete): ... + + +def __getattr__(name: str) -> Incomplete: ... From 450c020d081ecdd48c5589652a9e22ba05f2a14a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 27 Jan 2024 15:08:19 +0000 Subject: [PATCH 11/21] [pre-commit.ci] auto fixes from pre-commit.com hooks --- stubs/protobuf/google/protobuf/pyext/cpp_message.pyi | 1 - 1 file changed, 1 deletion(-) diff --git a/stubs/protobuf/google/protobuf/pyext/cpp_message.pyi b/stubs/protobuf/google/protobuf/pyext/cpp_message.pyi index c85aa53a20b0..af07cf3a4917 100644 --- a/stubs/protobuf/google/protobuf/pyext/cpp_message.pyi +++ b/stubs/protobuf/google/protobuf/pyext/cpp_message.pyi @@ -2,5 +2,4 @@ from _typeshed import Incomplete class GeneratedProtocolMessageType(Incomplete): ... - def __getattr__(name: str) -> Incomplete: ... From 72411155f40f1076fcf0e206bd0a63d5c1192e3c Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sun, 28 Jan 2024 00:39:30 +0900 Subject: [PATCH 12/21] fix Checkpoint's methods --- stubs/tensorflow/tensorflow/train/__init__.pyi | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index 96210058dffc..45652c8788a0 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -25,7 +25,7 @@ class CheckpointOptions: # experimental_skip_slot_variables: bool = False, # experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = None, ) -> None: ... - def __copy__(self) -> Self: ... + # def __copy__(self) -> Self: ... @final class Example(GeneratedProtocolMessageType): @@ -74,9 +74,11 @@ class _CheckpointLoadStatus: class Checkpoint: def __init__(self, root: Trackable | None = None, **kwargs: Trackable) -> None: ... + def read(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... + def restore(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... + def save(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... + def sync(self) -> None: ... def write(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... - def read(self, file_prefix: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... - def restore(self, file_prefix: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... class CheckpointManager: def __init__( From bf6c853d173e78bb48335a946483e2ca74e92519 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sun, 28 Jan 2024 00:48:46 +0900 Subject: [PATCH 13/21] comment out method and arg not present at runtime. --- stubs/tensorflow/tensorflow/train/__init__.pyi | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index 45652c8788a0..a73322b2f929 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -11,7 +11,7 @@ from tensorflow.python.trackable.base import Trackable class CheckpointOptions: experimental_io_device: None | str experimental_enable_async_checkpoint: bool - experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] + # experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] enable_async: bool # experimental_skip_slot_variables: bool # experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = None @@ -20,7 +20,7 @@ class CheckpointOptions: self, experimental_io_device: None | str = None, experimental_enable_async_checkpoint: bool = False, - experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] = None, + # experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] = None, enable_async: bool = False, # experimental_skip_slot_variables: bool = False, # experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = None, @@ -77,7 +77,7 @@ class Checkpoint: def read(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... def restore(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... def save(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... - def sync(self) -> None: ... + # def sync(self) -> None: ... def write(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... class CheckpointManager: From 66fcf637f9b0f26d39c9bd54083eb8b1dfc7b75d Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sun, 28 Jan 2024 11:30:38 +0900 Subject: [PATCH 14/21] Move messages to core/example --- .../google/protobuf/pyext/cpp_message.pyi | 5 --- .../tensorflow/core/example/example_pb2.pyi | 29 ++++++++++++ .../tensorflow/tensorflow/train/__init__.pyi | 44 +++++-------------- 3 files changed, 39 insertions(+), 39 deletions(-) delete mode 100644 stubs/protobuf/google/protobuf/pyext/cpp_message.pyi diff --git a/stubs/protobuf/google/protobuf/pyext/cpp_message.pyi b/stubs/protobuf/google/protobuf/pyext/cpp_message.pyi deleted file mode 100644 index af07cf3a4917..000000000000 --- a/stubs/protobuf/google/protobuf/pyext/cpp_message.pyi +++ /dev/null @@ -1,5 +0,0 @@ -from _typeshed import Incomplete - -class GeneratedProtocolMessageType(Incomplete): ... - -def __getattr__(name: str) -> Incomplete: ... diff --git a/stubs/tensorflow/tensorflow/core/example/example_pb2.pyi b/stubs/tensorflow/tensorflow/core/example/example_pb2.pyi index cd4bcd511a7e..caa641d8efe0 100644 --- a/stubs/tensorflow/tensorflow/core/example/example_pb2.pyi +++ b/stubs/tensorflow/tensorflow/core/example/example_pb2.pyi @@ -335,4 +335,33 @@ class SequenceExample(google.protobuf.message.Message): def HasField(self, field_name: typing_extensions.Literal["context", b"context", "feature_lists", b"feature_lists"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["context", b"context", "feature_lists", b"feature_lists"]) -> None: ... + +@typing_extensions.final +class Features(google.protobuf.message.Message): + feature: dict[str, Feature] + +@typing_extensions.final +class Feature(google.protobuf.message.Message): + float_list: FloatList + int64_list: Int64List + bytes_list: BytesList + +@typing_extensions.final +class FloatList(google.protobuf.message.Message): + value: list[float] + +@typing_extensions.final +class Int64List(google.protobuf.message.Message): + value: list[int] + +@typing_extensions.final +class BytesList(google.protobuf.message.Message): + value: list[bytes] + +@typing_extensions.final +class ServerDef(google.protobuf.message.Message): ... + +@typing_extensions.final +class ClusterDef(google.protobuf.message.Message): ... + global___SequenceExample = SequenceExample diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index a73322b2f929..753ffccc437b 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -1,12 +1,20 @@ from _typeshed import Incomplete from collections.abc import Callable -from typing import Any, TypeVar, final +from typing import Any, TypeVar from typing_extensions import Self import numpy as np import tensorflow as tf -from google.protobuf.pyext.cpp_message import GeneratedProtocolMessageType from tensorflow.python.trackable.base import Trackable +from tensorflow.core.example.example_pb2 import( + Features as Features, + Feature as Feature, + FloatList as FloatList, + Int64List as Int64List, + BytesList as BytesList, + ServerDef as ServerDef, + ClusterDef as ClusterDef +) class CheckpointOptions: experimental_io_device: None | str @@ -27,38 +35,6 @@ class CheckpointOptions: ) -> None: ... # def __copy__(self) -> Self: ... -@final -class Example(GeneratedProtocolMessageType): - features: Features - -@final -class Features(GeneratedProtocolMessageType): - feature: dict[str, Feature] - -@final -class Feature(GeneratedProtocolMessageType): - float_list: FloatList - int64_list: Int64List - bytes_list: BytesList - -@final -class FloatList(GeneratedProtocolMessageType): - value: list[float] - -@final -class Int64List(GeneratedProtocolMessageType): - value: list[int] - -@final -class BytesList(GeneratedProtocolMessageType): - value: list[bytes] - -@final -class ServerDef(GeneratedProtocolMessageType): ... - -@final -class ClusterDef(GeneratedProtocolMessageType): ... - _T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str]) class ClusterSpec: From 8b1ba15efd0e775ad0150c7fe3e9ec62d0ac87f6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 28 Jan 2024 02:31:21 +0000 Subject: [PATCH 15/21] [pre-commit.ci] auto fixes from pre-commit.com hooks --- stubs/tensorflow/tensorflow/train/__init__.pyi | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index 753ffccc437b..5e68e2de13f5 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -5,16 +5,16 @@ from typing_extensions import Self import numpy as np import tensorflow as tf -from tensorflow.python.trackable.base import Trackable -from tensorflow.core.example.example_pb2 import( - Features as Features, +from tensorflow.core.example.example_pb2 import ( + BytesList as BytesList, + ClusterDef as ClusterDef, Feature as Feature, + Features as Features, FloatList as FloatList, Int64List as Int64List, - BytesList as BytesList, ServerDef as ServerDef, - ClusterDef as ClusterDef ) +from tensorflow.python.trackable.base import Trackable class CheckpointOptions: experimental_io_device: None | str From 6783a3e2ef6c3b0e87198faa4af9c6ed5ee1f11c Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sun, 28 Jan 2024 11:39:45 +0900 Subject: [PATCH 16/21] Fix message import origin. --- .../tensorflow/core/example/example_pb2.pyi | 29 ------------------- .../tensorflow/tensorflow/train/__init__.pyi | 6 ++-- 2 files changed, 3 insertions(+), 32 deletions(-) diff --git a/stubs/tensorflow/tensorflow/core/example/example_pb2.pyi b/stubs/tensorflow/tensorflow/core/example/example_pb2.pyi index caa641d8efe0..cd4bcd511a7e 100644 --- a/stubs/tensorflow/tensorflow/core/example/example_pb2.pyi +++ b/stubs/tensorflow/tensorflow/core/example/example_pb2.pyi @@ -335,33 +335,4 @@ class SequenceExample(google.protobuf.message.Message): def HasField(self, field_name: typing_extensions.Literal["context", b"context", "feature_lists", b"feature_lists"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["context", b"context", "feature_lists", b"feature_lists"]) -> None: ... - -@typing_extensions.final -class Features(google.protobuf.message.Message): - feature: dict[str, Feature] - -@typing_extensions.final -class Feature(google.protobuf.message.Message): - float_list: FloatList - int64_list: Int64List - bytes_list: BytesList - -@typing_extensions.final -class FloatList(google.protobuf.message.Message): - value: list[float] - -@typing_extensions.final -class Int64List(google.protobuf.message.Message): - value: list[int] - -@typing_extensions.final -class BytesList(google.protobuf.message.Message): - value: list[bytes] - -@typing_extensions.final -class ServerDef(google.protobuf.message.Message): ... - -@typing_extensions.final -class ClusterDef(google.protobuf.message.Message): ... - global___SequenceExample = SequenceExample diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index 5e68e2de13f5..6e8f2fab0400 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -5,15 +5,15 @@ from typing_extensions import Self import numpy as np import tensorflow as tf -from tensorflow.core.example.example_pb2 import ( +from tensorflow.core.example.feature_pb2 import ( BytesList as BytesList, - ClusterDef as ClusterDef, Feature as Feature, Features as Features, FloatList as FloatList, Int64List as Int64List, - ServerDef as ServerDef, ) +from tensorflow.core.protobuf.cluster_pb2 import ClusterDef as ClusterDef +from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef as ServerDef from tensorflow.python.trackable.base import Trackable class CheckpointOptions: From 8d6965637403108db6fb1885b64d90d35ebc7621 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sun, 28 Jan 2024 11:45:04 +0900 Subject: [PATCH 17/21] Re-add Example --- stubs/tensorflow/tensorflow/train/__init__.pyi | 1 + 1 file changed, 1 insertion(+) diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index 6e8f2fab0400..76df53743ec5 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -5,6 +5,7 @@ from typing_extensions import Self import numpy as np import tensorflow as tf +from tensorflow.core.example.example_pb2 import Example as Example from tensorflow.core.example.feature_pb2 import ( BytesList as BytesList, Feature as Feature, From 794e76935ddf3b5bb6957b99095f432e0706aca4 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sun, 28 Jan 2024 17:20:00 +0900 Subject: [PATCH 18/21] fix: disable stubtest errors for protobuf types. See: https://github.com/python/typeshed/pull/11327#discussion_r1468737332 --- stubs/tensorflow/@tests/stubtest_allowlist.txt | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/stubs/tensorflow/@tests/stubtest_allowlist.txt b/stubs/tensorflow/@tests/stubtest_allowlist.txt index 5700e283375f..54f489ac7baa 100644 --- a/stubs/tensorflow/@tests/stubtest_allowlist.txt +++ b/stubs/tensorflow/@tests/stubtest_allowlist.txt @@ -85,3 +85,12 @@ tensorflow.io.SparseFeature.__new__ # Metaclass inconsistency. The runtime metaclass is defined from c++ extension and is undocumented. tensorflow.io.TFRecordWriter + +# stubtest does not pass for protobuf generated stubs. +tensorflow.train.Example.* +tensorflow.train.BytesList.* +tensorflow.train.Feature.* +tensorflow.train.FloatList.* +tensorflow.train.Int64List.* +tensorflow.train.ClusterDef.* +tensorflow.train.ServerDef.* From 080b1558d13f7f6a8b30a2522a3b5b31088fbe0f Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Thu, 1 Feb 2024 08:25:08 +0900 Subject: [PATCH 19/21] Remove ShardableTensor and ShardingCallback --- .../tensorflow/train/experimental.pyi | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/stubs/tensorflow/tensorflow/train/experimental.pyi b/stubs/tensorflow/tensorflow/train/experimental.pyi index 827d7eb5ec6f..0f6820f054ea 100644 --- a/stubs/tensorflow/tensorflow/train/experimental.pyi +++ b/stubs/tensorflow/tensorflow/train/experimental.pyi @@ -1,23 +1,3 @@ from _typeshed import Incomplete -# @dataclasses.dataclass(frozen=True) -# class ShardableTensor: -# _tensor_save_spec: Incomplete # saveable_object.SaveSpec -# tensor: Tensor -# dtype: DType -# device: Incomplete # device_lib.DeviceSpec -# name: str -# shape: TensorShape -# slice_spec: Incomplete # variables.Variable.SaveSliceInfo -# checkpoint_key: str -# trackable: Trackable -# def __hash__(self) -> int: ... -# def __repr__(self) -> str: ... - -# class ShardingCallback(abc.ABC): -# def description(self) -> str: ... -# @abc.abstractmethod -# def __call__(self, shardable_tensors: Sequence[ShardableTensor]) -> Sequence[Incomplete]: ... # Sequence[TensorSliceDict] -# def __hash__(self) -> int: ... - def __getattr__(name: str) -> Incomplete: ... From 55321571ac7ea222076919e635774f11b9108fce Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Thu, 1 Feb 2024 08:25:29 +0900 Subject: [PATCH 20/21] Add PythonState --- stubs/tensorflow/tensorflow/train/experimental.pyi | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/stubs/tensorflow/tensorflow/train/experimental.pyi b/stubs/tensorflow/tensorflow/train/experimental.pyi index 0f6820f054ea..c7f3c0ded5ec 100644 --- a/stubs/tensorflow/tensorflow/train/experimental.pyi +++ b/stubs/tensorflow/tensorflow/train/experimental.pyi @@ -1,3 +1,13 @@ +import abc from _typeshed import Incomplete +from typing_extensions import Self + +from tensorflow.python.trackable.base import Trackable + +class PythonState(Trackable, metaclass=abc.ABCMeta): + @abc.abstractmethod + def serialize(self) -> str: ... + @abc.abstractmethod + def deserialize(self, string_value: str) -> Self: ... def __getattr__(name: str) -> Incomplete: ... From 2ad4be9a27d8e2027d72153b09dde9bc9a8bba38 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Thu, 1 Feb 2024 23:59:08 +0900 Subject: [PATCH 21/21] remove commented out code, or add comment to explain its presence. --- stubs/tensorflow/tensorflow/train/__init__.pyi | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/stubs/tensorflow/tensorflow/train/__init__.pyi b/stubs/tensorflow/tensorflow/train/__init__.pyi index 76df53743ec5..1f42cd14a1a6 100644 --- a/stubs/tensorflow/tensorflow/train/__init__.pyi +++ b/stubs/tensorflow/tensorflow/train/__init__.pyi @@ -20,21 +20,18 @@ from tensorflow.python.trackable.base import Trackable class CheckpointOptions: experimental_io_device: None | str experimental_enable_async_checkpoint: bool + # Uncomment when the stubs' TF version is updated to 2.15 # experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] enable_async: bool - # experimental_skip_slot_variables: bool - # experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = None def __init__( self, experimental_io_device: None | str = None, experimental_enable_async_checkpoint: bool = False, + # Uncomment when the stubs' TF version is updated to 2.15 # experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] = None, enable_async: bool = False, - # experimental_skip_slot_variables: bool = False, - # experimental_sharding_callback: tf.train.experimental.ShardingCallback | None = None, ) -> None: ... - # def __copy__(self) -> Self: ... _T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str]) @@ -54,7 +51,7 @@ class Checkpoint: def read(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... def restore(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... def save(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... - # def sync(self) -> None: ... + # def sync(self) -> None: ... # Uncomment when the stubs' TF version is updated to 2.15 def write(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... class CheckpointManager: