Skip to content

Commit

Permalink
exp share refactor - ping on exp:remove, do not fail on error (#9248)
Browse files Browse the repository at this point in the history
studio experiments share refactor

1) Pings Studio on exp remove
2) Does not fail on any error during request.post
3) UI improvement on exp remove
  • Loading branch information
skshetry authored and daavoo committed Mar 28, 2023
1 parent 7329ba5 commit 04f9762
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 100 deletions.
10 changes: 7 additions & 3 deletions dvc/commands/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,22 @@ def check_arguments(self):
)

def run(self):
from dvc.utils import humanize

self.check_arguments()

removed_list = self.repo.experiments.remove(
removed = self.repo.experiments.remove(
exp_names=self.args.experiment,
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
queue=self.args.queue,
git_remote=self.args.git_remote,
)
removed = ",".join(removed_list)
ui.write(f"Removed experiments: {removed}")
if removed:
ui.write(f"Removed experiments: {humanize.join(map(repr, removed))}")
else:
ui.write("No experiments to remove.")

return 0

Expand Down
78 changes: 18 additions & 60 deletions dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
from typing import Iterable, List, Mapping, Optional, Set, Union

from funcy import group_by
from funcy import compact, group_by
from scmrepo.git.backend.base import SyncStatus

from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm import TqdmGit, iter_revs
from dvc.ui import ui
from dvc.utils import env2bool

from .exceptions import UnresolvedExpNamesError
from .refs import ExpRefInfo
Expand All @@ -16,12 +17,24 @@
logger = logging.getLogger(__name__)


STUDIO_URL = "https://studio.iterative.ai"
def notify_refs_to_studio(config, git_remote: str, **refs: List[str]):
token = config.get("studio_token")
refs = compact(refs)
if refs and (token or config["push_exp_to_studio"]) and not env2bool("DVC_TEST"):
from dvc.utils import studio

studio_url = config.get("studio_url")
studio.notify_refs(
git_remote,
default_token=token,
studio_url=studio_url,
**refs, # type: ignore[arg-type]
)


@locked
@scm_context
def push( # noqa: C901, PLR0912
def push( # noqa: C901
repo,
git_remote: str,
exp_names: Union[Iterable[str], str],
Expand All @@ -32,8 +45,6 @@ def push( # noqa: C901, PLR0912
push_cache: bool = False,
**kwargs,
) -> Iterable[str]:
from dvc.utils import env2bool

exp_ref_set: Set["ExpRefInfo"] = set()
if all_commits:
exp_ref_set.update(exp_refs(repo.scm))
Expand Down Expand Up @@ -76,64 +87,11 @@ def push( # noqa: C901, PLR0912
_push_cache(repo, push_cache_ref, **kwargs)

refs = push_result[SyncStatus.SUCCESS]
feature_config = repo.config["feature"]

push_to_studio = (
bool(feature_config.get("studio_token")) or feature_config["push_exp_to_studio"]
)
if refs and push_to_studio and not env2bool("DVC_TEST"):
token, repo_url = get_studio_token_and_repo_url(feature_config)
if token and repo_url:
studio_url = feature_config.get("studio_url")
_notify_studio([str(ref) for ref in refs], repo_url, token, url=studio_url)
pushed_refs = [str(r) for r in refs]
notify_refs_to_studio(repo.config["feature"], git_remote, pushed=pushed_refs)
return [ref.name for ref in refs]


def get_studio_token_and_repo_url(config):
import os

from dvc_studio_client.post_live_metrics import get_studio_repo_url

token = os.getenv("STUDIO_TOKEN") or config.get("studio_token")
if not token:
logger.debug("Studio token not found. Skipping push to Studio.")
repo_url = os.getenv("STUDIO_REPO_URL") or get_studio_repo_url()
if token and not repo_url:
logger.warning(
"Could not detect repository url. "
"Please set STUDIO_REPO_URL environment variable "
"to your remote git repository url. "
)
return token, repo_url


def _notify_studio(
refs: List[str],
repo_url: str,
token: str,
url: Optional[str] = None,
):
if not refs:
return

from urllib.parse import urljoin

import requests
from requests.adapters import HTTPAdapter

endpoint = urljoin(url or STUDIO_URL, "/webhook/dvc")
session = requests.Session()
session.mount(endpoint, HTTPAdapter(max_retries=3))

logger.debug("pushing experiments to Studio (%s)", url)
json = {"repo_url": repo_url, "client": "dvc", "refs": refs}
logger.trace("Sending %s to %s", json, endpoint) # type: ignore[attr-defined]

headers = {"Authorization": f"token {token}"}
resp = session.post(endpoint, json=json, headers=headers, timeout=5)
resp.raise_for_status()


def _push(
repo,
git_remote: str,
Expand Down
17 changes: 9 additions & 8 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

@locked
@scm_context
def remove( # noqa: C901
def remove( # noqa: C901, PLR0912
repo: "Repo",
exp_names: Union[None, str, List[str]] = None,
rev: Optional[str] = None,
Expand All @@ -39,13 +39,6 @@ def remove( # noqa: C901
removed.extend(celery_queue.clear(queued=True))

assert isinstance(repo.scm, Git)
if all_commits:
removed.extend(
_remove_commited_exps(
repo.scm, list(exp_refs(repo.scm, git_remote)), git_remote
)
)
return removed

exp_ref_list: List["ExpRefInfo"] = []
queue_entry_list: List["QueueEntry"] = []
Expand All @@ -70,6 +63,9 @@ def remove( # noqa: C901
exp_ref_dict = _resolve_exp_by_baseline(repo, rev, num, git_remote)
removed.extend(exp_ref_dict.keys())
exp_ref_list.extend(exp_ref_dict.values())
elif all_commits:
exp_ref_list.extend(exp_refs(repo.scm, git_remote))
removed = [ref.name for ref in exp_ref_list]

if exp_ref_list:
_remove_commited_exps(repo.scm, exp_ref_list, git_remote)
Expand All @@ -79,6 +75,11 @@ def remove( # noqa: C901

remove_tasks(celery_queue, queue_entry_list)

if git_remote:
from .push import notify_refs_to_studio

removed_refs = [str(r) for r in exp_ref_list]
notify_refs_to_studio(repo.config["feature"], git_remote, removed=removed_refs)
return removed


Expand Down
93 changes: 93 additions & 0 deletions dvc/utils/studio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from urllib.parse import urljoin

from dvc_studio_client.post_live_metrics import get_studio_repo_url
from funcy import compact
from requests import RequestException, Session
from requests.adapters import HTTPAdapter

if TYPE_CHECKING:
from requests import Response

logger = logging.getLogger(__name__)

STUDIO_URL = "https://studio.iterative.ai"
STUDIO_REPO_URL = "STUDIO_REPO_URL"
STUDIO_TOKEN = "STUDIO_TOKEN" # noqa: S105


def get_studio_token_and_repo_url(
default_token: Optional[str] = None,
repo_url_finder: Optional[Callable[[], Optional[str]]] = None,
) -> Tuple[Optional[str], Optional[str]]:
token = os.getenv(STUDIO_TOKEN) or default_token
url_finder = repo_url_finder or get_studio_repo_url
repo_url = os.getenv(STUDIO_REPO_URL) or url_finder()
return token, repo_url


def post(
endpoint: str,
token: str,
data: Dict[str, Any],
url: Optional[str] = STUDIO_URL,
max_retries: int = 3,
timeout: int = 5,
) -> "Response":
endpoint = urljoin(url or STUDIO_URL, endpoint)
session = Session()
session.mount(endpoint, HTTPAdapter(max_retries=max_retries))

logger.trace("Sending %s to %s", data, endpoint) # type: ignore[attr-defined]

headers = {"Authorization": f"token {token}"}
resp = session.post(endpoint, json=data, headers=headers, timeout=timeout)
resp.raise_for_status()
return resp


def notify_refs(
git_remote: str,
default_token: Optional[str] = None,
studio_url: Optional[str] = None,
repo_url_finder: Optional[Callable[[], Optional[str]]] = None,
**refs: List[str],
) -> None:
# TODO: Should we use git_remote to associate with Studio project
# instead of using `git ls-remote` on fallback?
refs = compact(refs)
if not refs:
return

assert git_remote
token, repo_url = get_studio_token_and_repo_url(
default_token=default_token,
repo_url_finder=repo_url_finder,
)
if not token:
logger.debug("Studio token not found.")
return

if not repo_url:
logger.warning(
"Could not detect repository url. "
"Please set %s environment variable "
"to your remote git repository url. ",
STUDIO_REPO_URL,
)
return

logger.debug(
"notifying Studio%s about updated experiments",
f" ({studio_url})" if studio_url else "",
)
data = {"repo_url": repo_url, "client": "dvc", "refs": refs}

try:
post("/webhook/dvc", token=token, data=data, url=studio_url)
except RequestException:
# TODO: handle expected failures and show appropriate message
# TODO: handle unexpected failures and show appropriate message
logger.debug("failed to notify Studio", exc_info=True)
29 changes: 0 additions & 29 deletions tests/unit/repo/experiments/test_push.py

This file was deleted.

44 changes: 44 additions & 0 deletions tests/unit/utils/test_studio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from urllib.parse import urljoin

import pytest
from requests import Response

from dvc.utils.studio import STUDIO_URL, notify_refs


@pytest.mark.parametrize(
"status_code",
[
200, # success
401, # should not fail on client errors
500, # should not fail even on server errors
],
)
def test_notify_refs(mocker, status_code):
response = Response()
response.status_code = status_code

mock_post = mocker.patch("requests.Session.post", return_value=response)

notify_refs(
"origin",
"TOKEN",
repo_url_finder=lambda: "[email protected]:iterative/dvc.git",
pushed=["p1", "p2"],
removed=["r1", "r2"],
)

assert mock_post.called
assert mock_post.call_args == mocker.call(
urljoin(STUDIO_URL, "/webhook/dvc"),
json={
"repo_url": "[email protected]:iterative/dvc.git",
"client": "dvc",
"refs": {
"pushed": ["p1", "p2"],
"removed": ["r1", "r2"],
},
},
headers={"Authorization": "token TOKEN"},
timeout=5,
)

0 comments on commit 04f9762

Please sign in to comment.