From bb05b0262de1820482dca2d1d71544532dd4aed8 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Thu, 5 May 2022 12:42:43 -0400 Subject: [PATCH] exp init: fixes #7534; simplifies/updates exp init --live --- dvc/commands/experiments/init.py | 7 +-- dvc/repo/experiments/init.py | 25 ++++------ tests/func/experiments/test_init.py | 69 +++++++++++++++++++------- tests/unit/command/test_experiments.py | 8 ++- 4 files changed, 66 insertions(+), 43 deletions(-) diff --git a/dvc/commands/experiments/init.py b/dvc/commands/experiments/init.py index aca7626593..8632d55eb3 100644 --- a/dvc/commands/experiments/init.py +++ b/dvc/commands/experiments/init.py @@ -25,7 +25,6 @@ class CmdExperimentsInit(CmdBase): DEFAULT_METRICS = "metrics.json" DEFAULT_PARAMS = "params.yaml" PLOTS = "plots" - DVCLIVE = "dvclive" DEFAULTS = { "code": CODE, "data": DATA, @@ -33,7 +32,6 @@ class CmdExperimentsInit(CmdBase): "metrics": DEFAULT_METRICS, "params": DEFAULT_PARAMS, "plots": PLOTS, - "live": DVCLIVE, } def run(self): @@ -190,12 +188,11 @@ def add_parser(experiments_subparsers, parent_parser): ) experiments_init_parser.add_argument( "--live", - help="Path to log dvclive outputs for your experiments" - f" (default: {CmdExperimentsInit.DVCLIVE})", + help="Path to log dvclive outputs for your experiments", ) experiments_init_parser.add_argument( "--type", - choices=["default", "dl"], + choices=["default", "checkpoint"], default="default", help="Select type of stage to create (default: %(default)s)", ) diff --git a/dvc/repo/experiments/init.py b/dvc/repo/experiments/init.py index fe2227fade..a1961f4d8c 100644 --- a/dvc/repo/experiments/init.py +++ b/dvc/repo/experiments/init.py @@ -36,7 +36,6 @@ "params": "Path to a [b]parameters[/b] file", "metrics": "Path to a [b]metrics[/b] file", "plots": "Path to a [b]plots[/b] file/directory", - "live": "Path to log [b]dvclive[/b] outputs", } @@ -79,15 +78,14 @@ def init_interactive( defaults: Dict[str, str], provided: Dict[str, str], validator: Callable[[str, str], Union[str, Tuple[str, str]]] = None, - live: bool = False, stream: Optional[TextIO] = None, ) -> Dict[str, str]: command_prompts = lremove(provided.keys(), ["cmd"]) dependencies_prompts = lremove(provided.keys(), ["code", "data", "params"]) - outputs_prompts = lremove( - provided.keys(), - ["models"] + (["live"] if live else ["metrics", "plots"]), - ) + output_keys = ["models"] + if "live" not in provided: + output_keys.extend(["metrics", "plots"]) + outputs_prompts = lremove(provided.keys(), output_keys) ret: Dict[str, str] = {} if "cmd" in provided: @@ -200,21 +198,16 @@ def init( defaults = defaults.copy() if defaults else {} overrides = overrides.copy() if overrides else {} - with_live = type == "dl" - if interactive: defaults = init_interactive( validator=partial(validate_prompts, repo), defaults=defaults, - live=with_live, provided=overrides, stream=stream, ) else: - if with_live: - # suppress `metrics`/`plots` if live is selected, unless - # it is also provided via overrides/cli. - # This makes output to be a checkpoint as well. + if "live" in overrides: + # suppress `metrics`/`plots` if live is selected. defaults.pop("metrics", None) defaults.pop("plots", None) else: @@ -251,7 +244,11 @@ def init( metrics_no_cache=compact([context.get("metrics"), live_metrics]), plots_no_cache=compact([context.get("plots"), live_plots]), force=force, - **{"checkpoints" if with_live else "outs": compact([models])}, + **{ + "checkpoints" + if type == "checkpoint" + else "outs": compact([models]) + }, ) with _disable_logging(), repo.scm_context(autostage=True, quiet=True): diff --git a/tests/func/experiments/test_init.py b/tests/func/experiments/test_init.py index 4281518284..87bcedfba6 100644 --- a/tests/func/experiments/test_init.py +++ b/tests/func/experiments/test_init.py @@ -19,7 +19,6 @@ def test_init_simple(tmp_dir, scm, dvc, capsys): CmdExperimentsInit.CODE: {"copy.py": ""}, "data": "data", "params.yaml": '{"foo": 1}', - "dvclive": {}, "plots": {}, } ) @@ -137,13 +136,11 @@ def test_init_interactive_when_no_path_prompts_need_to_be_asked( "cmd": "cmd", "deps": ["data", "src"], "metrics": [ - {"dvclive.json": {"cache": False}}, {"metrics.json": {"cache": False}}, ], "outs": ["models"], "params": [{"params.yaml": None}], "plots": [ - {os.path.join("dvclive", "scalars"): {"cache": False}}, {"plots": {"cache": False}}, ], } @@ -313,26 +310,18 @@ def test_init_default(tmp_dir, scm, dvc, interactive, overrides, inp, capsys): "data\n" "params.yaml\n" "models\n" - "dvclive\n" "y" ), ), ( True, {"cmd": "python script.py"}, - io.StringIO( - "script.py\n" - "data\n" - "params.yaml\n" - "models\n" - "dvclive\n" - "y" - ), + io.StringIO("script.py\n" "data\n" "params.yaml\n" "models\n" "y"), ), ( True, {"cmd": "python script.py", "models": "models"}, - io.StringIO("script.py\ndata\nparams.yaml\ndvclive\ny"), + io.StringIO("script.py\ndata\nparams.yaml\ny"), ), ], ids=[ @@ -345,11 +334,12 @@ def test_init_default(tmp_dir, scm, dvc, interactive, overrides, inp, capsys): def test_init_interactive_live( tmp_dir, scm, dvc, interactive, overrides, inp, capsys ): + overrides["live"] = "dvclive" + (tmp_dir / "params.yaml").dump({"foo": {"bar": 1}}) init( dvc, - type="dl", interactive=interactive, defaults=CmdExperimentsInit.DEFAULTS, overrides=overrides, @@ -361,7 +351,7 @@ def test_init_interactive_live( "cmd": "python script.py", "deps": ["data", "script.py"], "metrics": [{"dvclive.json": {"cache": False}}], - "outs": [{"models": {"checkpoint": True}}], + "outs": ["models"], "params": [{"params.yaml": None}], "plots": [ {os.path.join("dvclive", "scalars"): {"cache": False}} @@ -393,13 +383,13 @@ def test_init_interactive_live( (True, io.StringIO()), ], ) -def test_init_with_type_live_and_models_plots_provided( +def test_init_with_type_checkpoint_and_models_plots_provided( tmp_dir, dvc, interactive, inp ): (tmp_dir / "params.yaml").dump({"foo": 1}) init( dvc, - type="dl", + type="checkpoint", interactive=interactive, stream=inp, defaults=CmdExperimentsInit.DEFAULTS, @@ -411,13 +401,11 @@ def test_init_with_type_live_and_models_plots_provided( "cmd": "cmd", "deps": ["data", "src"], "metrics": [ - {"dvclive.json": {"cache": False}}, {"m": {"cache": False}}, ], "outs": [{"models": {"checkpoint": True}}], "params": [{"params.yaml": None}], "plots": [ - {os.path.join("dvclive", "scalars"): {"cache": False}}, {"p": {"cache": False}}, ], } @@ -445,6 +433,49 @@ def test_init_with_type_default_and_live_provided( defaults=CmdExperimentsInit.DEFAULTS, overrides={"cmd": "cmd", "live": "live"}, ) + assert (tmp_dir / "dvc.yaml").parse() == { + "stages": { + "train": { + "cmd": "cmd", + "deps": ["data", "src"], + "metrics": [ + {"live.json": {"cache": False}}, + ], + "outs": ["models"], + "params": [{"params.yaml": None}], + "plots": [ + {os.path.join("live", "scalars"): {"cache": False}}, + ], + } + } + } + assert (tmp_dir / "src").is_dir() + assert (tmp_dir / "data").is_dir() + + +@pytest.mark.parametrize( + "interactive, inp", + [ + (False, None), + (True, io.StringIO()), + ], +) +def test_init_with_live_and_metrics_plots_provided( + tmp_dir, dvc, interactive, inp +): + (tmp_dir / "params.yaml").dump({"foo": 1}) + init( + dvc, + interactive=interactive, + stream=inp, + defaults=CmdExperimentsInit.DEFAULTS, + overrides={ + "cmd": "cmd", + "live": "live", + "metrics": "metrics.json", + "plots": "plots", + }, + ) assert (tmp_dir / "dvc.yaml").parse() == { "stages": { "train": { diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 6bc5828450..5bd8f1d793 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -690,7 +690,6 @@ def test_experiments_init(dvc, scm, mocker, capsys, extra_args): "metrics": "metrics.json", "params": "params.yaml", "plots": "plots", - "live": "dvclive", }, overrides={"cmd": "cmd"}, interactive=False, @@ -727,7 +726,6 @@ def test_experiments_init_config(dvc, scm, mocker): "metrics": "metrics.json", "params": "params.yaml", "plots": "plots", - "live": "dvclive", }, overrides={"cmd": "cmd"}, interactive=False, @@ -782,11 +780,11 @@ def test_experiments_init_cmd_not_required_for_interactive_mode(dvc, mocker): "extra_args, expected_kw", [ (["--type", "default"], {"type": "default", "name": "train"}), - (["--type", "dl"], {"type": "dl", "name": "train"}), + (["--type", "checkpoint"], {"type": "checkpoint", "name": "train"}), (["--force"], {"force": True, "name": "train"}), ( - ["--name", "name", "--type", "dl"], - {"name": "name", "type": "dl"}, + ["--name", "name", "--type", "checkpoint"], + {"name": "name", "type": "checkpoint"}, ), ( [