Skip to content

Commit

Permalink
Add some missing keras layers
Browse files Browse the repository at this point in the history
  • Loading branch information
hoel-bagard committed Jan 28, 2024
1 parent d3b45a4 commit 17f66ee
Showing 1 changed file with 161 additions and 4 deletions.
165 changes: 161 additions & 4 deletions stubs/tensorflow/tensorflow/keras/layers.pyi
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from _typeshed import Incomplete
from collections.abc import Callable, Iterable, Sequence
from typing import Any, Generic, TypeVar, overload
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Any, Generic, Literal, TypeVar, overload
from typing_extensions import Self, TypeAlias

import tensorflow as tf
from tensorflow import Tensor, Variable, VariableAggregation, VariableSynchronization, _TensorCompatible
from tensorflow._aliases import AnyArray
from tensorflow._aliases import AnyArray, TensorLike
from tensorflow.keras.activations import _Activation
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.initializers import _Initializer
from tensorflow.keras.regularizers import _Regularizer
from tensorflow.keras.regularizers import Regularizer, _Regularizer
from tensorflow.python.feature_column.feature_column_v2 import DenseColumn, SequenceDenseColumn

_InputT = TypeVar("_InputT", contravariant=True)
_OutputT = TypeVar("_OutputT", covariant=True)
Expand Down Expand Up @@ -194,4 +195,160 @@ class Embedding(Layer[tf.Tensor, tf.Tensor]):
name: str | None = None,
) -> None: ...

class Conv2D(Layer[tf.Tensor, tf.Tensor]):
def __init__(
self,
filters: int,
kernel_size: int | tuple[int, int],
strides: int | tuple[int, int] = (1, 1),
padding: Literal["valid", "same"] = "valid",
data_format: None | Literal["channels_last", "channels_first"] = None,
dilation_rate: int | tuple[int, int] = (1, 1),
groups: int = 1,
activation: _Activation = None,
use_bias: bool = True,
kernel_initializer: _Initializer = "glorot_uniform",
bias_initializer: _Initializer = "zeros",
kernel_regularizer: _Regularizer = None,
bias_regularizer: _Regularizer = None,
activity_regularizer: _Regularizer = None,
kernel_constraint: _Constraint = None,
bias_constraint: _Constraint = None,
trainable: bool = True,
dtype: _LayerDtype = None,
dynamic: bool = False,
name: str | None = None,
) -> None: ...

class Identity(Layer[tf.Tensor, tf.Tensor]):
def __init__(
self, trainable: bool = True, dtype: _LayerDtype = None, dynamic: bool = False, name: str | None = None
) -> None: ...

class LayerNormalization(Layer[tf.Tensor, tf.Tensor]):
def __init__(
self,
axis: int = -1,
epsilon: float = 0.001,
center: bool = True,
scale: bool = True,
beta_initializer: _Initializer = "zeros",
gamma_initializer: _Initializer = "ones",
beta_regularizer: _Regularizer = None,
gamma_regularizer: _Regularizer = None,
beta_constraint: _Constraint = None,
gamma_constraint: _Constraint = None,
trainable: bool = True,
dtype: _LayerDtype = None,
dynamic: bool = False,
name: str | None = None,
) -> None: ...

class _IndexLookup(Layer[TensorLike, TensorLike]):
@overload
def __call__(self, inputs: tf.Tensor) -> tf.Tensor: ...
@overload
def __call__(self, inputs: tf.SparseTensor) -> tf.SparseTensor: ...
@overload
def __call__(self, inputs: tf.RaggedTensor) -> tf.RaggedTensor: ...
def vocabulary_size(self) -> int: ...

class StringLookup(_IndexLookup):
def __init__(
self,
max_tokens: int | None = None,
num_oov_indices: int = 1,
mask_token: str | None = None,
oov_token: str = "[UNK]",
vocabulary: str | None | _TensorCompatible = None,
idf_weights: _TensorCompatible | None = None,
encoding: str = "utf-8",
invert: bool = False,
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int",
sparse: bool = False,
pad_to_max_tokens: bool = False,
) -> None: ...

class IntegerLookup(_IndexLookup):
def __init__(
self,
max_tokens: int | None = None,
num_oov_indices: int = 1,
mask_token: int | None = None,
oov_token: int = -1,
vocabulary: str | None | _TensorCompatible = None,
vocabulary_dtype: Literal["int64", "int32"] = "int64",
idf_weights: _TensorCompatible | None = None,
invert: bool = False,
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int",
sparse: bool = False,
pad_to_max_tokens: bool = False,
) -> None: ...

class DenseFeatures(Layer[Mapping[str, TensorLike], tf.Tensor]):
def __init__(
self,
feature_columns: Sequence[DenseColumn | SequenceDenseColumn],
trainable: bool = True,
dtype: _LayerDtype = None,
dynamic: bool = False,
name: str | None = None,
) -> None: ...

class MultiHeadAttention(Layer[Any, tf.Tensor]):
def __init__(
self,
num_heads: int,
key_dim: int | None,
value_dim: int | None = None,
dropout: float = 0.0,
use_bias: bool = True,
output_shape: tuple[int, ...] | None = None,
attention_axes: tuple[int, ...] | None = None,
kernel_initialize: _Initializer = "glorot_uniform",
bias_initializer: _Initializer = "zeros",
kernel_regularizer: Regularizer | None = None,
bias_regularizer: _Regularizer | None = None,
activity_regularizer: _Regularizer | None = None,
kernel_constraint: _Constraint | None = None,
bias_constraint: _Constraint | None = None,
trainable: bool = True,
dtype: _LayerDtype = None,
dynamic: bool = False,
name: str | None = None,
) -> None: ...
@overload
def __call__(
self,
query: tf.Tensor,
value: tf.Tensor,
key: tf.Tensor | None = None,
attention_mask: tf.Tensor | None = None,
return_attention_scores: Literal[False] = False,
training: bool = False,
use_causal_mask: bool = False,
) -> tf.Tensor: ...
@overload
def __call__(
self,
query: tf.Tensor,
value: tf.Tensor,
key: tf.Tensor | None = None,
attention_mask: tf.Tensor | None = None,
return_attention_scores: Literal[True] = True,
training: bool = False,
use_causal_mask: bool = False,
) -> tuple[tf.Tensor, tf.Tensor]: ...

class GaussianDropout(Layer[tf.Tensor, tf.Tensor]):
def __init__(
self,
rate: float,
seed: int | None = None,
trainable: bool = True,
dtype: _LayerDtype = None,
dynamic: bool = False,
name: str | None = None,
) -> None: ...

def __getattr__(name: str) -> Incomplete: ...

0 comments on commit 17f66ee

Please sign in to comment.