Skip to content

Commit

Permalink
Factory: also parse command-line options
Browse files Browse the repository at this point in the history
Co-authored-by: Michael R. Crusoe <[email protected]>
  • Loading branch information
suecharo and mr-c committed Oct 24, 2024
1 parent e3f6cf7 commit 6be8e4d
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 16 deletions.
87 changes: 74 additions & 13 deletions cwltool/factory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
"""Wrap a CWL document as a callable Python object."""

import argparse
import functools
import os
import sys
from typing import Any, Optional, Union

from . import load_tool
from .context import LoadingContext, RuntimeContext
from .argparser import arg_parser
from .context import LoadingContext, RuntimeContext, getdefault
from .errors import WorkflowException
from .executors import JobExecutor, SingleJobExecutor
from .main import find_default_container
from .process import Process
from .utils import CWLObjectType
from .resolver import tool_resolver
from .secrets import SecretStore
from .utils import DEFAULT_TMP_PREFIX, CWLObjectType


class WorkflowStatus(Exception):
Expand All @@ -25,11 +34,15 @@ def __init__(self, t: Process, factory: "Factory") -> None:
self.t = t
self.factory = factory

def __call__(self, **kwargs):
# type: (**Any) -> Union[str, Optional[CWLObjectType]]
runtime_context = self.factory.runtime_context.copy()
runtime_context.basedir = os.getcwd()
out, status = self.factory.executor(self.t, kwargs, runtime_context)
def __call__(self, **kwargs: Any) -> Union[str, Optional[CWLObjectType]]:
"""
Execute the process.
:raise WorkflowStatus: If the result is not a success.
"""
if not self.factory.runtime_context.basedir:
self.factory.runtime_context.basedir = os.getcwd()
out, status = self.factory.executor(self.t, kwargs, self.factory.runtime_context)
if status != "success":
raise WorkflowStatus(out, status)
else:
Expand All @@ -47,18 +60,24 @@ def __init__(
executor: Optional[JobExecutor] = None,
loading_context: Optional[LoadingContext] = None,
runtime_context: Optional[RuntimeContext] = None,
argsl: Optional[list[str]] = None,
args: Optional[argparse.Namespace] = None,
) -> None:
"""Create a CWL Process factory from a CWL document."""
if argsl is not None:
args = arg_parser().parse_args(argsl)
if executor is None:
executor = SingleJobExecutor()
self.executor = executor
self.executor: JobExecutor = SingleJobExecutor()
else:
self.executor = executor
if runtime_context is None:
self.runtime_context = RuntimeContext()
self.runtime_context = RuntimeContext(vars(args) if args else {})
self._fix_runtime_context()
else:
self.runtime_context = runtime_context
if loading_context is None:
self.loading_context = LoadingContext()
self.loading_context.singularity = self.runtime_context.singularity
self.loading_context.podman = self.runtime_context.podman
self.loading_context = LoadingContext(vars(args) if args else {})
self._fix_loading_context(self.runtime_context)
else:
self.loading_context = loading_context

Expand All @@ -68,3 +87,45 @@ def make(self, cwl: Union[str, dict[str, Any]]) -> Callable:
if isinstance(load, int):
raise WorkflowException("Error loading tool")
return Callable(load, self)

def _fix_loading_context(self, runtime_context: RuntimeContext) -> None:
self.loading_context.resolver = getdefault(self.loading_context.resolver, tool_resolver)
self.loading_context.singularity = runtime_context.singularity
self.loading_context.podman = runtime_context.podman

def _fix_runtime_context(self) -> None:
self.runtime_context.basedir = os.getcwd()
self.runtime_context.find_default_container = functools.partial(
find_default_container, default_container=None, use_biocontainers=None
)

if sys.platform == "darwin":
default_mac_path = "/private/tmp/docker_tmp"
if self.runtimeContext.tmp_outdir_prefix == DEFAULT_TMP_PREFIX:
self.runtimeContext.tmp_outdir_prefix = default_mac_path

for dirprefix in ("tmpdir_prefix", "tmp_outdir_prefix", "cachedir"):
if (
getattr(self.runtime_context, dirprefix)
and getattr(self.runtime_context, dirprefix) != DEFAULT_TMP_PREFIX
):
sl = (
"/"
if getattr(self.runtime_context, dirprefix).endswith("/")
or dirprefix == "cachedir"
else ""
)
setattr(
self.runtime_context,
dirprefix,
os.path.abspath(getattr(self.runtime_context, dirprefix)) + sl,
)
if not os.path.exists(os.path.dirname(getattr(self.runtime_context, dirprefix))):
try:
os.makedirs(os.path.dirname(getattr(self.runtime_context, dirprefix)))
except Exception as e:
print("Failed to create directory: %s", e)

self.runtime_context.secret_store = getdefault(
self.runtime_context.secret_store, SecretStore()
)
2 changes: 1 addition & 1 deletion tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_replace_default_stdout_stderr() -> None:
runtime_context = RuntimeContext()
runtime_context.default_stdout = subprocess.DEVNULL # type: ignore
runtime_context.default_stderr = subprocess.DEVNULL # type: ignore
factory = Factory(None, None, runtime_context)
factory = Factory(runtime_context=runtime_context)
echo = factory.make(get_data("tests/echo.cwl"))

assert echo(inp="foo") == {"out": "foo\n"}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_sequential_workflow(tmp_path: Path) -> None:
runtime_context = RuntimeContext()
runtime_context.outdir = str(tmp_path)
runtime_context.select_resources = executor.select_resources
factory = Factory(executor, None, runtime_context)
factory = Factory(executor=executor, runtime_context=runtime_context)
echo = factory.make(get_data(test_file))
file_contents = {"class": "File", "location": get_data("tests/wf/whale.txt")}
assert echo(file1=file_contents) == {"count_output": 16}
Expand All @@ -25,7 +25,7 @@ def test_sequential_workflow(tmp_path: Path) -> None:
def test_scattered_workflow() -> None:
test_file = "tests/wf/scatter-wf4.cwl"
job_file = "tests/wf/scatter-job2.json"
factory = Factory(MultithreadedJobExecutor())
factory = Factory(executor=MultithreadedJobExecutor())
echo = factory.make(get_data(test_file))
with open(get_data(job_file)) as job:
assert echo(**json.load(job)) == {"out": ["foo one three", "foo two four"]}

0 comments on commit 6be8e4d

Please sign in to comment.