Skip to content

Commit

Permalink
[Improvements] Manage segment cache and memory (chroma-core#1670)
Browse files Browse the repository at this point in the history
Add configuration to limit the memory use by segment cache.
  • Loading branch information
nicolasgere authored Jan 31, 2024
1 parent a370684 commit e5751fd
Show file tree
Hide file tree
Showing 10 changed files with 323 additions and 25 deletions.
13 changes: 12 additions & 1 deletion chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@
DEFAULT_TENANT = "default_tenant"
DEFAULT_DATABASE = "default_database"


class Settings(BaseSettings): # type: ignore
environment: str = ""

Expand Down Expand Up @@ -116,6 +115,9 @@ class Settings(BaseSettings): # type: ignore
is_persistent: bool = False
persist_directory: str = "./chroma"

chroma_memory_limit_bytes: int = 0
chroma_segment_cache_policy: Optional[str] = None

chroma_server_host: Optional[str] = None
chroma_server_headers: Optional[Dict[str, str]] = None
chroma_server_http_port: Optional[str] = None
Expand Down Expand Up @@ -313,6 +315,15 @@ def __init__(self, settings: Settings):
if settings[key] is not None:
raise ValueError(LEGACY_ERROR)

if settings["chroma_segment_cache_policy"] is not None and settings["chroma_segment_cache_policy"] != "LRU":
logger.error(
f"Failed to set chroma_segment_cache_policy: Only LRU is available."
)
if settings["chroma_memory_limit_bytes"] == 0:
logger.error(
f"Failed to set chroma_segment_cache_policy: chroma_memory_limit_bytes is require."
)

# Apply the nofile limit if set
if settings["chroma_server_nofile"] is not None:
if platform.system() != "Windows":
Expand Down
Empty file.
Empty file.
Empty file.
104 changes: 104 additions & 0 deletions chromadb/segment/impl/manager/cache/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import uuid
from typing import Any, Callable
from chromadb.types import Segment
from overrides import override
from typing import Dict, Optional
from abc import ABC, abstractmethod

class SegmentCache(ABC):
@abstractmethod
def get(self, key: uuid.UUID) -> Optional[Segment]:
pass

@abstractmethod
def pop(self, key: uuid.UUID) -> Optional[Segment]:
pass

@abstractmethod
def set(self, key: uuid.UUID, value: Segment) -> None:
pass

@abstractmethod
def reset(self) -> None:
pass


class BasicCache(SegmentCache):
def __init__(self):
self.cache:Dict[uuid.UUID, Segment] = {}

@override
def get(self, key: uuid.UUID) -> Optional[Segment]:
return self.cache.get(key)

@override
def pop(self, key: uuid.UUID) -> Optional[Segment]:
return self.cache.pop(key, None)

@override
def set(self, key: uuid.UUID, value: Segment) -> None:
self.cache[key] = value

@override
def reset(self) -> None:
self.cache = {}


class SegmentLRUCache(BasicCache):
"""A simple LRU cache implementation that handles objects with dynamic sizes.
The size of each object is determined by a user-provided size function."""

def __init__(self, capacity: int, size_func: Callable[[uuid.UUID], int],
callback: Optional[Callable[[uuid.UUID, Segment], Any]] = None):
self.capacity = capacity
self.size_func = size_func
self.cache: Dict[uuid.UUID, Segment] = {}
self.history = []
self.callback = callback

def _upsert_key(self, key: uuid.UUID):
if key in self.history:
self.history.remove(key)
self.history.append(key)
else:
self.history.append(key)

@override
def get(self, key: uuid.UUID) -> Optional[Segment]:
self._upsert_key(key)
if key in self.cache:
return self.cache[key]
else:
return None

@override
def pop(self, key: uuid.UUID) -> Optional[Segment]:
if key in self.history:
self.history.remove(key)
return self.cache.pop(key, None)


@override
def set(self, key: uuid.UUID, value: Segment) -> None:
if key in self.cache:
return
item_size = self.size_func(key)
key_sizes = {key: self.size_func(key) for key in self.cache}
total_size = sum(key_sizes.values())
index = 0
# Evict items if capacity is exceeded
while total_size + item_size > self.capacity and len(self.history) > index:
key_delete = self.history[index]
if key_delete in self.cache:
self.callback(key_delete, self.cache[key_delete])
del self.cache[key_delete]
total_size -= key_sizes[key_delete]
index += 1

self.cache[key] = value
self._upsert_key(key)

@override
def reset(self):
self.cache = {}
self.history = []
74 changes: 52 additions & 22 deletions chromadb/segment/impl/manager/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
VectorReader,
S,
)
import logging
from chromadb.segment.impl.manager.cache.cache import SegmentLRUCache, BasicCache,SegmentCache
import os

