From a348e562e99927374778a0a459d287c93e8e10c6 Mon Sep 17 00:00:00 2001 From: Brian Pugh Date: Fri, 15 Nov 2024 18:39:11 -0500 Subject: [PATCH 1/4] replicate list[bool] bug; issue #249 --- tests/test_bind_list.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_bind_list.py b/tests/test_bind_list.py index 1524f010..23794ae9 100644 --- a/tests/test_bind_list.py +++ b/tests/test_bind_list.py @@ -47,6 +47,27 @@ def foo(a: Optional[List[int]] = None): assert_parse_args(foo, "foo") +@pytest.mark.parametrize( + "cmd_expected", + [ + ("", None), + ("--verbose", [True]), + ("--verbose --verbose", [True, True]), + ], +) +def test_keyword_list_of_bool(app, assert_parse_args, cmd_expected): + cmd, expected = cmd_expected + + @app.default + def foo(*, verbose: Optional[list[bool]] = None): + pass + + if expected is None: + assert_parse_args(foo, cmd) + else: + assert_parse_args(foo, cmd, verbose=expected) + + @pytest.mark.parametrize( "cmd", [ From 2fa29fbbf936267315a5daab6d0dd3ae0b0c3146 Mon Sep 17 00:00:00 2001 From: Brian Pugh Date: Fri, 15 Nov 2024 22:19:27 -0500 Subject: [PATCH 2/4] Handle the general case of iterative boolean flags. --- cyclopts/_convert.py | 17 ++++++++++++++--- cyclopts/argument.py | 19 ++++++++++++++----- cyclopts/parameter.py | 31 ++++++++++++++++++++++++++----- tests/test_bind_list.py | 4 ++++ 4 files changed, 58 insertions(+), 13 deletions(-) diff --git a/cyclopts/_convert.py b/cyclopts/_convert.py index 78e3e6f2..d895545c 100644 --- a/cyclopts/_convert.py +++ b/cyclopts/_convert.py @@ -31,13 +31,22 @@ _implicit_iterable_type_mapping: dict[type, type] = { + Iterable: list[str], + Sequence: list[str], + frozenset: frozenset[str], list: list[str], set: set[str], tuple: tuple[str, ...], - dict: dict[str, str], } -ITERABLE_TYPES = {list, set, frozenset, Sequence, Iterable, tuple} +ITERABLE_TYPES = { + Iterable, + Sequence, + frozenset, + list, + set, + tuple, +} NestedCliArgs = dict[str, Union[Sequence[str], "NestedCliArgs"]] @@ -161,7 +170,9 @@ def _convert( origin_type = get_origin(type_) inner_types = [resolve(x) for x in get_args(type_)] - if type_ in _implicit_iterable_type_mapping: + if type_ is dict: + out = convert(dict[str, str], token) + elif type_ in _implicit_iterable_type_mapping: out = convert(_implicit_iterable_type_mapping[type_], token) elif origin_type in (collections.abc.Iterable, collections.abc.Sequence): assert len(inner_types) == 1 diff --git a/cyclopts/argument.py b/cyclopts/argument.py index c50e40ec..a875a9a9 100644 --- a/cyclopts/argument.py +++ b/cyclopts/argument.py @@ -43,7 +43,7 @@ get_field_infos, ) from cyclopts.group import Group -from cyclopts.parameter import Parameter +from cyclopts.parameter import ITERATIVE_BOOL_IMPLICIT_VALUE, Parameter from cyclopts.token import Token from cyclopts.utils import UNSET, ParameterDict, grouper, is_builtin @@ -827,7 +827,10 @@ def _match_name( name = transform(name) if _startswith(term, name): trailing = term[len(name) :] - implicit_value = True if self.hint is bool else None + if self.hint is bool or self.hint in ITERATIVE_BOOL_IMPLICIT_VALUE: + implicit_value = True + else: + implicit_value = None if trailing: if trailing[0] == delimiter: trailing = trailing[1:] @@ -843,7 +846,10 @@ def _match_name( name = transform(name) if term.startswith(name): trailing = term[len(name) :] - implicit_value = (get_origin(self.hint) or self.hint)() + if self.hint in ITERATIVE_BOOL_IMPLICIT_VALUE: + implicit_value = False + else: + implicit_value = (get_origin(self.hint) or self.hint)() if trailing: if trailing[0] == delimiter: trailing = trailing[1:] @@ -917,8 +923,11 @@ def safe_converter(hint, tokens): keyword = {} for token in self.tokens: if token.implicit_value is not UNSET: - assert len(self.tokens) == 1 - return token.implicit_value + if self.hint in ITERATIVE_BOOL_IMPLICIT_VALUE: + return get_origin(self.hint)(x.implicit_value for x in self.tokens) + else: + assert len(self.tokens) == 1 + return token.implicit_value if token.keys: lookup = keyword diff --git a/cyclopts/parameter.py b/cyclopts/parameter.py index b255abdf..2801bc20 100644 --- a/cyclopts/parameter.py +++ b/cyclopts/parameter.py @@ -1,7 +1,7 @@ import inspect from collections.abc import Iterable from functools import partial -from typing import Any, Callable, Optional, Union, cast, get_args, get_origin +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast, get_args, get_origin import attrs from attrs import field, frozen @@ -18,7 +18,19 @@ to_tuple_converter, ) -_NEGATIVE_FLAG_TYPES = frozenset([bool, *ITERABLE_TYPES]) +ITERATIVE_BOOL_IMPLICIT_VALUE = frozenset( + { + Iterable[bool], + Sequence[bool], + List[bool], + list[bool], + Tuple[bool, ...], + tuple[bool, ...], + } +) + + +_NEGATIVE_FLAG_TYPES = frozenset([bool, *ITERABLE_TYPES, *ITERATIVE_BOOL_IMPLICIT_VALUE]) def _not_hyphen_validator(instance, attribute, values): @@ -162,10 +174,16 @@ def get_negatives(self, type_) -> tuple[str, ...]: if is_union(type_): type_ = next(x for x in get_args(type_) if x is not None) - type_ = get_origin(type_) or type_ + origin = get_origin(type_) - if (self.negative is not None and not self.negative) or type_ not in _NEGATIVE_FLAG_TYPES: + if self.negative is False: return () + if type_ not in _NEGATIVE_FLAG_TYPES: + if origin: + if origin not in _NEGATIVE_FLAG_TYPES: + return () + else: + return () out, user_negatives = [], [] if self.negative: @@ -182,7 +200,10 @@ def get_negatives(self, type_) -> tuple[str, ...]: name = name[2:] name_components = name.split(".") - negative_prefixes = self.negative_bool if type_ is bool else self.negative_iterable + if type_ is bool or type_ in ITERATIVE_BOOL_IMPLICIT_VALUE: + negative_prefixes = self.negative_bool + else: + negative_prefixes = self.negative_iterable name_prefix = ".".join(name_components[:-1]) if name_prefix: name_prefix += "." diff --git a/tests/test_bind_list.py b/tests/test_bind_list.py index 23794ae9..898fee26 100644 --- a/tests/test_bind_list.py +++ b/tests/test_bind_list.py @@ -53,6 +53,10 @@ def foo(a: Optional[List[int]] = None): ("", None), ("--verbose", [True]), ("--verbose --verbose", [True, True]), + ("--verbose --verbose --no-verbose", [True, True, False]), + ("--verbose --verbose=False", [True, False]), + ("--verbose --no-verbose=False", [True, True]), + ("--verbose --verbose=True", [True, True]), ], ) def test_keyword_list_of_bool(app, assert_parse_args, cmd_expected): From 872c17d510089408196dc60fabfbc643d30246db Mon Sep 17 00:00:00 2001 From: Brian Pugh Date: Tue, 19 Nov 2024 09:31:25 -0500 Subject: [PATCH 3/4] get rid of impossible check. --- cyclopts/parameter.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cyclopts/parameter.py b/cyclopts/parameter.py index 2801bc20..5a4e414c 100644 --- a/cyclopts/parameter.py +++ b/cyclopts/parameter.py @@ -97,7 +97,7 @@ def main(foo: Annotated[int, Parameter(name="bar")]): converter=lambda x: cast(tuple[Callable, ...], to_tuple_converter(x)), ) - # This can ONLY ever be a Tuple[str, ...] + # This can ONLY ever be ``None`` or ``Tuple[str, ...]`` negative: Union[None, str, Iterable[str]] = field(default=None, converter=optional_to_tuple_converter) # This can ONLY ever be a Tuple[Union[Group, str], ...] @@ -176,8 +176,6 @@ def get_negatives(self, type_) -> tuple[str, ...]: origin = get_origin(type_) - if self.negative is False: - return () if type_ not in _NEGATIVE_FLAG_TYPES: if origin: if origin not in _NEGATIVE_FLAG_TYPES: From a16ba52254629a9bbcd4143e122820c87b41255d Mon Sep 17 00:00:00 2001 From: Brian Pugh Date: Tue, 19 Nov 2024 09:33:53 -0500 Subject: [PATCH 4/4] inline if/else --- cyclopts/argument.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/cyclopts/argument.py b/cyclopts/argument.py index a875a9a9..996ea890 100644 --- a/cyclopts/argument.py +++ b/cyclopts/argument.py @@ -827,10 +827,7 @@ def _match_name( name = transform(name) if _startswith(term, name): trailing = term[len(name) :] - if self.hint is bool or self.hint in ITERATIVE_BOOL_IMPLICIT_VALUE: - implicit_value = True - else: - implicit_value = None + implicit_value = True if self.hint is bool or self.hint in ITERATIVE_BOOL_IMPLICIT_VALUE else None if trailing: if trailing[0] == delimiter: trailing = trailing[1:]