Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
drewkim committed Oct 16, 2024
1 parent 682e31e commit 80d5ade
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 66 deletions.
2 changes: 1 addition & 1 deletion chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def get_max_batch_size(self) -> int:
pass

@abstractmethod
def resolve_tenant_and_databases(self) -> UserIdentity:
def get_user_identity(self) -> UserIdentity:
"""Resolve the tenant and databases for the client. Returns the default
values if can't be resolved.
Expand Down
2 changes: 1 addition & 1 deletion chromadb/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ async def get_max_batch_size(self) -> int:
pass

@abstractmethod
async def resolve_tenant_and_databases(self) -> UserIdentity:
async def get_user_identity(self) -> UserIdentity:
"""Resolve the tenant and databases for the client. Returns the default
values if can't be resolved.
Expand Down
15 changes: 7 additions & 8 deletions chromadb/api/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ async def create(
# Get the root system component we want to interact with
self._server = self._system.instance(AsyncServerAPI)

user_identity = await self.resolve_tenant_and_databases()
user_identity = await self.get_user_identity()

maybe_tenant, maybe_database = SharedSystemClient.maybe_set_tenant_and_database(
settings.chroma_overwrite_singleton_tenant_database_access_from_auth,
tenant,
database,
user_identity.tenant,
user_identity.databases,
user_identity,
overwrite_singleton_tenant_database_access_from_auth=settings.chroma_overwrite_singleton_tenant_database_access_from_auth,
tenant=tenant,
database=database,
)
if maybe_tenant:
self.tenant = maybe_tenant
Expand Down Expand Up @@ -107,8 +106,8 @@ def from_system(
)

@override
async def resolve_tenant_and_databases(self) -> UserIdentity:
return await self._server.resolve_tenant_and_databases()
async def get_user_identity(self) -> UserIdentity:
return await self._server.get_user_identity()

@override
async def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
Expand Down
6 changes: 2 additions & 4 deletions chromadb/api/async_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,9 @@ async def get_tenant(self, name: str) -> Tenant:

return Tenant(name=resp_json["name"])

@trace_method(
"AsyncFastAPI.resolve_tenant_and_databases", OpenTelemetryGranularity.OPERATION
)
@trace_method("AsyncFastAPI.get_user_identity", OpenTelemetryGranularity.OPERATION)
@override
async def resolve_tenant_and_databases(self) -> UserIdentity:
async def get_user_identity(self) -> UserIdentity:
return UserIdentity(**(await self._make_request("get", "/auth/identity")))

@trace_method("AsyncFastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
Expand Down
19 changes: 9 additions & 10 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,13 @@ def __init__(
# Get the root system component we want to interact with
self._server = self._system.instance(ServerAPI)

user_identity = self.resolve_tenant_and_databases()
user_identity = self.get_user_identity()

maybe_tenant, maybe_database = SharedSystemClient.maybe_set_tenant_and_database(
settings.chroma_overwrite_singleton_tenant_database_access_from_auth,
tenant,
database,
user_identity.tenant,
user_identity.databases,
user_identity,
overwrite_singleton_tenant_database_access_from_auth=settings.chroma_overwrite_singleton_tenant_database_access_from_auth,
tenant=tenant,
database=database,
)
if maybe_tenant:
self.tenant = maybe_tenant
Expand Down Expand Up @@ -96,18 +95,18 @@ def from_system(
# endregion

@override
def resolve_tenant_and_databases(self) -> UserIdentity:
def get_user_identity(self) -> UserIdentity:
try:
return self._server.resolve_tenant_and_databases()
return self._server.get_user_identity()
except httpx.ConnectError:
raise ValueError(
"Could not connect to a Chroma server. Are you sure it is running?"
)
# Propagate ChromaErrors
except ChromaError as e:
raise e
except Exception:
raise ValueError("Could not resolve tenant and database.")
except Exception as e:
raise ValueError(str(e))

# region BaseAPI Methods
# Note - we could do this in less verbose ways, but they break type checking
Expand Down
6 changes: 2 additions & 4 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,9 @@ def get_tenant(self, name: str) -> Tenant:
resp_json = self._make_request("get", "/tenants/" + name)
return Tenant(name=resp_json["name"])

@trace_method(
"FastAPI.resolve_tenant_and_databases", OpenTelemetryGranularity.OPERATION
)
@trace_method("FastAPI.get_user_identity", OpenTelemetryGranularity.OPERATION)
@override
def resolve_tenant_and_databases(self) -> UserIdentity:
def get_user_identity(self) -> UserIdentity:
return UserIdentity(**self._make_request("get", "/auth/identity"))

@trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
Expand Down
2 changes: 1 addition & 1 deletion chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def create_tenant(self, name: str) -> None:
)

@override
def resolve_tenant_and_databases(self) -> UserIdentity:
def get_user_identity(self) -> UserIdentity:
return UserIdentity(
user_id="",
tenant=DEFAULT_TENANT,
Expand Down
16 changes: 8 additions & 8 deletions chromadb/api/shared_system_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import ClassVar, Dict, Optional, Tuple, List
from typing import ClassVar, Dict, Optional, Tuple
import uuid

from chromadb.auth import UserIdentity
from chromadb.api import ServerAPI
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
from chromadb.config import Settings, System
Expand Down Expand Up @@ -95,9 +96,8 @@ def _submit_client_start_event(self) -> None:

@staticmethod
def _singleton_tenant_database_if_applicable(
user_identity: UserIdentity,
overwrite_singleton_tenant_database_access_from_auth: bool,
user_tenant: Optional[str],
user_databases: Optional[List[str]],
) -> Tuple[Optional[str], Optional[str]]:
"""
If settings.chroma_overwrite_singleton_tenant_database_access_from_auth
Expand All @@ -118,6 +118,8 @@ def _singleton_tenant_database_if_applicable(
return None, None
tenant = None
database = None
user_tenant = user_identity.tenant
user_databases = user_identity.databases
if user_tenant and user_tenant != "*":
tenant = user_tenant
if user_databases and len(user_databases) == 1 and user_databases[0] != "*":
Expand All @@ -126,19 +128,17 @@ def _singleton_tenant_database_if_applicable(

@staticmethod
def maybe_set_tenant_and_database(
user_identity: UserIdentity,
overwrite_singleton_tenant_database_access_from_auth: bool,
tenant: Optional[str] = None,
database: Optional[str] = None,
user_tenant: Optional[str] = None,
user_databases: Optional[List[str]] = None,
) -> Tuple[Optional[str], Optional[str]]:
(
new_tenant,
new_database,
) = SharedSystemClient._singleton_tenant_database_if_applicable(
overwrite_singleton_tenant_database_access_from_auth,
user_tenant,
user_databases,
user_identity=user_identity,
overwrite_singleton_tenant_database_access_from_auth=overwrite_singleton_tenant_database_access_from_auth,
)

if (not tenant or tenant == DEFAULT_TENANT) and new_tenant:
Expand Down
17 changes: 13 additions & 4 deletions chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import logging
from abc import ABC
from enum import Enum
from graphlib import TopologicalSorter
from typing import Optional, List, Any, Dict, Set, Iterable, Union
from typing import Type, TypeVar, cast
Expand Down Expand Up @@ -91,6 +92,11 @@
DEFAULT_DATABASE = "default_database"


class APIVersion(str, Enum):
V1 = "/api/v1"
V2 = "/api/v2"


class Settings(BaseSettings): # type: ignore
# ==============
# Generic config
Expand Down Expand Up @@ -124,7 +130,7 @@ def empty_str_to_none(cls, v: str) -> Optional[str]:
chroma_server_ssl_enabled: Optional[bool] = False

chroma_server_ssl_verify: Optional[Union[bool, str]] = None
chroma_server_api_default_path: Optional[str] = "/api/v2"
chroma_server_api_default_path: Optional[str] = APIVersion.V2
# eg ["http://localhost:3000"]
chroma_server_cors_allow_origins: List[str] = []

Expand Down Expand Up @@ -162,9 +168,12 @@ def empty_str_to_none(cls, v: str) -> Optional[str]:
# ================

chroma_server_auth_ignore_paths: Dict[str, List[str]] = {
"/api/v2": ["GET"],
"/api/v2/heartbeat": ["GET"],
"/api/v2/version": ["GET"],
f"{APIVersion.V2}": ["GET"],
f"{APIVersion.V2}/heartbeat": ["GET"],
f"{APIVersion.V2}/version": ["GET"],
f"{APIVersion.V1}": ["GET"],
f"{APIVersion.V1}/heartbeat": ["GET"],
f"{APIVersion.V1}/version": ["GET"],
}
# Overwrite singleton tenant and database access from the auth provider
# if applicable. See chromadb/server/fastapi/__init__.py's
Expand Down
44 changes: 23 additions & 21 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,17 @@ def __init__(self, settings: Settings):

self.router = ChromaAPIRouter()

self.setup_v2_routes()
self.setup_v1_routes()

self._app.include_router(self.router)

use_route_names_as_operation_ids(self._app)
instrument_fastapi(self._app)
telemetry_client = self._system.instance(ProductTelemetryClient)
telemetry_client.capture(ServerStartEvent())

def setup_v2_routes(self) -> None:
self.router.add_api_route("/api/v2", self.root, methods=["GET"])
self.router.add_api_route("/api/v2/reset", self.reset, methods=["POST"])
self.router.add_api_route("/api/v2/version", self.version, methods=["GET"])
Expand All @@ -206,7 +217,7 @@ def __init__(self, settings: Settings):

self.router.add_api_route(
"/api/v2/auth/identity",
self.resolve_tenant_and_databases,
self.get_user_identity,
methods=["GET"],
response_model=None,
)
Expand Down Expand Up @@ -320,9 +331,10 @@ def __init__(self, settings: Settings):
response_model=None,
)

# ======================================================================
def setup_v1_routes(self) -> None:
# =====================================================================
# OLD ROUTES FOR BACKWARDS COMPATIBILITY — WILL BE REMOVED
# ======================================================================
# =====================================================================

self.router.add_api_route("/api/v1", self.root, methods=["GET"])
self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"])
Expand Down Expand Up @@ -441,14 +453,7 @@ def __init__(self, settings: Settings):
response_model=None,
)

# ======================================================================

self._app.include_router(self.router)

use_route_names_as_operation_ids(self._app)
instrument_fastapi(self._app)
telemetry_client = self._system.instance(ProductTelemetryClient)
telemetry_client.capture(ServerStartEvent())
# =====================================================================

def shutdown(self) -> None:
self._system.stop()
Expand Down Expand Up @@ -532,10 +537,8 @@ def auth_request(
)
return

@trace_method(
"FastAPI.resolve_tenant_and_databases", OpenTelemetryGranularity.OPERATION
)
async def resolve_tenant_and_databases(
@trace_method("FastAPI.get_user_identity", OpenTelemetryGranularity.OPERATION)
async def get_user_identity(
self,
request: Request,
) -> UserIdentity:
Expand All @@ -544,8 +547,7 @@ async def resolve_tenant_and_databases(
user_id="", tenant=DEFAULT_TENANT, databases=[DEFAULT_DATABASE]
)

identity = self.authn_provider.authenticate_or_raise(dict(request.headers))
return cast(UserIdentity, await to_thread.run_sync(lambda: identity))
return self.authn_provider.authenticate_or_raise(dict(request.headers))

@trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
async def create_database(
Expand Down Expand Up @@ -1156,9 +1158,9 @@ def process_pre_flight_checks() -> Dict[str, Any]:
),
)

# ==========================================================================
# OLD CONTROLLERS FOR BACKWARD COMPATIBILITY — WILL BE REMOVED
# ==========================================================================
# =========================================================================
# OLD V1 FUNCTIONS FOR BACKWARD COMPATIBILITY — WILL BE REMOVED
# =========================================================================

@trace_method(
"auth_and_get_tenant_and_database_for_request",
Expand Down Expand Up @@ -1841,4 +1843,4 @@ def process_query(request: Request, raw_body: bytes) -> QueryResult:

return nnresult

# ==========================================================================
# =========================================================================
1 change: 1 addition & 0 deletions chromadb/telemetry/product/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"chroma_api_impl",
"is_persistent",
"chroma_server_ssl_enabled",
"chroma_server_api_default_path",
]


Expand Down
6 changes: 2 additions & 4 deletions chromadb/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,11 @@ def http_api_factory(
) -> Generator[HttpAPIFactory, None, None]:
if request.param == "sync_client":
with patch("chromadb.api.client.Client._validate_tenant_database"):
with patch("chromadb.api.client.Client.resolve_tenant_and_databases"):
with patch("chromadb.api.client.Client.get_user_identity"):
yield chromadb.HttpClient
else:
with patch("chromadb.api.async_client.AsyncClient._validate_tenant_database"):
with patch(
"chromadb.api.async_client.AsyncClient.resolve_tenant_and_databases"
):
with patch("chromadb.api.async_client.AsyncClient.get_user_identity"):

def factory(*args: Any, **kwargs: Any) -> Any:
cls = asyncio.get_event_loop().run_until_complete(
Expand Down

0 comments on commit 80d5ade

Please sign in to comment.