diff --git a/juju/client/connection.py b/juju/client/connection.py index ea285a14c..03eb50961 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -17,7 +17,6 @@ import websockets from dateutil.parser import parse from typing_extensions import Self, TypeAlias, overload -from websockets.protocol import State from juju import errors, jasyncio, tag, utils from juju.client import client @@ -93,7 +92,7 @@ def status(self): and connection._receiver_task.cancelled() ) - if stopped or connection._ws.state is not State.OPEN: + if stopped or connection._ws.state is not websockets.protocol.State.OPEN: return self.ERROR # everything is fine! diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index a7cc808cb..9ed876466 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -7,6 +7,7 @@ from unittest import mock import pytest +import websockets from websockets.exceptions import ConnectionClosed from juju.client.connection import Connection @@ -17,8 +18,7 @@ class WebsocketMock: def __init__(self, responses): super().__init__() self.responses = deque(responses) - self.open = True - self.closed = False + self.state = websockets.protocol.State.OPEN async def send(self, message): pass @@ -30,8 +30,7 @@ async def recv(self): return json.dumps(self.responses.popleft()) async def close(self): - self.open = False - self.closed = True + pass async def test_out_of_order():