Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exp: Fix --rev args for ls pull and remove. #9483

Merged
merged 1 commit into from
May 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions dvc/repo/experiments/ls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections import defaultdict
from typing import Optional
from typing import List, Optional, Union

from dvc.repo import locked
from dvc.repo.scm_context import scm_context
Expand All @@ -15,14 +15,17 @@
@scm_context
def ls(
repo,
rev: Optional[str] = None,
rev: Optional[Union[List[str], str]] = None,
all_commits: bool = False,
num: int = 1,
git_remote: Optional[str] = None,
):
rev_set = None
if not all_commits:
revs = iter_revs(repo.scm, [rev or "HEAD"], num)
rev = rev or "HEAD"
if isinstance(rev, str):
rev = [rev]
revs = iter_revs(repo.scm, rev, num)
rev_set = set(revs.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)

Expand Down
6 changes: 4 additions & 2 deletions dvc/repo/experiments/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def pull( # noqa: C901
git_remote: str,
exp_names: Union[Iterable[str], str],
all_commits=False,
rev: Optional[str] = None,
rev: Optional[Union[List[str], str]] = None,
num=1,
force: bool = False,
pull_cache: bool = False,
Expand All @@ -49,7 +49,9 @@ def pull( # noqa: C901
raise UnresolvedExpNamesError(unresolved_exp_names)

if rev:
rev_dict = iter_revs(repo.scm, [rev], num)
if isinstance(rev, str):
rev = [rev]
rev_dict = iter_revs(repo.scm, rev, num)
rev_set = set(rev_dict.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)
for _, ref_info_list in ref_info_dict.items():
Expand Down
8 changes: 5 additions & 3 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
def remove( # noqa: C901, PLR0912
repo: "Repo",
exp_names: Union[None, str, List[str]] = None,
rev: Optional[str] = None,
rev: Optional[Union[List[str], str]] = None,
all_commits: bool = False,
num: int = 1,
queue: bool = False,
Expand Down Expand Up @@ -60,6 +60,8 @@ def remove( # noqa: C901, PLR0912
if remained:
raise UnresolvedExpNamesError(remained, git_remote=git_remote)
elif rev:
if isinstance(rev, str):
rev = [rev]
exp_ref_dict = _resolve_exp_by_baseline(repo, rev, num, git_remote)
removed.extend(exp_ref_dict.keys())
exp_ref_list.extend(exp_ref_dict.values())
Expand All @@ -85,14 +87,14 @@ def remove( # noqa: C901, PLR0912

def _resolve_exp_by_baseline(
repo: "Repo",
rev: str,
rev: List[str],
num: int,
git_remote: Optional[str] = None,
) -> Dict[str, "ExpRefInfo"]:
assert isinstance(repo.scm, Git)

commit_ref_dict: Dict[str, "ExpRefInfo"] = {}
rev_dict = iter_revs(repo.scm, [rev], num)
rev_dict = iter_revs(repo.scm, rev, num)
rev_set = set(rev_dict.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)
for _, ref_info_list in ref_info_dict.items():
Expand Down
6 changes: 6 additions & 0 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@ def test_list(tmp_dir, scm, dvc, exp_stage):
baseline_a[:7]: {ref_info_a.name, ref_info_b.name}
}

exp_list = dvc.experiments.ls(rev=[baseline_a, scm.get_rev()])
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a[:7]: {ref_info_a.name, ref_info_b.name},
"master": {ref_info_c.name},
}

exp_list = dvc.experiments.ls(all_commits=True)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a[:7]: {ref_info_a.name, ref_info_b.name},
Expand Down
24 changes: 24 additions & 0 deletions tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,30 @@ def test_pull_args(tmp_dir, scm, dvc, git_downstream, exp_stage, all_, rev, resu
assert git_downstream.tmp_dir.scm.get_ref(str(ref_info3)) == result3


def test_pull_multi_rev(tmp_dir, scm, dvc, git_downstream, exp_stage):
baseline = scm.get_rev()

results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"])
exp1 = first(results)
ref_info1 = first(exp_refs_by_rev(scm, exp1))
results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"])
exp2 = first(results)
ref_info2 = first(exp_refs_by_rev(scm, exp2))

scm.commit("new_baseline")

results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"])
exp3 = first(results)
ref_info3 = first(exp_refs_by_rev(scm, exp3))

downstream_exp = git_downstream.tmp_dir.dvc.experiments
git_downstream.tmp_dir.scm.fetch_refspecs(str(tmp_dir), ["master:master"])
downstream_exp.pull(git_downstream.remote, [], rev=[baseline, scm.get_rev()])
assert git_downstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1
assert git_downstream.tmp_dir.scm.get_ref(str(ref_info2)) == exp2
assert git_downstream.tmp_dir.scm.get_ref(str(ref_info3)) == exp3


def test_pull_diverged(tmp_dir, scm, dvc, git_downstream, exp_stage):
git_downstream.tmp_dir.scm_gen("foo", "foo", commit="init")
remote_rev = git_downstream.tmp_dir.scm.get_rev()
Expand Down
23 changes: 23 additions & 0 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,26 @@ def test_remove_experiments_by_rev(tmp_dir, scm, dvc, exp_stage):
assert "queue2" in queue_revs
assert scm.get_ref(new_exp_ref) is not None
assert "queue4" in queue_revs


def test_remove_multi_rev(tmp_dir, scm, dvc, exp_stage):
baseline = scm.get_rev()

results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"])
baseline_exp_ref = first(exp_refs_by_rev(scm, first(results)))

dvc.experiments.run(
exp_stage.addressing, params=["foo=2"], queue=True, name="queue2"
)
scm.commit("new_baseline")

results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"])
new_exp_ref = first(exp_refs_by_rev(scm, first(results)))

assert set(dvc.experiments.remove(rev=[baseline, scm.get_rev()])) == {
baseline_exp_ref.name,
new_exp_ref.name,
}

assert scm.get_ref(str(baseline_exp_ref)) is None
assert scm.get_ref(str(new_exp_ref)) is None