diff --git a/chromadb/db/impl/sqlite_pool.py b/chromadb/db/impl/sqlite_pool.py index 83a3edf104b..5444d71b4bc 100644 --- a/chromadb/db/impl/sqlite_pool.py +++ b/chromadb/db/impl/sqlite_pool.py @@ -1,8 +1,10 @@ import sqlite3 +import weakref from abc import ABC, abstractmethod from typing import Any, Set import threading from overrides import override +from typing_extensions import Annotated class Connection: @@ -70,7 +72,7 @@ class LockPool(Pool): shared cache mode. We use the shared cache mode to allow multiple threads to share a database. """ - _connections: Set[Connection] + _connections: Set[Annotated[weakref.ReferenceType, Connection]] _lock: threading.RLock _connection: threading.local _db_file: str @@ -93,7 +95,7 @@ def connect(self, *args: Any, **kwargs: Any) -> Connection: self, self._db_file, self._is_uri, *args, **kwargs ) self._connection.conn = new_connection - self._connections.add(new_connection) + self._connections.add(weakref.ref(new_connection)) return new_connection @override @@ -106,7 +108,8 @@ def return_to_pool(self, conn: Connection) -> None: @override def close(self) -> None: for conn in self._connections: - conn.close_actual() + if conn() is not None: + conn().close_actual() # type: ignore self._connections.clear() self._connection = threading.local() try: @@ -120,7 +123,7 @@ class PerThreadPool(Pool): extended to do so and block on connect() if the cap is reached. """ - _connections: Set[Connection] + _connections: Set[Annotated[weakref.ReferenceType, Connection]] _lock: threading.Lock _connection: threading.local _db_file: str @@ -143,14 +146,15 @@ def connect(self, *args: Any, **kwargs: Any) -> Connection: ) self._connection.conn = new_connection with self._lock: - self._connections.add(new_connection) + self._connections.add(weakref.ref(new_connection)) return new_connection @override def close(self) -> None: with self._lock: for conn in self._connections: - conn.close_actual() + if conn() is not None: + conn().close_actual() # type: ignore self._connections.clear() self._connection = threading.local() diff --git a/chromadb/test/property/invariants.py b/chromadb/test/property/invariants.py index c992f18f5a0..a7bd2c6cade 100644 --- a/chromadb/test/property/invariants.py +++ b/chromadb/test/property/invariants.py @@ -1,4 +1,9 @@ +import gc import math +from time import sleep + +import psutil + from chromadb.test.property.strategies import NormalizedRecordSet, RecordSet from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast from typing_extensions import Literal @@ -163,6 +168,29 @@ def _exact_distances( return np.argsort(distances).tolist(), distances.tolist() +def fd_not_exceeding_threadpool_size(threadpool_size: int) -> None: + """ + Checks that the open file descriptors are not exceeding the threadpool size + works only for SegmentAPI + """ + current_process = psutil.Process() + open_files = current_process.open_files() + max_retries = 5 + retry_count = 0 + # we probably don't need the below but we keep it to avoid flaky tests. + while ( + len([p.path for p in open_files if "sqlite3" in p.path]) - 1 > threadpool_size + and retry_count < max_retries + ): + gc.collect() # GC to collect the orphaned TLS objects + open_files = current_process.open_files() + retry_count += 1 + sleep(1) + assert ( + len([p.path for p in open_files if "sqlite3" in p.path]) - 1 <= threadpool_size + ) + + def ann_accuracy( collection: Collection, record_set: RecordSet, diff --git a/chromadb/test/test_multithreaded.py b/chromadb/test/test_multithreaded.py index 7cad62a07fe..745f562f4b0 100644 --- a/chromadb/test/test_multithreaded.py +++ b/chromadb/test/test_multithreaded.py @@ -7,6 +7,7 @@ from chromadb.api import ClientAPI import chromadb.test.property.invariants as invariants +from chromadb.api.segment import SegmentAPI from chromadb.test.property.strategies import RecordSet from chromadb.test.property.strategies import test_hnsw_config from chromadb.types import Metadata @@ -193,7 +194,10 @@ def perform_operation( exception = future.exception() if exception is not None: raise exception - + if ( + isinstance(client, SegmentAPI) and client.get_settings().is_persistent is True + ): # we can't check invariants for FastAPI + invariants.fd_not_exceeding_threadpool_size(num_workers) # Check that invariants hold invariants.count(coll, records_set) invariants.ids_match(coll, records_set) diff --git a/requirements_dev.txt b/requirements_dev.txt index 53d311409fe..15333a4ce40 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -6,6 +6,7 @@ hypothesis>=6.103.1 hypothesis[numpy]>=6.103.1 mypy-protobuf pre-commit +psutil pytest pytest-asyncio pytest-xdist