Skip to content

Commit

Permalink
exp init: change ui prompts structure and output
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Mar 28, 2022
1 parent 81eee05 commit 37c2023
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 118 deletions.
63 changes: 39 additions & 24 deletions dvc/commands/experiments/init.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import logging
from typing import TYPE_CHECKING, List

from funcy import compact

Expand All @@ -8,6 +9,11 @@
from dvc.exceptions import InvalidArgumentError
from dvc.ui import ui

if TYPE_CHECKING:
from dvc.dependency import Dependency
from dvc.stage import PipelineStage


logger = logging.getLogger(__name__)


Expand All @@ -29,7 +35,6 @@ class CmdExperimentsInit(CmdBase):
"plots": PLOTS,
"live": DVCLIVE,
}
EXP_LINK = "https://s.dvc.org/g/exp/run"

def run(self):
from dvc.commands.stage import parse_cmd
Expand Down Expand Up @@ -58,7 +63,7 @@ def run(self):
}
)

initialized_stage = init(
initialized_stage, initialized_deps = init(
self.repo,
name=self.args.name,
type=self.args.type,
Expand All @@ -67,32 +72,42 @@ def run(self):
interactive=self.args.interactive,
force=self.args.force,
)
self._post_init_display(initialized_stage, initialized_deps)
if self.args.run:
self.repo.experiments.run(targets=[initialized_stage.addressing])
return 0

text = ui.rich_text.assemble(
"Created ",
(self.args.name, "bright_blue"),
" stage in ",
("dvc.yaml", "green"),
".",
def _post_init_display(
self, stage: "PipelineStage", new_deps: List["Dependency"]
) -> None:
from dvc.utils import humanize

path_fmt = "[green]{0}[/green]".format
if new_deps:
deps_paths = humanize.join(map(path_fmt, new_deps))
ui.write(f"Creating dependencies: {deps_paths}", styled=True)

ui.write(
f"Creating [b]{self.args.name}[/b] stage in [green]dvc.yaml[/]",
styled=True,
)
if not self.args.run:
text.append_text(
ui.rich_text.assemble(
" To run, use ",
('"dvc exp run"', "green"),
".\nSee ",
(self.EXP_LINK, "repr.url"),
".",
)
)
if stage.outs or not self.args.run:
# separate the above status-like messages with help/tips section
ui.write(styled=True)

ui.write(text, styled=True)
if self.args.run:
return self.repo.experiments.run(
targets=[initialized_stage.addressing]
)
if stage.outs:
outs_paths = humanize.join(map(path_fmt, stage.outs))
tips = f"Ensure your experiment command creates {outs_paths}."
ui.write(tips, styled=True)

return 0
if not self.args.run:
ui.write(
'You can now run your experiment using [b]"dvc exp run"[/].',
styled=True,
)
else:
# separate between `exp.run` output and `dvc exp init` output
ui.write(styled=True)


def add_parser(experiments_subparsers, parent_parser):
Expand Down
93 changes: 30 additions & 63 deletions dvc/repo/experiments/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Callable,
Dict,
Iterable,
List,
Optional,
TextIO,
Tuple,
Expand All @@ -19,12 +20,11 @@
from dvc.exceptions import DvcException
from dvc.stage import PipelineStage
from dvc.types import OptStr
from dvc.utils import humanize

if TYPE_CHECKING:
from dvc.repo import Repo
from dvc.dvcfile import DVCFile
from rich.tree import Tree
from dvc.dependency import Dependency

from dvc.ui import ui

Expand Down Expand Up @@ -75,70 +75,41 @@ def _disable_logging(highest_level=logging.CRITICAL):
logging.disable(previous_level)


def build_workspace_tree(workspace: Dict[str, str], label: str) -> "Tree":
from rich.tree import Tree

tree = Tree(label, highlight=True)
for value in sorted(workspace.values()):
tree.add(f"[green]{value}[/green]")
return tree


def display_workspace_tree(workspace: Dict[str, str], label: str) -> None:
d = workspace.copy()
d.pop("cmd", None)

if d:
ui.write(build_workspace_tree(d, label), styled=True)
ui.write(styled=True)


PIPELINE_FILE_LINK = "https://s.dvc.org/g/pipeline-files"


def init_interactive(
name: str,
defaults: Dict[str, str],
provided: Dict[str, str],
validator: Callable[[str, str], Union[str, Tuple[str, str]]] = None,
live: bool = False,
stream: Optional[TextIO] = None,
) -> Dict[str, str]:
command = provided.pop("cmd", None)
prompts = lremove(provided.keys(), ["code", "data", "models", "params"])
prompts.extend(
lremove(provided.keys(), ["live"] if live else ["metrics", "plots"])
command_prompts = lremove(provided.keys(), ["cmd"])
dependencies_prompts = lremove(provided.keys(), ["code", "data", "params"])
outputs_prompts = lremove(
provided.keys(),
["models"] + (["live"] if live else ["metrics", "plots"]),
)

ret: Dict[str, str] = {}
if prompts or command:
ui.error_write(
f"This command will guide you to set up a [bright_blue]{name}[/]",
"stage in [green]dvc.yaml[/].",
f"\nSee [repr.url]{PIPELINE_FILE_LINK}[/].\n",
styled=True,
)

if command:
ret["cmd"] = command
else:
ret.update(
compact(_prompts(["cmd"], allow_omission=False, stream=stream))
if "cmd" in provided:
ret["cmd"] = provided["cmd"]

for heading, prompts, allow_omission in (
("", command_prompts, False),
("Enter experiment dependencies.", dependencies_prompts, True),
("Enter experiment outputs.", outputs_prompts, True),
):
if prompts and heading:
ui.error_write(heading, styled=True)
response = _prompts(
prompts,
defaults=defaults,
allow_omission=allow_omission,
validator=validator,
stream=stream,
)
ret.update(compact(response))
if prompts:
ui.error_write(styled=True)

if prompts:
ui.error_write(
"Enter the paths for dependencies and outputs of the command.",
styled=True,
)
ret.update(
compact(
_prompts(prompts, defaults, validator=validator, stream=stream)
)
)
ui.error_write(styled=True)
return ret


Expand Down Expand Up @@ -189,7 +160,7 @@ def is_file(path: str) -> bool:
return bool(ext)


def init_deps(stage: PipelineStage) -> None:
def init_deps(stage: PipelineStage) -> List["Dependency"]:
from funcy import rpartial

from dvc.dependency import ParamsDependency
Expand All @@ -198,10 +169,6 @@ def init_deps(stage: PipelineStage) -> None:
new_deps = [dep for dep in stage.deps if not dep.exists]
params, deps = lsplit(rpartial(isinstance, ParamsDependency), new_deps)

if new_deps:
paths = map("[green]{0}[/]".format, new_deps)
ui.write(f"Created {humanize.join(paths)}.", styled=True)

# always create a file for params, detect file/folder based on extension
# for other dependencies
dirs = [dep.fs_path for dep in deps if not is_file(dep.fs_path)]
Expand All @@ -213,6 +180,8 @@ def init_deps(stage: PipelineStage) -> None:
with localfs.open(path, "w", encoding="utf-8"):
pass

return new_deps


def init(
repo: "Repo",
Expand All @@ -223,7 +192,7 @@ def init(
interactive: bool = False,
force: bool = False,
stream: Optional[TextIO] = None,
) -> PipelineStage:
) -> Tuple[PipelineStage, List["Dependency"]]:
from dvc.dvcfile import make_dvcfile

dvcfile = make_dvcfile(repo, "dvc.yaml")
Expand All @@ -236,7 +205,6 @@ def init(

if interactive:
defaults = init_interactive(
name,
validator=partial(validate_prompts, repo),
defaults=defaults,
live=with_live,
Expand Down Expand Up @@ -288,10 +256,9 @@ def init(
with _disable_logging(), repo.scm_context(autostage=True, quiet=True):
stage.dump(update_lock=False)
stage.ignore_outs()
display_workspace_tree(context, "Using experiment project structure:")
init_deps(stage)
initialized_deps = init_deps(stage)
if params:
repo.scm_context.track_file(params)

assert isinstance(stage, PipelineStage)
return stage
return stage, initialized_deps
Loading

0 comments on commit 37c2023

Please sign in to comment.