Skip to content

Commit

Permalink
Merge pull request #666 from pyiron/node_status
Browse files Browse the repository at this point in the history
Node status
  • Loading branch information
liamhuber authored May 8, 2023
2 parents db8c670 + 5e833b9 commit 0311eb7
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 2 deletions.
11 changes: 11 additions & 0 deletions pyiron_contrib/workflow/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,13 @@ def ready(self):
return True

def update(self, value):
self._before_update()
self.value = value
self._after_update()

def _before_update(self):
pass

def _after_update(self):
pass

Expand Down Expand Up @@ -246,6 +250,13 @@ def wait_for_update(self):
def ready(self):
return not self.waiting_for_update and super().ready

def _before_update(self):
if self.node.running:
raise RuntimeError(
f"Parent node {self.node.label} of {self.label} is running, so value "
f"cannot be updated."
)

def _after_update(self):
self.waiting_for_update = False
self.node.update()
Expand Down
19 changes: 17 additions & 2 deletions pyiron_contrib/workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ def __init__(
workflow: Optional[Workflow] = None,
**kwargs
):
self.running = False
self.failed = False
self.node_function = node_function
self.label = label if label is not None else node_function.__name__

Expand Down Expand Up @@ -457,7 +459,18 @@ def update(self) -> None:
self.run()

def run(self) -> None:
function_output = self.node_function(**self.inputs.to_value_dict())
if self.running:
raise RuntimeError(f"{self.label} is already running")

self.running = True
self.failed = False

try:
function_output = self.node_function(**self.inputs.to_value_dict())
except Exception as e:
self.running = False
self.failed = True
raise e

if len(self.outputs) == 1:
function_output = (function_output,)
Expand All @@ -470,6 +483,8 @@ def run(self) -> None:
for channel_name in self.channels_requiring_update_after_run:
self.inputs[channel_name].wait_for_update()

self.running = False

def __call__(self) -> None:
self.run()

Expand All @@ -480,7 +495,7 @@ def disconnect(self):

@property
def ready(self) -> bool:
return self.inputs.ready
return not (self.running or self.failed) and self.inputs.ready

@property
def connected(self) -> bool:
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/workflow/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
class DummyNode:
def __init__(self):
self.foo = [0]
self.running = False
self.label = "node_label"

def update(self):
self.foo.append(self.foo[-1] + 1)
Expand Down Expand Up @@ -122,6 +124,10 @@ def test_update(self):
msg="Value should have been passed downstream"
)

self.ni1.node.running = True
with self.assertRaises(RuntimeError):
self.no.update(42)


class TestSignalChannels(TestCase):
def setUp(self) -> None:
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/workflow/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@


class DummyNode:
def __init__(self):
self.running = False
self.label = "node_label"

def update(self):
pass

Expand Down
40 changes: 40 additions & 0 deletions tests/unit/workflow/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,45 @@ def times_two(y):
msg="Running the upstream node should trigger a run here"
)

def test_statuses(self):
n = Node(plus_one, "p1")
self.assertTrue(n.ready)
self.assertFalse(n.running)
self.assertFalse(n.failed)

# Can't really test "running" until we have a background executor, so fake a bit
n.running = True
with self.assertRaises(RuntimeError):
# Running nodes can't be run
n.run()
n.running = False

n.inputs.x = "Can't be added together with an int"
with self.assertRaises(TypeError):
# The function error should get passed up
n.run()
self.assertFalse(n.ready)
# self.assertFalse(n.running)
self.assertTrue(n.failed)

n.inputs.x = 1
n.update()
self.assertFalse(
n.ready,
msg="Update _checks_ for ready, so should still have failed status"
)
# self.assertFalse(n.running)
self.assertTrue(n.failed)

n.run()
self.assertTrue(
n.ready,
msg="A manual run() call bypasses checks, so readiness should reset"
)
self.assertTrue(n.ready)
# self.assertFalse(n.running)
self.assertFalse(n.failed, msg="Re-running should reset failed status")


@skipUnless(version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+")
class TestFastNode(TestCase):
Expand All @@ -119,6 +158,7 @@ def test_instantiation(self):
with self.assertRaises(ValueError):
missing_defaults_should_fail = FastNode(no_default, "z")


@skipUnless(version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+")
class TestSingleValueNode(TestCase):
def test_instantiation(self):
Expand Down

0 comments on commit 0311eb7

Please sign in to comment.