Skip to content

Commit

Permalink
Update unit tests for the encoding script (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
tushuhei authored Nov 14, 2022
1 parent fd28bd7 commit 23119c7
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
17 changes: 9 additions & 8 deletions scripts/encode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,17 @@ def process(i: int, sentence: str, sep_indices: typing.Set[int]) -> str:
return line


def read_source_file(filename: str) -> typing.Tuple[str, typing.Set[int]]:
"""Reads the sentence and separator indices from the source file.
def normalize_input(data: str) -> typing.Tuple[str, typing.Set[int]]:
"""Normalizes the input to one line with separators.
Args:
filename (str): A source file path.
data(str): Source input
Returns:
typing.Tuple[str, typing.Set[int]]: A tuple of the sentence and the separator indices.
typing.Tuple[str, typing.Set[int]]: A tuple of the sentence and the
separator indices.
"""
with open(filename, encoding=sys.getdefaultencoding()) as f:
data = f.read().replace('\n', utils.SEP)
chunks = data.strip().split(utils.SEP)
chunks = data.replace('\n', utils.SEP).strip().split(utils.SEP)
chunk_lengths = [len(chunk) for chunk in chunks]
sep_indices = set(itertools.accumulate(chunk_lengths, lambda x, y: x + y))
sentence = ''.join(chunks)
Expand All @@ -99,7 +98,9 @@ def main(test: ArgList = None) -> None:
source_filename: str = args.source_data
entries_filename: str = args.outfile
processes = None if args.processes is None else int(args.processes)
sentence, sep_indices = read_source_file(source_filename)
with open(source_filename, encoding=sys.getdefaultencoding()) as f:
data = f.read()
sentence, sep_indices = normalize_input(data)
with multiprocessing.Pool(processes) as p:
func = functools.partial(
process, sentence=sentence, sep_indices=sep_indices)
Expand Down
36 changes: 19 additions & 17 deletions tests/test_encode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,22 @@ def test_on_negative_point(self) -> None:
self.assertIn('UW3:で', features)


class TestReadSourceFile(unittest.TestCase):

ENTRIES_FILE_PATH = os.path.abspath(
os.path.join(os.path.dirname(__file__), 'entries_test.txt'))

def setUp(self) -> None:
with open(
self.ENTRIES_FILE_PATH, 'w', encoding=sys.getdefaultencoding()) as f:
f.write(f'これは{utils.SEP}美しい{utils.SEP}ペンです。\n今日は{utils.SEP}晴天です。')

def test_read_source_file(self) -> None:
sentence, sep_indices = encode_data.read_source_file(self.ENTRIES_FILE_PATH)
self.assertEqual(sentence, 'これは美しいペンです。今日は晴天です。')
self.assertEqual(sep_indices, {3, 6, 11, 14, 19})

def tearDown(self) -> None:
os.remove(self.ENTRIES_FILE_PATH)
class TestNormalizeInput(unittest.TestCase):

def test_standard_input(self) -> None:
source = f'ABC{utils.SEP}DE{utils.SEP}FGHI'
sentence, sep_indices = encode_data.normalize_input(source)
self.assertEqual(sentence, 'ABCDEFGHI')
self.assertEqual(sep_indices, {3, 5, 9})

def test_with_linebreaks(self) -> None:
source = f'AB\nCDE{utils.SEP}FG'
sentence, sep_indices = encode_data.normalize_input(source)
self.assertEqual(sentence, 'ABCDEFG')
self.assertEqual(sep_indices, {2, 5, 7})

def test_doubled_seps(self) -> None:
source = f'ABC{utils.SEP}{utils.SEP}DE\n\nFG'
sentence, sep_indices = encode_data.normalize_input(source)
self.assertEqual(sentence, 'ABCDEFG')
self.assertEqual(sep_indices, {3, 5, 7})

0 comments on commit 23119c7

Please sign in to comment.