Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up remotes's exps #6471

Merged
merged 18 commits into from
Sep 5, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion dvc/command/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def run(self):
exp_names=self.args.experiment,
queue=self.args.queue,
clear_all=self.args.all,
remote=self.args.git_remote,
)

return 0
Expand Down Expand Up @@ -1231,7 +1232,7 @@ def add_parser(subparsers, parent_parser):
)
experiments_pull_parser.set_defaults(func=CmdExperimentsPull)

EXPERIMENTS_REMOVE_HELP = "Remove local experiments."
EXPERIMENTS_REMOVE_HELP = "Remove experiments."
experiments_remove_parser = experiments_subparsers.add_parser(
"remove",
parents=[parent_parser],
Expand All @@ -1249,6 +1250,12 @@ def add_parser(subparsers, parent_parser):
action="store_true",
help="Remove all committed experiments.",
)
experiments_remove_parser.add_argument(
karajan1001 marked this conversation as resolved.
Show resolved Hide resolved
"-g",
"--git-remote",
metavar="<git_remote>",
karajan1001 marked this conversation as resolved.
Show resolved Hide resolved
help="Name or URL of the Git remote to remove the experiment from",
)
experiments_remove_parser.add_argument(
"experiment",
nargs="*",
Expand Down
35 changes: 6 additions & 29 deletions dvc/repo/experiments/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dvc.repo import locked
from dvc.repo.scm_context import scm_context

from .utils import exp_commits, remote_exp_refs_by_name
from .utils import exp_commits, resolve_exp_ref

logger = logging.getLogger(__name__)

Expand All @@ -14,7 +14,11 @@
def pull(
repo, git_remote, exp_name, *args, force=False, pull_cache=False, **kwargs
):
exp_ref = _get_exp_ref(repo, git_remote, exp_name)
exp_ref = resolve_exp_ref(repo.scm, exp_name, git_remote)
if not exp_ref:
raise InvalidArgumentError(
f"Experiment '{exp_name}' does not exist in '{git_remote}'"
)

def on_diverged(refname: str, rev: str) -> bool:
if repo.scm.get_ref(refname) == rev:
Expand All @@ -35,33 +39,6 @@ def on_diverged(refname: str, rev: str) -> bool:
_pull_cache(repo, exp_ref, **kwargs)


def _get_exp_ref(repo, git_remote, exp_name):
if exp_name.startswith("refs/"):
return exp_name

karajan1001 marked this conversation as resolved.
Show resolved Hide resolved
exp_refs = list(remote_exp_refs_by_name(repo.scm, git_remote, exp_name))
if not exp_refs:
raise InvalidArgumentError(
f"Experiment '{exp_name}' does not exist in '{git_remote}'"
)
if len(exp_refs) > 1:
cur_rev = repo.scm.get_rev()
for info in exp_refs:
if info.baseline_sha == cur_rev:
return info
msg = [
(
f"Ambiguous name '{exp_name}' refers to multiple "
"experiments in '{git_remote}'. Use full refname to pull one "
"of the following:"
),
"",
]
msg.extend([f"\t{info}" for info in exp_refs])
raise InvalidArgumentError("\n".join(msg))
return exp_refs[0]


def _pull_cache(repo, exp_ref, dvc_remote=None, jobs=None, run_cache=False):
revs = list(exp_commits(repo.scm, [exp_ref]))
logger.debug("dvc fetch experiment '%s'", exp_ref)
Expand Down
36 changes: 6 additions & 30 deletions dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

from dvc.exceptions import DvcException, InvalidArgumentError
from dvc.repo import locked
from dvc.repo.experiments.base import ExpRefInfo
from dvc.repo.scm_context import scm_context

from .utils import exp_commits, exp_refs_by_name
from .utils import exp_commits, resolve_exp_ref

logger = logging.getLogger(__name__)

Expand All @@ -21,7 +20,11 @@ def push(
push_cache=False,
**kwargs,
):
exp_ref = _get_exp_ref(repo, exp_name)
exp_ref = resolve_exp_ref(repo.scm, exp_name)
if not exp_ref:
raise InvalidArgumentError(
f"'{exp_name}' is not a valid experiment name"
)

def on_diverged(refname: str, rev: str) -> bool:
if repo.scm.get_ref(refname) == rev:
Expand All @@ -42,33 +45,6 @@ def on_diverged(refname: str, rev: str) -> bool:
_push_cache(repo, exp_ref, **kwargs)


def _get_exp_ref(repo, exp_name: str) -> ExpRefInfo:
if exp_name.startswith("refs/"):
return ExpRefInfo.from_ref(exp_name)

exp_refs = list(exp_refs_by_name(repo.scm, exp_name))
if not exp_refs:
raise InvalidArgumentError(
f"'{exp_name}' is not a valid experiment name"
)
if len(exp_refs) > 1:
cur_rev = repo.scm.get_rev()
for info in exp_refs:
if info.baseline_sha == cur_rev:
return info
msg = [
(
f"Ambiguous name '{exp_name}' refers to multiple "
"experiments. Use full refname to push one of the "
"following:"
),
"",
]
msg.extend([f"\t{info}" for info in exp_refs])
raise InvalidArgumentError("\n".join(msg))
return exp_refs[0]


def _push_cache(repo, exp_ref, dvc_remote=None, jobs=None, run_cache=False):
revs = list(exp_commits(repo.scm, [exp_ref]))
logger.debug("dvc push experiment '%s'", exp_ref)
Expand Down
69 changes: 26 additions & 43 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from dvc.repo.scm_context import scm_context
from dvc.scm.base import RevError

from .base import EXPS_NAMESPACE, ExpRefInfo
from .utils import exp_refs, exp_refs_by_name, remove_exp_refs
from .utils import exp_refs, remove_exp_refs, resolve_exp_ref

logger = logging.getLogger(__name__)

Expand All @@ -19,6 +18,7 @@ def remove(
exp_names=None,
queue=False,
clear_all=False,
remote=None,
**kwargs,
):
if not any([exp_names, queue, clear_all]):
Expand All @@ -31,13 +31,7 @@ def remove(
removed += _clear_all(repo)

if exp_names:
remained = _remove_commited_exps(repo, exp_names)
remained = _remove_queued_exps(repo, remained)
if remained:
raise InvalidArgumentError(
"'{}' is not a valid experiment".format(";".join(remained))
)
removed += len(exp_names) - len(remained)
removed += _remove_exp_by_names(repo, remote, exp_names)
return removed


Expand Down Expand Up @@ -67,46 +61,24 @@ def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]:
return None


def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]:
cur_rev = repo.scm.get_rev()
if exp_name.startswith(EXPS_NAMESPACE):
if repo.scm.get_ref(exp_name):
return ExpRefInfo.from_ref(exp_name)
else:
exp_ref_list = list(exp_refs_by_name(repo.scm, exp_name))
if exp_ref_list:
return _get_ref(exp_ref_list, exp_name, cur_rev)
return None


