-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tensorflow
: Add missing members to the tensorflow.keras.layers
module.
#11333
Merged
JelleZijlstra
merged 27 commits into
python:main
from
hoel-bagard:hoel/add_tf_keras_layers
Mar 13, 2024
Merged
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
ca822c2
Add some missing keras layers
hoel-bagard 6e26f9b
Move some layers to keras.layers.preprocessing
hoel-bagard 35056a7
fix: add MultiHeadAttention's __call__ to the stubtest allowlist
hoel-bagard 6e83aa6
remove MultiHeadAttention...
hoel-bagard cd8632a
Revert "remove MultiHeadAttention..."
hoel-bagard 755b0ab
type ignore the override
hoel-bagard d543957
Merge branch 'main' into hoel/add_tf_keras_layers
hoel-bagard 9eda7f3
try to fix mypy crash.
hoel-bagard 2c41a03
add modules to allowlist due to cursed imports.
hoel-bagard 4bf4ad8
Merge branch 'main' into hoel/add_tf_keras_layers
hoel-bagard 976b9c8
remove tensorflow.keras.layers.MultiHeadAttention.__call__ from allow…
hoel-bagard 37f1b7c
Merge branch 'main' into hoel/add_tf_keras_layers
hoel-bagard a7739fb
test
hoel-bagard f5f74c2
Merge branch 'main' into hoel/add_tf_keras_layers
JelleZijlstra 03c3e9b
Revert "test"
hoel-bagard aa05ac7
Merge branch 'main' into hoel/add_tf_keras_layers
JelleZijlstra 7d0343f
fix: tuple -> Iterable
hoel-bagard 37c668b
add PreprocessingLayer methods/overloads
hoel-bagard 4228893
fix PreprocessingLayer typing
hoel-bagard 31134a1
make IndexLookup private
hoel-bagard aaa3738
silence/ignore mypy error.
hoel-bagard ac3b558
fix: make PreprocessingLayer's is_adapted into a property.
hoel-bagard bc7926a
Merge branch 'main' into hoel/add_tf_keras_layers
JelleZijlstra 05ce2ed
Merge branch 'main' into hoel/add_tf_keras_layers
rchen152 189e918
Merge branch 'main' into hoel/add_tf_keras_layers
hoel-bagard 849279e
try to fix pytype issue
hoel-bagard 3d1cfbc
merge with main
hoel-bagard File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
26 changes: 26 additions & 0 deletions
26
stubs/tensorflow/tensorflow/keras/layers/experimental/preprocessing.pyi
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import abc | ||
from typing import overload | ||
|
||
import tensorflow as tf | ||
from tensorflow._aliases import AnyArray, DataSequence, Float, Integer, TensorCompatible, TensorLike | ||
from tensorflow.keras.layers import Layer | ||
|
||
class PreprocessingLayer(Layer[TensorLike, TensorLike], metaclass=abc.ABCMeta): | ||
is_adapted: bool | ||
@overload # type: ignore | ||
def __call__(self, inputs: tf.Tensor, *, training: bool = False, mask: TensorCompatible | None = None) -> tf.Tensor: ... | ||
@overload | ||
def __call__( | ||
self, inputs: tf.SparseTensor, *, training: bool = False, mask: TensorCompatible | None = None | ||
) -> tf.SparseTensor: ... | ||
@overload | ||
def __call__( | ||
self, inputs: tf.RaggedTensor, *, training: bool = False, mask: TensorCompatible | None = None | ||
) -> tf.RaggedTensor: ... | ||
def adapt( | ||
self, | ||
data: tf.data.Dataset[TensorLike] | AnyArray | DataSequence, | ||
batch_size: Integer | None = None, | ||
steps: Float | None = None, | ||
) -> None: ... | ||
def compile(self, run_eagerly: bool | None = None, steps_per_execution: Integer | None = None) -> None: ... |
36 changes: 36 additions & 0 deletions
36
stubs/tensorflow/tensorflow/keras/layers/preprocessing/__init__.pyi
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Literal | ||
|
||
from tensorflow._aliases import TensorCompatible | ||
from tensorflow.keras.layers.preprocessing.index_lookup import _IndexLookup | ||
|
||
class StringLookup(_IndexLookup): | ||
def __init__( | ||
self, | ||
max_tokens: int | None = None, | ||
num_oov_indices: int = 1, | ||
mask_token: str | None = None, | ||
oov_token: str = "[UNK]", | ||
vocabulary: str | None | TensorCompatible = None, | ||
idf_weights: TensorCompatible | None = None, | ||
encoding: str = "utf-8", | ||
invert: bool = False, | ||
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int", | ||
sparse: bool = False, | ||
pad_to_max_tokens: bool = False, | ||
) -> None: ... | ||
|
||
class IntegerLookup(_IndexLookup): | ||
def __init__( | ||
self, | ||
max_tokens: int | None = None, | ||
num_oov_indices: int = 1, | ||
mask_token: int | None = None, | ||
oov_token: int = -1, | ||
vocabulary: str | None | TensorCompatible = None, | ||
vocabulary_dtype: Literal["int64", "int32"] = "int64", | ||
idf_weights: TensorCompatible | None = None, | ||
invert: bool = False, | ||
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int", | ||
sparse: bool = False, | ||
pad_to_max_tokens: bool = False, | ||
) -> None: ... |
9 changes: 9 additions & 0 deletions
9
stubs/tensorflow/tensorflow/keras/layers/preprocessing/index_lookup.pyi
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from _typeshed import Incomplete | ||
|
||
import tensorflow as tf | ||
from tensorflow.keras.layers.experimental.preprocessing import PreprocessingLayer | ||
|
||
class _IndexLookup(PreprocessingLayer): | ||
def compute_output_signature(self, input_spec: Incomplete) -> tf.TensorSpec: ... | ||
def get_vocabulary(self, include_special_tokens: bool = True) -> list[Incomplete]: ... | ||
def vocabulary_size(self) -> int: ... |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One optional improvement (feel free to leave out of scope), there's comment near here that some of these signatures can be simplified with Unpack TypedDict (pep 692). The upstream stubs actually do use that technique which saves on repetition of dtype/trainable/name/dynamic/etc. I see that mypy is checked in this issue.
Can we start using 692 in typeshed now (has it been used yet?) @JelleZijlstra
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still waiting for pytype as you can see in #9710