Skip to content

Commit

Permalink
WIP make args work as expected on ref tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
nat-n committed Feb 24, 2024
1 parent 43cc75b commit 7d09962
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 107 deletions.
6 changes: 2 additions & 4 deletions poethepoet/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def run_task(
if context is None:
context = self.get_run_context()
try:
return task.run(context=context, extra_args=task.invocation[1:])
return task.run(context=context)
except PoeException as error:
self.print_help(error=error)
return 1
Expand All @@ -175,9 +175,7 @@ def run_task_graph(self, task: "PoeTask") -> Optional[int]:
return self.run_task(stage_task, context)

try:
task_result = stage_task.run(
context=context, extra_args=stage_task.invocation[1:]
)
task_result = stage_task.run(context=context)
if task_result:
raise ExecutionError(
f"Task graph aborted after failed task {stage_task.name!r}"
Expand Down
6 changes: 4 additions & 2 deletions poethepoet/task/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,15 @@ def _get_argument_params(self, arg: ArgParams):

return result

def parse(self, extra_args: Sequence[str]):
parsed_args = vars(self.build_parser().parse_args(extra_args))
def parse(self, args: Sequence[str]) -> Dict[str, str]:
parsed_args = vars(self.build_parser().parse_args(args))

# Ensure positional args are still exposed by name even if they were parsed with
# alternate identifiers
for arg in self._args:
if isinstance(arg.get("positional"), str):
parsed_args[arg["name"]] = parsed_args[arg["positional"]]
del parsed_args[arg["positional"]]

# args named with dash case are converted to snake case before being exposed
return {name.replace("-", "_"): value for name, value in parsed_args.items()}
87 changes: 46 additions & 41 deletions poethepoet/task/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import sys
from os import environ
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand All @@ -10,7 +11,6 @@
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Type,
Union,
Expand Down Expand Up @@ -65,7 +65,7 @@ class PoeTask(metaclass=MetaPoeTask):
content: TaskContent
options: Dict[str, Any]
inheritance: TaskInheritance
named_args: Optional[Dict[str, str]] = None
_parsed_args: Optional[Tuple[Dict[str, str], Tuple[str, ...]]] = None

__options__: Dict[str, Union[Type, Tuple[Type, ...]]] = {}
__content_type__: Type = str
Expand Down Expand Up @@ -216,42 +216,49 @@ def resolve_task_type(

return None

def get_named_arg_values(self, env: "EnvVarsManager") -> Dict[str, str]:
try:
split_index = self.invocation.index("--")
parse_args = self.invocation[1:split_index]
except ValueError:
parse_args = self.invocation[1:]
def get_parsed_arguments(
self, env: "EnvVarsManager"
) -> Tuple[Dict[str, str], Tuple[str, ...]]:
if self._parsed_args is None:
all_args = self.invocation[1:]

if self.named_args is None:
self.named_args = self._parse_named_args(parse_args, env)

if not self.named_args:
return {}

return self.named_args
if args_def := self.options.get("args"):
from .args import PoeTaskArgs

def _parse_named_args(
self, extra_args: Sequence[str], env: "EnvVarsManager"
) -> Optional[Dict[str, str]]:
if args_def := self.options.get("args"):
from .args import PoeTaskArgs
try:
split_index = all_args.index("--")
option_args = all_args[:split_index]
extra_args = all_args[split_index + 1 :]
except ValueError:
option_args = all_args
extra_args = tuple()

self._parsed_args = (
PoeTaskArgs(args_def, self.name, self._ui.program_name, env).parse(
option_args
),
extra_args,
)

return PoeTaskArgs(args_def, self.name, self._ui.program_name, env).parse(
extra_args
)
else:
self._parsed_args = ({}, all_args)

return None
return self._parsed_args

def run(
self,
context: "RunContext",
extra_args: Sequence[str] = tuple(),
parent_env: Optional["EnvVarsManager"] = None,
) -> int:
"""
Run this task
"""

if environ.get("POE_DEBUG"):
task_type_key = self.__key__ # type: ignore[attr-defined]
print(f" * Running {task_type_key}:{self.name}")
print(f" . Invocation {self.invocation!r}")

upstream_invocations = self._get_upstream_invocations(context)

if context.dry and upstream_invocations.get("uses", {}):
Expand All @@ -265,21 +272,23 @@ def run(
)
return 0

return self._handle_run(
context,
extra_args,
context.get_task_env(
parent_env,
self.options.get("envfile"),
self.options.get("env"),
upstream_invocations["uses"],
),
task_env = context.get_task_env(
parent_env,
self.options.get("envfile"),
self.options.get("env"),
upstream_invocations["uses"],
)

if environ.get("POE_DEBUG"):
named_arg_values, extra_args = self.get_parsed_arguments(task_env)
print(f" . Parsed args {named_arg_values!r}")
print(f" . Extra args {extra_args!r}")

return self._handle_run(context, task_env)

def _handle_run(
self,
context: "RunContext",
extra_args: Sequence[str],
env: "EnvVarsManager",
) -> int:
"""
Expand All @@ -288,11 +297,7 @@ def _handle_run(
"""
raise NotImplementedError

def _get_executor(
self,
context: "RunContext",
env: "EnvVarsManager",
):
def _get_executor(self, context: "RunContext", env: "EnvVarsManager"):
return context.get_executor(
self.invocation,
env,
Expand Down Expand Up @@ -335,7 +340,7 @@ def _get_upstream_invocations(self, context: "RunContext"):
env = context.get_task_env(
None, self.options.get("envfile"), self.options.get("env")
)
env.update(self.get_named_arg_values(env))
env.update(self.get_parsed_arguments(env)[0])

self.__upstream_invocations = {
"deps": [
Expand Down
18 changes: 4 additions & 14 deletions poethepoet/task/cmd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import shlex
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union

from ..exceptions import PoeException
from .base import PoeTask
Expand All @@ -25,30 +25,20 @@ class CmdTask(PoeTask):
def _handle_run(
self,
context: "RunContext",
extra_args: Sequence[str],
env: "EnvVarsManager",
) -> int:
named_arg_values = self.get_named_arg_values(env)
named_arg_values, extra_args = self.get_parsed_arguments(env)
env.update(named_arg_values)

if named_arg_values:
# If named arguments are defined then pass only arguments following a double
# dash token: `--`
try:
split_index = extra_args.index("--")
extra_args = extra_args[split_index + 1 :]
except ValueError:
extra_args = tuple()

cmd = (*self._resolve_args(context, env), *extra_args)
cmd = (*self._resolve_commandline(context, env), *extra_args)

self._print_action(shlex.join(cmd), context.dry)

return self._get_executor(context, env).execute(
cmd, use_exec=self.options.get("use_exec", False)
)

def _resolve_args(self, context: "RunContext", env: "EnvVarsManager"):
def _resolve_commandline(self, context: "RunContext", env: "EnvVarsManager"):
from ..helpers.command import parse_poe_cmd, resolve_command_tokens
from ..helpers.command.ast_core import ParseError

Expand Down
8 changes: 4 additions & 4 deletions poethepoet/task/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
Expand Down Expand Up @@ -38,20 +37,21 @@ class ExprTask(PoeTask):
def _handle_run(
self,
context: "RunContext",
extra_args: Sequence[str],
env: "EnvVarsManager",
) -> int:
from ..helpers.python import format_class

named_arg_values = self.get_named_arg_values(env)
named_arg_values, extra_args = self.get_parsed_arguments(env)
env.update(named_arg_values)

# TODO: do something about extra_args, error?

imports = self.options.get("imports", tuple())

expr, env_values = self.parse_content(named_arg_values, env, imports)
argv = [
self.name,
*(env.fill_template(token) for token in extra_args),
*(env.fill_template(token) for token in self.invocation[1:]),
]

script = [
Expand Down
38 changes: 21 additions & 17 deletions poethepoet/task/ref.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union

from .base import PoeTask, TaskInheritance

Expand All @@ -21,34 +21,42 @@ class RefTask(PoeTask):
def _handle_run(
self,
context: "RunContext",
extra_args: Sequence[str],
env: "EnvVarsManager",
) -> int:
"""
Lookup and delegate to the referenced task
"""
import shlex

invocation = tuple(shlex.split(env.fill_template(self.content.strip())))
extra_args = [*invocation[1:], *extra_args]
task = self.from_config(
invocation[0],
named_arg_values, extra_args = self.get_parsed_arguments(env)
env.update(named_arg_values)

ref_invocation = [
env.fill_template(token)
for token in shlex.split(env.fill_template(self.content.strip()))
]
if not named_arg_values:
ref_invocation.extend(self.invocation[1:])
else:
ref_invocation.extend(extra_args)

self.ref_task = self.from_config(
ref_invocation[0],
self._config,
self._ui,
invocation,
invocation=tuple(ref_invocation),
inheritance=TaskInheritance.from_task(self),
)

if task.has_deps():
return self._run_task_graph(task, context, extra_args, env)
if self.ref_task.has_deps():
return self._run_task_graph(self.ref_task, context, env)

return task.run(context=context, extra_args=extra_args, parent_env=env)
return self.ref_task.run(context=context, parent_env=env)

def _run_task_graph(
self,
task: "PoeTask",
context: "RunContext",
extra_args: Sequence[str],
env: "EnvVarsManager",
) -> int:
from ..exceptions import ExecutionError
Expand All @@ -60,13 +68,9 @@ def _run_task_graph(
for stage_task in stage:
if stage_task == task:
# The final sink task gets special treatment
return task.run(
context=context, extra_args=extra_args, parent_env=env
)
return task.run(context=context, parent_env=env)

task_result = stage_task.run(
context=context, extra_args=stage_task.invocation[1:]
)
task_result = stage_task.run(context=context)
if task_result:
raise ExecutionError(
f"Task graph aborted after failed task {stage_task.name!r}"
Expand Down
9 changes: 5 additions & 4 deletions poethepoet/task/script.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
import shlex
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union

from ..exceptions import ExpressionParseError
from .base import PoeTask
Expand All @@ -27,20 +27,21 @@ class ScriptTask(PoeTask):
def _handle_run(
self,
context: "RunContext",
extra_args: Sequence[str],
env: "EnvVarsManager",
) -> int:
from ..helpers.python import format_class

named_arg_values = self.get_named_arg_values(env)
named_arg_values, extra_args = self.get_parsed_arguments(env)
env.update(named_arg_values)

# TODO: do something about extra_args, error?

target_module, function_call = self.parse_content(named_arg_values)
function_ref = function_call[: function_call.index("(")]

argv = [
self.name,
*(env.fill_template(token) for token in extra_args),
*(env.fill_template(token) for token in self.invocation[1:]),
]

# TODO: check whether the project really does use src layout, and don't do
Expand Down
Loading

0 comments on commit 7d09962

Please sign in to comment.