Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exp push: improve ui #9308

Merged
merged 3 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 46 additions & 19 deletions dvc/commands/experiments/push.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import logging
from typing import Any, Dict

from dvc.cli import completion
from dvc.cli.command import CmdBase
Expand All @@ -18,34 +19,60 @@ def raise_error_if_all_disabled(self):
"`--rev` or `--all-commits` flag."
)

def run(self):
@staticmethod
def log_result(result: Dict[str, Any], remote: str):
from dvc.utils import humanize

self.raise_error_if_all_disabled()
def join_exps(exps):
return humanize.join([f"[bold]{e}[/]" for e in exps])

result = self.repo.experiments.push(
self.args.git_remote,
self.args.experiment,
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
force=self.args.force,
push_cache=self.args.push_cache,
dvc_remote=self.args.dvc_remote,
jobs=self.args.jobs,
run_cache=self.args.run_cache,
)
if pushed_exps := result.get("pushed"):
ui.write(
f"Pushed experiment {humanize.join(map(repr, pushed_exps))} "
f"to Git remote {self.args.git_remote!r}."
if diverged_exps := result.get("diverged"):
exps = join_exps(diverged_exps)
ui.warn(
f"Local experiment {exps} has diverged "
"from remote experiment with the same name. "
"To override the remote experiment re-run with '--force'."
)
else:
if uptodate_exps := result.get("up_to_date"):
skshetry marked this conversation as resolved.
Show resolved Hide resolved
verb = "are" if len(uptodate_exps) > 1 else "is"
exps = join_exps(uptodate_exps)
ui.write(f"Experiment {exps} {verb} up to date.", styled=True)
skshetry marked this conversation as resolved.
Show resolved Hide resolved
if pushed_exps := result.get("success"):
exps = join_exps(pushed_exps)
ui.write(f"Pushed experiment {exps} to Git remote {remote!r}.", styled=True)
if not uptodate_exps and not pushed_exps:
ui.write("No experiments to push.")

if project_url := result.get("url"):
ui.write("[yellow]View your experiments at", project_url, styled=True)
skshetry marked this conversation as resolved.
Show resolved Hide resolved

if uploaded := result.get("uploaded"):
stats = {"uploaded": uploaded}
ui.write(humanize.get_summary(stats.items()))

def run(self):
from dvc.repo.experiments.push import UploadError

self.raise_error_if_all_disabled()

try:
result = self.repo.experiments.push(
self.args.git_remote,
self.args.experiment,
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
force=self.args.force,
push_cache=self.args.push_cache,
dvc_remote=self.args.dvc_remote,
jobs=self.args.jobs,
run_cache=self.args.run_cache,
)
except UploadError as e:
self.log_result(e.result, self.args.git_remote)
raise

self.log_result(result, self.args.git_remote)
if not self.args.push_cache:
ui.write(
"To push cached outputs",
Expand Down
37 changes: 23 additions & 14 deletions dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from funcy import compact, group_by
from scmrepo.git.backend.base import SyncStatus

from dvc.exceptions import DvcException
from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm import Git, TqdmGit, iter_revs
from dvc.ui import ui
from dvc.utils import env2bool

from .exceptions import UnresolvedExpNamesError
Expand All @@ -30,6 +30,12 @@
logger = logging.getLogger(__name__)


class UploadError(DvcException):
def __init__(self, msg, result):
self.result = result
super().__init__(msg)


def notify_refs_to_studio(
repo: "Repo", git_remote: str, **refs: List[str]
) -> Optional[str]:
Expand Down Expand Up @@ -104,23 +110,26 @@ def push( # noqa: C901
exp_ref_set.update(ref_info_list)

push_result = _push(repo, git_remote, exp_ref_set, force)
if push_result[SyncStatus.DIVERGED]:
skshetry marked this conversation as resolved.
Show resolved Hide resolved
diverged_refs = [ref.name for ref in push_result[SyncStatus.DIVERGED]]
ui.warn(
f"Local experiment '{diverged_refs}' has diverged from remote "
"experiment with the same name. To override the remote experiment "
"re-run with '--force'."
)

refs = {
status.name.lower(): [ref.name for ref in ref_list]
for status, ref_list in push_result.items()
}
result: Dict[str, Any] = {**refs, "uploaded": 0}

if push_cache:
push_cache_ref = (
push_result[SyncStatus.UP_TO_DATE] + push_result[SyncStatus.SUCCESS]
)
_push_cache(repo, push_cache_ref, **kwargs)

refs = push_result[SyncStatus.SUCCESS]
pushed_refs = [str(r) for r in refs]
try:
result["uploaded"] = _push_cache(repo, push_cache_ref, **kwargs)
except Exception as exc: # noqa: BLE001, pylint: disable=broad-except
raise UploadError("failed to push cache", result) from exc

pushed_refs = [str(r) for r in push_result[SyncStatus.SUCCESS]]
url = notify_refs_to_studio(repo, git_remote, pushed=pushed_refs)
return {"pushed": [ref.name for ref in refs], "url": url}
return {**result, "url": url}


def _push(
Expand Down Expand Up @@ -161,10 +170,10 @@ def _push_cache(
dvc_remote: Optional[str] = None,
jobs: Optional[int] = None,
run_cache: bool = False,
):
) -> int:
if isinstance(refs, ExpRefInfo):
refs = [refs]
assert isinstance(repo.scm, Git)
revs = list(exp_commits(repo.scm, refs))
logger.debug("dvc push experiment '%s'", refs)
repo.push(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs)
return repo.push(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs)
3 changes: 2 additions & 1 deletion tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def test_push_diverged(tmp_dir, scm, dvc, git_upstream, exp_stage):
git_upstream.tmp_dir.scm.set_ref(str(ref_info), remote_rev)

assert dvc.experiments.push(git_upstream.remote, [ref_info.name]) == {
"pushed": [],
"diverged": [ref_info.name],
"url": None,
"uploaded": 0,
}
assert git_upstream.tmp_dir.scm.get_ref(str(ref_info)) == remote_rev

Expand Down