Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experiments: support dvc repro --exp command line params #4331

Merged
merged 5 commits into from
Aug 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions dvc/command/repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def run(self):
queue=self.args.queue,
run_all=self.args.run_all,
jobs=self.args.jobs,
params=self.args.params,
)

if len(stages) == 0:
Expand Down Expand Up @@ -177,6 +178,13 @@ def add_parser(subparsers, parent_parser):
default=False,
help=argparse.SUPPRESS,
)
repro_parser.add_argument(
"--params",
action="append",
default=[],
help="Declare parameter values for an experiment.",
metavar="[<filename>:]<params_list>",
)
repro_parser.add_argument(
"--queue", action="store_true", default=False, help=argparse.SUPPRESS
)
Expand Down
8 changes: 8 additions & 0 deletions dvc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ def __init__(self, path):
)


class TOMLFileCorruptedError(DvcException):
def __init__(self, path):
path = relpath(path)
super().__init__(
f"unable to read: '{path}', TOML file structure is corrupted"
)


class RecursiveAddingWhileUsingFilename(DvcException):
def __init__(self):
super().__init__(
Expand Down
80 changes: 65 additions & 15 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import os
import re
import tempfile
from collections import defaultdict
from collections.abc import Mapping
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import contextmanager
from typing import Iterable, Optional

from funcy import cached_property

from dvc.exceptions import DvcException
from dvc.path_info import PathInfo
from dvc.repo.experiments.executor import ExperimentExecutor, LocalExecutor
from dvc.scm.git import Git
from dvc.stage.serialize import to_lockfile
Expand Down Expand Up @@ -139,21 +142,39 @@ def _scm_checkout(self, rev):
logger.debug("Checking out experiment commit '%s'", rev)
self.scm.checkout(rev)

def _stash_exp(self, *args, **kwargs):
def _stash_exp(self, *args, params: Optional[dict] = None, **kwargs):
"""Stash changes from the current (parent) workspace as an experiment.

Args:
params: Optional dictionary of parameter values to be used.
Values take priority over any parameters specified in the
user's workspace.
"""
rev = self.scm.get_rev()

# patch user's workspace into experiments clone
tmp = tempfile.NamedTemporaryFile(delete=False).name
try:
self.repo.scm.repo.git.diff(patch=True, output=tmp)
if os.path.getsize(tmp):
logger.debug("Patching experiment workspace")
self.scm.repo.git.apply(tmp)
else:
elif not params:
# experiment matches original baseline
raise UnchangedExperimentError(rev)
finally:
remove(tmp)

# update experiment params from command line
if params:
self._update_params(params)

# save additional repro command line arguments
self._pack_args(*args, **kwargs)

# save experiment as a stash commit w/message containing baseline rev
# (stash commits are merge commits and do not contain a parent commit
# SHA)
msg = f"{self.STASH_MSG_PREFIX}{rev}"
self.scm.repo.git.stash("push", "-m", msg)
return self.scm.resolve_rev("stash@{0}")
Expand All @@ -166,6 +187,36 @@ def _unpack_args(self, tree=None):
args_file = os.path.join(self.exp_dvc.tmp_dir, self.PACKED_ARGS_FILE)
return ExperimentExecutor.unpack_repro_args(args_file, tree=tree)

def _update_params(self, params: dict):
"""Update experiment params files with the specified values."""
from dvc.utils.toml import dump_toml, parse_toml_for_update
from dvc.utils.yaml import dump_yaml, parse_yaml_for_update

logger.debug("Using experiment params '%s'", params)

# recursive dict update
def _update(dict_, other):
for key, value in other.items():
if isinstance(value, Mapping):
dict_[key] = _update(dict_.get(key, {}), value)
else:
dict_[key] = value
return dict_

loaders = defaultdict(lambda: parse_yaml_for_update)
loaders.update({".toml": parse_toml_for_update})
dumpers = defaultdict(lambda: dump_yaml)
dumpers.update({".toml": dump_toml})

for params_fname in params:
path = PathInfo(self.exp_dvc.root_dir) / params_fname
with self.exp_dvc.tree.open(path, "r") as fobj:
text = fobj.read()
suffix = path.suffix.lower()
data = loaders[suffix](text, path)
_update(data, params[params_fname])
dumpers[suffix](path, data)

def _commit(self, exp_hash, check_exists=True, branch=True):
"""Commit stages as an experiment and return the commit SHA."""
if not self.scm.is_dirty():
Expand Down Expand Up @@ -207,23 +258,19 @@ def reproduce_queued(self, **kwargs):
)
return results

