From a4feb3c9733bcac04e63dc485ccc12b98af0824e Mon Sep 17 00:00:00 2001 From: Chayim Date: Wed, 15 Mar 2023 11:35:20 +0200 Subject: [PATCH] Speeding up the protocol parsing (#2596) * speeding up the protocol parser * linting * changes to ease --- redis/asyncio/connection.py | 25 ++++++++++++------------- redis/connection.py | 24 +++++++++++------------- 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index e77fba30da..056998e9e0 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -267,9 +267,6 @@ async def _read_response( response: Any byte, response = raw[:1], raw[1:] - if byte not in (b"-", b"+", b":", b"$", b"*"): - raise InvalidResponse(f"Protocol Error: {raw!r}") - # server returned an error if byte == b"-": response = response.decode("utf-8", errors="replace") @@ -289,22 +286,24 @@ async def _read_response( pass # int value elif byte == b":": - response = int(response) + return int(response) # bulk response + elif byte == b"$" and response == b"-1": + return None elif byte == b"$": - length = int(response) - if length == -1: - return None - response = await self._read(length) + response = await self._read(int(response)) # multi-bulk response + elif byte == b"*" and response == b"-1": + return None elif byte == b"*": - length = int(response) - if length == -1: - return None response = [ - (await self._read_response(disable_decoding)) for _ in range(length) + (await self._read_response(disable_decoding)) + for _ in range(int(response)) # noqa ] - if isinstance(response, bytes) and disable_decoding is False: + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if disable_decoding is False: response = self.encoder.decode(response) return response diff --git a/redis/connection.py b/redis/connection.py index d35980c167..c4a9685f6a 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -358,9 +358,6 @@ def _read_response(self, disable_decoding=False): byte, response = raw[:1], raw[1:] - if byte not in (b"-", b"+", b":", b"$", b"*"): - raise InvalidResponse(f"Protocol Error: {raw!r}") - # server returned an error if byte == b"-": response = response.decode("utf-8", errors="replace") @@ -379,23 +376,24 @@ def _read_response(self, disable_decoding=False): pass # int value elif byte == b":": - response = int(response) + return int(response) # bulk response + elif byte == b"$" and response == b"-1": + return None elif byte == b"$": - length = int(response) - if length == -1: - return None - response = self._buffer.read(length) + response = self._buffer.read(int(response)) # multi-bulk response + elif byte == b"*" and response == b"-1": + return None elif byte == b"*": - length = int(response) - if length == -1: - return None response = [ self._read_response(disable_decoding=disable_decoding) - for i in range(length) + for i in range(int(response)) ] - if isinstance(response, bytes) and disable_decoding is False: + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if disable_decoding is False: response = self.encoder.decode(response) return response