from chromadb.config import System, get_class
from chromadb.db.system import SysDB
from overrides import override
Expand All @@ -21,24 +25,23 @@
from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata
from typing import Dict, Type, Sequence, Optional, cast
from uuid import UUID, uuid4
from collections import defaultdict
import platform

from chromadb.utils.lru_cache import LRUCache
from chromadb.utils.directory import get_directory_size


if platform.system() != "Windows":
import resource
elif platform.system() == "Windows":
import ctypes


SEGMENT_TYPE_IMPLS = {
SegmentType.SQLITE: "chromadb.segment.impl.metadata.sqlite.SqliteMetadataSegment",
SegmentType.HNSW_LOCAL_MEMORY: "chromadb.segment.impl.vector.local_hnsw.LocalHnswSegment",
SegmentType.HNSW_LOCAL_PERSISTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment",
}


class LocalSegmentManager(SegmentManager):
_sysdb: SysDB
_system: System
Expand All @@ -47,9 +50,6 @@ class LocalSegmentManager(SegmentManager):
_vector_instances_file_handle_cache: LRUCache[
UUID, PersistentLocalHnswSegment
] # LRU cache to manage file handles across vector segment instances
_segment_cache: Dict[
UUID, Dict[SegmentScope, Segment]
] # Tracks which segments are loaded for a given collection
_vector_segment_type: SegmentType = SegmentType.HNSW_LOCAL_MEMORY
_lock: Lock
_max_file_handles: int
Expand All @@ -59,8 +59,17 @@ def __init__(self, system: System):
self._sysdb = self.require(SysDB)
self._system = system
self._opentelemetry_client = system.require(OpenTelemetryClient)
self.logger = logging.getLogger(__name__)
self._instances = {}
self._segment_cache = defaultdict(dict)
self.segment_cache: Dict[SegmentScope, SegmentCache] = {SegmentScope.METADATA: BasicCache()}
if system.settings.chroma_segment_cache_policy == "LRU" and system.settings.chroma_memory_limit_bytes > 0:
self.segment_cache[SegmentScope.VECTOR] = SegmentLRUCache(capacity=system.settings.chroma_memory_limit_bytes,callback=lambda k, v: self.callback_cache_evict(v), size_func=lambda k: self._get_segment_disk_size(k))
else:
self.segment_cache[SegmentScope.VECTOR] = BasicCache()




self._lock = Lock()

# TODO: prototyping with distributed segment for now, but this should be a configurable option
Expand All @@ -72,13 +81,21 @@ def __init__(self, system: System):
else:
self._max_file_handles = ctypes.windll.msvcrt._getmaxstdio() # type: ignore
segment_limit = (
self._max_file_handles
// PersistentLocalHnswSegment.get_file_handle_count()
self._max_file_handles
// PersistentLocalHnswSegment.get_file_handle_count()
)
self._vector_instances_file_handle_cache = LRUCache(
segment_limit, callback=lambda _, v: v.close_persistent_index()
)

def callback_cache_evict(self, segment: Segment):
collection_id = segment["collection"]
self.logger.info(f"LRU cache evict collection {collection_id}")
instance = self._instance(segment)
instance.stop()
del self._instances[segment["id"]]


@override
def start(self) -> None:
for instance in self._instances.values():
Expand All @@ -97,7 +114,7 @@ def reset_state(self) -> None:
instance.stop()
instance.reset_state()
self._instances = {}
self._segment_cache = defaultdict(dict)
self.segment_cache[SegmentScope.VECTOR].reset()
super().reset_state()

