From 56de65928246305182e9f29d2458f89becf8d3b3 Mon Sep 17 00:00:00 2001 From: Shuhei Iitsuka Date: Fri, 7 Jan 2022 14:37:52 +0900 Subject: [PATCH] Add thres arg to Python CLI (#32) --- budoux/__init__.py | 1 + budoux/main.py | 11 +++++++++-- budoux/parser.py | 5 +++-- tests/test_main.py | 14 ++++++++++++++ 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/budoux/__init__.py b/budoux/__init__.py index 2e0dd903..0bf1f115 100644 --- a/budoux/__init__.py +++ b/budoux/__init__.py @@ -19,3 +19,4 @@ Parser = parser.Parser load_default_japanese_parser = parser.load_default_japanese_parser +DEFAULT_THRES = parser.DEFAULT_THRES diff --git a/budoux/main.py b/budoux/main.py index fa9f251c..600a787d 100644 --- a/budoux/main.py +++ b/budoux/main.py @@ -98,6 +98,12 @@ def parse_args(test: ArgList = None) -> argparse.Namespace: action="version", version="%(prog)s {}".format(budoux.__version__), ) + parser.add_argument( + "--thres", + type=int, + default=budoux.DEFAULT_THRES, + help="threshold value to separate chunks (default: {})".format( + budoux.DEFAULT_THRES)) if test is not None: return parser.parse_args(test) else: @@ -116,13 +122,14 @@ def _main(test: ArgList = None): inputs = sys.stdin.read() else: inputs = args.text - res = parser.translate_html_string(inputs) + res = parser.translate_html_string(inputs, thres=args.thres) else: if args.text is None: inputs = [v.rstrip() for v in sys.stdin.readlines()] else: inputs = [v.rstrip() for v in args.text.splitlines()] - res = ["\n".join(res) for res in map(parser.parse, inputs)] + outputs = [parser.parse(sentence, thres=args.thres) for sentence in inputs] + res = ["\n".join(res) for res in outputs] ors = "\n" + args.delim + "\n" res = ors.join(res) diff --git a/budoux/parser.py b/budoux/parser.py index 66d3800a..651b58b5 100644 --- a/budoux/parser.py +++ b/budoux/parser.py @@ -22,6 +22,7 @@ MODEL_DIR = os.path.join(os.path.dirname(__file__), 'models') PARENT_CSS_STYLE = 'word-break: keep-all; overflow-wrap: break-word;' +DEFAULT_THRES = 1000 with open(os.path.join(os.path.dirname(__file__), 'skip_nodes.json')) as f: SKIP_NODES: typing.Set[str] = set(json.load(f)) @@ -109,7 +110,7 @@ def __init__(self, model: typing.Dict[str, int]): """ self.model = model - def parse(self, sentence: str, thres: int = 1000): + def parse(self, sentence: str, thres: int = DEFAULT_THRES): """Parses the input sentence and returns a list of semantic chunks. Args: @@ -146,7 +147,7 @@ def parse(self, sentence: str, thres: int = 1000): p3 = p return chunks - def translate_html_string(self, html: str, thres: int = 1000): + def translate_html_string(self, html: str, thres: int = DEFAULT_THRES): """Translates the given HTML string with markups for semantic line breaks. Args: diff --git a/tests/test_main.py b/tests/test_main.py index 88a8c652..bc1766d8 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -106,6 +106,20 @@ def test_cmdargs_multi_html(self): self.assertEqual(cm.exception.code, 2) + def test_cmdargs_thres(self): + cmdargs = ['--thres=0', '今日はとても天気です。'] + output_granular = main._main(cmdargs) + cmdargs = ['--thres=10000000', '今日はとても天気です。'] + output_whole = main._main(cmdargs) + self.assertGreater( + len(output_granular), len(output_whole), + 'Chunks should be more granular when a smaller threshold value is given.' + ) + self.assertEqual( + ''.join(output_granular.split('\n')), ''.join(output_whole.split('\n')), + 'The output sentence should be the same regardless of the threshold value.' + ) + class TestStdin(unittest.TestCase):