Skip to content
/ dvc Public
forked from iterative/dvc

Commit

Permalink
exp: Fix --rev args for ls pull and remove.
Browse files Browse the repository at this point in the history
  • Loading branch information
daavoo committed May 21, 2023
1 parent e089ee0 commit 3ddd4b8
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 8 deletions.
9 changes: 6 additions & 3 deletions dvc/repo/experiments/ls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections import defaultdict
from typing import Optional
from typing import List, Optional, Union

from dvc.repo import locked
from dvc.repo.scm_context import scm_context
Expand All @@ -15,14 +15,17 @@
@scm_context
def ls(
repo,
rev: Optional[str] = None,
rev: Optional[Union[List[str], str]] = None,
all_commits: bool = False,
num: int = 1,
git_remote: Optional[str] = None,
):
rev_set = None
if not all_commits:
revs = iter_revs(repo.scm, [rev or "HEAD"], num)
rev = rev or "HEAD"
if isinstance(rev, str):
rev = [rev]
revs = iter_revs(repo.scm, rev, num)
rev_set = set(revs.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)

Expand Down
6 changes: 4 additions & 2 deletions dvc/repo/experiments/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def pull( # noqa: C901
git_remote: str,
exp_names: Union[Iterable[str], str],
all_commits=False,
rev: Optional[str] = None,
rev: Optional[Union[List[str], str]] = None,
num=1,
force: bool = False,
pull_cache: bool = False,
Expand All @@ -49,7 +49,9 @@ def pull( # noqa: C901
raise UnresolvedExpNamesError(unresolved_exp_names)

if rev:
rev_dict = iter_revs(repo.scm, [rev], num)
if isinstance(rev, str):
rev = [rev]
rev_dict = iter_revs(repo.scm, rev, num)
rev_set = set(rev_dict.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)
for _, ref_info_list in ref_info_dict.items():
Expand Down
8 changes: 5 additions & 3 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
def remove( # noqa: C901, PLR0912
repo: "Repo",
exp_names: Union[None, str, List[str]] = None,
rev: Optional[str] = None,
rev: Optional[Union[List[str], str]] = None,
all_commits: bool = False,
num: int = 1,
queue: bool = False,
Expand Down Expand Up @@ -60,6 +60,8 @@ def remove( # noqa: C901, PLR0912
if remained:
raise UnresolvedExpNamesError(remained, git_remote=git_remote)
elif rev:
if isinstance(rev, str):
rev = [rev]
exp_ref_dict = _resolve_exp_by_baseline(repo, rev, num, git_remote)
removed.extend(exp_ref_dict.keys())
exp_ref_list.extend(exp_ref_dict.values())
Expand All @@ -85,14 +87,14 @@ def remove( # noqa: C901, PLR0912

def _resolve_exp_by_baseline(
repo: "Repo",
rev: str,
rev: List[str],
num: int,
git_remote: Optional[str] = None,
) -> Dict[str, "ExpRefInfo"]:
assert isinstance(repo.scm, Git)

commit_ref_dict: Dict[str, "ExpRefInfo"] = {}
rev_dict = iter_revs(repo.scm, [rev], num)
rev_dict = iter_revs(repo.scm, rev, num)
rev_set = set(rev_dict.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)
for _, ref_info_list in ref_info_dict.items():
Expand Down
6 changes: 6 additions & 0 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@ def test_list(tmp_dir, scm, dvc, exp_stage):
baseline_a[:7]: {ref_info_a.name, ref_info_b.name}
}

exp_list = dvc.experiments.ls(rev=[baseline_a, scm.get_rev()])
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a[:7]: {ref_info_a.name, ref_info_b.name},
"master": {ref_info_c.name},
}

exp_list = dvc.experiments.ls(all_commits=True)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a[:7]: {ref_info_a.name, ref_info_b.name},
Expand Down
24 changes: 24 additions & 0 deletions tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,30 @@ def test_pull_args(tmp_dir, scm, dvc, git_downstream, exp_stage, all_, rev, resu
assert git_downstream.tmp_dir.scm.get_ref(str(ref_info3)) == result3


def test_pull_multi_rev(tmp_dir, scm, dvc, git_downstream, exp_stage):
baseline = scm.get_rev()

results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"])
exp1 = first(results)
ref_info1 = first(exp_refs_by_rev(scm, exp1))
results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"])
exp2 = first(results)
ref_info2 = first(exp_refs_by_rev(scm, exp2))

scm.commit("new_baseline")

results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"])
exp3 = first(results)
ref_info3 = first(exp_refs_by_rev(scm, exp3))

downstream_exp = git_downstream.tmp_dir.dvc.experiments
git_downstream.tmp_dir.scm.fetch_refspecs(str(tmp_dir), ["master:master"])
downstream_exp.pull(git_downstream.remote, [], rev=[baseline, scm.get_rev()])
assert git_downstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1
assert git_downstream.tmp_dir.scm.get_ref(str(ref_info2)) == exp2
assert git_downstream.tmp_dir.scm.get_ref(str(ref_info3)) == exp3


def test_pull_diverged(tmp_dir, scm, dvc, git_downstream, exp_stage):
git_downstream.tmp_dir.scm_gen("foo", "foo", commit="init")
remote_rev = git_downstream.tmp_dir.scm.get_rev()
Expand Down
23 changes: 23 additions & 0 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,26 @@ def test_remove_experiments_by_rev(tmp_dir, scm, dvc, exp_stage):
assert "queue2" in queue_revs
assert scm.get_ref(new_exp_ref) is not None
assert "queue4" in queue_revs


def test_remove_multi_rev(tmp_dir, scm, dvc, exp_stage):
baseline = scm.get_rev()

results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"])
baseline_exp_ref = first(exp_refs_by_rev(scm, first(results)))

dvc.experiments.run(
exp_stage.addressing, params=["foo=2"], queue=True, name="queue2"
)
scm.commit("new_baseline")

results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"])
new_exp_ref = first(exp_refs_by_rev(scm, first(results)))

assert set(dvc.experiments.remove(rev=[baseline, scm.get_rev()])) == {
baseline_exp_ref.name,
new_exp_ref.name,
}

assert scm.get_ref(str(baseline_exp_ref)) is None
assert scm.get_ref(str(new_exp_ref)) is None

0 comments on commit 3ddd4b8

Please sign in to comment.