Skip to content

Commit

Permalink
exp push: improve ui (#9308)
Browse files Browse the repository at this point in the history
* exp push: improve ui

* fix misc suggestions
- show url even when the refs are up-to-date.
- show url at the end
- add Git remote information in up-to-date UI message.

* fix diverged exps message
  • Loading branch information
skshetry authored Apr 6, 2023
1 parent 3312a90 commit 3c570a2
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 36 deletions.
67 changes: 49 additions & 18 deletions dvc/commands/experiments/push.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import logging
from typing import Any, Dict

from dvc.cli import completion
from dvc.cli.command import CmdBase
Expand All @@ -18,34 +19,64 @@ def raise_error_if_all_disabled(self):
"`--rev` or `--all-commits` flag."
)

def run(self):
@staticmethod
def log_result(result: Dict[str, Any], remote: str):
from dvc.utils import humanize

self.raise_error_if_all_disabled()
def join_exps(exps):
return humanize.join([f"[bold]{e}[/]" for e in exps])

result = self.repo.experiments.push(
self.args.git_remote,
self.args.experiment,
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
force=self.args.force,
push_cache=self.args.push_cache,
dvc_remote=self.args.dvc_remote,
jobs=self.args.jobs,
run_cache=self.args.run_cache,
)
if pushed_exps := result.get("pushed"):
if diverged_exps := result.get("diverged"):
exps = join_exps(diverged_exps)
ui.error_write(
f"[yellow]Local experiment {exps} has diverged "
"from remote experiment with the same name.\n"
"To override the remote experiment re-run with '--force'.",
styled=True,
)
if uptodate_exps := result.get("up_to_date"):
exps = join_exps(uptodate_exps)
verb = "are" if len(uptodate_exps) > 1 else "is"
ui.write(
f"Pushed experiment {humanize.join(map(repr, pushed_exps))} "
f"to Git remote {self.args.git_remote!r}."
f"Experiment {exps} {verb} up to date on Git remote {remote!r}.",
styled=True,
)
else:
if pushed_exps := result.get("success"):
exps = join_exps(pushed_exps)
ui.write(f"Pushed experiment {exps} to Git remote {remote!r}.", styled=True)
if not uptodate_exps and not pushed_exps:
ui.write("No experiments to push.")

if uploaded := result.get("uploaded"):
stats = {"uploaded": uploaded}
ui.write(humanize.get_summary(stats.items()))

if project_url := result.get("url"):
ui.write("[yellow]View your experiments at", project_url, styled=True)

def run(self):
from dvc.repo.experiments.push import UploadError

self.raise_error_if_all_disabled()

try:
result = self.repo.experiments.push(
self.args.git_remote,
self.args.experiment,
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
force=self.args.force,
push_cache=self.args.push_cache,
dvc_remote=self.args.dvc_remote,
jobs=self.args.jobs,
run_cache=self.args.run_cache,
)
except UploadError as e:
self.log_result(e.result, self.args.git_remote)
raise

self.log_result(result, self.args.git_remote)
if not self.args.push_cache:
ui.write(
"To push cached outputs",
Expand Down
42 changes: 25 additions & 17 deletions dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from funcy import compact, group_by
from scmrepo.git.backend.base import SyncStatus

from dvc.exceptions import DvcException
from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm import Git, TqdmGit, iter_revs
from dvc.ui import ui
from dvc.utils import env2bool

from .exceptions import UnresolvedExpNamesError
Expand All @@ -30,6 +30,12 @@
logger = logging.getLogger(__name__)


class UploadError(DvcException):
def __init__(self, msg, result):
self.result = result
super().__init__(msg)


def notify_refs_to_studio(
repo: "Repo", git_remote: str, **refs: List[str]
) -> Optional[str]:
Expand Down Expand Up @@ -104,23 +110,25 @@ def push( # noqa: C901
exp_ref_set.update(ref_info_list)

push_result = _push(repo, git_remote, exp_ref_set, force)
if push_result[SyncStatus.DIVERGED]:
diverged_refs = [ref.name for ref in push_result[SyncStatus.DIVERGED]]
ui.warn(
f"Local experiment '{diverged_refs}' has diverged from remote "
"experiment with the same name. To override the remote experiment "
"re-run with '--force'."
)

refs = {
status.name.lower(): [ref.name for ref in ref_list]
for status, ref_list in push_result.items()
}
result: Dict[str, Any] = {**refs, "uploaded": 0}

pushed_refs_info = (
push_result[SyncStatus.UP_TO_DATE] + push_result[SyncStatus.SUCCESS]
)
if push_cache:
push_cache_ref = (
push_result[SyncStatus.UP_TO_DATE] + push_result[SyncStatus.SUCCESS]
)
_push_cache(repo, push_cache_ref, **kwargs)
try:
result["uploaded"] = _push_cache(repo, pushed_refs_info, **kwargs)
except Exception as exc: # noqa: BLE001, pylint: disable=broad-except
raise UploadError("failed to push cache", result) from exc

refs = push_result[SyncStatus.SUCCESS]
pushed_refs = [str(r) for r in refs]
pushed_refs = [str(r) for r in pushed_refs_info]
url = notify_refs_to_studio(repo, git_remote, pushed=pushed_refs)
return {"pushed": [ref.name for ref in refs], "url": url}
return {**result, "url": url}


def _push(
Expand Down Expand Up @@ -161,10 +169,10 @@ def _push_cache(
dvc_remote: Optional[str] = None,
jobs: Optional[int] = None,
run_cache: bool = False,
):
) -> int:
if isinstance(refs, ExpRefInfo):
refs = [refs]
assert isinstance(repo.scm, Git)
revs = list(exp_commits(repo.scm, refs))
logger.debug("dvc push experiment '%s'", refs)
repo.push(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs)
return repo.push(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs)
3 changes: 2 additions & 1 deletion tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def test_push_diverged(tmp_dir, scm, dvc, git_upstream, exp_stage):
git_upstream.tmp_dir.scm.set_ref(str(ref_info), remote_rev)

assert dvc.experiments.push(git_upstream.remote, [ref_info.name]) == {
"pushed": [],
"diverged": [ref_info.name],
"url": None,
"uploaded": 0,
}
assert git_upstream.tmp_dir.scm.get_ref(str(ref_info)) == remote_rev

Expand Down

0 comments on commit 3c570a2

Please sign in to comment.