Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exp share refactor - ping on exp:remove, do not fail on error #9248

Merged
merged 1 commit into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions dvc/commands/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,22 @@ def check_arguments(self):
)

def run(self):
from dvc.utils import humanize

self.check_arguments()

removed_list = self.repo.experiments.remove(
removed = self.repo.experiments.remove(
exp_names=self.args.experiment,
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
queue=self.args.queue,
git_remote=self.args.git_remote,
)
removed = ",".join(removed_list)
ui.write(f"Removed experiments: {removed}")
if removed:
ui.write(f"Removed experiments: {humanize.join(map(repr, removed))}")
else:
ui.write("No experiments to remove.")

return 0

Expand Down
78 changes: 18 additions & 60 deletions dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
from typing import Iterable, List, Mapping, Optional, Set, Union

from funcy import group_by
from funcy import compact, group_by
from scmrepo.git.backend.base import SyncStatus

from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm import TqdmGit, iter_revs
from dvc.ui import ui
from dvc.utils import env2bool

from .exceptions import UnresolvedExpNamesError
from .refs import ExpRefInfo
Expand All @@ -16,12 +17,24 @@
logger = logging.getLogger(__name__)


STUDIO_URL = "https://studio.iterative.ai"
def notify_refs_to_studio(config, git_remote: str, **refs: List[str]):
token = config.get("studio_token")
refs = compact(refs)
if refs and (token or config["push_exp_to_studio"]) and not env2bool("DVC_TEST"):
from dvc.utils import studio

studio_url = config.get("studio_url")
studio.notify_refs(
git_remote,
default_token=token,
studio_url=studio_url,
**refs, # type: ignore[arg-type]
)


@locked
@scm_context
def push( # noqa: C901, PLR0912
def push( # noqa: C901
repo,
git_remote: str,
exp_names: Union[Iterable[str], str],
Expand All @@ -32,8 +45,6 @@ def push( # noqa: C901, PLR0912
push_cache: bool = False,
**kwargs,
) -> Iterable[str]:
from dvc.utils import env2bool

exp_ref_set: Set["ExpRefInfo"] = set()
if all_commits:
exp_ref_set.update(exp_refs(repo.scm))
Expand Down Expand Up @@ -76,64 +87,11 @@ def push( # noqa: C901, PLR0912
_push_cache(repo, push_cache_ref, **kwargs)

refs = push_result[SyncStatus.SUCCESS]
feature_config = repo.config["feature"]

push_to_studio = (
bool(feature_config.get("studio_token")) or feature_config["push_exp_to_studio"]
)
if refs and push_to_studio and not env2bool("DVC_TEST"):
token, repo_url = get_studio_token_and_repo_url(feature_config)
if token and repo_url:
studio_url = feature_config.get("studio_url")
_notify_studio([str(ref) for ref in refs], repo_url, token, url=studio_url)
pushed_refs = [str(r) for r in refs]
notify_refs_to_studio(repo.config["feature"], git_remote, pushed=pushed_refs)
return [ref.name for ref in refs]


def get_studio_token_and_repo_url(config):
import os

from dvc_studio_client.post_live_metrics import get_studio_repo_url

token = os.getenv("STUDIO_TOKEN") or config.get("studio_token")
if not token:
logger.debug("Studio token not found. Skipping push to Studio.")
repo_url = os.getenv("STUDIO_REPO_URL") or get_studio_repo_url()
if token and not repo_url:
logger.warning(
"Could not detect repository url. "
"Please set STUDIO_REPO_URL environment variable "
"to your remote git repository url. "
)
return token, repo_url


def _notify_studio(
refs: List[str],
repo_url: str,
token: str,
url: Optional[str] = None,
):
if not refs:
return

from urllib.parse import urljoin

import requests
from requests.adapters import HTTPAdapter

endpoint = urljoin(url or STUDIO_URL, "/webhook/dvc")
session = requests.Session()
session.mount(endpoint, HTTPAdapter(max_retries=3))

logger.debug("pushing experiments to Studio (%s)", url)
json = {"repo_url": repo_url, "client": "dvc", "refs": refs}
logger.trace("Sending %s to %s", json, endpoint) # type: ignore[attr-defined]

headers = {"Authorization": f"token {token}"}
resp = session.post(endpoint, json=json, headers=headers, timeout=5)
resp.raise_for_status()


def _push(
repo,
git_remote: str,
Expand Down
17 changes: 9 additions & 8 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

@locked
@scm_context
def remove( # noqa: C901
def remove( # noqa: C901, PLR0912
repo: "Repo",
exp_names: Union[None, str, List[str]] = None,
rev: Optional[str] = None,
Expand All @@ -39,13 +39,6 @@ def remove( # noqa: C901
removed.extend(celery_queue.clear(queued=True))

assert isinstance(repo.scm, Git)
if all_commits:
removed.extend(
_remove_commited_exps(
repo.scm, list(exp_refs(repo.scm, git_remote)), git_remote
)
)
return removed

exp_ref_list: List["ExpRefInfo"] = []
queue_entry_list: List["QueueEntry"] = []
Expand All @@ -70,6 +63,9 @@ def remove( # noqa: C901
exp_ref_dict = _resolve_exp_by_baseline(repo, rev, num, git_remote)
removed.extend(exp_ref_dict.keys())
exp_ref_list.extend(exp_ref_dict.values())
elif all_commits:
exp_ref_list.extend(exp_refs(repo.scm, git_remote))
removed = [ref.name for ref in exp_ref_list]

if exp_ref_list:
_remove_commited_exps(repo.scm, exp_ref_list, git_remote)
Expand All @@ -79,6 +75,11 @@ def remove( # noqa: C901

remove_tasks(celery_queue, queue_entry_list)

if git_remote:
from .push import notify_refs_to_studio

removed_refs = [str(r) for r in exp_ref_list]
notify_refs_to_studio(repo.config["feature"], git_remote, removed=removed_refs)
return removed


Expand Down
93 changes: 93 additions & 0 deletions dvc/utils/studio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from urllib.parse import urljoin

from dvc_studio_client.post_live_metrics import get_studio_repo_url
from funcy import compact
from requests import RequestException, Session
from requests.adapters import HTTPAdapter

if TYPE_CHECKING:
from requests import Response

logger = logging.getLogger(__name__)

STUDIO_URL = "https://studio.iterative.ai"
STUDIO_REPO_URL = "STUDIO_REPO_URL"
STUDIO_TOKEN = "STUDIO_TOKEN" # noqa: S105


def get_studio_token_and_repo_url(
default_token: Optional[str] = None,
repo_url_finder: Optional[Callable[[], Optional[str]]] = None,
) -> Tuple[Optional[str], Optional[str]]:
token = os.getenv(STUDIO_TOKEN) or default_token
url_finder = repo_url_finder or get_studio_repo_url
repo_url = os.getenv(STUDIO_REPO_URL) or url_finder()
return token, repo_url


def post(
endpoint: str,
token: str,
data: Dict[str, Any],
url: Optional[str] = STUDIO_URL,
max_retries: int = 3,
timeout: int = 5,
) -> "Response":
endpoint = urljoin(url or STUDIO_URL, endpoint)
session = Session()
session.mount(endpoint, HTTPAdapter(max_retries=max_retries))

logger.trace("Sending %s to %s", data, endpoint) # type: ignore[attr-defined]

headers = {"Authorization": f"token {token}"}
resp = session.post(endpoint, json=data, headers=headers, timeout=timeout)
resp.raise_for_status()
return resp


def notify_refs(
git_remote: str,
default_token: Optional[str] = None,
studio_url: Optional[str] = None,
repo_url_finder: Optional[Callable[[], Optional[str]]] = None,
**refs: List[str],
) -> None:
# TODO: Should we use git_remote to associate with Studio project
# instead of using `git ls-remote` on fallback?
Comment on lines +58 to +59
Copy link
Member Author

@skshetry skshetry Mar 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dberenbaum, a bit of an edge case here.
What project/repository should we associate on dvc exp push origin? The origin remote url, or where the upstream remote was set to (during git push —set-upstream)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's an explicit remote, I think it should be the origin remote url. If the user wants to use a different remote, then they shouldn't specify origin, right?

WRT using the upstream remote, there's some very old discussion in #6332 (comment) and #6427. I agree it would be nice to have some way to use the upstream remote, but I think it's not that simple since we aren't pushing a branch and it's not a blocker.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant it in terms of a Studio project. We send remote url (aka repo_url) to Studio.
At the moment, we are sending —set-upstream url instead of the remote that is passed as argument on exp-push or exp-remove.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it more, it does not make sense to pass a different repo_url if you have pushed to a separate remote. I am changing it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, makes sense, thanks @skshetry!

refs = compact(refs)
if not refs:
return

assert git_remote
token, repo_url = get_studio_token_and_repo_url(
default_token=default_token,
repo_url_finder=repo_url_finder,
)
if not token:
logger.debug("Studio token not found.")
return

if not repo_url:
logger.warning(
"Could not detect repository url. "
"Please set %s environment variable "
"to your remote git repository url. ",
STUDIO_REPO_URL,
)
return

logger.debug(
"notifying Studio%s about updated experiments",
f" ({studio_url})" if studio_url else "",
)
data = {"repo_url": repo_url, "client": "dvc", "refs": refs}

try:
post("/webhook/dvc", token=token, data=data, url=studio_url)
except RequestException:
# TODO: handle expected failures and show appropriate message
# TODO: handle unexpected failures and show appropriate message
logger.debug("failed to notify Studio", exc_info=True)
29 changes: 0 additions & 29 deletions tests/unit/repo/experiments/test_push.py

This file was deleted.

44 changes: 44 additions & 0 deletions tests/unit/utils/test_studio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from urllib.parse import urljoin

import pytest
from requests import Response

from dvc.utils.studio import STUDIO_URL, notify_refs


@pytest.mark.parametrize(
"status_code",
[
200, # success
401, # should not fail on client errors
500, # should not fail even on server errors
],
)
def test_notify_refs(mocker, status_code):
response = Response()
response.status_code = status_code

mock_post = mocker.patch("requests.Session.post", return_value=response)

notify_refs(
"origin",
"TOKEN",
repo_url_finder=lambda: "[email protected]:iterative/dvc.git",
pushed=["p1", "p2"],
removed=["r1", "r2"],
)

assert mock_post.called
assert mock_post.call_args == mocker.call(
urljoin(STUDIO_URL, "/webhook/dvc"),
json={
"repo_url": "[email protected]:iterative/dvc.git",
"client": "dvc",
"refs": {
"pushed": ["p1", "p2"],
"removed": ["r1", "r2"],
},
},
headers={"Authorization": "token TOKEN"},
timeout=5,
)