@trace_method(
Expand Down Expand Up @@ -130,16 +147,31 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
instance = self.get_segment(collection_id, MetadataReader)
instance.delete()
del self._instances[segment["id"]]
if collection_id in self._segment_cache:
if segment["scope"] in self._segment_cache[collection_id]:
del self._segment_cache[collection_id][segment["scope"]]
del self._segment_cache[collection_id]
if segment["scope"] is SegmentScope.VECTOR:
self.segment_cache[SegmentScope.VECTOR].pop(collection_id)
if segment["scope"] is SegmentScope.METADATA:
self.segment_cache[SegmentScope.METADATA].pop(collection_id)
return [s["id"] for s in segments]

@trace_method(
"LocalSegmentManager.get_segment",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
def _get_segment_disk_size(self, collection_id: UUID) -> int:
segments = self._sysdb.get_segments(collection=collection_id, scope=SegmentScope.VECTOR)
if len(segments) == 0:
return 0
# With local segment manager (single server chroma), a collection always have one segment.
size = get_directory_size(
os.path.join(self._system.settings.require("persist_directory"), str(segments[0]["id"])))
return size

def _get_segment_sysdb(self, collection_id:UUID, scope: SegmentScope):
segments = self._sysdb.get_segments(collection=collection_id, scope=scope)
known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()])
# Get the first segment of a known type
segment = next(filter(lambda s: s["type"] in known_types, segments))
return segment
@override
def get_segment(self, collection_id: UUID, type: Type[S]) -> S:
if type == MetadataReader:
Expand All @@ -149,17 +181,15 @@ def get_segment(self, collection_id: UUID, type: Type[S]) -> S:
else:
raise ValueError(f"Invalid segment type: {type}")

if scope not in self._segment_cache[collection_id]:
segments = self._sysdb.get_segments(collection=collection_id, scope=scope)
known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()])
# Get the first segment of a known type
segment = next(filter(lambda s: s["type"] in known_types, segments))
self._segment_cache[collection_id][scope] = segment
segment = self.segment_cache[scope].get(collection_id)
if segment is None:
segment = self._get_segment_sysdb(collection_id, scope)
self.segment_cache[scope].set(collection_id, segment)

# Instances must be atomically created, so we use a lock to ensure that only one thread
# creates the instance.
with self._lock:
instance = self._instance(self._segment_cache[collection_id][scope])
instance = self._instance(segment)
return cast(S, instance)

@trace_method(
Expand Down Expand Up @@ -208,5 +238,5 @@ def _segment(type: SegmentType, scope: SegmentScope, collection: Collection) ->
scope=scope,
topic=collection["topic"],
collection=collection["id"],
metadata=metadata,
metadata=metadata
)
2 changes: 1 addition & 1 deletion chromadb/test/db/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def test_update_segment(sysdb: SysDB) -> None:
scope=SegmentScope.VECTOR,
topic="test_topic_a",
collection=sample_collections[0]["id"],
metadata=metadata,
metadata=metadata
)

sysdb.reset_state()
Expand Down
6 changes: 5 additions & 1 deletion chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hypothesis.strategies as st
from typing import Any, Optional, List, Dict, Union, cast
from typing_extensions import TypedDict
import uuid
import numpy as np
import numpy.typing as npt
import chromadb.api.types as types
Expand Down Expand Up @@ -237,16 +238,17 @@ def embedding_function_strategy(
@dataclass
class Collection:
name: str
id: uuid.UUID
metadata: Optional[types.Metadata]
dimension: int
dtype: npt.DTypeLike
topic: str
known_metadata_keys: types.Metadata
known_document_keywords: List[str]
has_documents: bool = False
has_embeddings: bool = False
embedding_function: Optional[types.EmbeddingFunction[Embeddable]] = None


@st.composite
def collections(
draw: st.DrawFn,
Expand Down Expand Up @@ -309,7 +311,9 @@ def collections(
embedding_function = draw(embedding_function_strategy(dimension, dtype))

return Collection(
id=uuid.uuid4(),
name=name,
topic="topic",
metadata=metadata,
dimension=dimension,
dtype=dtype,
Expand Down
Loading

0 comments on commit e5751fd

Please sign in to comment.