diff --git a/poethepoet/config/partition.py b/poethepoet/config/partition.py index 2ce2994fc..5d7360803 100644 --- a/poethepoet/config/partition.py +++ b/poethepoet/config/partition.py @@ -6,11 +6,13 @@ Optional, Sequence, Type, + TypedDict, Union, ) from ..exceptions import ConfigValidationError from ..options import NoValue, PoeOptions +from .primitives import EmptyDict, EnvDefault KNOWN_SHELL_INTERPRETERS = ( "posix", @@ -24,6 +26,14 @@ ) +class IncludeItem(TypedDict): + path: str + cwd: str + + +IncludeItem.__optional_keys__ = frozenset({"cwd"}) + + class ConfigPartition: options: PoeOptions full_config: Mapping[str, Any] @@ -74,9 +84,6 @@ def get(self, key: str, default: Any = NoValue): return self.options.get(key, default) -EmptyDict: Mapping = MappingProxyType({}) - - class ProjectConfig(ConfigPartition): is_primary = True @@ -88,10 +95,10 @@ class ConfigOptions(PoeOptions): default_task_type: str = "cmd" default_array_task_type: str = "sequence" default_array_item_task_type: str = "ref" - env: Mapping[str, str] = EmptyDict + env: Mapping[str, str | EnvDefault] = EmptyDict envfile: Union[str, Sequence[str]] = tuple() executor: Mapping[str, str] = MappingProxyType({"type": "auto"}) - include: Sequence[str] = tuple() + include: Union[str, Sequence[str], Sequence[IncludeItem]] = tuple() poetry_command: str = "poe" poetry_hooks: Mapping[str, str] = EmptyDict shell_interpreter: Union[str, Sequence[str]] = "posix" diff --git a/poethepoet/config/primitives.py b/poethepoet/config/primitives.py new file mode 100644 index 000000000..daeff634c --- /dev/null +++ b/poethepoet/config/primitives.py @@ -0,0 +1,8 @@ +from types import MappingProxyType +from typing import Mapping, TypedDict + +EmptyDict: Mapping = MappingProxyType({}) + + +class EnvDefault(TypedDict): + default: str diff --git a/poethepoet/options.py b/poethepoet/options/__init__.py similarity index 59% rename from poethepoet/options.py rename to poethepoet/options/__init__.py index ab7ebfd30..654e66691 100644 --- a/poethepoet/options.py +++ b/poethepoet/options/__init__.py @@ -1,22 +1,15 @@ -import collections +from __future__ import annotations + from keyword import iskeyword from typing import ( Any, - Dict, - List, - Literal, Mapping, - MutableMapping, - Optional, Sequence, - Tuple, - Type, - Union, - get_args, - get_origin, + get_type_hints, ) -from .exceptions import ConfigValidationError +from ..exceptions import ConfigValidationError +from .annotations import TypeAnnotation NoValue = object() @@ -26,7 +19,7 @@ class PoeOptions: A special kind of config object that parses options ... """ - __annotations: Dict[str, Type] + __annotations: dict[str, TypeAnnotation] def __init__(self, **options: Any): for key in self.get_fields(): @@ -61,13 +54,13 @@ def __getattr__(self, name: str): @classmethod def parse( cls, - source: Union[Mapping[str, Any], list], + source: Mapping[str, Any] | list, strict: bool = True, extra_keys: Sequence[str] = tuple(), ): config_keys = { - key[:-1] if key.endswith("_") and iskeyword(key[:-1]) else key: vtype - for key, vtype in cls.get_fields().items() + key[:-1] if key.endswith("_") and iskeyword(key[:-1]) else key: type_ + for key, type_ in cls.get_fields().items() } if strict: for index, item in enumerate(cls.normalize(source, strict)): @@ -110,29 +103,8 @@ def _parse_value( return value_type.parse(value, strict=strict) if strict: - expected_type: Union[Type, Tuple[Type, ...]] = cls._type_of(value_type) - if not isinstance(value, expected_type): - # Try format expected_type nicely in the error message - if not isinstance(expected_type, tuple): - expected_type = (expected_type,) - formatted_type = " | ".join( - type_.__name__ for type_ in expected_type if type_ is not type(None) - ) - raise ConfigValidationError( - f"Option {key!r} should have a value of type: {formatted_type}", - index=index, - ) - - annotation = cls.get_annotation(key) - if get_origin(annotation) is Literal: - allowed_values = get_args(annotation) - if value not in allowed_values: - raise ConfigValidationError( - f"Option {key!r} must be one of {allowed_values!r}", - index=index, - ) - - # TODO: validate list/dict contents + for error_msg in value_type.validate((key,), value): + raise ConfigValidationError(error_msg, index=index) return value @@ -171,28 +143,18 @@ def get(self, key: str, default: Any = NoValue) -> Any: if default is NoValue: # Fallback to getting getting the zero value for the type of this attribute # e.g. 0, False, empty list, empty dict, etc - return self.__get_zero_value(key) + annotation = self.get_fields().get(self._resolve_key(key)) + assert annotation + return annotation.zero_value() return default - def __get_zero_value(self, key: str): - type_of_attr = self.type_of(key) - if isinstance(type_of_attr, tuple): - if type(None) in type_of_attr: - # Optional types default to None - return None - type_of_attr = type_of_attr[0] - assert type_of_attr - return type_of_attr() - def __is_optional(self, key: str): - # TODO: precache optional options keys? - type_of_attr = self.type_of(key) - if isinstance(type_of_attr, tuple): - return type(None) in type_of_attr - return False + annotation = self.get_fields().get(self._resolve_key(key)) + assert annotation + return annotation.is_optional - def update(self, options_dict: Dict[str, Any]): + def update(self, options_dict: dict[str, Any]): new_options_dict = {} for key in self.get_fields().keys(): if key in options_dict: @@ -200,14 +162,6 @@ def update(self, options_dict: Dict[str, Any]): elif hasattr(self, key): new_options_dict[key] = getattr(self, key) - @classmethod - def type_of(cls, key: str) -> Optional[Union[Type, Tuple[Type, ...]]]: - return cls._type_of(cls.get_annotation(key)) - - @classmethod - def get_annotation(cls, key: str) -> Optional[Type]: - return cls.get_fields().get(cls._resolve_key(key)) - @classmethod def _resolve_key(cls, key: str) -> str: """ @@ -219,40 +173,7 @@ def _resolve_key(cls, key: str) -> str: return key @classmethod - def _type_of(cls, annotation: Any) -> Union[Type, Tuple[Type, ...]]: - if get_origin(annotation) is Union: - result: List[Type] = [] - for component in get_args(annotation): - component_type = cls._type_of(component) - if isinstance(component_type, tuple): - result.extend(component_type) - else: - result.append(component_type) - return tuple(result) - - if get_origin(annotation) in ( - dict, - Mapping, - MutableMapping, - collections.abc.Mapping, - collections.abc.MutableMapping, - ): - return dict - - if get_origin(annotation) in ( - list, - Sequence, - collections.abc.Sequence, - ): - return list - - if get_origin(annotation) is Literal: - return tuple({type(arg) for arg in get_args(annotation)}) - - return annotation - - @classmethod - def get_fields(cls) -> Dict[str, Any]: + def get_fields(cls) -> dict[str, TypeAnnotation]: """ Recent python versions removed inheritance for __annotations__ so we have to implement it explicitly @@ -260,10 +181,11 @@ def get_fields(cls) -> Dict[str, Any]: if not hasattr(cls, "__annotations"): annotations = {} for base_cls in cls.__bases__: - annotations.update(base_cls.__annotations__) - annotations.update(cls.__annotations__) + annotations.update(get_type_hints(base_cls)) + annotations.update(get_type_hints(cls)) + cls.__annotations = { - key: type_ + key: TypeAnnotation.parse(type_) for key, type_ in annotations.items() if not key.startswith("_") } diff --git a/poethepoet/options/annotations.py b/poethepoet/options/annotations.py new file mode 100644 index 000000000..c56f4dc74 --- /dev/null +++ b/poethepoet/options/annotations.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +import collections +import sys +import types +from typing import ( + Any, + Iterator, + Literal, + Mapping, + MutableMapping, + Sequence, + Union, + get_args, + get_origin, + get_type_hints, +) + + +class TypeAnnotation: + """ + This class and its descendants provide a convenient model for parsing and + enforcing pythonic type annotations for PoeOptions. + """ + + @staticmethod + def parse(annotation: Any): + origin = get_origin(annotation) + + if annotation in (str, int, float, bool): + return PrimativeType(annotation) + + elif annotation is dict or origin in ( + dict, + Mapping, + MutableMapping, + collections.abc.Mapping, + collections.abc.MutableMapping, + ): + return DictType(annotation) + + elif annotation is list or origin in ( + list, + tuple, + Sequence, + collections.abc.Sequence, + ): + return ListType(annotation) + + elif origin is Literal: + return LiteralType(annotation) + + elif origin in (Union, types.UnionType): + return UnionType(annotation) + + elif annotation is Any: + return AnyType(annotation) + + elif annotation in (None, type(None)): + return NoneType(annotation) + + elif _is_typeddict(annotation): + return TypedDictType(annotation) + + raise ValueError(f"Cannot parse TypeAnnotation for annotation: {annotation}") + + def __init__(self, annotation: Any): + self._annotation = annotation + + @property + def is_optional(self) -> bool: + return False + + def validate(self, path: tuple[str | int, ...], value: Any) -> Iterator[str]: + raise NotImplementedError + + def zero_value(self): + return None + + @staticmethod + def _format_path(path: tuple[str | int, ...]): + return "".join( + (f"[{part}]" if isinstance(part, int) else f".{part}") for part in path + ).strip(".") + + +class DictType(TypeAnnotation): + def __init__(self, annotation: Any): + super().__init__(annotation) + if args := get_args(annotation): + assert args[0] == str + self._value_type = TypeAnnotation.parse(get_args(annotation)[1]) + else: + self._value_type = AnyType() + + def __str__(self): + if isinstance(self._value_type, AnyType): + return "dict" + return f"dict[str, {self._value_type}]" + + def zero_value(self): + return {} + + def validate(self, path: tuple[str | int, ...], value: Any) -> Iterator[str]: + if not isinstance(value, dict): + yield f"Option {self._format_path(path)!r} must be a dict" + + if isinstance(self._value_type, AnyType): + return + + # We assume dict keys can only be strings so no need to check them + for key, dict_value in value.items(): + yield from self._value_type.validate((*path, key), dict_value) + + +class TypedDictType(TypeAnnotation): + def __init__(self, annotation: Any): + super().__init__(annotation) + self._schema = { + key: TypeAnnotation.parse(type_) + for key, type_ in get_type_hints(annotation).items() + } + self._optional_keys: frozenset[str] = getattr( + annotation, "__optional_keys__", frozenset() + ) + + def __str__(self): + return ( + "dict(" + + ", ".join(f"{key}: {value}" for key, value in self._schema.items()) + + ")" + ) + + def zero_value(self): + return {} + + def validate(self, path: tuple[str | int, ...], value: Any) -> Iterator[str]: + if not isinstance(value, dict): + yield f"Option {self._format_path(path)!r} must be a dict" + + for key, value_type in self._schema.items(): + if key not in value: + if key not in self._optional_keys: + yield ( + f"Option {self._format_path(path)!r} " + f"missing required key: {key}" + ) + continue + yield from value_type.validate((*path, key), value[key]) + + for key in set(value) - set(self._schema): + yield f"Option {self._format_path(path)!r} contains unexpected key: {key}" + + +class ListType(TypeAnnotation): + def __init__(self, annotation: Any): + super().__init__(annotation) + self._type = get_origin(annotation) or (tuple if annotation is tuple else list) + if args := get_args(annotation): + self._value_type = TypeAnnotation.parse(args[0]) + if self._type is tuple: + assert ( + args[1] is ... + ), "ListType only accepts tuples with any length type" + else: + self._value_type = AnyType() + + def __str__(self): + # Even if the type is tuple, only use list for error reporting etc + if isinstance(self._value_type, AnyType): + return "list" + return f"list[{self._value_type}]" + + def zero_value(self): + return [] + + def validate(self, path: tuple[str | int, ...], value: Any) -> Iterator[str]: + if not isinstance(value, (list, tuple)): + yield f"Option {self._format_path(path)!r} must be a list" + + if isinstance(self._value_type, AnyType): + return + + for idx, item in enumerate(value): + yield from self._value_type.validate((*path, idx), item) + + +class LiteralType(TypeAnnotation): + def __init__(self, annotation: Any): + super().__init__(annotation) + self._values = get_args(annotation) + + def __str__(self): + return " | ".join( + repr(type_) for type_ in self._values if type_ is not type(None) + ) + + def zero_value(self): + return self._values[0] + + def validate(self, path: tuple[str | int, ...], value: Any) -> Iterator[str]: + if value not in self._values: + yield f"Option {self._format_path(path)!r} must be one of {self._values!r}" + + +class UnionType(TypeAnnotation): + def __init__(self, annotation: Any): + super().__init__(annotation) + self._value_types = tuple( + TypeAnnotation.parse(arg) for arg in get_args(annotation) + ) + + @property + def is_optional(self) -> bool: + return any(isinstance(value_type, NoneType) for value_type in self._value_types) + + def __str__(self): + return " | ".join( + { + str(type_) + for type_ in self._value_types + if not isinstance(type_, NoneType) + } + ) + + def zero_value(self): + if type(None) in self._value_types: + return None + return self._value_types[0] + + def validate(self, path: tuple[str | int, ...], value: Any) -> Iterator[str]: + if len(self._value_types) == 2: + # In case this is a simple optional type then just validate the wrapped type + # This results in more specific validation errors + if isinstance(self._value_types[1], NoneType): + yield from self._value_types[0].validate(path, value) + return + elif isinstance(self._value_types[0], NoneType): + yield from self._value_types[1].validate(path, value) + return + + for value_type in self._value_types: + errors = next(value_type.validate(path, value), None) + if errors is None: + break + else: + yield ( + f"Option {self._format_path(path)!r} must have a value of type: {self}" + ) + + +class AnyType(TypeAnnotation): + def __init__(self, annotation: Any = Any): + super().__init__(annotation) + + def __str__(self): + return "Any" + + def validate(self, path: tuple[str | int, ...], value: Any) -> Iterator[str]: + if False: + yield "" + return + + +class NoneType(TypeAnnotation): + def __init__(self, annotation: Any = type(None)): + super().__init__(annotation) + + def __str__(self): + return "None" + + def validate(self, path: tuple[str | int, ...], value: Any) -> Iterator[str]: + if value is not None: + # this should probably never happen + yield f"Option {self._format_path(path)!r} must be None" + + +class PrimativeType(TypeAnnotation): + def __str__(self): + return self._annotation.__name__ + + def zero_value(self): + return self._annotation() + + def validate(self, path: tuple[str | int, ...], value: Any) -> Iterator[str]: + if not isinstance(value, self._annotation): + yield ( + f"Option {self._format_path(path)!r} must have a value of type: {self}" + ) + + +def _is_typeddict(value: Any): + import typing + + if not sys.version_info >= (3, 10): + return typing.is_typeddict(value) + else: + return isinstance(value, typing._TypedDictMeta) # type: ignore[attr-defined] diff --git a/poethepoet/task/args.py b/poethepoet/task/args.py index 5f70d9b22..4d2c2f2b5 100644 --- a/poethepoet/task/args.py +++ b/poethepoet/task/args.py @@ -1,15 +1,11 @@ +from __future__ import annotations + from typing import ( TYPE_CHECKING, Any, - Dict, - List, Literal, Mapping, - Optional, Sequence, - Tuple, - Type, - Union, ) if TYPE_CHECKING: @@ -20,20 +16,10 @@ from ..exceptions import ConfigValidationError from ..options import PoeOptions -ArgParams = Dict[str, Any] -ArgsDef = Union[List[str], List[ArgParams], Dict[str, ArgParams]] - -arg_param_schema: Dict[str, Union[Type, Tuple[Type, ...]]] = { - "default": (str, int, float, bool), - "help": str, - "name": str, - "options": (list, tuple), - "positional": (bool, str), - "required": bool, - "type": str, - "multiple": (bool, int), -} -arg_types: Dict[str, Type] = { +ArgParams = dict[str, Any] +ArgsDef = list[str] | list[ArgParams] | dict[str, ArgParams] + +arg_types: dict[str, type] = { "string": str, "float": float, "integer": int, @@ -42,14 +28,14 @@ class ArgSpec(PoeOptions): - default: Optional[Union[str, int, float, bool]] = None + default: str | int | float | bool | None = None help: str = "" name: str - options: Union[list, tuple] - positional: Union[bool, str] = False + options: Sequence[str] + positional: bool | str = False required: bool = False type: Literal["string", "float", "integer", "boolean"] = "string" - multiple: Union[bool, int] = False + multiple: bool | int = False @classmethod def normalize(cls, args_def: ArgsDef, strict: bool = True): @@ -95,7 +81,7 @@ def normalize(cls, args_def: ArgsDef, strict: bool = True): @classmethod def parse( cls, - source: Union[Mapping[str, Any], list], + source: Mapping[str, Any] | list, strict: bool = True, extra_keys: Sequence[str] = tuple(), ): @@ -126,7 +112,7 @@ def parse( @staticmethod def _get_arg_options_list( - arg: ArgParams, name: Optional[str] = None, strict: bool = True + arg: ArgParams, name: str | None = None, strict: bool = True ): position = arg.get("positional", False) name = name or arg.get("name") @@ -139,7 +125,7 @@ def _get_arg_options_list( if isinstance(position, str): return [position] return [name] - return tuple(arg.get("options", (f"--{name}",))) + return tuple(arg.get("options", [f"--{name}"])) def validate(self): try: @@ -186,7 +172,7 @@ def _validate(self): class PoeTaskArgs: - _args: Tuple[ArgSpec, ...] + _args: tuple[ArgSpec, ...] def __init__(self, args_def: ArgsDef, task_name: str): self._task_name = task_name @@ -201,8 +187,8 @@ def _parse_args_def(self, args_def: ArgsDef): @classmethod def get_help_content( - cls, args_def: Optional[ArgsDef], task_name: str, suppress_errors: bool = False - ) -> List[Tuple[Tuple[str, ...], str, str]]: + cls, args_def: ArgsDef | None, task_name: str, suppress_errors: bool = False + ) -> list[tuple[tuple[str, ...], str, str]]: if args_def is None: return [] @@ -242,9 +228,7 @@ def _enrich_config_error( error.context = f"Invalid argument {arg_ref!r} declared" error.task_name = task_name - def build_parser( - self, env: "EnvVarsManager", program_name: str - ) -> "ArgumentParser": + def build_parser(self, env: EnvVarsManager, program_name: str) -> ArgumentParser: import argparse parser = argparse.ArgumentParser( @@ -259,7 +243,7 @@ def build_parser( ) return parser - def _get_argument_params(self, arg: ArgSpec, env: "EnvVarsManager"): + def _get_argument_params(self, arg: ArgSpec, env: EnvVarsManager): default = arg.get("default") if isinstance(default, str): default = env.fill_template(default) @@ -295,7 +279,7 @@ def _get_argument_params(self, arg: ArgSpec, env: "EnvVarsManager"): return result - def parse(self, args: Sequence[str], env: "EnvVarsManager", program_name: str): + def parse(self, args: Sequence[str], env: EnvVarsManager, program_name: str): parsed_args = vars(self.build_parser(env, program_name).parse_args(args)) # Ensure positional args are still exposed by name even if they were parsed with # alternate identifiers diff --git a/poethepoet/task/base.py b/poethepoet/task/base.py index 1557f8677..e9c46b109 100644 --- a/poethepoet/task/base.py +++ b/poethepoet/task/base.py @@ -18,6 +18,7 @@ Union, ) +from ..config.primitives import EmptyDict, EnvDefault from ..exceptions import ConfigValidationError, PoeException from ..options import PoeOptions @@ -173,11 +174,11 @@ class TaskOptions(PoeOptions): capture_stdout: Optional[str] = None cwd: Optional[str] = None deps: Optional[Sequence[str]] = None - env: Optional[dict] = None - envfile: Optional[Union[str, list]] = None + env: Mapping[str, str | EnvDefault] = EmptyDict + envfile: Union[str, Sequence[str]] = tuple() executor: Optional[dict] = None help: Optional[str] = None - uses: Optional[dict] = None + uses: Optional[Mapping[str, str]] = None def validate(self): """ diff --git a/poethepoet/task/shell.py b/poethepoet/task/shell.py index 2e2752541..ae43e2a4a 100644 --- a/poethepoet/task/shell.py +++ b/poethepoet/task/shell.py @@ -4,6 +4,7 @@ TYPE_CHECKING, List, Optional, + Sequence, Tuple, Union, ) @@ -26,7 +27,7 @@ class ShellTask(PoeTask): __key__ = "shell" class TaskOptions(PoeTask.TaskOptions): - interpreter: Optional[Union[str, list]] = None + interpreter: Optional[Union[str, Sequence[str]]] = None def validate(self): super().validate()