Skip to content

Commit

Permalink
exp init: fixes #7534; simplifies/updates exp init --live
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum committed May 9, 2022
1 parent b14a972 commit bb05b02
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 43 deletions.
7 changes: 2 additions & 5 deletions dvc/commands/experiments/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@ class CmdExperimentsInit(CmdBase):
DEFAULT_METRICS = "metrics.json"
DEFAULT_PARAMS = "params.yaml"
PLOTS = "plots"
DVCLIVE = "dvclive"
DEFAULTS = {
"code": CODE,
"data": DATA,
"models": MODELS,
"metrics": DEFAULT_METRICS,
"params": DEFAULT_PARAMS,
"plots": PLOTS,
"live": DVCLIVE,
}

def run(self):
Expand Down Expand Up @@ -190,12 +188,11 @@ def add_parser(experiments_subparsers, parent_parser):
)
experiments_init_parser.add_argument(
"--live",
help="Path to log dvclive outputs for your experiments"
f" (default: {CmdExperimentsInit.DVCLIVE})",
help="Path to log dvclive outputs for your experiments",
)
experiments_init_parser.add_argument(
"--type",
choices=["default", "dl"],
choices=["default", "checkpoint"],
default="default",
help="Select type of stage to create (default: %(default)s)",
)
Expand Down
25 changes: 11 additions & 14 deletions dvc/repo/experiments/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
"params": "Path to a [b]parameters[/b] file",
"metrics": "Path to a [b]metrics[/b] file",
"plots": "Path to a [b]plots[/b] file/directory",
"live": "Path to log [b]dvclive[/b] outputs",
}


Expand Down Expand Up @@ -79,15 +78,14 @@ def init_interactive(
defaults: Dict[str, str],
provided: Dict[str, str],
validator: Callable[[str, str], Union[str, Tuple[str, str]]] = None,
live: bool = False,
stream: Optional[TextIO] = None,
) -> Dict[str, str]:
command_prompts = lremove(provided.keys(), ["cmd"])
dependencies_prompts = lremove(provided.keys(), ["code", "data", "params"])
outputs_prompts = lremove(
provided.keys(),
["models"] + (["live"] if live else ["metrics", "plots"]),
)
output_keys = ["models"]
if "live" not in provided:
output_keys.extend(["metrics", "plots"])
outputs_prompts = lremove(provided.keys(), output_keys)

