-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
271 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
""" | ||
``ApeGeargs`` <-> ``func_argparse`` integration. | ||
This module customizes the ``func_argparse`` | ||
argparser-generator to generate an ``ApeGeargs`` argparser. | ||
To activate, simply import ``apegears.func_argparse`` instead of ``func_argparse``. | ||
""" | ||
|
||
import collections | ||
|
||
import func_argparse | ||
from func_argparse import ( | ||
ArgparserGenerator as _ArgparserGenerator, | ||
ArgumentSpec as _ArgumentSpec, | ||
_is_option_type, _GenericAlias) | ||
|
||
from .parser import ArgumentParser | ||
|
||
|
||
################################################################################ | ||
# Definition of the our custom ArgumentParser generator | ||
|
||
class ApegearsArgumentSpec(_ArgumentSpec): | ||
|
||
def __init__(self, adder_name, *flags, **kwargs): | ||
self.adder_name = adder_name | ||
super().__init__(*flags, **kwargs) | ||
|
||
def add_to_parser(self, parser): | ||
adder = getattr(parser, self.adder_name) | ||
adder(*self.flags, **self.kwargs) | ||
|
||
|
||
class ApegearsGenerator(_ArgparserGenerator): | ||
|
||
ArgParser = ArgumentParser | ||
|
||
def _gen_param_arguments(self, arg_name, arg_type, doc, default, has_default, prefix): | ||
|
||
a = arg_name | ||
t = arg_type | ||
|
||
flags = [a] | ||
if prefix is not None and prefix != a: | ||
flags = [prefix] + flags | ||
|
||
kwargs = dict( | ||
help=doc, | ||
) | ||
|
||
if t is bool: | ||
adder = 'add_flag' | ||
|
||
else: | ||
|
||
required = not has_default | ||
if required and _is_option_type(t): | ||
t = t.__args__[0] | ||
required = False | ||
if not has_default: | ||
default = None | ||
has_default = True | ||
|
||
kwargs['required'] = required | ||
if has_default: | ||
kwargs['default'] = default | ||
|
||
adder = 'add_optional' | ||
|
||
# try list option | ||
elem_t = _get_list_contained_type(t) | ||
if elem_t is not None: | ||
adder = 'add_list' | ||
kwargs.update(type=_get_type(elem_t), required=False) | ||
|
||
else: | ||
# try dict option | ||
ktvt = _get_dict_contained_types(t) | ||
if ktvt is not None: | ||
adder = 'add_dict' | ||
kt, vt = ktvt | ||
kwargs.update(key_type=_get_type(kt), type=_get_type(vt), required=False) | ||
|
||
else: | ||
kwargs['type'] = _get_type(t) | ||
|
||
yield ApegearsArgumentSpec(adder, *flags, **kwargs) | ||
|
||
|
||
def _get_type(t): | ||
# all supported types are already directly supported by our ArgumentParser | ||
return t | ||
|
||
|
||
def _get_list_contained_type(t): | ||
if not isinstance(t, _GenericAlias): | ||
return None | ||
if t.__origin__ not in (list, collections.abc.Sequence): | ||
return None | ||
contained = t.__args__[0] | ||
assert isinstance(contained, type) | ||
return contained | ||
|
||
|
||
def _get_dict_contained_types(t): | ||
if not isinstance(t, _GenericAlias): | ||
return None | ||
if t.__origin__ not in (dict, collections.abc.Mapping): | ||
return None | ||
kt, vt = t.__args__ | ||
assert isinstance(kt, type) | ||
assert isinstance(vt, type) | ||
return kt, vt | ||
|
||
|
||
################################################################################ | ||
# bootstrapping | ||
|
||
# activate our custom generator: | ||
func_argparse.set_default_generator(ApegearsGenerator) | ||
|
||
# make any name importable from here, so users can change any line like | ||
# `from func_argparse import ...` | ||
# to | ||
# `from apegears.func_argparse import ...` | ||
from func_argparse import * | ||
|
||
|
||
################################################################################ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# for integration with various argument-parser generators | ||
func_argparse |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
""" | ||
Unit-tests for integration of the argparser with func_argparse generator. | ||
""" | ||
|
||
import unittest | ||
import datetime | ||
from enum import Enum | ||
from typing import List, Dict, Union | ||
from collections import OrderedDict | ||
|
||
from apegears.func_argparse import func_argparser, make_single_main | ||
|
||
|
||
################################################################################ | ||
|
||
class Type1: | ||
|
||
def __init__(self, val): | ||
self.val = val | ||
|
||
@classmethod | ||
def from_string(cls, x): | ||
return cls(float(x)) | ||
|
||
|
||
Type1.__argparse__ = dict( | ||
from_string=Type1.from_string, | ||
default='-1', | ||
names=['type1', 't'], | ||
help='a Type1 object', | ||
) | ||
|
||
|
||
class Enum1(Enum): | ||
foo = 1 | ||
bar = 22 | ||
coo = 333 | ||
|
||
|
||
################################################################################ | ||
|
||
class FuncArgparseTest(unittest.TestCase): | ||
""" | ||
Tests integration with the func_argparse generator. | ||
""" | ||
|
||
################################################################################ | ||
|
||
def test_basic(self): | ||
def foo(x: int, pretty: bool, z: float = 2.5): | ||
pass | ||
|
||
p = func_argparser(foo) | ||
|
||
def P(args): | ||
return p.parse_args(args.split()) | ||
|
||
self.assertEqual(P('-x 7').x, 7) | ||
self.assertEqual(P('-x 7').z, 2.5) | ||
self.assertEqual(P('-x 7').pretty, False) | ||
self.assertEqual(P('-x 7 --pretty').pretty, True) | ||
self.assertEqual(P('-x 7 -p').pretty, True) | ||
self.assertEqual(P('-x 7 --no-pretty').pretty, False) | ||
self.assertEqual(P('-x 7 -z 9.5').z, 9.5) | ||
self.assertRaises(SystemExit, P, '') # x is required | ||
self.assertRaises(SystemExit, P, '-x aaa') # not an int value | ||
|
||
def test_list_and_enum(self): | ||
def foo(x: List[Enum1]): | ||
pass | ||
|
||
p = func_argparser(foo) | ||
|
||
def P(args): | ||
return p.parse_args(args.split()) | ||
|
||
self.assertEqual(P('').x, []) | ||
self.assertEqual(P('-x bar coo').x, [Enum1.bar, Enum1.coo]) | ||
self.assertEqual(P('-x bar -x coo').x, [Enum1.bar, Enum1.coo]) | ||
self.assertRaises(SystemExit, P, '-x aaa') # not an Enum1 value | ||
|
||
def test_dict_and_standard_type(self): | ||
def foo(x: Dict[int, datetime.date]): | ||
pass | ||
|
||
p = func_argparser(foo) | ||
|
||
def P(args): | ||
return p.parse_args(args.split()) | ||
|
||
d1 = datetime.date(2001, 2, 3) | ||
d2 = datetime.date(2004, 5, 6) | ||
|
||
self.assertEqual(P('').x, OrderedDict()) | ||
self.assertEqual(P('-x 1=%s 2=%s' % (d1, d2)).x, OrderedDict([(1, d1), (2, d2)])) | ||
self.assertEqual(P('-x 2=%s' % d2).x, OrderedDict([(2, d2)])) | ||
self.assertRaises(SystemExit, P, '-x qqq=%s' % d1) # not an int key | ||
self.assertRaises(SystemExit, P, '-x 5=200') # not a date value | ||
|
||
def test_custom_type(self): | ||
def foo(x: Type1 = None): | ||
pass | ||
|
||
p = func_argparser(foo) | ||
|
||
def P(args): | ||
return p.parse_args(args.split()) | ||
|
||
self.assertEqual(P('').x, None) | ||
self.assertEqual(P('-x 4.5').x.val, 4.5) | ||
|
||
def test_union_with_none(self): | ||
def foo(x: Union[Type1, None]): | ||
pass | ||
|
||
p = func_argparser(foo) | ||
|
||
def P(args): | ||
return p.parse_args(args.split()) | ||
|
||
self.assertEqual(P('').x, None) | ||
self.assertEqual(P('-x 4.5').x.val, 4.5) | ||
|
||
def test_main(self): | ||
|
||
def foo(x: int, pretty: bool, z: float = 2.5): | ||
return dict(locals()) | ||
|
||
res = make_single_main(foo)('-x 5'.split()) | ||
self.assertEqual(res, dict(x=5, pretty=False, z=2.5)) | ||
res = make_single_main(foo)('-x 5 --pretty -z 3.5'.split()) | ||
self.assertEqual(res, dict(x=5, pretty=True, z=3.5)) | ||
|
||
|
||
################################################################################ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters