Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixins in project #562

Merged
merged 6 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.D/560.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Renamed `inherits` yaml property to `mixins`.
31 changes: 27 additions & 4 deletions neuro_flow/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class Project(Base):
images: Optional[Mapping[str, "Image"]] = field(metadata={"allow_none": True})
volumes: Optional[Mapping[str, "Volume"]] = field(metadata={"allow_none": True})
defaults: Optional["BatchFlowDefaults"] = field(metadata={"allow_none": True})
mixins: Optional[Mapping[str, "ExecUnitMixin"]] = field(
metadata={"allow_none": True}
)


# There are 'batch' for pipelined mode and 'live' for interactive one
Expand Down Expand Up @@ -95,6 +98,26 @@ class Image(Base):
force_rebuild: OptBoolExpr


@dataclass(frozen=True)
class ExecUnitMixin(WithSpecifiedFields, Base):
title: OptStrExpr # Autocalculated if not passed explicitly
name: OptStrExpr
image: OptStrExpr
preset: OptStrExpr
schedule_timeout: OptTimeDeltaExpr
entrypoint: OptStrExpr
cmd: OptStrExpr
workdir: OptRemotePathExpr
env: Optional[BaseExpr[MappingT]] = field(metadata={"allow_none": True})
volumes: Optional[BaseExpr[SequenceT]] = field(metadata={"allow_none": True})
tags: Optional[BaseExpr[SequenceT]] = field(metadata={"allow_none": True})
life_span: OptTimeDeltaExpr
http_port: OptIntExpr
http_auth: OptBoolExpr
pass_config: OptBoolExpr
mixins: Optional[Sequence[StrExpr]] = field(metadata={"allow_none": True})


@dataclass(frozen=True)
class ExecUnit(Base):
title: OptStrExpr # Autocalculated if not passed explicitly
Expand Down Expand Up @@ -170,7 +193,7 @@ class JobMixin(WithSpecifiedFields, Base):
port_forward: Optional[BaseExpr[SequenceT]] = field(metadata={"allow_none": True})
multi: SimpleOptBoolExpr
params: Optional[Mapping[str, Param]] = field(metadata={"allow_none": True})
inherits: Optional[Sequence[StrExpr]] = field(metadata={"allow_none": True})
mixins: Optional[Sequence[StrExpr]] = field(metadata={"allow_none": True})


@dataclass(frozen=True)
Expand All @@ -181,7 +204,7 @@ class Job(ExecUnit, WithSpecifiedFields, JobBase):
browse: OptBoolExpr
port_forward: Optional[BaseExpr[SequenceT]] = field(metadata={"allow_none": True})
multi: SimpleOptBoolExpr
inherits: Optional[Sequence[StrExpr]] = field(metadata={"allow_none": True})
mixins: Optional[Sequence[StrExpr]] = field(metadata={"allow_none": True})


class NeedsLevel(enum.Enum):
Expand Down Expand Up @@ -209,7 +232,7 @@ class TaskBase(Base):

@dataclass(frozen=True)
class Task(ExecUnit, WithSpecifiedFields, TaskBase):
inherits: Optional[Sequence[StrExpr]] = field(metadata={"allow_none": True})
mixins: Optional[Sequence[StrExpr]] = field(metadata={"allow_none": True})


@dataclass(frozen=True)
Expand All @@ -233,7 +256,7 @@ class TaskMixin(WithSpecifiedFields, Base):
strategy: Optional[Strategy] = field(metadata={"allow_none": True})
enable: EnableExpr = field(metadata={"default_expr": "${{ success() }}"})
cache: Optional[Cache] = field(metadata={"allow_none": True})
inherits: Optional[Sequence[StrExpr]] = field(metadata={"allow_none": True})
mixins: Optional[Sequence[StrExpr]] = field(metadata={"allow_none": True})


