Skip to content

Commit

Permalink
fixes #179. Adds substitution/evaluation for backend settings
Browse files Browse the repository at this point in the history
  • Loading branch information
o-smirnov committed Sep 19, 2023
1 parent c3fe211 commit 7a6a610
Show file tree
Hide file tree
Showing 10 changed files with 300 additions and 87 deletions.
124 changes: 109 additions & 15 deletions scabha/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from pyparsing import common
from functools import reduce
import operator
import dataclasses

from .substitutions import SubstitutionError, SubstitutionContext
from .basetypes import Unresolved, UNSET
from .exceptions import *

import typing
from typing import Dict, List, Any

from omegaconf import DictConfig, ListConfig

_parser = None

Expand Down Expand Up @@ -487,26 +488,113 @@ def evaluate(self, value, sublocation: List[str] = []):
except Exception as exc:
raise FormulaError(f"{'.'.join(self.location)}: evaluation of '{value}' failed", exc, tb=True)
else:
return self._resolve(value, in_formula=False)
try:
return self._resolve(value, in_formula=False)
except Exception as exc:
raise SubstitutionError(f"{'.'.join(self.location)}: evaluation of '{value}' failed", exc)
finally:
self.location = self.location[:loclen]

def evaluate_object(self, obj: Any,
sublocation = [],
raise_substitution_errors: bool = True,
recursion_level: int = 1,
verbose: bool = False):
# string? evaluate directly and return
if type(obj) is str:
try:
return self.evaluate(obj, sublocation=sublocation)
except AttributeError as err:
if raise_substitution_errors:
raise
return Unresolved(errors=[err])
except SubstitutionError as err:
if raise_substitution_errors:
raise
return Unresolved(errors=[err])

# helper function
def update(value, sloc):
if type(value) is Unresolved:
return value, False
subloc = sublocation + [sloc]
if verbose:
print(f"{subloc}: {value} ...")
new_value = self.evaluate_object(value, raise_substitution_errors=raise_substitution_errors,
recursion_level=recursion_level, verbose=verbose,
sublocation=subloc)
if verbose:
print(f"{subloc}: {value} -> {new_value}")
# UNSET return means delete or revert to default
if new_value is UNSET:
raise SubstitutionError(f"{'.'.join(self.location + subloc)}: UNSET not allowed here")
# compare
if isinstance(value, (dict, DictConfig, list, ListConfig)) or dataclasses.is_dataclass(value):
updated = value is not new_value
else:
updated = value != new_value
return new_value, updated

obj_out = obj
# recurse into containers?
if recursion_level:
recursion_level -= 1
# use evaluate_dict() to recurse into dicts
if isinstance(obj, (dict, DictConfig)):
for key, value in obj.items():
new_value, value_updated = update(value, key)
new_key = self.evaluate(key, sublocation=sublocation)
if new_key != key:
value_updated = True
if value_updated:
if obj_out is obj:
obj_out = obj.copy()
if new_key != key:
del obj_out[key]
key = new_key
obj_out[key] = new_value
# recurse into lists
elif isinstance(obj, (list, ListConfig)):
for i, value in enumerate(obj):
new_value, updated = update(value, f"#{i}")
if updated:
if obj_out is obj:
obj_out = obj.copy()
obj_out[i] = new_value
# recurse into dataclasses
elif dataclasses.is_dataclass(obj):
newvals = {}
for fld in dataclasses.fields(obj):
value = getattr(obj, fld.name)
new_value, updated = update(value, fld.name)
if updated:
newvals[fld.name] = new_value
if newvals:
obj_out = dataclasses.replace(obj, **newvals)

return obj_out


