From f18c5f39d0968b83dc702fa5e6486d10319cb591 Mon Sep 17 00:00:00 2001 From: Shuhei Iitsuka Date: Tue, 12 Dec 2023 09:30:51 +0000 Subject: [PATCH 1/2] Add the scale argument to encode_data.py --- scripts/encode_data.py | 15 ++++++++++++--- scripts/tests/test_encode_data.py | 27 +++++++++++++++++++-------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/scripts/encode_data.py b/scripts/encode_data.py index fcb5714c..b948fe8d 100644 --- a/scripts/encode_data.py +++ b/scripts/encode_data.py @@ -96,19 +96,27 @@ def parse_args(test: ArgList = None) -> argparse.Namespace: help='''Number of processes to use. (default: the number of CPUs in the system)''', default=None) + parser.add_argument( + '--scale', + type=int, + help='''Weight scale for the entrties. The value should be a unsigned + integer. (default: 1)''', + default=1) if test is None: return parser.parse_args() else: return parser.parse_args(test) -def process(i: int, sentence: str, sep_indices: typing.Set[int]) -> str: +def process(i: int, sentence: str, sep_indices: typing.Set[int], + scale: int) -> str: """Outputs an encoded line of features from the given index. Args: i (int): index sentence (str): A sentence sep_indices (typing.Set[int]): A set of separator indices. + scale (int): A weight scale for the entries. """ feature = get_feature(sentence[i - 3] if i > 2 else INVALID, sentence[i - 2] if i > 1 else INVALID, sentence[i - 1], @@ -116,7 +124,7 @@ def process(i: int, sentence: str, sep_indices: typing.Set[int]) -> str: sentence[i + 1] if i + 1 < len(sentence) else INVALID, sentence[i + 2] if i + 2 < len(sentence) else INVALID) positive = i in sep_indices - line = '\t'.join(['1' if positive else '-1'] + feature) + line = '\t'.join(['%d' % (scale) if positive else '%d' % (-scale)] + feature) return line @@ -142,12 +150,13 @@ def main(test: ArgList = None) -> None: source_filename: str = args.source_data entries_filename: str = args.outfile processes = None if args.processes is None else int(args.processes) + scale: int = args.scale with open(source_filename, encoding=sys.getdefaultencoding()) as f: data = f.read() sentence, sep_indices = normalize_input(data) with multiprocessing.Pool(processes) as p: func = functools.partial( - process, sentence=sentence, sep_indices=sep_indices) + process, sentence=sentence, sep_indices=sep_indices, scale=scale) lines = p.map(func, range(1, len(sentence) + 1)) with open(entries_filename, 'w', encoding=sys.getdefaultencoding()) as f: diff --git a/scripts/tests/test_encode_data.py b/scripts/tests/test_encode_data.py index 2bdc14ed..de376e19 100644 --- a/scripts/tests/test_encode_data.py +++ b/scripts/tests/test_encode_data.py @@ -97,6 +97,7 @@ def test_cmdargs_default(self) -> None: self.assertEqual(output.source_data, 'source.txt') self.assertEqual(output.outfile, encode_data.DEFAULT_OUTPUT_FILENAME) self.assertIsNone(output.processes) + self.assertEqual(output.scale, 1) def test_cmdargs_with_outfile(self) -> None: cmdargs = ['source.txt', '-o', 'out.txt'] @@ -104,6 +105,7 @@ def test_cmdargs_with_outfile(self) -> None: self.assertEqual(output.source_data, 'source.txt') self.assertEqual(output.outfile, 'out.txt') self.assertIsNone(output.processes) + self.assertEqual(output.scale, 1) def test_cmdargs_with_processes(self) -> None: cmdargs = ['source.txt', '--processes', '8'] @@ -111,6 +113,15 @@ def test_cmdargs_with_processes(self) -> None: self.assertEqual(output.source_data, 'source.txt') self.assertEqual(output.outfile, encode_data.DEFAULT_OUTPUT_FILENAME) self.assertEqual(output.processes, 8) + self.assertEqual(output.scale, 1) + + def test_cmdargs_with_scale(self) -> None: + cmdargs = ['source.txt', '--scale', '20'] + output = encode_data.parse_args(cmdargs) + self.assertEqual(output.source_data, 'source.txt') + self.assertEqual(output.outfile, encode_data.DEFAULT_OUTPUT_FILENAME) + self.assertIsNone(output.processes) + self.assertEqual(output.scale, 20) class TestProcess(unittest.TestCase): @@ -118,20 +129,20 @@ class TestProcess(unittest.TestCase): sentence = '六本木ヒルズでお昼を食べる。' sep_indices = {7, 10, 13} - def test_on_positive_point(self) -> None: - line = encode_data.process(8, self.sentence, self.sep_indices) + def test_on_negative_point_with_scale(self) -> None: + line = encode_data.process(8, self.sentence, self.sep_indices, 16) items = line.split('\t') - positive = items[0] + weight = items[0] features = set(items[1:]) - self.assertEqual(positive, '-1') + self.assertEqual(weight, '-16') self.assertIn('UW2:で', features) - def test_on_negative_point(self) -> None: - line = encode_data.process(7, self.sentence, self.sep_indices) + def test_on_positive_point_with_scale(self) -> None: + line = encode_data.process(7, self.sentence, self.sep_indices, 13) items = line.split('\t') - positive = items[0] + weight = items[0] features = set(items[1:]) - self.assertEqual(positive, '1') + self.assertEqual(weight, '13') self.assertIn('UW3:で', features) From 597acf01440d7d26d074b5132751c4683d78187d Mon Sep 17 00:00:00 2001 From: Shuhei Iitsuka Date: Tue, 12 Dec 2023 18:41:51 +0900 Subject: [PATCH 2/2] Fix typo Co-authored-by: Koji Ishii --- scripts/encode_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/encode_data.py b/scripts/encode_data.py index b948fe8d..019dfafd 100644 --- a/scripts/encode_data.py +++ b/scripts/encode_data.py @@ -99,7 +99,7 @@ def parse_args(test: ArgList = None) -> argparse.Namespace: parser.add_argument( '--scale', type=int, - help='''Weight scale for the entrties. The value should be a unsigned + help='''Weight scale for the entries. The value should be a unsigned integer. (default: 1)''', default=1) if test is None: