diff --git a/dvc/commands/experiments/__init__.py b/dvc/commands/experiments/__init__.py index d7e77d81e3..1a1335f05b 100644 --- a/dvc/commands/experiments/__init__.py +++ b/dvc/commands/experiments/__init__.py @@ -8,7 +8,6 @@ diff, exec_run, gc, - init, ls, pull, push, @@ -26,7 +25,6 @@ diff, exec_run, gc, - init, ls, pull, push, diff --git a/dvc/commands/experiments/init.py b/dvc/commands/experiments/init.py deleted file mode 100644 index 1d2791328b..0000000000 --- a/dvc/commands/experiments/init.py +++ /dev/null @@ -1,219 +0,0 @@ -import argparse -import logging -from typing import TYPE_CHECKING, List - -from funcy import compact - -from dvc.cli.command import CmdBase -from dvc.cli.utils import append_doc_link -from dvc.exceptions import InvalidArgumentError -from dvc.ui import ui - -if TYPE_CHECKING: - from dvc.dependency import Dependency - from dvc.stage import PipelineStage - - -logger = logging.getLogger(__name__) - - -class CmdExperimentsInit(CmdBase): - DEFAULT_NAME = "train" - CODE = "src" - DATA = "data" - MODELS = "models" - DEFAULT_METRICS = "metrics.json" - DEFAULT_PARAMS = "params.yaml" - PLOTS = "plots" - DEFAULTS = { - "code": CODE, - "data": DATA, - "models": MODELS, - "metrics": DEFAULT_METRICS, - "params": DEFAULT_PARAMS, - "plots": PLOTS, - } - - def run(self): - from dvc.commands.stage import parse_cmd - - cmd = parse_cmd(self.args.command) - if not self.args.interactive and not cmd: - raise InvalidArgumentError("command is not specified") - - from dvc.repo.experiments.init import init - - defaults = {} - if not self.args.explicit: - config = self.repo.config["exp"] - defaults.update({**self.DEFAULTS, **config}) - - cli_args = compact( - { - "cmd": cmd, - "code": self.args.code, - "data": self.args.data, - "models": self.args.models, - "metrics": self.args.metrics, - "params": self.args.params, - "plots": self.args.plots, - "live": self.args.live, - } - ) - - initialized_stage, initialized_deps, initialized_out_dirs = init( - self.repo, - name=self.args.name, - type=self.args.type, - defaults=defaults, - overrides=cli_args, - interactive=self.args.interactive, - force=self.args.force, - ) - self._post_init_display( - initialized_stage, initialized_deps, initialized_out_dirs - ) - if self.args.run: - self.repo.experiments.run(targets=[initialized_stage.addressing]) - return 0 - - def _post_init_display( - self, - stage: "PipelineStage", - new_deps: List["Dependency"], - new_out_dirs: List[str], - ) -> None: - from dvc.utils import humanize - - path_fmt = "[green]{}[/green]".format - if new_deps: - deps_paths = humanize.join(map(path_fmt, new_deps)) - ui.write(f"Creating dependencies: {deps_paths}", styled=True) - - if new_out_dirs: - out_dirs_paths = humanize.join(map(path_fmt, new_out_dirs)) - ui.write(f"Creating output directories: {out_dirs_paths}", styled=True) - - ui.write( - f"Creating [b]{self.args.name}[/b] stage in [green]dvc.yaml[/]", - styled=True, - ) - if stage.outs or not self.args.run: - # separate the above status-like messages with help/tips section - ui.write(styled=True) - - if stage.outs: - outs_paths = humanize.join(map(path_fmt, stage.outs)) - tips = f"Ensure your experiment command creates {outs_paths}." - ui.write(tips, styled=True) - - if not self.args.run: - ui.write( - 'You can now run your experiment using [b]"dvc exp run"[/].', - styled=True, - ) - else: - # separate between `exp.run` output and `dvc exp init` output - ui.write(styled=True) - - -def add_parser(experiments_subparsers, parent_parser): - EXPERIMENTS_INIT_HELP = "Quickly setup any project to use experiments." - experiments_init_parser = experiments_subparsers.add_parser( - "init", - parents=[parent_parser], - description=append_doc_link(EXPERIMENTS_INIT_HELP, "exp/init"), - formatter_class=argparse.RawDescriptionHelpFormatter, - help=EXPERIMENTS_INIT_HELP, - ) - experiments_init_parser.add_argument( - "command", - nargs=argparse.REMAINDER, - help="Command to execute.", - metavar="command", - ) - experiments_init_parser.add_argument( - "--run", - action="store_true", - help="Run the experiment after initializing it", - ) - experiments_init_parser.add_argument( - "--interactive", - "-i", - action="store_true", - help="Prompt for values that are not provided", - ) - experiments_init_parser.add_argument( - "-f", - "--force", - action="store_true", - default=False, - help="Overwrite existing stage", - ) - experiments_init_parser.add_argument( - "--explicit", - action="store_true", - default=False, - help="Only use the path values explicitly provided", - ) - experiments_init_parser.add_argument( - "--name", - "-n", - help="Name of the stage to create (default: %(default)s)", - default=CmdExperimentsInit.DEFAULT_NAME, - ) - experiments_init_parser.add_argument( - "--code", - help=( - "Path to the source file or directory " - "which your experiments depend" - f" (default: {CmdExperimentsInit.CODE})" - ), - ) - experiments_init_parser.add_argument( - "--data", - help=( - "Path to the data file or directory " - "which your experiments depend" - f" (default: {CmdExperimentsInit.DATA})" - ), - ) - experiments_init_parser.add_argument( - "--models", - help=( - "Path to the model file or directory for your experiments" - f" (default: {CmdExperimentsInit.MODELS})" - ), - ) - experiments_init_parser.add_argument( - "--params", - help=( - "Path to the parameters file for your experiments" - f" (default: {CmdExperimentsInit.DEFAULT_PARAMS})" - ), - ) - experiments_init_parser.add_argument( - "--metrics", - help=( - "Path to the metrics file for your experiments" - f" (default: {CmdExperimentsInit.DEFAULT_METRICS})" - ), - ) - experiments_init_parser.add_argument( - "--plots", - help=( - "Path to the plots file or directory for your experiments" - f" (default: {CmdExperimentsInit.PLOTS})" - ), - ) - experiments_init_parser.add_argument( - "--live", - help="Path to log dvclive outputs for your experiments", - ) - experiments_init_parser.add_argument( - "--type", - choices=["default", "checkpoint"], - default="default", - help="Select type of stage to create (default: %(default)s)", - ) - experiments_init_parser.set_defaults(func=CmdExperimentsInit) diff --git a/dvc/repo/experiments/init.py b/dvc/repo/experiments/init.py deleted file mode 100644 index d3f0b8ee49..0000000000 --- a/dvc/repo/experiments/init.py +++ /dev/null @@ -1,284 +0,0 @@ -import logging -import os -from contextlib import contextmanager -from functools import partial -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Optional, - TextIO, - Tuple, - Union, -) - -from funcy import compact, lremove, lsplit - -from dvc.exceptions import DvcException -from dvc.stage import PipelineStage - -if TYPE_CHECKING: - from dvc.dependency import Dependency - from dvc.dvcfile import ProjectFile, SingleStageFile - from dvc.repo import Repo - -from dvc.ui import ui - -PROMPTS = { - "cmd": "[b]Command[/b] to execute", - "code": "Path to a [b]code[/b] file/directory", - "data": "Path to a [b]data[/b] file/directory", - "models": "Path to a [b]model[/b] file/directory", - "params": "Path to a [b]parameters[/b] file", - "metrics": "Path to a [b]metrics[/b] file", - "plots": "Path to a [b]plots[/b] file/directory", -} - - -def _prompts( - keys: Iterable[str], - defaults: Optional[Dict[str, str]] = None, - validator: Optional[Callable[[str, str], Union[str, Tuple[str, str]]]] = None, - allow_omission: bool = True, - stream: Optional[TextIO] = None, -) -> Dict[str, Optional[str]]: - from dvc.ui.prompt import Prompt - - defaults = defaults or {} - return { - key: Prompt.prompt_( - PROMPTS[key], - console=ui.error_console, - default=defaults.get(key), - validator=partial(validator, key) if validator else None, - allow_omission=allow_omission, - stream=stream, - ) - for key in keys - } - - -@contextmanager -def _disable_logging(highest_level=logging.CRITICAL): - previous_level = logging.root.manager.disable - - logging.disable(highest_level) - - try: - yield - finally: - logging.disable(previous_level) - - -def init_interactive( - defaults: Dict[str, str], - provided: Dict[str, str], - validator: Optional[Callable[[str, str], Union[str, Tuple[str, str]]]] = None, - stream: Optional[TextIO] = None, -) -> Dict[str, str]: - command_prompts = lremove(provided.keys(), ["cmd"]) - dependencies_prompts = lremove(provided.keys(), ["code", "data", "params"]) - output_keys = ["models"] - if "live" not in provided: - output_keys.extend(["metrics", "plots"]) - outputs_prompts = lremove(provided.keys(), output_keys) - - ret: Dict[str, str] = {} - if "cmd" in provided: - ret["cmd"] = provided["cmd"] - - for heading, prompts, allow_omission in ( - ("", command_prompts, False), - ("Enter experiment dependencies.", dependencies_prompts, True), - ("Enter experiment outputs.", outputs_prompts, True), - ): - if prompts and heading: - ui.error_write(heading, styled=True) - response = _prompts( - prompts, - defaults=defaults, - allow_omission=allow_omission, - validator=validator, - stream=stream, - ) - ret.update(compact(response)) - if prompts: - ui.error_write(styled=True) - return ret - - -def _check_stage_exists( - dvcfile: Union["ProjectFile", "SingleStageFile"], - name: str, - force: bool = False, -) -> None: - if not force and dvcfile.exists() and name in dvcfile.stages: - from dvc.stage.exceptions import DuplicateStageName - - hint = "Use '--force' to overwrite." - raise DuplicateStageName(f"Stage '{name}' already exists in 'dvc.yaml'. {hint}") - - -def validate_prompts(repo: "Repo", key: str, value: str) -> Union[Any, Tuple[Any, str]]: - from dvc.ui.prompt import InvalidResponse - - msg_format = "[yellow]'{0}' does not exist, the {1} will be created.[/]" - if key == "params": - from dvc.dependency.param import ( - MissingParamsFile, - ParamsDependency, - ParamsIsADirectoryError, - ) - - assert isinstance(value, str) - try: - ParamsDependency(None, value, repo=repo).validate_filepath() - except MissingParamsFile: - return value, msg_format.format(value, "file") - except ParamsIsADirectoryError: - raise InvalidResponse( # noqa: B904 - f"[prompt.invalid]'{value}' is a directory. " - "Please retry with an existing parameters file." - ) - elif key in ("code", "data") and not os.path.exists(value): - typ = "file" if is_file(value) else "directory" - return value, msg_format.format(value, typ) - return value - - -def is_file(path: str) -> bool: - _, ext = os.path.splitext(path) - return bool(ext) - - -def init_deps(stage: PipelineStage) -> List["Dependency"]: - from funcy import rpartial - - from dvc.dependency import ParamsDependency - from dvc.fs import localfs - - new_deps = [dep for dep in stage.deps if not dep.exists] - params, deps = lsplit(rpartial(isinstance, ParamsDependency), new_deps) - - # always create a file for params, detect file/folder based on extension - # for other dependencies - dirs = [dep.fs_path for dep in deps if not is_file(dep.fs_path)] - files = [dep.fs_path for dep in deps + params if is_file(dep.fs_path)] - for path in dirs: - localfs.makedirs(path) - for path in files: - localfs.makedirs(localfs.path.parent(path), exist_ok=True) - localfs.touch(path) - - return new_deps - - -def init_out_dirs(stage: PipelineStage) -> List[str]: - from dvc.fs import localfs - - new_dirs = [] - - # create dirs for outputs - for out in stage.outs: - path = out.def_path - if is_file(path): - path = localfs.path.parent(path) - if path and not localfs.exists(path): - localfs.makedirs(path) - new_dirs.append(path) - - return new_dirs - - -def init( - repo: "Repo", - name: str = "train", - type: str = "default", # noqa: A002, pylint: disable=redefined-builtin - defaults: Optional[Dict[str, str]] = None, - overrides: Optional[Dict[str, str]] = None, - interactive: bool = False, - force: bool = False, - stream: Optional[TextIO] = None, -) -> Tuple[PipelineStage, List["Dependency"], List[str]]: - from dvc.dvcfile import PROJECT_FILE, load_file - - dvcfile = load_file(repo, PROJECT_FILE) - _check_stage_exists(dvcfile, name, force=force) - - defaults = defaults.copy() if defaults else {} - overrides = overrides.copy() if overrides else {} - - if interactive: - defaults = init_interactive( - validator=partial(validate_prompts, repo), - defaults=defaults, - provided=overrides, - stream=stream, - ) - elif "live" in overrides: - # suppress `metrics`/`plots` if live is selected. - defaults.pop("metrics", None) - defaults.pop("plots", None) - else: - defaults.pop("live", None) # suppress live otherwise - - context: Dict[str, str] = {**defaults, **overrides} - assert "cmd" in context - - params = context.get("params") - if params: - from dvc.dependency.param import ( - MissingParamsFile, - ParamsDependency, - ParamsIsADirectoryError, - ) - - try: - ParamsDependency(None, params, repo=repo).validate_filepath() - except ParamsIsADirectoryError as exc: - raise DvcException(f"{exc}.") # noqa: B904 # swallow cause for display - except MissingParamsFile: - pass - - if type == "checkpoint": - outs_key = "checkpoints" - metrics_key = "metrics_persist_no_cache" - plots_key = "plots_persist_no_cache" - else: - outs_key = "outs" - metrics_key = "metrics_no_cache" - plots_key = "plots_no_cache" - - models = [context.get("models")] - metrics = [context.get("metrics")] - plots = [context.get("plots")] - if live_path := context.pop("live", None): - metrics.append(os.path.join(live_path, "metrics.json")) - plots.append(os.path.join(live_path, "plots")) - - stage = repo.stage.create( - name=name, - cmd=context["cmd"], - deps=compact([context.get("code"), context.get("data")]), - params=[{params: None}] if params else None, - force=force, - **{ - outs_key: compact(models), - metrics_key: compact(metrics), - plots_key: compact(plots), - }, - ) - assert isinstance(stage, PipelineStage) - - with _disable_logging(), repo.scm_context(autostage=True, quiet=True): - stage.dump(update_lock=False) - initialized_out_dirs = init_out_dirs(stage) - stage.ignore_outs() - initialized_deps = init_deps(stage) - if params: - repo.scm_context.track_file(params) - - return stage, initialized_deps, initialized_out_dirs diff --git a/tests/func/experiments/test_init.py b/tests/func/experiments/test_init.py deleted file mode 100644 index 95a2d98e85..0000000000 --- a/tests/func/experiments/test_init.py +++ /dev/null @@ -1,500 +0,0 @@ -import io -import os - -import pytest - -from dvc.cli import main -from dvc.commands.experiments.init import CmdExperimentsInit -from dvc.repo.experiments.init import init -from dvc.stage.exceptions import DuplicateStageName - -# the tests may hang on prompts on failure -pytestmark = pytest.mark.timeout(3, func_only=True) - - -@pytest.mark.timeout(5, func_only=True) -def test_init_simple(tmp_dir, scm, dvc, capsys): - tmp_dir.gen( - { - CmdExperimentsInit.CODE: {"copy.py": ""}, - "data": "data", - "params.yaml": '{"foo": 1}', - "plots": {}, - } - ) - code_path = os.path.join(CmdExperimentsInit.CODE, "copy.py") - script = f"python {code_path}" - - capsys.readouterr() - assert main(["exp", "init", script]) == 0 - out, err = capsys.readouterr() - assert not err - assert ( - "Creating train stage in dvc.yaml\n\n" - "Ensure your experiment command creates metrics.json, plots and models" - '.\nYou can now run your experiment using "dvc exp run".' in out - ) - assert (tmp_dir / "dvc.yaml").parse() == { - "stages": { - "train": { - "cmd": script, - "deps": ["data", "src"], - "metrics": [{"metrics.json": {"cache": False}}], - "outs": ["models"], - "params": [{"params.yaml": None}], - "plots": [{"plots": {"cache": False}}], - } - } - } - assert (tmp_dir / "data").read_text() == "data" - assert (tmp_dir / "src").is_dir() - - -@pytest.mark.parametrize("interactive", [True, False]) -def test_when_stage_already_exists_with_same_name(tmp_dir, dvc, interactive): - (tmp_dir / "dvc.yaml").dump({"stages": {"train": {"cmd": "test"}}}) - with pytest.raises(DuplicateStageName) as exc: - init( - dvc, - interactive=interactive, - overrides={"cmd": "true"}, - defaults=CmdExperimentsInit.DEFAULTS, - ) - assert ( - str(exc.value) - == "Stage 'train' already exists in 'dvc.yaml'. Use '--force' to overwrite." - ) - - -def test_when_stage_force_if_already_exists(tmp_dir, dvc): - (tmp_dir / "params.yaml").dump({"foo": 1}) - (tmp_dir / "dvc.yaml").dump({"stages": {"train": {"cmd": "test"}}}) - init( - dvc, - force=True, - overrides={"cmd": "true"}, - defaults=CmdExperimentsInit.DEFAULTS, - ) - d = (tmp_dir / "dvc.yaml").parse() - assert d["stages"]["train"]["cmd"] == "true" - - -@pytest.mark.parametrize("interactive", [True, False]) -def test_creates_params_file_by_default(tmp_dir, dvc, interactive): - init( - dvc, - interactive=interactive, - defaults=CmdExperimentsInit.DEFAULTS, - overrides={"cmd": "cmd"}, - stream=io.StringIO(""), - ) - - assert (tmp_dir / "params.yaml").is_file() - assert (tmp_dir / "params.yaml").parse() == {} - - -def test_with_a_custom_name(tmp_dir, dvc): - init(dvc, name="custom", overrides={"cmd": "cmd"}) - assert (tmp_dir / "dvc.yaml").parse() == {"stages": {"custom": {"cmd": "cmd"}}} - - -def test_init_with_no_defaults_non_interactive(tmp_dir, scm, dvc): - init(dvc, defaults={}, overrides={"cmd": "python script.py"}) - - assert (tmp_dir / "dvc.yaml").parse() == { - "stages": {"train": {"cmd": "python script.py"}} - } - scm._reset() - assert not (tmp_dir / "dvc.lock").exists() - assert scm.is_tracked("dvc.yaml") - - -@pytest.mark.parametrize( - "extra_overrides, inp", - [ - ({"cmd": "cmd"}, io.StringIO()), - ({}, io.StringIO("cmd")), - ], -) -def test_init_interactive_when_no_path_prompts_need_to_be_asked( - tmp_dir, dvc, extra_overrides, inp -): - """When we pass everything that's required of, it should not prompt us.""" - (tmp_dir / "params.yaml").dump({"foo": 1}) - init( - dvc, - interactive=True, - defaults=CmdExperimentsInit.DEFAULTS, - overrides={**CmdExperimentsInit.DEFAULTS, **extra_overrides}, - stream=inp, - ) - assert (tmp_dir / "dvc.yaml").parse() == { - "stages": { - "train": { - "cmd": "cmd", - "deps": ["data", "src"], - "metrics": [ - {"metrics.json": {"cache": False}}, - ], - "outs": ["models"], - "params": [{"params.yaml": None}], - "plots": [ - {"plots": {"cache": False}}, - ], - } - } - } - assert (tmp_dir / "src").is_dir() - assert (tmp_dir / "data").is_dir() - - -def test_when_params_is_omitted_in_interactive_mode(tmp_dir, scm, dvc): - (tmp_dir / "params.yaml").dump({"foo": 1}) - inp = io.StringIO("python script.py\nscript.py\ndata\nn") - - init(dvc, interactive=True, stream=inp, defaults=CmdExperimentsInit.DEFAULTS) - - assert (tmp_dir / "dvc.yaml").parse() == { - "stages": { - "train": { - "cmd": "python script.py", - "deps": ["data", "script.py"], - "metrics": [{"metrics.json": {"cache": False}}], - "outs": ["models"], - "plots": [{"plots": {"cache": False}}], - } - } - } - assert not (tmp_dir / "dvc.lock").exists() - assert not (tmp_dir / "script.py").read_text() - assert (tmp_dir / "data").is_dir() - scm._reset() - assert scm.is_tracked("dvc.yaml") - assert not scm.is_tracked("params.yaml") - assert scm.is_tracked(".gitignore") - assert scm.is_ignored("models") - - -def test_init_interactive_params_validation(tmp_dir, dvc, capsys): - tmp_dir.gen({"data": {"foo": "foo"}}) - (tmp_dir / "params.yaml").dump({"foo": 1}) - inp = io.StringIO("python script.py\nscript.py\ndata\ndata") - - init(dvc, stream=inp, interactive=True, defaults=CmdExperimentsInit.DEFAULTS) - - assert (tmp_dir / "dvc.yaml").parse() == { - "stages": { - "train": { - "cmd": "python script.py", - "deps": ["data", "script.py"], - "metrics": [{"metrics.json": {"cache": False}}], - "outs": ["models"], - "params": [{"params.yaml": None}], - "plots": [{"plots": {"cache": False}}], - } - } - } - assert not (tmp_dir / "script.py").read_text() - assert (tmp_dir / "data").is_dir() - - out, err = capsys.readouterr() - assert not out - assert ( - "Path to a parameters file [params.yaml, n to omit]: " - "'data' is a directory. " - "Please retry with an existing parameters file.\n" - "Path to a parameters file [params.yaml, n to omit]:" in err - ) - - -def test_init_with_no_defaults_interactive(tmp_dir, dvc): - inp = io.StringIO("script.py\ndata\nn\nmodel\nmetric\nn\n") - init( - dvc, - defaults={}, - overrides={"cmd": "python script.py"}, - interactive=True, - stream=inp, - ) - assert (tmp_dir / "dvc.yaml").parse() == { - "stages": { - "train": { - "cmd": "python script.py", - "deps": ["data", "script.py"], - "metrics": [{"metric": {"cache": False}}], - "outs": ["model"], - } - } - } - assert not (tmp_dir / "script.py").read_text() - assert (tmp_dir / "data").is_dir() - - -@pytest.mark.parametrize( - "interactive, overrides, inp", - [ - (False, {"cmd": "python script.py", "code": "script.py"}, None), - ( - True, - {}, - io.StringIO( - "python script.py\n" - "script.py\n" - "data\n" - "params.yaml\n" - "models\n" - "metrics.json\n" - "plots\n" - "y" - ), - ), - ], - ids=["non-interactive", "interactive"], -) -def test_init_default(tmp_dir, scm, dvc, interactive, overrides, inp, capsys): - (tmp_dir / "params.yaml").dump({"foo": {"bar": 1}}) - - init( - dvc, - interactive=interactive, - defaults=CmdExperimentsInit.DEFAULTS, - overrides=overrides, - stream=inp, - ) - - assert (tmp_dir / "dvc.yaml").parse() == { - "stages": { - "train": { - "cmd": "python script.py", - "deps": ["data", "script.py"], - "metrics": [{"metrics.json": {"cache": False}}], - "outs": ["models"], - "params": [{"params.yaml": None}], - "plots": [{"plots": {"cache": False}}], - } - } - } - assert not (tmp_dir / "dvc.lock").exists() - assert not (tmp_dir / "script.py").read_text() - assert (tmp_dir / "data").is_dir() - scm._reset() - assert scm.is_tracked("dvc.yaml") - assert scm.is_tracked("params.yaml") - assert scm.is_tracked(".gitignore") - assert scm.is_ignored("models") - out, err = capsys.readouterr() - - assert not out - if interactive: - assert "'script.py' does not exist, the file will be created." in err - assert "'data' does not exist, the directory will be created." in err - - -@pytest.mark.timeout(5, func_only=True) -@pytest.mark.parametrize( - "interactive, overrides, inp", - [ - (False, {"cmd": "python script.py", "code": "script.py"}, None), - ( - True, - {}, - io.StringIO("python script.py\nscript.py\ndata\nparams.yaml\nmodels\ny"), - ), - ( - True, - {"cmd": "python script.py"}, - io.StringIO("script.py\ndata\nparams.yaml\nmodels\ny"), - ), - ( - True, - {"cmd": "python script.py", "models": "models"}, - io.StringIO("script.py\ndata\nparams.yaml\ny"), - ), - ], - ids=[ - "non-interactive", - "interactive", - "interactive-cmd-provided", - "interactive-cmd-models-provided", - ], -) -def test_init_interactive_live(tmp_dir, scm, dvc, interactive, overrides, inp, capsys): - overrides["live"] = "dvclive" - - (tmp_dir / "params.yaml").dump({"foo": {"bar": 1}}) - - init( - dvc, - interactive=interactive, - defaults=CmdExperimentsInit.DEFAULTS, - overrides=overrides, - stream=inp, - ) - assert (tmp_dir / "dvc.yaml").parse() == { - "stages": { - "train": { - "cmd": "python script.py", - "deps": ["data", "script.py"], - "metrics": [ - {os.path.join("dvclive", "metrics.json"): {"cache": False}} - ], - "outs": ["models"], - "params": [{"params.yaml": None}], - "plots": [{os.path.join("dvclive", "plots"): {"cache": False}}], - } - } - } - assert not (tmp_dir / "dvc.lock").exists() - assert not (tmp_dir / "script.py").read_text() - assert (tmp_dir / "data").is_dir() - scm._reset() - assert scm.is_tracked("dvc.yaml") - assert scm.is_tracked("params.yaml") - assert scm.is_tracked(".gitignore") - assert scm.is_ignored("models") - - out, err = capsys.readouterr() - - assert not out - if interactive: - assert "'script.py' does not exist, the file will be created." in err - assert "'data' does not exist, the directory will be created." in err - - -@pytest.mark.parametrize( - "interactive, inp", - [ - (False, None), - (True, io.StringIO()), - ], -) -def test_init_with_type_checkpoint_and_models_plots_provided( - tmp_dir, dvc, interactive, inp -): - (tmp_dir / "params.yaml").dump({"foo": 1}) - init( - dvc, - type="checkpoint", - interactive=interactive, - stream=inp, - defaults=CmdExperimentsInit.DEFAULTS, - overrides={"cmd": "cmd", "metrics": "m", "plots": "p"}, - ) - assert (tmp_dir / "dvc.yaml").parse() == { - "stages": { - "train": { - "cmd": "cmd", - "deps": ["data", "src"], - "metrics": [ - {"m": {"cache": False, "persist": True}}, - ], - "outs": [{"models": {"checkpoint": True}}], - "params": [{"params.yaml": None}], - "plots": [ - {"p": {"cache": False, "persist": True}}, - ], - } - } - } - assert (tmp_dir / "src").is_dir() - assert (tmp_dir / "data").is_dir() - - -@pytest.mark.parametrize( - "interactive, inp", - [ - (False, None), - (True, io.StringIO()), - ], -) -def test_init_with_type_default_and_live_provided(tmp_dir, dvc, interactive, inp): - (tmp_dir / "params.yaml").dump({"foo": 1}) - init( - dvc, - interactive=interactive, - stream=inp, - defaults=CmdExperimentsInit.DEFAULTS, - overrides={"cmd": "cmd", "live": "live"}, - ) - assert (tmp_dir / "dvc.yaml").parse() == { - "stages": { - "train": { - "cmd": "cmd", - "deps": ["data", "src"], - "metrics": [ - {os.path.join("live", "metrics.json"): {"cache": False}}, - ], - "outs": ["models"], - "params": [{"params.yaml": None}], - "plots": [ - {os.path.join("live", "plots"): {"cache": False}}, - ], - } - } - } - assert (tmp_dir / "src").is_dir() - assert (tmp_dir / "data").is_dir() - - -@pytest.mark.parametrize( - "interactive, inp", - [ - (False, None), - (True, io.StringIO()), - ], -) -def test_init_with_live_and_metrics_plots_provided(tmp_dir, dvc, interactive, inp): - (tmp_dir / "params.yaml").dump({"foo": 1}) - init( - dvc, - interactive=interactive, - stream=inp, - defaults=CmdExperimentsInit.DEFAULTS, - overrides={ - "cmd": "cmd", - "live": "live", - "metrics": "metrics.json", - "plots": "plots", - }, - ) - assert (tmp_dir / "dvc.yaml").parse() == { - "stages": { - "train": { - "cmd": "cmd", - "deps": ["data", "src"], - "metrics": [ - {os.path.join("live", "metrics.json"): {"cache": False}}, - {"metrics.json": {"cache": False}}, - ], - "outs": ["models"], - "params": [{"params.yaml": None}], - "plots": [ - {os.path.join("live", "plots"): {"cache": False}}, - {"plots": {"cache": False}}, - ], - } - } - } - assert (tmp_dir / "src").is_dir() - assert (tmp_dir / "data").is_dir() - - -def test_gen_output_dirs(tmp_dir, dvc): - init( - dvc, - defaults=CmdExperimentsInit.DEFAULTS, - overrides={ - "cmd": "cmd", - "models": "models/predict.h5", - "metrics": "eval/scores.json", - "plots": "eval/plots", - "live": "eval/live", - }, - ) - - assert (tmp_dir / "models").is_dir() - assert (tmp_dir / "eval").is_dir() - assert (tmp_dir / "eval/plots").is_dir() - assert (tmp_dir / "eval/live").is_dir() - assert not (tmp_dir / "models/predict.h5").exists() - assert not (tmp_dir / "eval/scores.json").exists() diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 9ca3fd54f7..3d4cb33983 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -1,17 +1,15 @@ import csv -import pathlib import textwrap from datetime import datetime import pytest -from dvc.cli import DvcParserError, parse_args +from dvc.cli import parse_args from dvc.commands.experiments.apply import CmdExperimentsApply from dvc.commands.experiments.branch import CmdExperimentsBranch from dvc.commands.experiments.clean import CmdExperimentsClean from dvc.commands.experiments.diff import CmdExperimentsDiff from dvc.commands.experiments.gc import CmdExperimentsGC -from dvc.commands.experiments.init import CmdExperimentsInit from dvc.commands.experiments.ls import CmdExperimentsList from dvc.commands.experiments.pull import CmdExperimentsPull from dvc.commands.experiments.push import CmdExperimentsPush @@ -20,9 +18,6 @@ from dvc.commands.experiments.save import CmdExperimentsSave from dvc.commands.experiments.show import CmdExperimentsShow, show_experiments from dvc.exceptions import InvalidArgumentError -from dvc.repo import Repo -from tests.utils import ANY -from tests.utils.asserts import called_once_with_subset from .test_repro import common_arguments as repro_arguments @@ -694,204 +689,6 @@ def test_show_experiments_sort_by(capsys, sort_order): assert params == (2, 1, 0) -@pytest.mark.parametrize("extra_args", [(), ("--run",)]) -def test_experiments_init(dvc, scm, mocker, capsys, extra_args): - stage = mocker.Mock(outs=[], addressing="train") - m = mocker.patch("dvc.repo.experiments.init.init", return_value=(stage, [], [])) - runner = mocker.patch("dvc.repo.experiments.run.run", return_value=0) - cli_args = parse_args(["exp", "init", *extra_args, "cmd"]) - cmd = cli_args.func(cli_args) - - assert isinstance(cmd, CmdExperimentsInit) - assert cmd.run() == 0 - m.assert_called_once_with( - ANY(Repo), - name="train", - type="default", - defaults={ - "code": "src", - "models": "models", - "data": "data", - "metrics": "metrics.json", - "params": "params.yaml", - "plots": "plots", - }, - overrides={"cmd": "cmd"}, - interactive=False, - force=False, - ) - - if extra_args: - # `parse_args` creates a new `Repo` object - runner.assert_called_once_with(ANY(Repo), targets=["train"]) - - -def test_experiments_init_config(dvc, scm, mocker): - with dvc.config.edit() as conf: - conf["exp"] = {"code": "new_src", "models": "new_models"} - - stage = mocker.Mock(outs=[]) - m = mocker.patch("dvc.repo.experiments.init.init", return_value=(stage, [], [])) - cli_args = parse_args(["exp", "init", "cmd"]) - cmd = cli_args.func(cli_args) - - assert isinstance(cmd, CmdExperimentsInit) - assert cmd.run() == 0 - - m.assert_called_once_with( - ANY(Repo), - name="train", - type="default", - defaults={ - "code": "new_src", - "models": "new_models", - "data": "data", - "metrics": "metrics.json", - "params": "params.yaml", - "plots": "plots", - }, - overrides={"cmd": "cmd"}, - interactive=False, - force=False, - ) - - -def test_experiments_init_explicit(dvc, mocker): - stage = mocker.Mock(outs=[]) - m = mocker.patch("dvc.repo.experiments.init.init", return_value=(stage, [], [])) - cli_args = parse_args(["exp", "init", "--explicit", "cmd"]) - cmd = cli_args.func(cli_args) - - assert cmd.run() == 0 - m.assert_called_once_with( - ANY(Repo), - name="train", - type="default", - defaults={}, - overrides={"cmd": "cmd"}, - interactive=False, - force=False, - ) - - -def test_experiments_init_cmd_required_for_non_interactive_mode(dvc): - cli_args = parse_args(["exp", "init"]) - cmd = cli_args.func(cli_args) - assert isinstance(cmd, CmdExperimentsInit) - - with pytest.raises(InvalidArgumentError) as exc: - cmd.run() - assert str(exc.value) == "command is not specified" - - -def test_experiments_init_cmd_not_required_for_interactive_mode(dvc, mocker): - cli_args = parse_args(["exp", "init", "--interactive"]) - cmd = cli_args.func(cli_args) - assert isinstance(cmd, CmdExperimentsInit) - - stage = mocker.Mock(outs=[]) - m = mocker.patch("dvc.repo.experiments.init.init", return_value=(stage, [], [])) - assert cmd.run() == 0 - assert called_once_with_subset(m, ANY(Repo), interactive=True) - - -@pytest.mark.parametrize( - "extra_args, expected_kw", - [ - (["--type", "default"], {"type": "default", "name": "train"}), - (["--type", "checkpoint"], {"type": "checkpoint", "name": "train"}), - (["--force"], {"force": True, "name": "train"}), - ( - ["--name", "name", "--type", "checkpoint"], - {"name": "name", "type": "checkpoint"}, - ), - ( - [ - "--plots", - "p", - "--models", - "m", - "--code", - "c", - "--metrics", - "m.json", - "--params", - "p.yaml", - "--data", - "d", - "--live", - "live", - ], - { - "name": "train", - "overrides": { - "plots": "p", - "models": "m", - "code": "c", - "metrics": "m.json", - "params": "p.yaml", - "data": "d", - "live": "live", - "cmd": "cmd", - }, - }, - ), - ], -) -def test_experiments_init_extra_args(extra_args, expected_kw, mocker): - cli_args = parse_args(["exp", "init", *extra_args, "cmd"]) - cmd = cli_args.func(cli_args) - assert isinstance(cmd, CmdExperimentsInit) - - stage = mocker.Mock(outs=[]) - m = mocker.patch("dvc.repo.experiments.init.init", return_value=(stage, [], [])) - assert cmd.run() == 0 - assert called_once_with_subset(m, ANY(Repo), **expected_kw) - - -def test_experiments_init_type_invalid_choice(): - with pytest.raises(DvcParserError): - parse_args(["exp", "init", "--type=invalid", "cmd"]) - - -@pytest.mark.parametrize("args", [[], ["--run"]]) -def test_experiments_init_displays_output_on_no_run(dvc, mocker, capsys, args): - model_dir = pathlib.Path("models") - model_path = str(model_dir / "predict.h5") - stage = dvc.stage.create( - name="train", - cmd=["cmd"], - deps=["code", "data"], - params=["params.yaml"], - outs=["metrics.json", "plots", model_path], - ) - mocker.patch( - "dvc.repo.experiments.init.init", - return_value=(stage, stage.deps, [model_dir]), - ) - mocker.patch("dvc.repo.experiments.run.run", return_value=0) - cli_args = parse_args(["exp", "init", "cmd", *args]) - cmd = cli_args.func(cli_args) - assert cmd.run() == 0 - - expected_lines = [ - "Creating dependencies: code, data and params.yaml", - "Creating output directories: models", - "Creating train stage in dvc.yaml", - "", - "Ensure your experiment command creates metrics.json, plots and ", - f"{model_path}.", - ] - if not cli_args.run: - expected_lines += [ - 'You can now run your experiment using "dvc exp run".', - ] - - out, err = capsys.readouterr() - assert not err - assert out.splitlines() == expected_lines - - def test_show_experiments_pcp(tmp_dir, mocker): all_experiments = { "workspace": {