@dataclass(frozen=True)
Expand Down
59 changes: 32 additions & 27 deletions neuro_flow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ def _specified_fields(self) -> AbstractSet[str]:
async def merge_asts(child: _MergeTarget, parent: SupportsAstMerge) -> _MergeTarget:
child_fields = {f.name for f in dataclasses.fields(child)}
for field in parent._specified_fields:
if field == "inherits" or field not in child_fields:
if field == "mixins" or field not in child_fields:
continue
field_present = field in child._specified_fields
child_value = getattr(child, field)
Expand Down Expand Up @@ -1178,7 +1178,7 @@ async def merge_asts(child: _MergeTarget, parent: SupportsAstMerge) -> _MergeTar

class MixinApplyTarget(Protocol):
@property
def inherits(self) -> Optional[Sequence[StrExpr]]:
def mixins(self) -> Optional[Sequence[StrExpr]]:
...

@property
Expand All @@ -1192,9 +1192,9 @@ def _specified_fields(self) -> AbstractSet[str]:
async def apply_mixins(
base: _MixinApplyTarget, mixins: Mapping[str, SupportsAstMerge]
) -> _MixinApplyTarget:
if base.inherits is None:
if base.mixins is None:
return base
for mixin_expr in reversed(base.inherits):
for mixin_expr in reversed(base.mixins):
mixin_name = await mixin_expr.eval(EMPTY_ROOT)
try:
mixin = mixins[mixin_name]
Expand All @@ -1215,10 +1215,8 @@ async def setup_mixins(
return {}
graph: Dict[str, Dict[str, int]] = {}
for mixin_name, mixin in raw_mixins.items():
inherits = mixin.inherits or []
graph[mixin_name] = {
await dep_expr.eval(EMPTY_ROOT): 1 for dep_expr in inherits
}
mixins = mixin.mixins or []
graph[mixin_name] = {await dep_expr.eval(EMPTY_ROOT): 1 for dep_expr in mixins}
topo = ColoredTopoSorter(graph)
result: Dict[str, _MixinApplyTarget] = {}
while not topo.is_all_colored(1):
Expand All @@ -1232,19 +1230,21 @@ class RunningLiveFlow:
_ast_flow: ast.LiveFlow
_ctx: LiveContext
_cl: ConfigLoader
_mixins: Optional[Mapping[str, ast.JobMixin]] = None
_mixins: Mapping[str, SupportsAstMerge]

def __init__(
self,
ast_flow: ast.LiveFlow,
ctx: LiveContext,
config_loader: ConfigLoader,
defaults: DefaultsConf,
mixins: Mapping[str, SupportsAstMerge],
):
self._ast_flow = ast_flow
self._ctx = ctx
self._cl = config_loader
self._defaults = defaults
self._mixins = mixins

@property
def job_ids(self) -> Iterable[str]:
Expand Down Expand Up @@ -1274,18 +1274,13 @@ async def is_multi(self, job_id: str) -> bool:
# Simple shortcut
return (await self.get_meta(job_id)).multi

async def get_mixins(self) -> Mapping[str, ast.JobMixin]:
if self._mixins is None:
self._mixins = await setup_mixins(self._ast_flow.mixins)
return self._mixins

async def _get_job_ast(
self, job_id: str
) -> Union[ast.Job, ast.JobActionCall, ast.JobModuleCall]:
try:
base = self._ast_flow.jobs[job_id]
if isinstance(base, ast.Job):
base = await apply_mixins(base, await self.get_mixins())
base = await apply_mixins(base, self._mixins)
return base
except KeyError:
raise UnknownJob(job_id)
Expand Down Expand Up @@ -1491,7 +1486,13 @@ async def create(
images=images,
)

return cls(ast_flow, live_ctx, config_loader, defaults)
raw_mixins: Mapping[str, MixinApplyTarget] = {
**(ast_project.mixins or {}),
**(ast_flow.mixins or {}),
}
mixins = await setup_mixins(raw_mixins)

return cls(ast_flow, live_ctx, config_loader, defaults, mixins)


_T = TypeVar("_T", bound=BaseBatchContext, covariant=True)
Expand All @@ -1514,7 +1515,7 @@ def graph(self) -> Mapping[str, Mapping[str, ast.NeedsLevel]]:

@property
@abstractmethod
def mixins(self) -> Optional[Mapping[str, ast.TaskMixin]]:
def mixins(self) -> Optional[Mapping[str, SupportsAstMerge]]:
pass

@property
Expand Down Expand Up @@ -1647,7 +1648,7 @@ def __init__(
config_loader: ConfigLoader,
action: ast.BatchAction,
parent_ctx_class: Type[RootABC],
mixins: Optional[Mapping[str, ast.TaskMixin]],
mixins: Optional[Mapping[str, SupportsAstMerge]],
):
super().__init__(ctx, tasks, config_loader)
self._action = action
Expand All @@ -1660,7 +1661,7 @@ def early_images(self) -> Mapping[str, EarlyImageCtx]:
return self._early_images

@property
def mixins(self) -> Optional[Mapping[str, ast.TaskMixin]]:
def mixins(self) -> Optional[Mapping[str, SupportsAstMerge]]:
return self._mixins

def get_image_ast(self, image_id: str) -> ast.Image:
Expand Down Expand Up @@ -1942,7 +1943,7 @@ def __init__(
local_info: Optional[LocallyPreparedInfo],
ast_flow: ast.BatchFlow,
ast_project: ast.Project,
mixins: Optional[Mapping[str, ast.TaskMixin]],
mixins: Optional[Mapping[str, SupportsAstMerge]],
):
super().__init__(
ctx,
Expand All @@ -1969,7 +1970,7 @@ def get_image_ast(self, image_id: str) -> ast.Image:
raise

@property
def mixins(self) -> Optional[Mapping[str, ast.TaskMixin]]:
def mixins(self) -> Optional[Mapping[str, SupportsAstMerge]]:
return self._mixins

@property
Expand Down Expand Up @@ -2079,7 +2080,11 @@ async def create(
),
)

mixins = await setup_mixins(ast_flow.mixins)
raw_mixins: Mapping[str, MixinApplyTarget] = {
**(ast_project.mixins or {}),
**(ast_flow.mixins or {}),
}
mixins = await setup_mixins(raw_mixins)
tasks = await TaskGraphBuilder(
batch_ctx, config_loader, cache_conf, ast_flow.tasks, mixins
).build()
Expand Down Expand Up @@ -2109,7 +2114,7 @@ def __init__(
action: ast.BatchAction,
bake_id: str,
local_info: Optional[LocallyPreparedInfo],
mixins: Optional[Mapping[str, ast.TaskMixin]],
mixins: Optional[Mapping[str, SupportsAstMerge]],
):
super().__init__(
flow_ctx,
Expand All @@ -2130,7 +2135,7 @@ def get_image_ast(self, image_id: str) -> ast.Image:
return self._action.images[image_id]

@property
def mixins(self) -> Optional[Mapping[str, ast.TaskMixin]]:
def mixins(self) -> Optional[Mapping[str, SupportsAstMerge]]:
return self._mixins

async def calc_outputs(self, task_results: NeedsCtx) -> DepCtx:
Expand Down Expand Up @@ -2162,7 +2167,7 @@ async def create(
bake_id: str,
local_info: Optional[LocallyPreparedInfo],
defaults: DefaultsConf = DefaultsConf(),
mixins: Optional[Mapping[str, ast.TaskMixin]] = None,
mixins: Optional[Mapping[str, SupportsAstMerge]] = None,
) -> "RunningBatchActionFlow":
step_1_ctx = BatchActionContextStep1(
inputs=inputs,
Expand Down Expand Up @@ -2503,7 +2508,7 @@ def __init__(
self,
config_loader: ConfigLoader,
ast_tasks: Sequence[Union[ast.Task, ast.TaskActionCall, ast.TaskModuleCall]],
mixins: Optional[Mapping[str, ast.TaskMixin]],
mixins: Optional[Mapping[str, SupportsAstMerge]],
):
self._cl = config_loader
self._ast_tasks = ast_tasks
Expand Down Expand Up @@ -2696,7 +2701,7 @@ def __init__(
config_loader: ConfigLoader,
default_cache: CacheConf,
ast_tasks: Sequence[Union[ast.Task, ast.TaskActionCall, ast.TaskModuleCall]],
mixins: Optional[Mapping[str, ast.TaskMixin]],
mixins: Optional[Mapping[str, SupportsAstMerge]],
):
super().__init__(config_loader, ast_tasks, mixins)
self._ctx = ctx
Expand Down
37 changes: 25 additions & 12 deletions neuro_flow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,11 +475,18 @@ async def eval(self, root: RootABC) -> LiteralT:
return self.val


def literal(toktype: str) -> Parser:
def f(tok: Token) -> Any:
def make_toktype_predicate(toktype: str) -> Callable[[Token], bool]:
def _predicate(token: Token) -> bool:
return token.type == toktype

return _predicate


def literal(toktype: str) -> "Parser[Token, Literal]":
def f(tok: Token) -> Literal:
return Literal(tok.start, tok.end, literal_eval(tok.value))

return some(lambda tok: tok.type == toktype) >> f
return some(make_toktype_predicate(toktype)) >> f


class Getter(Entity):
Expand Down Expand Up @@ -756,7 +763,7 @@ def make_list(args: Tuple[Item, List[Item]]) -> ListMaker:
return ListMaker(lst[0].start, lst[-1].end, lst)


def make_empty_list(args: Tuple[Item, Item]) -> ListMaker:
def make_empty_list(args: Tuple[Token, Token]) -> ListMaker:
return ListMaker(args[0].start, args[1].end, [])


Expand All @@ -778,23 +785,29 @@ def child_items(self) -> Iterable["Item"]:
yield entry[1]


def make_dict(args: Tuple[Item, Item, List[Tuple[Item, Item]]]) -> DictMaker:
def make_dict(
args: Union[Tuple[Item, Item, List[Tuple[Item, Item]]], Tuple[Item, Item]]
) -> DictMaker:
lst = [(args[0], args[1])]
if len(args) > 2:
lst += args[2]
lst += args[2] # type: ignore
return DictMaker(lst[0][0].start, lst[-1][1].end, lst)


def make_empty_dict(args: Tuple[Item, Item]) -> DictMaker:
def make_empty_dict(args: Tuple[Token, Token]) -> DictMaker:
return DictMaker(args[0].start, args[0].end, [])


def a(value: str) -> Parser:
def a(value: str) -> "Parser[Token, Token]":
"""Eq(a) -> Parser(a, a)

Returns a parser that parses a token that is equal to the value value.
"""
return some(lambda t: t.value == value).named(f'(a "{value}")')

def _is_value_eq(token: Token) -> bool:
return token.value == value

return some(_is_value_eq).named(f'(a "{value}")')


DOT: Final = skip(a("."))
Expand Down Expand Up @@ -837,7 +850,7 @@ def a(value: str) -> Parser:

LITERAL: Final = NONE | BOOL | REAL | INT | STR

NAME: Final = some(lambda tok: tok.type == "NAME")
NAME: Final = some(make_toktype_predicate("NAME"))

LIST_MAKER: Final = forward_decl()

Expand Down Expand Up @@ -937,13 +950,13 @@ def a(value: str) -> Parser:
+ maybe(COMMA)
+ RBRACE
)
>> make_dict
>> make_dict # type: ignore
| (a("{") + a("}")) >> make_empty_dict
)

TMPL: Final = (OPEN_TMPL + EXPR + CLOSE_TMPL) | (OPEN_TMPL2 + EXPR + CLOSE_TMPL2)

TEXT: Final = some(lambda tok: tok.type == "TEXT") >> make_text
TEXT: Final = some(make_toktype_predicate("TEXT")) >> make_text

PARSER: Final = oneplus(TMPL | TEXT) + skip(finished)

Expand Down
Loading