Skip to content

Commit

Permalink
experiments: support dvc repro --exp command line params (#4331)
Browse files Browse the repository at this point in the history
* experiments: support passing params values instead of reading from user's workspace

* repro: add --params option for experiments

* support toml params

* update invalid param error message

* update tests
  • Loading branch information
pmrowla authored Aug 6, 2020
1 parent 4759517 commit ab23bcd
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 16 deletions.
8 changes: 8 additions & 0 deletions dvc/command/repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def run(self):
queue=self.args.queue,
run_all=self.args.run_all,
jobs=self.args.jobs,
params=self.args.params,
)

if len(stages) == 0:
Expand Down Expand Up @@ -177,6 +178,13 @@ def add_parser(subparsers, parent_parser):
default=False,
help=argparse.SUPPRESS,
)
repro_parser.add_argument(
"--params",
action="append",
default=[],
help="Declare parameter values for an experiment.",
metavar="[<filename>:]<params_list>",
)
repro_parser.add_argument(
"--queue", action="store_true", default=False, help=argparse.SUPPRESS
)
Expand Down
8 changes: 8 additions & 0 deletions dvc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ def __init__(self, path):
)


class TOMLFileCorruptedError(DvcException):
def __init__(self, path):
path = relpath(path)
super().__init__(
f"unable to read: '{path}', TOML file structure is corrupted"
)


class RecursiveAddingWhileUsingFilename(DvcException):
def __init__(self):
super().__init__(
Expand Down
80 changes: 65 additions & 15 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import os
import re
import tempfile
from collections import defaultdict
from collections.abc import Mapping
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import contextmanager
from typing import Iterable, Optional

from funcy import cached_property

from dvc.exceptions import DvcException
from dvc.path_info import PathInfo
from dvc.repo.experiments.executor import ExperimentExecutor, LocalExecutor
from dvc.scm.git import Git
from dvc.stage.serialize import to_lockfile
Expand Down Expand Up @@ -139,21 +142,39 @@ def _scm_checkout(self, rev):
logger.debug("Checking out experiment commit '%s'", rev)
self.scm.checkout(rev)

def _stash_exp(self, *args, **kwargs):
def _stash_exp(self, *args, params: Optional[dict] = None, **kwargs):
"""Stash changes from the current (parent) workspace as an experiment.
Args:
params: Optional dictionary of parameter values to be used.
Values take priority over any parameters specified in the
user's workspace.
"""
rev = self.scm.get_rev()

# patch user's workspace into experiments clone
tmp = tempfile.NamedTemporaryFile(delete=False).name
try:
self.repo.scm.repo.git.diff(patch=True, output=tmp)
if os.path.getsize(tmp):
logger.debug("Patching experiment workspace")
self.scm.repo.git.apply(tmp)
else:
elif not params:
# experiment matches original baseline
raise UnchangedExperimentError(rev)
finally:
remove(tmp)

# update experiment params from command line
if params:
self._update_params(params)

# save additional repro command line arguments
self._pack_args(*args, **kwargs)

# save experiment as a stash commit w/message containing baseline rev
# (stash commits are merge commits and do not contain a parent commit
# SHA)
msg = f"{self.STASH_MSG_PREFIX}{rev}"
self.scm.repo.git.stash("push", "-m", msg)
return self.scm.resolve_rev("stash@{0}")
Expand All @@ -166,6 +187,36 @@ def _unpack_args(self, tree=None):
args_file = os.path.join(self.exp_dvc.tmp_dir, self.PACKED_ARGS_FILE)
return ExperimentExecutor.unpack_repro_args(args_file, tree=tree)

def _update_params(self, params: dict):
"""Update experiment params files with the specified values."""
from dvc.utils.toml import dump_toml, parse_toml_for_update
from dvc.utils.yaml import dump_yaml, parse_yaml_for_update

logger.debug("Using experiment params '%s'", params)

# recursive dict update
def _update(dict_, other):
for key, value in other.items():
if isinstance(value, Mapping):
dict_[key] = _update(dict_.get(key, {}), value)
else:
dict_[key] = value
return dict_

loaders = defaultdict(lambda: parse_yaml_for_update)
loaders.update({".toml": parse_toml_for_update})
dumpers = defaultdict(lambda: dump_yaml)
dumpers.update({".toml": dump_toml})

for params_fname in params:
path = PathInfo(self.exp_dvc.root_dir) / params_fname
with self.exp_dvc.tree.open(path, "r") as fobj:
text = fobj.read()
suffix = path.suffix.lower()
data = loaders[suffix](text, path)
_update(data, params[params_fname])
dumpers[suffix](path, data)

def _commit(self, exp_hash, check_exists=True, branch=True):
"""Commit stages as an experiment and return the commit SHA."""
if not self.scm.is_dirty():
Expand Down Expand Up @@ -207,23 +258,19 @@ def reproduce_queued(self, **kwargs):
)
return results

