From e05bbcd38cd88fcb07f1e91443f93322ca0d2bd6 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 16 Dec 2024 11:07:22 +0100 Subject: [PATCH] for_all_nodes, handle exception/exit Fix #164 --- sisyphus/graph.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/sisyphus/graph.py b/sisyphus/graph.py index 4a4bed4..814e986 100644 --- a/sisyphus/graph.py +++ b/sisyphus/graph.py @@ -545,6 +545,7 @@ def for_all_nodes(self, f, nodes=None, bottom_up=False, *, pool: Optional[Thread pool_lock = threading.Lock() finished_lock = threading.Lock() + stopped_event = threading.Event() if not pool: pool = self.pool @@ -555,6 +556,8 @@ def runner(job): """ sis_id = job._sis_id() with pool_lock: + if stopped_event.is_set(): + return if sis_id not in visited: visited[sis_id] = pool.apply_async( tools.default_handle_exception_interrupt_main_thread(runner_helper), (job,) @@ -564,6 +567,8 @@ def runner_helper(job): """ :param Job job: """ + if stopped_event.is_set(): + return # make sure all inputs are updated job._sis_runnable() nonlocal finished @@ -583,12 +588,17 @@ def runner_helper(job): with finished_lock: finished += 1 - for node in nodes: - runner(node) + try: + for node in nodes: + runner(node) - # Check if all jobs are finished - while len(visited) != finished: - time.sleep(0.1) + # Check if all jobs are finished + while len(visited) != finished: + time.sleep(0.1) + except BaseException: + with pool_lock: + stopped_event.set() + raise # Check again and create output set out = set()