diff --git a/keras/layers/preprocessing/index_lookup.py b/keras/layers/preprocessing/index_lookup.py index 1bcff31bd81b..51d01f79b5ad 100644 --- a/keras/layers/preprocessing/index_lookup.py +++ b/keras/layers/preprocessing/index_lookup.py @@ -411,11 +411,7 @@ def set_vocabulary(self, vocabulary, idf_weights=None): should_have_oov = (self.num_oov_indices > 0) expected_oov = [self.oov_token] * self.num_oov_indices found_oov = vocabulary[oov_start:token_start] - has_oov = should_have_oov and found_oov == expected_oov - # If we get a numpy array, then has_oov may end up being a numpy array - # instead of a bool. Fix this by collapsing the variable if it's not bool. - if not isinstance(has_oov, bool): - has_oov = any(has_oov) + has_oov = should_have_oov and np.array_equal(found_oov, expected_oov) if all([should_have_mask, has_mask, should_have_oov]) and not has_oov: raise ValueError( diff --git a/keras/layers/preprocessing/index_lookup_test.py b/keras/layers/preprocessing/index_lookup_test.py index f3684f867436..0781f98aeaee 100644 --- a/keras/layers/preprocessing/index_lookup_test.py +++ b/keras/layers/preprocessing/index_lookup_test.py @@ -1535,6 +1535,31 @@ def test_get_vocabulary_no_special_tokens(self): self.assertAllEqual(returned_vocab, ["wind", "and", "fire"]) self.assertAllEqual(layer.vocabulary_size(), 5) + def test_vocab_multi_oov(self): + vocab_data = ["", "[OOV]", "[OOV]", "wind", "and", "fire"] + layer = index_lookup.IndexLookup( + max_tokens=None, + num_oov_indices=2, + mask_token="", + oov_token="[OOV]", + dtype=tf.string) + layer.set_vocabulary(vocab_data) + returned_vocab = layer.get_vocabulary() + self.assertAllEqual(returned_vocab, vocab_data) + + def test_vocab_multi_oov_not_present(self): + vocab_data = ["wind", "and", "fire"] + layer = index_lookup.IndexLookup( + max_tokens=None, + num_oov_indices=10, + mask_token="", + oov_token="[OOV]", + dtype=tf.string) + layer.set_vocabulary(vocab_data) + returned_vocab = layer.get_vocabulary() + self.assertAllEqual(returned_vocab, + [""] + ["[OOV]"] * 10 + ["wind", "and", "fire"]) + def test_vocab_with_max_cap(self): vocab_data = ["", "[OOV]", "wind", "and", "fire"] layer = index_lookup.IndexLookup(