Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensorflow: add tensorflow.autodiff #11442

Merged
merged 4 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions stubs/tensorflow/@tests/stubtest_allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ tensorflow.python.feature_column.feature_column_v2.SharedEmbeddingColumnCreator.
tensorflow.GradientTape.__getattr__
tensorflow.data.Dataset.__getattr__
tensorflow.experimental.Optional.__getattr__
tensorflow.autodiff.GradientTape.__getattr__

# The Tensor methods below were removed in 2.14, however they are still defined for the
# internal subclasses that are used at runtime/in practice.
Expand Down
60 changes: 3 additions & 57 deletions stubs/tensorflow/tensorflow/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from _typeshed import Incomplete, Unused
from abc import ABC, ABCMeta, abstractmethod
from builtins import bool as _bool
from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
from contextlib import contextmanager
from enum import Enum
from types import TracebackType
Expand All @@ -18,18 +18,8 @@ from tensorflow import (
keras as keras,
math as math,
)
from tensorflow._aliases import (
AnyArray,
ContainerGradients,
ContainerTensors,
ContainerTensorsLike,
DTypeLike,
Gradients,
ShapeLike,
Slice,
TensorCompatible,
TensorLike,
)
from tensorflow._aliases import AnyArray, DTypeLike, ShapeLike, Slice, TensorCompatible
from tensorflow.autodiff import GradientTape as GradientTape
from tensorflow.core.protobuf import struct_pb2

# Explicit import of DType is covered by the wildcard, but
Expand Down Expand Up @@ -302,50 +292,6 @@ class UnconnectedGradients(Enum):
NONE = "none"
ZERO = "zero"

class GradientTape:
def __init__(self, persistent: _bool = False, watch_accessed_variables: _bool = True) -> None: ...
def __enter__(self) -> Self: ...
def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ...
# Higher kinded types would be nice here and these overloads are a way to simulate some of them.
@overload
def gradient(
self,
target: ContainerTensors,
sources: TensorLike,
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> Gradients: ...
@overload
def gradient(
self,
target: ContainerTensors,
sources: Sequence[Tensor],
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> list[Gradients]: ...
@overload
def gradient(
self,
target: ContainerTensors,
sources: Mapping[str, Tensor],
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> dict[str, Gradients]: ...
@overload
def gradient(
self,
target: ContainerTensors,
sources: ContainerTensors,
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> ContainerGradients: ...
@contextmanager
def stop_recording(self) -> Generator[None, None, None]: ...
def reset(self) -> None: ...
def watch(self, tensor: ContainerTensorsLike) -> None: ...
def watched_variables(self) -> tuple[Variable, ...]: ...
def __getattr__(self, name: str) -> Incomplete: ...

_SpecProto = TypeVar("_SpecProto", bound=Message)

class TypeSpec(ABC, Generic[_SpecProto]):
Expand Down
63 changes: 63 additions & 0 deletions stubs/tensorflow/tensorflow/autodiff.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from _typeshed import Incomplete
from builtins import bool as _bool
from collections.abc import Generator, Mapping, Sequence
from contextlib import contextmanager
from types import TracebackType
from typing import overload
from typing_extensions import Self

import tensorflow as tf
from tensorflow import Tensor, UnconnectedGradients, Variable
from tensorflow._aliases import ContainerGradients, ContainerTensors, ContainerTensorsLike, Gradients, TensorLike

class ForwardAccumulator:
def __init__(self, primals: Tensor, tangents: Tensor) -> None: ...
def jvp(
hoel-bagard marked this conversation as resolved.
Show resolved Hide resolved
self, primals: Tensor, unconnected_gradients: tf.UnconnectedGradients = tf.UnconnectedGradients.NONE # noqa: Y011
) -> Tensor | None: ...
def __enter__(self) -> Self: ...
def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ...

class GradientTape:
def __init__(self, persistent: _bool = False, watch_accessed_variables: _bool = True) -> None: ...
def __enter__(self) -> Self: ...
def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ...
# Higher kinded types would be nice here and these overloads are a way to simulate some of them.
@overload
def gradient(
self,
target: ContainerTensors,
sources: TensorLike,
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> Gradients: ...
@overload
def gradient(
self,
target: ContainerTensors,
sources: Sequence[Tensor],
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> list[Gradients]: ...
@overload
def gradient(
self,
target: ContainerTensors,
sources: Mapping[str, Tensor],
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> dict[str, Gradients]: ...
@overload
def gradient(
self,
target: ContainerTensors,
sources: ContainerTensors,
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> ContainerGradients: ...
@contextmanager
def stop_recording(self) -> Generator[None, None, None]: ...
def reset(self) -> None: ...
def watch(self, tensor: ContainerTensorsLike) -> None: ...
def watched_variables(self) -> tuple[Variable, ...]: ...
def __getattr__(self, name: str) -> Incomplete: ...
Loading