Skip to content

Commit

Permalink
Support python310 pipe unions.
Browse files Browse the repository at this point in the history
  • Loading branch information
BrianPugh committed Jan 17, 2024
1 parent 0e62cf5 commit 7018180
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
11 changes: 9 additions & 2 deletions cyclopts/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
else:
from typing import Annotated # pragma: no cover

_union_types = set()
_union_types.add(Union)
if sys.version_info >= (3, 10):
from types import UnionType

_union_types.add(UnionType)

from cyclopts.exceptions import CoercionError

if TYPE_CHECKING:
Expand Down Expand Up @@ -93,7 +100,7 @@ def _convert(type_, element, converter=None):
if origin_type is collections.abc.Iterable:
assert len(inner_types) == 1
return pconvert(List[inner_types[0]], element) # pyright: ignore[reportGeneralTypeIssues]
elif origin_type is Union:
elif origin_type in _union_types:
for t in inner_types:
if t is NoneType:
continue
Expand Down Expand Up @@ -174,7 +181,7 @@ def resolve_optional(type_: Type) -> Type:
# Python will automatically flatten out nested unions when possible.
# So we don't need to loop over resolution.

if get_origin(type_) is not Union:
if get_origin(type_) not in _union_types:
return type_

non_none_types = [t for t in get_args(type_) if t is not NoneType]
Expand Down
30 changes: 30 additions & 0 deletions tests/test_bind_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,36 @@ def foo(a: Optional[int] = None):
assert_parse_args(foo, cmd_str, 1)


@pytest.mark.skipif(sys.version_info < (3, 10), reason="Pipe Typing Syntax")
@pytest.mark.parametrize(
"cmd_str",
[
"foo 1",
"foo --a=1",
"foo --a 1",
],
)
@pytest.mark.parametrize("annotated", [False, True])
def test_optional_nonrequired_implicit_coercion_python310_syntax(app, cmd_str, annotated, assert_parse_args):
"""
For a union without an explicit coercion, the first non-None type annotation
should be used. In this case, it's ``int``.
"""
if annotated:

@app.command
def foo(a: Annotated[int | None, Parameter(help="help for a")] = None):
pass

else:

@app.command
def foo(a: int | None = None):
pass

assert_parse_args(foo, cmd_str, 1)


@pytest.mark.parametrize(
"cmd_str",
[
Expand Down

0 comments on commit 7018180

Please sign in to comment.