Skip to content

Commit

Permalink
Merge pull request #573 from DagsHub/ls-api-integrations
Browse files Browse the repository at this point in the history
LS API Integrations
  • Loading branch information
kbolashev authored Jan 7, 2025
2 parents 0f51a70 + 7fd5803 commit e3418d7
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 58 deletions.
2 changes: 2 additions & 0 deletions dagshub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .upload.wrapper import upload_files
from . import notebook
from .repo_bucket import get_repo_bucket_client
from .ls_client import get_label_studio_client
from . import storage

__all__ = [
Expand All @@ -13,5 +14,6 @@
upload_files.__name__,
notebook.save_notebook.__name__,
get_repo_bucket_client.__name__,
get_label_studio_client.__name__,
storage.__name__,
]
91 changes: 90 additions & 1 deletion dagshub/common/api/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from os import PathLike
from pathlib import Path, PurePosixPath
import rich.progress
import json

from dagshub.common.api.responses import (
RepoAPIResponse,
Expand All @@ -11,14 +12,19 @@
ContentAPIEntry,
StorageContentAPIResult,
)
from dagshub.data_engine.model.errors import LSInitializingError
from dagshub.common.download import download_files
from dagshub.common.rich_util import get_rich_progress
from dagshub.common.util import multi_urljoin
from functools import partial

from pydantic import BaseModel

from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_exception_type

from functools import cached_property

from typing import Optional, Tuple, Any, List, Union
from typing import Optional, Tuple, Any, List, Union, Dict

import dacite

Expand Down Expand Up @@ -50,6 +56,11 @@ class PathNotFoundError(Exception):
pass


class LabelStudioProject(BaseModel):
project_name: str
project_id: str


class RepoAPI:
def __init__(self, repo: str, host: Optional[str] = None, auth: Optional[Any] = None):
"""
Expand All @@ -68,6 +79,15 @@ def __init__(self, repo: str, host: Optional[str] = None, auth: Optional[Any] =
else:
self.auth = auth

@retry(retry=retry_if_exception_type(LSInitializingError), wait=wait_fixed(3), stop=stop_after_attempt(5))
def _tenacious_ls_request(self, *args, **kwargs):
res = self.http_request(*args, **kwargs)
if res.text.startswith("<!DOCTYPE html>"):
raise LSInitializingError()
elif res.status_code // 100 != 2:
raise RuntimeError(f"Process failed! Server Response: {res.text}")
return res

def _http_request(self, method, url, **kwargs):
if "auth" not in kwargs:
kwargs["auth"] = self.auth
Expand Down Expand Up @@ -142,6 +162,66 @@ def get_connected_storages(self) -> List[StorageAPIEntry]:

return [dacite.from_dict(StorageAPIEntry, storage_entry) for storage_entry in res.json()]

def list_annotation_projects(self) -> Dict[str, LabelStudioProject]:
"""
Get annotation projects that are associated with the repository
"""
res = self._tenacious_ls_request(
"GET",
multi_urljoin(self.label_studio_api_url(), "projects"),
headers={"Authorization": f"Bearer {dagshub.auth.get_token()}"},
)
return {
project["title"]: LabelStudioProject(project_name=project["title"], project_id=str(project["id"]))
for project in res.json()["results"]
}

def add_annotation_project(self, project_name: str, config: Optional[str] = None) -> None:
"""
Add an annotation project to the repository
"""
if project_name in self.list_annotation_projects():
raise ValueError(f"{project_name} already exists!")

self._tenacious_ls_request(
"POST",
multi_urljoin(self.label_studio_api_url(), "projects"),
headers={"Authorization": f"Bearer {dagshub.auth.get_token()}", "Content-Type": "application/json"},
data=json.dumps(
{"title": project_name, "label_config": config} if config is not None else {"title": project_name}
),
)

def update_label_studio_project_config(self, project_name: str, config: str) -> None:
"""
Update the labelling config of an annotation project
"""
projects = self.list_annotation_projects()
if project_name not in projects:
raise ValueError(f"{project_name} doesn't exist!")

self._tenacious_ls_request(
"PATCH",
multi_urljoin(self.label_studio_api_url(), "projects", projects[project_name].project_id),
headers={"Authorization": f"Bearer {dagshub.auth.get_token()}", "Content-Type": "application/json"},
data=json.dumps({"title": project_name, "label_config": config}),
)

def add_autolabelling_endpoint(self, project_name: str, endpoint: str) -> None:
"""
Add an endpoint serving MLBackend for auto-labelling
"""
projects = self.list_annotation_projects()
if project_name not in projects:
raise ValueError(f"{project_name} not in projects. Available project names: {list(projects.keys())}")

self._tenacious_ls_request(
"POST",
multi_urljoin(self.label_studio_api_url(), "ml"),
headers={"Authorization": f"Bearer {dagshub.auth.get_token()}", "Content-Type": "application/json"},
data=json.dumps({"url": endpoint, "project": projects[project_name].project_id}),
)

def list_path(self, path: str, revision: Optional[str] = None, include_size: bool = False) -> List[ContentAPIEntry]:
"""
List contents of a repository directory
Expand Down Expand Up @@ -550,6 +630,15 @@ def storage_api_url(self) -> str:
"""
return multi_urljoin(self.repo_api_url, "storage")

def label_studio_api_url(self) -> str:
"""
URL for getting label studio api
Format: https://dagshub.com/<repo-owner>/<repo-name>/annotations/de/api/
:meta private:
"""
return multi_urljoin(self.host, self.full_name, "annotations/de/api/")

def repo_bucket_api_url(self) -> str:
"""
Endpoint URL for getting access to the S3-compatible repo bucket
Expand Down
63 changes: 6 additions & 57 deletions dagshub/data_engine/model/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
import threading
import time
import uuid
import requests
import webbrowser
from contextlib import contextmanager
from dataclasses import dataclass, field
from os import PathLike
from pathlib import Path
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, Set, ContextManager, Tuple, Literal, Callable
from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_exception_type


