From 59165f27cde956c6654641d649c1a34f5cc8a4ac Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 23 Aug 2021 18:43:55 +0800 Subject: [PATCH 01/18] Clean up remotes's exps fix #6006 1. add a new argument `--git-remote` to `dvc exp remove` 2. add some tests for it --- dvc/command/experiments.py | 7 +++++++ dvc/repo/experiments/remove.py | 15 +++++++++++++-- dvc/scm/git/backend/dulwich.py | 4 +++- tests/func/experiments/test_remove.py | 24 ++++++++++++++++++++++++ tests/unit/command/test_experiments.py | 15 +++++++++------ 5 files changed, 56 insertions(+), 9 deletions(-) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index e1fa5d83f1..eafecda500 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -750,6 +750,7 @@ def run(self): exp_names=self.args.experiment, queue=self.args.queue, clear_all=self.args.all, + remote=self.args.git_remote, ) return 0 @@ -1249,6 +1250,12 @@ def add_parser(subparsers, parent_parser): action="store_true", help="Remove all committed experiments.", ) + remove_group.add_argument( + "-r", + "--git-remote", + metavar="", + help="Name of the Git remote to GC all of the experiment branches.", + ) experiments_remove_parser.add_argument( "experiment", nargs="*", diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 2df54728bd..d14f338fe0 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -7,7 +7,7 @@ from dvc.scm.base import RevError from .base import EXPS_NAMESPACE, ExpRefInfo -from .utils import exp_refs, exp_refs_by_name, remove_exp_refs +from .utils import exp_refs, exp_refs_by_name, remote_exp_refs, remove_exp_refs logger = logging.getLogger(__name__) @@ -19,9 +19,10 @@ def remove( exp_names=None, queue=False, clear_all=False, + remote=None, **kwargs, ): - if not any([exp_names, queue, clear_all]): + if not any([exp_names, queue, clear_all, remote]): return 0 removed = 0 @@ -29,6 +30,8 @@ def remove( removed += _clear_stash(repo) if clear_all: removed += _clear_all(repo) + if remote: + removed += _clear_remote(repo, remote) if exp_names: remained = _remove_commited_exps(repo, exp_names) @@ -41,6 +44,14 @@ def remove( return removed +def _clear_remote(repo, remote: str): + ref_infos = list(remote_exp_refs(repo.scm, remote)) + for ref_info in ref_infos: + ref_name = str(ref_info) + repo.scm.push_refspec(remote, None, ref_name) + return len(ref_infos) + + def _clear_stash(repo): removed = len(repo.experiments.stash) repo.experiments.stash.clear() diff --git a/dvc/scm/git/backend/dulwich.py b/dvc/scm/git/backend/dulwich.py index 12e1691426..48a5f87988 100644 --- a/dvc/scm/git/backend/dulwich.py +++ b/dvc/scm/git/backend/dulwich.py @@ -399,9 +399,11 @@ def push_refspec( ) from exc def update_refs(refs): + from dulwich.objects import ZERO_SHA + new_refs = {} for ref, value in zip(dest_refs, values): - if ref in refs: + if ref in refs and value != ZERO_SHA: local_sha = self.repo.refs[ref] remote_sha = refs[ref] try: diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index 909561a080..fb7c7560d9 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -90,3 +90,27 @@ def test_remove_all(tmp_dir, scm, dvc, exp_stage): assert len(dvc.experiments.stash) == 2 assert scm.get_ref(str(ref_info2)) is None assert scm.get_ref(str(ref_info1)) is None + + +@pytest.mark.parametrize("use_url", [True, False]) +def test_remove_remote(tmp_dir, scm, dvc, exp_stage, git_upstream, use_url): + remote = git_upstream.url if use_url else git_upstream.remote + + results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"]) + exp1 = first(results) + ref_info1 = first(exp_refs_by_rev(scm, exp1)) + + results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) + exp2 = first(results) + ref_info2 = first(exp_refs_by_rev(scm, exp2)) + + dvc.experiments.push(remote, ref_info1.name) + dvc.experiments.push(remote, ref_info2.name) + assert git_upstream.scm.get_ref(str(ref_info1)) == exp1 + assert git_upstream.scm.get_ref(str(ref_info2)) == exp2 + + dvc.experiments.remove(experiments=[ref_info1]) + dvc.experiments.remove(remote=remote) + + assert git_upstream.scm.get_ref(str(ref_info1)) is None + assert git_upstream.scm.get_ref(str(ref_info2)) is None diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index d8043af94b..6c066ca304 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -253,15 +253,17 @@ def test_experiments_pull(dvc, scm, mocker): @pytest.mark.parametrize( - "queue,clear_all", - [(True, False), (False, True)], + "queue,clear_all,remote", + [(True, False, None), (False, True, None), (False, False, True)], ) -def test_experiments_remove(dvc, scm, mocker, queue, clear_all): +def test_experiments_remove(dvc, scm, mocker, queue, clear_all, remote): if queue: - args = "--queue" + args = ["--queue"] if clear_all: - args = "--all" - cli_args = parse_args(["experiments", "remove", args]) + args = ["--all"] + if remote: + args = ["--git-remote", "myremote"] + cli_args = parse_args(["experiments", "remove"] + args) assert cli_args.func == CmdExperimentsRemove cmd = cli_args.func(cli_args) @@ -273,4 +275,5 @@ def test_experiments_remove(dvc, scm, mocker, queue, clear_all): exp_names=[], queue=queue, clear_all=clear_all, + remote="myremote" if remote else None, ) From 1f45c5bdb06324d2381ca43bf51b2e7e4faa9c02 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Sat, 28 Aug 2021 16:05:33 +0800 Subject: [PATCH 02/18] Support remove a special remote exp 1. support remote a special remote exp 2. modify the tests for it 3. fix an issue in dvc exp pull --- dvc/command/experiments.py | 2 +- dvc/repo/experiments/pull.py | 3 -- dvc/repo/experiments/push.py | 3 -- dvc/repo/experiments/remove.py | 65 ++++++++++++++------------ dvc/repo/experiments/utils.py | 4 +- tests/func/experiments/test_remove.py | 32 +++++++------ tests/unit/command/test_experiments.py | 4 +- 7 files changed, 57 insertions(+), 56 deletions(-) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index eafecda500..65f68340d2 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -1250,7 +1250,7 @@ def add_parser(subparsers, parent_parser): action="store_true", help="Remove all committed experiments.", ) - remove_group.add_argument( + experiments_remove_parser.add_argument( "-r", "--git-remote", metavar="", diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index 6a21fafccf..73d4833d6f 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -36,9 +36,6 @@ def on_diverged(refname: str, rev: str) -> bool: def _get_exp_ref(repo, git_remote, exp_name): - if exp_name.startswith("refs/"): - return exp_name - exp_refs = list(remote_exp_refs_by_name(repo.scm, git_remote, exp_name)) if not exp_refs: raise InvalidArgumentError( diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index 7b58682e33..337d7affb0 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -43,9 +43,6 @@ def on_diverged(refname: str, rev: str) -> bool: def _get_exp_ref(repo, exp_name: str) -> ExpRefInfo: - if exp_name.startswith("refs/"): - return ExpRefInfo.from_ref(exp_name) - exp_refs = list(exp_refs_by_name(repo.scm, exp_name)) if not exp_refs: raise InvalidArgumentError( diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index d14f338fe0..ca2cd5d28d 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -6,8 +6,13 @@ from dvc.repo.scm_context import scm_context from dvc.scm.base import RevError -from .base import EXPS_NAMESPACE, ExpRefInfo -from .utils import exp_refs, exp_refs_by_name, remote_exp_refs, remove_exp_refs +from .base import ExpRefInfo +from .utils import ( + exp_refs, + exp_refs_by_name, + remote_exp_refs_by_name, + remove_exp_refs, +) logger = logging.getLogger(__name__) @@ -22,7 +27,7 @@ def remove( remote=None, **kwargs, ): - if not any([exp_names, queue, clear_all, remote]): + if not any([exp_names, queue, clear_all]): return 0 removed = 0 @@ -30,28 +35,12 @@ def remove( removed += _clear_stash(repo) if clear_all: removed += _clear_all(repo) - if remote: - removed += _clear_remote(repo, remote) if exp_names: - remained = _remove_commited_exps(repo, exp_names) - remained = _remove_queued_exps(repo, remained) - if remained: - raise InvalidArgumentError( - "'{}' is not a valid experiment".format(";".join(remained)) - ) - removed += len(exp_names) - len(remained) + removed += _remove_exp_by_names(repo, remote, exp_names) return removed -def _clear_remote(repo, remote: str): - ref_infos = list(remote_exp_refs(repo.scm, remote)) - for ref_info in ref_infos: - ref_name = str(ref_info) - repo.scm.push_refspec(remote, None, ref_name) - return len(ref_infos) - - def _clear_stash(repo): removed = len(repo.experiments.stash) repo.experiments.stash.clear() @@ -78,15 +67,16 @@ def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]: return None -def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]: +def _get_exp_ref(repo, remote: str, exp_name: str) -> Optional[ExpRefInfo]: cur_rev = repo.scm.get_rev() - if exp_name.startswith(EXPS_NAMESPACE): - if repo.scm.get_ref(exp_name): - return ExpRefInfo.from_ref(exp_name) - else: + if not remote: exp_ref_list = list(exp_refs_by_name(repo.scm, exp_name)) - if exp_ref_list: - return _get_ref(exp_ref_list, exp_name, cur_rev) + else: + exp_ref_list = list( + remote_exp_refs_by_name(repo.scm, remote, exp_name) + ) + if exp_ref_list: + return _get_ref(exp_ref_list, exp_name, cur_rev) return None @@ -107,17 +97,21 @@ def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]: return ref_infos[0] -def _remove_commited_exps(repo, refs: List[str]) -> List[str]: +def _remove_commited_exps(repo, remote: str, refs: List[str]) -> List[str]: remain_list = [] remove_list = [] for ref in refs: - ref_info = _get_exp_ref(repo, ref) + ref_info = _get_exp_ref(repo, remote, ref) + if ref_info: remove_list.append(ref_info) else: remain_list.append(ref) if remove_list: - remove_exp_refs(repo.scm, remove_list) + if not remote: + remove_exp_refs(repo.scm, remove_list) + else: + repo.scm.push_refspec(remote, None, str(ref_info)) return remain_list @@ -130,3 +124,14 @@ def _remove_queued_exps(repo, refs_or_revs: List[str]) -> List[str]: else: repo.experiments.stash.drop(stash_index) return remain_list + + +def _remove_exp_by_names(repo, remote, exp_names: List[str]) -> int: + remained = _remove_commited_exps(repo, remote, exp_names) + if not remote: + remained = _remove_queued_exps(repo, remained) + if remained: + raise InvalidArgumentError( + "'{}' is not a valid experiment".format(";".join(remained)) + ) + return len(exp_names) - len(remained) diff --git a/dvc/repo/experiments/utils.py b/dvc/repo/experiments/utils.py index e45bc418a0..9a54f3b347 100644 --- a/dvc/repo/experiments/utils.py +++ b/dvc/repo/experiments/utils.py @@ -33,7 +33,7 @@ def exp_refs_by_name( ) -> Generator["ExpRefInfo", None, None]: """Iterate over all experiment refs matching the specified name.""" for ref_info in exp_refs(scm): - if ref_info.name == name: + if ref_info.name == name or str(ref_info) == name: yield ref_info @@ -63,7 +63,7 @@ def remote_exp_refs_by_name( ) -> Generator["ExpRefInfo", None, None]: """Iterate over all remote experiment refs matching the specified name.""" for ref_info in remote_exp_refs(scm, url): - if ref_info.name == name: + if ref_info.name == name or str(ref_info) == name: yield ref_info diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index fb7c7560d9..0d61aac422 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -96,21 +96,23 @@ def test_remove_all(tmp_dir, scm, dvc, exp_stage): def test_remove_remote(tmp_dir, scm, dvc, exp_stage, git_upstream, use_url): remote = git_upstream.url if use_url else git_upstream.remote - results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"]) - exp1 = first(results) - ref_info1 = first(exp_refs_by_rev(scm, exp1)) - - results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) - exp2 = first(results) - ref_info2 = first(exp_refs_by_rev(scm, exp2)) + ref_info_list = [] + exp_list = [] + for i in range(3): + results = dvc.experiments.run( + exp_stage.addressing, params=[f"foo={i}"] + ) + exp = first(results) + exp_list.append(exp) + ref_info = first(exp_refs_by_rev(scm, exp)) + ref_info_list.append(ref_info) + dvc.experiments.push(remote, ref_info.name) + assert git_upstream.scm.get_ref(str(ref_info)) == exp - dvc.experiments.push(remote, ref_info1.name) - dvc.experiments.push(remote, ref_info2.name) - assert git_upstream.scm.get_ref(str(ref_info1)) == exp1 - assert git_upstream.scm.get_ref(str(ref_info2)) == exp2 + dvc.experiments.remove(experiments=ref_info_list) - dvc.experiments.remove(experiments=[ref_info1]) - dvc.experiments.remove(remote=remote) + dvc.experiments.remove(remote=remote, experiments=ref_info_list[:2]) - assert git_upstream.scm.get_ref(str(ref_info1)) is None - assert git_upstream.scm.get_ref(str(ref_info2)) is None + assert git_upstream.scm.get_ref(str(ref_info_list[0])) is None + assert git_upstream.scm.get_ref(str(ref_info_list[1])) is None + assert git_upstream.scm.get_ref(str(ref_info_list[2])) == exp_list[2] diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 6c066ca304..d843a49b53 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -262,7 +262,7 @@ def test_experiments_remove(dvc, scm, mocker, queue, clear_all, remote): if clear_all: args = ["--all"] if remote: - args = ["--git-remote", "myremote"] + args = ["--git-remote", "myremote", "exp-123", "exp-234"] cli_args = parse_args(["experiments", "remove"] + args) assert cli_args.func == CmdExperimentsRemove @@ -272,7 +272,7 @@ def test_experiments_remove(dvc, scm, mocker, queue, clear_all, remote): assert cmd.run() == 0 m.assert_called_once_with( cmd.repo, - exp_names=[], + exp_names=["exp-123", "exp-234"] if remote else [], queue=queue, clear_all=clear_all, remote="myremote" if remote else None, From 3e7c2f1b44477db182c319d0128c37e0ee991a8f Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Sat, 28 Aug 2021 16:54:51 +0800 Subject: [PATCH 03/18] reviewed problems and solve failed tests --- dvc/command/experiments.py | 7 ++++--- dvc/repo/experiments/remove.py | 3 ++- tests/func/experiments/test_remove.py | 6 +++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index 65f68340d2..08036ca088 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -1232,7 +1232,7 @@ def add_parser(subparsers, parent_parser): ) experiments_pull_parser.set_defaults(func=CmdExperimentsPull) - EXPERIMENTS_REMOVE_HELP = "Remove local experiments." + EXPERIMENTS_REMOVE_HELP = "Remove experiments." experiments_remove_parser = experiments_subparsers.add_parser( "remove", parents=[parent_parser], @@ -1251,10 +1251,11 @@ def add_parser(subparsers, parent_parser): help="Remove all committed experiments.", ) experiments_remove_parser.add_argument( - "-r", + "-g", "--git-remote", metavar="", - help="Name of the Git remote to GC all of the experiment branches.", + help="Name or URL of the Git remote to delete the experiment " + "references", ) experiments_remove_parser.add_argument( "experiment", diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index ca2cd5d28d..e165339a16 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -111,7 +111,8 @@ def _remove_commited_exps(repo, remote: str, refs: List[str]) -> List[str]: if not remote: remove_exp_refs(repo.scm, remove_list) else: - repo.scm.push_refspec(remote, None, str(ref_info)) + for ref_info in remove_list: + repo.scm.push_refspec(remote, None, str(ref_info)) return remain_list diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index 0d61aac422..f240785d44 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -109,9 +109,9 @@ def test_remove_remote(tmp_dir, scm, dvc, exp_stage, git_upstream, use_url): dvc.experiments.push(remote, ref_info.name) assert git_upstream.scm.get_ref(str(ref_info)) == exp - dvc.experiments.remove(experiments=ref_info_list) - - dvc.experiments.remove(remote=remote, experiments=ref_info_list[:2]) + dvc.experiments.remove( + remote=remote, exp_names=[str(ref_info_list[0]), ref_info_list[1].name] + ) assert git_upstream.scm.get_ref(str(ref_info_list[0])) is None assert git_upstream.scm.get_ref(str(ref_info_list[1])) is None From 5affb0054e4a929c63a56057201154192a4bcc57 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 30 Aug 2021 12:25:20 +0800 Subject: [PATCH 04/18] Return to the old version. --- dvc/repo/experiments/pull.py | 3 +++ dvc/repo/experiments/push.py | 3 +++ dvc/repo/experiments/remove.py | 6 +++++- dvc/repo/experiments/utils.py | 2 +- 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index 73d4833d6f..6a21fafccf 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -36,6 +36,9 @@ def on_diverged(refname: str, rev: str) -> bool: def _get_exp_ref(repo, git_remote, exp_name): + if exp_name.startswith("refs/"): + return exp_name + exp_refs = list(remote_exp_refs_by_name(repo.scm, git_remote, exp_name)) if not exp_refs: raise InvalidArgumentError( diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index 337d7affb0..7b58682e33 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -43,6 +43,9 @@ def on_diverged(refname: str, rev: str) -> bool: def _get_exp_ref(repo, exp_name: str) -> ExpRefInfo: + if exp_name.startswith("refs/"): + return ExpRefInfo.from_ref(exp_name) + exp_refs = list(exp_refs_by_name(repo.scm, exp_name)) if not exp_refs: raise InvalidArgumentError( diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index e165339a16..0a13aa8783 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -6,7 +6,7 @@ from dvc.repo.scm_context import scm_context from dvc.scm.base import RevError -from .base import ExpRefInfo +from .base import EXPS_NAMESPACE, ExpRefInfo from .utils import ( exp_refs, exp_refs_by_name, @@ -68,6 +68,10 @@ def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]: def _get_exp_ref(repo, remote: str, exp_name: str) -> Optional[ExpRefInfo]: + if exp_name.startswith(EXPS_NAMESPACE): + if repo.scm.get_ref(exp_name): + return ExpRefInfo.from_ref(exp_name) + cur_rev = repo.scm.get_rev() if not remote: exp_ref_list = list(exp_refs_by_name(repo.scm, exp_name)) diff --git a/dvc/repo/experiments/utils.py b/dvc/repo/experiments/utils.py index 9a54f3b347..fdb74d230f 100644 --- a/dvc/repo/experiments/utils.py +++ b/dvc/repo/experiments/utils.py @@ -63,7 +63,7 @@ def remote_exp_refs_by_name( ) -> Generator["ExpRefInfo", None, None]: """Iterate over all remote experiment refs matching the specified name.""" for ref_info in remote_exp_refs(scm, url): - if ref_info.name == name or str(ref_info) == name: + if ref_info.name == name: yield ref_info From 0b688c56ab0229b2b92a7a751a9a8b39f360370f Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 30 Aug 2021 14:30:26 +0800 Subject: [PATCH 05/18] Some reviewed changes --- dvc/repo/experiments/remove.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 0a13aa8783..71588b5332 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -67,7 +67,9 @@ def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]: return None -def _get_exp_ref(repo, remote: str, exp_name: str) -> Optional[ExpRefInfo]: +def _get_exp_ref( + repo, remote: Optional[str], exp_name: str +) -> Optional[ExpRefInfo]: if exp_name.startswith(EXPS_NAMESPACE): if repo.scm.get_ref(exp_name): return ExpRefInfo.from_ref(exp_name) @@ -79,6 +81,7 @@ def _get_exp_ref(repo, remote: str, exp_name: str) -> Optional[ExpRefInfo]: exp_ref_list = list( remote_exp_refs_by_name(repo.scm, remote, exp_name) ) + if exp_ref_list: return _get_ref(exp_ref_list, exp_name, cur_rev) return None From fc4b415093ffc97f54363101e5879dc26be1bc98 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 30 Aug 2021 16:10:18 +0800 Subject: [PATCH 06/18] Some refactors and fix #6421 1. fix #6421. 2. add a test for it. 3. do some refactors --- dvc/repo/experiments/pull.py | 13 +++----- dvc/repo/experiments/push.py | 13 ++------ dvc/repo/experiments/remove.py | 47 ++++++++------------------- dvc/repo/experiments/utils.py | 26 ++++++++++++++- tests/func/experiments/test_remote.py | 2 +- 5 files changed, 47 insertions(+), 54 deletions(-) diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index 6a21fafccf..11985d9ce4 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -4,7 +4,8 @@ from dvc.repo import locked from dvc.repo.scm_context import scm_context -from .utils import exp_commits, remote_exp_refs_by_name +from .base import ExpRefInfo +from .utils import exp_commits, get_exp_ref_list logger = logging.getLogger(__name__) @@ -35,20 +36,14 @@ def on_diverged(refname: str, rev: str) -> bool: _pull_cache(repo, exp_ref, **kwargs) -def _get_exp_ref(repo, git_remote, exp_name): - if exp_name.startswith("refs/"): - return exp_name +def _get_exp_ref(repo, git_remote: str, exp_name: str) -> ExpRefInfo: + exp_refs = get_exp_ref_list(repo, exp_name, git_remote) - exp_refs = list(remote_exp_refs_by_name(repo.scm, git_remote, exp_name)) if not exp_refs: raise InvalidArgumentError( f"Experiment '{exp_name}' does not exist in '{git_remote}'" ) if len(exp_refs) > 1: - cur_rev = repo.scm.get_rev() - for info in exp_refs: - if info.baseline_sha == cur_rev: - return info msg = [ ( f"Ambiguous name '{exp_name}' refers to multiple " diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index 7b58682e33..0b16bda58f 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -2,10 +2,10 @@ from dvc.exceptions import DvcException, InvalidArgumentError from dvc.repo import locked -from dvc.repo.experiments.base import ExpRefInfo from dvc.repo.scm_context import scm_context -from .utils import exp_commits, exp_refs_by_name +from .base import ExpRefInfo +from .utils import exp_commits, get_exp_ref_list logger = logging.getLogger(__name__) @@ -43,19 +43,12 @@ def on_diverged(refname: str, rev: str) -> bool: def _get_exp_ref(repo, exp_name: str) -> ExpRefInfo: - if exp_name.startswith("refs/"): - return ExpRefInfo.from_ref(exp_name) - - exp_refs = list(exp_refs_by_name(repo.scm, exp_name)) + exp_refs = get_exp_ref_list(repo, exp_name) if not exp_refs: raise InvalidArgumentError( f"'{exp_name}' is not a valid experiment name" ) if len(exp_refs) > 1: - cur_rev = repo.scm.get_rev() - for info in exp_refs: - if info.baseline_sha == cur_rev: - return info msg = [ ( f"Ambiguous name '{exp_name}' refers to multiple " diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 71588b5332..eab02c2e9a 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -6,13 +6,8 @@ from dvc.repo.scm_context import scm_context from dvc.scm.base import RevError -from .base import EXPS_NAMESPACE, ExpRefInfo -from .utils import ( - exp_refs, - exp_refs_by_name, - remote_exp_refs_by_name, - remove_exp_refs, -) +from .base import ExpRefInfo +from .utils import exp_refs, get_exp_ref_list, remove_exp_refs logger = logging.getLogger(__name__) @@ -67,34 +62,18 @@ def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]: return None -def _get_exp_ref( - repo, remote: Optional[str], exp_name: str -) -> Optional[ExpRefInfo]: - if exp_name.startswith(EXPS_NAMESPACE): - if repo.scm.get_ref(exp_name): - return ExpRefInfo.from_ref(exp_name) - - cur_rev = repo.scm.get_rev() - if not remote: - exp_ref_list = list(exp_refs_by_name(repo.scm, exp_name)) - else: - exp_ref_list = list( - remote_exp_refs_by_name(repo.scm, remote, exp_name) - ) - - if exp_ref_list: - return _get_ref(exp_ref_list, exp_name, cur_rev) - return None - - -def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]: +def _get_exp_ref(repo, remote, exp_name) -> Optional[ExpRefInfo]: + ref_infos = get_exp_ref_list(repo, exp_name, remote) + if not ref_infos: + return None if len(ref_infos) > 1: + cur_rev = repo.scm.get_rev() for info in ref_infos: if info.baseline_sha == cur_rev: return info msg = [ ( - f"Ambiguous name '{name}' refers to multiple " + f"Ambiguous name '{exp_name}' refers to multiple " "experiments. Use full refname to remove one of " "the following:" ) @@ -104,16 +83,18 @@ def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]: return ref_infos[0] -def _remove_commited_exps(repo, remote: str, refs: List[str]) -> List[str]: +def _remove_commited_exps( + repo, remote: Optional[str], exp_names: List[str] +) -> List[str]: remain_list = [] remove_list = [] - for ref in refs: - ref_info = _get_exp_ref(repo, remote, ref) + for exp_name in exp_names: + ref_info = _get_exp_ref(repo, remote, exp_name) if ref_info: remove_list.append(ref_info) else: - remain_list.append(ref) + remain_list.append(exp_name) if remove_list: if not remote: remove_exp_refs(repo.scm, remove_list) diff --git a/dvc/repo/experiments/utils.py b/dvc/repo/experiments/utils.py index fdb74d230f..f3725eba24 100644 --- a/dvc/repo/experiments/utils.py +++ b/dvc/repo/experiments/utils.py @@ -1,4 +1,4 @@ -from typing import Generator, Iterable, Optional, Set +from typing import Generator, Iterable, List, Optional, Set from dvc.scm.git import Git @@ -115,3 +115,27 @@ def fix_exp_head(scm: "Git", ref: Optional[str]) -> Optional[str]: if name == "HEAD" and scm.get_ref(EXEC_BASELINE): return "".join((EXEC_BASELINE, tail)) return ref + + +def get_exp_ref_list( + repo, exp_name: str, git_remote: Optional[str] = None +) -> List[ExpRefInfo]: + if exp_name.startswith("refs/"): + return [ExpRefInfo.from_ref(exp_name)] + + if git_remote: + exp_ref_list = list( + remote_exp_refs_by_name(repo.scm, git_remote, exp_name) + ) + else: + exp_ref_list = list(exp_refs_by_name(repo.scm, exp_name)) + + if not exp_ref_list: + return [] + if len(exp_ref_list) > 1: + cur_rev = repo.scm.get_rev() + for info in exp_ref_list: + if info.baseline_sha == cur_rev: + exp_ref_list = [info] + break + return exp_ref_list diff --git a/tests/func/experiments/test_remote.py b/tests/func/experiments/test_remote.py index dffb1970c6..282b07a5f3 100644 --- a/tests/func/experiments/test_remote.py +++ b/tests/func/experiments/test_remote.py @@ -151,7 +151,7 @@ def test_pull(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url): git_downstream.scm.remove_ref(str(ref_info)) - downstream_exp.pull(remote, str(ref_info)) + downstream_exp.pull(remote, str(ref_info), pull_cache=True) assert git_downstream.scm.get_ref(str(ref_info)) == exp From 0936d89f287188e09c768a2b640b57109eb6272d Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 30 Aug 2021 20:14:33 +0800 Subject: [PATCH 07/18] impove the function in utils 1. rename the function in utils. 2. remove test for #6421. 3. put all error handling into the utils. --- dvc/repo/experiments/pull.py | 30 +++++----------------- dvc/repo/experiments/push.py | 29 +++++---------------- dvc/repo/experiments/remove.py | 26 ++----------------- dvc/repo/experiments/utils.py | 36 +++++++++++++++++++++------ tests/func/experiments/test_remote.py | 2 +- 5 files changed, 43 insertions(+), 80 deletions(-) diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index 11985d9ce4..54ccbee0ca 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -4,8 +4,7 @@ from dvc.repo import locked from dvc.repo.scm_context import scm_context -from .base import ExpRefInfo -from .utils import exp_commits, get_exp_ref_list +from .utils import exp_commits, resolve_exp_ref logger = logging.getLogger(__name__) @@ -15,7 +14,11 @@ def pull( repo, git_remote, exp_name, *args, force=False, pull_cache=False, **kwargs ): - exp_ref = _get_exp_ref(repo, git_remote, exp_name) + exp_ref = resolve_exp_ref(repo, exp_name, git_remote) + if not exp_ref: + raise InvalidArgumentError( + f"Experiment '{exp_name}' does not exist in '{git_remote}'" + ) def on_diverged(refname: str, rev: str) -> bool: if repo.scm.get_ref(refname) == rev: @@ -36,27 +39,6 @@ def on_diverged(refname: str, rev: str) -> bool: _pull_cache(repo, exp_ref, **kwargs) -def _get_exp_ref(repo, git_remote: str, exp_name: str) -> ExpRefInfo: - exp_refs = get_exp_ref_list(repo, exp_name, git_remote) - - if not exp_refs: - raise InvalidArgumentError( - f"Experiment '{exp_name}' does not exist in '{git_remote}'" - ) - if len(exp_refs) > 1: - msg = [ - ( - f"Ambiguous name '{exp_name}' refers to multiple " - "experiments in '{git_remote}'. Use full refname to pull one " - "of the following:" - ), - "", - ] - msg.extend([f"\t{info}" for info in exp_refs]) - raise InvalidArgumentError("\n".join(msg)) - return exp_refs[0] - - def _pull_cache(repo, exp_ref, dvc_remote=None, jobs=None, run_cache=False): revs = list(exp_commits(repo.scm, [exp_ref])) logger.debug("dvc fetch experiment '%s'", exp_ref) diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index 0b16bda58f..cf77bfc4c2 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -4,8 +4,7 @@ from dvc.repo import locked from dvc.repo.scm_context import scm_context -from .base import ExpRefInfo -from .utils import exp_commits, get_exp_ref_list +from .utils import exp_commits, resolve_exp_ref logger = logging.getLogger(__name__) @@ -21,7 +20,11 @@ def push( push_cache=False, **kwargs, ): - exp_ref = _get_exp_ref(repo, exp_name) + exp_ref = resolve_exp_ref(repo, exp_name) + if not exp_ref: + raise InvalidArgumentError( + f"'{exp_name}' is not a valid experiment name" + ) def on_diverged(refname: str, rev: str) -> bool: if repo.scm.get_ref(refname) == rev: @@ -42,26 +45,6 @@ def on_diverged(refname: str, rev: str) -> bool: _push_cache(repo, exp_ref, **kwargs) -def _get_exp_ref(repo, exp_name: str) -> ExpRefInfo: - exp_refs = get_exp_ref_list(repo, exp_name) - if not exp_refs: - raise InvalidArgumentError( - f"'{exp_name}' is not a valid experiment name" - ) - if len(exp_refs) > 1: - msg = [ - ( - f"Ambiguous name '{exp_name}' refers to multiple " - "experiments. Use full refname to push one of the " - "following:" - ), - "", - ] - msg.extend([f"\t{info}" for info in exp_refs]) - raise InvalidArgumentError("\n".join(msg)) - return exp_refs[0] - - def _push_cache(repo, exp_ref, dvc_remote=None, jobs=None, run_cache=False): revs = list(exp_commits(repo.scm, [exp_ref])) logger.debug("dvc push experiment '%s'", exp_ref) diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index eab02c2e9a..082729ff98 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -6,8 +6,7 @@ from dvc.repo.scm_context import scm_context from dvc.scm.base import RevError -from .base import ExpRefInfo -from .utils import exp_refs, get_exp_ref_list, remove_exp_refs +from .utils import exp_refs, remove_exp_refs, resolve_exp_ref logger = logging.getLogger(__name__) @@ -62,34 +61,13 @@ def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]: return None -def _get_exp_ref(repo, remote, exp_name) -> Optional[ExpRefInfo]: - ref_infos = get_exp_ref_list(repo, exp_name, remote) - if not ref_infos: - return None - if len(ref_infos) > 1: - cur_rev = repo.scm.get_rev() - for info in ref_infos: - if info.baseline_sha == cur_rev: - return info - msg = [ - ( - f"Ambiguous name '{exp_name}' refers to multiple " - "experiments. Use full refname to remove one of " - "the following:" - ) - ] - msg.extend([f"\t{info}" for info in ref_infos]) - raise InvalidArgumentError("\n".join(msg)) - return ref_infos[0] - - def _remove_commited_exps( repo, remote: Optional[str], exp_names: List[str] ) -> List[str]: remain_list = [] remove_list = [] for exp_name in exp_names: - ref_info = _get_exp_ref(repo, remote, exp_name) + ref_info = resolve_exp_ref(repo, exp_name, remote) if ref_info: remove_list.append(ref_info) diff --git a/dvc/repo/experiments/utils.py b/dvc/repo/experiments/utils.py index f3725eba24..4858de142c 100644 --- a/dvc/repo/experiments/utils.py +++ b/dvc/repo/experiments/utils.py @@ -1,5 +1,6 @@ -from typing import Generator, Iterable, List, Optional, Set +from typing import Generator, Iterable, Optional, Set +from dvc.exceptions import InvalidArgumentError from dvc.scm.git import Git from .base import ( @@ -117,11 +118,11 @@ def fix_exp_head(scm: "Git", ref: Optional[str]) -> Optional[str]: return ref -def get_exp_ref_list( +def resolve_exp_ref( repo, exp_name: str, git_remote: Optional[str] = None -) -> List[ExpRefInfo]: +) -> Optional[ExpRefInfo]: if exp_name.startswith("refs/"): - return [ExpRefInfo.from_ref(exp_name)] + return ExpRefInfo.from_ref(exp_name) if git_remote: exp_ref_list = list( @@ -131,11 +132,30 @@ def get_exp_ref_list( exp_ref_list = list(exp_refs_by_name(repo.scm, exp_name)) if not exp_ref_list: - return [] + return None if len(exp_ref_list) > 1: cur_rev = repo.scm.get_rev() for info in exp_ref_list: if info.baseline_sha == cur_rev: - exp_ref_list = [info] - break - return exp_ref_list + return info + if git_remote: + msg = [ + ( + f"Ambiguous name '{exp_name}' refers to multiple " + "experiments. Use full refname to push one of the " + "following:" + ), + "", + ] + else: + msg = [ + ( + f"Ambiguous name '{exp_name}' refers to multiple " + f"experiments in '{git_remote}'. Use full refname to pull " + "one of the following:" + ), + "", + ] + msg.extend([f"\t{info}" for info in exp_ref_list]) + raise InvalidArgumentError("\n".join(msg)) + return exp_ref_list[0] diff --git a/tests/func/experiments/test_remote.py b/tests/func/experiments/test_remote.py index 282b07a5f3..dffb1970c6 100644 --- a/tests/func/experiments/test_remote.py +++ b/tests/func/experiments/test_remote.py @@ -151,7 +151,7 @@ def test_pull(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url): git_downstream.scm.remove_ref(str(ref_info)) - downstream_exp.pull(remote, str(ref_info), pull_cache=True) + downstream_exp.pull(remote, str(ref_info)) assert git_downstream.scm.get_ref(str(ref_info)) == exp From bc1b3753517997ddc28c3bb8a1a826e8a3611e84 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Wed, 1 Sep 2021 10:25:10 +0800 Subject: [PATCH 08/18] add a test to the `resolve_exp_ref` related to issue #6421 --- dvc/repo/experiments/pull.py | 2 +- dvc/repo/experiments/push.py | 2 +- dvc/repo/experiments/remove.py | 2 +- dvc/repo/experiments/utils.py | 10 ++++------ tests/func/experiments/test_utils.py | 26 ++++++++++++++++++++++++++ 5 files changed, 33 insertions(+), 9 deletions(-) create mode 100644 tests/func/experiments/test_utils.py diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index 54ccbee0ca..1b7f157e0e 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -14,7 +14,7 @@ def pull( repo, git_remote, exp_name, *args, force=False, pull_cache=False, **kwargs ): - exp_ref = resolve_exp_ref(repo, exp_name, git_remote) + exp_ref = resolve_exp_ref(repo.scm, exp_name, git_remote) if not exp_ref: raise InvalidArgumentError( f"Experiment '{exp_name}' does not exist in '{git_remote}'" diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index cf77bfc4c2..d42c3736b2 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -20,7 +20,7 @@ def push( push_cache=False, **kwargs, ): - exp_ref = resolve_exp_ref(repo, exp_name) + exp_ref = resolve_exp_ref(repo.scm, exp_name) if not exp_ref: raise InvalidArgumentError( f"'{exp_name}' is not a valid experiment name" diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 082729ff98..c712572f5d 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -67,7 +67,7 @@ def _remove_commited_exps( remain_list = [] remove_list = [] for exp_name in exp_names: - ref_info = resolve_exp_ref(repo, exp_name, remote) + ref_info = resolve_exp_ref(repo.scm, exp_name, remote) if ref_info: remove_list.append(ref_info) diff --git a/dvc/repo/experiments/utils.py b/dvc/repo/experiments/utils.py index 4858de142c..8740e7d74c 100644 --- a/dvc/repo/experiments/utils.py +++ b/dvc/repo/experiments/utils.py @@ -119,22 +119,20 @@ def fix_exp_head(scm: "Git", ref: Optional[str]) -> Optional[str]: def resolve_exp_ref( - repo, exp_name: str, git_remote: Optional[str] = None + scm, exp_name: str, git_remote: Optional[str] = None ) -> Optional[ExpRefInfo]: if exp_name.startswith("refs/"): return ExpRefInfo.from_ref(exp_name) if git_remote: - exp_ref_list = list( - remote_exp_refs_by_name(repo.scm, git_remote, exp_name) - ) + exp_ref_list = list(remote_exp_refs_by_name(scm, git_remote, exp_name)) else: - exp_ref_list = list(exp_refs_by_name(repo.scm, exp_name)) + exp_ref_list = list(exp_refs_by_name(scm, exp_name)) if not exp_ref_list: return None if len(exp_ref_list) > 1: - cur_rev = repo.scm.get_rev() + cur_rev = scm.get_rev() for info in exp_ref_list: if info.baseline_sha == cur_rev: return info diff --git a/tests/func/experiments/test_utils.py b/tests/func/experiments/test_utils.py new file mode 100644 index 0000000000..026e8eb9e2 --- /dev/null +++ b/tests/func/experiments/test_utils.py @@ -0,0 +1,26 @@ +import pytest +from funcy import first + +from dvc.repo.experiments.utils import exp_refs_by_rev, resolve_exp_ref + + +@pytest.mark.parametrize( + "full_name, test_remote", [(True, False), (False, True), (False, False)] +) +def test_remove_remote( + tmp_dir, scm, dvc, exp_stage, git_upstream, full_name, test_remote +): + remote = None + results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) + exp = first(results) + ref_info = first(exp_refs_by_rev(scm, exp)) + if test_remote: + remote = git_upstream.url + dvc.experiments.push(remote, ref_info.name) + + if full_name: + exp_name = str(ref_info) + else: + exp_name = ref_info.name + + assert resolve_exp_ref(scm, exp_name, remote).name == ref_info.name From e60deabd321309e0e2465195373d94726f1b4b11 Mon Sep 17 00:00:00 2001 From: Gao Date: Wed, 1 Sep 2021 15:17:51 +0800 Subject: [PATCH 09/18] Update dvc/command/experiments.py Co-authored-by: Jorge Orpinel --- dvc/command/experiments.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index 08036ca088..8b6e350776 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -1254,8 +1254,8 @@ def add_parser(subparsers, parent_parser): "-g", "--git-remote", metavar="", - help="Name or URL of the Git remote to delete the experiment " - "references", + help="Name or URL of the Git remote to remove the experiment " + "from", ) experiments_remove_parser.add_argument( "experiment", From 0f66d5043bd33383a8d63006f2c2bea5d5c86d99 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Sep 2021 07:18:43 +0000 Subject: [PATCH 10/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dvc/command/experiments.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index 8b6e350776..e1ec5e16e5 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -1254,8 +1254,7 @@ def add_parser(subparsers, parent_parser): "-g", "--git-remote", metavar="", - help="Name or URL of the Git remote to remove the experiment " - "from", + help="Name or URL of the Git remote to remove the experiment " "from", ) experiments_remove_parser.add_argument( "experiment", From 92436ff86d45268900098ec555b97b6300bbd76f Mon Sep 17 00:00:00 2001 From: Gao Date: Thu, 2 Sep 2021 11:44:20 +0800 Subject: [PATCH 11/18] Update dvc/command/experiments.py Co-authored-by: Jorge Orpinel --- dvc/command/experiments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index e1ec5e16e5..0cb4718d15 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -1254,7 +1254,7 @@ def add_parser(subparsers, parent_parser): "-g", "--git-remote", metavar="", - help="Name or URL of the Git remote to remove the experiment " "from", + help="Name or URL of the Git remote to remove the experiment from", ) experiments_remove_parser.add_argument( "experiment", From a642a06fa37a532f2cfeedf740e07da2b5c5c20d Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Thu, 2 Sep 2021 14:44:45 +0800 Subject: [PATCH 12/18] Rewrite the unit tests --- tests/func/experiments/conftest.py | 4 +-- tests/func/experiments/test_utils.py | 26 ---------------- tests/unit/repo/experiments/test_utils.py | 37 +++++++++++++++++++++++ 3 files changed, 39 insertions(+), 28 deletions(-) delete mode 100644 tests/func/experiments/test_utils.py create mode 100644 tests/unit/repo/experiments/test_utils.py diff --git a/tests/func/experiments/conftest.py b/tests/func/experiments/conftest.py index fb592ae152..9ff08f6d18 100644 --- a/tests/func/experiments/conftest.py +++ b/tests/func/experiments/conftest.py @@ -102,7 +102,7 @@ def git_upstream(tmp_dir, erepo_dir): @pytest.fixture def git_downstream(tmp_dir, erepo_dir): url = "file://{}".format(tmp_dir.resolve().as_posix()) - erepo_dir.scm.gitpython.repo.create_remote("upstream", url) - erepo_dir.remote = "upstream" + erepo_dir.scm.gitpython.repo.create_remote("downstream", url) + erepo_dir.remote = "downstream" erepo_dir.url = url return erepo_dir diff --git a/tests/func/experiments/test_utils.py b/tests/func/experiments/test_utils.py deleted file mode 100644 index 026e8eb9e2..0000000000 --- a/tests/func/experiments/test_utils.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest -from funcy import first - -from dvc.repo.experiments.utils import exp_refs_by_rev, resolve_exp_ref - - -@pytest.mark.parametrize( - "full_name, test_remote", [(True, False), (False, True), (False, False)] -) -def test_remove_remote( - tmp_dir, scm, dvc, exp_stage, git_upstream, full_name, test_remote -): - remote = None - results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) - exp = first(results) - ref_info = first(exp_refs_by_rev(scm, exp)) - if test_remote: - remote = git_upstream.url - dvc.experiments.push(remote, ref_info.name) - - if full_name: - exp_name = str(ref_info) - else: - exp_name = ref_info.name - - assert resolve_exp_ref(scm, exp_name, remote).name == ref_info.name diff --git a/tests/unit/repo/experiments/test_utils.py b/tests/unit/repo/experiments/test_utils.py new file mode 100644 index 0000000000..1d4b432129 --- /dev/null +++ b/tests/unit/repo/experiments/test_utils.py @@ -0,0 +1,37 @@ +import pytest + +from dvc.repo.experiments.base import EXPS_NAMESPACE, ExpRefInfo +from dvc.repo.experiments.utils import resolve_exp_ref + + +@pytest.fixture +def git_upstream(scm, make_tmp_dir, name="origin"): + remote_dir = make_tmp_dir("git-remote", scm=True) + url = f"file://{remote_dir.resolve().as_posix()}" + scm.gitpython.repo.create_remote(name, url) + remote_dir.remote = name + remote_dir.url = url + return remote_dir + + +def commit_exp_ref(tmp_dir, scm, file="foo", contents="foo", name="foo"): + tmp_dir.scm_gen(file, contents, commit="init") + rev = scm.get_rev() + ref = "/".join([EXPS_NAMESPACE, "ab", "c123", name]) + scm.gitpython.set_ref(ref, rev) + return ref, rev + + +@pytest.mark.parametrize("use_url", [True, False]) +@pytest.mark.parametrize("name_only", [True, False]) +def test_resolve_exp_ref(tmp_dir, scm, git_upstream, name_only, use_url): + ref, _ = commit_exp_ref(tmp_dir, scm) + ref_info = resolve_exp_ref(scm, "foo" if name_only else ref) + assert isinstance(ref_info, ExpRefInfo) + assert str(ref_info) == ref + + scm.push_refspec(git_upstream.url, ref, ref) + remote = git_upstream.url if use_url else git_upstream.remote + remote_ref_info = resolve_exp_ref(scm, "foo" if name_only else ref, remote) + assert isinstance(remote_ref_info, ExpRefInfo) + assert str(remote_ref_info) == ref From a5677216e1eedd47f3022e401647418b7812577e Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Thu, 2 Sep 2021 15:59:24 +0800 Subject: [PATCH 13/18] Some fixture move --- tests/conftest.py | 25 +++++++++++++++++++++++ tests/func/experiments/conftest.py | 18 ---------------- tests/unit/repo/experiments/test_utils.py | 10 --------- 3 files changed, 25 insertions(+), 28 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index cc160688f6..033fda3ea3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -155,3 +155,28 @@ def pytest_configure(config): enabled_remotes.discard(remote_name) if enabled: enabled_remotes.add(remote_name) + + +@pytest.fixture +def git_upstream(tmp_dir, erepo_dir, git_dir, request): + if "dvc" in request.fixturenames: + url = "file://{}".format(erepo_dir.resolve().as_posix()) + else: + url = "file://{}".format(git_dir.resolve().as_posix()) + tmp_dir.scm.gitpython.repo.create_remote("upstream", url) + erepo_dir.remote = "upstream" + erepo_dir.url = url + return erepo_dir + + +@pytest.fixture +def git_downstream(tmp_dir, erepo_dir, git_dir, request): + if "dvc" in request.fixturenames: + url = "file://{}".format(erepo_dir.resolve().as_posix()) + else: + url = "file://{}".format(git_dir.resolve().as_posix()) + url = "file://{}".format(tmp_dir.resolve().as_posix()) + erepo_dir.scm.gitpython.repo.create_remote("upstream", url) + erepo_dir.remote = "upstream" + erepo_dir.url = url + return erepo_dir diff --git a/tests/func/experiments/conftest.py b/tests/func/experiments/conftest.py index 9ff08f6d18..455aa687c4 100644 --- a/tests/func/experiments/conftest.py +++ b/tests/func/experiments/conftest.py @@ -88,21 +88,3 @@ def checkpoint_stage(tmp_dir, scm, dvc, mocker): scm.commit("init") stage.iterations = DEFAULT_ITERATIONS return stage - - -@pytest.fixture -def git_upstream(tmp_dir, erepo_dir): - url = "file://{}".format(erepo_dir.resolve().as_posix()) - tmp_dir.scm.gitpython.repo.create_remote("upstream", url) - erepo_dir.remote = "upstream" - erepo_dir.url = url - return erepo_dir - - -@pytest.fixture -def git_downstream(tmp_dir, erepo_dir): - url = "file://{}".format(tmp_dir.resolve().as_posix()) - erepo_dir.scm.gitpython.repo.create_remote("downstream", url) - erepo_dir.remote = "downstream" - erepo_dir.url = url - return erepo_dir diff --git a/tests/unit/repo/experiments/test_utils.py b/tests/unit/repo/experiments/test_utils.py index 1d4b432129..b0fd3da808 100644 --- a/tests/unit/repo/experiments/test_utils.py +++ b/tests/unit/repo/experiments/test_utils.py @@ -4,16 +4,6 @@ from dvc.repo.experiments.utils import resolve_exp_ref -@pytest.fixture -def git_upstream(scm, make_tmp_dir, name="origin"): - remote_dir = make_tmp_dir("git-remote", scm=True) - url = f"file://{remote_dir.resolve().as_posix()}" - scm.gitpython.repo.create_remote(name, url) - remote_dir.remote = name - remote_dir.url = url - return remote_dir - - def commit_exp_ref(tmp_dir, scm, file="foo", contents="foo", name="foo"): tmp_dir.scm_gen(file, contents, commit="init") rev = scm.get_rev() From acddcb34bab378d2a04a5cab37e8c8ada86e87bc Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Thu, 2 Sep 2021 16:03:07 +0800 Subject: [PATCH 14/18] A problem in review --- dvc/repo/experiments/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dvc/repo/experiments/utils.py b/dvc/repo/experiments/utils.py index 8740e7d74c..c2a8f9cb05 100644 --- a/dvc/repo/experiments/utils.py +++ b/dvc/repo/experiments/utils.py @@ -34,7 +34,7 @@ def exp_refs_by_name( ) -> Generator["ExpRefInfo", None, None]: """Iterate over all experiment refs matching the specified name.""" for ref_info in exp_refs(scm): - if ref_info.name == name or str(ref_info) == name: + if ref_info.name == name: yield ref_info From 04c902de22f727367556c8716ffd7a4ea13cfc49 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Thu, 2 Sep 2021 16:11:16 +0800 Subject: [PATCH 15/18] Move fixtures to dir_helper --- tests/conftest.py | 25 ------------------------- tests/dir_helpers.py | 25 +++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 033fda3ea3..cc160688f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -155,28 +155,3 @@ def pytest_configure(config): enabled_remotes.discard(remote_name) if enabled: enabled_remotes.add(remote_name) - - -@pytest.fixture -def git_upstream(tmp_dir, erepo_dir, git_dir, request): - if "dvc" in request.fixturenames: - url = "file://{}".format(erepo_dir.resolve().as_posix()) - else: - url = "file://{}".format(git_dir.resolve().as_posix()) - tmp_dir.scm.gitpython.repo.create_remote("upstream", url) - erepo_dir.remote = "upstream" - erepo_dir.url = url - return erepo_dir - - -@pytest.fixture -def git_downstream(tmp_dir, erepo_dir, git_dir, request): - if "dvc" in request.fixturenames: - url = "file://{}".format(erepo_dir.resolve().as_posix()) - else: - url = "file://{}".format(git_dir.resolve().as_posix()) - url = "file://{}".format(tmp_dir.resolve().as_posix()) - erepo_dir.scm.gitpython.repo.create_remote("upstream", url) - erepo_dir.remote = "upstream" - erepo_dir.url = url - return erepo_dir diff --git a/tests/dir_helpers.py b/tests/dir_helpers.py index 92609e3cd9..26e594a65c 100644 --- a/tests/dir_helpers.py +++ b/tests/dir_helpers.py @@ -407,3 +407,28 @@ def git_dir(make_tmp_dir): path = make_tmp_dir("git-erepo", scm=True) path.scm.commit("init repo") return path + + +@pytest.fixture +def git_upstream(tmp_dir, erepo_dir, git_dir, request): + if "dvc" in request.fixturenames: + url = "file://{}".format(erepo_dir.resolve().as_posix()) + else: + url = "file://{}".format(git_dir.resolve().as_posix()) + tmp_dir.scm.gitpython.repo.create_remote("upstream", url) + erepo_dir.remote = "upstream" + erepo_dir.url = url + return erepo_dir + + +@pytest.fixture +def git_downstream(tmp_dir, erepo_dir, git_dir, request): + if "dvc" in request.fixturenames: + url = "file://{}".format(erepo_dir.resolve().as_posix()) + else: + url = "file://{}".format(git_dir.resolve().as_posix()) + url = "file://{}".format(tmp_dir.resolve().as_posix()) + erepo_dir.scm.gitpython.repo.create_remote("upstream", url) + erepo_dir.remote = "upstream" + erepo_dir.url = url + return erepo_dir From b877a53256db395c6e8f79480eaaad8ba236f9e6 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Thu, 2 Sep 2021 16:43:00 +0800 Subject: [PATCH 16/18] Bug fix --- tests/dir_helpers.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/tests/dir_helpers.py b/tests/dir_helpers.py index 26e594a65c..f50ac02fc8 100644 --- a/tests/dir_helpers.py +++ b/tests/dir_helpers.py @@ -67,6 +67,8 @@ "erepo_dir", "git_dir", "git_init", + "git_upstream", + "git_downstream", ] @@ -411,24 +413,19 @@ def git_dir(make_tmp_dir): @pytest.fixture def git_upstream(tmp_dir, erepo_dir, git_dir, request): - if "dvc" in request.fixturenames: - url = "file://{}".format(erepo_dir.resolve().as_posix()) - else: - url = "file://{}".format(git_dir.resolve().as_posix()) + remote = erepo_dir if "dvc" in request.fixturenames else git_dir + url = "file://{}".format(remote.resolve().as_posix()) tmp_dir.scm.gitpython.repo.create_remote("upstream", url) - erepo_dir.remote = "upstream" - erepo_dir.url = url - return erepo_dir + remote.remote = "upstream" + remote.url = url + return remote @pytest.fixture def git_downstream(tmp_dir, erepo_dir, git_dir, request): - if "dvc" in request.fixturenames: - url = "file://{}".format(erepo_dir.resolve().as_posix()) - else: - url = "file://{}".format(git_dir.resolve().as_posix()) + remote = erepo_dir if "dvc" in request.fixturenames else git_dir url = "file://{}".format(tmp_dir.resolve().as_posix()) - erepo_dir.scm.gitpython.repo.create_remote("upstream", url) - erepo_dir.remote = "upstream" - erepo_dir.url = url - return erepo_dir + remote.scm.gitpython.repo.create_remote("upstream", url) + remote.remote = "upstream" + remote.url = url + return remote From 61be09b7729eeffba6e6cd82db43abd465de60c1 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Thu, 2 Sep 2021 21:06:40 +0800 Subject: [PATCH 17/18] Solve the pytest fail ``` import file mismatch: imported module 'test_utils' has this __file__ attribute: /home/runner/work/dvc/dvc/tests/unit/repo/experiments/test_utils.py which is not the same as the test file we want to collect: /home/runner/work/dvc/dvc/tests/unit/stage/test_utils.py HINT: remove __pycache__ / .pyc files and/or use a unique basename for your test file modules ``` --- tests/unit/repo/experiments/__init__.py | 0 tests/unit/stage/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/unit/repo/experiments/__init__.py create mode 100644 tests/unit/stage/__init__.py diff --git a/tests/unit/repo/experiments/__init__.py b/tests/unit/repo/experiments/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/stage/__init__.py b/tests/unit/stage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From 95fbb7f9fc05901cd010d4527504e214521759f1 Mon Sep 17 00:00:00 2001 From: Gao Date: Sat, 4 Sep 2021 17:02:45 +0800 Subject: [PATCH 18/18] Update dvc/command/experiments.py Co-authored-by: Dave Berenbaum --- dvc/command/experiments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index 0cb4718d15..10c50691f6 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -1250,7 +1250,7 @@ def add_parser(subparsers, parent_parser): action="store_true", help="Remove all committed experiments.", ) - experiments_remove_parser.add_argument( + remove_group.add_argument( "-g", "--git-remote", metavar="",