Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the scale argument to encode_data.py #408

Merged
merged 2 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions scripts/encode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,27 +96,35 @@ def parse_args(test: ArgList = None) -> argparse.Namespace:
help='''Number of processes to use.
(default: the number of CPUs in the system)''',
default=None)
parser.add_argument(
'--scale',
type=int,
help='''Weight scale for the entries. The value should be a unsigned
integer. (default: 1)''',
default=1)
if test is None:
return parser.parse_args()
else:
return parser.parse_args(test)


def process(i: int, sentence: str, sep_indices: typing.Set[int]) -> str:
def process(i: int, sentence: str, sep_indices: typing.Set[int],
scale: int) -> str:
"""Outputs an encoded line of features from the given index.

Args:
i (int): index
sentence (str): A sentence
sep_indices (typing.Set[int]): A set of separator indices.
scale (int): A weight scale for the entries.
"""
feature = get_feature(sentence[i - 3] if i > 2 else INVALID,
sentence[i - 2] if i > 1 else INVALID, sentence[i - 1],
sentence[i] if i < len(sentence) else INVALID,
sentence[i + 1] if i + 1 < len(sentence) else INVALID,
sentence[i + 2] if i + 2 < len(sentence) else INVALID)
positive = i in sep_indices
line = '\t'.join(['1' if positive else '-1'] + feature)
line = '\t'.join(['%d' % (scale) if positive else '%d' % (-scale)] + feature)
return line


Expand All @@ -142,12 +150,13 @@ def main(test: ArgList = None) -> None:
source_filename: str = args.source_data
entries_filename: str = args.outfile
processes = None if args.processes is None else int(args.processes)
scale: int = args.scale
with open(source_filename, encoding=sys.getdefaultencoding()) as f:
data = f.read()
sentence, sep_indices = normalize_input(data)
with multiprocessing.Pool(processes) as p:
func = functools.partial(
process, sentence=sentence, sep_indices=sep_indices)
process, sentence=sentence, sep_indices=sep_indices, scale=scale)
lines = p.map(func, range(1, len(sentence) + 1))

with open(entries_filename, 'w', encoding=sys.getdefaultencoding()) as f:
Expand Down
27 changes: 19 additions & 8 deletions scripts/tests/test_encode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,41 +97,52 @@ def test_cmdargs_default(self) -> None:
self.assertEqual(output.source_data, 'source.txt')
self.assertEqual(output.outfile, encode_data.DEFAULT_OUTPUT_FILENAME)
self.assertIsNone(output.processes)
self.assertEqual(output.scale, 1)

def test_cmdargs_with_outfile(self) -> None:
cmdargs = ['source.txt', '-o', 'out.txt']
output = encode_data.parse_args(cmdargs)
self.assertEqual(output.source_data, 'source.txt')
self.assertEqual(output.outfile, 'out.txt')
self.assertIsNone(output.processes)
self.assertEqual(output.scale, 1)

def test_cmdargs_with_processes(self) -> None:
cmdargs = ['source.txt', '--processes', '8']
output = encode_data.parse_args(cmdargs)
self.assertEqual(output.source_data, 'source.txt')
self.assertEqual(output.outfile, encode_data.DEFAULT_OUTPUT_FILENAME)
self.assertEqual(output.processes, 8)
self.assertEqual(output.scale, 1)

def test_cmdargs_with_scale(self) -> None:
cmdargs = ['source.txt', '--scale', '20']
output = encode_data.parse_args(cmdargs)
self.assertEqual(output.source_data, 'source.txt')
self.assertEqual(output.outfile, encode_data.DEFAULT_OUTPUT_FILENAME)
self.assertIsNone(output.processes)
self.assertEqual(output.scale, 20)


class TestProcess(unittest.TestCase):

sentence = '六本木ヒルズでお昼を食べる。'
sep_indices = {7, 10, 13}

def test_on_positive_point(self) -> None:
line = encode_data.process(8, self.sentence, self.sep_indices)
def test_on_negative_point_with_scale(self) -> None:
line = encode_data.process(8, self.sentence, self.sep_indices, 16)
items = line.split('\t')
positive = items[0]
weight = items[0]
features = set(items[1:])
self.assertEqual(positive, '-1')
self.assertEqual(weight, '-16')
self.assertIn('UW2:で', features)

def test_on_negative_point(self) -> None:
line = encode_data.process(7, self.sentence, self.sep_indices)
def test_on_positive_point_with_scale(self) -> None:
line = encode_data.process(7, self.sentence, self.sep_indices, 13)
items = line.split('\t')
positive = items[0]
weight = items[0]
features = set(items[1:])
self.assertEqual(positive, '1')
self.assertEqual(weight, '13')
self.assertIn('UW3:で', features)


Expand Down
Loading