diff --git a/docs/tasks/task_types/ref.rst b/docs/tasks/task_types/ref.rst index 0376d1c3b..e99eb543b 100644 --- a/docs/tasks/task_types/ref.rst +++ b/docs/tasks/task_types/ref.rst @@ -33,3 +33,9 @@ Available task options ---------------------- ``ref`` tasks support all of the :doc:`standard task options <../options>` with the exception of ``use_exec``. + + +Passing arguments +----------------- + +By default any arguments passed to a ref task will be forwarded to the referenced task, allowing it to function as a task alias. If named arguments are configured for the ref task then additional arguments can still be passed to the referenced task after ``--`` on the command line. diff --git a/poethepoet/app.py b/poethepoet/app.py index afb0b6ac6..fed4ae145 100644 --- a/poethepoet/app.py +++ b/poethepoet/app.py @@ -153,7 +153,7 @@ def run_task( if context is None: context = self.get_run_context() try: - return task.run(context=context, extra_args=task.invocation[1:]) + return task.run(context=context) except PoeException as error: self.print_help(error=error) return 1 @@ -175,9 +175,7 @@ def run_task_graph(self, task: "PoeTask") -> Optional[int]: return self.run_task(stage_task, context) try: - task_result = stage_task.run( - context=context, extra_args=stage_task.invocation[1:] - ) + task_result = stage_task.run(context=context) if task_result: raise ExecutionError( f"Task graph aborted after failed task {stage_task.name!r}" diff --git a/poethepoet/task/args.py b/poethepoet/task/args.py index aeaba47cb..ec32e91e2 100644 --- a/poethepoet/task/args.py +++ b/poethepoet/task/args.py @@ -298,13 +298,15 @@ def _get_argument_params(self, arg: ArgParams): return result - def parse(self, extra_args: Sequence[str]): - parsed_args = vars(self.build_parser().parse_args(extra_args)) + def parse(self, args: Sequence[str]) -> Dict[str, str]: + parsed_args = vars(self.build_parser().parse_args(args)) + # Ensure positional args are still exposed by name even if they were parsed with # alternate identifiers for arg in self._args: if isinstance(arg.get("positional"), str): parsed_args[arg["name"]] = parsed_args[arg["positional"]] del parsed_args[arg["positional"]] + # args named with dash case are converted to snake case before being exposed return {name.replace("-", "_"): value for name, value in parsed_args.items()} diff --git a/poethepoet/task/base.py b/poethepoet/task/base.py index 0e4cf3789..850fd5263 100644 --- a/poethepoet/task/base.py +++ b/poethepoet/task/base.py @@ -1,5 +1,6 @@ import re import sys +from os import environ from pathlib import Path from typing import ( TYPE_CHECKING, @@ -10,7 +11,6 @@ List, NamedTuple, Optional, - Sequence, Tuple, Type, Union, @@ -65,7 +65,7 @@ class PoeTask(metaclass=MetaPoeTask): content: TaskContent options: Dict[str, Any] inheritance: TaskInheritance - named_args: Optional[Dict[str, str]] = None + _parsed_args: Optional[Tuple[Dict[str, str], Tuple[str, ...]]] = None __options__: Dict[str, Union[Type, Tuple[Type, ...]]] = {} __content_type__: Type = str @@ -216,42 +216,49 @@ def resolve_task_type( return None - def get_named_arg_values(self, env: "EnvVarsManager") -> Dict[str, str]: - try: - split_index = self.invocation.index("--") - parse_args = self.invocation[1:split_index] - except ValueError: - parse_args = self.invocation[1:] + def get_parsed_arguments( + self, env: "EnvVarsManager" + ) -> Tuple[Dict[str, str], Tuple[str, ...]]: + if self._parsed_args is None: + all_args = self.invocation[1:] - if self.named_args is None: - self.named_args = self._parse_named_args(parse_args, env) - - if not self.named_args: - return {} - - return self.named_args + if args_def := self.options.get("args"): + from .args import PoeTaskArgs - def _parse_named_args( - self, extra_args: Sequence[str], env: "EnvVarsManager" - ) -> Optional[Dict[str, str]]: - if args_def := self.options.get("args"): - from .args import PoeTaskArgs + try: + split_index = all_args.index("--") + option_args = all_args[:split_index] + extra_args = all_args[split_index + 1 :] + except ValueError: + option_args = all_args + extra_args = tuple() + + self._parsed_args = ( + PoeTaskArgs(args_def, self.name, self._ui.program_name, env).parse( + option_args + ), + extra_args, + ) - return PoeTaskArgs(args_def, self.name, self._ui.program_name, env).parse( - extra_args - ) + else: + self._parsed_args = ({}, all_args) - return None + return self._parsed_args def run( self, context: "RunContext", - extra_args: Sequence[str] = tuple(), parent_env: Optional["EnvVarsManager"] = None, ) -> int: """ Run this task """ + + if environ.get("POE_DEBUG"): + task_type_key = self.__key__ # type: ignore[attr-defined] + print(f" * Running {task_type_key}:{self.name}") + print(f" . Invocation {self.invocation!r}") + upstream_invocations = self._get_upstream_invocations(context) if context.dry and upstream_invocations.get("uses", {}): @@ -265,21 +272,23 @@ def run( ) return 0 - return self._handle_run( - context, - extra_args, - context.get_task_env( - parent_env, - self.options.get("envfile"), - self.options.get("env"), - upstream_invocations["uses"], - ), + task_env = context.get_task_env( + parent_env, + self.options.get("envfile"), + self.options.get("env"), + upstream_invocations["uses"], ) + if environ.get("POE_DEBUG"): + named_arg_values, extra_args = self.get_parsed_arguments(task_env) + print(f" . Parsed args {named_arg_values!r}") + print(f" . Extra args {extra_args!r}") + + return self._handle_run(context, task_env) + def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: """ @@ -288,11 +297,7 @@ def _handle_run( """ raise NotImplementedError - def _get_executor( - self, - context: "RunContext", - env: "EnvVarsManager", - ): + def _get_executor(self, context: "RunContext", env: "EnvVarsManager"): return context.get_executor( self.invocation, env, @@ -335,7 +340,7 @@ def _get_upstream_invocations(self, context: "RunContext"): env = context.get_task_env( None, self.options.get("envfile"), self.options.get("env") ) - env.update(self.get_named_arg_values(env)) + env.update(self.get_parsed_arguments(env)[0]) self.__upstream_invocations = { "deps": [ diff --git a/poethepoet/task/cmd.py b/poethepoet/task/cmd.py index 7db949a0e..22eb8a58e 100644 --- a/poethepoet/task/cmd.py +++ b/poethepoet/task/cmd.py @@ -1,5 +1,5 @@ import shlex -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union from ..exceptions import PoeException from .base import PoeTask @@ -25,22 +25,12 @@ class CmdTask(PoeTask): def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: - named_arg_values = self.get_named_arg_values(env) + named_arg_values, extra_args = self.get_parsed_arguments(env) env.update(named_arg_values) - if named_arg_values: - # If named arguments are defined then pass only arguments following a double - # dash token: `--` - try: - split_index = extra_args.index("--") - extra_args = extra_args[split_index + 1 :] - except ValueError: - extra_args = tuple() - - cmd = (*self._resolve_args(context, env), *extra_args) + cmd = (*self._resolve_commandline(context, env), *extra_args) self._print_action(shlex.join(cmd), context.dry) @@ -48,7 +38,7 @@ def _handle_run( cmd, use_exec=self.options.get("use_exec", False) ) - def _resolve_args(self, context: "RunContext", env: "EnvVarsManager"): + def _resolve_commandline(self, context: "RunContext", env: "EnvVarsManager"): from ..helpers.command import parse_poe_cmd, resolve_command_tokens from ..helpers.command.ast_core import ParseError diff --git a/poethepoet/task/expr.py b/poethepoet/task/expr.py index 450e020c2..ae5a5e0f5 100644 --- a/poethepoet/task/expr.py +++ b/poethepoet/task/expr.py @@ -6,7 +6,6 @@ Iterable, Mapping, Optional, - Sequence, Tuple, Type, Union, @@ -38,20 +37,21 @@ class ExprTask(PoeTask): def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: from ..helpers.python import format_class - named_arg_values = self.get_named_arg_values(env) + named_arg_values, extra_args = self.get_parsed_arguments(env) env.update(named_arg_values) + # TODO: do something about extra_args, error? + imports = self.options.get("imports", tuple()) expr, env_values = self.parse_content(named_arg_values, env, imports) argv = [ self.name, - *(env.fill_template(token) for token in extra_args), + *(env.fill_template(token) for token in self.invocation[1:]), ] script = [ diff --git a/poethepoet/task/ref.py b/poethepoet/task/ref.py index 0612ba82e..d637cff89 100644 --- a/poethepoet/task/ref.py +++ b/poethepoet/task/ref.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union from .base import PoeTask, TaskInheritance @@ -21,7 +21,6 @@ class RefTask(PoeTask): def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: """ @@ -29,26 +28,34 @@ def _handle_run( """ import shlex - invocation = tuple(shlex.split(env.fill_template(self.content.strip()))) - extra_args = [*invocation[1:], *extra_args] + named_arg_values, extra_args = self.get_parsed_arguments(env) + env.update(named_arg_values) + + ref_invocation = ( + *( + env.fill_template(token) + for token in shlex.split(env.fill_template(self.content.strip())) + ), + *extra_args, + ) + task = self.from_config( - invocation[0], + ref_invocation[0], self._config, self._ui, - invocation, + invocation=ref_invocation, inheritance=TaskInheritance.from_task(self), ) if task.has_deps(): - return self._run_task_graph(task, context, extra_args, env) + return self._run_task_graph(task, context, env) - return task.run(context=context, extra_args=extra_args, parent_env=env) + return task.run(context=context, parent_env=env) def _run_task_graph( self, task: "PoeTask", context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: from ..exceptions import ExecutionError @@ -60,13 +67,9 @@ def _run_task_graph( for stage_task in stage: if stage_task == task: # The final sink task gets special treatment - return task.run( - context=context, extra_args=extra_args, parent_env=env - ) + return task.run(context=context, parent_env=env) - task_result = stage_task.run( - context=context, extra_args=stage_task.invocation[1:] - ) + task_result = stage_task.run(context=context) if task_result: raise ExecutionError( f"Task graph aborted after failed task {stage_task.name!r}" diff --git a/poethepoet/task/script.py b/poethepoet/task/script.py index 3f4b84730..5ff6dc9ed 100644 --- a/poethepoet/task/script.py +++ b/poethepoet/task/script.py @@ -1,6 +1,6 @@ import re import shlex -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union from ..exceptions import ExpressionParseError from .base import PoeTask @@ -27,20 +27,21 @@ class ScriptTask(PoeTask): def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: from ..helpers.python import format_class - named_arg_values = self.get_named_arg_values(env) + named_arg_values, extra_args = self.get_parsed_arguments(env) env.update(named_arg_values) + # TODO: do something about extra_args, error? + target_module, function_call = self.parse_content(named_arg_values) function_ref = function_call[: function_call.index("(")] argv = [ self.name, - *(env.fill_template(token) for token in extra_args), + *(env.fill_template(token) for token in self.invocation[1:]), ] # TODO: check whether the project really does use src layout, and don't do diff --git a/poethepoet/task/sequence.py b/poethepoet/task/sequence.py index 09999e2c0..805ed1a5e 100644 --- a/poethepoet/task/sequence.py +++ b/poethepoet/task/sequence.py @@ -4,7 +4,6 @@ Dict, List, Optional, - Sequence, Tuple, Type, Union, @@ -69,13 +68,12 @@ def __init__( def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: - named_arg_values = self.get_named_arg_values(env) + named_arg_values, extra_args = self.get_parsed_arguments(env) env.update(named_arg_values) - if not named_arg_values and any(arg.strip() for arg in extra_args): + if not named_arg_values and any(arg.strip() for arg in self.invocation[1:]): raise PoeException(f"Sequence task {self.name!r} does not accept arguments") if len(self.subtasks) > 1: @@ -85,9 +83,7 @@ def _handle_run( ignore_fail = self.options.get("ignore_fail") non_zero_subtasks: List[str] = list() for subtask in self.subtasks: - task_result = subtask.run( - context=context, extra_args=tuple(), parent_env=env - ) + task_result = subtask.run(context=context, parent_env=env) if task_result and not ignore_fail: raise ExecutionError( f"Sequence aborted after failed subtask {subtask.name!r}" diff --git a/poethepoet/task/shell.py b/poethepoet/task/shell.py index f447737f4..4feb1380d 100644 --- a/poethepoet/task/shell.py +++ b/poethepoet/task/shell.py @@ -6,7 +6,6 @@ Dict, List, Optional, - Sequence, Tuple, Type, Union, @@ -34,13 +33,12 @@ class ShellTask(PoeTask): def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: - named_arg_values = self.get_named_arg_values(env) + named_arg_values, extra_args = self.get_parsed_arguments(env) env.update(named_arg_values) - if not named_arg_values and any(arg.strip() for arg in extra_args): + if not named_arg_values and any(arg.strip() for arg in self.invocation[1:]): raise PoeException(f"Shell task {self.name!r} does not accept arguments") interpreter_cmd = self.resolve_interpreter_cmd() diff --git a/poethepoet/task/switch.py b/poethepoet/task/switch.py index 719728c3a..977508dcd 100644 --- a/poethepoet/task/switch.py +++ b/poethepoet/task/switch.py @@ -5,7 +5,6 @@ List, MutableMapping, Optional, - Sequence, Tuple, Type, Union, @@ -94,23 +93,18 @@ def __init__( def _handle_run( self, context: "RunContext", - extra_args: Sequence[str], env: "EnvVarsManager", ) -> int: - named_arg_values = self.get_named_arg_values(env) + named_arg_values, extra_args = self.get_parsed_arguments(env) env.update(named_arg_values) - if not named_arg_values and any(arg.strip() for arg in extra_args): + if not named_arg_values and any(arg.strip() for arg in self.invocation[1:]): raise PoeException(f"Switch task {self.name!r} does not accept arguments") # Indicate on the global context that there are multiple stages to this task context.multistage = True - task_result = self.control_task.run( - context=context, - extra_args=extra_args if self.options.get("args") else tuple(), - parent_env=env, - ) + task_result = self.control_task.run(context=context, parent_env=env) if task_result: raise ExecutionError( f"Switch task {self.name!r} aborted after failed control task" @@ -135,7 +129,7 @@ def _handle_run( f"switch task {self.name!r}." ) - return case_task.run(context=context, extra_args=extra_args, parent_env=env) + return case_task.run(context=context, parent_env=env) @classmethod def _get_case_keys(cls, task_def: Dict[str, Any]) -> List[Any]: diff --git a/tests/fixtures/refs_project/pyproject.toml b/tests/fixtures/refs_project/pyproject.toml index a15407b5f..2c2f3a302 100644 --- a/tests/fixtures/refs_project/pyproject.toml +++ b/tests/fixtures/refs_project/pyproject.toml @@ -10,3 +10,14 @@ ref = "greet lol!" [tool.poe.tasks.greet-dave] ref = "greet-subject --subject dave" + +[tool.poe.tasks.apologize] +cmd = "echo \"I'm sorry ${name}, ${explain}\"" +args = ["name", "explain"] + +[tool.poe.tasks.say-sorry] +ref = "apologize" + +[tool.poe.tasks.sorry-dave] +ref = "apologize --name=Dave --explain='${explain}'" +args = [{ name = "explain", positional = true, multiple = true }] diff --git a/tests/test_ref_task.py b/tests/test_ref_task.py index 15ab36210..82f986109 100644 --- a/tests/test_ref_task.py +++ b/tests/test_ref_task.py @@ -10,3 +10,61 @@ def test_ref_passes_extra_args_in_definition(run_poe_subproc): assert result.capture == "Poe => poe_test_echo hi 'lol!'\n" assert result.stdout == "hi lol!\n" assert result.stderr == "" + + +def test_ref_parses_named_args(run_poe_subproc): + result = run_poe_subproc( + "apologize", "--name=Davey", "--explain=Ah cannae dae that", project="refs" + ) + assert ( + result.capture == """Poe => echo 'I'"'"'m sorry Davey, Ah cannae dae that'\n""" + ) + assert result.stdout == "I'm sorry Davey, Ah cannae dae that\n" + assert result.stderr == "" + + +def test_ref_forwards_arguments_if_none_defined(run_poe_subproc): + result = run_poe_subproc( + "say-sorry", + "--name=Davey", + "--explain=Ah cannae dae that", + "--", + ",", + "anything", + "else?", + project="refs", + ) + assert result.capture == ( + """Poe => echo 'I'"'"'m sorry Davey, Ah cannae dae that'""" + """ , anything 'else?'\n""" + ) + assert result.stdout == "I'm sorry Davey, Ah cannae dae that , anything else?\n" + assert result.stderr == "" + + +def test_ref_forwards_arguments(run_poe_subproc): + result = run_poe_subproc( + "sorry-dave", + "I", + "cant", + "do", + "that", + "--", + "--", + ",", + "anything", + "else?", + project="refs", + ) + assert ( + result.capture + == """Poe => echo 'I'"'"'m sorry Dave, I cant do that' , anything 'else?'\n""" + ) + assert result.stdout == "I'm sorry Dave, I cant do that , anything else?\n" + assert result.stderr == "" + + # Pass extra args including -- if no args defined + result = run_poe_subproc("greet-funny", "OK", project="refs") + assert result.capture == "Poe => poe_test_echo hi 'lol!' OK\n" + assert result.stdout == "hi lol! OK\n" + assert result.stderr == ""