def new(self, *args, workspace=True, **kwargs):
def new(self, *args, **kwargs):
"""Create a new experiment.

Experiment will be reproduced and checked out into the user's
workspace.
"""
rev = self.repo.scm.get_rev()
self._scm_checkout(rev)
if workspace:
try:
stash_rev = self._stash_exp(*args, **kwargs)
except UnchangedExperimentError as exc:
logger.info("Reproducing existing experiment '%s'.", rev[:7])
raise exc
else:
# configure params via command line here
pass
try:
stash_rev = self._stash_exp(*args, **kwargs)
except UnchangedExperimentError as exc:
logger.info("Reproducing existing experiment '%s'.", rev[:7])
raise exc
logger.debug(
"Stashed experiment '%s' for future execution.", stash_rev[:7]
)
Expand Down Expand Up @@ -365,8 +412,10 @@ def checkout_exp(self, rev):
tmp = tempfile.NamedTemporaryFile(delete=False).name
self.scm.repo.head.commit.diff("HEAD~1", patch=True, output=tmp)

logger.debug("Stashing workspace changes.")
self.repo.scm.repo.git.stash("push")
dirty = self.repo.scm.is_dirty()
if dirty:
logger.debug("Stashing workspace changes.")
self.repo.scm.repo.git.stash("push")

try:
if os.path.getsize(tmp):
Expand All @@ -379,7 +428,8 @@ def checkout_exp(self, rev):
raise DvcException("failed to apply experiment changes.")
finally:
remove(tmp)
self._unstash_workspace()
if dirty:
self._unstash_workspace()

if need_checkout:
dvc_checkout(self.repo)
Expand Down
29 changes: 28 additions & 1 deletion dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def reproduce(
recursive=False,
pipeline=False,
all_pipelines=False,
**kwargs
**kwargs,
):
from dvc.utils import parse_target

Expand All @@ -71,6 +71,7 @@ def reproduce(
)

experiment = kwargs.pop("experiment", False)
params = _parse_params(kwargs.pop("params", []))
queue = kwargs.pop("queue", False)
run_all = kwargs.pop("run_all", False)
jobs = kwargs.pop("jobs", 1)
Expand All @@ -81,6 +82,7 @@ def reproduce(
target=target,
recursive=recursive,
all_pipelines=all_pipelines,
params=params,
queue=queue,
run_all=run_all,
jobs=jobs,
Expand Down Expand Up @@ -116,6 +118,31 @@ def reproduce(
return _reproduce_stages(active_graph, targets, **kwargs)


def _parse_params(path_params):
from flatten_json import unflatten
from yaml import safe_load, YAMLError
from dvc.dependency.param import ParamsDependency

ret = {}
for path_param in path_params:
path, _, params_str = path_param.rpartition(":")
# remove empty strings from params, on condition such as `-p "file1:"`
params = {}
for param_str in filter(bool, params_str.split(",")):
try:
# interpret value strings using YAML rules
key, value = param_str.split("=")
params[key] = safe_load(value)
except (ValueError, YAMLError):
raise InvalidArgumentError(
f"Invalid param/value pair '{param_str}'"
)
if not path:
path = ParamsDependency.DEFAULT_PARAMS_FILE
ret[path] = unflatten(params, ".")
return ret


def _reproduce_experiments(repo, run_all=False, jobs=1, **kwargs):
if run_all:
return repo.experiments.reproduce_queued(jobs=jobs)
Expand Down
21 changes: 21 additions & 0 deletions dvc/utils/toml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import toml

from dvc.exceptions import TOMLFileCorruptedError


def parse_toml_for_update(text, path):
"""Parses text into Python structure.

NOTE: Python toml package does not currently use ordered dicts, so
keys may be re-ordered between load/dump, but this function will at
least preserve comments.
"""
try:
return toml.loads(text, decoder=toml.TomlPreserveCommentDecoder())
except toml.TomlDecodeError as exc:
raise TOMLFileCorruptedError(path) from exc


def dump_toml(path, data):
with open(path, "w", encoding="utf-8") as fobj:
toml.dump(data, fobj, encoder=toml.TomlPreserveCommentEncoder())
1 change: 1 addition & 0 deletions tests/unit/command/test_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"recursive": False,
"force_downstream": False,
"experiment": False,
"params": [],
"queue": False,
"run_all": False,
"jobs": None,
Expand Down