From 4cc06339c8100d3356ab7f23f544a46faf287e40 Mon Sep 17 00:00:00 2001 From: Gao Date: Sun, 5 Sep 2021 10:06:07 +0800 Subject: [PATCH] Clean up remotes's exps (#6471) * Clean up remotes's exps fix #6006 1. add a new argument `--git-remote` to `dvc exp remove` 2. support remote a special remote exp 3. add tests for it 4. fix #6421 5. add a test for #6421. 6. extract some functions from the `dvc pull`, `dvc push`, `dvc remove` to utils. 7. add a tests for this new util function `resolve_exp_ref` 8. add `__init__.py` to some test package for the pytest fail Co-authored-by: Jorge Orpinel Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Dave Berenbaum --- dvc/command/experiments.py | 9 ++- dvc/repo/experiments/pull.py | 35 ++---------- dvc/repo/experiments/push.py | 36 ++---------- dvc/repo/experiments/remove.py | 69 +++++++++-------------- dvc/repo/experiments/utils.py | 42 ++++++++++++++ dvc/scm/git/backend/dulwich.py | 4 +- tests/dir_helpers.py | 22 ++++++++ tests/func/experiments/conftest.py | 18 ------ tests/func/experiments/test_remove.py | 26 +++++++++ tests/unit/command/test_experiments.py | 17 +++--- tests/unit/repo/experiments/__init__.py | 0 tests/unit/repo/experiments/test_utils.py | 27 +++++++++ tests/unit/stage/__init__.py | 0 13 files changed, 176 insertions(+), 129 deletions(-) create mode 100644 tests/unit/repo/experiments/__init__.py create mode 100644 tests/unit/repo/experiments/test_utils.py create mode 100644 tests/unit/stage/__init__.py diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index e1fa5d83f1..10c50691f6 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -750,6 +750,7 @@ def run(self): exp_names=self.args.experiment, queue=self.args.queue, clear_all=self.args.all, + remote=self.args.git_remote, ) return 0 @@ -1231,7 +1232,7 @@ def add_parser(subparsers, parent_parser): ) experiments_pull_parser.set_defaults(func=CmdExperimentsPull) - EXPERIMENTS_REMOVE_HELP = "Remove local experiments." + EXPERIMENTS_REMOVE_HELP = "Remove experiments." experiments_remove_parser = experiments_subparsers.add_parser( "remove", parents=[parent_parser], @@ -1249,6 +1250,12 @@ def add_parser(subparsers, parent_parser): action="store_true", help="Remove all committed experiments.", ) + remove_group.add_argument( + "-g", + "--git-remote", + metavar="", + help="Name or URL of the Git remote to remove the experiment from", + ) experiments_remove_parser.add_argument( "experiment", nargs="*", diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index 6a21fafccf..1b7f157e0e 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -4,7 +4,7 @@ from dvc.repo import locked from dvc.repo.scm_context import scm_context -from .utils import exp_commits, remote_exp_refs_by_name +from .utils import exp_commits, resolve_exp_ref logger = logging.getLogger(__name__) @@ -14,7 +14,11 @@ def pull( repo, git_remote, exp_name, *args, force=False, pull_cache=False, **kwargs ): - exp_ref = _get_exp_ref(repo, git_remote, exp_name) + exp_ref = resolve_exp_ref(repo.scm, exp_name, git_remote) + if not exp_ref: + raise InvalidArgumentError( + f"Experiment '{exp_name}' does not exist in '{git_remote}'" + ) def on_diverged(refname: str, rev: str) -> bool: if repo.scm.get_ref(refname) == rev: @@ -35,33 +39,6 @@ def on_diverged(refname: str, rev: str) -> bool: _pull_cache(repo, exp_ref, **kwargs) -def _get_exp_ref(repo, git_remote, exp_name): - if exp_name.startswith("refs/"): - return exp_name - - exp_refs = list(remote_exp_refs_by_name(repo.scm, git_remote, exp_name)) - if not exp_refs: - raise InvalidArgumentError( - f"Experiment '{exp_name}' does not exist in '{git_remote}'" - ) - if len(exp_refs) > 1: - cur_rev = repo.scm.get_rev() - for info in exp_refs: - if info.baseline_sha == cur_rev: - return info - msg = [ - ( - f"Ambiguous name '{exp_name}' refers to multiple " - "experiments in '{git_remote}'. Use full refname to pull one " - "of the following:" - ), - "", - ] - msg.extend([f"\t{info}" for info in exp_refs]) - raise InvalidArgumentError("\n".join(msg)) - return exp_refs[0] - - def _pull_cache(repo, exp_ref, dvc_remote=None, jobs=None, run_cache=False): revs = list(exp_commits(repo.scm, [exp_ref])) logger.debug("dvc fetch experiment '%s'", exp_ref) diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index 7b58682e33..d42c3736b2 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -2,10 +2,9 @@ from dvc.exceptions import DvcException, InvalidArgumentError from dvc.repo import locked -from dvc.repo.experiments.base import ExpRefInfo from dvc.repo.scm_context import scm_context -from .utils import exp_commits, exp_refs_by_name +from .utils import exp_commits, resolve_exp_ref logger = logging.getLogger(__name__) @@ -21,7 +20,11 @@ def push( push_cache=False, **kwargs, ): - exp_ref = _get_exp_ref(repo, exp_name) + exp_ref = resolve_exp_ref(repo.scm, exp_name) + if not exp_ref: + raise InvalidArgumentError( + f"'{exp_name}' is not a valid experiment name" + ) def on_diverged(refname: str, rev: str) -> bool: if repo.scm.get_ref(refname) == rev: @@ -42,33 +45,6 @@ def on_diverged(refname: str, rev: str) -> bool: _push_cache(repo, exp_ref, **kwargs) -def _get_exp_ref(repo, exp_name: str) -> ExpRefInfo: - if exp_name.startswith("refs/"): - return ExpRefInfo.from_ref(exp_name) - - exp_refs = list(exp_refs_by_name(repo.scm, exp_name)) - if not exp_refs: - raise InvalidArgumentError( - f"'{exp_name}' is not a valid experiment name" - ) - if len(exp_refs) > 1: - cur_rev = repo.scm.get_rev() - for info in exp_refs: - if info.baseline_sha == cur_rev: - return info - msg = [ - ( - f"Ambiguous name '{exp_name}' refers to multiple " - "experiments. Use full refname to push one of the " - "following:" - ), - "", - ] - msg.extend([f"\t{info}" for info in exp_refs]) - raise InvalidArgumentError("\n".join(msg)) - return exp_refs[0] - - def _push_cache(repo, exp_ref, dvc_remote=None, jobs=None, run_cache=False): revs = list(exp_commits(repo.scm, [exp_ref])) logger.debug("dvc push experiment '%s'", exp_ref) diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 2df54728bd..c712572f5d 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -6,8 +6,7 @@ from dvc.repo.scm_context import scm_context from dvc.scm.base import RevError -from .base import EXPS_NAMESPACE, ExpRefInfo -from .utils import exp_refs, exp_refs_by_name, remove_exp_refs +from .utils import exp_refs, remove_exp_refs, resolve_exp_ref logger = logging.getLogger(__name__) @@ -19,6 +18,7 @@ def remove( exp_names=None, queue=False, clear_all=False, + remote=None, **kwargs, ): if not any([exp_names, queue, clear_all]): @@ -31,13 +31,7 @@ def remove( removed += _clear_all(repo) if exp_names: - remained = _remove_commited_exps(repo, exp_names) - remained = _remove_queued_exps(repo, remained) - if remained: - raise InvalidArgumentError( - "'{}' is not a valid experiment".format(";".join(remained)) - ) - removed += len(exp_names) - len(remained) + removed += _remove_exp_by_names(repo, remote, exp_names) return removed @@ -67,46 +61,24 @@ def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]: return None -def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]: - cur_rev = repo.scm.get_rev() - if exp_name.startswith(EXPS_NAMESPACE): - if repo.scm.get_ref(exp_name): - return ExpRefInfo.from_ref(exp_name) - else: - exp_ref_list = list(exp_refs_by_name(repo.scm, exp_name)) - if exp_ref_list: - return _get_ref(exp_ref_list, exp_name, cur_rev) - return None - - -def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]: - if len(ref_infos) > 1: - for info in ref_infos: - if info.baseline_sha == cur_rev: - return info - msg = [ - ( - f"Ambiguous name '{name}' refers to multiple " - "experiments. Use full refname to remove one of " - "the following:" - ) - ] - msg.extend([f"\t{info}" for info in ref_infos]) - raise InvalidArgumentError("\n".join(msg)) - return ref_infos[0] - - -def _remove_commited_exps(repo, refs: List[str]) -> List[str]: +def _remove_commited_exps( + repo, remote: Optional[str], exp_names: List[str] +) -> List[str]: remain_list = [] remove_list = [] - for ref in refs: - ref_info = _get_exp_ref(repo, ref) + for exp_name in exp_names: + ref_info = resolve_exp_ref(repo.scm, exp_name, remote) + if ref_info: remove_list.append(ref_info) else: - remain_list.append(ref) + remain_list.append(exp_name) if remove_list: - remove_exp_refs(repo.scm, remove_list) + if not remote: + remove_exp_refs(repo.scm, remove_list) + else: + for ref_info in remove_list: + repo.scm.push_refspec(remote, None, str(ref_info)) return remain_list @@ -119,3 +91,14 @@ def _remove_queued_exps(repo, refs_or_revs: List[str]) -> List[str]: else: repo.experiments.stash.drop(stash_index) return remain_list + + +def _remove_exp_by_names(repo, remote, exp_names: List[str]) -> int: + remained = _remove_commited_exps(repo, remote, exp_names) + if not remote: + remained = _remove_queued_exps(repo, remained) + if remained: + raise InvalidArgumentError( + "'{}' is not a valid experiment".format(";".join(remained)) + ) + return len(exp_names) - len(remained) diff --git a/dvc/repo/experiments/utils.py b/dvc/repo/experiments/utils.py index e45bc418a0..c2a8f9cb05 100644 --- a/dvc/repo/experiments/utils.py +++ b/dvc/repo/experiments/utils.py @@ -1,5 +1,6 @@ from typing import Generator, Iterable, Optional, Set +from dvc.exceptions import InvalidArgumentError from dvc.scm.git import Git from .base import ( @@ -115,3 +116,44 @@ def fix_exp_head(scm: "Git", ref: Optional[str]) -> Optional[str]: if name == "HEAD" and scm.get_ref(EXEC_BASELINE): return "".join((EXEC_BASELINE, tail)) return ref + + +def resolve_exp_ref( + scm, exp_name: str, git_remote: Optional[str] = None +) -> Optional[ExpRefInfo]: + if exp_name.startswith("refs/"): + return ExpRefInfo.from_ref(exp_name) + + if git_remote: + exp_ref_list = list(remote_exp_refs_by_name(scm, git_remote, exp_name)) + else: + exp_ref_list = list(exp_refs_by_name(scm, exp_name)) + + if not exp_ref_list: + return None + if len(exp_ref_list) > 1: + cur_rev = scm.get_rev() + for info in exp_ref_list: + if info.baseline_sha == cur_rev: + return info + if git_remote: + msg = [ + ( + f"Ambiguous name '{exp_name}' refers to multiple " + "experiments. Use full refname to push one of the " + "following:" + ), + "", + ] + else: + msg = [ + ( + f"Ambiguous name '{exp_name}' refers to multiple " + f"experiments in '{git_remote}'. Use full refname to pull " + "one of the following:" + ), + "", + ] + msg.extend([f"\t{info}" for info in exp_ref_list]) + raise InvalidArgumentError("\n".join(msg)) + return exp_ref_list[0] diff --git a/dvc/scm/git/backend/dulwich.py b/dvc/scm/git/backend/dulwich.py index 1fbea38aa0..91327206bb 100644 --- a/dvc/scm/git/backend/dulwich.py +++ b/dvc/scm/git/backend/dulwich.py @@ -407,9 +407,11 @@ def push_refspec( ) from exc def update_refs(refs): + from dulwich.objects import ZERO_SHA + new_refs = {} for ref, value in zip(dest_refs, values): - if ref in refs: + if ref in refs and value != ZERO_SHA: local_sha = self.repo.refs[ref] remote_sha = refs[ref] try: diff --git a/tests/dir_helpers.py b/tests/dir_helpers.py index 92609e3cd9..f50ac02fc8 100644 --- a/tests/dir_helpers.py +++ b/tests/dir_helpers.py @@ -67,6 +67,8 @@ "erepo_dir", "git_dir", "git_init", + "git_upstream", + "git_downstream", ] @@ -407,3 +409,23 @@ def git_dir(make_tmp_dir): path = make_tmp_dir("git-erepo", scm=True) path.scm.commit("init repo") return path + + +@pytest.fixture +def git_upstream(tmp_dir, erepo_dir, git_dir, request): + remote = erepo_dir if "dvc" in request.fixturenames else git_dir + url = "file://{}".format(remote.resolve().as_posix()) + tmp_dir.scm.gitpython.repo.create_remote("upstream", url) + remote.remote = "upstream" + remote.url = url + return remote + + +@pytest.fixture +def git_downstream(tmp_dir, erepo_dir, git_dir, request): + remote = erepo_dir if "dvc" in request.fixturenames else git_dir + url = "file://{}".format(tmp_dir.resolve().as_posix()) + remote.scm.gitpython.repo.create_remote("upstream", url) + remote.remote = "upstream" + remote.url = url + return remote diff --git a/tests/func/experiments/conftest.py b/tests/func/experiments/conftest.py index fb592ae152..455aa687c4 100644 --- a/tests/func/experiments/conftest.py +++ b/tests/func/experiments/conftest.py @@ -88,21 +88,3 @@ def checkpoint_stage(tmp_dir, scm, dvc, mocker): scm.commit("init") stage.iterations = DEFAULT_ITERATIONS return stage - - -@pytest.fixture -def git_upstream(tmp_dir, erepo_dir): - url = "file://{}".format(erepo_dir.resolve().as_posix()) - tmp_dir.scm.gitpython.repo.create_remote("upstream", url) - erepo_dir.remote = "upstream" - erepo_dir.url = url - return erepo_dir - - -@pytest.fixture -def git_downstream(tmp_dir, erepo_dir): - url = "file://{}".format(tmp_dir.resolve().as_posix()) - erepo_dir.scm.gitpython.repo.create_remote("upstream", url) - erepo_dir.remote = "upstream" - erepo_dir.url = url - return erepo_dir diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index 909561a080..f240785d44 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -90,3 +90,29 @@ def test_remove_all(tmp_dir, scm, dvc, exp_stage): assert len(dvc.experiments.stash) == 2 assert scm.get_ref(str(ref_info2)) is None assert scm.get_ref(str(ref_info1)) is None + + +@pytest.mark.parametrize("use_url", [True, False]) +def test_remove_remote(tmp_dir, scm, dvc, exp_stage, git_upstream, use_url): + remote = git_upstream.url if use_url else git_upstream.remote + + ref_info_list = [] + exp_list = [] + for i in range(3): + results = dvc.experiments.run( + exp_stage.addressing, params=[f"foo={i}"] + ) + exp = first(results) + exp_list.append(exp) + ref_info = first(exp_refs_by_rev(scm, exp)) + ref_info_list.append(ref_info) + dvc.experiments.push(remote, ref_info.name) + assert git_upstream.scm.get_ref(str(ref_info)) == exp + + dvc.experiments.remove( + remote=remote, exp_names=[str(ref_info_list[0]), ref_info_list[1].name] + ) + + assert git_upstream.scm.get_ref(str(ref_info_list[0])) is None + assert git_upstream.scm.get_ref(str(ref_info_list[1])) is None + assert git_upstream.scm.get_ref(str(ref_info_list[2])) == exp_list[2] diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index d8043af94b..d843a49b53 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -253,15 +253,17 @@ def test_experiments_pull(dvc, scm, mocker): @pytest.mark.parametrize( - "queue,clear_all", - [(True, False), (False, True)], + "queue,clear_all,remote", + [(True, False, None), (False, True, None), (False, False, True)], ) -def test_experiments_remove(dvc, scm, mocker, queue, clear_all): +def test_experiments_remove(dvc, scm, mocker, queue, clear_all, remote): if queue: - args = "--queue" + args = ["--queue"] if clear_all: - args = "--all" - cli_args = parse_args(["experiments", "remove", args]) + args = ["--all"] + if remote: + args = ["--git-remote", "myremote", "exp-123", "exp-234"] + cli_args = parse_args(["experiments", "remove"] + args) assert cli_args.func == CmdExperimentsRemove cmd = cli_args.func(cli_args) @@ -270,7 +272,8 @@ def test_experiments_remove(dvc, scm, mocker, queue, clear_all): assert cmd.run() == 0 m.assert_called_once_with( cmd.repo, - exp_names=[], + exp_names=["exp-123", "exp-234"] if remote else [], queue=queue, clear_all=clear_all, + remote="myremote" if remote else None, ) diff --git a/tests/unit/repo/experiments/__init__.py b/tests/unit/repo/experiments/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/repo/experiments/test_utils.py b/tests/unit/repo/experiments/test_utils.py new file mode 100644 index 0000000000..b0fd3da808 --- /dev/null +++ b/tests/unit/repo/experiments/test_utils.py @@ -0,0 +1,27 @@ +import pytest + +from dvc.repo.experiments.base import EXPS_NAMESPACE, ExpRefInfo +from dvc.repo.experiments.utils import resolve_exp_ref + + +def commit_exp_ref(tmp_dir, scm, file="foo", contents="foo", name="foo"): + tmp_dir.scm_gen(file, contents, commit="init") + rev = scm.get_rev() + ref = "/".join([EXPS_NAMESPACE, "ab", "c123", name]) + scm.gitpython.set_ref(ref, rev) + return ref, rev + + +@pytest.mark.parametrize("use_url", [True, False]) +@pytest.mark.parametrize("name_only", [True, False]) +def test_resolve_exp_ref(tmp_dir, scm, git_upstream, name_only, use_url): + ref, _ = commit_exp_ref(tmp_dir, scm) + ref_info = resolve_exp_ref(scm, "foo" if name_only else ref) + assert isinstance(ref_info, ExpRefInfo) + assert str(ref_info) == ref + + scm.push_refspec(git_upstream.url, ref, ref) + remote = git_upstream.url if use_url else git_upstream.remote + remote_ref_info = resolve_exp_ref(scm, "foo" if name_only else ref, remote) + assert isinstance(remote_ref_info, ExpRefInfo) + assert str(remote_ref_info) == ref diff --git a/tests/unit/stage/__init__.py b/tests/unit/stage/__init__.py new file mode 100644 index 0000000000..e69de29bb2