Skip to content

Commit

Permalink
exp init: create output dirs
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum committed May 16, 2022
1 parent 570e3a5 commit cee35b4
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 14 deletions.
18 changes: 15 additions & 3 deletions dvc/commands/experiments/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

if TYPE_CHECKING:
from dvc.dependency import Dependency
from dvc.output import Output
from dvc.stage import PipelineStage


Expand Down Expand Up @@ -61,7 +62,7 @@ def run(self):
}
)

initialized_stage, initialized_deps = init(
initialized_stage, initialized_deps, initialized_out_dirs = init(
self.repo,
name=self.args.name,
type=self.args.type,
Expand All @@ -70,13 +71,18 @@ def run(self):
interactive=self.args.interactive,
force=self.args.force,
)
self._post_init_display(initialized_stage, initialized_deps)
self._post_init_display(
initialized_stage, initialized_deps, initialized_out_dirs
)
if self.args.run:
self.repo.experiments.run(targets=[initialized_stage.addressing])
return 0

def _post_init_display(
self, stage: "PipelineStage", new_deps: List["Dependency"]
self,
stage: "PipelineStage",
new_deps: List["Dependency"],
new_out_dirs: List["Output"],
) -> None:
from dvc.utils import humanize

Expand All @@ -85,6 +91,12 @@ def _post_init_display(
deps_paths = humanize.join(map(path_fmt, new_deps))
ui.write(f"Creating dependencies: {deps_paths}", styled=True)

if new_out_dirs:
out_dirs_paths = humanize.join(map(path_fmt, new_out_dirs))
ui.write(
f"Creating output directories: {out_dirs_paths}", styled=True
)

ui.write(
f"Creating [b]{self.args.name}[/b] stage in [green]dvc.yaml[/]",
styled=True,
Expand Down
22 changes: 20 additions & 2 deletions dvc/repo/experiments/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,23 @@ def init_deps(stage: PipelineStage) -> List["Dependency"]:
return new_deps


def init_out_dirs(stage: PipelineStage) -> List[str]:
from dvc.fs import localfs

new_dirs = []

# create dirs for outputs
for out in stage.outs:
path = out.def_path
if is_file(path):
path = localfs.path.parent(path)
if path and not localfs.exists(path):
localfs.makedirs(path)
new_dirs.append(path)

return new_dirs


def init(
repo: "Repo",
name: str = "train",
Expand All @@ -189,7 +206,7 @@ def init(
interactive: bool = False,
force: bool = False,
stream: Optional[TextIO] = None,
) -> Tuple[PipelineStage, List["Dependency"]]:
) -> Tuple[PipelineStage, List["Dependency"], List[str]]:
from dvc.dvcfile import make_dvcfile

dvcfile = make_dvcfile(repo, "dvc.yaml")
Expand Down Expand Up @@ -253,10 +270,11 @@ def init(

with _disable_logging(), repo.scm_context(autostage=True, quiet=True):
stage.dump(update_lock=False)
initialized_out_dirs = init_out_dirs(stage)
stage.ignore_outs()
initialized_deps = init_deps(stage)
if params:
repo.scm_context.track_file(params)

assert isinstance(stage, PipelineStage)
return stage, initialized_deps
return stage, initialized_deps, initialized_out_dirs
21 changes: 21 additions & 0 deletions tests/func/experiments/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,24 @@ def test_init_with_live_and_metrics_plots_provided(
}
assert (tmp_dir / "src").is_dir()
assert (tmp_dir / "data").is_dir()


def test_gen_output_dirs(tmp_dir, dvc):
init(
dvc,
defaults=CmdExperimentsInit.DEFAULTS,
overrides={
"cmd": "cmd",
"models": "models/predict.h5",
"metrics": "eval/scores.json",
"plots": "eval/plots",
"live": "eval/live",
},
)

assert (tmp_dir / "models").is_dir()
assert (tmp_dir / "eval").is_dir()
assert (tmp_dir / "eval/plots").is_dir()
assert (tmp_dir / "eval/live").is_dir()
assert not (tmp_dir / "models/predict.h5").exists()
assert not (tmp_dir / "eval/scores.json").exists()
23 changes: 14 additions & 9 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import csv
import pathlib
import textwrap
from datetime import datetime

Expand Down Expand Up @@ -671,7 +672,7 @@ def test_show_experiments_sort_by(capsys, sort_order):
def test_experiments_init(dvc, scm, mocker, capsys, extra_args):
stage = mocker.Mock(outs=[], addressing="train")
m = mocker.patch(
"dvc.repo.experiments.init.init", return_value=(stage, [])
"dvc.repo.experiments.init.init", return_value=(stage, [], [])
)
runner = mocker.patch("dvc.repo.experiments.run.run", return_value=0)
cli_args = parse_args(["exp", "init", *extra_args, "cmd"])
Expand Down Expand Up @@ -707,7 +708,7 @@ def test_experiments_init_config(dvc, scm, mocker):

stage = mocker.Mock(outs=[])
m = mocker.patch(
"dvc.repo.experiments.init.init", return_value=(stage, [])
"dvc.repo.experiments.init.init", return_value=(stage, [], [])
)
cli_args = parse_args(["exp", "init", "cmd"])
cmd = cli_args.func(cli_args)
Expand Down Expand Up @@ -736,7 +737,7 @@ def test_experiments_init_config(dvc, scm, mocker):
def test_experiments_init_explicit(dvc, mocker):
stage = mocker.Mock(outs=[])
m = mocker.patch(
"dvc.repo.experiments.init.init", return_value=(stage, [])
"dvc.repo.experiments.init.init", return_value=(stage, [], [])
)
cli_args = parse_args(["exp", "init", "--explicit", "cmd"])
cmd = cli_args.func(cli_args)
Expand Down Expand Up @@ -770,7 +771,7 @@ def test_experiments_init_cmd_not_required_for_interactive_mode(dvc, mocker):

stage = mocker.Mock(outs=[])
m = mocker.patch(
"dvc.repo.experiments.init.init", return_value=(stage, [])
"dvc.repo.experiments.init.init", return_value=(stage, [], [])
)
assert cmd.run() == 0
assert called_once_with_subset(m, ANY(Repo), interactive=True)
Expand Down Expand Up @@ -826,7 +827,7 @@ def test_experiments_init_extra_args(extra_args, expected_kw, mocker):

stage = mocker.Mock(outs=[])
m = mocker.patch(
"dvc.repo.experiments.init.init", return_value=(stage, [])
"dvc.repo.experiments.init.init", return_value=(stage, [], [])
)
assert cmd.run() == 0
assert called_once_with_subset(m, ANY(Repo), **expected_kw)
Expand All @@ -839,15 +840,18 @@ def test_experiments_init_type_invalid_choice():

@pytest.mark.parametrize("args", [[], ["--run"]])
def test_experiments_init_displays_output_on_no_run(dvc, mocker, capsys, args):
model_dir = pathlib.Path("models")
model_path = str(model_dir / "predict.h5")
stage = dvc.stage.create(
name="train",
cmd=["cmd"],
deps=["code", "data"],
params=["params.yaml"],
outs=["metrics.json", "plots", "models"],
outs=["metrics.json", "plots", model_path],
)
mocker.patch(
"dvc.repo.experiments.init.init", return_value=(stage, stage.deps)
"dvc.repo.experiments.init.init",
return_value=(stage, stage.deps, [model_dir]),
)
mocker.patch("dvc.repo.experiments.run.run", return_value=0)
cli_args = parse_args(["exp", "init", "cmd", *args])
Expand All @@ -856,10 +860,11 @@ def test_experiments_init_displays_output_on_no_run(dvc, mocker, capsys, args):

expected_lines = [
"Creating dependencies: code, data and params.yaml",
"Creating output directories: models",
"Creating train stage in dvc.yaml",
"",
"Ensure your experiment command creates "
"metrics.json, plots and models.",
"Ensure your experiment command creates metrics.json, plots and ",
f"{model_path}.",
]
if not cli_args.run:
expected_lines += [
Expand Down

0 comments on commit cee35b4

Please sign in to comment.