Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for Postponed Evaluation of types annotations (PEP-563) #138

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions cyclopts/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union

import cyclopts.utils

if sys.version_info < (3, 9):
from typing_extensions import Annotated
else:
Expand Down Expand Up @@ -551,14 +553,14 @@ def parse_known_args(
command = self.help_print
while meta_parent := meta_parent._meta_parent:
command = meta_parent.help_print
bound = inspect.signature(command).bind(tokens, console=console)
bound = cyclopts.utils.signature(command).bind(tokens, console=console)
unused_tokens = []
elif any(flag in tokens for flag in self.version_flags):
# Version
command = self.version_print
while meta_parent := meta_parent._meta_parent:
command = meta_parent.version_print
bound = inspect.signature(command).bind()
bound = cyclopts.utils.signature(command).bind()
unused_tokens = []
else:
try:
Expand Down Expand Up @@ -599,7 +601,7 @@ def parse_known_args(
# Running the application with no arguments and no registered
# ``default_command`` will default to ``help_print``.
command = self.help_print
bound = inspect.signature(command).bind(tokens=tokens, console=console)
bound = cyclopts.utils.signature(command).bind(tokens=tokens, console=console)
unused_tokens = []
except CycloptsError as e:
e.app = command_app
Expand Down
3 changes: 2 additions & 1 deletion cyclopts/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rich.panel import Panel
from rich.text import Text

import cyclopts.utils
from cyclopts.group import Group
from cyclopts.utils import ParameterDict

Expand Down Expand Up @@ -112,7 +113,7 @@ def __str__(self):
if self.target:
file, lineno = _get_function_info(self.target)
strings.append(f'Function defined in file "{file}", line {lineno}:')
strings.append(f" {self.target.__name__}{inspect.signature(self.target)}")
strings.append(f" {self.target.__name__}{cyclopts.utils.signature(self.target)}")
if self.root_input_tokens is not None:
strings.append(f"Root Input Tokens: {self.root_input_tokens}")
else:
Expand Down
3 changes: 2 additions & 1 deletion cyclopts/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rich.table import Table
from rich.text import Text

import cyclopts.utils
from cyclopts.group import Group
from cyclopts.parameter import Parameter, get_hint_parameter

Expand Down Expand Up @@ -145,7 +146,7 @@ def format_usage(

if app.default_command:
to_show = set()
for parameter in inspect.signature(app.default_command).parameters.values():
for parameter in cyclopts.utils.signature(app.default_command).parameters.values():
if parameter.kind in (parameter.POSITIONAL_ONLY, parameter.VAR_POSITIONAL, parameter.POSITIONAL_OR_KEYWORD):
to_show.add("[ARGS]")
if parameter.kind in (parameter.KEYWORD_ONLY, parameter.VAR_KEYWORD, parameter.POSITIONAL_OR_KEYWORD):
Expand Down
3 changes: 2 additions & 1 deletion cyclopts/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import attrs
from attrs import field, frozen

import cyclopts.utils
from cyclopts._convert import (
AnnotatedType,
convert,
Expand Down Expand Up @@ -179,7 +180,7 @@ def validate_command(f: Callable):
ValueError
Function has naming or parameter/signature inconsistencies.
"""
signature = inspect.signature(f)
signature = cyclopts.utils.signature(f)
for iparam in signature.parameters.values():
get_origin_and_validate(iparam.annotation)
type_, cparam = get_hint_parameter(iparam)
Expand Down
32 changes: 16 additions & 16 deletions cyclopts/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from docstring_parser import parse as docstring_parse

import cyclopts.utils
from cyclopts.exceptions import DocstringError
from cyclopts.group import Group
from cyclopts.parameter import Parameter, get_hint_parameter
Expand All @@ -24,18 +25,17 @@ def _list_index(lst: List, key: Callable) -> int:
raise ValueError


def _has_unparsed_parameters(f: Callable, *args) -> bool:
signature = inspect.signature(f)
for iparam in signature.parameters.values():
def _has_unparsed_parameters(func_signature: inspect.Signature, *args) -> bool:
for iparam in func_signature.parameters.values():
cparam: Parameter
_, cparam = get_hint_parameter(iparam, *args)

if not cparam.parse:
return True
return False


def _resolve_groups(
f: Callable,
func_signature: inspect.Signature,
app_parameter: Optional[Parameter],
group_arguments: Group,
group_parameters: Group,
Expand All @@ -47,9 +47,7 @@ def _resolve_groups(
resolved_groups = []
iparam_to_groups = ParameterDict()

signature = inspect.signature(f)

for iparam in signature.parameters.values():
for iparam in func_signature.parameters.values():
_, cparam = get_hint_parameter(iparam, app_parameter)

if not cparam.parse:
Expand Down Expand Up @@ -96,11 +94,11 @@ def _resolve_groups(
return resolved_groups, iparam_to_groups


def _resolve_docstring(f) -> ParameterDict:
signature = inspect.signature(f)
f_docstring = docstring_parse(f.__doc__)

def _resolve_docstring(f: Callable, signature: inspect.Signature) -> ParameterDict:
iparam_to_docstring_cparam = ParameterDict()
if f.__doc__ is None:
return iparam_to_docstring_cparam
f_docstring = docstring_parse(f.__doc__)

for dparam in f_docstring.params:
try:
Expand Down Expand Up @@ -156,17 +154,19 @@ def __init__(
group_parameters = Group.create_default_parameters()

self.command = f
signature = inspect.signature(f)
signature = cyclopts.utils.signature(f)
self.name_to_iparam = cast(Dict[str, inspect.Parameter], signature.parameters)

# Get:
# 1. Fully resolved and created Groups.
# 2. A mapping of inspect.Parameter to those Group objects.
self.groups, self.iparam_to_groups = _resolve_groups(f, app_parameter, group_arguments, group_parameters)
self.groups, self.iparam_to_groups = _resolve_groups(
signature, app_parameter, group_arguments, group_parameters
)

# Fully Resolve each Cyclopts Parameter
self.iparam_to_cparam = ParameterDict()
iparam_to_docstring_cparam = _resolve_docstring(f) if parse_docstring else ParameterDict()
iparam_to_docstring_cparam = _resolve_docstring(f, signature) if parse_docstring else ParameterDict()
empty_help_string_parameter = Parameter(help="")
for iparam, groups in self.iparam_to_groups.items():
if iparam.kind in (iparam.POSITIONAL_ONLY, iparam.VAR_POSITIONAL):
Expand All @@ -188,7 +188,7 @@ def __init__(
)[1]
self.iparam_to_cparam[iparam] = cparam

self.bind = signature.bind_partial if _has_unparsed_parameters(f, app_parameter) else signature.bind
self.bind = signature.bind_partial if _has_unparsed_parameters(signature, app_parameter) else signature.bind

# Create a convenient group-to-iparam structure
self.groups_iparams = [
Expand Down
13 changes: 11 additions & 2 deletions cyclopts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,26 @@

_union_types.add(UnionType)

# fmt: off
if sys.version_info >= (3, 10):
def signature(f: Any) -> inspect.Signature:
return inspect.signature(f, eval_str=True)
else:
def signature(f: Any) -> inspect.Signature:
return inspect.signature(f)
# fmt: on


def record_init(target: str):
"""Class decorator that records init argument names as a tuple to ``target``."""

def decorator(cls):
original_init = cls.__init__
signature = inspect.signature(original_init)
function_signature = signature(original_init)

@functools.wraps(original_init)
def new_init(self, *args, **kwargs):
bound = signature.bind(self, *args, **kwargs)
bound = function_signature.bind(self, *args, **kwargs)
original_init(self, *args, **kwargs)
# Circumvent frozen protection.
object.__setattr__(self, target, tuple(k for k, v in bound.arguments.items() if v is not self))
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import inspect
from pathlib import Path

import pytest
from rich.console import Console

import cyclopts
import cyclopts.utils
from cyclopts import App, Group, Parameter


Expand All @@ -26,7 +26,7 @@ def default_function_groups():
@pytest.fixture
def assert_parse_args(app):
def inner(f, cmd: str, *args, **kwargs):
signature = inspect.signature(f)
signature = cyclopts.utils.signature(f)
expected_bind = signature.bind(*args, **kwargs)
actual_command, actual_bind = app.parse_args(cmd, print_error=False, exit_on_error=False)
assert actual_command == f
Expand All @@ -38,7 +38,7 @@ def inner(f, cmd: str, *args, **kwargs):
@pytest.fixture
def assert_parse_args_partial(app):
def inner(f, cmd: str, *args, **kwargs):
signature = inspect.signature(f)
signature = cyclopts.utils.signature(f)
expected_bind = signature.bind_partial(*args, **kwargs)
actual_command, actual_bind = app.parse_args(cmd, print_error=False, exit_on_error=False)
assert actual_command == f
Expand Down
23 changes: 22 additions & 1 deletion tests/test_bind_pos_only.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sys

import pytest

from cyclopts import CoercionError, MissingArgumentError, UnknownOptionError, ValidationError
from cyclopts import MissingArgumentError, UnknownOptionError, ValidationError


@pytest.mark.parametrize(
Expand Down Expand Up @@ -82,3 +84,22 @@ def foo(a: int, b: int, c: int, /, d: int):

with pytest.raises(e):
app.parse_args(cmd_str, print_error=False, exit_on_error=False)


@pytest.mark.skipif(
sys.version_info < (3, 10), reason="https://peps.python.org/pep-0563/ Postponed Evaluation of Annotations"
)
@pytest.mark.parametrize(
"cmd_str",
[
"foo a 2 3 4",
"foo a 2 3 --d 4",
"foo a 2 --d=4 3",
],
)
def test_pos_only_extended_str_type(app, cmd_str, assert_parse_args):
@app.command
def foo(a: "str", b: "int", c: int, /, d: "int"):
pass

assert_parse_args(foo, cmd_str, "a", 2, 3, 4)
25 changes: 25 additions & 0 deletions tests/test_help.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,31 @@ def cmd(
assert actual == expected


@pytest.mark.skipif(
sys.version_info < (3, 10), reason="https://peps.python.org/pep-0563/ Postponed Evaluation of Annotations"
)
def test_help_parameter_string_annotation(capture_format_group_parameters):
def cmd(number: "Annotated[int,Parameter(name=['--number','-n'])]"):
"""Print number.

Args:
number (int): a number to print.
"""
pass

actual = capture_format_group_parameters(cmd)
expected = dedent(
"""\
╭─ Parameters ───────────────────────────────────────────────────────╮
│ * NUMBER,--number -n a number to print. [required] │
╰────────────────────────────────────────────────────────────────────╯
"""
)
print(actual)
print(expected)
assert actual == expected


def test_help_format_group_parameters_choices_literal_set_typing(capture_format_group_parameters):
def cmd(
steps_to_skip: Annotated[
Expand Down
13 changes: 6 additions & 7 deletions tests/test_parameter2cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
import sys

if sys.version_info < (3, 9):
Expand All @@ -8,14 +7,14 @@

from cyclopts.parameter import Parameter
from cyclopts.resolve import ResolvedCommand
from cyclopts.utils import ParameterDict
from cyclopts.utils import ParameterDict, signature


def test_parameter2cli_positional_or_keyword(default_function_groups):
def foo(a: Annotated[int, Parameter(negative=())]):
pass

a_iparam = list(inspect.signature(foo).parameters.values())[0]
a_iparam = list(signature(foo).parameters.values())[0]
actual = ResolvedCommand(foo, *default_function_groups).parameter2cli
assert actual == ParameterDict({a_iparam: ["--a"]})

Expand All @@ -24,7 +23,7 @@ def test_parameter2cli_positional_only(default_function_groups):
def foo(a: Annotated[int, Parameter(negative=())], /):
pass

a_iparam = list(inspect.signature(foo).parameters.values())[0]
a_iparam = list(signature(foo).parameters.values())[0]
actual = ResolvedCommand(foo, *default_function_groups).parameter2cli
assert actual == ParameterDict({a_iparam: ["A"]})

Expand All @@ -33,7 +32,7 @@ def test_parameter2cli_keyword_only(default_function_groups):
def foo(*, a: Annotated[int, Parameter(negative=())]):
pass

a_iparam = list(inspect.signature(foo).parameters.values())[0]
a_iparam = list(signature(foo).parameters.values())[0]
actual = ResolvedCommand(foo, *default_function_groups).parameter2cli
assert actual == ParameterDict({a_iparam: ["--a"]})

Expand All @@ -42,7 +41,7 @@ def test_parameter2cli_var_keyword(default_function_groups):
def foo(**a: Annotated[int, Parameter(negative=())]):
pass

a_iparam = list(inspect.signature(foo).parameters.values())[0]
a_iparam = list(signature(foo).parameters.values())[0]
actual = ResolvedCommand(foo, *default_function_groups).parameter2cli
assert actual == ParameterDict({a_iparam: ["--a"]})

Expand All @@ -51,6 +50,6 @@ def test_parameter2cli_var_positional(default_function_groups):
def foo(*a: Annotated[int, Parameter(negative=())]):
pass

a_iparam = list(inspect.signature(foo).parameters.values())[0]
a_iparam = list(signature(foo).parameters.values())[0]
actual = ResolvedCommand(foo, *default_function_groups).parameter2cli
assert actual == ParameterDict({a_iparam: ["A"]})
7 changes: 3 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import inspect
from typing import List

import pytest

from cyclopts.utils import ParameterDict
from cyclopts.utils import ParameterDict, signature


@pytest.fixture
Expand All @@ -15,7 +14,7 @@ def test_parameter_dict_immutable(parameter_dict):
def foo(a: int, b: int = 3):
pass

parameters = dict(inspect.signature(foo).parameters)
parameters = dict(signature(foo).parameters)

for name, parameter in parameters.items():
parameter_dict[parameter] = name
Expand All @@ -32,7 +31,7 @@ def test_parameter_dict_mutable(parameter_dict):
def foo(a: int, b: List[int] = []): # noqa: B006
pass

parameters = dict(inspect.signature(foo).parameters)
parameters = dict(signature(foo).parameters)

for name, parameter in parameters.items():
parameter_dict[parameter] = name
Expand Down
Loading