ret: Dict[str, str] = {}
if "cmd" in provided:
Expand Down Expand Up @@ -200,21 +198,16 @@ def init(
defaults = defaults.copy() if defaults else {}
overrides = overrides.copy() if overrides else {}

with_live = type == "dl"

if interactive:
defaults = init_interactive(
validator=partial(validate_prompts, repo),
defaults=defaults,
live=with_live,
provided=overrides,
stream=stream,
)
else:
if with_live:
# suppress `metrics`/`plots` if live is selected, unless
# it is also provided via overrides/cli.
# This makes output to be a checkpoint as well.
if "live" in overrides:
# suppress `metrics`/`plots` if live is selected.
defaults.pop("metrics", None)
defaults.pop("plots", None)
else:
Expand Down Expand Up @@ -251,7 +244,11 @@ def init(
metrics_no_cache=compact([context.get("metrics"), live_metrics]),
plots_no_cache=compact([context.get("plots"), live_plots]),
force=force,
**{"checkpoints" if with_live else "outs": compact([models])},
**{
"checkpoints"
if type == "checkpoint"
else "outs": compact([models])
},
)

with _disable_logging(), repo.scm_context(autostage=True, quiet=True):
Expand Down
69 changes: 50 additions & 19 deletions tests/func/experiments/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def test_init_simple(tmp_dir, scm, dvc, capsys):
CmdExperimentsInit.CODE: {"copy.py": ""},
"data": "data",
"params.yaml": '{"foo": 1}',
"dvclive": {},
"plots": {},
}
)
Expand Down Expand Up @@ -137,13 +136,11 @@ def test_init_interactive_when_no_path_prompts_need_to_be_asked(
"cmd": "cmd",
"deps": ["data", "src"],
"metrics": [
{"dvclive.json": {"cache": False}},
{"metrics.json": {"cache": False}},
],
"outs": ["models"],
"params": [{"params.yaml": None}],
"plots": [
{os.path.join("dvclive", "scalars"): {"cache": False}},
{"plots": {"cache": False}},
],
}
Expand Down Expand Up @@ -313,26 +310,18 @@ def test_init_default(tmp_dir, scm, dvc, interactive, overrides, inp, capsys):
"data\n"
"params.yaml\n"
"models\n"
"dvclive\n"
"y"
),
),
(
True,
{"cmd": "python script.py"},
io.StringIO(
"script.py\n"
"data\n"
"params.yaml\n"
"models\n"
"dvclive\n"
"y"
),
io.StringIO("script.py\n" "data\n" "params.yaml\n" "models\n" "y"),
),
(
True,
{"cmd": "python script.py", "models": "models"},
io.StringIO("script.py\ndata\nparams.yaml\ndvclive\ny"),
io.StringIO("script.py\ndata\nparams.yaml\ny"),
),
],
ids=[
Expand All @@ -345,11 +334,12 @@ def test_init_default(tmp_dir, scm, dvc, interactive, overrides, inp, capsys):
def test_init_interactive_live(
tmp_dir, scm, dvc, interactive, overrides, inp, capsys
):
overrides["live"] = "dvclive"

(tmp_dir / "params.yaml").dump({"foo": {"bar": 1}})

init(
dvc,
type="dl",
interactive=interactive,
defaults=CmdExperimentsInit.DEFAULTS,
overrides=overrides,
Expand All @@ -361,7 +351,7 @@ def test_init_interactive_live(
"cmd": "python script.py",
"deps": ["data", "script.py"],
"metrics": [{"dvclive.json": {"cache": False}}],
"outs": [{"models": {"checkpoint": True}}],
"outs": ["models"],
"params": [{"params.yaml": None}],
"plots": [
{os.path.join("dvclive", "scalars"): {"cache": False}}
Expand Down Expand Up @@ -393,13 +383,13 @@ def test_init_interactive_live(
(True, io.StringIO()),
],
)
def test_init_with_type_live_and_models_plots_provided(
def test_init_with_type_checkpoint_and_models_plots_provided(
tmp_dir, dvc, interactive, inp
):
(tmp_dir / "params.yaml").dump({"foo": 1})
init(
dvc,
type="dl",
type="checkpoint",
interactive=interactive,
stream=inp,
defaults=CmdExperimentsInit.DEFAULTS,
Expand All @@ -411,13 +401,11 @@ def test_init_with_type_live_and_models_plots_provided(
"cmd": "cmd",
"deps": ["data", "src"],
"metrics": [
{"dvclive.json": {"cache": False}},
{"m": {"cache": False}},
],
"outs": [{"models": {"checkpoint": True}}],
"params": [{"params.yaml": None}],
"plots": [
{os.path.join("dvclive", "scalars"): {"cache": False}},
{"p": {"cache": False}},
],
}
Expand Down Expand Up @@ -445,6 +433,49 @@ def test_init_with_type_default_and_live_provided(
defaults=CmdExperimentsInit.DEFAULTS,
overrides={"cmd": "cmd", "live": "live"},
)
assert (tmp_dir / "dvc.yaml").parse() == {
"stages": {
"train": {
"cmd": "cmd",
"deps": ["data", "src"],
"metrics": [
{"live.json": {"cache": False}},
],
"outs": ["models"],
"params": [{"params.yaml": None}],
"plots": [
{os.path.join("live", "scalars"): {"cache": False}},
],
}
}
}
assert (tmp_dir / "src").is_dir()
assert (tmp_dir / "data").is_dir()


@pytest.mark.parametrize(
"interactive, inp",
[
(False, None),
(True, io.StringIO()),
],
)
def test_init_with_live_and_metrics_plots_provided(
tmp_dir, dvc, interactive, inp
):
(tmp_dir / "params.yaml").dump({"foo": 1})
init(
dvc,
interactive=interactive,
stream=inp,
defaults=CmdExperimentsInit.DEFAULTS,
overrides={
"cmd": "cmd",
"live": "live",
"metrics": "metrics.json",
"plots": "plots",
},
)
assert (tmp_dir / "dvc.yaml").parse() == {
"stages": {
"train": {
Expand Down
8 changes: 3 additions & 5 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,6 @@ def test_experiments_init(dvc, scm, mocker, capsys, extra_args):
"metrics": "metrics.json",
"params": "params.yaml",
"plots": "plots",
"live": "dvclive",
},
overrides={"cmd": "cmd"},
interactive=False,
Expand Down Expand Up @@ -727,7 +726,6 @@ def test_experiments_init_config(dvc, scm, mocker):
"metrics": "metrics.json",
"params": "params.yaml",
"plots": "plots",
"live": "dvclive",
},
overrides={"cmd": "cmd"},
interactive=False,
Expand Down Expand Up @@ -782,11 +780,11 @@ def test_experiments_init_cmd_not_required_for_interactive_mode(dvc, mocker):
"extra_args, expected_kw",
[
(["--type", "default"], {"type": "default", "name": "train"}),
(["--type", "dl"], {"type": "dl", "name": "train"}),
(["--type", "checkpoint"], {"type": "checkpoint", "name": "train"}),
(["--force"], {"force": True, "name": "train"}),
(
["--name", "name", "--type", "dl"],
{"name": "name", "type": "dl"},
["--name", "name", "--type", "checkpoint"],
{"name": "name", "type": "checkpoint"},
),
(
[
Expand Down

0 comments on commit bb05b02

Please sign in to comment.