Skip to content

Commit

Permalink
Merge pull request #584 from NicolasGensollen/add-graph-checksums-if-…
Browse files Browse the repository at this point in the history
…missing

[ENH] Add `_graph_checksums` to `input_spec` if missing
  • Loading branch information
satra authored Sep 28, 2022
2 parents 9053ba7 + a7374aa commit 8c8a79c
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 31 deletions.
90 changes: 60 additions & 30 deletions pydra/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,64 @@ def _reset(self):
task._reset()


def _sanitize_input_spec(
input_spec: ty.Union[SpecInfo, ty.List[str]],
wf_name: str,
) -> SpecInfo:
"""Makes sure the provided input specifications are valid.
If the input specification is a list of strings, this will
build a proper SpecInfo object out of it.
Parameters
----------
input_spec : SpecInfo or List[str]
Input specification to be sanitized.
wf_name : str
The name of the workflow for which the input specifications
are sanitized.
Returns
-------
input_spec : SpecInfo
Sanitized input specifications.
Raises
------
ValueError
If provided `input_spec` is None.
"""
graph_checksum_input = ("_graph_checksums", ty.Any)
if input_spec:
if isinstance(input_spec, SpecInfo):
if not any([x == BaseSpec for x in input_spec.bases]):
raise ValueError("Provided SpecInfo must have BaseSpec as it's base.")
if "_graph_checksums" not in {f[0] for f in input_spec.fields}:
input_spec.fields.insert(0, graph_checksum_input)
return input_spec
else:
return SpecInfo(
name="Inputs",
fields=[graph_checksum_input]
+ [
(
nm,
attr.ib(
type=ty.Any,
metadata={
"help_string": f"{nm} input from {wf_name} workflow"
},
),
)
for nm in input_spec
],
bases=(BaseSpec,),
)
else:
raise ValueError(f"Empty input_spec provided to Workflow {wf_name}.")


class Workflow(TaskBase):
"""A composite task with structure of computational graph."""

Expand All @@ -806,7 +864,7 @@ def __init__(
audit_flags: AuditFlag = AuditFlag.NONE,
cache_dir=None,
cache_locations=None,
input_spec: ty.Optional[ty.Union[ty.List[ty.Text], SpecInfo, BaseSpec]] = None,
input_spec: ty.Optional[ty.Union[ty.List[ty.Text], SpecInfo]] = None,
cont_dim=None,
messenger_args=None,
messengers=None,
Expand Down Expand Up @@ -842,35 +900,7 @@ def __init__(
TODO
"""
if input_spec:
if isinstance(input_spec, BaseSpec):
self.input_spec = input_spec
elif isinstance(input_spec, SpecInfo):
if not any([x == BaseSpec for x in input_spec.bases]):
raise ValueError(
"Provided SpecInfo must have BaseSpec as it's base."
)
self.input_spec = input_spec
else:
self.input_spec = SpecInfo(
name="Inputs",
fields=[("_graph_checksums", ty.Any)]
+ [
(
nm,
attr.ib(
type=ty.Any,
metadata={
"help_string": f"{nm} input from {name} workflow"
},
),
)
for nm in input_spec
],
bases=(BaseSpec,),
)
else:
raise ValueError("Empty input_spec provided to Workflow")
self.input_spec = _sanitize_input_spec(input_spec, name)

self.output_spec = output_spec

Expand Down
2 changes: 1 addition & 1 deletion pydra/engine/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_wf_specinfo_input_spec():
name="workflow",
input_spec=input_spec,
)
for x in ["a", "b"]:
for x in ["a", "b", "_graph_checksums"]:
assert hasattr(wf.inputs, x)
assert wf.inputs.a == ""
assert wf.inputs.b == {"foo": 1, "bar": False}
Expand Down

0 comments on commit 8c8a79c

Please sign in to comment.