From 3c570a2177a05bd82cc665a87f66e3de285d8b51 Mon Sep 17 00:00:00 2001 From: skshetry <18718008+skshetry@users.noreply.github.com> Date: Thu, 6 Apr 2023 09:54:39 +0545 Subject: [PATCH] exp push: improve ui (#9308) * exp push: improve ui * fix misc suggestions - show url even when the refs are up-to-date. - show url at the end - add Git remote information in up-to-date UI message. * fix diverged exps message --- dvc/commands/experiments/push.py | 67 ++++++++++++++++++++------- dvc/repo/experiments/push.py | 42 ++++++++++------- tests/func/experiments/test_remote.py | 3 +- 3 files changed, 76 insertions(+), 36 deletions(-) diff --git a/dvc/commands/experiments/push.py b/dvc/commands/experiments/push.py index 239971a371..b76ad8d8a0 100644 --- a/dvc/commands/experiments/push.py +++ b/dvc/commands/experiments/push.py @@ -1,5 +1,6 @@ import argparse import logging +from typing import Any, Dict from dvc.cli import completion from dvc.cli.command import CmdBase @@ -18,34 +19,64 @@ def raise_error_if_all_disabled(self): "`--rev` or `--all-commits` flag." ) - def run(self): + @staticmethod + def log_result(result: Dict[str, Any], remote: str): from dvc.utils import humanize - self.raise_error_if_all_disabled() + def join_exps(exps): + return humanize.join([f"[bold]{e}[/]" for e in exps]) - result = self.repo.experiments.push( - self.args.git_remote, - self.args.experiment, - all_commits=self.args.all_commits, - rev=self.args.rev, - num=self.args.num, - force=self.args.force, - push_cache=self.args.push_cache, - dvc_remote=self.args.dvc_remote, - jobs=self.args.jobs, - run_cache=self.args.run_cache, - ) - if pushed_exps := result.get("pushed"): + if diverged_exps := result.get("diverged"): + exps = join_exps(diverged_exps) + ui.error_write( + f"[yellow]Local experiment {exps} has diverged " + "from remote experiment with the same name.\n" + "To override the remote experiment re-run with '--force'.", + styled=True, + ) + if uptodate_exps := result.get("up_to_date"): + exps = join_exps(uptodate_exps) + verb = "are" if len(uptodate_exps) > 1 else "is" ui.write( - f"Pushed experiment {humanize.join(map(repr, pushed_exps))} " - f"to Git remote {self.args.git_remote!r}." + f"Experiment {exps} {verb} up to date on Git remote {remote!r}.", + styled=True, ) - else: + if pushed_exps := result.get("success"): + exps = join_exps(pushed_exps) + ui.write(f"Pushed experiment {exps} to Git remote {remote!r}.", styled=True) + if not uptodate_exps and not pushed_exps: ui.write("No experiments to push.") + if uploaded := result.get("uploaded"): + stats = {"uploaded": uploaded} + ui.write(humanize.get_summary(stats.items())) + if project_url := result.get("url"): ui.write("[yellow]View your experiments at", project_url, styled=True) + def run(self): + from dvc.repo.experiments.push import UploadError + + self.raise_error_if_all_disabled() + + try: + result = self.repo.experiments.push( + self.args.git_remote, + self.args.experiment, + all_commits=self.args.all_commits, + rev=self.args.rev, + num=self.args.num, + force=self.args.force, + push_cache=self.args.push_cache, + dvc_remote=self.args.dvc_remote, + jobs=self.args.jobs, + run_cache=self.args.run_cache, + ) + except UploadError as e: + self.log_result(e.result, self.args.git_remote) + raise + + self.log_result(result, self.args.git_remote) if not self.args.push_cache: ui.write( "To push cached outputs", diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index 160970c472..5be647f6bc 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -14,10 +14,10 @@ from funcy import compact, group_by from scmrepo.git.backend.base import SyncStatus +from dvc.exceptions import DvcException from dvc.repo import locked from dvc.repo.scm_context import scm_context from dvc.scm import Git, TqdmGit, iter_revs -from dvc.ui import ui from dvc.utils import env2bool from .exceptions import UnresolvedExpNamesError @@ -30,6 +30,12 @@ logger = logging.getLogger(__name__) +class UploadError(DvcException): + def __init__(self, msg, result): + self.result = result + super().__init__(msg) + + def notify_refs_to_studio( repo: "Repo", git_remote: str, **refs: List[str] ) -> Optional[str]: @@ -104,23 +110,25 @@ def push( # noqa: C901 exp_ref_set.update(ref_info_list) push_result = _push(repo, git_remote, exp_ref_set, force) - if push_result[SyncStatus.DIVERGED]: - diverged_refs = [ref.name for ref in push_result[SyncStatus.DIVERGED]] - ui.warn( - f"Local experiment '{diverged_refs}' has diverged from remote " - "experiment with the same name. To override the remote experiment " - "re-run with '--force'." - ) + + refs = { + status.name.lower(): [ref.name for ref in ref_list] + for status, ref_list in push_result.items() + } + result: Dict[str, Any] = {**refs, "uploaded": 0} + + pushed_refs_info = ( + push_result[SyncStatus.UP_TO_DATE] + push_result[SyncStatus.SUCCESS] + ) if push_cache: - push_cache_ref = ( - push_result[SyncStatus.UP_TO_DATE] + push_result[SyncStatus.SUCCESS] - ) - _push_cache(repo, push_cache_ref, **kwargs) + try: + result["uploaded"] = _push_cache(repo, pushed_refs_info, **kwargs) + except Exception as exc: # noqa: BLE001, pylint: disable=broad-except + raise UploadError("failed to push cache", result) from exc - refs = push_result[SyncStatus.SUCCESS] - pushed_refs = [str(r) for r in refs] + pushed_refs = [str(r) for r in pushed_refs_info] url = notify_refs_to_studio(repo, git_remote, pushed=pushed_refs) - return {"pushed": [ref.name for ref in refs], "url": url} + return {**result, "url": url} def _push( @@ -161,10 +169,10 @@ def _push_cache( dvc_remote: Optional[str] = None, jobs: Optional[int] = None, run_cache: bool = False, -): +) -> int: if isinstance(refs, ExpRefInfo): refs = [refs] assert isinstance(repo.scm, Git) revs = list(exp_commits(repo.scm, refs)) logger.debug("dvc push experiment '%s'", refs) - repo.push(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs) + return repo.push(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs) diff --git a/tests/func/experiments/test_remote.py b/tests/func/experiments/test_remote.py index c9df12ca04..2281db83f5 100644 --- a/tests/func/experiments/test_remote.py +++ b/tests/func/experiments/test_remote.py @@ -77,8 +77,9 @@ def test_push_diverged(tmp_dir, scm, dvc, git_upstream, exp_stage): git_upstream.tmp_dir.scm.set_ref(str(ref_info), remote_rev) assert dvc.experiments.push(git_upstream.remote, [ref_info.name]) == { - "pushed": [], + "diverged": [ref_info.name], "url": None, + "uploaded": 0, } assert git_upstream.tmp_dir.scm.get_ref(str(ref_info)) == remote_rev