Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix behaviour of async PythonParser to match RedisParser as for issue #2349 #2582

Merged
merged 5 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Allow data to drain from async PythonParser when reading during a disconnect()
* Add test and fix async HiredisParser when reading during a disconnect() (#2349)
* Use hiredis-py pack_command if available.
* Support `.unlink()` in ClusterPipeline
Expand Down
24 changes: 11 additions & 13 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def decode(self, value: EncodableT, force=False) -> EncodableT:
class BaseParser:
"""Plain Python parsing class"""

__slots__ = "_stream", "_read_size"
__slots__ = "_stream", "_read_size", "_connected"

EXCEPTION_CLASSES: ExceptionMappingT = {
"ERR": {
Expand Down Expand Up @@ -172,6 +172,7 @@ class BaseParser:
def __init__(self, socket_read_size: int):
self._stream: Optional[asyncio.StreamReader] = None
self._read_size = socket_read_size
self._connected = False

def __del__(self):
try:
Expand Down Expand Up @@ -208,7 +209,7 @@ async def read_response(
class PythonParser(BaseParser):
"""Plain Python parsing class"""

__slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
__slots__ = ("encoder", "_buffer", "_pos", "_chunks")

def __init__(self, socket_read_size: int):
super().__init__(socket_read_size)
Expand All @@ -226,28 +227,28 @@ def on_connect(self, connection: "Connection"):
self._stream = connection._reader
if self._stream is None:
raise RedisError("Buffer is closed.")

self.encoder = connection.encoder
self._clear()
self._connected = True

def on_disconnect(self):
"""Called when the stream disconnects"""
if self._stream is not None:
self._stream = None
self.encoder = None
self._clear()
self._connected = False

async def can_read_destructive(self) -> bool:
if not self._connected:
raise RedisError("Buffer is closed.")
if self._buffer:
return True
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
async with async_timeout.timeout(0):
return await self._stream.read(1)
except asyncio.TimeoutError:
return False

async def read_response(self, disable_decoding: bool = False):
if not self._connected:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._chunks:
# augment parsing buffer with previously read data
self._buffer += b"".join(self._chunks)
Expand All @@ -261,8 +262,6 @@ async def read_response(self, disable_decoding: bool = False):
async def _read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
raw = await self._readline()
response: Any
byte, response = raw[:1], raw[1:]
Expand Down Expand Up @@ -350,14 +349,13 @@ async def _readline(self) -> bytes:
class HiredisParser(BaseParser):
"""Parser class for connections using Hiredis"""

__slots__ = BaseParser.__slots__ + ("_reader", "_connected")
__slots__ = ("_reader",)

def __init__(self, socket_read_size: int):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not available.")
super().__init__(socket_read_size=socket_read_size)
self._reader: Optional[hiredis.Reader] = None
self._connected: bool = False

def on_connect(self, connection: "Connection"):
self._stream = connection._reader
Expand Down
2 changes: 0 additions & 2 deletions tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,6 @@ async def test_connection_disconect_race(parser_class):
This test verifies that a read in progress can finish even
if the `disconnect()` method is called.
"""
if parser_class == PythonParser:
pytest.xfail("doesn't work yet with PythonParser")
if parser_class == HiredisParser and not HIREDIS_AVAILABLE:
pytest.skip("Hiredis not available")

Expand Down