Skip to content

Commit

Permalink
Merge pull request #667 from pfebrer/workflow_state
Browse files Browse the repository at this point in the history
ensure that workflows have (and transmit) the right state
  • Loading branch information
zerothi authored Jan 4, 2024
2 parents 25c1b3a + bd0ba27 commit e10b77b
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 1 deletion.
48 changes: 48 additions & 0 deletions src/sisl/nodes/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from sisl.nodes import Node, Workflow
from sisl.nodes.node import ConstantNode
from sisl.nodes.utils import traverse_tree_forward


Expand Down Expand Up @@ -186,3 +187,50 @@ def some_multiplication(a, b, c, d, e, f):

assert first_triple_sum.nodes[triple_sum._sum_key]._nupdates == 1
assert first_triple_sum.nodes[f"{triple_sum._sum_key}_1"]._nupdates == 2


def test_outdated(triple_sum):
wf = triple_sum(2, 3, 4)

# Test that the workflow knows it is outdated in all possible situations
assert wf._outdated == True

wf.get()
assert wf._outdated == False

wf.update_inputs(a=3)
assert wf._outdated == True

inp = ConstantNode(3)
wf.update_inputs(a=inp)

wf.get()
assert wf._outdated == False
inp.update_inputs(value=5)
assert wf._outdated == True

# Now test that the workflow lets connected nodes that they are outdated.
out = ConstantNode(wf)

out.get()
assert out._outdated == False

wf.update_inputs(a=4)
assert out._outdated == True


def test_errored(triple_sum):
wf = triple_sum(2, 3, 4)

assert wf._errored == False
wf.get()
assert wf._errored == False

wf.update_inputs(a={"c": "f"})
assert wf._errored == False
with pytest.raises(Exception):
wf.get()
assert wf._errored == True

wf.update_inputs(a={"c": "m"})
assert wf._errored == False
25 changes: 24 additions & 1 deletion src/sisl/nodes/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,22 @@ class Workflow(Node):

network = NetworkDescriptor()

def _set_outdated(self, value: bool):
self.nodes.output._outdated = value

def _get_outdated(self) -> bool:
return self.nodes.output._outdated

_outdated = property(fset=_set_outdated, fget=_get_outdated)

def _receive_output_link(self, node):
super()._receive_output_link(node)
self.nodes.output._receive_output_link(node)

def _receive_output_unlink(self, node):
super()._receive_output_unlink(node)
self.nodes.output._receive_output_unlink(node)

@classmethod
def from_node_tree(cls, output_node: Node, workflow_name: Union[str, None] = None):
"""Creates a workflow class from a node.
Expand Down Expand Up @@ -975,7 +991,12 @@ def get(self):
It will recompute it if necessary.
"""
return self.nodes.output.get()
self._errored = False
try:
return self.nodes.output.get()
except:
self._errored = True
raise

def update_inputs(self, **inputs):
"""Updates the inputs of the workflow."""
Expand All @@ -990,6 +1011,8 @@ def update_inputs(self, **inputs):

self._inputs.update(inputs)

self._receive_outdated()

return self

def _get_output(self):
Expand Down

0 comments on commit e10b77b

Please sign in to comment.