Skip to content

Commit

Permalink
Fix lookup layer oov token check when num_oov_indices > len(vocabular…
Browse files Browse the repository at this point in the history
…y tokens)

This was always broken for numpy vocabulary inputs, and recently broke for lists as well.

PiperOrigin-RevId: 380935015
  • Loading branch information
mattdangerw authored and tensorflower-gardener committed Jun 23, 2021
1 parent 586c4ad commit fa80feb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
6 changes: 1 addition & 5 deletions keras/layers/preprocessing/index_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,7 @@ def set_vocabulary(self, vocabulary, idf_weights=None):
should_have_oov = (self.num_oov_indices > 0)
expected_oov = [self.oov_token] * self.num_oov_indices
found_oov = vocabulary[oov_start:token_start]
has_oov = should_have_oov and found_oov == expected_oov
# If we get a numpy array, then has_oov may end up being a numpy array
# instead of a bool. Fix this by collapsing the variable if it's not bool.
if not isinstance(has_oov, bool):
has_oov = any(has_oov)
has_oov = should_have_oov and np.array_equal(found_oov, expected_oov)

if all([should_have_mask, has_mask, should_have_oov]) and not has_oov:
raise ValueError(
Expand Down
25 changes: 25 additions & 0 deletions keras/layers/preprocessing/index_lookup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,31 @@ def test_get_vocabulary_no_special_tokens(self):
self.assertAllEqual(returned_vocab, ["wind", "and", "fire"])
self.assertAllEqual(layer.vocabulary_size(), 5)

def test_vocab_multi_oov(self):
vocab_data = ["", "[OOV]", "[OOV]", "wind", "and", "fire"]
layer = index_lookup.IndexLookup(
max_tokens=None,
num_oov_indices=2,
mask_token="",
oov_token="[OOV]",
dtype=tf.string)
layer.set_vocabulary(vocab_data)
returned_vocab = layer.get_vocabulary()
self.assertAllEqual(returned_vocab, vocab_data)

def test_vocab_multi_oov_not_present(self):
vocab_data = ["wind", "and", "fire"]
layer = index_lookup.IndexLookup(
max_tokens=None,
num_oov_indices=10,
mask_token="",
oov_token="[OOV]",
dtype=tf.string)
layer.set_vocabulary(vocab_data)
returned_vocab = layer.get_vocabulary()
self.assertAllEqual(returned_vocab,
[""] + ["[OOV]"] * 10 + ["wind", "and", "fire"])

def test_vocab_with_max_cap(self):
vocab_data = ["", "[OOV]", "wind", "and", "fire"]
layer = index_lookup.IndexLookup(
Expand Down

0 comments on commit fa80feb

Please sign in to comment.