From 1a43d70d8dfd7b15a2530139b56344d4d5019df5 Mon Sep 17 00:00:00 2001 From: daavoo Date: Sun, 21 May 2023 12:36:40 +0200 Subject: [PATCH] exp: Fix --rev args for ls pull and remove. Closes #9471 --- dvc/repo/experiments/ls.py | 9 +++++--- dvc/repo/experiments/pull.py | 6 ++++-- dvc/repo/experiments/remove.py | 8 +++++--- tests/func/experiments/test_experiments.py | 6 ++++++ tests/func/experiments/test_remote.py | 24 ++++++++++++++++++++++ tests/func/experiments/test_remove.py | 23 +++++++++++++++++++++ 6 files changed, 68 insertions(+), 8 deletions(-) diff --git a/dvc/repo/experiments/ls.py b/dvc/repo/experiments/ls.py index b3f0a6fb4f..c488d19211 100644 --- a/dvc/repo/experiments/ls.py +++ b/dvc/repo/experiments/ls.py @@ -1,6 +1,6 @@ import logging from collections import defaultdict -from typing import Optional +from typing import List, Optional, Union from dvc.repo import locked from dvc.repo.scm_context import scm_context @@ -15,14 +15,17 @@ @scm_context def ls( repo, - rev: Optional[str] = None, + rev: Optional[Union[List[str], str]] = None, all_commits: bool = False, num: int = 1, git_remote: Optional[str] = None, ): rev_set = None if not all_commits: - revs = iter_revs(repo.scm, [rev or "HEAD"], num) + rev = rev or "HEAD" + if isinstance(rev, str): + rev = [rev] + revs = iter_revs(repo.scm, rev, num) rev_set = set(revs.keys()) ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote) diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index ca76eea18f..2d4624c4f8 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -23,7 +23,7 @@ def pull( # noqa: C901 git_remote: str, exp_names: Union[Iterable[str], str], all_commits=False, - rev: Optional[str] = None, + rev: Optional[Union[List[str], str]] = None, num=1, force: bool = False, pull_cache: bool = False, @@ -49,7 +49,9 @@ def pull( # noqa: C901 raise UnresolvedExpNamesError(unresolved_exp_names) if rev: - rev_dict = iter_revs(repo.scm, [rev], num) + if isinstance(rev, str): + rev = [rev] + rev_dict = iter_revs(repo.scm, rev, num) rev_set = set(rev_dict.keys()) ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote) for _, ref_info_list in ref_info_dict.items(): diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index a708f663fd..76996e83a6 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -24,7 +24,7 @@ def remove( # noqa: C901, PLR0912 repo: "Repo", exp_names: Union[None, str, List[str]] = None, - rev: Optional[str] = None, + rev: Optional[Union[List[str], str]] = None, all_commits: bool = False, num: int = 1, queue: bool = False, @@ -60,6 +60,8 @@ def remove( # noqa: C901, PLR0912 if remained: raise UnresolvedExpNamesError(remained, git_remote=git_remote) elif rev: + if isinstance(rev, str): + rev = [rev] exp_ref_dict = _resolve_exp_by_baseline(repo, rev, num, git_remote) removed.extend(exp_ref_dict.keys()) exp_ref_list.extend(exp_ref_dict.values()) @@ -85,14 +87,14 @@ def remove( # noqa: C901, PLR0912 def _resolve_exp_by_baseline( repo: "Repo", - rev: str, + rev: List[str], num: int, git_remote: Optional[str] = None, ) -> Dict[str, "ExpRefInfo"]: assert isinstance(repo.scm, Git) commit_ref_dict: Dict[str, "ExpRefInfo"] = {} - rev_dict = iter_revs(repo.scm, [rev], num) + rev_dict = iter_revs(repo.scm, rev, num) rev_set = set(rev_dict.keys()) ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote) for _, ref_info_list in ref_info_dict.items(): diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 2d5ab04883..7fd34c788b 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -369,6 +369,12 @@ def test_list(tmp_dir, scm, dvc, exp_stage): baseline_a[:7]: {ref_info_a.name, ref_info_b.name} } + exp_list = dvc.experiments.ls(rev=[baseline_a, scm.get_rev()]) + assert {key: set(val) for key, val in exp_list.items()} == { + baseline_a[:7]: {ref_info_a.name, ref_info_b.name}, + "master": {ref_info_c.name}, + } + exp_list = dvc.experiments.ls(all_commits=True) assert {key: set(val) for key, val in exp_list.items()} == { baseline_a[:7]: {ref_info_a.name, ref_info_b.name}, diff --git a/tests/func/experiments/test_remote.py b/tests/func/experiments/test_remote.py index 3639312258..2f09721f35 100644 --- a/tests/func/experiments/test_remote.py +++ b/tests/func/experiments/test_remote.py @@ -251,6 +251,30 @@ def test_pull_args(tmp_dir, scm, dvc, git_downstream, exp_stage, all_, rev, resu assert git_downstream.tmp_dir.scm.get_ref(str(ref_info3)) == result3 +def test_pull_multi_rev(tmp_dir, scm, dvc, git_downstream, exp_stage): + baseline = scm.get_rev() + + results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"]) + exp1 = first(results) + ref_info1 = first(exp_refs_by_rev(scm, exp1)) + results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) + exp2 = first(results) + ref_info2 = first(exp_refs_by_rev(scm, exp2)) + + scm.commit("new_baseline") + + results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"]) + exp3 = first(results) + ref_info3 = first(exp_refs_by_rev(scm, exp3)) + + downstream_exp = git_downstream.tmp_dir.dvc.experiments + git_downstream.tmp_dir.scm.fetch_refspecs(str(tmp_dir), ["master:master"]) + downstream_exp.pull(git_downstream.remote, [], rev=[baseline, scm.get_rev()]) + assert git_downstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1 + assert git_downstream.tmp_dir.scm.get_ref(str(ref_info2)) == exp2 + assert git_downstream.tmp_dir.scm.get_ref(str(ref_info3)) == exp3 + + def test_pull_diverged(tmp_dir, scm, dvc, git_downstream, exp_stage): git_downstream.tmp_dir.scm_gen("foo", "foo", commit="init") remote_rev = git_downstream.tmp_dir.scm.get_rev() diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index 51d2c9e241..60d653d1e8 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -156,3 +156,26 @@ def test_remove_experiments_by_rev(tmp_dir, scm, dvc, exp_stage): assert "queue2" in queue_revs assert scm.get_ref(new_exp_ref) is not None assert "queue4" in queue_revs + + +def test_remove_multi_rev(tmp_dir, scm, dvc, exp_stage): + baseline = scm.get_rev() + + results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"]) + baseline_exp_ref = first(exp_refs_by_rev(scm, first(results))) + + dvc.experiments.run( + exp_stage.addressing, params=["foo=2"], queue=True, name="queue2" + ) + scm.commit("new_baseline") + + results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"]) + new_exp_ref = first(exp_refs_by_rev(scm, first(results))) + + assert set(dvc.experiments.remove(rev=[baseline, scm.get_rev()])) == { + baseline_exp_ref.name, + new_exp_ref.name, + } + + assert scm.get_ref(str(baseline_exp_ref)) is None + assert scm.get_ref(str(new_exp_ref)) is None