Skip to content

Commit

Permalink
remove extra await when calling ws_connect()
Browse files Browse the repository at this point in the history
The ClientSession.ws_connect() method is synchronous and returns a
_RequestContextManager which takes a coroutine as parameter (here,
ClientSession._ws_connect()).

This context manager is in charge of closing the connection in its __aexit__()
method, so it has to be used with "async with".

However, this context manager can also be awaited as it has an __await__()
method. In this case, it will await the _ws_connect() coroutine. This is what is
done in the current code, but the connection will not be released.

Remove the "await" to return the context manager, so that the user can use it
with "async with", which will properly release resources.

This is the documented way of using ws_connect():
https://docs.aiohttp.org/en/stable/client_quickstart.html#websockets

Signed-off-by: Olivier Matz <[email protected]>
  • Loading branch information
olivier-matz-6wind committed Aug 7, 2024
1 parent 6ff3ed2 commit d7ec65d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
54 changes: 27 additions & 27 deletions examples/pod_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def main():
async with WsApiClient() as ws_api:
v1_ws = client.CoreV1Api(api_client=ws_api)
exec_command = ['/bin/sh']
ws = await v1_ws.connect_get_namespaced_pod_exec(
websocket = await v1_ws.connect_get_namespaced_pod_exec(
BUSYBOX_POD,
"default",
command=exec_command,
Expand All @@ -117,32 +117,32 @@ async def main():
]
error_data = ""
closed = False
while commands and not closed:
command = commands.pop(0)
stdin_channel_prefix = chr(0)
await ws.send_bytes((stdin_channel_prefix + command).encode("utf-8"))
while True:
try:
msg = await ws.receive(timeout=1)
except asyncio.TimeoutError:
break
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
closed = True
break
channel = msg.data[0]
data = msg.data[1:].decode("utf-8")
if not data:
continue
if channel == STDOUT_CHANNEL:
print(f"stdout: {data}")
elif channel == STDERR_CHANNEL:
print(f"stderr: {data}")
elif channel == ERROR_CHANNEL:
error_data += data
if error_data:
returncode = ws_api.parse_error_data(error_data)
print(f"Exit code: {returncode}")
await ws.close()
async with websocket as ws:
while commands and not closed:
command = commands.pop(0)
stdin_channel_prefix = chr(0)
await ws.send_bytes((stdin_channel_prefix + command).encode("utf-8"))
while True:
try:
msg = await ws.receive(timeout=1)
except asyncio.TimeoutError:
break
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
closed = True
break
channel = msg.data[0]
data = msg.data[1:].decode("utf-8")
if not data:
continue
if channel == STDOUT_CHANNEL:
print(f"stdout: {data}")
elif channel == STDERR_CHANNEL:
print(f"stderr: {data}")
elif channel == ERROR_CHANNEL:
error_data += data
if error_data:
returncode = ws_api.parse_error_data(error_data)
print(f"Exit code: {returncode}")

if __name__ == "__main__":
loop = asyncio.get_event_loop()
Expand Down
2 changes: 1 addition & 1 deletion kubernetes_asyncio/stream/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,4 @@ async def request(self, method, url, query_params=None, headers=None,

else:

return await self.rest_client.pool_manager.ws_connect(url, headers=headers, heartbeat=self.heartbeat)
return self.rest_client.pool_manager.ws_connect(url, headers=headers, heartbeat=self.heartbeat)

Check warning on line 120 in kubernetes_asyncio/stream/ws_client.py

View check run for this annotation

Codecov / codecov/patch

kubernetes_asyncio/stream/ws_client.py#L120

Added line #L120 was not covered by tests

0 comments on commit d7ec65d

Please sign in to comment.