Skip to content

Commit

Permalink
Fix calling dataset.map on unbatched data with TextVectorization
Browse files Browse the repository at this point in the history
When output_sequence_length is set, we would error out when the static shape
contains None on any dimension. We can fix that by using the dynamic shape
to calculate padding for output_sequence_length.

PiperOrigin-RevId: 382169202
  • Loading branch information
mattdangerw authored and tensorflower-gardener committed Jun 29, 2021
1 parent 0ba7292 commit 546bf07
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 48 deletions.
45 changes: 24 additions & 21 deletions keras/layers/preprocessing/text_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,6 @@ def compute_output_signature(self, input_spec):
return tf.TensorSpec(shape=output_shape, dtype=output_dtype)

def update_state(self, data):
if isinstance(data, (list, tuple, np.ndarray)):
data = tf.convert_to_tensor(data)
if data.shape.rank == 1:
data = tf.expand_dims(data, axis=-1)
self._index_lookup_layer.update_state(self._preprocess(data))

def finalize_state(self):
Expand Down Expand Up @@ -485,23 +481,30 @@ def call(self, inputs):
return inputs

lookup_data = self._index_lookup_layer(inputs)
if self._output_mode == INT:

# Maybe trim the output (NOOP if self._output_sequence_length is None).
output_tensor = lookup_data[..., :self._output_sequence_length]

output_shape = output_tensor.shape.as_list()
output_shape[-1] = self._output_sequence_length

# If it is a ragged tensor, convert it to dense with correct shape.
if tf_utils.is_ragged(output_tensor):
return output_tensor.to_tensor(default_value=0, shape=output_shape)

if self._output_sequence_length is None:
return output_tensor

padding, _ = tf.required_space_to_batch_paddings(
output_tensor.shape, output_shape)
return tf.pad(output_tensor, padding)
# For any non-int output, we can return directly form the underlying layer.
if self._output_mode is not INT:
return lookup_data

# If we have a ragged tensor, we can pad during the conversion to dense.
if tf_utils.is_ragged(lookup_data):
shape = lookup_data.shape.as_list()
# If output sequence length is None, to_tensor will pad the last dimension
# to the bounding shape of the ragged dimension.
shape[-1] = self._output_sequence_length
return lookup_data.to_tensor(default_value=0, shape=shape)

# If we have a dense tensor, we need to pad/trim directly.
if self._output_sequence_length is not None:
# Maybe trim the output.
lookup_data = lookup_data[..., :self._output_sequence_length]

# Maybe pad the output.
shape = tf.shape(lookup_data)
padded_shape = tf.concat((shape[:-1], [self._output_sequence_length]), 0)
# We need to be careful to use dynamic shape here as
# required_space_to_batch_paddings requires a fully known shape.
padding, _ = tf.required_space_to_batch_paddings(shape, padded_shape)
return tf.pad(lookup_data, padding)

return lookup_data
52 changes: 25 additions & 27 deletions keras/layers/preprocessing/text_vectorization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,10 @@ def test_scalar_input_int_mode_no_len_limit(self):
layer = text_vectorization.TextVectorization()
layer.adapt(vocab_data)
out = layer(input_data)
if tf.executing_eagerly():
self.assertAllClose(out.numpy(), [2, 3, 4, 5, 5, 4, 2, 1])
self.assertAllClose(out.numpy(), [2, 3, 4, 5, 5, 4, 2, 1])
layer.set_vocabulary(["earth", "wind", "and", "fire"])
out = layer(input_data)
if tf.executing_eagerly():
self.assertAllClose(out.numpy(), [2, 3, 4, 5, 5, 4, 2, 1])
self.assertAllClose(out.numpy(), [2, 3, 4, 5, 5, 4, 2, 1])

