-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] Add quota component and test for static (#1720)
## Description of changes *Summarize the changes made by this PR.* - New functionality - Add quota check, it will be use to be able to rate limit, apply static check to payload etc. ## Test plan *How are these changes tested?* - [ ] Tests pass locally with `pytest`, added unit test --------- Co-authored-by: nicolas <[email protected]>
- Loading branch information
1 parent
da68516
commit e6ceeee
Showing
7 changed files
with
198 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from abc import abstractmethod | ||
from enum import Enum | ||
from typing import Optional, Literal | ||
|
||
from chromadb import Documents, Embeddings | ||
from chromadb.api import Metadatas | ||
from chromadb.config import ( | ||
Component, | ||
System, | ||
) | ||
|
||
|
||
class Resource(Enum): | ||
METADATA_KEY_LENGTH = "METADATA_KEY_LENGTH" | ||
METADATA_VALUE_LENGTH = "METADATA_VALUE_LENGTH" | ||
DOCUMENT_SIZE = "DOCUMENT_SIZE" | ||
EMBEDDINGS_DIMENSION = "EMBEDDINGS_DIMENSION" | ||
|
||
|
||
class QuotaError(Exception): | ||
def __init__(self, resource: Resource, quota: int, actual: int): | ||
super().__init__(f"quota error. resource: {resource} quota: {quota} actual: {actual}") | ||
self.quota = quota | ||
self.actual = actual | ||
self.resource = resource | ||
|
||
class QuotaProvider(Component): | ||
""" | ||
Retrieves quotas for resources within a system. | ||
Methods: | ||
get_for_subject(resource, subject=None, tier=None): | ||
Returns the quota for a given resource, optionally considering the tier and subject. | ||
""" | ||
def __init__(self, system: System) -> None: | ||
super().__init__(system) | ||
self.system = system | ||
|
||
@abstractmethod | ||
def get_for_subject(self, resource: Resource, subject: Optional[str] = None, tier: Optional[str] = None) -> \ | ||
Optional[int]: | ||
pass | ||
|
||
|
||
class QuotaEnforcer(Component): | ||
""" | ||
Enforces quota restrictions on various resources using quota provider. | ||
Methods: | ||
static_check(metadatas=None, documents=None, embeddings=None, collection_id=None): | ||
Performs static checks against quotas for metadatas, documents, and embeddings. Raises QuotaError if limits are exceeded. | ||
""" | ||
def __init__(self, system: System) -> None: | ||
super().__init__(system) | ||
self.should_enforce = False | ||
if system.settings.chroma_quota_provider_impl: | ||
self._quota_provider = system.require(QuotaProvider) | ||
self.should_enforce = True | ||
self.system = system | ||
|
||
def static_check(self, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, | ||
embeddings: Optional[Embeddings] = None, collection_id: Optional[str] = None): | ||
if not self.should_enforce: | ||
return | ||
metadata_key_length_quota = self._quota_provider.get_for_subject(resource=Resource.METADATA_KEY_LENGTH, | ||
subject=collection_id) | ||
metadata_value_length_quota = self._quota_provider.get_for_subject(resource=Resource.METADATA_VALUE_LENGTH, | ||
subject=collection_id) | ||
if metadatas and (metadata_key_length_quota or metadata_key_length_quota): | ||
for metadata in metadatas: | ||
for key in metadata: | ||
if metadata_key_length_quota and len(key) > metadata_key_length_quota: | ||
raise QuotaError(resource=Resource.METADATA_KEY_LENGTH, actual=len(key), | ||
quota=metadata_key_length_quota) | ||
if metadata_value_length_quota and isinstance(metadata[key], str) and len( | ||
metadata[key]) > metadata_value_length_quota: | ||
raise QuotaError(resource=Resource.METADATA_VALUE_LENGTH, actual=len(metadata[key]), | ||
quota=metadata_value_length_quota) | ||
document_size_quota = self._quota_provider.get_for_subject(resource=Resource.DOCUMENT_SIZE, subject=collection_id) | ||
if document_size_quota and documents: | ||
for document in documents: | ||
if len(document) > document_size_quota: | ||
raise QuotaError(resource=Resource.DOCUMENT_SIZE, actual=len(document), quota=document_size_quota) | ||
embedding_dimension_quota = self._quota_provider.get_for_subject(resource=Resource.EMBEDDINGS_DIMENSION, | ||
subject=collection_id) | ||
if embedding_dimension_quota and embeddings: | ||
for embedding in embeddings: | ||
if len(embedding) > embedding_dimension_quota: | ||
raise QuotaError(resource=Resource.EMBEDDINGS_DIMENSION, actual=len(embedding), | ||
quota=embedding_dimension_quota) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from typing import Optional | ||
|
||
from overrides import overrides | ||
|
||
from chromadb.quota import QuotaProvider, Resource | ||
|
||
|
||
class QuotaProviderForTest(QuotaProvider): | ||
def __init__(self, system) -> None: | ||
super().__init__(system) | ||
|
||
@overrides | ||
def get_for_subject(self, resource: Resource, subject: Optional[str] = "", tier: Optional[str] = "") -> Optional[int]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import random | ||
import string | ||
from typing import Optional, List, Tuple, Any | ||
from unittest.mock import patch | ||
|
||
from chromadb.config import System, Settings | ||
from chromadb.quota import QuotaEnforcer, Resource | ||
import pytest | ||
|
||
|
||
def generate_random_string(size: int) -> str: | ||
return ''.join(random.choices(string.ascii_letters + string.digits, k=size)) | ||
|
||
def mock_get_for_subject(self, resource: Resource, subject: Optional[str] = "", tier: Optional[str] = "") -> Optional[ | ||
int]: | ||
"""Mock function to simulate quota retrieval.""" | ||
return 10 | ||
|
||
|
||
def run_static_checks(enforcer: QuotaEnforcer, test_cases: List[Tuple[Any, Optional[str]]], data_key: str): | ||
"""Generalized function to run static checks on different types of data.""" | ||
for test_case in test_cases: | ||
data, expected_error = test_case if len(test_case) == 2 else (test_case[0], None) | ||
args = {data_key: [data]} | ||
if expected_error: | ||
with pytest.raises(Exception) as exc_info: | ||
enforcer.static_check(**args) | ||
assert expected_error in str(exc_info.value.resource) | ||
else: | ||
enforcer.static_check(**args) | ||
|
||
|
||
|
||
@pytest.fixture(scope="module") | ||
def enforcer() -> QuotaEnforcer: | ||
settings = Settings( | ||
chroma_quota_provider_impl = "chromadb.quota.test_provider.QuotaProviderForTest" | ||
) | ||
system = System(settings) | ||
return system.require(QuotaEnforcer) | ||
|
||
@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject) | ||
def test_static_enforcer_metadata(enforcer): | ||
test_cases = [ | ||
({generate_random_string(20): generate_random_string(5)}, "METADATA_KEY_LENGTH"), | ||
({generate_random_string(5): generate_random_string(5)}, None), | ||
({generate_random_string(5): generate_random_string(20)}, "METADATA_VALUE_LENGTH"), | ||
({generate_random_string(5): generate_random_string(5)}, None) | ||
] | ||
run_static_checks(enforcer, test_cases, 'metadatas') | ||
|
||
|
||
@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject) | ||
def test_static_enforcer_documents(enforcer): | ||
test_cases = [ | ||
(generate_random_string(20), "DOCUMENT_SIZE"), | ||
(generate_random_string(5), None) | ||
] | ||
run_static_checks(enforcer, test_cases, 'documents') | ||
|
||
@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject) | ||
def test_static_enforcer_embeddings(enforcer): | ||
test_cases = [ | ||
(random.sample(range(1, 101), 100), "EMBEDDINGS_DIMENSION"), | ||
(random.sample(range(1, 101), 5), None) | ||
] | ||
run_static_checks(enforcer, test_cases, 'embeddings') | ||
|
||
# Should not raise an error if no quota provider is present | ||
def test_enforcer_without_quota_provider(): | ||
test_cases = [ | ||
(random.sample(range(1, 101), 1), None), | ||
(random.sample(range(1, 101), 5), None) | ||
] | ||
settings = Settings() | ||
system = System(settings) | ||
enforcer = system.require(QuotaEnforcer) | ||
run_static_checks(enforcer, test_cases, 'embeddings') |