Skip to content

Commit

Permalink
try removing overload
Browse files Browse the repository at this point in the history
  • Loading branch information
hoel-bagard committed Jan 31, 2024
1 parent b17c74a commit f9685a2
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions stubs/tensorflow/tensorflow/keras/layers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -277,29 +277,29 @@ class MultiHeadAttention(Layer[Any, tf.Tensor]):
dynamic: bool = False,
name: str | None = None,
) -> None: ...
@overload
def __call__(
self,
query: tf.Tensor,
value: tf.Tensor,
key: tf.Tensor | None,
attention_mask: tf.Tensor | None,
return_attention_scores: Literal[False],
training: bool,
use_causal_mask: bool,
) -> tf.Tensor: ...
@overload
def __call__(
self,
query: tf.Tensor,
value: tf.Tensor,
key: tf.Tensor | None,
attention_mask: tf.Tensor | None,
return_attention_scores: Literal[True],
training: bool,
use_causal_mask: bool,
) -> tuple[tf.Tensor, tf.Tensor]: ...
@overload
# @overload
# def __call__(
# self,
# query: tf.Tensor,
# value: tf.Tensor,
# key: tf.Tensor | None,
# attention_mask: tf.Tensor | None,
# return_attention_scores: Literal[False],
# training: bool,
# use_causal_mask: bool,
# ) -> tf.Tensor: ...
# @overload
# def __call__(
# self,
# query: tf.Tensor,
# value: tf.Tensor,
# key: tf.Tensor | None,
# attention_mask: tf.Tensor | None,
# return_attention_scores: Literal[True],
# training: bool,
# use_causal_mask: bool,
# ) -> tuple[tf.Tensor, tf.Tensor]: ...
# @overload
def __call__(
self,
query: tf.Tensor,
Expand Down

0 comments on commit f9685a2

Please sign in to comment.