diff --git a/dvc/commands/experiments/remove.py b/dvc/commands/experiments/remove.py index ce827f4e88..c9dc4e0c6b 100644 --- a/dvc/commands/experiments/remove.py +++ b/dvc/commands/experiments/remove.py @@ -24,9 +24,11 @@ def check_arguments(self): ) def run(self): + from dvc.utils import humanize + self.check_arguments() - removed_list = self.repo.experiments.remove( + removed = self.repo.experiments.remove( exp_names=self.args.experiment, all_commits=self.args.all_commits, rev=self.args.rev, @@ -34,8 +36,10 @@ def run(self): queue=self.args.queue, git_remote=self.args.git_remote, ) - removed = ",".join(removed_list) - ui.write(f"Removed experiments: {removed}") + if removed: + ui.write(f"Removed experiments: {humanize.join(map(repr, removed))}") + else: + ui.write("No experiments to remove.") return 0 diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index d36cb7da56..12a7dd32c4 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -1,13 +1,14 @@ import logging from typing import Iterable, List, Mapping, Optional, Set, Union -from funcy import group_by +from funcy import compact, group_by from scmrepo.git.backend.base import SyncStatus from dvc.repo import locked from dvc.repo.scm_context import scm_context from dvc.scm import TqdmGit, iter_revs from dvc.ui import ui +from dvc.utils import env2bool from .exceptions import UnresolvedExpNamesError from .refs import ExpRefInfo @@ -16,12 +17,24 @@ logger = logging.getLogger(__name__) -STUDIO_URL = "https://studio.iterative.ai" +def notify_refs_to_studio(config, git_remote: str, **refs: List[str]): + token = config.get("studio_token") + refs = compact(refs) + if refs and (token or config["push_exp_to_studio"]) and not env2bool("DVC_TEST"): + from dvc.utils import studio + + studio_url = config.get("studio_url") + studio.notify_refs( + git_remote, + default_token=token, + studio_url=studio_url, + **refs, # type: ignore[arg-type] + ) @locked @scm_context -def push( # noqa: C901, PLR0912 +def push( # noqa: C901 repo, git_remote: str, exp_names: Union[Iterable[str], str], @@ -32,8 +45,6 @@ def push( # noqa: C901, PLR0912 push_cache: bool = False, **kwargs, ) -> Iterable[str]: - from dvc.utils import env2bool - exp_ref_set: Set["ExpRefInfo"] = set() if all_commits: exp_ref_set.update(exp_refs(repo.scm)) @@ -76,64 +87,11 @@ def push( # noqa: C901, PLR0912 _push_cache(repo, push_cache_ref, **kwargs) refs = push_result[SyncStatus.SUCCESS] - feature_config = repo.config["feature"] - - push_to_studio = ( - bool(feature_config.get("studio_token")) or feature_config["push_exp_to_studio"] - ) - if refs and push_to_studio and not env2bool("DVC_TEST"): - token, repo_url = get_studio_token_and_repo_url(feature_config) - if token and repo_url: - studio_url = feature_config.get("studio_url") - _notify_studio([str(ref) for ref in refs], repo_url, token, url=studio_url) + pushed_refs = [str(r) for r in refs] + notify_refs_to_studio(repo.config["feature"], git_remote, pushed=pushed_refs) return [ref.name for ref in refs] -def get_studio_token_and_repo_url(config): - import os - - from dvc_studio_client.post_live_metrics import get_studio_repo_url - - token = os.getenv("STUDIO_TOKEN") or config.get("studio_token") - if not token: - logger.debug("Studio token not found. Skipping push to Studio.") - repo_url = os.getenv("STUDIO_REPO_URL") or get_studio_repo_url() - if token and not repo_url: - logger.warning( - "Could not detect repository url. " - "Please set STUDIO_REPO_URL environment variable " - "to your remote git repository url. " - ) - return token, repo_url - - -def _notify_studio( - refs: List[str], - repo_url: str, - token: str, - url: Optional[str] = None, -): - if not refs: - return - - from urllib.parse import urljoin - - import requests - from requests.adapters import HTTPAdapter - - endpoint = urljoin(url or STUDIO_URL, "/webhook/dvc") - session = requests.Session() - session.mount(endpoint, HTTPAdapter(max_retries=3)) - - logger.debug("pushing experiments to Studio (%s)", url) - json = {"repo_url": repo_url, "client": "dvc", "refs": refs} - logger.trace("Sending %s to %s", json, endpoint) # type: ignore[attr-defined] - - headers = {"Authorization": f"token {token}"} - resp = session.post(endpoint, json=json, headers=headers, timeout=5) - resp.raise_for_status() - - def _push( repo, git_remote: str, diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 8ba6925ecb..27675dc2d1 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -21,7 +21,7 @@ @locked @scm_context -def remove( # noqa: C901 +def remove( # noqa: C901, PLR0912 repo: "Repo", exp_names: Union[None, str, List[str]] = None, rev: Optional[str] = None, @@ -39,13 +39,6 @@ def remove( # noqa: C901 removed.extend(celery_queue.clear(queued=True)) assert isinstance(repo.scm, Git) - if all_commits: - removed.extend( - _remove_commited_exps( - repo.scm, list(exp_refs(repo.scm, git_remote)), git_remote - ) - ) - return removed exp_ref_list: List["ExpRefInfo"] = [] queue_entry_list: List["QueueEntry"] = [] @@ -70,6 +63,9 @@ def remove( # noqa: C901 exp_ref_dict = _resolve_exp_by_baseline(repo, rev, num, git_remote) removed.extend(exp_ref_dict.keys()) exp_ref_list.extend(exp_ref_dict.values()) + elif all_commits: + exp_ref_list.extend(exp_refs(repo.scm, git_remote)) + removed = [ref.name for ref in exp_ref_list] if exp_ref_list: _remove_commited_exps(repo.scm, exp_ref_list, git_remote) @@ -79,6 +75,11 @@ def remove( # noqa: C901 remove_tasks(celery_queue, queue_entry_list) + if git_remote: + from .push import notify_refs_to_studio + + removed_refs = [str(r) for r in exp_ref_list] + notify_refs_to_studio(repo.config["feature"], git_remote, removed=removed_refs) return removed diff --git a/dvc/utils/studio.py b/dvc/utils/studio.py new file mode 100644 index 0000000000..ce87f3b1df --- /dev/null +++ b/dvc/utils/studio.py @@ -0,0 +1,93 @@ +import logging +import os +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple +from urllib.parse import urljoin + +from dvc_studio_client.post_live_metrics import get_studio_repo_url +from funcy import compact +from requests import RequestException, Session +from requests.adapters import HTTPAdapter + +if TYPE_CHECKING: + from requests import Response + +logger = logging.getLogger(__name__) + +STUDIO_URL = "https://studio.iterative.ai" +STUDIO_REPO_URL = "STUDIO_REPO_URL" +STUDIO_TOKEN = "STUDIO_TOKEN" # noqa: S105 + + +def get_studio_token_and_repo_url( + default_token: Optional[str] = None, + repo_url_finder: Optional[Callable[[], Optional[str]]] = None, +) -> Tuple[Optional[str], Optional[str]]: + token = os.getenv(STUDIO_TOKEN) or default_token + url_finder = repo_url_finder or get_studio_repo_url + repo_url = os.getenv(STUDIO_REPO_URL) or url_finder() + return token, repo_url + + +def post( + endpoint: str, + token: str, + data: Dict[str, Any], + url: Optional[str] = STUDIO_URL, + max_retries: int = 3, + timeout: int = 5, +) -> "Response": + endpoint = urljoin(url or STUDIO_URL, endpoint) + session = Session() + session.mount(endpoint, HTTPAdapter(max_retries=max_retries)) + + logger.trace("Sending %s to %s", data, endpoint) # type: ignore[attr-defined] + + headers = {"Authorization": f"token {token}"} + resp = session.post(endpoint, json=data, headers=headers, timeout=timeout) + resp.raise_for_status() + return resp + + +def notify_refs( + git_remote: str, + default_token: Optional[str] = None, + studio_url: Optional[str] = None, + repo_url_finder: Optional[Callable[[], Optional[str]]] = None, + **refs: List[str], +) -> None: + # TODO: Should we use git_remote to associate with Studio project + # instead of using `git ls-remote` on fallback? + refs = compact(refs) + if not refs: + return + + assert git_remote + token, repo_url = get_studio_token_and_repo_url( + default_token=default_token, + repo_url_finder=repo_url_finder, + ) + if not token: + logger.debug("Studio token not found.") + return + + if not repo_url: + logger.warning( + "Could not detect repository url. " + "Please set %s environment variable " + "to your remote git repository url. ", + STUDIO_REPO_URL, + ) + return + + logger.debug( + "notifying Studio%s about updated experiments", + f" ({studio_url})" if studio_url else "", + ) + data = {"repo_url": repo_url, "client": "dvc", "refs": refs} + + try: + post("/webhook/dvc", token=token, data=data, url=studio_url) + except RequestException: + # TODO: handle expected failures and show appropriate message + # TODO: handle unexpected failures and show appropriate message + logger.debug("failed to notify Studio", exc_info=True) diff --git a/tests/unit/repo/experiments/test_push.py b/tests/unit/repo/experiments/test_push.py deleted file mode 100644 index f2ccfcd2dd..0000000000 --- a/tests/unit/repo/experiments/test_push.py +++ /dev/null @@ -1,29 +0,0 @@ -from urllib.parse import urljoin - -from requests import Response - -from dvc.repo.experiments.push import STUDIO_URL, _notify_studio - - -def test_notify_studio_for_exp_push(mocker): - valid_response = Response() - valid_response.status_code = 200 - mock_post = mocker.patch("requests.Session.post", return_value=valid_response) - - _notify_studio( - ["ref1", "ref2", "ref3"], - "git@github.com:iterative/dvc.git", - "TOKEN", - ) - - assert mock_post.called - assert mock_post.call_args == mocker.call( - urljoin(STUDIO_URL, "/webhook/dvc"), - json={ - "repo_url": "git@github.com:iterative/dvc.git", - "client": "dvc", - "refs": ["ref1", "ref2", "ref3"], - }, - headers={"Authorization": "token TOKEN"}, - timeout=5, - ) diff --git a/tests/unit/utils/test_studio.py b/tests/unit/utils/test_studio.py new file mode 100644 index 0000000000..df9598764f --- /dev/null +++ b/tests/unit/utils/test_studio.py @@ -0,0 +1,44 @@ +from urllib.parse import urljoin + +import pytest +from requests import Response + +from dvc.utils.studio import STUDIO_URL, notify_refs + + +@pytest.mark.parametrize( + "status_code", + [ + 200, # success + 401, # should not fail on client errors + 500, # should not fail even on server errors + ], +) +def test_notify_refs(mocker, status_code): + response = Response() + response.status_code = status_code + + mock_post = mocker.patch("requests.Session.post", return_value=response) + + notify_refs( + "origin", + "TOKEN", + repo_url_finder=lambda: "git@github.com:iterative/dvc.git", + pushed=["p1", "p2"], + removed=["r1", "r2"], + ) + + assert mock_post.called + assert mock_post.call_args == mocker.call( + urljoin(STUDIO_URL, "/webhook/dvc"), + json={ + "repo_url": "git@github.com:iterative/dvc.git", + "client": "dvc", + "refs": { + "pushed": ["p1", "p2"], + "removed": ["r1", "r2"], + }, + }, + headers={"Authorization": "token TOKEN"}, + timeout=5, + )