Skip to content

Commit

Permalink
fix: add MultiHeadAttention's __call__ to the stubtest allowlist
Browse files Browse the repository at this point in the history
attempt to fix MultiHeadAttention's __call__
  • Loading branch information
hoel-bagard committed Jan 31, 2024
1 parent 6e26f9b commit 35056a7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
1 change: 1 addition & 0 deletions stubs/tensorflow/@tests/stubtest_allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ tensorflow.keras.layers.*.__init__
tensorflow.keras.layers.*.call
tensorflow.keras.regularizers.Regularizer.__call__
tensorflow.keras.constraints.Constraint.__call__
tensorflow.keras.layers.MultiHeadAttention.__call__

# Layer class does good deal of __new__ magic and actually returns one of two different internal
# types depending on tensorflow execution mode. This feels like implementation internal.
Expand Down
27 changes: 19 additions & 8 deletions stubs/tensorflow/tensorflow/keras/layers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from typing_extensions import Self, TypeAlias

import tensorflow as tf
from tensorflow import Tensor, Variable, VariableAggregation, VariableSynchronization
from tensorflow._aliases import AnyArray, TensorLike, TensorCompatible, DTypeLike
from tensorflow._aliases import AnyArray, DTypeLike, TensorCompatible, TensorLike
from tensorflow.keras.activations import _Activation
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.initializers import _Initializer
Expand Down Expand Up @@ -282,23 +282,34 @@ class MultiHeadAttention(Layer[Any, tf.Tensor]):
self,
query: tf.Tensor,
value: tf.Tensor,
key: tf.Tensor | None = None,
attention_mask: tf.Tensor | None = None,
return_attention_scores: Literal[False] = False,
training: bool = False,
use_causal_mask: bool = False,
key: tf.Tensor | None,
attention_mask: tf.Tensor | None,
return_attention_scores: Literal[False],
training: bool,
use_causal_mask: bool,
) -> tf.Tensor: ...
@overload
def __call__(
self,
query: tf.Tensor,
value: tf.Tensor,
key: tf.Tensor | None,
attention_mask: tf.Tensor | None,
return_attention_scores: Literal[True],
training: bool,
use_causal_mask: bool,
) -> tuple[tf.Tensor, tf.Tensor]: ...
@overload
def __call__(
self,
query: tf.Tensor,
value: tf.Tensor,
key: tf.Tensor | None = None,
attention_mask: tf.Tensor | None = None,
return_attention_scores: Literal[True] = True,
return_attention_scores: bool = False,
training: bool = False,
use_causal_mask: bool = False,
) -> tuple[tf.Tensor, tf.Tensor]: ...
) -> tuple[tf.Tensor, tf.Tensor] | tf.Tensor: ...

class GaussianDropout(Layer[tf.Tensor, tf.Tensor]):
def __init__(
Expand Down

0 comments on commit 35056a7

Please sign in to comment.