Skip to content

Commit

Permalink
Implement foreach ... in loop in dvc.yaml (#4734)
Browse files Browse the repository at this point in the history
* Implement foreach...in loop in dvc.yaml

* fix tests
  • Loading branch information
skshetry authored Nov 2, 2020
1 parent 96cdd07 commit e211423
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 20 deletions.
28 changes: 28 additions & 0 deletions dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from collections import defaultdict
from collections.abc import Mapping, Sequence
from copy import deepcopy
from itertools import starmap
from typing import TYPE_CHECKING
Expand All @@ -25,6 +26,10 @@
WDIR_KWD = "wdir"
DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE
PARAMS_KWD = "params"
FOREACH_KWD = "foreach"
IN_KWD = "in"

DEFAULT_SENTINEL = object()


class DataResolver:
Expand All @@ -50,6 +55,11 @@ def __init__(self, repo: "Repo", wdir: PathInfo, d: dict):

def _resolve_entry(self, name: str, definition):
context = Context.clone(self.global_ctx)
if FOREACH_KWD in definition:
assert IN_KWD in definition
return self._foreach(
context, name, definition[FOREACH_KWD], definition[IN_KWD]
)
return self._resolve_stage(context, name, definition)

def resolve(self):
Expand Down Expand Up @@ -114,3 +124,21 @@ def _resolve_wdir(self, context: Context, wdir: str = None) -> PathInfo:
return self.wdir
wdir = resolve(wdir, context)
return self.wdir / str(wdir)

def _foreach(self, context: Context, name: str, foreach_data, in_data):
def each_iter(value, key=DEFAULT_SENTINEL):
c = Context.clone(context)
c["item"] = value
if key is not DEFAULT_SENTINEL:
c["key"] = key
suffix = str(key if key is not DEFAULT_SENTINEL else value)
return self._resolve_stage(c, f"{name}-{suffix}", in_data)

iterable = resolve(foreach_data, context)
if isinstance(iterable, Sequence):
gen = (each_iter(v) for v in iterable)
elif isinstance(iterable, Mapping):
gen = (each_iter(v, k) for k, v in iterable.items())
else:
raise Exception(f"got type of {type(iterable)}")
return join(gen)
42 changes: 22 additions & 20 deletions dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dvc import dependency, output
from dvc.output import CHECKSUMS_SCHEMA, BaseOutput
from dvc.parsing import USE_KWD, VARS_KWD
from dvc.parsing import FOREACH_KWD, IN_KWD, USE_KWD, VARS_KWD
from dvc.stage.params import StageParams

STAGES = "stages"
Expand Down Expand Up @@ -48,26 +48,28 @@

PARAM_PSTAGE_NON_DEFAULT_SCHEMA = {str: [str]}

SINGLE_PIPELINE_STAGE_SCHEMA = {
str: {
StageParams.PARAM_CMD: str,
Optional(StageParams.PARAM_WDIR): str,
Optional(StageParams.PARAM_DEPS): [str],
Optional(StageParams.PARAM_PARAMS): [
Any(str, PARAM_PSTAGE_NON_DEFAULT_SCHEMA)
],
Optional(StageParams.PARAM_FROZEN): bool,
Optional(StageParams.PARAM_META): object,
Optional(StageParams.PARAM_ALWAYS_CHANGED): bool,
Optional(StageParams.PARAM_OUTS): [
Any(str, OUT_PSTAGE_DETAILED_SCHEMA)
],
Optional(StageParams.PARAM_METRICS): [
Any(str, OUT_PSTAGE_DETAILED_SCHEMA)
],
Optional(StageParams.PARAM_PLOTS): [Any(str, PLOT_PSTAGE_SCHEMA)],
}
STAGE_DEFINITION = {
StageParams.PARAM_CMD: str,
Optional(StageParams.PARAM_WDIR): str,
Optional(StageParams.PARAM_DEPS): [str],
Optional(StageParams.PARAM_PARAMS): [
Any(str, PARAM_PSTAGE_NON_DEFAULT_SCHEMA)
],
Optional(StageParams.PARAM_FROZEN): bool,
Optional(StageParams.PARAM_META): object,
Optional(StageParams.PARAM_ALWAYS_CHANGED): bool,
Optional(StageParams.PARAM_OUTS): [Any(str, OUT_PSTAGE_DETAILED_SCHEMA)],
Optional(StageParams.PARAM_METRICS): [
Any(str, OUT_PSTAGE_DETAILED_SCHEMA)
],
Optional(StageParams.PARAM_PLOTS): [Any(str, PLOT_PSTAGE_SCHEMA)],
}

FOREACH_IN = {
Required(FOREACH_KWD): Any(dict, list, str),
Required(IN_KWD): STAGE_DEFINITION,
}
SINGLE_PIPELINE_STAGE_SCHEMA = {str: Any(STAGE_DEFINITION, FOREACH_IN)}
MULTI_STAGE_SCHEMA = {
STAGES: SINGLE_PIPELINE_STAGE_SCHEMA,
USE_KWD: str,
Expand Down
69 changes: 69 additions & 0 deletions tests/func/test_stage_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,72 @@ def test_with_templated_wdir(tmp_dir, dvc):
}
},
)


def test_simple_foreach_loop(tmp_dir, dvc):
iterable = ["foo", "bar", "baz"]
d = {
"stages": {
"build": {
"foreach": iterable,
"in": {"cmd": "python script.py ${item}"},
}
}
}

resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
f"build-{item}": {"cmd": f"python script.py {item}"}
for item in iterable
}
}


def test_foreach_loop_dict(tmp_dir, dvc):
iterable = {"models": {"us": {"thresh": 10}, "gb": {"thresh": 15}}}
d = {
"stages": {
"build": {
"foreach": iterable["models"],
"in": {"cmd": "python script.py ${item.thresh}"},
}
}
}

resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
f"build-{key}": {"cmd": f"python script.py {item['thresh']}"}
for key, item in iterable["models"].items()
}
}


def test_foreach_loop_templatized(tmp_dir, dvc):
params = {"models": {"us": {"thresh": 10}}}
vars_ = {"models": {"gb": {"thresh": 15}}}
dump_yaml(tmp_dir / DEFAULT_PARAMS_FILE, params)
d = {
"vars": vars_,
"stages": {
"build": {
"foreach": "${models}",
"in": {"cmd": "python script.py --thresh ${item.thresh}"},
}
},
}

resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert_stage_equal(
resolver.resolve(),
{
"stages": {
"build-gb": {"cmd": "python script.py --thresh 15"},
"build-us": {
"cmd": "python script.py --thresh 10",
"params": ["models.us.thresh"],
},
}
},
)

0 comments on commit e211423

Please sign in to comment.