Skip to content

Commit

Permalink
exp push: improve ui
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Apr 5, 2023
1 parent cf52efb commit ea5c084
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 34 deletions.
64 changes: 45 additions & 19 deletions dvc/commands/experiments/push.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import logging
from typing import Any, Dict

from dvc.cli import completion
from dvc.cli.command import CmdBase
Expand All @@ -18,34 +19,59 @@ def raise_error_if_all_disabled(self):
"`--rev` or `--all-commits` flag."
)

def run(self):
@staticmethod
def log_result(result: Dict[str, Any], remote: str):
from dvc.utils import humanize

self.raise_error_if_all_disabled()
def join_exps(exps):
return humanize.join([f"[bold]{e}[/]" for e in exps])

result = self.repo.experiments.push(
self.args.git_remote,
self.args.experiment,
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
force=self.args.force,
push_cache=self.args.push_cache,
dvc_remote=self.args.dvc_remote,
jobs=self.args.jobs,
run_cache=self.args.run_cache,
)
if pushed_exps := result.get("pushed"):
ui.write(
f"Pushed experiment {humanize.join(map(repr, pushed_exps))} "
f"to Git remote {self.args.git_remote!r}."
if diverged_exps := result.get("diverged"):
exps = join_exps(diverged_exps)
ui.warn(
f"Local experiment {exps} has diverged "
"from remote experiment with the same name. "
"To override the remote experiment re-run with '--force'."
)
else:
if uptodate_exps := result.get("up_to_date"):
verb = "are" if len(uptodate_exps) > 1 else "is"
exps = join_exps(uptodate_exps)
ui.write(f"Experiment {exps} {verb} up to date.", styled=True)
if pushed_exps := result.get("pushed"):
exps = join_exps(pushed_exps)
ui.write(f"Pushed experiment {exps} to Git remote {remote!r}.", styled=True)
if not uptodate_exps and not pushed_exps:
ui.write("No experiments to push.")

if project_url := result.get("url"):
ui.write("[yellow]View your experiments at", project_url, styled=True)

if uploaded := result.get("uploaded"):
ui.write(humanize.get_summary({"pushed": uploaded}))

def run(self):
from dvc.repo.experiments.push import ExpPushError

self.raise_error_if_all_disabled()

try:
result = self.repo.experiments.push(
self.args.git_remote,
self.args.experiment,
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
force=self.args.force,
push_cache=self.args.push_cache,
dvc_remote=self.args.dvc_remote,
jobs=self.args.jobs,
run_cache=self.args.run_cache,
)
except ExpPushError as e:
self.log_result(e.result, self.args.git_remote)
raise

self.log_result(result, self.args.git_remote)
if not self.args.push_cache:
ui.write(
"To push cached outputs",
Expand Down
37 changes: 23 additions & 14 deletions dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from funcy import compact, group_by
from scmrepo.git.backend.base import SyncStatus

from dvc.exceptions import DvcException, UploadError
from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm import Git, TqdmGit, iter_revs
from dvc.ui import ui
from dvc.utils import env2bool

from .exceptions import UnresolvedExpNamesError
Expand All @@ -30,6 +30,12 @@
logger = logging.getLogger(__name__)


class ExpPushError(DvcException):
def __init__(self, msg, result):
self.result = result
super().__init__(msg)


def notify_refs_to_studio(
repo: "Repo", git_remote: str, **refs: List[str]
) -> Optional[str]:
Expand Down Expand Up @@ -104,23 +110,26 @@ def push( # noqa: C901
exp_ref_set.update(ref_info_list)

push_result = _push(repo, git_remote, exp_ref_set, force)
if push_result[SyncStatus.DIVERGED]:
diverged_refs = [ref.name for ref in push_result[SyncStatus.DIVERGED]]
ui.warn(
f"Local experiment '{diverged_refs}' has diverged from remote "
"experiment with the same name. To override the remote experiment "
"re-run with '--force'."
)

refs = {
status.name.lower(): [ref.name for ref in ref_list]
for status, ref_list in push_result.items()
}
result: Dict[str, Any] = {**refs, "uploaded": 0}

if push_cache:
push_cache_ref = (
push_result[SyncStatus.UP_TO_DATE] + push_result[SyncStatus.SUCCESS]
)
_push_cache(repo, push_cache_ref, **kwargs)

refs = push_result[SyncStatus.SUCCESS]
pushed_refs = [str(r) for r in refs]
try:
result["uploaded"] = _push_cache(repo, push_cache_ref, **kwargs)
except UploadError as exc:
raise ExpPushError("failed to push cache", result) from exc

pushed_refs = [str(r) for r in push_result[SyncStatus.SUCCESS]]
url = notify_refs_to_studio(repo, git_remote, pushed=pushed_refs)
return {"pushed": [ref.name for ref in refs], "url": url}
return {**result, "url": url}


def _push(
Expand Down Expand Up @@ -161,10 +170,10 @@ def _push_cache(
dvc_remote: Optional[str] = None,
jobs: Optional[int] = None,
run_cache: bool = False,
):
) -> int:
if isinstance(refs, ExpRefInfo):
refs = [refs]
assert isinstance(repo.scm, Git)
revs = list(exp_commits(repo.scm, refs))
logger.debug("dvc push experiment '%s'", refs)
repo.push(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs)
return repo.push(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs)
3 changes: 2 additions & 1 deletion tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def test_push_diverged(tmp_dir, scm, dvc, git_upstream, exp_stage):
git_upstream.tmp_dir.scm.set_ref(str(ref_info), remote_rev)

assert dvc.experiments.push(git_upstream.remote, [ref_info.name]) == {
"pushed": [],
"diverged": [ref_info.name],
"url": None,
"uploaded": 0,
}
assert git_upstream.tmp_dir.scm.get_ref(str(ref_info)) == remote_rev

Expand Down

0 comments on commit ea5c084

Please sign in to comment.