Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Project level defaults #506

Merged
merged 5 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.D/506.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added new sections `defaults`, `images`, `volumes` to the `project.yml` file. The work the same as the do
in `live`/`batch` except they are global -- everything defined in `project.yml` applies to all workflows.
6 changes: 5 additions & 1 deletion neuro_flow/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class Project(Base):
owner: SimpleOptStrExpr # user name can contain "-"
role: SimpleOptStrExpr

images: Optional[Mapping[str, "Image"]] = field(metadata={"allow_none": True})
volumes: Optional[Mapping[str, "Volume"]] = field(metadata={"allow_none": True})
defaults: Optional["BatchFlowDefaults"] = field(metadata={"allow_none": True})


# There are 'batch' for pipelined mode and 'live' for interactive one
# (while 'batches' are technically just non-interactive jobs.
Expand Down Expand Up @@ -265,7 +269,7 @@ class TaskModuleCall(BaseModuleCall, TaskBase):


@dataclass(frozen=True)
class FlowDefaults(Base):
class FlowDefaults(WithSpecifiedFields, Base):
tags: Optional[BaseExpr[SequenceT]] = field(metadata={"allow_none": True})

env: Optional[BaseExpr[MappingT]] = field(metadata={"allow_none": True})
Expand Down
25 changes: 18 additions & 7 deletions neuro_flow/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import aiohttp
import logging
import secrets
import sys
import tarfile
Expand Down Expand Up @@ -45,6 +46,9 @@
from async_generator import asynccontextmanager


log = logging.getLogger(__name__)


@dataclasses.dataclass(frozen=True)
class ActionSpec:
scheme: str
Expand Down Expand Up @@ -203,13 +207,20 @@ def workspace(self) -> LocalPath:

@asynccontextmanager
async def project_stream(self) -> AsyncIterator[Optional[TextIO]]:
for ext in (".yml", ".yaml"):
path = self._workspace / "project"
path = path.with_suffix(ext)
if path.exists():
with path.open() as f:
yield f
return
for dir in (self._config_dir, self._workspace):
for ext in (".yml", ".yaml"):
path = dir / "project"
path = path.with_suffix(ext)
if path.exists():
with path.open() as f:
if dir == self._workspace:
log.warning(
f"Using project yaml file from workspace instead"
f" of config directory {self._config_dir}. Please move "
"it there, reading from workspace will be removed soon."
)
yield f
return
yield None

