From 7d09962aafb3d6e707b02db54b9e44ca14dd6bbe Mon Sep 17 00:00:00 2001 From: Nat Noordanus Date: Sat, 24 Feb 2024 22:09:58 +0100 Subject: [PATCH] WIP make args work as expected on ref tasks --- poethepoet/app.py | 6 +-- poethepoet/task/args.py | 6 ++- poethepoet/task/base.py | 87 ++++++++++++++++++++----------------- poethepoet/task/cmd.py | 18 ++------ poethepoet/task/expr.py | 8 ++-- poethepoet/task/ref.py | 38 ++++++++-------- poethepoet/task/script.py | 9 ++-- poethepoet/task/sequence.py | 10 ++--- poethepoet/task/shell.py | 6 +-- poethepoet/task/switch.py | 14 ++---- 10 files changed, 95 insertions(+), 107 deletions(-) diff --git a/poethepoet/app.py b/poethepoet/app.py index afb0b6ac6..fed4ae145 100644 --- a/poethepoet/app.py +++ b/poethepoet/app.py @@ -153,7 +153,7 @@ def run_task( if context is None: context = self.get_run_context() try: - return task.run(context=context, extra_args=task.invocation[1:]) + return task.run(context=context) except PoeException as error: self.print_help(error=error) return 1 @@ -175,9 +175,7 @@ def run_task_graph(self, task: "PoeTask") -> Optional[int]: return self.run_task(stage_task, context) try: - task_result = stage_task.run( - context=context, extra_args=stage_task.invocation[1:] - ) + task_result = stage_task.run(context=context) if task_result: raise ExecutionError( f"Task graph aborted after failed task {stage_task.name!r}" diff --git a/poethepoet/task/args.py b/poethepoet/task/args.py index aeaba47cb..ec32e91e2 100644 --- a/poethepoet/task/args.py +++ b/poethepoet/task/args.py @@ -298,13 +298,15 @@ def _get_argument_params(self, arg: ArgParams): return result - def parse(self, extra_args: Sequence[str]): - parsed_args = vars(self.build_parser().parse_args(extra_args)) + def parse(self, args: Sequence[str]) -> Dict[str, str]: + parsed_args = vars(self.build_parser().parse_args(args)) + # Ensure positional args are still exposed by name even if they were parsed with # alternate identifiers for arg in self._args: if isinstance(arg.get("positional"), str): parsed_args[arg["name"]] = parsed_args[arg["positional"]] del parsed_args[arg["positional"]] + # args named with dash case are converted to snake case before being exposed return {name.replace("-", "_"): value for name, value in parsed_args.items()} diff --git a/poethepoet/task/base.py b/poethepoet/task/base.py index 0e4cf3789..850fd5263 100644 --- a/poethepoet/task/base.py +++ b/poethepoet/task/base.py @@ -1,5 +1,6 @@ import re import sys +from os import environ from pathlib import Path from typing import ( TYPE_CHECKING, @@ -10,7 +11,6 @@ List, NamedTuple, Optional, - Sequence, Tuple, Type, Union, @@ -65,7 +65,7 @@ class PoeTask(metaclass=MetaPoeTask): content: TaskContent options: Dict[str, Any] inheritance: TaskInheritance - named_args: Optional[Dict[str, str]] = None + _parsed_args: Optional[Tuple[Dict[str, str], Tuple[str, ...]]] = None __options__: Dict[str, Union[Type, Tuple[Type, ...]]] = {} __content_type__: Type = str @@ -216,42 +216,49 @@ def resolve_task_type( return None - def get_named_arg_values(self, env: "EnvVarsManager") -> Dict[str, str]: - try: - split_index = self.invocation.index("--") - parse_args = self.invocation[1:split_index] - except ValueError: - parse_args = self.invocation[1:] + def get_parsed_arguments( + self, env: "EnvVarsManager" + ) -> Tuple[Dict[str, str], Tuple[str, ...]]: + if self._parsed_args is None: + all_args = self.invocation[1:] - if self.named_args is None: - self.named_args = self._parse_named_args(parse_args, env) - - if not self.named_args: - return {} - - return self.named_args + if args_def := self.options.get("args"): + from .args import PoeTaskArgs - def _parse_named_args( - self, extra_args: Sequence[str], env: "EnvVarsManager" - ) -> Optional[Dict[str, str]]: - if args_def := self.options.get("args"): - from .args import PoeTaskArgs + try: + split_index = all_args.index("--") + option_args = all_args[:split_index] + extra_args = all_args[split_index + 1 :] + except ValueError: + option_args = all_args + extra_args = tuple() + + self._parsed_args = ( + PoeTaskArgs(args_def, self.name, self._ui.program_name, env).parse( + option_args + ), + extra_args, + ) - return PoeTaskArgs(args_def, self.name, self._ui.program_name, env).parse( - extra_args - ) + else: + self._parsed_args = ({}, all_args) - return None + return self._parsed_args def run( self, context: "RunContext", - extra_args: Sequence[str] = tuple(), parent_env: Optional["EnvVarsManager"] = None, ) -> int: """ Run this task """ + + if environ.get("POE_DEBUG"): + task_type_key = self.__key__ # type: ignore[attr-defined] + print(f" * Running {task_type_key}:{self.name}") + print(f" . Invocation {self.invocation!r}") + upstream_invocations = self._get_upstream_invocations(context) if context.dry and upstream_invocations.get("uses", {}): @@ -265,21 +272,23 @@ def run( ) return 0 - return self._handle_run( - context, - extra_args, - context.get_task_env( - parent_env, - self.options.get("envfile"), - self.options.get("env"), - upstream_invocations["uses"], - ), + task_env = context.get_task_env( + parent_env, + self.options.get("envfile"), + self.options.get("env"), + upstream_invocations["uses"], ) + if environ.get("POE_DEBUG"): + named_arg_values, extra_args = self.get_parsed_arguments(task_env) + print(f" . Parsed args {named_arg_values!r}") + print(f" . Extra args {extra_args!r}") + + return self._handle_run(context, task_env) + def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: """ @@ -288,11 +297,7 @@ def _handle_run( """ raise NotImplementedError - def _get_executor( - self, - context: "RunContext", - env: "EnvVarsManager", - ): + def _get_executor(self, context: "RunContext", env: "EnvVarsManager"): return context.get_executor( self.invocation, env, @@ -335,7 +340,7 @@ def _get_upstream_invocations(self, context: "RunContext"): env = context.get_task_env( None, self.options.get("envfile"), self.options.get("env") ) - env.update(self.get_named_arg_values(env)) + env.update(self.get_parsed_arguments(env)[0]) self.__upstream_invocations = { "deps": [ diff --git a/poethepoet/task/cmd.py b/poethepoet/task/cmd.py index 7db949a0e..22eb8a58e 100644 --- a/poethepoet/task/cmd.py +++ b/poethepoet/task/cmd.py @@ -1,5 +1,5 @@ import shlex -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union from ..exceptions import PoeException from .base import PoeTask @@ -25,22 +25,12 @@ class CmdTask(PoeTask): def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: - named_arg_values = self.get_named_arg_values(env) + named_arg_values, extra_args = self.get_parsed_arguments(env) env.update(named_arg_values) - if named_arg_values: - # If named arguments are defined then pass only arguments following a double - # dash token: `--` - try: - split_index = extra_args.index("--") - extra_args = extra_args[split_index + 1 :] - except ValueError: - extra_args = tuple() - - cmd = (*self._resolve_args(context, env), *extra_args) + cmd = (*self._resolve_commandline(context, env), *extra_args) self._print_action(shlex.join(cmd), context.dry) @@ -48,7 +38,7 @@ def _handle_run( cmd, use_exec=self.options.get("use_exec", False) ) - def _resolve_args(self, context: "RunContext", env: "EnvVarsManager"): + def _resolve_commandline(self, context: "RunContext", env: "EnvVarsManager"): from ..helpers.command import parse_poe_cmd, resolve_command_tokens from ..helpers.command.ast_core import ParseError diff --git a/poethepoet/task/expr.py b/poethepoet/task/expr.py index 450e020c2..ae5a5e0f5 100644 --- a/poethepoet/task/expr.py +++ b/poethepoet/task/expr.py @@ -6,7 +6,6 @@ Iterable, Mapping, Optional, - Sequence, Tuple, Type, Union, @@ -38,20 +37,21 @@ class ExprTask(PoeTask): def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: from ..helpers.python import format_class - named_arg_values = self.get_named_arg_values(env) + named_arg_values, extra_args = self.get_parsed_arguments(env) env.update(named_arg_values) + # TODO: do something about extra_args, error? + imports = self.options.get("imports", tuple()) expr, env_values = self.parse_content(named_arg_values, env, imports) argv = [ self.name, - *(env.fill_template(token) for token in extra_args), + *(env.fill_template(token) for token in self.invocation[1:]), ] script = [ diff --git a/poethepoet/task/ref.py b/poethepoet/task/ref.py index 0612ba82e..667b17c98 100644 --- a/poethepoet/task/ref.py +++ b/poethepoet/task/ref.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union from .base import PoeTask, TaskInheritance @@ -21,7 +21,6 @@ class RefTask(PoeTask): def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: """ @@ -29,26 +28,35 @@ def _handle_run( """ import shlex - invocation = tuple(shlex.split(env.fill_template(self.content.strip()))) - extra_args = [*invocation[1:], *extra_args] - task = self.from_config( - invocation[0], + named_arg_values, extra_args = self.get_parsed_arguments(env) + env.update(named_arg_values) + + ref_invocation = [ + env.fill_template(token) + for token in shlex.split(env.fill_template(self.content.strip())) + ] + if not named_arg_values: + ref_invocation.extend(self.invocation[1:]) + else: + ref_invocation.extend(extra_args) + + self.ref_task = self.from_config( + ref_invocation[0], self._config, self._ui, - invocation, + invocation=tuple(ref_invocation), inheritance=TaskInheritance.from_task(self), ) - if task.has_deps(): - return self._run_task_graph(task, context, extra_args, env) + if self.ref_task.has_deps(): + return self._run_task_graph(self.ref_task, context, env) - return task.run(context=context, extra_args=extra_args, parent_env=env) + return self.ref_task.run(context=context, parent_env=env) def _run_task_graph( self, task: "PoeTask", context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: from ..exceptions import ExecutionError @@ -60,13 +68,9 @@ def _run_task_graph( for stage_task in stage: if stage_task == task: # The final sink task gets special treatment - return task.run( - context=context, extra_args=extra_args, parent_env=env - ) + return task.run(context=context, parent_env=env) - task_result = stage_task.run( - context=context, extra_args=stage_task.invocation[1:] - ) + task_result = stage_task.run(context=context) if task_result: raise ExecutionError( f"Task graph aborted after failed task {stage_task.name!r}" diff --git a/poethepoet/task/script.py b/poethepoet/task/script.py index 3f4b84730..5ff6dc9ed 100644 --- a/poethepoet/task/script.py +++ b/poethepoet/task/script.py @@ -1,6 +1,6 @@ import re import shlex -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union from ..exceptions import ExpressionParseError from .base import PoeTask @@ -27,20 +27,21 @@ class ScriptTask(PoeTask): def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: from ..helpers.python import format_class - named_arg_values = self.get_named_arg_values(env) + named_arg_values, extra_args = self.get_parsed_arguments(env) env.update(named_arg_values) + # TODO: do something about extra_args, error? + target_module, function_call = self.parse_content(named_arg_values) function_ref = function_call[: function_call.index("(")] argv = [ self.name, - *(env.fill_template(token) for token in extra_args), + *(env.fill_template(token) for token in self.invocation[1:]), ] # TODO: check whether the project really does use src layout, and don't do diff --git a/poethepoet/task/sequence.py b/poethepoet/task/sequence.py index 09999e2c0..805ed1a5e 100644 --- a/poethepoet/task/sequence.py +++ b/poethepoet/task/sequence.py @@ -4,7 +4,6 @@ Dict, List, Optional, - Sequence, Tuple, Type, Union, @@ -69,13 +68,12 @@ def __init__( def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: - named_arg_values = self.get_named_arg_values(env) + named_arg_values, extra_args = self.get_parsed_arguments(env) env.update(named_arg_values) - if not named_arg_values and any(arg.strip() for arg in extra_args): + if not named_arg_values and any(arg.strip() for arg in self.invocation[1:]): raise PoeException(f"Sequence task {self.name!r} does not accept arguments") if len(self.subtasks) > 1: @@ -85,9 +83,7 @@ def _handle_run( ignore_fail = self.options.get("ignore_fail") non_zero_subtasks: List[str] = list() for subtask in self.subtasks: - task_result = subtask.run( - context=context, extra_args=tuple(), parent_env=env - ) + task_result = subtask.run(context=context, parent_env=env) if task_result and not ignore_fail: raise ExecutionError( f"Sequence aborted after failed subtask {subtask.name!r}" diff --git a/poethepoet/task/shell.py b/poethepoet/task/shell.py index f447737f4..4feb1380d 100644 --- a/poethepoet/task/shell.py +++ b/poethepoet/task/shell.py @@ -6,7 +6,6 @@ Dict, List, Optional, - Sequence, Tuple, Type, Union, @@ -34,13 +33,12 @@ class ShellTask(PoeTask): def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: - named_arg_values = self.get_named_arg_values(env) + named_arg_values, extra_args = self.get_parsed_arguments(env) env.update(named_arg_values) - if not named_arg_values and any(arg.strip() for arg in extra_args): + if not named_arg_values and any(arg.strip() for arg in self.invocation[1:]): raise PoeException(f"Shell task {self.name!r} does not accept arguments") interpreter_cmd = self.resolve_interpreter_cmd() diff --git a/poethepoet/task/switch.py b/poethepoet/task/switch.py index 719728c3a..977508dcd 100644 --- a/poethepoet/task/switch.py +++ b/poethepoet/task/switch.py @@ -5,7 +5,6 @@ List, MutableMapping, Optional, - Sequence, Tuple, Type, Union, @@ -94,23 +93,18 @@ def __init__( def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: - named_arg_values = self.get_named_arg_values(env) + named_arg_values, extra_args = self.get_parsed_arguments(env) env.update(named_arg_values) - if not named_arg_values and any(arg.strip() for arg in extra_args): + if not named_arg_values and any(arg.strip() for arg in self.invocation[1:]): raise PoeException(f"Switch task {self.name!r} does not accept arguments") # Indicate on the global context that there are multiple stages to this task context.multistage = True - task_result = self.control_task.run( - context=context, - extra_args=extra_args if self.options.get("args") else tuple(), - parent_env=env, - ) + task_result = self.control_task.run(context=context, parent_env=env) if task_result: raise ExecutionError( f"Switch task {self.name!r} aborted after failed control task" @@ -135,7 +129,7 @@ def _handle_run( f"switch task {self.name!r}." ) - return case_task.run(context=context, extra_args=extra_args, parent_env=env) + return case_task.run(context=context, parent_env=env) @classmethod def _get_case_keys(cls, task_def: Dict[str, Any]) -> List[Any]: