Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up remotes's exps #6471

Merged
merged 18 commits into from
Sep 5, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion dvc/command/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def run(self):
exp_names=self.args.experiment,
queue=self.args.queue,
clear_all=self.args.all,
remote=self.args.git_remote,
)

return 0
Expand Down Expand Up @@ -1231,7 +1232,7 @@ def add_parser(subparsers, parent_parser):
)
experiments_pull_parser.set_defaults(func=CmdExperimentsPull)

EXPERIMENTS_REMOVE_HELP = "Remove local experiments."
EXPERIMENTS_REMOVE_HELP = "Remove experiments."
experiments_remove_parser = experiments_subparsers.add_parser(
"remove",
parents=[parent_parser],
Expand All @@ -1249,6 +1250,13 @@ def add_parser(subparsers, parent_parser):
action="store_true",
help="Remove all committed experiments.",
)
experiments_remove_parser.add_argument(
karajan1001 marked this conversation as resolved.
Show resolved Hide resolved
"-g",
"--git-remote",
metavar="<git_remote>",
karajan1001 marked this conversation as resolved.
Show resolved Hide resolved
help="Name or URL of the Git remote to delete the experiment "
"references",
karajan1001 marked this conversation as resolved.
Show resolved Hide resolved
)
experiments_remove_parser.add_argument(
"experiment",
nargs="*",
Expand Down
3 changes: 0 additions & 3 deletions dvc/repo/experiments/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def on_diverged(refname: str, rev: str) -> bool:


def _get_exp_ref(repo, git_remote, exp_name):
if exp_name.startswith("refs/"):
return exp_name

karajan1001 marked this conversation as resolved.
Show resolved Hide resolved
exp_refs = list(remote_exp_refs_by_name(repo.scm, git_remote, exp_name))
if not exp_refs:
raise InvalidArgumentError(
Expand Down
3 changes: 0 additions & 3 deletions dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ def on_diverged(refname: str, rev: str) -> bool:


def _get_exp_ref(repo, exp_name: str) -> ExpRefInfo:
if exp_name.startswith("refs/"):
return ExpRefInfo.from_ref(exp_name)

exp_refs = list(exp_refs_by_name(repo.scm, exp_name))
if not exp_refs:
raise InvalidArgumentError(
Expand Down
55 changes: 36 additions & 19 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
from dvc.repo.scm_context import scm_context
from dvc.scm.base import RevError

from .base import EXPS_NAMESPACE, ExpRefInfo
from .utils import exp_refs, exp_refs_by_name, remove_exp_refs
from .base import ExpRefInfo
from .utils import (
exp_refs,
exp_refs_by_name,
remote_exp_refs_by_name,
remove_exp_refs,
)

logger = logging.getLogger(__name__)

Expand All @@ -19,6 +24,7 @@ def remove(
exp_names=None,
queue=False,
clear_all=False,
remote=None,
**kwargs,
):
if not any([exp_names, queue, clear_all]):
Expand All @@ -31,13 +37,7 @@ def remove(
removed += _clear_all(repo)

if exp_names:
remained = _remove_commited_exps(repo, exp_names)
remained = _remove_queued_exps(repo, remained)
if remained:
raise InvalidArgumentError(
"'{}' is not a valid experiment".format(";".join(remained))
)
removed += len(exp_names) - len(remained)
removed += _remove_exp_by_names(repo, remote, exp_names)
return removed


Expand Down Expand Up @@ -67,15 +67,16 @@ def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]:
return None


def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]:
def _get_exp_ref(repo, remote: str, exp_name: str) -> Optional[ExpRefInfo]:
cur_rev = repo.scm.get_rev()
if exp_name.startswith(EXPS_NAMESPACE):
if repo.scm.get_ref(exp_name):
return ExpRefInfo.from_ref(exp_name)
else:
if not remote:
exp_ref_list = list(exp_refs_by_name(repo.scm, exp_name))
if exp_ref_list:
return _get_ref(exp_ref_list, exp_name, cur_rev)
else:
exp_ref_list = list(
remote_exp_refs_by_name(repo.scm, remote, exp_name)
)
if exp_ref_list:
return _get_ref(exp_ref_list, exp_name, cur_rev)
return None


Expand All @@ -96,17 +97,22 @@ def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]:
return ref_infos[0]


def _remove_commited_exps(repo, refs: List[str]) -> List[str]:
def _remove_commited_exps(repo, remote: str, refs: List[str]) -> List[str]:
karajan1001 marked this conversation as resolved.
Show resolved Hide resolved
remain_list = []
remove_list = []
for ref in refs:
ref_info = _get_exp_ref(repo, ref)
ref_info = _get_exp_ref(repo, remote, ref)

if ref_info:
remove_list.append(ref_info)
else:
remain_list.append(ref)
if remove_list:
remove_exp_refs(repo.scm, remove_list)
if not remote:
remove_exp_refs(repo.scm, remove_list)
else:
for ref_info in remove_list:
repo.scm.push_refspec(remote, None, str(ref_info))
return remain_list


Expand All @@ -119,3 +125,14 @@ def _remove_queued_exps(repo, refs_or_revs: List[str]) -> List[str]:
else:
repo.experiments.stash.drop(stash_index)
return remain_list


def _remove_exp_by_names(repo, remote, exp_names: List[str]) -> int:
remained = _remove_commited_exps(repo, remote, exp_names)
if not remote:
remained = _remove_queued_exps(repo, remained)
if remained:
raise InvalidArgumentError(
"'{}' is not a valid experiment".format(";".join(remained))
)
return len(exp_names) - len(remained)
4 changes: 2 additions & 2 deletions dvc/repo/experiments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def exp_refs_by_name(
) -> Generator["ExpRefInfo", None, None]:
"""Iterate over all experiment refs matching the specified name."""
for ref_info in exp_refs(scm):
if ref_info.name == name:
if ref_info.name == name or str(ref_info) == name:
karajan1001 marked this conversation as resolved.
Show resolved Hide resolved
yield ref_info


Expand Down Expand Up @@ -63,7 +63,7 @@ def remote_exp_refs_by_name(
) -> Generator["ExpRefInfo", None, None]:
"""Iterate over all remote experiment refs matching the specified name."""
for ref_info in remote_exp_refs(scm, url):
if ref_info.name == name:
if ref_info.name == name or str(ref_info) == name:
yield ref_info


Expand Down
4 changes: 3 additions & 1 deletion dvc/scm/git/backend/dulwich.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,11 @@ def push_refspec(
) from exc

def update_refs(refs):
from dulwich.objects import ZERO_SHA

new_refs = {}
for ref, value in zip(dest_refs, values):
if ref in refs:
if ref in refs and value != ZERO_SHA:
karajan1001 marked this conversation as resolved.
Show resolved Hide resolved
local_sha = self.repo.refs[ref]
remote_sha = refs[ref]
try:
Expand Down
26 changes: 26 additions & 0 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,29 @@ def test_remove_all(tmp_dir, scm, dvc, exp_stage):
assert len(dvc.experiments.stash) == 2
assert scm.get_ref(str(ref_info2)) is None
assert scm.get_ref(str(ref_info1)) is None


@pytest.mark.parametrize("use_url", [True, False])
def test_remove_remote(tmp_dir, scm, dvc, exp_stage, git_upstream, use_url):
remote = git_upstream.url if use_url else git_upstream.remote

ref_info_list = []
exp_list = []
for i in range(3):
results = dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"]
)
exp = first(results)
exp_list.append(exp)
ref_info = first(exp_refs_by_rev(scm, exp))
ref_info_list.append(ref_info)
dvc.experiments.push(remote, ref_info.name)
assert git_upstream.scm.get_ref(str(ref_info)) == exp

