Skip to content

Commit

Permalink
[ENH] Add quota component and test for static (#1720)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - New functionality
- Add quota check, it will be use to be able to rate limit, apply static
check to payload etc.


## Test plan
*How are these changes tested?*

- [ ] Tests pass locally with `pytest`, added unit test

---------

Co-authored-by: nicolas <[email protected]>
  • Loading branch information
nicolasgere and nicolas authored Feb 16, 2024
1 parent da68516 commit e6ceeee
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 2 deletions.
6 changes: 5 additions & 1 deletion chromadb/api/segment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from chromadb.api import ServerAPI
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.db.system import SysDB
from chromadb.quota import QuotaEnforcer
from chromadb.segment import SegmentManager, MetadataReader, VectorReader
from chromadb.telemetry.opentelemetry import (
add_attributes_to_current_span,
Expand Down Expand Up @@ -58,7 +59,6 @@
import logging
import re


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -101,6 +101,7 @@ def __init__(self, system: System):
self._settings = system.settings
self._sysdb = self.require(SysDB)
self._manager = self.require(SegmentManager)
self._quota = self.require(QuotaEnforcer)
self._product_telemetry_client = self.require(ProductTelemetryClient)
self._opentelemetry_client = self.require(OpenTelemetryClient)
self._producer = self.require(Producer)
Expand Down Expand Up @@ -356,6 +357,7 @@ def _add(
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
self._quota.static_check(metadatas, documents, embeddings, collection_id)
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.ADD)
validate_batch(
Expand Down Expand Up @@ -398,6 +400,7 @@ def _update(
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
self._quota.static_check(metadatas, documents, embeddings, collection_id)
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
validate_batch(
Expand Down Expand Up @@ -442,6 +445,7 @@ def _upsert(
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
self._quota.static_check(metadatas, documents, embeddings, collection_id)
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
validate_batch(
Expand Down
3 changes: 3 additions & 0 deletions chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@
"chromadb.telemetry.product.ProductTelemetryClient": "chroma_product_telemetry_impl",
"chromadb.ingest.Producer": "chroma_producer_impl",
"chromadb.ingest.Consumer": "chroma_consumer_impl",
"chromadb.quota.QuotaProvider": "chroma_quota_provider_impl",
"chromadb.ingest.CollectionAssignmentPolicy": "chroma_collection_assignment_policy_impl", # noqa
"chromadb.db.system.SysDB": "chroma_sysdb_impl",
"chromadb.segment.SegmentManager": "chroma_segment_manager_impl",
"chromadb.segment.distributed.SegmentDirectory": "chroma_segment_directory_impl",
"chromadb.segment.distributed.MemberlistProvider": "chroma_memberlist_provider_impl",

}

DEFAULT_TENANT = "default_tenant"
Expand All @@ -99,6 +101,7 @@ class Settings(BaseSettings): # type: ignore
chroma_segment_manager_impl: str = (
"chromadb.segment.impl.manager.local.LocalSegmentManager"
)
chroma_quota_provider_impl:Optional[str] = None

# Distributed architecture specific components
chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory"
Expand Down
90 changes: 90 additions & 0 deletions chromadb/quota/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from abc import abstractmethod
from enum import Enum
from typing import Optional, Literal

from chromadb import Documents, Embeddings
from chromadb.api import Metadatas
from chromadb.config import (
Component,
System,
)


class Resource(Enum):
METADATA_KEY_LENGTH = "METADATA_KEY_LENGTH"
METADATA_VALUE_LENGTH = "METADATA_VALUE_LENGTH"
DOCUMENT_SIZE = "DOCUMENT_SIZE"
EMBEDDINGS_DIMENSION = "EMBEDDINGS_DIMENSION"


class QuotaError(Exception):
def __init__(self, resource: Resource, quota: int, actual: int):
super().__init__(f"quota error. resource: {resource} quota: {quota} actual: {actual}")
self.quota = quota
self.actual = actual
self.resource = resource

class QuotaProvider(Component):
"""
Retrieves quotas for resources within a system.
Methods:
get_for_subject(resource, subject=None, tier=None):
Returns the quota for a given resource, optionally considering the tier and subject.
"""
def __init__(self, system: System) -> None:
super().__init__(system)
self.system = system

@abstractmethod
def get_for_subject(self, resource: Resource, subject: Optional[str] = None, tier: Optional[str] = None) -> \
Optional[int]:
pass


class QuotaEnforcer(Component):
"""
Enforces quota restrictions on various resources using quota provider.
Methods:
static_check(metadatas=None, documents=None, embeddings=None, collection_id=None):
Performs static checks against quotas for metadatas, documents, and embeddings. Raises QuotaError if limits are exceeded.
"""
def __init__(self, system: System) -> None:
super().__init__(system)
self.should_enforce = False
if system.settings.chroma_quota_provider_impl:
self._quota_provider = system.require(QuotaProvider)
self.should_enforce = True
self.system = system

def static_check(self, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None,
embeddings: Optional[Embeddings] = None, collection_id: Optional[str] = None):
if not self.should_enforce:
return
metadata_key_length_quota = self._quota_provider.get_for_subject(resource=Resource.METADATA_KEY_LENGTH,
subject=collection_id)
metadata_value_length_quota = self._quota_provider.get_for_subject(resource=Resource.METADATA_VALUE_LENGTH,
subject=collection_id)
if metadatas and (metadata_key_length_quota or metadata_key_length_quota):
for metadata in metadatas:
for key in metadata:
if metadata_key_length_quota and len(key) > metadata_key_length_quota:
raise QuotaError(resource=Resource.METADATA_KEY_LENGTH, actual=len(key),
quota=metadata_key_length_quota)
if metadata_value_length_quota and isinstance(metadata[key], str) and len(
metadata[key]) > metadata_value_length_quota:
raise QuotaError(resource=Resource.METADATA_VALUE_LENGTH, actual=len(metadata[key]),
quota=metadata_value_length_quota)
document_size_quota = self._quota_provider.get_for_subject(resource=Resource.DOCUMENT_SIZE, subject=collection_id)
if document_size_quota and documents:
for document in documents:
if len(document) > document_size_quota:
raise QuotaError(resource=Resource.DOCUMENT_SIZE, actual=len(document), quota=document_size_quota)
embedding_dimension_quota = self._quota_provider.get_for_subject(resource=Resource.EMBEDDINGS_DIMENSION,
subject=collection_id)
if embedding_dimension_quota and embeddings:
for embedding in embeddings:
if len(embedding) > embedding_dimension_quota:
raise QuotaError(resource=Resource.EMBEDDINGS_DIMENSION, actual=len(embedding),
quota=embedding_dimension_quota)
14 changes: 14 additions & 0 deletions chromadb/quota/test_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Optional

from overrides import overrides

from chromadb.quota import QuotaProvider, Resource


class QuotaProviderForTest(QuotaProvider):
def __init__(self, system) -> None:
super().__init__(system)

@overrides
def get_for_subject(self, resource: Resource, subject: Optional[str] = "", tier: Optional[str] = "") -> Optional[int]:
pass
8 changes: 8 additions & 0 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
InvalidDimensionException,
InvalidHTTPVersion,
)
from chromadb.quota import QuotaError
from chromadb.server.fastapi.types import (
AddEmbedding,
CreateDatabase,
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(self, settings: Settings):
allow_origins=settings.chroma_server_cors_allow_origins,
allow_methods=["*"],
)
self._app.add_exception_handler(QuotaError, self.quota_exception_handler)

self._app.on_event("shutdown")(self.shutdown)

Expand Down Expand Up @@ -291,6 +293,12 @@ def app(self) -> fastapi.FastAPI:
def root(self) -> Dict[str, int]:
return {"nanosecond heartbeat": self._api.heartbeat()}

async def quota_exception_handler(request: Request, exc: QuotaError):
return JSONResponse(
status_code=429,
content={"message": f"quota error. resource: {exc.resource} quota: {exc.quota} actual: {exc.actual}"},
)

def heartbeat(self) -> Dict[str, int]:
return self.root()

Expand Down
1 change: 0 additions & 1 deletion chromadb/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,6 @@ def system_wrong_auth(
def system(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]:
yield next(request.param())


@pytest.fixture(scope="module", params=system_fixtures_ssl())
def system_ssl(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]:
yield next(request.param())
Expand Down
78 changes: 78 additions & 0 deletions chromadb/test/quota/test_static_quota_enforcer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import random
import string
from typing import Optional, List, Tuple, Any
from unittest.mock import patch

from chromadb.config import System, Settings
from chromadb.quota import QuotaEnforcer, Resource
import pytest


def generate_random_string(size: int) -> str:
return ''.join(random.choices(string.ascii_letters + string.digits, k=size))

def mock_get_for_subject(self, resource: Resource, subject: Optional[str] = "", tier: Optional[str] = "") -> Optional[
int]:
"""Mock function to simulate quota retrieval."""
return 10


def run_static_checks(enforcer: QuotaEnforcer, test_cases: List[Tuple[Any, Optional[str]]], data_key: str):
"""Generalized function to run static checks on different types of data."""
for test_case in test_cases:
data, expected_error = test_case if len(test_case) == 2 else (test_case[0], None)
args = {data_key: [data]}
if expected_error:
with pytest.raises(Exception) as exc_info:
enforcer.static_check(**args)
assert expected_error in str(exc_info.value.resource)
else:
enforcer.static_check(**args)



@pytest.fixture(scope="module")
def enforcer() -> QuotaEnforcer:
settings = Settings(
chroma_quota_provider_impl = "chromadb.quota.test_provider.QuotaProviderForTest"
)
system = System(settings)
return system.require(QuotaEnforcer)

@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject)
def test_static_enforcer_metadata(enforcer):
test_cases = [
({generate_random_string(20): generate_random_string(5)}, "METADATA_KEY_LENGTH"),
({generate_random_string(5): generate_random_string(5)}, None),
({generate_random_string(5): generate_random_string(20)}, "METADATA_VALUE_LENGTH"),
({generate_random_string(5): generate_random_string(5)}, None)
]
run_static_checks(enforcer, test_cases, 'metadatas')


@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject)
def test_static_enforcer_documents(enforcer):
test_cases = [
(generate_random_string(20), "DOCUMENT_SIZE"),
(generate_random_string(5), None)
]
run_static_checks(enforcer, test_cases, 'documents')

@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject)
def test_static_enforcer_embeddings(enforcer):
test_cases = [
(random.sample(range(1, 101), 100), "EMBEDDINGS_DIMENSION"),
(random.sample(range(1, 101), 5), None)
]
run_static_checks(enforcer, test_cases, 'embeddings')

# Should not raise an error if no quota provider is present
def test_enforcer_without_quota_provider():
test_cases = [
(random.sample(range(1, 101), 1), None),
(random.sample(range(1, 101), 5), None)
]
settings = Settings()
system = System(settings)
enforcer = system.require(QuotaEnforcer)
run_static_checks(enforcer, test_cases, 'embeddings')

0 comments on commit e6ceeee

Please sign in to comment.