Skip to content

Commit

Permalink
Clean up remotes's exps (#6471)
Browse files Browse the repository at this point in the history
* Clean up remotes's exps

fix #6006
1. add a new argument `--git-remote` to `dvc exp remove`
2. support remote a special remote exp
3. add tests for it
4. fix  #6421
5. add a test for #6421.
6. extract some functions from the `dvc pull`, `dvc push`, `dvc remove` to utils.
7. add a tests for this new util function `resolve_exp_ref`
8. add `__init__.py` to some test package for the pytest fail

Co-authored-by: Jorge Orpinel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Dave Berenbaum <[email protected]>
  • Loading branch information
4 people authored Sep 5, 2021
1 parent b99cd0b commit 4cc0633
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 129 deletions.
9 changes: 8 additions & 1 deletion dvc/command/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def run(self):
exp_names=self.args.experiment,
queue=self.args.queue,
clear_all=self.args.all,
remote=self.args.git_remote,
)

return 0
Expand Down Expand Up @@ -1231,7 +1232,7 @@ def add_parser(subparsers, parent_parser):
)
experiments_pull_parser.set_defaults(func=CmdExperimentsPull)

EXPERIMENTS_REMOVE_HELP = "Remove local experiments."
EXPERIMENTS_REMOVE_HELP = "Remove experiments."
experiments_remove_parser = experiments_subparsers.add_parser(
"remove",
parents=[parent_parser],
Expand All @@ -1249,6 +1250,12 @@ def add_parser(subparsers, parent_parser):
action="store_true",
help="Remove all committed experiments.",
)
remove_group.add_argument(
"-g",
"--git-remote",
metavar="<git_remote>",
help="Name or URL of the Git remote to remove the experiment from",
)
experiments_remove_parser.add_argument(
"experiment",
nargs="*",
Expand Down
35 changes: 6 additions & 29 deletions dvc/repo/experiments/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dvc.repo import locked
from dvc.repo.scm_context import scm_context

from .utils import exp_commits, remote_exp_refs_by_name
from .utils import exp_commits, resolve_exp_ref

logger = logging.getLogger(__name__)

Expand All @@ -14,7 +14,11 @@
def pull(
repo, git_remote, exp_name, *args, force=False, pull_cache=False, **kwargs
):
exp_ref = _get_exp_ref(repo, git_remote, exp_name)
exp_ref = resolve_exp_ref(repo.scm, exp_name, git_remote)
if not exp_ref:
raise InvalidArgumentError(
f"Experiment '{exp_name}' does not exist in '{git_remote}'"
)

def on_diverged(refname: str, rev: str) -> bool:
if repo.scm.get_ref(refname) == rev:
Expand All @@ -35,33 +39,6 @@ def on_diverged(refname: str, rev: str) -> bool:
_pull_cache(repo, exp_ref, **kwargs)


def _get_exp_ref(repo, git_remote, exp_name):
if exp_name.startswith("refs/"):
return exp_name

exp_refs = list(remote_exp_refs_by_name(repo.scm, git_remote, exp_name))
if not exp_refs:
raise InvalidArgumentError(
f"Experiment '{exp_name}' does not exist in '{git_remote}'"
)
if len(exp_refs) > 1:
cur_rev = repo.scm.get_rev()
for info in exp_refs:
if info.baseline_sha == cur_rev:
return info
msg = [
(
f"Ambiguous name '{exp_name}' refers to multiple "
"experiments in '{git_remote}'. Use full refname to pull one "
"of the following:"
),
"",
]
msg.extend([f"\t{info}" for info in exp_refs])
raise InvalidArgumentError("\n".join(msg))
return exp_refs[0]


def _pull_cache(repo, exp_ref, dvc_remote=None, jobs=None, run_cache=False):
revs = list(exp_commits(repo.scm, [exp_ref]))
logger.debug("dvc fetch experiment '%s'", exp_ref)
Expand Down
36 changes: 6 additions & 30 deletions dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

from dvc.exceptions import DvcException, InvalidArgumentError
from dvc.repo import locked
from dvc.repo.experiments.base import ExpRefInfo
from dvc.repo.scm_context import scm_context

from .utils import exp_commits, exp_refs_by_name
from .utils import exp_commits, resolve_exp_ref

logger = logging.getLogger(__name__)

Expand All @@ -21,7 +20,11 @@ def push(
push_cache=False,
**kwargs,
):
exp_ref = _get_exp_ref(repo, exp_name)
exp_ref = resolve_exp_ref(repo.scm, exp_name)
if not exp_ref:
raise InvalidArgumentError(
f"'{exp_name}' is not a valid experiment name"
)

def on_diverged(refname: str, rev: str) -> bool:
if repo.scm.get_ref(refname) == rev:
Expand All @@ -42,33 +45,6 @@ def on_diverged(refname: str, rev: str) -> bool:
_push_cache(repo, exp_ref, **kwargs)


def _get_exp_ref(repo, exp_name: str) -> ExpRefInfo:
if exp_name.startswith("refs/"):
return ExpRefInfo.from_ref(exp_name)

exp_refs = list(exp_refs_by_name(repo.scm, exp_name))
if not exp_refs:
raise InvalidArgumentError(
f"'{exp_name}' is not a valid experiment name"
)
if len(exp_refs) > 1:
cur_rev = repo.scm.get_rev()
for info in exp_refs:
if info.baseline_sha == cur_rev:
return info
msg = [
(
f"Ambiguous name '{exp_name}' refers to multiple "
"experiments. Use full refname to push one of the "
"following:"
),
"",
]
msg.extend([f"\t{info}" for info in exp_refs])
raise InvalidArgumentError("\n".join(msg))
return exp_refs[0]


def _push_cache(repo, exp_ref, dvc_remote=None, jobs=None, run_cache=False):
revs = list(exp_commits(repo.scm, [exp_ref]))
logger.debug("dvc push experiment '%s'", exp_ref)
Expand Down
69 changes: 26 additions & 43 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from dvc.repo.scm_context import scm_context
from dvc.scm.base import RevError

from .base import EXPS_NAMESPACE, ExpRefInfo
from .utils import exp_refs, exp_refs_by_name, remove_exp_refs
from .utils import exp_refs, remove_exp_refs, resolve_exp_ref

logger = logging.getLogger(__name__)

Expand All @@ -19,6 +18,7 @@ def remove(
exp_names=None,
queue=False,
clear_all=False,
remote=None,
**kwargs,
):
if not any([exp_names, queue, clear_all]):
Expand All @@ -31,13 +31,7 @@ def remove(
removed += _clear_all(repo)

if exp_names:
remained = _remove_commited_exps(repo, exp_names)
remained = _remove_queued_exps(repo, remained)
if remained:
raise InvalidArgumentError(
"'{}' is not a valid experiment".format(";".join(remained))
)
removed += len(exp_names) - len(remained)
removed += _remove_exp_by_names(repo, remote, exp_names)
return removed


Expand Down Expand Up @@ -67,46 +61,24 @@ def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]:
return None


def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]:
cur_rev = repo.scm.get_rev()
if exp_name.startswith(EXPS_NAMESPACE):
if repo.scm.get_ref(exp_name):
return ExpRefInfo.from_ref(exp_name)
else:
exp_ref_list = list(exp_refs_by_name(repo.scm, exp_name))
if exp_ref_list:
return _get_ref(exp_ref_list, exp_name, cur_rev)
return None


