Skip to content

Commit

Permalink
Support global defaults in batch mode
Browse files Browse the repository at this point in the history
  • Loading branch information
romasku committed Jul 7, 2021
1 parent cc84f45 commit f18809b
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 9 deletions.
30 changes: 25 additions & 5 deletions neuro_flow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,12 @@ async def setup_params_ctx(
async def setup_strategy_ctx(
ctx: RootABC,
ast_defaults: Optional[ast.BatchFlowDefaults],
ast_global_defaults: Optional[ast.BatchFlowDefaults],
) -> StrategyCtx:
if ast_defaults is not None and ast_global_defaults is not None:
ast_defaults = await merge_asts(ast_defaults, ast_global_defaults)
elif ast_global_defaults:
ast_defaults = ast_global_defaults
if ast_defaults is None:
return StrategyCtx()
fail_fast = await ast_defaults.fail_fast.eval(ctx)
Expand Down Expand Up @@ -2005,9 +2010,12 @@ async def create(
_client=config_loader.client,
)
if local_info is None:
early_images = await setup_images_early(
step_1_ctx, step_1_ctx, ast_flow.images
)
early_images: Mapping[str, EarlyImageCtx] = {
**(
await setup_images_early(step_1_ctx, step_1_ctx, ast_project.images)
),
**(await setup_images_early(step_1_ctx, step_1_ctx, ast_flow.images)),
}
else:
early_images = local_info.early_images

Expand Down Expand Up @@ -2040,16 +2048,28 @@ async def create(
images=images,
)

if ast_project.defaults:
base_cache = await setup_cache(
step_2_ctx,
CacheConf(),
ast_project.defaults.cache,
ast.CacheStrategy.INHERIT,
)
else:
base_cache = CacheConf()

if ast_flow.defaults:
ast_cache = ast_flow.defaults.cache
else:
ast_cache = None
cache_conf = await setup_cache(
step_2_ctx, CacheConf(), ast_cache, ast.CacheStrategy.INHERIT
step_2_ctx, base_cache, ast_cache, ast.CacheStrategy.INHERIT
)

batch_ctx = step_2_ctx.to_batch_ctx(
strategy=await setup_strategy_ctx(step_2_ctx, ast_flow.defaults),
strategy=await setup_strategy_ctx(
step_2_ctx, ast_flow.defaults, ast_project.defaults
),
)

mixins = await setup_mixins(ast_flow.mixins)
Expand Down
57 changes: 57 additions & 0 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from yarl import URL

from neuro_flow import ast
from neuro_flow.ast import CacheStrategy
from neuro_flow.config_loader import BatchLocalCL, ConfigLoader, LiveLocalCL
from neuro_flow.context import (
EMPTY_ROOT,
Expand Down Expand Up @@ -1115,3 +1116,59 @@ async def test_batch_task_with_no_image(assets: pathlib.Path, client: Client) ->

finally:
await cl.close()


async def test_early_images_include_globals(
assets: pathlib.Path, client: Client
) -> None:
ws = assets / "with_project_yaml"
config_dir = ConfigDir(
workspace=ws,
config_dir=ws,
)
cl = BatchLocalCL(config_dir, client)
try:
flow = await RunningBatchFlow.create(cl, "batch", "bake-id")
assert flow.early_images["image_a"].ref == "image:banana"
assert flow.early_images["image_a"].context == ws / "dir"
assert flow.early_images["image_a"].dockerfile == ws / "dir/Dockerfile"

assert flow.early_images["image_b"].ref == "image:main"
assert flow.early_images["image_b"].context == ws / "dir"
assert flow.early_images["image_b"].dockerfile == ws / "dir/Dockerfile"

finally:
await cl.close()


async def test_batch_with_project_globals(assets: pathlib.Path, client: Client) -> None:
ws = assets / "with_project_yaml"
config_dir = ConfigDir(
workspace=ws,
config_dir=ws,
)
cl = BatchLocalCL(config_dir, client)
try:
flow = await RunningBatchFlow.create(cl, "batch", "bake-id")
task = await flow.get_task((), "task", needs={}, state={})
assert "tag-a" in task.tags
assert "tag-b" in task.tags
assert task.env["global_a"] == "val-a"
assert task.env["global_b"] == "val-b"
assert task.volumes == [
"storage:common:/mnt/common:rw",
"storage:dir:/var/dir:ro",
]
assert task.workdir == RemotePath("/global/dir")
assert task.life_span == 100800.0
assert task.preset == "cpu-large"
assert task.schedule_timeout == 2157741.0
assert task.image == "image:main"

assert not task.strategy.fail_fast
assert task.strategy.max_parallel == 20
assert task.cache.strategy == CacheStrategy.NONE
assert task.cache.life_span == 9000.0

finally:
await cl.close()
6 changes: 4 additions & 2 deletions tests/unit/test_project_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def test_parse_full(assets: pathlib.Path) -> None:
Pos(0, 0, config_file), Pos(0, 0, config_file), "2h30m"
),
),
fail_fast=OptBoolExpr(Pos(0, 0, config_file), Pos(0, 0, config_file), True),
max_parallel=OptIntExpr(Pos(0, 0, config_file), Pos(0, 0, config_file), 10),
fail_fast=OptBoolExpr(
Pos(0, 0, config_file), Pos(0, 0, config_file), False
),
max_parallel=OptIntExpr(Pos(0, 0, config_file), Pos(0, 0, config_file), 20),
),
)
12 changes: 12 additions & 0 deletions tests/unit/with_project_yaml/batch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
kind: batch
images:
image_b:
ref: image:main
context: dir
dockerfile: dir/Dockerfile
tasks:
- id: task
image: image:main
bash: echo OK
volumes:
- ${{ volumes.volume_a.ref }}
6 changes: 6 additions & 0 deletions tests/unit/with_project_yaml/live.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: live
jobs:
test:
image: ${{ images.image_a.ref }}
volumes:
- ${{ volumes.volume_a.ref }}
4 changes: 2 additions & 2 deletions tests/unit/with_project_yaml/project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ defaults:
life_span: 1d4h
preset: cpu-large
schedule_timeout: 24d23h22m21s
fail_fast: true
max_parallel: 10
fail_fast: false
max_parallel: 20
cache:
strategy: none
life_span: 2h30m

0 comments on commit f18809b

Please sign in to comment.