Skip to content

Commit

Permalink
Merge branch 'main' into feature/hnsw-max-seq-id-persistence
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov authored Mar 27, 2024
2 parents 38e686d + 1ce93c7 commit 3d62b49
Show file tree
Hide file tree
Showing 100 changed files with 2,024 additions and 3,310 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,4 @@
"unordered_set": "cpp",
"algorithm": "cpp"
},
}
}
27 changes: 11 additions & 16 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
)

import chromadb.types as t

from typing import Any, Optional, Sequence, Generator, List, cast, Set, Dict
from overrides import override
from uuid import UUID, uuid4
Expand Down Expand Up @@ -123,10 +122,12 @@ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
name=name,
tenant=tenant,
)

@trace_method("SegmentAPI.get_database", OpenTelemetryGranularity.OPERATION)
@override
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database:
return self._sysdb.get_database(name=name, tenant=tenant)

@trace_method("SegmentAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
def create_tenant(self, name: str) -> None:
Expand All @@ -136,6 +137,7 @@ def create_tenant(self, name: str) -> None:
self._sysdb.create_tenant(
name=name,
)

@trace_method("SegmentAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
@override
def get_tenant(self, name: str) -> t.Tenant:
Expand Down Expand Up @@ -374,15 +376,14 @@ def _add(
for r in _records(
t.Operation.ADD,
ids=ids,
collection_id=collection_id,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
):
self._validate_embedding_record(coll, r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
CollectionAddEvent(
Expand Down Expand Up @@ -417,15 +418,14 @@ def _update(
for r in _records(
t.Operation.UPDATE,
ids=ids,
collection_id=collection_id,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
):
self._validate_embedding_record(coll, r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
CollectionUpdateEvent(
Expand Down Expand Up @@ -462,15 +462,14 @@ def _upsert(
for r in _records(
t.Operation.UPSERT,
ids=ids,
collection_id=collection_id,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
):
self._validate_embedding_record(coll, r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

return True

Expand Down Expand Up @@ -632,12 +631,10 @@ def _delete(
return []

records_to_submit = []
for r in _records(
operation=t.Operation.DELETE, ids=ids_to_delete, collection_id=collection_id
):
for r in _records(operation=t.Operation.DELETE, ids=ids_to_delete):
self._validate_embedding_record(coll, r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
CollectionDeleteEvent(
Expand Down Expand Up @@ -803,7 +800,7 @@ def max_batch_size(self) -> int:
# used for channel assignment in the distributed version of the system.
@trace_method("SegmentAPI._validate_embedding_record", OpenTelemetryGranularity.ALL)
def _validate_embedding_record(
self, collection: t.Collection, record: t.SubmitEmbeddingRecord
self, collection: t.Collection, record: t.OperationRecord
) -> None:
"""Validate the dimension of an embedding record before submitting it to the system."""
add_attributes_to_current_span({"collection_id": str(collection["id"])})
Expand Down Expand Up @@ -845,12 +842,11 @@ def _get_collection(self, collection_id: UUID) -> t.Collection:
def _records(
operation: t.Operation,
ids: IDs,
collection_id: UUID,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> Generator[t.SubmitEmbeddingRecord, None, None]:
) -> Generator[t.OperationRecord, None, None]:
"""Convert parallel lists of embeddings, metadatas and documents to a sequence of
SubmitEmbeddingRecords"""

Expand All @@ -877,13 +873,12 @@ def _records(
else:
metadata = {"chroma:uri": uri}

record = t.SubmitEmbeddingRecord(
record = t.OperationRecord(
id=id,
embedding=embeddings[i] if embeddings else None,
encoding=t.ScalarEncoding.FLOAT32, # Hardcode for now
metadata=metadata,
operation=operation,
collection_id=collection_id,
)
yield record

Expand Down
13 changes: 4 additions & 9 deletions chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
"chromadb.ingest.Producer": "chroma_producer_impl",
"chromadb.ingest.Consumer": "chroma_consumer_impl",
"chromadb.quota.QuotaProvider": "chroma_quota_provider_impl",
"chromadb.ingest.CollectionAssignmentPolicy": "chroma_collection_assignment_policy_impl", # noqa
"chromadb.db.system.SysDB": "chroma_sysdb_impl",
"chromadb.segment.SegmentManager": "chroma_segment_manager_impl",
"chromadb.segment.distributed.SegmentDirectory": "chroma_segment_directory_impl",
Expand All @@ -86,9 +85,12 @@
class Settings(BaseSettings): # type: ignore
environment: str = ""

# Legacy config has to be kept around because pydantic will error
# Legacy config that has to be kept around because pydantic will error
# on nonexisting keys
chroma_db_impl: Optional[str] = None
chroma_collection_assignment_policy_impl: str = (
"chromadb.ingest.impl.simple_policy.SimpleAssignmentPolicy"
)
# Can be "chromadb.api.segment.SegmentAPI" or "chromadb.api.fastapi.FastAPI"
chroma_api_impl: str = "chromadb.api.segment.SegmentAPI"
chroma_product_telemetry_impl: str = "chromadb.telemetry.product.posthog.Posthog"
Expand All @@ -109,9 +111,6 @@ class Settings(BaseSettings): # type: ignore
# Distributed architecture specific components
chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory"
chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider"
chroma_collection_assignment_policy_impl: str = (
"chromadb.ingest.impl.simple_policy.SimpleAssignmentPolicy"
)
worker_memberlist_name: str = "query-service-memberlist"
chroma_coordinator_host = "localhost"

Expand Down Expand Up @@ -146,10 +145,6 @@ def empty_str_to_none(cls, v: str) -> Optional[str]:

chroma_server_nofile: Optional[int] = None

pulsar_broker_url: Optional[str] = None
pulsar_admin_port: Optional[int] = 8080
pulsar_broker_port: Optional[int] = 6650

chroma_server_auth_provider: Optional[str] = None

@validator("chroma_server_auth_provider", pre=True, always=True, allow_reuse=True)
Expand Down
21 changes: 0 additions & 21 deletions chromadb/db/impl/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
UpdateSegmentRequest,
)
from chromadb.proto.coordinator_pb2_grpc import SysDBStub
from chromadb.telemetry.opentelemetry import OpenTelemetryClient
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
from chromadb.types import (
Collection,
Expand Down Expand Up @@ -145,14 +144,12 @@ def get_segments(
id: Optional[UUID] = None,
type: Optional[str] = None,
scope: Optional[SegmentScope] = None,
topic: Optional[str] = None,
collection: Optional[UUID] = None,
) -> Sequence[Segment]:
request = GetSegmentsRequest(
id=id.hex if id else None,
type=type,
scope=to_proto_segment_scope(scope) if scope else None,
topic=topic,
collection=collection.hex if collection else None,
)
response = self._sys_db_stub.GetSegments(request)
Expand All @@ -166,14 +163,9 @@ def get_segments(
def update_segment(
self,
id: UUID,
topic: OptionalArgument[Optional[str]] = Unspecified(),
collection: OptionalArgument[Optional[UUID]] = Unspecified(),
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
) -> None:
write_topic = None
if topic != Unspecified():
write_topic = cast(Union[str, None], topic)

write_collection = None
if collection != Unspecified():
write_collection = cast(Union[UUID, None], collection)
Expand All @@ -184,17 +176,12 @@ def update_segment(

request = UpdateSegmentRequest(
id=id.hex,
topic=write_topic,
collection=write_collection.hex if write_collection else None,
metadata=to_proto_update_metadata(write_metadata)
if write_metadata
else None,
)

if topic is None:
request.ClearField("topic")
request.reset_topic = True

if collection is None:
request.ClearField("collection")
request.reset_collection = True
Expand Down Expand Up @@ -252,7 +239,6 @@ def delete_collection(
def get_collections(
self,
id: Optional[UUID] = None,
topic: Optional[str] = None,
name: Optional[str] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
Expand All @@ -262,7 +248,6 @@ def get_collections(
# TODO: implement limit and offset in the gRPC service
request = GetCollectionsRequest(
id=id.hex if id else None,
topic=topic,
name=name,
tenant=tenant,
database=database,
Expand All @@ -277,15 +262,10 @@ def get_collections(
def update_collection(
self,
id: UUID,
topic: OptionalArgument[str] = Unspecified(),
name: OptionalArgument[str] = Unspecified(),
dimension: OptionalArgument[Optional[int]] = Unspecified(),
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
) -> None:
write_topic = None
if topic != Unspecified():
write_topic = cast(str, topic)

write_name = None
if name != Unspecified():
write_name = cast(str, name)
Expand All @@ -300,7 +280,6 @@ def update_collection(

request = UpdateCollectionRequest(
id=id.hex,
topic=write_topic,
name=write_name,
dimension=write_dimension,
metadata=to_proto_update_metadata(write_metadata)
Expand Down
18 changes: 1 addition & 17 deletions chromadb/db/impl/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Dict, cast
from uuid import UUID
from overrides import overrides
from chromadb.ingest import CollectionAssignmentPolicy
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component, System
from chromadb.proto.convert import (
from_proto_metadata,
Expand Down Expand Up @@ -38,7 +37,7 @@
UpdateCollectionRequest,
UpdateCollectionResponse,
UpdateSegmentRequest,
UpdateSegmentResponse
UpdateSegmentResponse,
)
from chromadb.proto.coordinator_pb2_grpc import (
SysDBServicer,
Expand All @@ -55,7 +54,6 @@ class GrpcMockSysDB(SysDBServicer, Component):

_server: grpc.Server
_server_port: int
_assignment_policy: CollectionAssignmentPolicy
_segments: Dict[str, Segment] = {}
_tenants_to_databases_to_collections: Dict[
str, Dict[str, Dict[str, Collection]]
Expand All @@ -64,7 +62,6 @@ class GrpcMockSysDB(SysDBServicer, Component):

def __init__(self, system: System):
self._server_port = system.settings.require("chroma_server_grpc_port")
self._assignment_policy = system.instance(CollectionAssignmentPolicy)
return super().__init__(system)

@overrides
Expand Down Expand Up @@ -203,7 +200,6 @@ def GetSegments(
if request.HasField("scope")
else None
)
target_topic = request.topic if request.HasField("topic") else None
target_collection = (
UUID(hex=request.collection) if request.HasField("collection") else None
)
Expand All @@ -216,8 +212,6 @@ def GetSegments(
continue
if target_scope and segment["scope"] != target_scope:
continue
if target_topic and segment["topic"] != target_topic:
continue
if target_collection and segment["collection"] != target_collection:
continue
found_segments.append(segment)
Expand All @@ -238,10 +232,6 @@ def UpdateSegment(
)
else:
segment = self._segments[id_to_update.hex]
if request.HasField("topic"):
segment["topic"] = request.topic
if request.HasField("reset_topic") and request.reset_topic:
segment["topic"] = None
if request.HasField("collection"):
segment["collection"] = UUID(hex=request.collection)
if request.HasField("reset_collection") and request.reset_collection:
Expand Down Expand Up @@ -326,7 +316,6 @@ def CreateCollection(
name=request.name,
metadata=from_proto_metadata(request.metadata),
dimension=request.dimension,
topic=self._assignment_policy.assign_collection(id),
database=database,
tenant=tenant,
)
Expand Down Expand Up @@ -368,7 +357,6 @@ def GetCollections(
self, request: GetCollectionsRequest, context: grpc.ServicerContext
) -> GetCollectionsResponse:
target_id = UUID(hex=request.id) if request.HasField("id") else None
target_topic = request.topic if request.HasField("topic") else None
target_name = request.name if request.HasField("name") else None

tenant = request.tenant
Expand All @@ -387,8 +375,6 @@ def GetCollections(
for collection in collections.values():
if target_id and collection["id"] != target_id:
continue
if target_topic and collection["topic"] != target_topic:
continue
if target_name and collection["name"] != target_name:
continue
found_collections.append(collection)
Expand Down Expand Up @@ -418,8 +404,6 @@ def UpdateCollection(
)
else:
collection = collections[id_to_update.hex]
if request.HasField("topic"):
collection["topic"] = request.topic
if request.HasField("name"):
collection["name"] = request.name
if request.HasField("dimension"):
Expand Down
Loading

0 comments on commit 3d62b49

Please sign in to comment.