From 34cd66294aba868bf7427fab6e9f244fabef751b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Thu, 16 Sep 2021 16:36:53 +0545 Subject: [PATCH 1/5] exp init: refactor, add basic template support The template is not usable yet in packages nor it will be initialized during 'dvc init' though. But it supports custom templates if user adds the template in .dvc/stages or in dvc's codebase on resources/stages. --- dvc/command/experiments.py | 75 +++++++---------- dvc/repo/experiments/init.py | 121 ++++++++++++++++++++++++++++ requirements/default.txt | 1 + resources/stages/default.yaml | 17 ++++ resources/stages/live.yaml | 15 ++++ tests/func/experiments/test_init.py | 36 +++++++-- 6 files changed, 212 insertions(+), 53 deletions(-) create mode 100644 dvc/repo/experiments/init.py create mode 100644 resources/stages/default.yaml create mode 100644 resources/stages/live.yaml diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index e699832a4f..067a17cf3f 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -1,6 +1,5 @@ import argparse import logging -import os from collections import Counter, OrderedDict, defaultdict from datetime import date, datetime from fnmatch import fnmatch @@ -791,58 +790,36 @@ def run(self): class CmdExperimentsInit(CmdBase): - CODE = "src" - DATA = "data" - MODELS = "models" - DEFAULT_METRICS = "metrics.json" - DEFAULT_PARAMS = "params.yaml" - PLOTS = "plots" - DVCLIVE = "dvclive" - DEFAULT_NAME = "default" - def run(self): from dvc.command.stage import parse_cmd + from dvc.repo.experiments.init import init cmd = parse_cmd(self.args.cmd) if not cmd: raise InvalidArgumentError("command is not specified") - if self.args.interactive: - raise NotImplementedError( - "'-i/--interactive' is not implemented yet." - ) - if self.args.explicit: - raise NotImplementedError("'--explicit' is not implemented yet.") - if self.args.template: - raise NotImplementedError("template is not supported yet.") - - from dvc.utils.serialize import LOADERS - - code = self.args.code or self.CODE - data = self.args.data or self.DATA - models = self.args.models or self.MODELS - metrics = self.args.metrics or self.DEFAULT_METRICS - params_path = self.args.params or self.DEFAULT_PARAMS - plots = self.args.plots or self.PLOTS - dvclive = self.args.live or self.DVCLIVE - - _, ext = os.path.splitext(params_path) - params = list(LOADERS[ext](params_path)) - - name = self.args.name or self.DEFAULT_NAME - stage = self.repo.stage.add( - name=name, - cmd=cmd, - deps=[code, data], - outs=[models], - params=[{params_path: params}], - metrics_no_cache=[metrics], - plots_no_cache=[plots], - live=dvclive, - force=True, - ) + data = { + "cmd": cmd, + "code": self.args.code, + "data": self.args.data, + "models": self.args.models, + "metrics": self.args.metrics, + "params": self.args.params, + "plots": self.args.plots, + "live": self.args.live, + } + + initialized_stage = init( + self.repo, + data, + template_name=self.args.template_name, + interactive=self.args.interactive, + explicit=self.args.explicit, + ) if self.args.run: - return self.repo.experiments.run(targets=[stage.addressing]) + return self.repo.experiments.run( + targets=[initialized_stage.addressing] + ) return 0 @@ -1385,10 +1362,14 @@ def add_parser(subparsers, parent_parser): help="Prompt for values that are not provided", ) experiments_init_parser.add_argument( - "--template", help="Stage template to use to fill with provided values" + "--template", + dest="template_name", + help="Stage template to use to fill with provided values", ) experiments_init_parser.add_argument( - "--explicit", help="Only use the path values explicitly provided" + "--explicit", + action="store_true", + help="Only use the path values explicitly provided", ) experiments_init_parser.add_argument( "--name", "-n", help="Name of the stage to create" diff --git a/dvc/repo/experiments/init.py b/dvc/repo/experiments/init.py new file mode 100644 index 0000000000..130d813f2a --- /dev/null +++ b/dvc/repo/experiments/init.py @@ -0,0 +1,121 @@ +import dataclasses +import os +from collections import ChainMap +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Dict, Optional + +from funcy import compact +from voluptuous import MultipleInvalid, Schema + +from dvc.exceptions import DvcException +from dvc.schema import STAGE_DEFINITION + +if TYPE_CHECKING: + from jinja2 import BaseLoader + + from dvc.repo import Repo + + +DEFAULT_TEMPLATE = "default" + + +@dataclasses.dataclass +class TemplateDefaults: + code: str = "src" + data: str = "data" + models: str = "models" + metrics: str = "metrics.json" + params: str = "params.yaml" + plots: str = "plots" + live: str = "dvclive" + + +DEFAULT_VALUES = dataclasses.asdict(TemplateDefaults()) +STAGE_SCHEMA = Schema(STAGE_DEFINITION) + + +def get_loader(repo: "Repo") -> "BaseLoader": + from jinja2 import ChoiceLoader, FileSystemLoader + + default_path = Path(__file__).parents[3] / "resources" / "stages" + return ChoiceLoader( + [ + # not initialized yet + FileSystemLoader(Path(repo.dvc_dir) / "stages"), + # won't work for other packages + FileSystemLoader(default_path), + ] + ) + + +def init( + repo: "Repo", + data: Dict[str, Optional[object]], + template_name: str = None, + interactive: bool = False, + explicit: bool = False, + template_loader: Callable[["Repo"], "BaseLoader"] = get_loader, + force: bool = False, +): + from jinja2 import Environment + + from dvc.dvcfile import make_dvcfile + from dvc.stage import check_circular_dependency, check_duplicated_arguments + from dvc.stage.loader import StageLoader + from dvc.utils.serialize import LOADERS, parse_yaml_for_update + + data = compact(data) # remove None values + loader = template_loader(repo) + environment = Environment(loader=loader) + name = template_name or DEFAULT_TEMPLATE + + dvcfile = make_dvcfile(repo, "dvc.yaml") + if not force and dvcfile.exists() and name in dvcfile.stages: + raise DvcException(f"stage '{name}' already exists.") + + template = environment.get_template(f"{name}.yaml") + context = ChainMap(data) + if interactive: + # TODO: interactive requires us to check for variables present + # in the template and, adapt our prompts accordingly. + raise NotImplementedError("'-i/--interactive' is not supported yet.") + if not explicit: + context.maps.append(DEFAULT_VALUES) + else: + # TODO: explicit requires us to check for undefined variables. + raise NotImplementedError("'--explicit' is not implemented yet.") + + assert "params" in context + # See https://github.com/iterative/dvc/issues/6605 for the support + # for depending on all params of a file. + param_path = str(context["params"]) + _, ext = os.path.splitext(param_path) + param_names = list(LOADERS[ext](param_path)) + + # render, parse yaml and then validate schema + rendered = template.render(**context, param_names=param_names) + template_path = os.path.relpath(template.filename) + data = parse_yaml_for_update(rendered, template_path) + try: + validated = STAGE_SCHEMA(data) + except MultipleInvalid as exc: + raise DvcException( + f"template '{template_path}'" + "failed schema validation while rendering" + ) from exc + + stage = StageLoader.load_stage(dvcfile, name, validated) + # ensure correctness, similar to what we have in `repo.stage.add` + check_circular_dependency(stage) + check_duplicated_arguments(stage) + new_index = repo.index.add(stage) + new_index.check_graph() + + with repo.scm.track_file_changes(config=repo.config): + # note that we are not dumping the "template" as-is + # we are dumping a stage data, which is processed + # so formatting-wise, it may look different. + stage.dump(update_lock=False) + stage.ignore_outs() + + return stage diff --git a/requirements/default.txt b/requirements/default.txt index 7fa2d2c5f1..733222ea5c 100644 --- a/requirements/default.txt +++ b/requirements/default.txt @@ -44,3 +44,4 @@ typing_extensions>=3.10.0.2 fsspec[http]>=2021.8.1 aiohttp-retry==2.4.5 diskcache>=5.2.1 +jinja2>=2.11.3 diff --git a/resources/stages/default.yaml b/resources/stages/default.yaml new file mode 100644 index 0000000000..8422940ef1 --- /dev/null +++ b/resources/stages/default.yaml @@ -0,0 +1,17 @@ +cmd: {{ cmd }} +deps: +- {{ code }} +- {{ data }} +params: +- {{ params }}: + {% for p in param_names %} + - {{ p }} + {% endfor %} +outs: +- {{ models }} +metrics: +- {{ metrics }}: + cache: false +plots: +- {{ plots }}: + cache: false diff --git a/resources/stages/live.yaml b/resources/stages/live.yaml new file mode 100644 index 0000000000..188c243f76 --- /dev/null +++ b/resources/stages/live.yaml @@ -0,0 +1,15 @@ +cmd: {{ cmd }} +deps: +- {{ code }} +- {{ data }} +params: +- {{ params }}: + {% for p in param_names %} + - {{ p }} + {% endfor %} +outs: +- {{ models }} +live: + {{ live }}: + summary: true + html: true diff --git a/tests/func/experiments/test_init.py b/tests/func/experiments/test_init.py index d55451d8f5..919b93e208 100644 --- a/tests/func/experiments/test_init.py +++ b/tests/func/experiments/test_init.py @@ -1,30 +1,27 @@ import os -from dvc.command.experiments import CmdExperimentsInit from dvc.main import main -from dvc.utils.serialize import load_yaml def test_init(tmp_dir, dvc): tmp_dir.gen( { - CmdExperimentsInit.CODE: {"copy.py": ""}, + "src": {"copy.py": ""}, "data": "data", "params.yaml": '{"foo": 1}', "dvclive": {}, "plots": {}, } ) - code_path = os.path.join(CmdExperimentsInit.CODE, "copy.py") + code_path = os.path.join("src", "copy.py") script = f"python {code_path}" assert main(["exp", "init", script]) == 0 - assert load_yaml(tmp_dir / "dvc.yaml") == { + assert (tmp_dir / "dvc.yaml").parse() == { "stages": { "default": { "cmd": script, "deps": ["data", "src"], - "live": {"dvclive": {"html": True, "summary": True}}, "metrics": [{"metrics.json": {"cache": False}}], "outs": ["models"], "params": ["foo"], @@ -32,3 +29,30 @@ def test_init(tmp_dir, dvc): } } } + + +def test_init_live(tmp_dir, dvc): + tmp_dir.gen( + { + "src": {"copy.py": ""}, + "data": "data", + "params.yaml": '{"foo": 1}', + "dvclive": {}, + "plots": {}, + } + ) + code_path = os.path.join("src", "copy.py") + script = f"python {code_path}" + + assert main(["exp", "init", "--template", "live", script]) == 0 + assert (tmp_dir / "dvc.yaml").parse() == { + "stages": { + "live": { + "cmd": script, + "deps": ["data", "src"], + "outs": ["models"], + "params": ["foo"], + "live": {"dvclive": {"html": True, "summary": True}}, + } + } + } From 33efc62ba144542f8b178cac301fda88d1b41297 Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Thu, 16 Sep 2021 17:01:44 +0545 Subject: [PATCH 2/5] Update dvc/repo/experiments/init.py --- dvc/repo/experiments/init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dvc/repo/experiments/init.py b/dvc/repo/experiments/init.py index 130d813f2a..77a1d825c3 100644 --- a/dvc/repo/experiments/init.py +++ b/dvc/repo/experiments/init.py @@ -100,7 +100,7 @@ def init( validated = STAGE_SCHEMA(data) except MultipleInvalid as exc: raise DvcException( - f"template '{template_path}'" + f"template '{template_path}' " "failed schema validation while rendering" ) from exc From 076aad42f6985c5c2f26209f072d3a6efb5acfa5 Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Thu, 16 Sep 2021 17:58:00 +0545 Subject: [PATCH 3/5] Apply suggestions from code review --- dvc/repo/experiments/init.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dvc/repo/experiments/init.py b/dvc/repo/experiments/init.py index 77a1d825c3..bc421c49b6 100644 --- a/dvc/repo/experiments/init.py +++ b/dvc/repo/experiments/init.py @@ -59,7 +59,7 @@ def init( ): from jinja2 import Environment - from dvc.dvcfile import make_dvcfile + from dvc.utils import relpath from dvc.stage import check_circular_dependency, check_duplicated_arguments from dvc.stage.loader import StageLoader from dvc.utils.serialize import LOADERS, parse_yaml_for_update @@ -94,7 +94,7 @@ def init( # render, parse yaml and then validate schema rendered = template.render(**context, param_names=param_names) - template_path = os.path.relpath(template.filename) + template_path = relpath(template.filename) data = parse_yaml_for_update(rendered, template_path) try: validated = STAGE_SCHEMA(data) From 7a861e4868c4a6b856067c0f2515f04c311e4b1e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Sep 2021 12:14:10 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dvc/repo/experiments/init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dvc/repo/experiments/init.py b/dvc/repo/experiments/init.py index bc421c49b6..877d1ddc81 100644 --- a/dvc/repo/experiments/init.py +++ b/dvc/repo/experiments/init.py @@ -59,9 +59,9 @@ def init( ): from jinja2 import Environment - from dvc.utils import relpath from dvc.stage import check_circular_dependency, check_duplicated_arguments from dvc.stage.loader import StageLoader + from dvc.utils import relpath from dvc.utils.serialize import LOADERS, parse_yaml_for_update data = compact(data) # remove None values From f15c892cc2d7fbba56d81bb2b4739f6d4b63895c Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Thu, 16 Sep 2021 18:03:02 +0545 Subject: [PATCH 5/5] Update dvc/repo/experiments/init.py --- dvc/repo/experiments/init.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dvc/repo/experiments/init.py b/dvc/repo/experiments/init.py index 877d1ddc81..17abe25cb6 100644 --- a/dvc/repo/experiments/init.py +++ b/dvc/repo/experiments/init.py @@ -59,6 +59,7 @@ def init( ): from jinja2 import Environment + from dvc.dvcfile import make_dvcfile from dvc.stage import check_circular_dependency, check_duplicated_arguments from dvc.stage.loader import StageLoader from dvc.utils import relpath