Skip to content

Commit

Permalink
tensorflow: add tf.linalg module (#11386)
Browse files Browse the repository at this point in the history
  • Loading branch information
hoel-bagard authored Feb 17, 2024
1 parent 2e85a70 commit 955cdf5
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 0 deletions.
1 change: 1 addition & 0 deletions stubs/tensorflow/tensorflow/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ from tensorflow.dtypes import *
from tensorflow.dtypes import DType as DType
from tensorflow.experimental.dtensor import Layout
from tensorflow.keras import losses as losses
from tensorflow.linalg import eye as eye

# Most tf.math functions are exported as tf, but sadly not all are.
from tensorflow.math import (
Expand Down
1 change: 1 addition & 0 deletions stubs/tensorflow/tensorflow/_aliases.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class KerasSerializable2(Protocol):

KerasSerializable: TypeAlias = KerasSerializable1 | KerasSerializable2

Integer: TypeAlias = tf.Tensor | int | IntArray | np.number[Any] # Here tf.Tensor and IntArray are assumed to be 0D.
Slice: TypeAlias = int | slice | None
FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence]
IntDataSequence: TypeAlias = Sequence[int] | Sequence[IntDataSequence]
Expand Down
52 changes: 52 additions & 0 deletions stubs/tensorflow/tensorflow/linalg.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from _typeshed import Incomplete
from builtins import bool as _bool
from collections.abc import Iterable
from typing import Literal, overload

import tensorflow as tf
from tensorflow import RaggedTensor, Tensor, norm as norm
from tensorflow._aliases import DTypeLike, IntArray, Integer, ScalarTensorCompatible, TensorCompatible
from tensorflow.math import l2_normalize as l2_normalize

@overload
def matmul(
a: TensorCompatible,
b: TensorCompatible,
transpose_a: _bool = False,
transpose_b: _bool = False,
adjoint_a: _bool = False,
adjoint_b: _bool = False,
a_is_sparse: _bool = False,
b_is_sparse: _bool = False,
output_type: DTypeLike | None = None,
name: str | None = None,
) -> Tensor: ...
@overload
def matmul(
a: RaggedTensor,
b: RaggedTensor,
transpose_a: _bool = False,
transpose_b: _bool = False,
adjoint_a: _bool = False,
adjoint_b: _bool = False,
a_is_sparse: _bool = False,
b_is_sparse: _bool = False,
output_type: DTypeLike | None = None,
name: str | None = None,
) -> RaggedTensor: ...
def set_diag(
input: TensorCompatible,
diagonal: TensorCompatible,
name: str | None = "set_diag",
k: int = 0,
align: Literal["RIGHT_LEFT", "RIGHT_RIGHT", "LEFT_LEFT", "LEFT_RIGHT"] = "RIGHT_LEFT",
) -> Tensor: ...
def eye(
num_rows: ScalarTensorCompatible,
num_columns: ScalarTensorCompatible | None = None,
batch_shape: Iterable[int] | IntArray | tf.Tensor | None = None,
dtype: DTypeLike = ...,
name: str | None = None,
) -> Tensor: ...
def band_part(input: TensorCompatible, num_lower: Integer, num_upper: Integer, name: str | None = None) -> Tensor: ...
def __getattr__(name: str) -> Incomplete: ...

0 comments on commit 955cdf5

Please sign in to comment.