Skip to content

Commit

Permalink
[ENH] FE talks to logservice (chroma-core#1793)
Browse files Browse the repository at this point in the history
## Description of changes
https://linear.app/trychroma/issue/CHR-242/fe-talks-to-log-service
- FE talks to logservice

## Test plan
*How are these changes tested?*

- [ ] test_logservice
  • Loading branch information
weiligu authored Mar 1, 2024
1 parent e1ad5f9 commit 091e466
Show file tree
Hide file tree
Showing 13 changed files with 391 additions and 19 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/chroma-cluster-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ jobs:
platform: ['16core-64gb-ubuntu-latest']
testfile: ["chromadb/test/db/test_system.py",
"chromadb/test/ingest/test_producer_consumer.py",
"chromadb/test/segment/distributed/test_memberlist_provider.py"]
"chromadb/test/segment/distributed/test_memberlist_provider.py",
"chromadb/test/test_logservice.py"]
runs-on: ${{ matrix.platform }}
steps:
- name: Checkout
Expand Down Expand Up @@ -65,4 +66,4 @@ jobs:
- name: Start Tilt
run: tilt ci
- name: Test
run: bin/cluster-test.sh bash -c 'cd go && go test -timeout 30s -run ^TestNodeWatcher$ github.com/chroma/chroma-coordinator/internal/memberlist_manager'
run: bin/cluster-test.sh bash -c 'cd go && go test -timeout 30s -run ^TestNodeWatcher$ github.com/chroma/chroma-coordinator/internal/memberlist_manager'
9 changes: 6 additions & 3 deletions Tiltfile
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ k8s_resource(
'coordinator-serviceaccount-rolebinding:RoleBinding',
'coordinator-worker-memberlist-binding:clusterrolebinding',

'logservice-serviceaccount:serviceaccount',

'worker-serviceaccount:serviceaccount',
'worker-serviceaccount-rolebinding:RoleBinding',
'worker-memberlist-readerwriter:ClusterRole',
Expand All @@ -65,14 +67,15 @@ k8s_resource(
k8s_resource('postgres', resource_deps=['k8s_setup'], labels=["infrastructure"])
k8s_resource('pulsar', resource_deps=['k8s_setup'], labels=["infrastructure"], port_forwards=['6650:6650', '8080:8080'])
k8s_resource('migration', resource_deps=['postgres'], labels=["infrastructure"])
k8s_resource('logservice', resource_deps=['migration'], labels=["chroma"])
k8s_resource('frontend-server', resource_deps=['pulsar'],labels=["chroma"], port_forwards=8000 )
k8s_resource('logservice', resource_deps=['migration'], labels=["chroma"], port_forwards='50052:50051')
k8s_resource('frontend-server', resource_deps=['logservice'],labels=["chroma"], port_forwards=8000 )
k8s_resource('coordinator', resource_deps=['pulsar', 'frontend-server', 'migration'], labels=["chroma"], port_forwards=50051)
k8s_resource('worker', resource_deps=['coordinator'],labels=["chroma"])

# Extra stuff to make debugging and testing easier
k8s_yaml([
'k8s/test/coordinator_service.yaml',
'k8s/test/logservice_service.yaml',
'k8s/test/minio.yaml',
'k8s/test/pulsar_service.yaml',
'k8s/test/worker_service.yaml',
Expand All @@ -90,4 +93,4 @@ k8s_resource(
)

# Local S3
k8s_resource('minio-deployment', resource_deps=['k8s_setup'], labels=["debug"], port_forwards=9000)
k8s_resource('minio-deployment', resource_deps=['k8s_setup'], labels=["debug"], port_forwards=9000)
1 change: 1 addition & 0 deletions bin/cluster-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ echo "Pulsar Broker is running at port $PULSAR_BROKER_URL"
echo "Chroma Coordinator is running at port $CHROMA_COORDINATOR_HOST"

kubectl -n chroma port-forward svc/coordinator-lb 50051:50051 &
kubectl -n chroma port-forward svc/logservice-lb 50052:50051 &
kubectl -n chroma port-forward svc/pulsar-lb 6650:6650 &
kubectl -n chroma port-forward svc/pulsar-lb 8080:8080 &
kubectl -n chroma port-forward svc/frontend-server 8000:8000 &
Expand Down
17 changes: 13 additions & 4 deletions chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,13 @@
"chromadb.segment.SegmentManager": "chroma_segment_manager_impl",
"chromadb.segment.distributed.SegmentDirectory": "chroma_segment_directory_impl",
"chromadb.segment.distributed.MemberlistProvider": "chroma_memberlist_provider_impl",
"chromadb.rate_limiting.RateLimitingProvider": "chroma_rate_limiting_provider_impl"
"chromadb.rate_limiting.RateLimitingProvider": "chroma_rate_limiting_provider_impl",
}

DEFAULT_TENANT = "default_tenant"
DEFAULT_DATABASE = "default_database"


class Settings(BaseSettings): # type: ignore
environment: str = ""

Expand All @@ -101,8 +102,10 @@ class Settings(BaseSettings): # type: ignore
chroma_segment_manager_impl: str = (
"chromadb.segment.impl.manager.local.LocalSegmentManager"
)
chroma_quota_provider_impl:Optional[str] = None
chroma_rate_limiting_provider_impl:Optional[str] = None

chroma_quota_provider_impl: Optional[str] = None
chroma_rate_limiting_provider_impl: Optional[str] = None

# Distributed architecture specific components
chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory"
chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider"
Expand All @@ -112,6 +115,9 @@ class Settings(BaseSettings): # type: ignore
worker_memberlist_name: str = "worker-memberlist"
chroma_coordinator_host = "localhost"

chroma_logservice_host = "localhost"
chroma_logservice_port = 50052

tenant_id: str = "default"
topic_namespace: str = "default"

Expand Down Expand Up @@ -320,7 +326,10 @@ def __init__(self, settings: Settings):
if settings[key] is not None:
raise ValueError(LEGACY_ERROR)

if settings["chroma_segment_cache_policy"] is not None and settings["chroma_segment_cache_policy"] != "LRU":
if (
settings["chroma_segment_cache_policy"] is not None
and settings["chroma_segment_cache_policy"] != "LRU"
):
logger.error(
f"Failed to set chroma_segment_cache_policy: Only LRU is available."
)
Expand Down
171 changes: 171 additions & 0 deletions chromadb/logservice/logservice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import sys

import grpc

from chromadb.ingest import (
Producer,
Consumer,
ConsumerCallbackFn,
)
from chromadb.proto.chroma_pb2 import (
SubmitEmbeddingRecord as ProtoSubmitEmbeddingRecord,
)
from chromadb.proto.convert import to_proto_submit
from chromadb.proto.logservice_pb2 import PushLogsRequest, PullLogsRequest
from chromadb.proto.logservice_pb2_grpc import LogServiceStub
from chromadb.types import (
SubmitEmbeddingRecord,
SeqId,
)
from chromadb.config import System
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from overrides import override
from typing import Sequence, Optional, Dict, cast
from uuid import UUID
import logging

logger = logging.getLogger(__name__)


class LogService(Producer, Consumer):
"""
Distributed Chroma Log Service
"""

_log_service_stub: LogServiceStub
_channel: grpc.Channel
_log_service_url: str
_log_service_port: int

def __init__(self, system: System):
self._log_service_url = system.settings.require("chroma_logservice_host")
self._log_service_port = system.settings.require("chroma_logservice_port")
self._opentelemetry_client = system.require(OpenTelemetryClient)
super().__init__(system)

@trace_method("LogService.start", OpenTelemetryGranularity.ALL)
@override
def start(self) -> None:
self._channel = grpc.insecure_channel(
f"{self._log_service_url}:{self._log_service_port}"
)
self._log_service_stub = LogServiceStub(self._channel) # type: ignore
super().start()

@trace_method("LogService.stop", OpenTelemetryGranularity.ALL)
@override
def stop(self) -> None:
self._channel.close()
super().stop()

@trace_method("LogService.reset_state", OpenTelemetryGranularity.ALL)
@override
def reset_state(self) -> None:
super().reset_state()

@override
def create_topic(self, topic_name: str) -> None:
raise NotImplementedError("Not implemented")

@trace_method("LogService.delete_topic", OpenTelemetryGranularity.ALL)
@override
def delete_topic(self, topic_name: str) -> None:
raise NotImplementedError("Not implemented")

@trace_method("LogService.submit_embedding", OpenTelemetryGranularity.ALL)
@override
def submit_embedding(
self, topic_name: str, embedding: SubmitEmbeddingRecord
) -> SeqId:
if not self._running:
raise RuntimeError("Component not running")

return self.submit_embeddings(topic_name, [embedding])[0] # type: ignore

@trace_method("LogService.submit_embeddings", OpenTelemetryGranularity.ALL)
@override
def submit_embeddings(
self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord]
) -> Sequence[SeqId]:
logger.info(f"Submitting {len(embeddings)} embeddings to {topic_name}")

if not self._running:
raise RuntimeError("Component not running")

if len(embeddings) == 0:
return []

# push records to the log service
collection_id_to_embeddings: Dict[UUID, list[SubmitEmbeddingRecord]] = {}
for embedding in embeddings:
collection_id = cast(UUID, embedding.get("collection_id"))
if collection_id is None:
raise ValueError("collection_id is required")
if collection_id not in collection_id_to_embeddings:
collection_id_to_embeddings[collection_id] = []
collection_id_to_embeddings[collection_id].append(embedding)

counts = []
for collection_id, records in collection_id_to_embeddings.items():
protos_to_submit = [to_proto_submit(record) for record in records]
counts.append(
self.push_logs(
collection_id,
cast(Sequence[SubmitEmbeddingRecord], protos_to_submit),
)
)

return counts

@trace_method("LogService.subscribe", OpenTelemetryGranularity.ALL)
@override
def subscribe(
self,
topic_name: str,
consume_fn: ConsumerCallbackFn,
start: Optional[SeqId] = None,
end: Optional[SeqId] = None,
id: Optional[UUID] = None,
) -> UUID:
logger.info(f"Subscribing to {topic_name}, noop for logservice")
return UUID(int=0)

@trace_method("LogService.unsubscribe", OpenTelemetryGranularity.ALL)
@override
def unsubscribe(self, subscription_id: UUID) -> None:
logger.info(f"Unsubscribing from {subscription_id}, noop for logservice")

@override
def min_seqid(self) -> SeqId:
return 0

@override
def max_seqid(self) -> SeqId:
return sys.maxsize

@property
@override
def max_batch_size(self) -> int:
return sys.maxsize

def push_logs(
self, collection_id: UUID, records: Sequence[SubmitEmbeddingRecord]
) -> int:
request = PushLogsRequest(collection_id=str(collection_id), records=records)
response = self._log_service_stub.PushLogs(request)
return response.record_count # type: ignore

def pull_logs(
self, collection_id: UUID, start_id: int, batch_size: int
) -> Sequence[ProtoSubmitEmbeddingRecord]:
request = PullLogsRequest(
collection_id=str(collection_id),
start_from_id=start_id,
batch_size=batch_size,
)
response = self._log_service_stub.PullLogs(request)
return response.records # type: ignore
2 changes: 2 additions & 0 deletions chromadb/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def basic_http_client() -> Generator[System, None, None]:
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_http_port=8000,
chroma_server_host="localhost",
allow_reset=True,
)
system = System(settings)
Expand Down Expand Up @@ -468,6 +469,7 @@ def system_wrong_auth(
def system(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]:
yield next(request.param())


@pytest.fixture(scope="module", params=system_fixtures_ssl())
def system_ssl(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]:
yield next(request.param())
Expand Down
Loading

0 comments on commit 091e466

Please sign in to comment.