diff --git a/neuro_flow/context.py b/neuro_flow/context.py index 3653629c..e6387a5f 100644 --- a/neuro_flow/context.py +++ b/neuro_flow/context.py @@ -1033,7 +1033,12 @@ async def setup_params_ctx( async def setup_strategy_ctx( ctx: RootABC, ast_defaults: Optional[ast.BatchFlowDefaults], + ast_global_defaults: Optional[ast.BatchFlowDefaults], ) -> StrategyCtx: + if ast_defaults is not None and ast_global_defaults is not None: + ast_defaults = await merge_asts(ast_defaults, ast_global_defaults) + elif ast_global_defaults: + ast_defaults = ast_global_defaults if ast_defaults is None: return StrategyCtx() fail_fast = await ast_defaults.fail_fast.eval(ctx) @@ -2005,9 +2010,12 @@ async def create( _client=config_loader.client, ) if local_info is None: - early_images = await setup_images_early( - step_1_ctx, step_1_ctx, ast_flow.images - ) + early_images: Mapping[str, EarlyImageCtx] = { + **( + await setup_images_early(step_1_ctx, step_1_ctx, ast_project.images) + ), + **(await setup_images_early(step_1_ctx, step_1_ctx, ast_flow.images)), + } else: early_images = local_info.early_images @@ -2040,16 +2048,28 @@ async def create( images=images, ) + if ast_project.defaults: + base_cache = await setup_cache( + step_2_ctx, + CacheConf(), + ast_project.defaults.cache, + ast.CacheStrategy.INHERIT, + ) + else: + base_cache = CacheConf() + if ast_flow.defaults: ast_cache = ast_flow.defaults.cache else: ast_cache = None cache_conf = await setup_cache( - step_2_ctx, CacheConf(), ast_cache, ast.CacheStrategy.INHERIT + step_2_ctx, base_cache, ast_cache, ast.CacheStrategy.INHERIT ) batch_ctx = step_2_ctx.to_batch_ctx( - strategy=await setup_strategy_ctx(step_2_ctx, ast_flow.defaults), + strategy=await setup_strategy_ctx( + step_2_ctx, ast_flow.defaults, ast_project.defaults + ), ) mixins = await setup_mixins(ast_flow.mixins) diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index c3de5937..10e6a11f 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -8,6 +8,7 @@ from yarl import URL from neuro_flow import ast +from neuro_flow.ast import CacheStrategy from neuro_flow.config_loader import BatchLocalCL, ConfigLoader, LiveLocalCL from neuro_flow.context import ( EMPTY_ROOT, @@ -1115,3 +1116,59 @@ async def test_batch_task_with_no_image(assets: pathlib.Path, client: Client) -> finally: await cl.close() + + +async def test_early_images_include_globals( + assets: pathlib.Path, client: Client +) -> None: + ws = assets / "with_project_yaml" + config_dir = ConfigDir( + workspace=ws, + config_dir=ws, + ) + cl = BatchLocalCL(config_dir, client) + try: + flow = await RunningBatchFlow.create(cl, "batch", "bake-id") + assert flow.early_images["image_a"].ref == "image:banana" + assert flow.early_images["image_a"].context == ws / "dir" + assert flow.early_images["image_a"].dockerfile == ws / "dir/Dockerfile" + + assert flow.early_images["image_b"].ref == "image:main" + assert flow.early_images["image_b"].context == ws / "dir" + assert flow.early_images["image_b"].dockerfile == ws / "dir/Dockerfile" + + finally: + await cl.close() + + +async def test_batch_with_project_globals(assets: pathlib.Path, client: Client) -> None: + ws = assets / "with_project_yaml" + config_dir = ConfigDir( + workspace=ws, + config_dir=ws, + ) + cl = BatchLocalCL(config_dir, client) + try: + flow = await RunningBatchFlow.create(cl, "batch", "bake-id") + task = await flow.get_task((), "task", needs={}, state={}) + assert "tag-a" in task.tags + assert "tag-b" in task.tags + assert task.env["global_a"] == "val-a" + assert task.env["global_b"] == "val-b" + assert task.volumes == [ + "storage:common:/mnt/common:rw", + "storage:dir:/var/dir:ro", + ] + assert task.workdir == RemotePath("/global/dir") + assert task.life_span == 100800.0 + assert task.preset == "cpu-large" + assert task.schedule_timeout == 2157741.0 + assert task.image == "image:main" + + assert not task.strategy.fail_fast + assert task.strategy.max_parallel == 20 + assert task.cache.strategy == CacheStrategy.NONE + assert task.cache.life_span == 9000.0 + + finally: + await cl.close() diff --git a/tests/unit/test_project_parser.py b/tests/unit/test_project_parser.py index 08391d30..d4654ef9 100644 --- a/tests/unit/test_project_parser.py +++ b/tests/unit/test_project_parser.py @@ -186,7 +186,9 @@ def test_parse_full(assets: pathlib.Path) -> None: Pos(0, 0, config_file), Pos(0, 0, config_file), "2h30m" ), ), - fail_fast=OptBoolExpr(Pos(0, 0, config_file), Pos(0, 0, config_file), True), - max_parallel=OptIntExpr(Pos(0, 0, config_file), Pos(0, 0, config_file), 10), + fail_fast=OptBoolExpr( + Pos(0, 0, config_file), Pos(0, 0, config_file), False + ), + max_parallel=OptIntExpr(Pos(0, 0, config_file), Pos(0, 0, config_file), 20), ), ) diff --git a/tests/unit/with_project_yaml/batch.yml b/tests/unit/with_project_yaml/batch.yml new file mode 100644 index 00000000..de7e1106 --- /dev/null +++ b/tests/unit/with_project_yaml/batch.yml @@ -0,0 +1,12 @@ +kind: batch +images: + image_b: + ref: image:main + context: dir + dockerfile: dir/Dockerfile +tasks: +- id: task + image: image:main + bash: echo OK + volumes: + - ${{ volumes.volume_a.ref }} diff --git a/tests/unit/with_project_yaml/live.yml b/tests/unit/with_project_yaml/live.yml new file mode 100644 index 00000000..8c78e6c9 --- /dev/null +++ b/tests/unit/with_project_yaml/live.yml @@ -0,0 +1,6 @@ +kind: live +jobs: + test: + image: ${{ images.image_a.ref }} + volumes: + - ${{ volumes.volume_a.ref }} diff --git a/tests/unit/with_project_yaml/project.yml b/tests/unit/with_project_yaml/project.yml index 072c9484..9a72e256 100644 --- a/tests/unit/with_project_yaml/project.yml +++ b/tests/unit/with_project_yaml/project.yml @@ -35,8 +35,8 @@ defaults: life_span: 1d4h preset: cpu-large schedule_timeout: 24d23h22m21s - fail_fast: true - max_parallel: 10 + fail_fast: false + max_parallel: 20 cache: strategy: none life_span: 2h30m