Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for list[bool] and similar (list of flags). #251

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions cyclopts/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,22 @@


_implicit_iterable_type_mapping: dict[type, type] = {
Iterable: list[str],
Sequence: list[str],
frozenset: frozenset[str],
list: list[str],
set: set[str],
tuple: tuple[str, ...],
dict: dict[str, str],
}

ITERABLE_TYPES = {list, set, frozenset, Sequence, Iterable, tuple}
ITERABLE_TYPES = {
Iterable,
Sequence,
frozenset,
list,
set,
tuple,
}

NestedCliArgs = dict[str, Union[Sequence[str], "NestedCliArgs"]]

Expand Down Expand Up @@ -161,7 +170,9 @@
origin_type = get_origin(type_)
inner_types = [resolve(x) for x in get_args(type_)]

if type_ in _implicit_iterable_type_mapping:
if type_ is dict:
out = convert(dict[str, str], token)

Check warning on line 174 in cyclopts/_convert.py

View check run for this annotation

Codecov / codecov/patch

cyclopts/_convert.py#L174

Added line #L174 was not covered by tests
elif type_ in _implicit_iterable_type_mapping:
out = convert(_implicit_iterable_type_mapping[type_], token)
elif origin_type in (collections.abc.Iterable, collections.abc.Sequence):
assert len(inner_types) == 1
Expand Down
16 changes: 11 additions & 5 deletions cyclopts/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
get_field_infos,
)
from cyclopts.group import Group
from cyclopts.parameter import Parameter
from cyclopts.parameter import ITERATIVE_BOOL_IMPLICIT_VALUE, Parameter
from cyclopts.token import Token
from cyclopts.utils import UNSET, ParameterDict, grouper, is_builtin

Expand Down Expand Up @@ -827,7 +827,7 @@ def _match_name(
name = transform(name)
if _startswith(term, name):
trailing = term[len(name) :]
implicit_value = True if self.hint is bool else None
implicit_value = True if self.hint is bool or self.hint in ITERATIVE_BOOL_IMPLICIT_VALUE else None
if trailing:
if trailing[0] == delimiter:
trailing = trailing[1:]
Expand All @@ -843,7 +843,10 @@ def _match_name(
name = transform(name)
if term.startswith(name):
trailing = term[len(name) :]
implicit_value = (get_origin(self.hint) or self.hint)()
if self.hint in ITERATIVE_BOOL_IMPLICIT_VALUE:
implicit_value = False
else:
implicit_value = (get_origin(self.hint) or self.hint)()
if trailing:
if trailing[0] == delimiter:
trailing = trailing[1:]
Expand Down Expand Up @@ -917,8 +920,11 @@ def safe_converter(hint, tokens):
keyword = {}
for token in self.tokens:
if token.implicit_value is not UNSET:
assert len(self.tokens) == 1
return token.implicit_value
if self.hint in ITERATIVE_BOOL_IMPLICIT_VALUE:
return get_origin(self.hint)(x.implicit_value for x in self.tokens)
else:
assert len(self.tokens) == 1
return token.implicit_value

if token.keys:
lookup = keyword
Expand Down
33 changes: 26 additions & 7 deletions cyclopts/parameter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from collections.abc import Iterable
from functools import partial
from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast, get_args, get_origin

import attrs
from attrs import field, frozen
Expand All @@ -18,7 +18,19 @@
to_tuple_converter,
)

_NEGATIVE_FLAG_TYPES = frozenset([bool, *ITERABLE_TYPES])
ITERATIVE_BOOL_IMPLICIT_VALUE = frozenset(
{
Iterable[bool],
Sequence[bool],
List[bool],
list[bool],
Tuple[bool, ...],
tuple[bool, ...],
}
)


_NEGATIVE_FLAG_TYPES = frozenset([bool, *ITERABLE_TYPES, *ITERATIVE_BOOL_IMPLICIT_VALUE])


def _not_hyphen_validator(instance, attribute, values):
Expand Down Expand Up @@ -85,7 +97,7 @@ def main(foo: Annotated[int, Parameter(name="bar")]):
converter=lambda x: cast(tuple[Callable, ...], to_tuple_converter(x)),
)

# This can ONLY ever be a Tuple[str, ...]
# This can ONLY ever be ``None`` or ``Tuple[str, ...]``
negative: Union[None, str, Iterable[str]] = field(default=None, converter=optional_to_tuple_converter)

# This can ONLY ever be a Tuple[Union[Group, str], ...]
Expand Down Expand Up @@ -162,10 +174,14 @@ def get_negatives(self, type_) -> tuple[str, ...]:
if is_union(type_):
type_ = next(x for x in get_args(type_) if x is not None)

type_ = get_origin(type_) or type_
origin = get_origin(type_)

if (self.negative is not None and not self.negative) or type_ not in _NEGATIVE_FLAG_TYPES:
return ()
if type_ not in _NEGATIVE_FLAG_TYPES:
if origin:
if origin not in _NEGATIVE_FLAG_TYPES:
return ()
else:
return ()

out, user_negatives = [], []
if self.negative:
Expand All @@ -182,7 +198,10 @@ def get_negatives(self, type_) -> tuple[str, ...]:
name = name[2:]
name_components = name.split(".")

negative_prefixes = self.negative_bool if type_ is bool else self.negative_iterable
if type_ is bool or type_ in ITERATIVE_BOOL_IMPLICIT_VALUE:
negative_prefixes = self.negative_bool
else:
negative_prefixes = self.negative_iterable
name_prefix = ".".join(name_components[:-1])
if name_prefix:
name_prefix += "."
Expand Down
25 changes: 25 additions & 0 deletions tests/test_bind_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,31 @@ def foo(a: Optional[List[int]] = None):
assert_parse_args(foo, "foo")


@pytest.mark.parametrize(
"cmd_expected",
[
("", None),
("--verbose", [True]),
("--verbose --verbose", [True, True]),
("--verbose --verbose --no-verbose", [True, True, False]),
("--verbose --verbose=False", [True, False]),
("--verbose --no-verbose=False", [True, True]),
("--verbose --verbose=True", [True, True]),
],
)
def test_keyword_list_of_bool(app, assert_parse_args, cmd_expected):
cmd, expected = cmd_expected

@app.default
def foo(*, verbose: Optional[list[bool]] = None):
pass

if expected is None:
assert_parse_args(foo, cmd)
else:
assert_parse_args(foo, cmd, verbose=expected)


@pytest.mark.parametrize(
"cmd",
[
Expand Down