Skip to content

Commit

Permalink
exp init: create params file by default
Browse files Browse the repository at this point in the history
Note that this won't add support for tracking the params file,
it'll only create the params file. Some changes on error and
prompts are made.
  • Loading branch information
skshetry committed Mar 21, 2022
1 parent a3bab48 commit 9d96ddf
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 54 deletions.
21 changes: 4 additions & 17 deletions dvc/dependency/param.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import errno
import logging
import os
import typing
Expand Down Expand Up @@ -133,29 +132,17 @@ def status(self):

def validate_filepath(self):
if not self.exists:
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), str(self)
)
raise MissingParamsFile(f"Parameters file '{self}' does not exist")
if self.isdir():
raise IsADirectoryError(
errno.EISDIR, os.strerror(errno.EISDIR), str(self)
raise ParamsIsADirectoryError(
f"'{self}' is a directory, expected a parameters file"
)

def read_file(self):
_, ext = os.path.splitext(self.fs_path)
loader = LOADERS[ext]

try:
self.validate_filepath()
except FileNotFoundError as exc:
raise MissingParamsFile(
f"Parameters file '{self}' does not exist"
) from exc
except IsADirectoryError as exc:
raise ParamsIsADirectoryError(
f"'{self}' is a directory, expected a parameters file"
) from exc

self.validate_filepath()
try:
return loader(self.fs_path, fs=self.repo.fs)
except ParseError as exc:
Expand Down
42 changes: 19 additions & 23 deletions dvc/repo/experiments/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,32 +159,28 @@ def validate_prompts(
) -> Union[Any, Tuple[Any, str]]:
from dvc.ui.prompt import InvalidResponse

msg_format = "[yellow]'{0}' does not exist, the {1} will be created.[/]"
if key == "params":
import errno

from dvc.dependency.param import ParamsDependency
from dvc.dependency.param import (
MissingParamsFile,
ParamsDependency,
ParamsIsADirectoryError,
)

assert isinstance(value, str)
msg_format = (
"[prompt.invalid]'{0}' {1}. "
"Please retry with an existing parameters file."
)
try:
ParamsDependency(None, value, repo=repo).validate_filepath()
except (IsADirectoryError, FileNotFoundError) as e:
suffices = {
errno.EISDIR: "is a directory",
errno.ENOENT: "does not exist",
}
raise InvalidResponse(msg_format.format(value, suffices[e.errno]))
except MissingParamsFile:
return value, msg_format.format(value, "file")
except ParamsIsADirectoryError:
raise InvalidResponse(
f"[prompt.invalid]'{value}' is a directory. "
"Please retry with an existing parameters file."
)
elif key in ("code", "data"):
if not os.path.exists(value):
typ = "file" if is_file(value) else "directory"
return (
value,
f"[yellow]'{value}' does not exist, "
f"the {typ} will be created. ",
)
return value, msg_format.format(value, typ)
return value


Expand Down Expand Up @@ -260,7 +256,6 @@ def init(
context: Dict[str, str] = {**defaults, **overrides}
assert "cmd" in context

params_kv = []
params = context.get("params")
if params:
from dvc.dependency.param import (
Expand All @@ -270,18 +265,19 @@ def init(
)

try:
params_d = ParamsDependency(None, params, repo=repo).read_file()
except (MissingParamsFile, ParamsIsADirectoryError) as exc:
ParamsDependency(None, params, repo=repo).validate_filepath()
except ParamsIsADirectoryError as exc:
raise DvcException(f"{exc}.") # swallow cause for display
params_kv.append({params: list(params_d.keys())})
except MissingParamsFile:
pass

checkpoint_out = bool(context.get("live"))
models = context.get("models")
stage = repo.stage.create(
name=name,
cmd=context["cmd"],
deps=compact([context.get("code"), context.get("data")]),
params=params_kv,
params=[{params: None}] if params else None,
metrics_no_cache=compact([context.get("metrics")]),
plots_no_cache=compact([context.get("plots")]),
live=context.get("live"),
Expand Down
41 changes: 27 additions & 14 deletions tests/func/experiments/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_init_simple(tmp_dir, scm, dvc, capsys):
"deps": ["data", "src"],
"metrics": [{"metrics.json": {"cache": False}}],
"outs": ["models"],
"params": ["foo"],
"params": [{"params.yaml": None}],
"plots": [{"plots": {"cache": False}}],
}
}
Expand Down Expand Up @@ -76,6 +76,24 @@ def test_when_stage_force_if_already_exists(tmp_dir, dvc):
assert d["stages"]["train"]["cmd"] == "true"


@pytest.mark.parametrize("interactive", [True, False])
def test_creates_params_file_by_default(tmp_dir, dvc, interactive, capsys):
init(
dvc,
interactive=interactive,
defaults=CmdExperimentsInit.DEFAULTS,
overrides={"cmd": "cmd"},
stream=io.StringIO(""),
)

assert (tmp_dir / "params.yaml").is_file()
assert (tmp_dir / "params.yaml").parse() == {}
out, err = capsys.readouterr()
assert "Created src, data and params.yaml." in out
if interactive:
assert "'params.yaml' does not exist, the file will be created." in err


def test_with_a_custom_name(tmp_dir, dvc):
init(dvc, name="custom", overrides={"cmd": "cmd"})
assert (tmp_dir / "dvc.yaml").parse() == {
Expand Down Expand Up @@ -111,7 +129,7 @@ def test_init_interactive_when_no_path_prompts_need_to_be_asked(
interactive=True,
defaults=CmdExperimentsInit.DEFAULTS,
overrides={**CmdExperimentsInit.DEFAULTS, **extra_overrides},
stream=inp, # we still need to confirm
stream=inp,
)
assert (tmp_dir / "dvc.yaml").parse() == {
"stages": {
Expand All @@ -123,7 +141,7 @@ def test_init_interactive_when_no_path_prompts_need_to_be_asked(
# we specify `live` through `overrides`,
# so it creates checkpoint-based output.
"outs": [{"models": {"checkpoint": True}}],
"params": ["foo"],
"params": [{"params.yaml": None}],
"plots": [{"plots": {"cache": False}}],
}
}
Expand Down Expand Up @@ -164,9 +182,7 @@ def test_when_params_is_omitted_in_interactive_mode(tmp_dir, scm, dvc):
def test_init_interactive_params_validation(tmp_dir, dvc, capsys):
tmp_dir.gen({"data": {"foo": "foo"}})
(tmp_dir / "params.yaml").dump({"foo": 1})
inp = io.StringIO(
"python script.py\nscript.py\ndata\nmodels\nparams.json\ndata\n"
)
inp = io.StringIO("python script.py\nscript.py\ndata\nmodels\ndata\n")

init(
dvc, stream=inp, interactive=True, defaults=CmdExperimentsInit.DEFAULTS
Expand All @@ -179,7 +195,7 @@ def test_init_interactive_params_validation(tmp_dir, dvc, capsys):
"deps": ["data", "script.py"],
"metrics": [{"metrics.json": {"cache": False}}],
"outs": ["models"],
"params": ["foo"],
"params": [{"params.yaml": None}],
"plots": [{"plots": {"cache": False}}],
}
}
Expand All @@ -189,9 +205,6 @@ def test_init_interactive_params_validation(tmp_dir, dvc, capsys):

out, err = capsys.readouterr()
assert (
"Path to a parameters file [params.yaml, n to omit]: "
"'params.json' does not exist. "
"Please retry with an existing parameters file.\n"
"Path to a parameters file [params.yaml, n to omit]: "
"'data' is a directory. "
"Please retry with an existing parameters file.\n"
Expand Down Expand Up @@ -262,7 +275,7 @@ def test_init_default(tmp_dir, scm, dvc, interactive, overrides, inp, capsys):
"deps": ["data", "script.py"],
"metrics": [{"metrics.json": {"cache": False}}],
"outs": ["models"],
"params": ["foo"],
"params": [{"params.yaml": None}],
"plots": [{"plots": {"cache": False}}],
}
}
Expand Down Expand Up @@ -347,7 +360,7 @@ def test_init_interactive_live(
"deps": ["data", "script.py"],
"live": {"dvclive": {"html": True, "summary": True}},
"outs": [{"models": {"checkpoint": True}}],
"params": ["foo"],
"params": [{"params.yaml": None}],
}
}
}
Expand Down Expand Up @@ -396,7 +409,7 @@ def test_init_with_type_live_and_models_plots_provided(
"live": {"dvclive": {"html": True, "summary": True}},
"metrics": [{"m": {"cache": False}}],
"outs": [{"models": {"checkpoint": True}}],
"params": ["foo"],
"params": [{"params.yaml": None}],
"plots": [{"p": {"cache": False}}],
}
}
Expand Down Expand Up @@ -431,7 +444,7 @@ def test_init_with_type_default_and_live_provided(
"live": {"live": {"html": True, "summary": True}},
"metrics": [{"metrics.json": {"cache": False}}],
"outs": [{"models": {"checkpoint": True}}],
"params": ["foo"],
"params": [{"params.yaml": None}],
"plots": [{"plots": {"cache": False}}],
}
}
Expand Down

0 comments on commit 9d96ddf

Please sign in to comment.