Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove a special queued experiments #6393

Merged
merged 10 commits into from
Aug 13, 2021
78 changes: 58 additions & 20 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
from typing import List, Optional

from dvc.exceptions import InvalidArgumentError
from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm.base import RevError

from .base import EXPS_NAMESPACE, ExpRefInfo
from .utils import exp_refs_by_name, remove_exp_refs
Expand All @@ -21,32 +23,43 @@ def remove(repo, exp_names=None, queue=False, **kwargs):
removed += len(repo.experiments.stash)
repo.experiments.stash.clear()
if exp_names:
ref_infos = list(_get_exp_refs(repo, exp_names))
remove_exp_refs(repo.scm, ref_infos)
removed += len(ref_infos)
remained = _remove_commited_exps(repo, exp_names)
remained = _remove_queued_exps(repo, remained)
if remained:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A behavior change here, because we had already removed all of the matched experiments, raise an exception here is needless.

raise InvalidArgumentError(
"'{}' is not a valid experiment".format(";".join(remained))
)
removed += len(exp_names) - len(remained)
return removed


def _get_exp_refs(repo, exp_names):
cur_rev = repo.scm.get_rev()
for name in exp_names:
if name.startswith(EXPS_NAMESPACE):
if not repo.scm.get_ref(name):
raise InvalidArgumentError(
f"'{name}' is not a valid experiment name"
)
yield ExpRefInfo.from_ref(name)
else:
def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]:
stash_revs = repo.experiments.stash_revs
for _, ref_info in stash_revs.items():
if ref_info.name == ref_or_rev:
return ref_info.index
try:
rev = repo.scm.resolve_rev(ref_or_rev)
if rev in stash_revs:
return stash_revs.get(rev).index
except RevError:
pass
return None

exp_refs = list(exp_refs_by_name(repo.scm, name))
if not exp_refs:
raise InvalidArgumentError(
f"'{name}' is not a valid experiment name"
)
yield _get_ref(exp_refs, name, cur_rev)

def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]:
cur_rev = repo.scm.get_rev()
if exp_name.startswith(EXPS_NAMESPACE):
if repo.scm.get_ref(exp_name):
return ExpRefInfo.from_ref(exp_name)
else:
exp_refs = list(exp_refs_by_name(repo.scm, exp_name))
if exp_refs:
return _get_ref(exp_refs, exp_name, cur_rev)
return None


def _get_ref(ref_infos, name, cur_rev):
def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]:
if len(ref_infos) > 1:
for info in ref_infos:
if info.baseline_sha == cur_rev:
Expand All @@ -61,3 +74,28 @@ def _get_ref(ref_infos, name, cur_rev):
msg.extend([f"\t{info}" for info in ref_infos])
raise InvalidArgumentError("\n".join(msg))
return ref_infos[0]


def _remove_commited_exps(repo, refs: List[str]) -> List[str]:
remain_list = []
remove_list = []
for ref in refs:
ref_info = _get_exp_ref(repo, ref)
if ref_info:
remove_list.append(ref_info)
else:
remain_list.append(ref)
if remove_list:
remove_exp_refs(repo.scm, remove_list)
return remain_list


def _remove_queued_exps(repo, refs_or_revs: List[str]) -> List[str]:
remain_list = []
for ref_or_rev in refs_or_revs:
stash_index = _get_exp_stash_index(repo, ref_or_rev)
if stash_index is None:
remain_list.append(ref_or_rev)
else:
repo.experiments.stash.drop(stash_index)
return remain_list
15 changes: 0 additions & 15 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,21 +576,6 @@ def test_run_metrics(tmp_dir, scm, dvc, exp_stage, mocker):
assert show_mock.called_once()


def test_remove(tmp_dir, scm, dvc, exp_stage):
results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"])
exp = first(results)
ref_info = first(exp_refs_by_rev(scm, exp))
dvc.experiments.run(exp_stage.addressing, params=["foo=3"], queue=True)

removed = dvc.experiments.remove([str(ref_info)])
assert removed == 1
assert scm.get_ref(str(ref_info)) is None

removed = dvc.experiments.remove(queue=True)
assert removed == 1
assert len(dvc.experiments.stash) == 0


def test_checkout_targets_deps(tmp_dir, scm, dvc, exp_stage):
from dvc.utils.fs import remove

Expand Down
74 changes: 74 additions & 0 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
from funcy import first

from dvc.exceptions import InvalidArgumentError
from dvc.repo.experiments.utils import exp_refs_by_rev


def test_remove_experiments_by_ref(tmp_dir, scm, dvc, exp_stage, caplog):
queue_length = 3
ref_list = []

dberenbaum marked this conversation as resolved.
Show resolved Hide resolved
for i in range(queue_length):
results = dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"]
)
ref_info = first(exp_refs_by_rev(scm, first(results)))
ref_list.append(str(ref_info))

with pytest.raises(InvalidArgumentError):
assert dvc.experiments.remove(ref_list[:2] + ["non-exist"])
dberenbaum marked this conversation as resolved.
Show resolved Hide resolved
assert scm.get_ref(str(ref_list[0])) is None
assert scm.get_ref(str(ref_list[1])) is None
assert scm.get_ref(str(ref_list[2])) is not None


def test_remove_all_queued_experiments(tmp_dir, scm, dvc, exp_stage):
dberenbaum marked this conversation as resolved.
Show resolved Hide resolved
queue_length = 3

for i in range(queue_length):
dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"], queue=True
)

results = dvc.experiments.run(
exp_stage.addressing, params=[f"foo={queue_length}"]
)
ref_info = first(exp_refs_by_rev(scm, first(results)))

assert len(dvc.experiments.stash) == queue_length
assert dvc.experiments.remove(queue=True) == queue_length
assert len(dvc.experiments.stash) == 0
assert scm.get_ref(str(ref_info)) is not None


def test_remove_special_queued_experiments(tmp_dir, scm, dvc, exp_stage):
dberenbaum marked this conversation as resolved.
Show resolved Hide resolved
results = dvc.experiments.run(
exp_stage.addressing, params=["foo=1"], queue=True, name="queue1"
)
rev1 = first(results)
results = dvc.experiments.run(
exp_stage.addressing, params=["foo=2"], queue=True, name="queue2"
)
rev2 = first(results)
results = dvc.experiments.run(
exp_stage.addressing, params=["foo=3"], queue=True, name="queue3"
)
rev3 = first(results)
results = dvc.experiments.run(exp_stage.addressing, params=["foo=4"])
ref_info1 = first(exp_refs_by_rev(scm, first(results)))
results = dvc.experiments.run(exp_stage.addressing, params=["foo=5"])
ref_info2 = first(exp_refs_by_rev(scm, first(results)))

assert rev1 in dvc.experiments.stash_revs
assert rev2 in dvc.experiments.stash_revs
assert rev3 in dvc.experiments.stash_revs
assert scm.get_ref(str(ref_info1)) is not None
assert scm.get_ref(str(ref_info2)) is not None

assert dvc.experiments.remove(["queue1", rev2[:5], str(ref_info1)]) == 3
assert rev1 not in dvc.experiments.stash_revs
assert rev2 not in dvc.experiments.stash_revs
assert rev3 in dvc.experiments.stash_revs
assert scm.get_ref(str(ref_info1)) is None
assert scm.get_ref(str(ref_info2)) is not None