def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]:
if len(ref_infos) > 1:
for info in ref_infos:
if info.baseline_sha == cur_rev:
return info
msg = [
(
f"Ambiguous name '{name}' refers to multiple "
"experiments. Use full refname to remove one of "
"the following:"
)
]
msg.extend([f"\t{info}" for info in ref_infos])
raise InvalidArgumentError("\n".join(msg))
return ref_infos[0]


def _remove_commited_exps(repo, refs: List[str]) -> List[str]:
def _remove_commited_exps(
repo, remote: Optional[str], exp_names: List[str]
) -> List[str]:
remain_list = []
remove_list = []
for ref in refs:
ref_info = _get_exp_ref(repo, ref)
for exp_name in exp_names:
ref_info = resolve_exp_ref(repo.scm, exp_name, remote)

if ref_info:
remove_list.append(ref_info)
else:
remain_list.append(ref)
remain_list.append(exp_name)
if remove_list:
remove_exp_refs(repo.scm, remove_list)
if not remote:
remove_exp_refs(repo.scm, remove_list)
else:
for ref_info in remove_list:
repo.scm.push_refspec(remote, None, str(ref_info))
return remain_list


Expand All @@ -119,3 +91,14 @@ def _remove_queued_exps(repo, refs_or_revs: List[str]) -> List[str]:
else:
repo.experiments.stash.drop(stash_index)
return remain_list


