From e2114233984e98bb76d0b5cea8ccedf4da31b32d Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Mon, 2 Nov 2020 18:56:23 +0545 Subject: [PATCH] Implement foreach ... in loop in dvc.yaml (#4734) * Implement foreach...in loop in dvc.yaml * fix tests --- dvc/parsing/__init__.py | 28 +++++++++++++ dvc/schema.py | 42 ++++++++++--------- tests/func/test_stage_resolver.py | 69 +++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 20 deletions(-) diff --git a/dvc/parsing/__init__.py b/dvc/parsing/__init__.py index b173de4384..7b11d34734 100644 --- a/dvc/parsing/__init__.py +++ b/dvc/parsing/__init__.py @@ -1,6 +1,7 @@ import logging import os from collections import defaultdict +from collections.abc import Mapping, Sequence from copy import deepcopy from itertools import starmap from typing import TYPE_CHECKING @@ -25,6 +26,10 @@ WDIR_KWD = "wdir" DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE PARAMS_KWD = "params" +FOREACH_KWD = "foreach" +IN_KWD = "in" + +DEFAULT_SENTINEL = object() class DataResolver: @@ -50,6 +55,11 @@ def __init__(self, repo: "Repo", wdir: PathInfo, d: dict): def _resolve_entry(self, name: str, definition): context = Context.clone(self.global_ctx) + if FOREACH_KWD in definition: + assert IN_KWD in definition + return self._foreach( + context, name, definition[FOREACH_KWD], definition[IN_KWD] + ) return self._resolve_stage(context, name, definition) def resolve(self): @@ -114,3 +124,21 @@ def _resolve_wdir(self, context: Context, wdir: str = None) -> PathInfo: return self.wdir wdir = resolve(wdir, context) return self.wdir / str(wdir) + + def _foreach(self, context: Context, name: str, foreach_data, in_data): + def each_iter(value, key=DEFAULT_SENTINEL): + c = Context.clone(context) + c["item"] = value + if key is not DEFAULT_SENTINEL: + c["key"] = key + suffix = str(key if key is not DEFAULT_SENTINEL else value) + return self._resolve_stage(c, f"{name}-{suffix}", in_data) + + iterable = resolve(foreach_data, context) + if isinstance(iterable, Sequence): + gen = (each_iter(v) for v in iterable) + elif isinstance(iterable, Mapping): + gen = (each_iter(v, k) for k, v in iterable.items()) + else: + raise Exception(f"got type of {type(iterable)}") + return join(gen) diff --git a/dvc/schema.py b/dvc/schema.py index 28f14ab7f3..c27c54d48e 100644 --- a/dvc/schema.py +++ b/dvc/schema.py @@ -2,7 +2,7 @@ from dvc import dependency, output from dvc.output import CHECKSUMS_SCHEMA, BaseOutput -from dvc.parsing import USE_KWD, VARS_KWD +from dvc.parsing import FOREACH_KWD, IN_KWD, USE_KWD, VARS_KWD from dvc.stage.params import StageParams STAGES = "stages" @@ -48,26 +48,28 @@ PARAM_PSTAGE_NON_DEFAULT_SCHEMA = {str: [str]} -SINGLE_PIPELINE_STAGE_SCHEMA = { - str: { - StageParams.PARAM_CMD: str, - Optional(StageParams.PARAM_WDIR): str, - Optional(StageParams.PARAM_DEPS): [str], - Optional(StageParams.PARAM_PARAMS): [ - Any(str, PARAM_PSTAGE_NON_DEFAULT_SCHEMA) - ], - Optional(StageParams.PARAM_FROZEN): bool, - Optional(StageParams.PARAM_META): object, - Optional(StageParams.PARAM_ALWAYS_CHANGED): bool, - Optional(StageParams.PARAM_OUTS): [ - Any(str, OUT_PSTAGE_DETAILED_SCHEMA) - ], - Optional(StageParams.PARAM_METRICS): [ - Any(str, OUT_PSTAGE_DETAILED_SCHEMA) - ], - Optional(StageParams.PARAM_PLOTS): [Any(str, PLOT_PSTAGE_SCHEMA)], - } +STAGE_DEFINITION = { + StageParams.PARAM_CMD: str, + Optional(StageParams.PARAM_WDIR): str, + Optional(StageParams.PARAM_DEPS): [str], + Optional(StageParams.PARAM_PARAMS): [ + Any(str, PARAM_PSTAGE_NON_DEFAULT_SCHEMA) + ], + Optional(StageParams.PARAM_FROZEN): bool, + Optional(StageParams.PARAM_META): object, + Optional(StageParams.PARAM_ALWAYS_CHANGED): bool, + Optional(StageParams.PARAM_OUTS): [Any(str, OUT_PSTAGE_DETAILED_SCHEMA)], + Optional(StageParams.PARAM_METRICS): [ + Any(str, OUT_PSTAGE_DETAILED_SCHEMA) + ], + Optional(StageParams.PARAM_PLOTS): [Any(str, PLOT_PSTAGE_SCHEMA)], } + +FOREACH_IN = { + Required(FOREACH_KWD): Any(dict, list, str), + Required(IN_KWD): STAGE_DEFINITION, +} +SINGLE_PIPELINE_STAGE_SCHEMA = {str: Any(STAGE_DEFINITION, FOREACH_IN)} MULTI_STAGE_SCHEMA = { STAGES: SINGLE_PIPELINE_STAGE_SCHEMA, USE_KWD: str, diff --git a/tests/func/test_stage_resolver.py b/tests/func/test_stage_resolver.py index 2c36aac9d1..c4b9176b74 100644 --- a/tests/func/test_stage_resolver.py +++ b/tests/func/test_stage_resolver.py @@ -233,3 +233,72 @@ def test_with_templated_wdir(tmp_dir, dvc): } }, ) + + +def test_simple_foreach_loop(tmp_dir, dvc): + iterable = ["foo", "bar", "baz"] + d = { + "stages": { + "build": { + "foreach": iterable, + "in": {"cmd": "python script.py ${item}"}, + } + } + } + + resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d) + assert resolver.resolve() == { + "stages": { + f"build-{item}": {"cmd": f"python script.py {item}"} + for item in iterable + } + } + + +def test_foreach_loop_dict(tmp_dir, dvc): + iterable = {"models": {"us": {"thresh": 10}, "gb": {"thresh": 15}}} + d = { + "stages": { + "build": { + "foreach": iterable["models"], + "in": {"cmd": "python script.py ${item.thresh}"}, + } + } + } + + resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d) + assert resolver.resolve() == { + "stages": { + f"build-{key}": {"cmd": f"python script.py {item['thresh']}"} + for key, item in iterable["models"].items() + } + } + + +def test_foreach_loop_templatized(tmp_dir, dvc): + params = {"models": {"us": {"thresh": 10}}} + vars_ = {"models": {"gb": {"thresh": 15}}} + dump_yaml(tmp_dir / DEFAULT_PARAMS_FILE, params) + d = { + "vars": vars_, + "stages": { + "build": { + "foreach": "${models}", + "in": {"cmd": "python script.py --thresh ${item.thresh}"}, + } + }, + } + + resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d) + assert_stage_equal( + resolver.resolve(), + { + "stages": { + "build-gb": {"cmd": "python script.py --thresh 15"}, + "build-us": { + "cmd": "python script.py --thresh 10", + "params": ["models.us.thresh"], + }, + } + }, + )