From 35056a765304c36cf3231fe6de7be1ac4f8f9d62 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sun, 28 Jan 2024 19:15:24 +0900 Subject: [PATCH] fix: add MultiHeadAttention's __call__ to the stubtest allowlist attempt to fix MultiHeadAttention's __call__ --- .../tensorflow/@tests/stubtest_allowlist.txt | 1 + .../tensorflow/keras/layers/__init__.pyi | 27 +++++++++++++------ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/stubs/tensorflow/@tests/stubtest_allowlist.txt b/stubs/tensorflow/@tests/stubtest_allowlist.txt index 5700e283375f..375e76454a96 100644 --- a/stubs/tensorflow/@tests/stubtest_allowlist.txt +++ b/stubs/tensorflow/@tests/stubtest_allowlist.txt @@ -60,6 +60,7 @@ tensorflow.keras.layers.*.__init__ tensorflow.keras.layers.*.call tensorflow.keras.regularizers.Regularizer.__call__ tensorflow.keras.constraints.Constraint.__call__ +tensorflow.keras.layers.MultiHeadAttention.__call__ # Layer class does good deal of __new__ magic and actually returns one of two different internal # types depending on tensorflow execution mode. This feels like implementation internal. diff --git a/stubs/tensorflow/tensorflow/keras/layers/__init__.pyi b/stubs/tensorflow/tensorflow/keras/layers/__init__.pyi index 7d6d20b17dae..24e7c0f7bb4c 100644 --- a/stubs/tensorflow/tensorflow/keras/layers/__init__.pyi +++ b/stubs/tensorflow/tensorflow/keras/layers/__init__.pyi @@ -5,7 +5,7 @@ from typing_extensions import Self, TypeAlias import tensorflow as tf from tensorflow import Tensor, Variable, VariableAggregation, VariableSynchronization -from tensorflow._aliases import AnyArray, TensorLike, TensorCompatible, DTypeLike +from tensorflow._aliases import AnyArray, DTypeLike, TensorCompatible, TensorLike from tensorflow.keras.activations import _Activation from tensorflow.keras.constraints import Constraint from tensorflow.keras.initializers import _Initializer @@ -282,23 +282,34 @@ class MultiHeadAttention(Layer[Any, tf.Tensor]): self, query: tf.Tensor, value: tf.Tensor, - key: tf.Tensor | None = None, - attention_mask: tf.Tensor | None = None, - return_attention_scores: Literal[False] = False, - training: bool = False, - use_causal_mask: bool = False, + key: tf.Tensor | None, + attention_mask: tf.Tensor | None, + return_attention_scores: Literal[False], + training: bool, + use_causal_mask: bool, ) -> tf.Tensor: ... @overload + def __call__( + self, + query: tf.Tensor, + value: tf.Tensor, + key: tf.Tensor | None, + attention_mask: tf.Tensor | None, + return_attention_scores: Literal[True], + training: bool, + use_causal_mask: bool, + ) -> tuple[tf.Tensor, tf.Tensor]: ... + @overload def __call__( self, query: tf.Tensor, value: tf.Tensor, key: tf.Tensor | None = None, attention_mask: tf.Tensor | None = None, - return_attention_scores: Literal[True] = True, + return_attention_scores: bool = False, training: bool = False, use_causal_mask: bool = False, - ) -> tuple[tf.Tensor, tf.Tensor]: ... + ) -> tuple[tf.Tensor, tf.Tensor] | tf.Tensor: ... class GaussianDropout(Layer[tf.Tensor, tf.Tensor]): def __init__(