Skip to content

Commit

Permalink
Merge pull request #189 from requests-cache/sqlite-nocontext-autoclose
Browse files Browse the repository at this point in the history
Close SQLite connection if session is deleted and thread is still running
  • Loading branch information
JWCook authored Oct 12, 2023
2 parents 4d896b6 + 52f3442 commit f1cce5a
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 29 deletions.
14 changes: 14 additions & 0 deletions aiohttp_client_cache/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@ async def _init_db(self):
)
return self._connection

def __del__(self):
"""If the aiosqlite connection is still open when this object is deleted, force its thread
to close by emptying its internal queue and setting its ``_running`` flag to ``False``.
This is basically a last resort to avoid hanging the application if this backend is used
without the CachedSession contextmanager.
Note: Since this uses internal attributes, it has the potential to break in future versions
of aiosqlite.
"""
if self._connection is not None:
self._connection._tx.queue.clear()
self._connection._running = False
self._connection = None

@asynccontextmanager
async def bulk_commit(self):
"""Contextmanager to more efficiently write a large number of records at once
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
def test(session):
"""Run tests for a specific python version"""
test_paths = session.posargs or [UNIT_TESTS]
session.install('.', 'pytest', 'pytest-xdist', 'requests-mock', 'timeout-decorator')
session.install('.', 'pytest', 'pytest-aiohttp', 'pytest-asyncio', 'pytest-xdist')

cmd = f'pytest -rs {XDIST_ARGS}'
session.run(*cmd.split(' '), *test_paths)
Expand Down
11 changes: 7 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ docs = ["furo", "linkify-it-py", "markdown-it-py", "myst-parser", "python

[tool.poetry.dev-dependencies]
# For unit + integration tests
async-timeout = ">=4.0"
brotli = ">=1.0"
pytest = ">=6.2"
pytest-aiohttp = "^0.3"
Expand Down
23 changes: 21 additions & 2 deletions test/integration/base_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from uuid import uuid4

import pytest
from async_timeout import timeout
from itsdangerous.exc import BadSignature
from itsdangerous.serializer import Serializer

Expand Down Expand Up @@ -35,13 +36,17 @@ class BaseBackendTest:

@asynccontextmanager
async def init_session(self, clear=True, **kwargs) -> AsyncIterator[CachedSession]:
session = await self._init_session(clear=clear, **kwargs)
async with session:
yield session

async def _init_session(self, clear=True, **kwargs) -> CachedSession:
kwargs.setdefault('allowed_methods', ALL_METHODS)
cache = self.backend_class(CACHE_NAME, **self.init_kwargs, **kwargs)
if clear:
await cache.clear()

async with CachedSession(cache=cache, **self.init_kwargs, **kwargs) as session:
yield session
return CachedSession(cache=cache, **self.init_kwargs, **kwargs)

@pytest.mark.parametrize('method', HTTPBIN_METHODS)
@pytest.mark.parametrize('field', ['params', 'data', 'json'])
Expand Down Expand Up @@ -100,6 +105,20 @@ async def get_url(mysession, url):
responses = await asyncio.gather(*tasks)
assert all([r.from_cache is True for r in responses])

async def test_without_contextmanager(self):
"""Test that the cache backend can be safely used without the CachedSession contextmanager.
An "unclosed ClientSession" warning is expected here, however.
"""
# Timeout to avoid hanging if the test fails
async with timeout(5.0):
session = await self._init_session()
await session.get(httpbin('get'))
del session

session = await self._init_session(clear=False)
r = await session.get(httpbin('get'))
assert r.from_cache is True

async def test_request__expire_after(self):
async with self.init_session() as session:
await session.get(httpbin('get'), expire_after=1)
Expand Down
8 changes: 8 additions & 0 deletions test/integration/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ async def test_content_reset(self):
content_2 = await cached_response_2.read()
assert content_1 == content_2 == original_content

async def test_without_contextmanager(self):
"""Test that the cache backend can be safely used without the CachedSession contextmanager.
An "unclosed ClientSession" warning is expected here, however.
"""
session = await self._init_session()
await session.get(httpbin('get'))
del session

# Serialization tests don't apply to in-memory cache
async def test_serializer__pickle(self):
pass
Expand Down
24 changes: 15 additions & 9 deletions test/integration/test_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
from contextlib import asynccontextmanager
from tempfile import gettempdir
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -53,17 +54,22 @@ async def test_concurrent_bulk_commit(self, mock_sqlite):
mock_connection = AsyncMock()
mock_sqlite.connect = AsyncMock(return_value=mock_connection)

async with self.init_cache() as cache:
@asynccontextmanager
async def bulk_commit_ctx():
async with self.init_cache() as cache:

async def bulk_commit_items(n_items):
async with cache.bulk_commit():
for i in range(n_items):
await cache.write(f'key_{n_items}_{i}', f'value_{i}')

async def bulk_commit_items(n_items):
async with cache.bulk_commit():
for i in range(n_items):
await cache.write(f'key_{n_items}_{i}', f'value_{i}')
yield bulk_commit_items

assert mock_connection.commit.call_count == 1
tasks = [asyncio.create_task(bulk_commit_items(n)) for n in [10, 100, 1000, 10000]]
await asyncio.gather(*tasks)
assert mock_connection.commit.call_count == 5
async with bulk_commit_ctx() as bulk_commit_items:
assert mock_connection.commit.call_count == 1
tasks = [asyncio.create_task(bulk_commit_items(n)) for n in [10, 100, 1000, 10000]]
await asyncio.gather(*tasks)
assert mock_connection.commit.call_count == 5

async def test_fast_save(self):
async with self.init_cache(index=1, fast_save=True) as cache_1, self.init_cache(
Expand Down
21 changes: 8 additions & 13 deletions test/unit/test_base_backend.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import pickle
from sys import version_info
from unittest.mock import MagicMock, patch

import pytest

from aiohttp_client_cache import CachedResponse
from aiohttp_client_cache.backends import CacheBackend, DictCache, get_placeholder_backend
from test.conftest import skip_37

TEST_URL = 'https://test.com'

pytestmark = pytest.mark.asyncio
skip_py37 = pytest.mark.skipif(
version_info < (3, 8), reason='Test requires AsyncMock from python 3.8+'
)


def get_mock_response(**kwargs):
response_kwargs = {
Expand Down Expand Up @@ -71,7 +66,7 @@ async def test_get_response__cache_miss(mock_delete):
mock_delete.assert_not_called()


@skip_py37
@skip_37
@patch.object(CacheBackend, 'delete')
@patch.object(CacheBackend, 'is_cacheable', return_value=False)
async def test_get_response__cache_expired(mock_is_cacheable, mock_delete):
Expand All @@ -84,7 +79,7 @@ async def test_get_response__cache_expired(mock_is_cacheable, mock_delete):
mock_delete.assert_called_with('request-key')


@skip_py37
@skip_37
@pytest.mark.parametrize('error_type', [AttributeError, KeyError, TypeError, pickle.PickleError])
@patch.object(CacheBackend, 'delete')
@patch.object(DictCache, 'read')
Expand All @@ -99,7 +94,7 @@ async def test_get_response__cache_invalid(mock_read, mock_delete, error_type):
mock_delete.assert_not_called()


@skip_py37
@skip_37
@patch.object(DictCache, 'read', return_value=object())
async def test_get_response__quiet_serde_error(mock_read):
"""Test for a quiet deserialization error in which no errors are raised but attributes are
Expand All @@ -113,7 +108,7 @@ async def test_get_response__quiet_serde_error(mock_read):
assert response is None


@skip_py37
@skip_37
async def test_save_response():
cache = CacheBackend()
mock_response = get_mock_response()
Expand All @@ -126,7 +121,7 @@ async def test_save_response():
assert await cache.redirects.read(redirect_key) == 'key'


@skip_py37
@skip_37
async def test_save_response__manual_save():
"""Manually save a response with no cache key provided"""
cache = CacheBackend()
Expand Down Expand Up @@ -193,7 +188,7 @@ async def test_has_url():
assert not await cache.has_url('https://test.com/some_other_path')


@skip_py37
@skip_37
@patch('aiohttp_client_cache.backends.base.create_key')
async def test_create_key(mock_create_key):
"""Actual logic is in cache_keys module; just test to make sure it gets called correctly"""
Expand Down Expand Up @@ -244,7 +239,7 @@ async def test_is_cacheable(method, status, disabled, expired, filter_return, ex
assert await cache.is_cacheable(mock_response) is expected_result


@skip_py37
@skip_37
@pytest.mark.parametrize(
'method, status, disabled, expired, body, expected_result',
[
Expand Down

0 comments on commit f1cce5a

Please sign in to comment.