Skip to content

Commit

Permalink
[ENH] Connection pool FD leak v2 (#2014)
Browse files Browse the repository at this point in the history
Closes #1379

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Using weakrefs in pools' `connections` set.

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python, `yarn test` for js

## Documentation Changes
N/A

## Refs

- https://peps.python.org/pep-0567/
  • Loading branch information
tazarov authored Jul 26, 2024
1 parent fbb4ef4 commit 4b2a033
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 7 deletions.
16 changes: 10 additions & 6 deletions chromadb/db/impl/sqlite_pool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import sqlite3
import weakref
from abc import ABC, abstractmethod
from typing import Any, Set
import threading
from overrides import override
from typing_extensions import Annotated


class Connection:
Expand Down Expand Up @@ -70,7 +72,7 @@ class LockPool(Pool):
shared cache mode. We use the shared cache mode to allow multiple threads to share a database.
"""

_connections: Set[Connection]
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
_lock: threading.RLock
_connection: threading.local
_db_file: str
Expand All @@ -93,7 +95,7 @@ def connect(self, *args: Any, **kwargs: Any) -> Connection:
self, self._db_file, self._is_uri, *args, **kwargs
)
self._connection.conn = new_connection
self._connections.add(new_connection)
self._connections.add(weakref.ref(new_connection))
return new_connection

@override
Expand All @@ -106,7 +108,8 @@ def return_to_pool(self, conn: Connection) -> None:
@override
def close(self) -> None:
for conn in self._connections:
conn.close_actual()
if conn() is not None:
conn().close_actual() # type: ignore
self._connections.clear()
self._connection = threading.local()
try:
Expand All @@ -120,7 +123,7 @@ class PerThreadPool(Pool):
extended to do so and block on connect() if the cap is reached.
"""

_connections: Set[Connection]
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
_lock: threading.Lock
_connection: threading.local
_db_file: str
Expand All @@ -143,14 +146,15 @@ def connect(self, *args: Any, **kwargs: Any) -> Connection:
)
self._connection.conn = new_connection
with self._lock:
self._connections.add(new_connection)
self._connections.add(weakref.ref(new_connection))
return new_connection

@override
def close(self) -> None:
with self._lock:
for conn in self._connections:
conn.close_actual()
if conn() is not None:
conn().close_actual() # type: ignore
self._connections.clear()
self._connection = threading.local()

Expand Down
28 changes: 28 additions & 0 deletions chromadb/test/property/invariants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import gc
import math
from time import sleep

import psutil

from chromadb.test.property.strategies import NormalizedRecordSet, RecordSet
from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast
from typing_extensions import Literal
Expand Down Expand Up @@ -163,6 +168,29 @@ def _exact_distances(
return np.argsort(distances).tolist(), distances.tolist()


def fd_not_exceeding_threadpool_size(threadpool_size: int) -> None:
"""
Checks that the open file descriptors are not exceeding the threadpool size
works only for SegmentAPI
"""
current_process = psutil.Process()
open_files = current_process.open_files()
max_retries = 5
retry_count = 0
# we probably don't need the below but we keep it to avoid flaky tests.
while (
len([p.path for p in open_files if "sqlite3" in p.path]) - 1 > threadpool_size
and retry_count < max_retries
):
gc.collect() # GC to collect the orphaned TLS objects
open_files = current_process.open_files()
retry_count += 1
sleep(1)
assert (
len([p.path for p in open_files if "sqlite3" in p.path]) - 1 <= threadpool_size
)


def ann_accuracy(
collection: Collection,
record_set: RecordSet,
Expand Down
6 changes: 5 additions & 1 deletion chromadb/test/test_multithreaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from chromadb.api import ClientAPI
import chromadb.test.property.invariants as invariants
from chromadb.api.segment import SegmentAPI
from chromadb.test.property.strategies import RecordSet
from chromadb.test.property.strategies import test_hnsw_config
from chromadb.types import Metadata
Expand Down Expand Up @@ -193,7 +194,10 @@ def perform_operation(
exception = future.exception()
if exception is not None:
raise exception

if (
isinstance(client, SegmentAPI) and client.get_settings().is_persistent is True
): # we can't check invariants for FastAPI
invariants.fd_not_exceeding_threadpool_size(num_workers)
# Check that invariants hold
invariants.count(coll, records_set)
invariants.ids_match(coll, records_set)
Expand Down
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ hypothesis>=6.103.1
hypothesis[numpy]>=6.103.1
mypy-protobuf
pre-commit
psutil
pytest
pytest-asyncio
pytest-xdist
Expand Down

0 comments on commit 4b2a033

Please sign in to comment.