Skip to content

Commit

Permalink
Fix(preprocessing): resolve ValueError for ngrams tuple (#19190)
Browse files Browse the repository at this point in the history
* Fix(preprocessing): resolve `ValueError` for `ngrams` tuple

Resolves a deserialisation issue when `ngrams` is set to a tuple in `TextVectorization` layer. Closes #19180

* Test(preprocessing): test loading a model with `ngrams` tuple

* Test(preprocessing): run `test_save_load_with_ngrams_flow` with `tensorflow` backend only
  • Loading branch information
mykolaskrynnyk authored Feb 17, 2024
1 parent 7ce3d62 commit 7a0e670
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
4 changes: 4 additions & 0 deletions keras/layers/preprocessing/text_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ def from_config(cls, config):
config["split"] = serialization_lib.deserialize_keras_object(
config["split"]
)

if isinstance(config["ngrams"], list):
config["ngrams"] = tuple(config["ngrams"])

return cls(**config)

def set_vocabulary(self, vocabulary, idf_weights=None):
Expand Down
22 changes: 22 additions & 0 deletions keras/layers/preprocessing/text_vectorization_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os

import numpy as np
import pytest
import tensorflow as tf
from tensorflow import data as tf_data

from keras import Sequential
from keras import backend
from keras import layers
from keras import models
from keras import saving
from keras import testing


Expand Down Expand Up @@ -62,6 +66,24 @@ def test_set_vocabulary(self):
self.assertTrue(backend.is_tensor(output))
self.assertAllClose(output, np.array([[4, 1, 3, 0], [1, 2, 0, 0]]))

@pytest.mark.skipif(
backend.backend() != "tensorflow", reason="Requires string input dtype"
)
def test_save_load_with_ngrams_flow(self):
input_data = np.array(["foo bar", "bar baz", "baz bada boom"])
model = Sequential(
[
layers.Input(dtype="string", shape=(1,)),
layers.TextVectorization(ngrams=(1, 2)),
]
)
model.layers[0].adapt(input_data)
output = model(input_data)
temp_filepath = os.path.join(self.get_temp_dir(), "model.keras")
model.save(temp_filepath)
model = saving.load_model(temp_filepath)
self.assertAllClose(output, model(input_data))

def test_tf_data_compatibility(self):
max_tokens = 5000
max_len = 4
Expand Down

0 comments on commit 7a0e670

Please sign in to comment.