def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]:
if len(ref_infos) > 1:
for info in ref_infos:
if info.baseline_sha == cur_rev:
return info
msg = [
(
f"Ambiguous name '{name}' refers to multiple "
"experiments. Use full refname to remove one of "
"the following:"
)
]
msg.extend([f"\t{info}" for info in ref_infos])
raise InvalidArgumentError("\n".join(msg))
return ref_infos[0]


def _remove_commited_exps(repo, refs: List[str]) -> List[str]:
def _remove_commited_exps(
repo, remote: Optional[str], exp_names: List[str]
) -> List[str]:
remain_list = []
remove_list = []
for ref in refs:
ref_info = _get_exp_ref(repo, ref)
for exp_name in exp_names:
ref_info = resolve_exp_ref(repo.scm, exp_name, remote)

if ref_info:
remove_list.append(ref_info)
else:
remain_list.append(ref)
remain_list.append(exp_name)
if remove_list:
remove_exp_refs(repo.scm, remove_list)
if not remote:
remove_exp_refs(repo.scm, remove_list)
else:
for ref_info in remove_list:
repo.scm.push_refspec(remote, None, str(ref_info))
return remain_list


Expand All @@ -119,3 +91,14 @@ def _remove_queued_exps(repo, refs_or_revs: List[str]) -> List[str]:
else:
repo.experiments.stash.drop(stash_index)
return remain_list


def _remove_exp_by_names(repo, remote, exp_names: List[str]) -> int:
remained = _remove_commited_exps(repo, remote, exp_names)
if not remote:
remained = _remove_queued_exps(repo, remained)
if remained:
raise InvalidArgumentError(
"'{}' is not a valid experiment".format(";".join(remained))
)
return len(exp_names) - len(remained)
44 changes: 43 additions & 1 deletion dvc/repo/experiments/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Generator, Iterable, Optional, Set

from dvc.exceptions import InvalidArgumentError
from dvc.scm.git import Git

