Skip to content

Commit

Permalink
Fixed test for subword_tokenize
Browse files Browse the repository at this point in the history
  • Loading branch information
VibhuJawa committed Jan 5, 2022
1 parent 56133b3 commit 62606d3
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions python/cudf/cudf/tests/test_subword_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,16 @@ def test_text_subword_tokenize(tmpdir):
content = content + "100\n101\n102\n\n"
hash_file.write(content)

cudf_tokenizer = SubwordTokenizer("voc_hash.txt")
cudf_tokenizer = SubwordTokenizer(hash_file)

tokens, masks, metadata = cudf_tokenizer(sr, 8, 8)
token_d = cudf_tokenizer(
sr, 8, 8, add_special_tokens=False, truncation=True
)
tokens, masks, metadata = (
token_d["input_ids"],
token_d["attention_mask"],
token_d["metadata"],
)
expected_tokens = cupy.asarray(
[
2023,
Expand Down Expand Up @@ -172,6 +179,7 @@ def test_text_subword_tokenize(tmpdir):
],
dtype=np.uint32,
)
expected_tokens = expected_tokens.reshape(-1, 8)
assert_eq(expected_tokens, tokens)

expected_masks = cupy.asarray(
Expand Down Expand Up @@ -219,9 +227,11 @@ def test_text_subword_tokenize(tmpdir):
],
dtype=np.uint32,
)
expected_masks = expected_masks.reshape(-1, 8)
assert_eq(expected_masks, masks)

expected_metadata = cupy.asarray(
[0, 0, 3, 1, 0, 3, 2, 0, 3, 3, 0, 1, 4, 0, 1], dtype=np.uint32
)
expected_metadata = expected_metadata.reshape(-1, 3)
assert_eq(expected_metadata, metadata)

0 comments on commit 62606d3

Please sign in to comment.