Skip to content

Commit

Permalink
exp init: support interactive and explicit options (#6681)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Sep 28, 2021
1 parent 8107ec8 commit 4668a35
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 31 deletions.
170 changes: 140 additions & 30 deletions dvc/command/experiments.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import argparse
import logging
import os
from collections import Counter, OrderedDict, defaultdict
from collections import ChainMap, Counter, OrderedDict, defaultdict
from datetime import date, datetime
from fnmatch import fnmatch
from typing import TYPE_CHECKING, Dict, Iterable, Optional

from funcy import lmap
from funcy import compact, lmap, post_processing
from rich.prompt import Prompt

from dvc import prompt
from dvc.command import completion
from dvc.command.base import CmdBase, append_doc_link, fix_subparsers
from dvc.command.metrics import DEFAULT_PRECISION
Expand Down Expand Up @@ -664,7 +664,7 @@ def run(self):
logger.warning(msg)

msg = "Are you sure you want to proceed?"
if not self.args.force and not prompt.confirm(msg):
if not self.args.force and not ui.confirm(msg):
return 1

removed = self.repo.experiments.gc(
Expand Down Expand Up @@ -785,6 +785,50 @@ def run(self):
return 0


class RequiredPrompt(Prompt):
def process_response(self, value: str):
from rich.prompt import InvalidResponse

ret = super().process_response(value)
if not ret:
raise InvalidResponse(
"[prompt.invalid]Response required. Please try again."
)
return ret

def render_default(self, default):
from rich.text import Text

return Text(f"{default!s}", "green")


class SkippablePrompt(RequiredPrompt):
skip_value: str = "n"

def process_response(self, value: str):
ret = super().process_response(value)
return None if ret == self.skip_value else ret

def make_prompt(self, default):
prompt = self.prompt.copy()
prompt.end = ""

prompt.append(" [")
if (
default is not ...
and self.show_default
and isinstance(default, (str, self.response_type))
):
_default = self.render_default(default)
prompt.append(_default)
prompt.append(", ")

prompt.append(f"{self.skip_value} to skip", style="italic")
prompt.append("]")
prompt.append(self.prompt_suffix)
return prompt


class CmdExperimentsInit(CmdBase):
CODE = "src"
DATA = "data"
Expand All @@ -795,45 +839,100 @@ class CmdExperimentsInit(CmdBase):
DVCLIVE = "dvclive"
DEFAULT_NAME = "default"

@post_processing(dict)
def init_interactive(self, defaults=None):
defaults = defaults or {}
prompts = {
"cmd": "[b]Command[/b] to execute",
"code": "Path to a [b]code[/b] file/directory",
"data": "Path to a [b]data[/b] file/directory",
"models": "Path to a [b]model[/b] file/directory",
"metrics": "Path to a [b]metrics[/b] file",
"params": "Path to a [b]parameters[/b] file",
"plots": "Path to a [b]plots[/b] file/directory",
"live": "Path to log [b]dvclive[/b] outputs",
}
message = (
"This command will guide you to set up your first stage in "
"[green]dvc.yaml[/green].\n"
)
ui.error_write(message, styled=True)

for key, prompt in prompts.items():
prompt_cls = RequiredPrompt if key == "cmd" else SkippablePrompt
kwargs = {"default": defaults[key]} if key in defaults else {}
value = prompt_cls.ask(prompt, console=ui.error_console, **kwargs)
yield key, value

def run(self):
from dvc.command.stage import parse_cmd

cmd = parse_cmd(self.args.cmd)
if not cmd:
if not self.args.interactive and not cmd:
raise InvalidArgumentError("command is not specified")
if self.args.interactive:
raise NotImplementedError(
"'-i/--interactive' is not implemented yet."
)
if self.args.explicit:
raise NotImplementedError("'--explicit' is not implemented yet.")
if self.args.template:
raise NotImplementedError("template is not supported yet.")

from dvc.utils.serialize import LOADERS
global_defaults = {
"code": self.CODE,
"data": self.DATA,
"models": self.MODELS,
"metrics": self.DEFAULT_METRICS,
"params": self.DEFAULT_PARAMS,
"plots": self.PLOTS,
}

context = ChainMap()
if not self.args.explicit:
config = {} # TODO
context.maps.extend([config, global_defaults])

code = self.args.code or self.CODE
data = self.args.data or self.DATA
models = self.args.models or self.MODELS
metrics = self.args.metrics or self.DEFAULT_METRICS
params_path = self.args.params or self.DEFAULT_PARAMS
plots = self.args.plots or self.PLOTS
dvclive = self.args.live or self.DVCLIVE
if self.args.interactive:
defaults = context.new_child({"live": self.DVCLIVE})
context = self.init_interactive(defaults=defaults)
else:
d = compact(
{
"cmd": cmd,
"code": self.args.code,
"data": self.args.data,
"models": self.args.models,
"metrics": self.args.metrics,
"params": self.args.params,
"plots": self.args.plots,
"live": self.args.live,
}
)
context = context.new_child(d)

_, ext = os.path.splitext(params_path)
params = list(LOADERS[ext](params_path))
assert "cmd" in context
command = context["cmd"]
code = context.get("code")
data = context.get("data")
models = context.get("models")
metrics = context.get("metrics")
plots = context.get("plots")
live = context.get("live")

params_kv = []
if "params" in context:
from dvc.utils.serialize import LOADERS

path = context["params"]
_, ext = os.path.splitext(path)
params_kv = [{path: list(LOADERS[ext](path))}]

name = self.args.name or self.DEFAULT_NAME
stage = self.repo.stage.add(
name=name,
cmd=cmd,
deps=[code, data],
outs=[models],
params=[{params_path: params}],
metrics_no_cache=[metrics],
plots_no_cache=[plots],
live=dvclive,
force=True,
cmd=command,
deps=compact([code, data]),
outs=compact([models]),
params=params_kv,
metrics_no_cache=compact([metrics]),
plots_no_cache=compact([plots]),
live=live,
force=self.args.force,
)

if self.args.run:
Expand Down Expand Up @@ -1383,7 +1482,18 @@ def add_parser(subparsers, parent_parser):
"--template", help="Stage template to use to fill with provided values"
)
experiments_init_parser.add_argument(
"--explicit", help="Only use the path values explicitly provided"
"-f",
"--force",
action="store_true",
default=False,
help="Overwrite existing stage",
)

experiments_init_parser.add_argument(
"--explicit",
action="store_true",
default=False,
help="Only use the path values explicitly provided",
)
experiments_init_parser.add_argument(
"--name", "-n", help="Name of the stage to create"
Expand Down
1 change: 0 additions & 1 deletion tests/func/experiments/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def test_init(tmp_dir, dvc):
"default": {
"cmd": script,
"deps": ["data", "src"],
"live": {"dvclive": {"html": True, "summary": True}},
"metrics": [{"metrics.json": {"cache": False}}],
"outs": ["models"],
"params": ["foo"],
Expand Down

0 comments on commit 4668a35

Please sign in to comment.