diff --git a/sisyphus/job.py b/sisyphus/job.py index 5901482..0e2a078 100644 --- a/sisyphus/job.py +++ b/sisyphus/job.py @@ -196,7 +196,6 @@ def __new__(cls: Type[T], *args, **kwargs) -> T: # Init def _sis_init(self, args, kwargs, parsed_args): - for key, arg in parsed_args.items(): if isinstance(arg, Job): logging.warning( @@ -211,6 +210,7 @@ def _sis_init(self, args, kwargs, parsed_args): self._sis_outputs = {} self._sis_keep_value = None self._sis_hold_job = False + self._sis_worker_wrapper = gs.worker_wrapper self._sis_blocks = set() self._sis_kwargs = parsed_args @@ -316,6 +316,7 @@ def __getstate__(self): "current_block", "_sis_cleanable_cache", "_sis_cleaned_or_not_cleanable", + "_sis_worker_wrapper", ]: if key in d: del d[key] diff --git a/sisyphus/task.py b/sisyphus/task.py index e874689..88d91e3 100644 --- a/sisyphus/task.py +++ b/sisyphus/task.py @@ -481,5 +481,8 @@ def get_worker_call(self, task_id=None): call += [gs.CMD_WORKER, os.path.relpath(self.path()), self.name()] if task_id is not None: call.append(str(task_id)) - call = gs.worker_wrapper(getattr(self, "_job", None), self.name(), call) + if hasattr(self, "_job"): + call = self._job._sis_worker_wrapper(self._job, self.name(), call) + else: + call = gs.worker_wrapper(None, self.name(), call) return call