def evaluate_dict(self, params: Dict[str, Any],
corresponding_ns: typing.Optional[Dict[str, Any]] = None,
defaults: Dict[str, Any] = {},
sublocation = [],
raise_substitution_errors: bool = True,
verbose: bool =False):
params = params.copy()
recursive: bool = False,
verbose: bool = False):
params_out = params
for name, value in list(params.items()):
if type(value) is not Unresolved:
retry = True
while retry:
retry = False
if verbose: # or type(value) is UNSET:
print(f"{name}: {value} ...")
if type(value) is Unresolved:
continue
#
retry = True
while retry:
retry = False
if verbose: # or type(value) is UNSET:
print(f"{name}: {value} ...")
if type(value) is str:
try:
new_value = self.evaluate(value, sublocation=[name])
new_value = self.evaluate(value, sublocation=sublocation + [name])
except AttributeError as err:
if raise_substitution_errors:
raise
Expand All @@ -519,21 +607,27 @@ def evaluate_dict(self, params: Dict[str, Any],
print(f"{name}: {value} -> {new_value}")
# UNSET return means delete or revert to default
if new_value is UNSET:
if params_out is params:
params_out = params.copy()
# if value is in defaults, try to evaluate that instead
if name in defaults and defaults[name] is not UNSET:
value = params[name] = defaults[name]
value = params_out[name] = defaults[name]
if corresponding_ns:
corresponding_ns[name] = str(defaults[name])
retry = True
else:
del params[name]
del params_out[name]
if corresponding_ns and name in corresponding_ns:
del corresponding_ns[name]
elif new_value is not value and new_value != value:
params[name] = new_value
if params_out is params:
params_out = params.copy()
params_out[name] = new_value
if corresponding_ns:
corresponding_ns[name] = new_value
return params
return params_out



if __name__ == "__main__":
pass
10 changes: 10 additions & 0 deletions scabha/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,20 @@ def dtype_from_str(dtype_str: str):
def is_file_type(dtype):
return dtype in (File, Directory, MS)


def is_filelist_type(dtype):
return dtype in (List[File], List[Directory], List[MS])


def evaluate_and_substitute_object(obj: Any,
subst: SubstitutionNS,
recursion_level: int = 1,
location: List[str] = []):
with substitutions_from(subst, raise_errors=True) as context:
evaltor = Evaluator(subst, context, location=location, allow_unresolved=False)
return evaltor.evaluate_object(obj, raise_substitution_errors=True, recursion_level=recursion_level)


def evaluate_and_substitute(inputs: Dict[str, Any],
subst: SubstitutionNS,
corresponding_ns: SubstitutionNS,
Expand Down
16 changes: 12 additions & 4 deletions stimela/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,20 @@ def _call_backends(backend_opts: StimelaBackendOptions, log: logging.Logger, met
else:
log_exception(exc1, log=log)

initialized = None

def init_backends(backend_opts: StimelaBackendOptions, log: logging.Logger):
return _call_backends(backend_opts, log, "init", "initializing")

def close_backends(backend_opts: StimelaBackendOptions, log: logging.Logger):
return _call_backends(backend_opts, log, "close", "closing")
global initialized
if initialized is None:
initialized = backend_opts
return _call_backends(backend_opts, log, "init", "initializing")

def close_backends(log: logging.Logger):
global initialized
if initialized is not None:
result = _call_backends(initialized, log, "close", "closing")
initialized = None
return result

def cleanup_backends(backend_opts: StimelaBackendOptions, log: logging.Logger):
return _call_backends(backend_opts, log, "cleanup", "cleaning up")
Expand Down
67 changes: 56 additions & 11 deletions stimela/backends/kube/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
import getpass
import logging
import time
import pwd
import grp
import os

import stimela
from scabha.basetypes import EmptyDictDefault, DictDefault, EmptyListDefault, ListDefault
from stimela.exceptions import BackendError

session_id = secrets.token_hex(8)
session_user = getpass.getuser()
Expand Down Expand Up @@ -50,7 +54,7 @@ def init(backend: 'stimela.backend.StimelaBackendOptions', log: logging.Logger):

def close(backend: 'stimela.backend.StimelaBackendOptions', log: logging.Logger):
from . import infrastructure
if not AVAILABLE:
if AVAILABLE:
infrastructure.close(backend, log)

def cleanup(backend: 'stimela.backend.StimelaBackendOptions', log: logging.Logger):
Expand All @@ -63,7 +67,7 @@ def run(cab: 'stimela.kitchen.cab.Cab', params: Dict[str, Any], fqname: str,
from . import run_kube
return run_kube.run(cab=cab, params=params, fqname=fqname, backend=backend, log=log, subst=subst)

_kube_client = _kube_config = None
_kube_client = _kube_config = _kube_context = None

def get_kube_api(context: Optional[str]=None):
global _kube_client
Expand All @@ -72,7 +76,10 @@ def get_kube_api(context: Optional[str]=None):

if _kube_config is None:
_kube_config = True
_kube_context = context
kubernetes.config.load_kube_config(context=context)
elif context != _kube_context:
raise BackendError(f"k8s context has changed (was {_kube_context}, now {context}), this is not permitted")

return core_v1_api.CoreV1Api(), CustomObjectsApi()

Expand Down Expand Up @@ -125,7 +132,6 @@ class StartupOptions(object):
report_pvcs: bool = True # report any transient PVCs
cleanup_pvcs: bool = True # cleanup any transient PVCs

context: Optional[str] = None # k8s context -- use default if not given
on_exit: ExitOptions = ExitOptions() # startup behaviour options
on_startup: StartupOptions = StartupOptions() # cleanup behaviour options

Expand All @@ -139,14 +145,24 @@ class Volume(object):
provision_timeout: int = 1200 # How long to wait for provisioning before timing out
mount: Optional[str] = None # mount point

from_snapshot: Optional[str] = None # create from snapshot

# Status of PVC at start of sesssion or at start of step:
# must_exist: reuse, error if it doesn't exist
# allow_reuse: reuse if exists, else create
# recreate: delete if exists and recreate
# cant_exist: report an error if it exists, else create
ExistPolicy = Enum("ExisttPolicy", "must_exist allow_reuse recreate cant_exist", module=__name__)
at_start: ExistPolicy = ExistPolicy.allow_reuse
at_step: ExistPolicy = ExistPolicy.allow_reuse

# lifecycle policy
# persist: leave the PVC in place for future re-use
# session: delete at end of stimela run
# step: delete at end of step (only applies to per-step PVCs)
Lifecycle = Enum("Lifecycle", "persist session step", module=__name__)
lifecycle: Lifecycle = Lifecycle.session

reuse: bool = True # if a PVC with that name already exists, reuse it, else error
append_id: bool = True # for session- or step-lifecycle PVCs, append ID to name

def __post_init__ (self):
Expand Down Expand Up @@ -184,8 +200,10 @@ class LocalMount(object):
# infrastructure settings are global and can't be changed per cab or per step
infrastructure: Infrastructure = Infrastructure()

namespace: Optional[str] = None
dask_cluster: Optional[DaskCluster] = None
context: Optional[str] = None # k8s context -- use default if not given -- can't change
namespace: Optional[str] = None # k8s namespace

dask_cluster: Optional[DaskCluster] = None # if set, a DaskJob will be created
service_account: str = "compute-runner"
kubectl_path: str = "kubectl"

Expand All @@ -201,23 +219,50 @@ class LocalMount(object):

always_pull_images: bool = False # change to True to repull

debug_mode: bool = False # in debug mode, payload is not run
@dataclass
class DebugOptions(object):
pause_on_start: bool = False # pause instead of running payload
pause_on_cleanup: bool = False # pause before attempting cleanup

debug: DebugOptions = DebugOptions()

job_pod: KubePodSpec = KubePodSpec()

# if >0, events will be collected and reported
verbose_events: int = 0
# format string for reporting kubernetes events, this can include rich markup
verbose_event_format: str = "\[k8s event type: {event.type}, reason: {event.reason}] {event.message}"
verbose_event_format: str = "=NOSUBST('\[k8s event type: {event.type}, reason: {event.reason}] {event.message}')"
verbose_event_colors: Dict[str, str] = DictDefault(
warning="blue", error="yellow", default="grey50")

# user and group IDs -- if None, use local user
uid: Optional[int] = None
gid: Optional[int] = None
@dataclass
class UserInfo(object):
# user and group names and IDs -- if None, use local user
name: Optional[str] = None
group: Optional[str] = None
uid: Optional[int] = None
gid: Optional[int] = None
gecos: Optional[str] = None
home: Optional[str] = None # home dir inside container, default is /home/{user}
home_ramdisk: bool = True # home dir mounted as RAM disk, else local disk
inject_nss: bool = True # inject user info for NSS_WRAPPER

user: UserInfo = UserInfo()

# user-defined set of pod types -- each is a pod spec structure keyed by pod_type
predefined_pod_specs: Dict[str, Dict[str, Any]] = EmptyDictDefault()


KubeBackendSchema = OmegaConf.structured(KubeBackendOptions)

_uid = os.getuid()
_gid = os.getgid()

session_user_info = KubeBackendOptions.UserInfo(
name=session_user,
group=grp.getgrgid(_gid).gr_name,
uid=_uid,
gid=_gid,
home=f"/home/{session_user}",
gecos=pwd.getpwuid(_uid).pw_gecos
)
Loading

0 comments on commit 7a6a610

Please sign in to comment.