def _remove_exp_by_names(repo, remote, exp_names: List[str]) -> int:
remained = _remove_commited_exps(repo, remote, exp_names)
if not remote:
remained = _remove_queued_exps(repo, remained)
if remained:
raise InvalidArgumentError(
"'{}' is not a valid experiment".format(";".join(remained))
)
return len(exp_names) - len(remained)
42 changes: 42 additions & 0 deletions dvc/repo/experiments/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Generator, Iterable, Optional, Set

from dvc.exceptions import InvalidArgumentError
from dvc.scm.git import Git

from .base import (
Expand Down Expand Up @@ -115,3 +116,44 @@ def fix_exp_head(scm: "Git", ref: Optional[str]) -> Optional[str]:
if name == "HEAD" and scm.get_ref(EXEC_BASELINE):
return "".join((EXEC_BASELINE, tail))
return ref


def resolve_exp_ref(
scm, exp_name: str, git_remote: Optional[str] = None
) -> Optional[ExpRefInfo]:
if exp_name.startswith("refs/"):
return ExpRefInfo.from_ref(exp_name)

if git_remote:
exp_ref_list = list(remote_exp_refs_by_name(scm, git_remote, exp_name))
else:
exp_ref_list = list(exp_refs_by_name(scm, exp_name))

if not exp_ref_list:
return None
if len(exp_ref_list) > 1:
cur_rev = scm.get_rev()
for info in exp_ref_list:
if info.baseline_sha == cur_rev:
return info
if git_remote:
msg = [
(
f"Ambiguous name '{exp_name}' refers to multiple "
"experiments. Use full refname to push one of the "
"following:"
),
"",
]
else:
msg = [
(
f"Ambiguous name '{exp_name}' refers to multiple "
f"experiments in '{git_remote}'. Use full refname to pull "
"one of the following:"
),
"",
]
msg.extend([f"\t{info}" for info in exp_ref_list])
raise InvalidArgumentError("\n".join(msg))
return exp_ref_list[0]
4 changes: 3 additions & 1 deletion dvc/scm/git/backend/dulwich.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,11 @@ def push_refspec(
) from exc

def update_refs(refs):
from dulwich.objects import ZERO_SHA

new_refs = {}
for ref, value in zip(dest_refs, values):
if ref in refs:
if ref in refs and value != ZERO_SHA:
local_sha = self.repo.refs[ref]
remote_sha = refs[ref]
try:
Expand Down
22 changes: 22 additions & 0 deletions tests/dir_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
"erepo_dir",
"git_dir",
"git_init",
"git_upstream",
"git_downstream",
]


Expand Down Expand Up @@ -407,3 +409,23 @@ def git_dir(make_tmp_dir):
path = make_tmp_dir("git-erepo", scm=True)
path.scm.commit("init repo")
return path


@pytest.fixture
def git_upstream(tmp_dir, erepo_dir, git_dir, request):
remote = erepo_dir if "dvc" in request.fixturenames else git_dir
url = "file://{}".format(remote.resolve().as_posix())
tmp_dir.scm.gitpython.repo.create_remote("upstream", url)
remote.remote = "upstream"
remote.url = url
return remote


@pytest.fixture
def git_downstream(tmp_dir, erepo_dir, git_dir, request):
remote = erepo_dir if "dvc" in request.fixturenames else git_dir
url = "file://{}".format(tmp_dir.resolve().as_posix())
remote.scm.gitpython.repo.create_remote("upstream", url)
remote.remote = "upstream"
remote.url = url
return remote
18 changes: 0 additions & 18 deletions tests/func/experiments/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,3 @@ def checkpoint_stage(tmp_dir, scm, dvc, mocker):
scm.commit("init")
stage.iterations = DEFAULT_ITERATIONS
return stage


@pytest.fixture
def git_upstream(tmp_dir, erepo_dir):
url = "file://{}".format(erepo_dir.resolve().as_posix())
tmp_dir.scm.gitpython.repo.create_remote("upstream", url)
erepo_dir.remote = "upstream"
erepo_dir.url = url
return erepo_dir


@pytest.fixture
def git_downstream(tmp_dir, erepo_dir):
url = "file://{}".format(tmp_dir.resolve().as_posix())
erepo_dir.scm.gitpython.repo.create_remote("upstream", url)
erepo_dir.remote = "upstream"
erepo_dir.url = url
return erepo_dir
Loading

0 comments on commit 4cc0633

Please sign in to comment.