def flow_path(self, name: str) -> LocalPath:
Expand Down
142 changes: 102 additions & 40 deletions neuro_flow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,12 @@ async def setup_batch_flow_ctx(
async def setup_defaults_env_tags_ctx(
ctx: WithFlowContext,
ast_defaults: Optional[ast.FlowDefaults],
ast_global_defaults: Optional[ast.FlowDefaults],
) -> Tuple[DefaultsConf, EnvCtx, TagsCtx]:
if ast_defaults is not None and ast_global_defaults is not None:
ast_defaults = await merge_asts(ast_defaults, ast_global_defaults)
elif ast_global_defaults:
ast_defaults = ast_global_defaults
env: EnvCtx
tags: TagsCtx
volumes: List[str]
Expand Down Expand Up @@ -1028,7 +1033,12 @@ async def setup_params_ctx(
async def setup_strategy_ctx(
ctx: RootABC,
ast_defaults: Optional[ast.BatchFlowDefaults],
ast_global_defaults: Optional[ast.BatchFlowDefaults],
) -> StrategyCtx:
if ast_defaults is not None and ast_global_defaults is not None:
ast_defaults = await merge_asts(ast_defaults, ast_global_defaults)
elif ast_global_defaults:
ast_defaults = ast_global_defaults
if ast_defaults is None:
return StrategyCtx()
fail_fast = await ast_defaults.fail_fast.eval(ctx)
Expand Down Expand Up @@ -1127,12 +1137,45 @@ def check_module_call_is_local(action_name: str, call_ast: ast.BaseModuleCall) -
)


class MixinProtocol(Protocol):
class SupportsAstMerge(Protocol):
@property
def _specified_fields(self) -> AbstractSet[str]:
...


_MergeTarget = TypeVar("_MergeTarget", bound=SupportsAstMerge)


async def merge_asts(child: _MergeTarget, parent: SupportsAstMerge) -> _MergeTarget:
child_fields = {f.name for f in dataclasses.fields(child)}
for field in parent._specified_fields:
if field == "inherits" or field not in child_fields:
continue
field_present = field in child._specified_fields
child_value = getattr(child, field)
parent_value = getattr(parent, field)
merge_supported = isinstance(parent_value, BaseSequenceExpr) or isinstance(
parent_value, BaseMappingExpr
)
if not field_present or (child_value is None and merge_supported):
child = replace(
child,
**{field: parent_value},
_specified_fields=child._specified_fields | {field},
)
elif isinstance(parent_value, BaseSequenceExpr):
assert isinstance(child_value, BaseSequenceExpr)
child = replace(
child, **{field: ConcatSequenceExpr(child_value, parent_value)}
)
elif isinstance(parent_value, BaseMappingExpr):
assert isinstance(child_value, BaseMappingExpr)
child = replace(
child, **{field: MergeMappingsExpr(child_value, parent_value)}
)
return child


class MixinApplyTarget(Protocol):
@property
def inherits(self) -> Optional[Sequence[StrExpr]]:
Expand All @@ -1147,7 +1190,7 @@ def _specified_fields(self) -> AbstractSet[str]:


async def apply_mixins(
base: _MixinApplyTarget, mixins: Mapping[str, MixinProtocol]
base: _MixinApplyTarget, mixins: Mapping[str, SupportsAstMerge]
) -> _MixinApplyTarget:
if base.inherits is None:
return base
Expand All @@ -1161,31 +1204,7 @@ async def apply_mixins(
start=mixin_expr.start,
end=mixin_expr.end,
)
for field in mixin._specified_fields:
if field == "inherits":
continue # Do not inherit 'inherits' field
field_present = field in base._specified_fields
base_value = getattr(base, field)
mixin_value = getattr(mixin, field)
merge_supported = isinstance(mixin_value, BaseSequenceExpr) or isinstance(
mixin_value, BaseMappingExpr
)
if not field_present or (base_value is None and merge_supported):
base = replace(
base,
**{field: mixin_value},
_specified_fields=base._specified_fields | {field},
)
elif isinstance(mixin_value, BaseSequenceExpr):
assert isinstance(base_value, BaseSequenceExpr)
base = replace(
base, **{field: ConcatSequenceExpr(base_value, mixin_value)}
)
elif isinstance(mixin_value, BaseMappingExpr):
assert isinstance(base_value, BaseMappingExpr)
base = replace(
base, **{field: MergeMappingsExpr(base_value, mixin_value)}
)
base = await merge_asts(base, mixin)
return base


Expand Down Expand Up @@ -1438,6 +1457,7 @@ async def create(
cls, config_loader: ConfigLoader, config_name: str = "live"
) -> "RunningLiveFlow":
ast_flow = await config_loader.fetch_flow(config_name)
ast_project = await config_loader.fetch_project()

assert isinstance(ast_flow, ast.LiveFlow)

Expand All @@ -1451,14 +1471,24 @@ async def create(
)

defaults, env, tags = await setup_defaults_env_tags_ctx(
step_1_ctx, ast_flow.defaults
step_1_ctx, ast_flow.defaults, ast_project.defaults
)

volumes = {
**(await setup_volumes_ctx(step_1_ctx, ast_project.volumes)),
**(await setup_volumes_ctx(step_1_ctx, ast_flow.volumes)),
}

images = {
**(await setup_images_ctx(step_1_ctx, step_1_ctx, ast_project.images)),
**(await setup_images_ctx(step_1_ctx, step_1_ctx, ast_flow.images)),
}

live_ctx = step_1_ctx.to_live_ctx(
env=env,
tags=tags,
volumes=await setup_volumes_ctx(step_1_ctx, ast_flow.volumes),
images=await setup_images_ctx(step_1_ctx, step_1_ctx, ast_flow.images),
volumes=volumes,
images=images,
)

return cls(ast_flow, live_ctx, config_loader, defaults)
Expand Down Expand Up @@ -1963,6 +1993,7 @@ async def create(
local_info: Optional[LocallyPreparedInfo] = None,
) -> "RunningBatchFlow":
ast_flow = await config_loader.fetch_flow(batch)
ast_project = await config_loader.fetch_project()

assert isinstance(ast_flow, ast.BatchFlow)

Expand All @@ -1979,35 +2010,66 @@ async def create(
_client=config_loader.client,
)
if local_info is None:
early_images = await setup_images_early(
step_1_ctx, step_1_ctx, ast_flow.images
)
early_images: Mapping[str, EarlyImageCtx] = {
**(
await setup_images_early(step_1_ctx, step_1_ctx, ast_project.images)
),
**(await setup_images_early(step_1_ctx, step_1_ctx, ast_flow.images)),
}
else:
early_images = local_info.early_images

defaults, env, tags = await setup_defaults_env_tags_ctx(
step_1_ctx, ast_flow.defaults
step_1_ctx, ast_flow.defaults, ast_project.defaults
)

volumes = {
**(await setup_volumes_ctx(step_1_ctx, ast_project.volumes)),
**(await setup_volumes_ctx(step_1_ctx, ast_flow.volumes)),
}

images = {
**(
await setup_images_ctx(
step_1_ctx, step_1_ctx, ast_project.images, early_images
)
),
**(
await setup_images_ctx(
step_1_ctx, step_1_ctx, ast_flow.images, early_images
)
),
}

step_2_ctx = step_1_ctx.to_step_2(
env=env,
tags=tags,
volumes=await setup_volumes_ctx(step_1_ctx, ast_flow.volumes),
images=await setup_images_ctx(
step_1_ctx, step_1_ctx, ast_flow.images, early_images
),
volumes=volumes,
images=images,
)

if ast_project.defaults:
base_cache = await setup_cache(
step_2_ctx,
CacheConf(),
ast_project.defaults.cache,
ast.CacheStrategy.INHERIT,
)
else:
base_cache = CacheConf()

if ast_flow.defaults:
ast_cache = ast_flow.defaults.cache
else:
ast_cache = None
cache_conf = await setup_cache(
step_2_ctx, CacheConf(), ast_cache, ast.CacheStrategy.INHERIT
step_2_ctx, base_cache, ast_cache, ast.CacheStrategy.INHERIT
)

batch_ctx = step_2_ctx.to_batch_ctx(
strategy=await setup_strategy_ctx(step_2_ctx, ast_flow.defaults),
strategy=await setup_strategy_ctx(
step_2_ctx, ast_flow.defaults, ast_project.defaults
),
)

mixins = await setup_mixins(ast_flow.mixins)
Expand Down
Loading