def test_scalar_input_int_mode_trim_to_len_limit(self):
vocab_data = [
Expand All @@ -333,12 +331,10 @@ def test_scalar_input_int_mode_trim_to_len_limit(self):
layer = text_vectorization.TextVectorization(output_sequence_length=3)
layer.adapt(vocab_data)
out = layer(input_data)
if tf.executing_eagerly():
self.assertAllClose(out.numpy(), [2, 3, 4])
self.assertAllClose(out.numpy(), [2, 3, 4])
layer.set_vocabulary(["earth", "wind", "and", "fire"])
out = layer(input_data)
if tf.executing_eagerly():
self.assertAllClose(out.numpy(), [2, 3, 4])
self.assertAllClose(out.numpy(), [2, 3, 4])

def test_scalar_input_int_pad_to_len_limit(self):
vocab_data = [
Expand All @@ -348,25 +344,21 @@ def test_scalar_input_int_pad_to_len_limit(self):
layer = text_vectorization.TextVectorization(output_sequence_length=10)
layer.adapt(vocab_data)
out = layer(input_data)
if tf.executing_eagerly():
self.assertAllClose(out.numpy(), [2, 3, 4, 5, 5, 4, 2, 1, 0, 0])
self.assertAllClose(out.numpy(), [2, 3, 4, 5, 5, 4, 2, 1, 0, 0])
layer.set_vocabulary(["earth", "wind", "and", "fire"])
out = layer(input_data)
if tf.executing_eagerly():
self.assertAllClose(out.numpy(), [2, 3, 4, 5, 5, 4, 2, 1, 0, 0])
self.assertAllClose(out.numpy(), [2, 3, 4, 5, 5, 4, 2, 1, 0, 0])

def test_list_inputs_1d(self):
vocab_data = ["two two two", "two three three", "three four four five"]
input_data = ["two three", "four five"]
layer = text_vectorization.TextVectorization()
layer.adapt(vocab_data)
out = layer(input_data)
if tf.executing_eagerly():
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
layer.set_vocabulary(["two", "three", "four", "five"])
out = layer(input_data)
if tf.executing_eagerly():
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])

def test_tensor_inputs(self):
vocab_data = tf.constant(
Expand All @@ -375,12 +367,10 @@ def test_tensor_inputs(self):
layer = text_vectorization.TextVectorization()
layer.adapt(vocab_data)
out = layer(input_data)
if tf.executing_eagerly():
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
layer.set_vocabulary(["two", "three", "four", "five"])
out = layer(input_data)
if tf.executing_eagerly():
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])

def test_list_inputs_2d(self):
vocab_data = [
Expand All @@ -389,22 +379,30 @@ def test_list_inputs_2d(self):
layer = text_vectorization.TextVectorization()
layer.adapt(vocab_data)
out = layer(input_data)
if tf.executing_eagerly():
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
layer.set_vocabulary(["two", "three", "four", "five"])
out = layer(input_data)
if tf.executing_eagerly():
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])

def test_dataset_of_single_strings(self):
vocab_data = ["two two two", "two three three", "three four four five"]
input_data = ["two three", "four five"]
vocab_ds = tf.data.Dataset.from_tensor_slices(vocab_data) # unbatched
input_ds = tf.data.Dataset.from_tensor_slices(input_data) # unbatched
layer = text_vectorization.TextVectorization()
layer.adapt(vocab_ds)
out = layer(input_data)
if tf.executing_eagerly():
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
out = input_ds.map(layer)
self.assertAllClose(list(out.as_numpy_iterator()), [[2, 3], [4, 5]])

def test_dataset_of_single_strings_with_output_sequence(self):
vocab_data = ["two two two", "two three three", "three four four five"]
input_data = ["two three", "four five"]
vocab_ds = tf.data.Dataset.from_tensor_slices(vocab_data) # unbatched
input_ds = tf.data.Dataset.from_tensor_slices(input_data) # unbatched
layer = text_vectorization.TextVectorization(output_sequence_length=3)
layer.adapt(vocab_ds)
out = input_ds.map(layer)
self.assertAllClose(list(out.as_numpy_iterator()), [[2, 3, 0], [4, 5, 0]])

@parameterized.named_parameters(
{
Expand Down

0 comments on commit 546bf07

Please sign in to comment.