From b50a94a85fb571e40e834b981beec6bcdb0364de Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sun, 18 Feb 2024 21:07:03 +0900 Subject: [PATCH 1/4] move GradientTape to autodiff, add ForwardAccumulator --- stubs/tensorflow/tensorflow/__init__.pyi | 60 ++--------------------- stubs/tensorflow/tensorflow/autodiff.pyi | 61 ++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 57 deletions(-) create mode 100644 stubs/tensorflow/tensorflow/autodiff.pyi diff --git a/stubs/tensorflow/tensorflow/__init__.pyi b/stubs/tensorflow/tensorflow/__init__.pyi index 7e70bc61ab17..308690fe7556 100644 --- a/stubs/tensorflow/tensorflow/__init__.pyi +++ b/stubs/tensorflow/tensorflow/__init__.pyi @@ -1,7 +1,7 @@ from _typeshed import Incomplete, Unused from abc import ABC, ABCMeta, abstractmethod from builtins import bool as _bool -from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence from contextlib import contextmanager from enum import Enum from types import TracebackType @@ -18,18 +18,8 @@ from tensorflow import ( keras as keras, math as math, ) -from tensorflow._aliases import ( - AnyArray, - ContainerGradients, - ContainerTensors, - ContainerTensorsLike, - DTypeLike, - Gradients, - ShapeLike, - Slice, - TensorCompatible, - TensorLike, -) +from tensorflow._aliases import AnyArray, DTypeLike, ShapeLike, Slice, TensorCompatible +from tensorflow.autodiff import GradientTape as GradientTape from tensorflow.core.protobuf import struct_pb2 # Explicit import of DType is covered by the wildcard, but @@ -302,50 +292,6 @@ class UnconnectedGradients(Enum): NONE = "none" ZERO = "zero" -class GradientTape: - def __init__(self, persistent: _bool = False, watch_accessed_variables: _bool = True) -> None: ... - def __enter__(self) -> Self: ... - def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ... - # Higher kinded types would be nice here and these overloads are a way to simulate some of them. - @overload - def gradient( - self, - target: ContainerTensors, - sources: TensorLike, - output_gradients: list[Tensor] | None = None, - unconnected_gradients: UnconnectedGradients = ..., - ) -> Gradients: ... - @overload - def gradient( - self, - target: ContainerTensors, - sources: Sequence[Tensor], - output_gradients: list[Tensor] | None = None, - unconnected_gradients: UnconnectedGradients = ..., - ) -> list[Gradients]: ... - @overload - def gradient( - self, - target: ContainerTensors, - sources: Mapping[str, Tensor], - output_gradients: list[Tensor] | None = None, - unconnected_gradients: UnconnectedGradients = ..., - ) -> dict[str, Gradients]: ... - @overload - def gradient( - self, - target: ContainerTensors, - sources: ContainerTensors, - output_gradients: list[Tensor] | None = None, - unconnected_gradients: UnconnectedGradients = ..., - ) -> ContainerGradients: ... - @contextmanager - def stop_recording(self) -> Generator[None, None, None]: ... - def reset(self) -> None: ... - def watch(self, tensor: ContainerTensorsLike) -> None: ... - def watched_variables(self) -> tuple[Variable, ...]: ... - def __getattr__(self, name: str) -> Incomplete: ... - _SpecProto = TypeVar("_SpecProto", bound=Message) class TypeSpec(ABC, Generic[_SpecProto]): diff --git a/stubs/tensorflow/tensorflow/autodiff.pyi b/stubs/tensorflow/tensorflow/autodiff.pyi new file mode 100644 index 000000000000..5a920a19f3c1 --- /dev/null +++ b/stubs/tensorflow/tensorflow/autodiff.pyi @@ -0,0 +1,61 @@ +from _typeshed import Incomplete +from builtins import bool as _bool +from collections.abc import Generator, Mapping, Sequence +from contextlib import contextmanager +from types import TracebackType +from typing import overload +from typing_extensions import Self + +import tensorflow as tf +from tensorflow import Tensor, UnconnectedGradients, Variable +from tensorflow._aliases import ContainerGradients, ContainerTensors, ContainerTensorsLike, Gradients, TensorLike + +class ForwardAccumulator: + def __init__(self, primals: Tensor, tangents: Tensor) -> None: ... + def jvp( + self, primals: Tensor, unconnected_gradients: tf.UnconnectedGradients = tf.UnconnectedGradients.NONE # noqa: Y011 + ): ... + +class GradientTape: + def __init__(self, persistent: _bool = False, watch_accessed_variables: _bool = True) -> None: ... + def __enter__(self) -> Self: ... + def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ... + # Higher kinded types would be nice here and these overloads are a way to simulate some of them. + @overload + def gradient( + self, + target: ContainerTensors, + sources: TensorLike, + output_gradients: list[Tensor] | None = None, + unconnected_gradients: UnconnectedGradients = ..., + ) -> Gradients: ... + @overload + def gradient( + self, + target: ContainerTensors, + sources: Sequence[Tensor], + output_gradients: list[Tensor] | None = None, + unconnected_gradients: UnconnectedGradients = ..., + ) -> list[Gradients]: ... + @overload + def gradient( + self, + target: ContainerTensors, + sources: Mapping[str, Tensor], + output_gradients: list[Tensor] | None = None, + unconnected_gradients: UnconnectedGradients = ..., + ) -> dict[str, Gradients]: ... + @overload + def gradient( + self, + target: ContainerTensors, + sources: ContainerTensors, + output_gradients: list[Tensor] | None = None, + unconnected_gradients: UnconnectedGradients = ..., + ) -> ContainerGradients: ... + @contextmanager + def stop_recording(self) -> Generator[None, None, None]: ... + def reset(self) -> None: ... + def watch(self, tensor: ContainerTensorsLike) -> None: ... + def watched_variables(self) -> tuple[Variable, ...]: ... + def __getattr__(self, name: str) -> Incomplete: ... From b1572a92ce7b67f0558f1048a6e170f7dcd515f3 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sun, 18 Feb 2024 21:48:14 +0900 Subject: [PATCH 2/4] add tensorflow.autodiff.GradientTape.__getattr__ to allowlist --- stubs/tensorflow/@tests/stubtest_allowlist.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/stubs/tensorflow/@tests/stubtest_allowlist.txt b/stubs/tensorflow/@tests/stubtest_allowlist.txt index 57cd5b91a724..224ec46db4e0 100644 --- a/stubs/tensorflow/@tests/stubtest_allowlist.txt +++ b/stubs/tensorflow/@tests/stubtest_allowlist.txt @@ -20,6 +20,7 @@ tensorflow.python.feature_column.feature_column_v2.SharedEmbeddingColumnCreator. tensorflow.GradientTape.__getattr__ tensorflow.data.Dataset.__getattr__ tensorflow.experimental.Optional.__getattr__ +tensorflow.autodiff.GradientTape.__getattr__ # The Tensor methods below were removed in 2.14, however they are still defined for the # internal subclasses that are used at runtime/in practice. From b56bd08775613d026925a499bc4d3bd3b285049c Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sun, 18 Feb 2024 21:49:19 +0900 Subject: [PATCH 3/4] add missing return type --- stubs/tensorflow/tensorflow/autodiff.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stubs/tensorflow/tensorflow/autodiff.pyi b/stubs/tensorflow/tensorflow/autodiff.pyi index 5a920a19f3c1..802fb3ea42c2 100644 --- a/stubs/tensorflow/tensorflow/autodiff.pyi +++ b/stubs/tensorflow/tensorflow/autodiff.pyi @@ -14,7 +14,7 @@ class ForwardAccumulator: def __init__(self, primals: Tensor, tangents: Tensor) -> None: ... def jvp( self, primals: Tensor, unconnected_gradients: tf.UnconnectedGradients = tf.UnconnectedGradients.NONE # noqa: Y011 - ): ... + ) -> Tensor | None: ... class GradientTape: def __init__(self, persistent: _bool = False, watch_accessed_variables: _bool = True) -> None: ... From f3f3cc9b0d9193a82ef4f5c4d9a3c600c586fe07 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Thu, 29 Feb 2024 18:47:23 +0900 Subject: [PATCH 4/4] fix: add __enter__ and __exit__ to ForwardAccumulator --- stubs/tensorflow/tensorflow/autodiff.pyi | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stubs/tensorflow/tensorflow/autodiff.pyi b/stubs/tensorflow/tensorflow/autodiff.pyi index 802fb3ea42c2..61077d6f1309 100644 --- a/stubs/tensorflow/tensorflow/autodiff.pyi +++ b/stubs/tensorflow/tensorflow/autodiff.pyi @@ -15,6 +15,8 @@ class ForwardAccumulator: def jvp( self, primals: Tensor, unconnected_gradients: tf.UnconnectedGradients = tf.UnconnectedGradients.NONE # noqa: Y011 ) -> Tensor | None: ... + def __enter__(self) -> Self: ... + def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ... class GradientTape: def __init__(self, persistent: _bool = False, watch_accessed_variables: _bool = True) -> None: ...