diff --git a/jarvis_util/util/argparse.py b/jarvis_util/util/argparse.py index 59c587f..175d1d3 100644 --- a/jarvis_util/util/argparse.py +++ b/jarvis_util/util/argparse.py @@ -10,6 +10,52 @@ from tabulate import tabulate +class PatternTree: + def __init__(self): + self.pattern = {} + + def add_menu(self, menu): + alias_to = None + for alias_str, alias_toks in menu['aliases']: + alias_to = self._add_menu(menu, self.pattern, alias_toks, alias_to) + + def _add_menu(self, menu, pattern, toks, alias_to): + tok = toks[0] + if tok not in pattern: + if len(toks) == 1 and alias_to is not None: + pattern[tok] = alias_to + else: + pattern[tok] = {} + if len(toks) == 1: + pattern[tok]['__menu'] = menu + return pattern[tok] + self._add_menu(menu, pattern[tok], toks[1:], alias_to) + + def get_default_menu(self): + if '' in self.pattern: + return self.pattern['']['__menu'] + return None + + def match_pattern(self, toks): + return self._match_pattern(toks, self.pattern, 0) + + def _match_pattern(self, toks, pattern, depth, last_match=(0, None)): + if len(toks) == 0: + return last_match + tok = toks[0] + if tok in pattern: + if '__menu' in pattern[tok]: + last_match = (depth + 1, pattern[tok]) + last_match = self._match_pattern( + toks[1:], pattern[tok], depth + 1, last_match) + return last_match + def __hash__(self): + return self.hash + + def __eq__(self, other): + return other.alias == self.alias + + class ArgParse(ABC): """ A class for parsing command line arguments. @@ -35,7 +81,7 @@ def __init__(self, args=None, exit_on_fail=True, **custom_info): self.error = None self.exit_on_fail = exit_on_fail self.custom_info = custom_info - self.menus = [] + self.menus = PatternTree() self.vars = {} self.remainder = [] self.remainder_kv = {} @@ -136,7 +182,7 @@ def add_menu(self, name=None, msg=None, 'is_cmd': is_cmd, 'aliases': full_aliases } - self.menus.append(menu) + self.menus.add_menu(menu) self.menu = menu @staticmethod @@ -254,42 +300,6 @@ def default_kwargs(menu_args): kwargs[arg['name']] = None return kwargs - def _match_aliases(self, menus, i, arg_i): - """ - Convert the alias "arg_i" into the original menu name - - :param menus: The list of menus being traversed - :param i: The position within the menu name (name_toks) we are comparing - against arg_i - :param arg_i: The argument at offset i in the input args (self.args) - :return: A list of names - """ - matches = set() - for a, b, menu in menus: - for alias_str, alias_toks in menu['aliases']: - if i < len(alias_toks) and alias_toks[i] == arg_i: - matches.add(menu['name_toks'][i]) - break - return list(matches) - - def _match_menus(self, menus, i, arg_is): - """ - Find the subset of menus where the ith parameter of the menu's - name matches - - :param menus: - :param i: - :param arg_is: - :return: - """ - matches = {} - for a, b, menu in menus: - for alias_str, alias_toks in menu['aliases']: - if i < len(alias_toks) and alias_toks[i] in arg_is: - matches[alias_str] = (alias_str, alias_toks, menu) - break - return list(matches.values()) - def _parse_menu(self): """ Determine which menu is used in the CLI. @@ -300,32 +310,21 @@ def _parse_menu(self): """ # Identify the menu we are currently under self.menu = None - matches = list([(alias_str, alias_toks, menu) - for menu in self.menus - for alias_str, alias_toks in menu['aliases']]) - i = 0 - # Iteratively filter out matches - while i < len(self.args): - alias_matches = self._match_aliases(matches, i, self.args[i]) - new_matches = self._match_menus(matches, i, alias_matches) - if len(new_matches) == 1: - self.menu = new_matches[0][2] - menu_alias = (new_matches[0][0], new_matches[0][1]) - self.args = self.args[len(menu_alias[1]):] - break - if len(new_matches) == 0: - break - matches = new_matches - i += 1 + # Filter out matches + depth, menus = self.menus.match_pattern(self.args) # If there was nothing remotely close, try default menu - if i == 0 and self.menu is None: - for menu in self.menus: - if menu['name_str'] == '': - self.menu = menu + if menus is None: + self.menu = self.menus.get_default_menu() + else: + self.menu = menus['__menu'] + self.args = self.args[depth:] # If there was nothing at all, error now if self.menu is None or not self.menu['is_cmd']: - if i > 0: - matches.sort(key=lambda x: len(x[1])) + if depth > 0: + matches = [self.menu] + matches += [pattern['__menu'] + for tok, pattern in menus.items() + if tok != '__menu' and '__menu' in pattern] else: matches = [] self._invalid_menu(matches) @@ -515,10 +514,6 @@ def _get_opt_name(self, opt_name): return opt_name def _invalid_menu(self, matches): - if len(matches) == 0: - menu_name = '' - else: - menu_name = matches[0][0] self._print_error('', matches=matches) def _invalid_choice(self, opt_name, arg): @@ -542,7 +537,8 @@ def _print_menu_error(self, msg): def _print_error(self, msg, matches=None): - print(f'{msg}') + if len(msg): + print(f'{msg}') self._print_help(matches) if self.exit_on_fail: sys.exit(1) @@ -552,7 +548,7 @@ def _print_error(self, msg, def _print_help(self, matches=None): self.needed_help = True - if self.menu is not None: + if matches is None: self._print_menu_help() else: self._print_menus(matches) @@ -565,12 +561,16 @@ def _print_menus(self, matches): :return: """ if len(matches) == 0: - for menu in self.menus: + menus = [] + if len(self.menus.pattern): + menus = [pattern['__menu'] + for tok, pattern in self.menus.pattern.items() + if tok != '__menu' and '__menu' in pattern] + for menu in menus: self.menu = menu self._print_menu_help(True, max_len=1) else: - menus = list({menu['name_str']:menu for alias_str, alias_toks, menu in matches}.values()) - for menu in menus: + for menu in matches: self.menu = menu self._print_menu_help(True)