dvc.experiments.remove(
remote=remote, exp_names=[str(ref_info_list[0]), ref_info_list[1].name]
)

assert git_upstream.scm.get_ref(str(ref_info_list[0])) is None
assert git_upstream.scm.get_ref(str(ref_info_list[1])) is None
assert git_upstream.scm.get_ref(str(ref_info_list[2])) == exp_list[2]
17 changes: 10 additions & 7 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,15 +253,17 @@ def test_experiments_pull(dvc, scm, mocker):


@pytest.mark.parametrize(
"queue,clear_all",
[(True, False), (False, True)],
"queue,clear_all,remote",
[(True, False, None), (False, True, None), (False, False, True)],
)
def test_experiments_remove(dvc, scm, mocker, queue, clear_all):
def test_experiments_remove(dvc, scm, mocker, queue, clear_all, remote):
if queue:
args = "--queue"
args = ["--queue"]
if clear_all:
args = "--all"
cli_args = parse_args(["experiments", "remove", args])
args = ["--all"]
if remote:
args = ["--git-remote", "myremote", "exp-123", "exp-234"]
cli_args = parse_args(["experiments", "remove"] + args)
assert cli_args.func == CmdExperimentsRemove

cmd = cli_args.func(cli_args)
Expand All @@ -270,7 +272,8 @@ def test_experiments_remove(dvc, scm, mocker, queue, clear_all):
assert cmd.run() == 0
m.assert_called_once_with(
cmd.repo,
exp_names=[],
exp_names=["exp-123", "exp-234"] if remote else [],
queue=queue,
clear_all=clear_all,
remote="myremote" if remote else None,
)