Skip to content

Commit

Permalink
reuse EnvironmentModifier
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Feb 11, 2024
1 parent f1e83a4 commit 29c56f6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 25 deletions.
11 changes: 4 additions & 7 deletions sisyphus/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,9 @@ def _sis_init(self, args, kwargs, parsed_args):
self._sis_is_finished = False
self._sis_setup_since_restart = False

self._sis_environment = None
if gs.CLEANUP_ENVIRONMENT:
self._sis_environment = tools.EnvironmentModifier()
self._sis_environment.keep(gs.DEFAULT_ENVIRONMENT_KEEP)
self._sis_environment.set(gs.DEFAULT_ENVIRONMENT_SET)
self._sis_environ_updates = {}
self._sis_environment = tools.EnvironmentModifier(cleanup_env=gs.CLEANUP_ENVIRONMENT)
self._sis_environment.keep(gs.DEFAULT_ENVIRONMENT_KEEP)
self._sis_environment.set(gs.DEFAULT_ENVIRONMENT_SET)

if gs.AUTO_SET_JOB_INIT_ATTRIBUTES:
self.set_attrs(parsed_args)
Expand Down Expand Up @@ -1139,7 +1136,7 @@ def update_rqmt(self, task_name, rqmt):

def set_env(self, key: str, value: str):
"""this environment var will be set at job startup"""
self._sis_environ_updates[key] = value
self._sis_environment.set_var(key, value)

def tasks(self) -> Iterator[Task]:
"""
Expand Down
33 changes: 22 additions & 11 deletions sisyphus/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def try_get(v):


class execute_in_dir(object):

"""Object to be used by the with statement.
All code after the with will be executed in the given directory,
working directory will be changed back after with statement.
Expand All @@ -156,7 +155,6 @@ def __exit__(self, type, value, traceback):


class cache_result(object):

"""decorated to cache the result of a function for x_seconds"""

def __init__(self, cache_time=30, force_update=None, clear_cache=None):
Expand Down Expand Up @@ -481,33 +479,42 @@ class EnvironmentModifier:
A class to cleanup the environment before a job starts
"""

def __init__(self):
def __init__(self, *, cleanup_env: bool = True):
self.cleanup_env = cleanup_env
self.keep_vars = set()
self.set_vars = {}

def keep(self, var):
if type(var) == str:
if isinstance(var, str):
self.keep_vars.add(var)
else:
self.keep_vars.update(var)

def set(self, var, value=None):
if type(var) == dict:
if isinstance(var, dict):
self.set_vars.update(var)
else:
self.set_vars[var] = value

def set_var(self, key: str, value: str, *, allow_env_substitute: bool = False):
if not allow_env_substitute:
# Need to escape $ for string.Template.substitute below.
value = value.replace("$", "$$")
self.set_vars[key] = value

def modify_environment(self):
import os
import string

orig_env = dict(os.environ)
keys = list(os.environ.keys())
for k in keys:
if k not in self.keep_vars:
del os.environ[k]
if self.cleanup_env:
keys = list(os.environ.keys())
for k in keys:
if k not in self.keep_vars:
del os.environ[k]

for k, v in self.set_vars.items():
if type(v) == str:
if isinstance(v, str):
os.environ[k] = string.Template(v).substitute(orig_env)
else:
os.environ[k] = str(v)
Expand All @@ -516,7 +523,11 @@ def modify_environment(self):
logging.debug("environment var %s=%s" % (k, v))

def __repr__(self):
return repr(self.keep_vars) + " " + repr(self.set_vars)
return (
f"cleanup_env={self.cleanup_env} "
+ (f"keep={self.keep_vars!r} " if self.cleanup_env else "")
+ f"set={self.set_vars!r}"
)


class FinishedResultsCache:
Expand Down
8 changes: 1 addition & 7 deletions sisyphus/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,9 @@ def worker_helper(args):
gs.active_engine.init_worker(task)

# cleanup environment
if hasattr(task._job, "_sis_environment") and task._job._sis_environment:
if getattr(task._job, "_sis_environment", None):
task._job._sis_environment.modify_environment()

# Maybe update some env vars.
# Use getattr for compatibility with older serialized jobs.
if getattr(task._job, "_sis_environ_updates", None):
for k, v in task._job._sis_environ_updates.items():
os.environ[k] = v

try:
# run task
task.run(task_id, resume_job, logging_thread=logging_thread)
Expand Down

0 comments on commit 29c56f6

Please sign in to comment.