diff --git a/test/integration/test_sqlite.py b/test/integration/test_sqlite.py index 4badcb1..d86315d 100644 --- a/test/integration/test_sqlite.py +++ b/test/integration/test_sqlite.py @@ -1,5 +1,6 @@ import asyncio import os +from contextlib import asynccontextmanager from tempfile import gettempdir from unittest.mock import MagicMock, patch @@ -53,17 +54,22 @@ async def test_concurrent_bulk_commit(self, mock_sqlite): mock_connection = AsyncMock() mock_sqlite.connect = AsyncMock(return_value=mock_connection) - async with self.init_cache() as cache: + @asynccontextmanager + async def bulk_commit_ctx(): + async with self.init_cache() as cache: + + async def bulk_commit_items(n_items): + async with cache.bulk_commit(): + for i in range(n_items): + await cache.write(f'key_{n_items}_{i}', f'value_{i}') - async def bulk_commit_items(n_items): - async with cache.bulk_commit(): - for i in range(n_items): - await cache.write(f'key_{n_items}_{i}', f'value_{i}') + yield bulk_commit_items - assert mock_connection.commit.call_count == 1 - tasks = [asyncio.create_task(bulk_commit_items(n)) for n in [10, 100, 1000, 10000]] - await asyncio.gather(*tasks) - assert mock_connection.commit.call_count == 5 + async with bulk_commit_ctx() as bulk_commit_items: + assert mock_connection.commit.call_count == 1 + tasks = [asyncio.create_task(bulk_commit_items(n)) for n in [10, 100, 1000, 10000]] + await asyncio.gather(*tasks) + assert mock_connection.commit.call_count == 5 async def test_fast_save(self): async with self.init_cache(index=1, fast_save=True) as cache_1, self.init_cache(