import rich.progress
Expand Down Expand Up @@ -45,7 +43,6 @@
DatasetFieldComparisonError,
FieldNotFoundError,
DatasetNotFoundError,
LSInitializingError,
)
from dagshub.data_engine.model.metadata import (
validate_uploading_metadata,
Expand Down Expand Up @@ -1109,7 +1106,6 @@ def _encode_query_for_frontend(self) -> str:
def fields(self) -> List[MetadataFieldSchema]:
return self.source.metadata_fields

@retry(retry=retry_if_exception_type(LSInitializingError), wait=wait_fixed(3), stop=stop_after_attempt(5))
async def add_annotation_model_from_config(self, config, project_name, ngrok_authtoken, port=9090):
"""
Initialize a LS backend for ML annotation using a preset configuration.
Expand All @@ -1121,38 +1117,15 @@ async def add_annotation_model_from_config(self, config, project_name, ngrok_aut
ngrok_authtoken: uses ngrok to forward local connection
port: (optional, default: 9090) port on which orchestrator is hosted
"""
ls_api_endpoint = multi_urljoin(self.source.repoApi.host, self.source.repo, "annotations/de/api/projects")

res = requests.get(
ls_api_endpoint,
headers={"Authorization": f"Bearer {dagshub.auth.get_token()}"},
)
if res.text.startswith("<!DOCTYPE html>"):
raise LSInitializingError()
elif res.status_code // 100 != 2:
raise ValueError(f"Adding backend failed! Response: {res.text}")
projects = {project["title"]: str(project["id"]) for project in res.json()["results"]}
projects = self.source.repoApi.list_annotation_projects()

if project_name not in projects:
res = requests.post(
ls_api_endpoint,
headers={"Authorization": f"Bearer {dagshub.auth.get_token()}", "Content-Type": "application/json"},
data=json.dumps({"title": project_name, "label_config": config.pop("label_config")}),
)
if res.status_code // 100 != 2:
raise ValueError(f"Adding backend failed! Response: {res.text}")
self.source.repoApi.add_annotation_project(project_name, config.pop("label_config"))
else:
res = requests.patch(
multi_urljoin(ls_api_endpoint, projects[project_name]),
headers={"Authorization": f"Bearer {dagshub.auth.get_token()}", "Content-Type": "application/json"},
data=json.dumps({"label_config": config.pop("label_config")}),
)
if res.status_code // 100 != 2:
raise ValueError(f"Adding backend failed! Response: {res.text}")
self.source.repoApi.update_label_studio_project_config(project_name, config.pop("label_config"))

await self.add_annotation_model(**config, port=port, project_name=project_name, ngrok_authtoken=ngrok_authtoken)

@retry(retry=retry_if_exception_type(LSInitializingError), wait=wait_fixed(3), stop=stop_after_attempt(5))
async def add_annotation_model(
self,
repo: str,
Expand Down Expand Up @@ -1186,7 +1159,8 @@ def fn_encoder(fn):
raise ValueError("As `ngrok_authtoken` is not specified, project will have to be added manually.")
with get_rich_progress() as progress:
task = progress.add_task("Initializing LS Model...", total=1)
res = requests.post(
res = http_request(
"POST",
f"{LS_ORCHESTRATOR_URL}:{port}/configure",
headers={"Content-Type": "application/json"},
json=json.dumps(
Expand Down Expand Up @@ -1217,32 +1191,7 @@ def fn_encoder(fn):
progress.update(task, advance=1, description="Configured any necessary forwarding")

if project_name:
ls_api_endpoint = multi_urljoin(self.source.repoApi.host, self.source.repo, "annotations/de/api")
res = requests.get(
multi_urljoin(ls_api_endpoint, "projects"),
headers={"Authorization": f"Bearer {dagshub.auth.get_token()}"},
)
if res.text.startswith("<!DOCTYPE html>"):
raise LSInitializingError()
elif res.status_code // 100 != 2:
raise ValueError(f"Adding backend failed! Response: {res.text}")
projects = {project["title"]: str(project["id"]) for project in res.json()["results"]}

if project_name not in projects:
raise ValueError(
f"{project_name} not in projects. Available project names: {list(projects.keys())}"
)

res = requests.post(
multi_urljoin(ls_api_endpoint, "ml"),
headers={"Authorization": f"Bearer {dagshub.auth.get_token()}", "Content-Type": "application/json"},
data=json.dumps({"url": endpoint, "project": projects[project_name]}),
)
if res.status_code // 100 == 2:
progress.update(task, advance=1, description="Added model to LS backend")
print("Backend added successfully!")
else:
raise ValueError(f"Adding backend failed! Response: {res.text}")
self.source.repoApi.add_autolabelling_endpoint(project_name, endpoint)
else:
progress.update(task, advance=1, description="Added model to LS backend")
print(f"Connection Established! Add LS endpoint: {endpoint} to your project.")
Expand Down
1 change: 1 addition & 0 deletions dagshub/data_engine/model/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def predict_with_mlflow_model(
Args:
repo: repository to extract the model from
name: name of the model in the repository's MLflow registry.
host: address of the DagsHub instance with the repo to load the model from.
Set it if the model is hosted on a different DagsHub instance than the datasource.
Expand Down
65 changes: 65 additions & 0 deletions dagshub/ls_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_exception_type
from json import JSONDecodeError
from typing import Optional
import importlib.util
import semver

from dagshub.common.api.repo import RepoAPI
from dagshub.auth import get_token
from dagshub.common import config


def _use_legacy_client():
"""
https://github.com/HumanSignal/label-studio/releases/tag/1.13.0, \
https://github.com/HumanSignal/label-studio/pull/5961; \
introduces breaking changes; anyone using SDK < 1.0 should use the legacy client.
:meta experimental:
"""
import label_studio_sdk

return semver.compare("1.0.0", label_studio_sdk.__version__) == 1


@retry(retry=retry_if_exception_type(JSONDecodeError), wait=wait_fixed(3), stop=stop_after_attempt(5))
def get_label_studio_client(
repo: str, legacy_client: Optional[bool] = None, host: Optional[str] = None, token: Optional[str] = None
):
"""
Creates a `label_studio_sdk.Client / label_studio_sdk.client.LabelStudio \
<https://labelstud.io/guide/sdk> / \
https://api.labelstud.io/api-reference/introduction/getting-started`.\
object to interact with the label studio instance associated with the repository.
Args:
repo: Name of the repo in the format of ``username/reponame``
legacy_client: if True, returns the older legacy LabelStudio Client.
host: URL of the hosted DagsHub instance. default is ``https://dagshub.com``.
token: (optional, default: None) uses programmatically specified token, \
if not provided either uses cached token or requests oauth interactively.
Returns:
`label_studio_sdk.Client` / `label_studio_sdk.client.LabelStudio` object
"""

if importlib.util.find_spec("label_studio_sdk") is None:
raise ModuleNotFoundError("Could not import module label_studio_sdk. Make sure to pip install label_studio_sdk")

if legacy_client is None:
legacy_client = _use_legacy_client()

if not host:
host = config.host

if legacy_client:
from label_studio_sdk import Client as LabelStudio
else:
from label_studio_sdk.client import LabelStudio

repo_api = RepoAPI(repo, host=host)
kwargs = {
"url" if legacy_client else "base_url": repo_api.label_studio_api_url()[:-4],
"api_key": token if token is not None else get_token(host=host),
}

return LabelStudio(**kwargs)
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ run ``pip install dagshub[jupyter]`` to install additional dependencies to enhan
reference/file_downloading
reference/file_uploading
reference/repo_bucket
reference/ls_client
reference/loading_models
reference/repo_api
reference/notebook
Expand Down
5 changes: 5 additions & 0 deletions docs/source/reference/ls_client.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
LS Client Documentation
=======================

.. automodule:: dagshub.ls_client
:members:
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_version(rel_path: str) -> str:
"pathvalidate>=3.0.0",
"python-dateutil",
"boto3",
"semver",
"dagshub-annotation-converter>=0.1.0",
]

Expand Down

0 comments on commit e3418d7

Please sign in to comment.