From 2981ce30b8c9a1c3d91dda8977d64ee3488d6ce4 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 30 Jan 2023 17:55:55 +0100 Subject: [PATCH] Don't try to close the aiohttp session if connector_owner is False (#382) --- gql/transport/aiohttp.py | 21 +++++++++++++------ tests/test_aiohttp.py | 44 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 2b155870..6dc0a409 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -183,12 +183,21 @@ async def close(self) -> None: log.debug("Closing transport") - closed_event = self.create_aiohttp_closed_event(self.session) - await self.session.close() - try: - await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) - except asyncio.TimeoutError: - pass + if ( + self.client_session_args + and self.client_session_args.get("connector_owner") is False + ): + + log.debug("connector_owner is False -> not closing connector") + + else: + closed_event = self.create_aiohttp_closed_event(self.session) + await self.session.close() + try: + await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) + except asyncio.TimeoutError: + pass + self.session = None async def execute( diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 9a62a65c..27af1438 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1441,3 +1441,47 @@ async def handler(request): # Checking that there is no space after the colon in the log expected_log = '"query":"query getContinents' assert expected_log in caplog.text + + +@pytest.mark.asyncio +async def test_aiohttp_connector_owner_false(event_loop, aiohttp_server): + from aiohttp import web, TCPConnector + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + connector = TCPConnector() + transport = AIOHTTPTransport( + url=url, + timeout=10, + client_session_args={ + "connector": connector, + "connector_owner": False, + }, + ) + + for _ in range(2): + async with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + await connector.close()