Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Connection pool FD leak v2 #2014

Merged
merged 5 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions chromadb/db/impl/sqlite_pool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import sqlite3
import weakref
from abc import ABC, abstractmethod
from typing import Any, Set
import threading
from overrides import override
from typing_extensions import Annotated


class Connection:
Expand Down Expand Up @@ -70,7 +72,7 @@ class LockPool(Pool):
shared cache mode. We use the shared cache mode to allow multiple threads to share a database.
"""

_connections: Set[Connection]
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
_lock: threading.RLock
_connection: threading.local
_db_file: str
Expand All @@ -93,7 +95,7 @@ def connect(self, *args: Any, **kwargs: Any) -> Connection:
self, self._db_file, self._is_uri, *args, **kwargs
)
self._connection.conn = new_connection
self._connections.add(new_connection)
self._connections.add(weakref.ref(new_connection))
return new_connection

@override
Expand All @@ -106,7 +108,8 @@ def return_to_pool(self, conn: Connection) -> None:
@override
def close(self) -> None:
for conn in self._connections:
conn.close_actual()
if conn() is not None:
conn().close_actual() # type: ignore
self._connections.clear()
self._connection = threading.local()
try:
Expand All @@ -120,7 +123,7 @@ class PerThreadPool(Pool):
extended to do so and block on connect() if the cap is reached.
"""

_connections: Set[Connection]
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
_lock: threading.Lock
_connection: threading.local
_db_file: str
Expand All @@ -143,14 +146,15 @@ def connect(self, *args: Any, **kwargs: Any) -> Connection:
)
self._connection.conn = new_connection
with self._lock:
self._connections.add(new_connection)
self._connections.add(weakref.ref(new_connection))
return new_connection

@override
def close(self) -> None:
with self._lock:
for conn in self._connections:
conn.close_actual()
if conn() is not None:
conn().close_actual() # type: ignore
self._connections.clear()
self._connection = threading.local()

Expand Down
28 changes: 28 additions & 0 deletions chromadb/test/property/invariants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import gc
import math
from time import sleep

import psutil

from chromadb.test.property.strategies import NormalizedRecordSet, RecordSet
from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast
from typing_extensions import Literal
Expand Down Expand Up @@ -163,6 +168,29 @@ def _exact_distances(
return np.argsort(distances).tolist(), distances.tolist()


def fd_not_exceeding_threadpool_size(threadpool_size: int) -> None:
"""
Checks that the open file descriptors are not exceeding the threadpool size
works only for SegmentAPI
"""
current_process = psutil.Process()
open_files = current_process.open_files()
max_retries = 5
retry_count = 0
# we probably don't need the below but we keep it to avoid flaky tests.
while (
len([p.path for p in open_files if "sqlite3" in p.path]) - 1 > threadpool_size
and retry_count < max_retries
):
gc.collect() # GC to collect the orphaned TLS objects
open_files = current_process.open_files()
retry_count += 1
sleep(1)
assert (
len([p.path for p in open_files if "sqlite3" in p.path]) - 1 <= threadpool_size
)


def ann_accuracy(
collection: Collection,
record_set: RecordSet,
Expand Down
6 changes: 5 additions & 1 deletion chromadb/test/test_multithreaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from chromadb.api import ClientAPI
import chromadb.test.property.invariants as invariants
from chromadb.api.segment import SegmentAPI
from chromadb.test.property.strategies import RecordSet
from chromadb.test.property.strategies import test_hnsw_config
from chromadb.types import Metadata
Expand Down Expand Up @@ -193,7 +194,10 @@ def perform_operation(
exception = future.exception()
if exception is not None:
raise exception

if (
isinstance(client, SegmentAPI) and client.get_settings().is_persistent is True
): # we can't check invariants for FastAPI
invariants.fd_not_exceeding_threadpool_size(num_workers)
# Check that invariants hold
invariants.count(coll, records_set)
invariants.ids_match(coll, records_set)
Expand Down
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ hypothesis>=6.103.1
hypothesis[numpy]>=6.103.1
mypy-protobuf
pre-commit
psutil
pytest
pytest-asyncio
pytest-xdist
Expand Down
Loading