Skip to content

Commit

Permalink
Add thres arg to Python CLI (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
tushuhei authored Jan 7, 2022
1 parent 5158e8e commit 56de659
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 4 deletions.
1 change: 1 addition & 0 deletions budoux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@

Parser = parser.Parser
load_default_japanese_parser = parser.load_default_japanese_parser
DEFAULT_THRES = parser.DEFAULT_THRES
11 changes: 9 additions & 2 deletions budoux/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def parse_args(test: ArgList = None) -> argparse.Namespace:
action="version",
version="%(prog)s {}".format(budoux.__version__),
)
parser.add_argument(
"--thres",
type=int,
default=budoux.DEFAULT_THRES,
help="threshold value to separate chunks (default: {})".format(
budoux.DEFAULT_THRES))
if test is not None:
return parser.parse_args(test)
else:
Expand All @@ -116,13 +122,14 @@ def _main(test: ArgList = None):
inputs = sys.stdin.read()
else:
inputs = args.text
res = parser.translate_html_string(inputs)
res = parser.translate_html_string(inputs, thres=args.thres)
else:
if args.text is None:
inputs = [v.rstrip() for v in sys.stdin.readlines()]
else:
inputs = [v.rstrip() for v in args.text.splitlines()]
res = ["\n".join(res) for res in map(parser.parse, inputs)]
outputs = [parser.parse(sentence, thres=args.thres) for sentence in inputs]
res = ["\n".join(res) for res in outputs]
ors = "\n" + args.delim + "\n"
res = ors.join(res)

Expand Down
5 changes: 3 additions & 2 deletions budoux/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

MODEL_DIR = os.path.join(os.path.dirname(__file__), 'models')
PARENT_CSS_STYLE = 'word-break: keep-all; overflow-wrap: break-word;'
DEFAULT_THRES = 1000
with open(os.path.join(os.path.dirname(__file__), 'skip_nodes.json')) as f:
SKIP_NODES: typing.Set[str] = set(json.load(f))

Expand Down Expand Up @@ -109,7 +110,7 @@ def __init__(self, model: typing.Dict[str, int]):
"""
self.model = model

def parse(self, sentence: str, thres: int = 1000):
def parse(self, sentence: str, thres: int = DEFAULT_THRES):
"""Parses the input sentence and returns a list of semantic chunks.
Args:
Expand Down Expand Up @@ -146,7 +147,7 @@ def parse(self, sentence: str, thres: int = 1000):
p3 = p
return chunks

def translate_html_string(self, html: str, thres: int = 1000):
def translate_html_string(self, html: str, thres: int = DEFAULT_THRES):
"""Translates the given HTML string with markups for semantic line breaks.
Args:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,20 @@ def test_cmdargs_multi_html(self):

self.assertEqual(cm.exception.code, 2)

def test_cmdargs_thres(self):
cmdargs = ['--thres=0', '今日はとても天気です。']
output_granular = main._main(cmdargs)
cmdargs = ['--thres=10000000', '今日はとても天気です。']
output_whole = main._main(cmdargs)
self.assertGreater(
len(output_granular), len(output_whole),
'Chunks should be more granular when a smaller threshold value is given.'
)
self.assertEqual(
''.join(output_granular.split('\n')), ''.join(output_whole.split('\n')),
'The output sentence should be the same regardless of the threshold value.'
)


class TestStdin(unittest.TestCase):

Expand Down

0 comments on commit 56de659

Please sign in to comment.