Skip to content

Commit

Permalink
More robust aliasing
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemartinlogan committed Feb 10, 2024
1 parent 22e674d commit dc7edcb
Showing 1 changed file with 70 additions and 70 deletions.
140 changes: 70 additions & 70 deletions jarvis_util/util/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,52 @@
from tabulate import tabulate


class PatternTree:
def __init__(self):
self.pattern = {}

def add_menu(self, menu):
alias_to = None
for alias_str, alias_toks in menu['aliases']:
alias_to = self._add_menu(menu, self.pattern, alias_toks, alias_to)

def _add_menu(self, menu, pattern, toks, alias_to):
tok = toks[0]
if tok not in pattern:
if len(toks) == 1 and alias_to is not None:
pattern[tok] = alias_to
else:
pattern[tok] = {}
if len(toks) == 1:
pattern[tok]['__menu'] = menu
return pattern[tok]
self._add_menu(menu, pattern[tok], toks[1:], alias_to)

def get_default_menu(self):
if '' in self.pattern:
return self.pattern['']['__menu']
return None

def match_pattern(self, toks):
return self._match_pattern(toks, self.pattern, 0)

def _match_pattern(self, toks, pattern, depth, last_match=(0, None)):
if len(toks) == 0:
return last_match
tok = toks[0]
if tok in pattern:
if '__menu' in pattern[tok]:
last_match = (depth + 1, pattern[tok])
last_match = self._match_pattern(
toks[1:], pattern[tok], depth + 1, last_match)
return last_match
def __hash__(self):
return self.hash

def __eq__(self, other):
return other.alias == self.alias


class ArgParse(ABC):
"""
A class for parsing command line arguments.
Expand All @@ -35,7 +81,7 @@ def __init__(self, args=None, exit_on_fail=True, **custom_info):
self.error = None
self.exit_on_fail = exit_on_fail
self.custom_info = custom_info
self.menus = []
self.menus = PatternTree()
self.vars = {}
self.remainder = []
self.remainder_kv = {}
Expand Down Expand Up @@ -136,7 +182,7 @@ def add_menu(self, name=None, msg=None,
'is_cmd': is_cmd,
'aliases': full_aliases
}
self.menus.append(menu)
self.menus.add_menu(menu)
self.menu = menu

@staticmethod
Expand Down Expand Up @@ -254,42 +300,6 @@ def default_kwargs(menu_args):
kwargs[arg['name']] = None
return kwargs

def _match_aliases(self, menus, i, arg_i):
"""
Convert the alias "arg_i" into the original menu name
:param menus: The list of menus being traversed
:param i: The position within the menu name (name_toks) we are comparing
against arg_i
:param arg_i: The argument at offset i in the input args (self.args)
:return: A list of names
"""
matches = set()
for a, b, menu in menus:
for alias_str, alias_toks in menu['aliases']:
if i < len(alias_toks) and alias_toks[i] == arg_i:
matches.add(menu['name_toks'][i])
break
return list(matches)

def _match_menus(self, menus, i, arg_is):
"""
Find the subset of menus where the ith parameter of the menu's
name matches
:param menus:
:param i:
:param arg_is:
:return:
"""
matches = {}
for a, b, menu in menus:
for alias_str, alias_toks in menu['aliases']:
if i < len(alias_toks) and alias_toks[i] in arg_is:
matches[alias_str] = (alias_str, alias_toks, menu)
break
return list(matches.values())

def _parse_menu(self):
"""
Determine which menu is used in the CLI.
Expand All @@ -300,32 +310,21 @@ def _parse_menu(self):
"""
# Identify the menu we are currently under
self.menu = None
matches = list([(alias_str, alias_toks, menu)
for menu in self.menus
for alias_str, alias_toks in menu['aliases']])
i = 0
# Iteratively filter out matches
while i < len(self.args):
alias_matches = self._match_aliases(matches, i, self.args[i])
new_matches = self._match_menus(matches, i, alias_matches)
if len(new_matches) == 1:
self.menu = new_matches[0][2]
menu_alias = (new_matches[0][0], new_matches[0][1])
self.args = self.args[len(menu_alias[1]):]
break
if len(new_matches) == 0:
break
matches = new_matches
i += 1
# Filter out matches
depth, menus = self.menus.match_pattern(self.args)
# If there was nothing remotely close, try default menu
if i == 0 and self.menu is None:
for menu in self.menus:
if menu['name_str'] == '':
self.menu = menu
if menus is None:
self.menu = self.menus.get_default_menu()
else:
self.menu = menus['__menu']
self.args = self.args[depth:]
# If there was nothing at all, error now
if self.menu is None or not self.menu['is_cmd']:
if i > 0:
matches.sort(key=lambda x: len(x[1]))
if depth > 0:
matches = [self.menu]
matches += [pattern['__menu']
for tok, pattern in menus.items()
if tok != '__menu' and '__menu' in pattern]
else:
matches = []
self._invalid_menu(matches)
Expand Down Expand Up @@ -515,10 +514,6 @@ def _get_opt_name(self, opt_name):
return opt_name

def _invalid_menu(self, matches):
if len(matches) == 0:
menu_name = ''
else:
menu_name = matches[0][0]
self._print_error('', matches=matches)

def _invalid_choice(self, opt_name, arg):
Expand All @@ -542,7 +537,8 @@ def _print_menu_error(self, msg):

def _print_error(self, msg,
matches=None):
print(f'{msg}')
if len(msg):
print(f'{msg}')
self._print_help(matches)
if self.exit_on_fail:
sys.exit(1)
Expand All @@ -552,7 +548,7 @@ def _print_error(self, msg,
def _print_help(self,
matches=None):
self.needed_help = True
if self.menu is not None:
if matches is None:
self._print_menu_help()
else:
self._print_menus(matches)
Expand All @@ -565,12 +561,16 @@ def _print_menus(self, matches):
:return:
"""
if len(matches) == 0:
for menu in self.menus:
menus = []
if len(self.menus.pattern):
menus = [pattern['__menu']
for tok, pattern in self.menus.pattern.items()
if tok != '__menu' and '__menu' in pattern]
for menu in menus:
self.menu = menu
self._print_menu_help(True, max_len=1)
else:
menus = list({menu['name_str']:menu for alias_str, alias_toks, menu in matches}.values())
for menu in menus:
for menu in matches:
self.menu = menu
self._print_menu_help(True)

Expand Down

0 comments on commit dc7edcb

Please sign in to comment.