def new(self, *args, workspace=True, **kwargs):
def new(self, *args, **kwargs):
"""Create a new experiment.
Experiment will be reproduced and checked out into the user's
workspace.
"""
rev = self.repo.scm.get_rev()
self._scm_checkout(rev)
if workspace:
try:
stash_rev = self._stash_exp(*args, **kwargs)
except UnchangedExperimentError as exc:
logger.info("Reproducing existing experiment '%s'.", rev[:7])
raise exc
else:
# configure params via command line here
pass
try:
stash_rev = self._stash_exp(*args, **kwargs)
except UnchangedExperimentError as exc:
logger.info("Reproducing existing experiment '%s'.", rev[:7])
raise exc
logger.debug(
"Stashed experiment '%s' for future execution.", stash_rev[:7]
)
Expand Down Expand Up @@ -365,8 +412,10 @@ def checkout_exp(self, rev):
tmp = tempfile.NamedTemporaryFile(delete=False).name
self.scm.repo.head.commit.diff("HEAD~1", patch=True, output=tmp)

logger.debug("Stashing workspace changes.")
self.repo.scm.repo.git.stash("push")
dirty = self.repo.scm.is_dirty()
if dirty:
logger.debug("Stashing workspace changes.")
self.repo.scm.repo.git.stash("push")

try:
if os.path.getsize(tmp):
Expand All @@ -379,7 +428,8 @@ def checkout_exp(self, rev):
raise DvcException("failed to apply experiment changes.")
finally:
remove(tmp)
self._unstash_workspace()
if dirty:
self._unstash_workspace()

if need_checkout:
dvc_checkout(self.repo)
Expand Down
29 changes: 28 additions & 1 deletion dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def reproduce(
recursive=False,
pipeline=False,
all_pipelines=False,
**kwargs
**kwargs,
):
from dvc.utils import parse_target

Expand All @@ -71,6 +71,7 @@ def reproduce(
)

experiment = kwargs.pop("experiment", False)
params = _parse_params(kwargs.pop("params", []))
queue = kwargs.pop("queue", False)
run_all = kwargs.pop("run_all", False)
jobs = kwargs.pop("jobs", 1)
Expand All @@ -81,6 +82,7 @@ def reproduce(
target=target,
recursive=recursive,
all_pipelines=all_pipelines,
params=params,
queue=queue,
run_all=run_all,
jobs=jobs,
Expand Down Expand Up @@ -116,6 +118,31 @@ def reproduce(
return _reproduce_stages(active_graph, targets, **kwargs)


def _parse_params(path_params):
from flatten_json import unflatten
from yaml import safe_load, YAMLError
from dvc.dependency.param import ParamsDependency

ret = {}
for path_param in path_params:
path, _, params_str = path_param.rpartition(":")
# remove empty strings from params, on condition such as `-p "file1:"`
params = {}
for param_str in filter(bool, params_str.split(",")):
try:
# interpret value strings using YAML rules
key, value = param_str.split("=")
params[key] = safe_load(value)
except (ValueError, YAMLError):
raise InvalidArgumentError(
f"Invalid param/value pair '{param_str}'"
)
if not path:
path = ParamsDependency.DEFAULT_PARAMS_FILE
ret[path] = unflatten(params, ".")
return ret


def _reproduce_experiments(repo, run_all=False, jobs=1, **kwargs):
if run_all:
return repo.experiments.reproduce_queued(jobs=jobs)
Expand Down
21 changes: 21 additions & 0 deletions dvc/utils/toml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import toml

from dvc.exceptions import TOMLFileCorruptedError


def parse_toml_for_update(text, path):
"""Parses text into Python structure.
NOTE: Python toml package does not currently use ordered dicts, so
keys may be re-ordered between load/dump, but this function will at
least preserve comments.
"""
try:
return toml.loads(text, decoder=toml.TomlPreserveCommentDecoder())
except toml.TomlDecodeError as exc:
raise TOMLFileCorruptedError(path) from exc


def dump_toml(path, data):
with open(path, "w", encoding="utf-8") as fobj:
toml.dump(data, fobj, encoder=toml.TomlPreserveCommentEncoder())
1 change: 1 addition & 0 deletions tests/unit/command/test_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"recursive": False,
"force_downstream": False,
"experiment": False,
"params": [],
"queue": False,
"run_all": False,
"jobs": None,
Expand Down

0 comments on commit ab23bcd

Please sign in to comment.