Skip to content

Commit

Permalink
Make TextVectorization work with list input.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 379498023
  • Loading branch information
yashk2810 authored and tensorflower-gardener committed Jun 15, 2021
1 parent 0eb1954 commit fe45d47
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion keras/layers/preprocessing/text_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def build(self, input_shape):
# expression to evaluate to False instead of True if the shape is undefined;
# the expression needs to evaluate to True in that case.
if self._split is not None:
if input_shape.ndims > 1 and not input_shape[-1] == 1: # pylint: disable=g-comparison-negation
if input_shape and input_shape.ndims > 1 and not input_shape[-1] == 1: # pylint: disable=g-comparison-negation
raise RuntimeError(
"When using TextVectorization to tokenize strings, the innermost "
"dimension of the input array must be 1, got shape "
Expand Down
5 changes: 5 additions & 0 deletions keras/layers/preprocessing/text_vectorization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,11 @@ def test_layer_dimensionality_handling_with_split(self, data, expected):
output = vectorization(tf.ragged.constant(data, inner_shape=(1,)))
self.assertAllEqual(expected, output)

def test_layer_list_input(self):
output = text_vectorization.TextVectorization()(["a", "b", "c"])
expected_output = [[1], [1], [1]]
self.assertEqual(output.numpy().tolist(), expected_output)


@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
class TextVectorizationPreprocessingTest(
Expand Down

0 comments on commit fe45d47

Please sign in to comment.