from .base import (
Expand Down Expand Up @@ -33,7 +34,7 @@ def exp_refs_by_name(
) -> Generator["ExpRefInfo", None, None]:
"""Iterate over all experiment refs matching the specified name."""
for ref_info in exp_refs(scm):
if ref_info.name == name:
if ref_info.name == name or str(ref_info) == name:
karajan1001 marked this conversation as resolved.
Show resolved Hide resolved
yield ref_info


Expand Down Expand Up @@ -115,3 +116,44 @@ def fix_exp_head(scm: "Git", ref: Optional[str]) -> Optional[str]:
if name == "HEAD" and scm.get_ref(EXEC_BASELINE):
return "".join((EXEC_BASELINE, tail))
return ref


def resolve_exp_ref(
scm, exp_name: str, git_remote: Optional[str] = None
) -> Optional[ExpRefInfo]:
if exp_name.startswith("refs/"):
return ExpRefInfo.from_ref(exp_name)

if git_remote:
exp_ref_list = list(remote_exp_refs_by_name(scm, git_remote, exp_name))
else:
exp_ref_list = list(exp_refs_by_name(scm, exp_name))

if not exp_ref_list:
return None
if len(exp_ref_list) > 1:
cur_rev = scm.get_rev()
for info in exp_ref_list:
if info.baseline_sha == cur_rev:
return info
if git_remote:
msg = [
(
f"Ambiguous name '{exp_name}' refers to multiple "
"experiments. Use full refname to push one of the "
"following:"
),
"",
]
else:
msg = [
(
f"Ambiguous name '{exp_name}' refers to multiple "
f"experiments in '{git_remote}'. Use full refname to pull "
"one of the following:"
),
"",
]
msg.extend([f"\t{info}" for info in exp_ref_list])
raise InvalidArgumentError("\n".join(msg))
return exp_ref_list[0]
4 changes: 3 additions & 1 deletion dvc/scm/git/backend/dulwich.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,11 @@ def push_refspec(
) from exc

def update_refs(refs):
from dulwich.objects import ZERO_SHA

new_refs = {}
for ref, value in zip(dest_refs, values):
if ref in refs:
if ref in refs and value != ZERO_SHA:
karajan1001 marked this conversation as resolved.
Show resolved Hide resolved
local_sha = self.repo.refs[ref]
remote_sha = refs[ref]
try:
Expand Down
4 changes: 2 additions & 2 deletions tests/func/experiments/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def git_upstream(tmp_dir, erepo_dir):
@pytest.fixture
def git_downstream(tmp_dir, erepo_dir):
url = "file://{}".format(tmp_dir.resolve().as_posix())
erepo_dir.scm.gitpython.repo.create_remote("upstream", url)
erepo_dir.remote = "upstream"
erepo_dir.scm.gitpython.repo.create_remote("downstream", url)
erepo_dir.remote = "downstream"
karajan1001 marked this conversation as resolved.
Show resolved Hide resolved
erepo_dir.url = url
return erepo_dir
26 changes: 26 additions & 0 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,29 @@ def test_remove_all(tmp_dir, scm, dvc, exp_stage):
assert len(dvc.experiments.stash) == 2
assert scm.get_ref(str(ref_info2)) is None
assert scm.get_ref(str(ref_info1)) is None


@pytest.mark.parametrize("use_url", [True, False])
def test_remove_remote(tmp_dir, scm, dvc, exp_stage, git_upstream, use_url):
remote = git_upstream.url if use_url else git_upstream.remote

ref_info_list = []
exp_list = []
for i in range(3):
results = dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"]
)
exp = first(results)
exp_list.append(exp)
ref_info = first(exp_refs_by_rev(scm, exp))
ref_info_list.append(ref_info)
dvc.experiments.push(remote, ref_info.name)
assert git_upstream.scm.get_ref(str(ref_info)) == exp

dvc.experiments.remove(
remote=remote, exp_names=[str(ref_info_list[0]), ref_info_list[1].name]
)

assert git_upstream.scm.get_ref(str(ref_info_list[0])) is None
assert git_upstream.scm.get_ref(str(ref_info_list[1])) is None
assert git_upstream.scm.get_ref(str(ref_info_list[2])) == exp_list[2]
17 changes: 10 additions & 7 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,15 +253,17 @@ def test_experiments_pull(dvc, scm, mocker):


@pytest.mark.parametrize(
"queue,clear_all",
[(True, False), (False, True)],
"queue,clear_all,remote",
[(True, False, None), (False, True, None), (False, False, True)],
)
def test_experiments_remove(dvc, scm, mocker, queue, clear_all):
def test_experiments_remove(dvc, scm, mocker, queue, clear_all, remote):
if queue:
args = "--queue"
args = ["--queue"]
if clear_all:
args = "--all"
cli_args = parse_args(["experiments", "remove", args])
args = ["--all"]
if remote:
args = ["--git-remote", "myremote", "exp-123", "exp-234"]
cli_args = parse_args(["experiments", "remove"] + args)
assert cli_args.func == CmdExperimentsRemove

cmd = cli_args.func(cli_args)
Expand All @@ -270,7 +272,8 @@ def test_experiments_remove(dvc, scm, mocker, queue, clear_all):
assert cmd.run() == 0
m.assert_called_once_with(
cmd.repo,
exp_names=[],
exp_names=["exp-123", "exp-234"] if remote else [],
queue=queue,
clear_all=clear_all,
remote="myremote" if remote else None,
)
Loading