Skip to content

Commit

Permalink
remove extra await when calling ws_connect()
Browse files Browse the repository at this point in the history
The ClientSession.ws_connect() method is synchronous and returns a
_RequestContextManager which takes a coroutine as parameter (here,
ClientSession._ws_connect()).

This context manager is in charge of closing the connection in its __aexit__()
method, so it has to be used with "async with".

However, this context manager can also be awaited as it has an __await__()
method. In this case, it will await the _ws_connect() coroutine. This is what is
done in the current code, but the connection will not be released.

Remove the "await" to return the context manager, so that the user can use it
with "async with", which will properly release resources.

This is the documented way of using ws_connect():
https://docs.aiohttp.org/en/stable/client_quickstart.html#websockets

Signed-off-by: Olivier Matz <[email protected]>
  • Loading branch information
olivier-matz-6wind committed Aug 9, 2024
1 parent 9cb98d0 commit c3a3128
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
54 changes: 27 additions & 27 deletions examples/pod_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def main():
async with WsApiClient() as ws_api:
v1_ws = client.CoreV1Api(api_client=ws_api)
exec_command = ['/bin/sh']
ws = await v1_ws.connect_get_namespaced_pod_exec(
websocket = await v1_ws.connect_get_namespaced_pod_exec(
BUSYBOX_POD,
"default",
command=exec_command,
Expand All @@ -116,32 +116,32 @@ async def main():
]
error_data = ""
closed = False
while commands and not closed:
command = commands.pop(0)
stdin_channel_prefix = chr(0)
await ws.send_bytes((stdin_channel_prefix + command).encode("utf-8"))
while True:
try:
msg = await ws.receive(timeout=1)
except asyncio.TimeoutError:
break
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
closed = True
break
channel = msg.data[0]
data = msg.data[1:].decode("utf-8")
if not data:
continue
if channel == STDOUT_CHANNEL:
print(f"stdout: {data}")
elif channel == STDERR_CHANNEL:
print(f"stderr: {data}")
elif channel == ERROR_CHANNEL:
error_data += data
if error_data:
returncode = ws_api.parse_error_data(error_data)
print(f"Exit code: {returncode}")
await ws.close()
async with websocket as ws:
while commands and not closed:
command = commands.pop(0)
stdin_channel_prefix = chr(0)
await ws.send_bytes((stdin_channel_prefix + command).encode("utf-8"))
while True:
try:
msg = await ws.receive(timeout=1)
except asyncio.TimeoutError:
break
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
closed = True
break
channel = msg.data[0]
data = msg.data[1:].decode("utf-8")
if not data:
continue
if channel == STDOUT_CHANNEL:
print(f"stdout: {data}")
elif channel == STDERR_CHANNEL:
print(f"stderr: {data}")
elif channel == ERROR_CHANNEL:
error_data += data
if error_data:
returncode = ws_api.parse_error_data(error_data)
print(f"Exit code: {returncode}")

if __name__ == "__main__":
loop = asyncio.get_event_loop()
Expand Down
2 changes: 1 addition & 1 deletion kubernetes_asyncio/stream/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,4 @@ async def request(self, method, url, query_params=None, headers=None,

else:

return await self.rest_client.pool_manager.ws_connect(url, headers=headers, heartbeat=self.heartbeat)
return self.rest_client.pool_manager.ws_connect(url, headers=headers, heartbeat=self.heartbeat)

Check warning on line 111 in kubernetes_asyncio/stream/ws_client.py

View check run for this annotation

Codecov / codecov/patch

kubernetes_asyncio/stream/ws_client.py#L111

Added line #L111 was not covered by tests

0 comments on commit c3a3128

Please sign in to comment.