Skip to content

Commit

Permalink
Project level defaults (#506)
Browse files Browse the repository at this point in the history
* Add parsing of project level defaults

* Support global defaults in live mode

* Support global defaults in batch mode

* Add changelog

* Allow to place project.yml under .neuro folder
  • Loading branch information
romasku authored Jul 8, 2021
1 parent b169d3d commit eedb9e7
Show file tree
Hide file tree
Showing 13 changed files with 622 additions and 121 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.D/506.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added new sections `defaults`, `images`, `volumes` to the `project.yml` file. The work the same as the do
in `live`/`batch` except they are global -- everything defined in `project.yml` applies to all workflows.
6 changes: 5 additions & 1 deletion neuro_flow/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class Project(Base):
owner: SimpleOptStrExpr # user name can contain "-"
role: SimpleOptStrExpr

images: Optional[Mapping[str, "Image"]] = field(metadata={"allow_none": True})
volumes: Optional[Mapping[str, "Volume"]] = field(metadata={"allow_none": True})
defaults: Optional["BatchFlowDefaults"] = field(metadata={"allow_none": True})


# There are 'batch' for pipelined mode and 'live' for interactive one
# (while 'batches' are technically just non-interactive jobs.
Expand Down Expand Up @@ -265,7 +269,7 @@ class TaskModuleCall(BaseModuleCall, TaskBase):


@dataclass(frozen=True)
class FlowDefaults(Base):
class FlowDefaults(WithSpecifiedFields, Base):
tags: Optional[BaseExpr[SequenceT]] = field(metadata={"allow_none": True})

env: Optional[BaseExpr[MappingT]] = field(metadata={"allow_none": True})
Expand Down
25 changes: 18 additions & 7 deletions neuro_flow/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import aiohttp
import logging
import secrets
import sys
import tarfile
Expand Down Expand Up @@ -45,6 +46,9 @@
from async_generator import asynccontextmanager


log = logging.getLogger(__name__)


@dataclasses.dataclass(frozen=True)
class ActionSpec:
scheme: str
Expand Down Expand Up @@ -203,13 +207,20 @@ def workspace(self) -> LocalPath:

@asynccontextmanager
async def project_stream(self) -> AsyncIterator[Optional[TextIO]]:
for ext in (".yml", ".yaml"):
path = self._workspace / "project"
path = path.with_suffix(ext)
if path.exists():
with path.open() as f:
yield f
return
for dir in (self._config_dir, self._workspace):
for ext in (".yml", ".yaml"):
path = dir / "project"
path = path.with_suffix(ext)
if path.exists():
with path.open() as f:
if dir == self._workspace:
log.warning(
f"Using project yaml file from workspace instead"
f" of config directory {self._config_dir}. Please move "
"it there, reading from workspace will be removed soon."
)
yield f
return
yield None

def flow_path(self, name: str) -> LocalPath:
Expand Down
142 changes: 102 additions & 40 deletions neuro_flow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,12 @@ async def setup_batch_flow_ctx(
async def setup_defaults_env_tags_ctx(
ctx: WithFlowContext,
ast_defaults: Optional[ast.FlowDefaults],
ast_global_defaults: Optional[ast.FlowDefaults],
) -> Tuple[DefaultsConf, EnvCtx, TagsCtx]:
if ast_defaults is not None and ast_global_defaults is not None:
ast_defaults = await merge_asts(ast_defaults, ast_global_defaults)
elif ast_global_defaults:
ast_defaults = ast_global_defaults
env: EnvCtx
tags: TagsCtx
volumes: List[str]
Expand Down Expand Up @@ -1028,7 +1033,12 @@ async def setup_params_ctx(
async def setup_strategy_ctx(
ctx: RootABC,
ast_defaults: Optional[ast.BatchFlowDefaults],
ast_global_defaults: Optional[ast.BatchFlowDefaults],
) -> StrategyCtx:
if ast_defaults is not None and ast_global_defaults is not None:
ast_defaults = await merge_asts(ast_defaults, ast_global_defaults)
elif ast_global_defaults:
ast_defaults = ast_global_defaults
if ast_defaults is None:
return StrategyCtx()
fail_fast = await ast_defaults.fail_fast.eval(ctx)
Expand Down Expand Up @@ -1127,12 +1137,45 @@ def check_module_call_is_local(action_name: str, call_ast: ast.BaseModuleCall) -
)


class MixinProtocol(Protocol):
class SupportsAstMerge(Protocol):
@property
def _specified_fields(self) -> AbstractSet[str]:
...


_MergeTarget = TypeVar("_MergeTarget", bound=SupportsAstMerge)


async def merge_asts(child: _MergeTarget, parent: SupportsAstMerge) -> _MergeTarget:
child_fields = {f.name for f in dataclasses.fields(child)}
for field in parent._specified_fields:
if field == "inherits" or field not in child_fields:
continue
field_present = field in child._specified_fields
child_value = getattr(child, field)
parent_value = getattr(parent, field)
merge_supported = isinstance(parent_value, BaseSequenceExpr) or isinstance(
parent_value, BaseMappingExpr
)
if not field_present or (child_value is None and merge_supported):
child = replace(
child,
**{field: parent_value},
_specified_fields=child._specified_fields | {field},
)
elif isinstance(parent_value, BaseSequenceExpr):
assert isinstance(child_value, BaseSequenceExpr)
child = replace(
child, **{field: ConcatSequenceExpr(child_value, parent_value)}
)
elif isinstance(parent_value, BaseMappingExpr):
assert isinstance(child_value, BaseMappingExpr)
child = replace(
child, **{field: MergeMappingsExpr(child_value, parent_value)}
)
return child


class MixinApplyTarget(Protocol):
@property
def inherits(self) -> Optional[Sequence[StrExpr]]:
Expand All @@ -1147,7 +1190,7 @@ def _specified_fields(self) -> AbstractSet[str]:


async def apply_mixins(
base: _MixinApplyTarget, mixins: Mapping[str, MixinProtocol]
base: _MixinApplyTarget, mixins: Mapping[str, SupportsAstMerge]
) -> _MixinApplyTarget:
if base.inherits is None:
return base
Expand All @@ -1161,31 +1204,7 @@ async def apply_mixins(
start=mixin_expr.start,
end=mixin_expr.end,
)
for field in mixin._specified_fields:
if field == "inherits":
continue # Do not inherit 'inherits' field
field_present = field in base._specified_fields
base_value = getattr(base, field)
mixin_value = getattr(mixin, field)
merge_supported = isinstance(mixin_value, BaseSequenceExpr) or isinstance(
mixin_value, BaseMappingExpr
)
if not field_present or (base_value is None and merge_supported):
base = replace(
base,
**{field: mixin_value},
_specified_fields=base._specified_fields | {field},
)
elif isinstance(mixin_value, BaseSequenceExpr):
assert isinstance(base_value, BaseSequenceExpr)
base = replace(
base, **{field: ConcatSequenceExpr(base_value, mixin_value)}
)
elif isinstance(mixin_value, BaseMappingExpr):
assert isinstance(base_value, BaseMappingExpr)
base = replace(
base, **{field: MergeMappingsExpr(base_value, mixin_value)}
)
base = await merge_asts(base, mixin)
return base


Expand Down Expand Up @@ -1438,6 +1457,7 @@ async def create(
cls, config_loader: ConfigLoader, config_name: str = "live"
) -> "RunningLiveFlow":
ast_flow = await config_loader.fetch_flow(config_name)
ast_project = await config_loader.fetch_project()

assert isinstance(ast_flow, ast.LiveFlow)

Expand All @@ -1451,14 +1471,24 @@ async def create(
)

defaults, env, tags = await setup_defaults_env_tags_ctx(
step_1_ctx, ast_flow.defaults
step_1_ctx, ast_flow.defaults, ast_project.defaults
)

volumes = {
**(await setup_volumes_ctx(step_1_ctx, ast_project.volumes)),
**(await setup_volumes_ctx(step_1_ctx, ast_flow.volumes)),
}

images = {
**(await setup_images_ctx(step_1_ctx, step_1_ctx, ast_project.images)),
**(await setup_images_ctx(step_1_ctx, step_1_ctx, ast_flow.images)),
}

live_ctx = step_1_ctx.to_live_ctx(
env=env,
tags=tags,
volumes=await setup_volumes_ctx(step_1_ctx, ast_flow.volumes),
images=await setup_images_ctx(step_1_ctx, step_1_ctx, ast_flow.images),
volumes=volumes,
images=images,
)

return cls(ast_flow, live_ctx, config_loader, defaults)
Expand Down Expand Up @@ -1963,6 +1993,7 @@ async def create(
local_info: Optional[LocallyPreparedInfo] = None,
) -> "RunningBatchFlow":
ast_flow = await config_loader.fetch_flow(batch)
ast_project = await config_loader.fetch_project()

assert isinstance(ast_flow, ast.BatchFlow)

Expand All @@ -1979,35 +2010,66 @@ async def create(
_client=config_loader.client,
)
if local_info is None:
early_images = await setup_images_early(
step_1_ctx, step_1_ctx, ast_flow.images
)
early_images: Mapping[str, EarlyImageCtx] = {
**(
await setup_images_early(step_1_ctx, step_1_ctx, ast_project.images)
),
**(await setup_images_early(step_1_ctx, step_1_ctx, ast_flow.images)),
}
else:
early_images = local_info.early_images

defaults, env, tags = await setup_defaults_env_tags_ctx(
step_1_ctx, ast_flow.defaults
step_1_ctx, ast_flow.defaults, ast_project.defaults
)

volumes = {
**(await setup_volumes_ctx(step_1_ctx, ast_project.volumes)),
**(await setup_volumes_ctx(step_1_ctx, ast_flow.volumes)),
}

images = {
**(
await setup_images_ctx(
step_1_ctx, step_1_ctx, ast_project.images, early_images
)
),
**(
await setup_images_ctx(
step_1_ctx, step_1_ctx, ast_flow.images, early_images
)
),
}

step_2_ctx = step_1_ctx.to_step_2(
env=env,
tags=tags,
volumes=await setup_volumes_ctx(step_1_ctx, ast_flow.volumes),
images=await setup_images_ctx(
step_1_ctx, step_1_ctx, ast_flow.images, early_images
),
volumes=volumes,
images=images,
)

if ast_project.defaults:
base_cache = await setup_cache(
step_2_ctx,
CacheConf(),
ast_project.defaults.cache,
ast.CacheStrategy.INHERIT,
)
else:
base_cache = CacheConf()

if ast_flow.defaults:
ast_cache = ast_flow.defaults.cache
else:
ast_cache = None
cache_conf = await setup_cache(
step_2_ctx, CacheConf(), ast_cache, ast.CacheStrategy.INHERIT
step_2_ctx, base_cache, ast_cache, ast.CacheStrategy.INHERIT
)

batch_ctx = step_2_ctx.to_batch_ctx(
strategy=await setup_strategy_ctx(step_2_ctx, ast_flow.defaults),
strategy=await setup_strategy_ctx(
step_2_ctx, ast_flow.defaults, ast_project.defaults
),
)

mixins = await setup_mixins(ast_flow.mixins)
Expand Down
Loading

0 comments on commit eedb9e7

Please sign in to comment.