diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 977158cd0..3bffc85d6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,6 +41,7 @@ jobs: path: tests/requirements.txt - name: Run mypy run: | + pip install -U setuptools pip install -r tests/requirements-mypy.txt mypy - name: Run linter diff --git a/CHANGES/1256.feature b/CHANGES/1256.feature new file mode 100644 index 000000000..1ebf626c7 --- /dev/null +++ b/CHANGES/1256.feature @@ -0,0 +1 @@ +Add auto_close_connection_pool for Redis-created connection pools, not manually created pools diff --git a/aioredis/client.py b/aioredis/client.py index 3b5e091d5..a284ec88f 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -858,8 +858,16 @@ def __init__( health_check_interval: int = 0, client_name: Optional[str] = None, username: Optional[str] = None, + auto_close_connection_pool: bool = True, ): kwargs: Dict[str, Any] + # auto_close_connection_pool only has an effect if connection_pool is + # None. This is a similar feature to the missing __del__ to resolve #1103, + # but it accounts for whether a user wants to manually close the connection + # pool, as a similar feature to ConnectionPool's __del__. + self.auto_close_connection_pool = ( + auto_close_connection_pool if connection_pool is None else False + ) if not connection_pool: kwargs = { "db": db, @@ -1067,11 +1075,22 @@ def __del__(self, _warnings: Any = warnings) -> None: context = {"client": self, "message": self._DEL_MESSAGE} asyncio.get_event_loop().call_exception_handler(context) - async def close(self): + async def close(self, close_connection_pool: Optional[bool] = None) -> None: + """ + Closes Redis client connection + + :param close_connection_pool: decides whether to close the connection pool used + by this Redis client, overriding Redis.auto_close_connection_pool. By default, + let Redis.auto_close_connection_pool decide whether to close the connection pool. + """ conn = self.connection if conn: self.connection = None await self.connection_pool.release(conn) + if close_connection_pool or ( + close_connection_pool is None and self.auto_close_connection_pool + ): + await self.connection_pool.disconnect() # COMMAND EXECUTION AND PROTOCOL PARSING async def execute_command(self, *args, **options): diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 1931e2048..75f89a354 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -14,6 +14,83 @@ pytestmark = pytest.mark.asyncio +class TestRedisAutoReleaseConnectionPool: + @pytest.fixture + async def r(self, create_redis) -> aioredis.Redis: + """This is necessary since r and r2 create ConnectionPools behind the scenes""" + r = await create_redis() + r.auto_close_connection_pool = True + yield r + + @staticmethod + def get_total_connected_connections(pool): + return len(pool._available_connections) + len(pool._in_use_connections) + + @staticmethod + async def create_two_conn(r: aioredis.Redis): + if not r.single_connection_client: # Single already initialized connection + r.connection = await r.connection_pool.get_connection("_") + return await r.connection_pool.get_connection("_") + + @staticmethod + def has_no_connected_connections(pool: aioredis.ConnectionPool): + return not any( + x.is_connected + for x in pool._available_connections + list(pool._in_use_connections) + ) + + async def test_auto_disconnect_redis_created_pool(self, r: aioredis.Redis): + new_conn = await self.create_two_conn(r) + assert new_conn != r.connection + assert self.get_total_connected_connections(r.connection_pool) == 2 + await r.close() + assert self.has_no_connected_connections(r.connection_pool) + + async def test_do_not_auto_disconnect_redis_created_pool(self, r2: aioredis.Redis): + assert r2.auto_close_connection_pool is False, ( + "The connection pool should not be disconnected as a manually created " + "connection pool was passed in in conftest.py" + ) + new_conn = await self.create_two_conn(r2) + assert self.get_total_connected_connections(r2.connection_pool) == 2 + await r2.close() + assert r2.connection_pool._in_use_connections == {new_conn} + assert new_conn.is_connected + assert len(r2.connection_pool._available_connections) == 1 + assert r2.connection_pool._available_connections[0].is_connected + + async def test_auto_release_override_true_manual_created_pool( + self, r: aioredis.Redis + ): + assert r.auto_close_connection_pool is True, "This is from the class fixture" + await self.create_two_conn(r) + await r.close() + assert self.get_total_connected_connections(r.connection_pool) == 2, ( + "The connection pool should not be disconnected as a manually created " + "connection pool was passed in in conftest.py" + ) + assert self.has_no_connected_connections(r.connection_pool) + + @pytest.mark.parametrize("auto_close_conn_pool", [True, False]) + async def test_close_override(self, r: aioredis.Redis, auto_close_conn_pool): + r.auto_close_connection_pool = auto_close_conn_pool + await self.create_two_conn(r) + await r.close(close_connection_pool=True) + assert self.has_no_connected_connections(r.connection_pool) + + @pytest.mark.parametrize("auto_close_conn_pool", [True, False]) + async def test_negate_auto_close_client_pool( + self, r: aioredis.Redis, auto_close_conn_pool + ): + r.auto_close_connection_pool = auto_close_conn_pool + new_conn = await self.create_two_conn(r) + await r.close(close_connection_pool=False) + assert not self.has_no_connected_connections(r.connection_pool) + assert r.connection_pool._in_use_connections == {new_conn} + assert r.connection_pool._available_connections[0].is_connected + assert self.get_total_connected_connections(r.connection_pool) == 2 + + class DummyConnection(Connection): description_format = "DummyConnection<>"