Skip to content

Commit

Permalink
Fix tying for mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Oct 5, 2023
1 parent a9d8698 commit 55f43c9
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions tests/resp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
from contextlib import closing
from typing import Any, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

CRNL = b"\r\n"

Expand Down Expand Up @@ -34,11 +34,11 @@ def __init__(self, code: str, value: str) -> None:
def __repr__(self) -> str:
return f"ErrorString({self.code!r}, {super().__repr__()})"

def __str__(self):
def __str__(self) -> str:
return f"{self.code} {super().__str__()}"


class PushData(list):
class PushData(List[Any]):
"""
A special type of list indicating data from a push response
"""
Expand All @@ -47,7 +47,7 @@ def __repr__(self) -> str:
return f"PushData({super().__repr__()})"


class Attribute(dict):
class Attribute(Dict[Any, Any]):
"""
A special type of map indicating data from a attribute response
"""
Expand All @@ -62,7 +62,7 @@ class RespEncoder:
"""

def __init__(
self, protocol: int = 2, encoding: str = "utf-8", errorhander="strict"
self, protocol: int = 2, encoding: str = "utf-8", errorhander: str = "strict"
) -> None:
self.protocol = protocol
self.encoding = encoding
Expand Down Expand Up @@ -248,7 +248,7 @@ def parse(
rest += incoming
string = self.decode_bytes(rest[: (count + 4)])
if string[3] != ":":
raise ValueError(f"Expected colon after hint, got {bulkstr[3]}")
raise ValueError(f"Expected colon after hint, got {string[3]}")
hint = string[:3]
string = string[4 : (count + 4)]
yield VerbatimStr(string, hint), rest[expect:]
Expand Down Expand Up @@ -310,8 +310,8 @@ def parse(
# we decode them automatically
decoded = self.decode_bytes(arg)
assert isinstance(decoded, str)
code, value = decoded.split(" ", 1)
yield ErrorStr(code, value), rest
err, value = decoded.split(" ", 1)
yield ErrorStr(err, value), rest

elif code == b"!": # resp3 error
count = int(arg)
Expand All @@ -323,8 +323,8 @@ def parse(
bulkstr = rest[:count]
decoded = self.decode_bytes(bulkstr)
assert isinstance(decoded, str)
code, value = decoded.split(" ", 1)
yield ErrorStr(code, value), rest[expect:]
err, value = decoded.split(" ", 1)
yield ErrorStr(err, value), rest[expect:]

else:
raise ValueError(f"Unknown opcode '{code.decode()}'")
Expand Down Expand Up @@ -427,26 +427,26 @@ class RespServer:
Accepts RESP commands and returns RESP responses.
"""

handlers = {}
handlers: Dict[str, Callable[..., Any]] = {}

def __init__(self):
def __init__(self) -> None:
self.protocol = 2
self.server_ver = self.get_server_version()
self.auth = []
self.auth: List[Any] = []
self.client_name = ""

# patchable methods for testing

def get_server_version(self):
def get_server_version(self) -> int:
return 6

def on_auth(self, auth):
def on_auth(self, auth: List[Any]) -> None:
pass

def on_setname(self, name):
def on_setname(self, name: str) -> None:
pass

def on_protocol(self, proto):
def on_protocol(self, proto: int) -> None:
pass

def command(self, cmd: Any) -> bytes:
Expand All @@ -466,7 +466,7 @@ def _command(self, cmd: Any) -> Any:

return ErrorStr("ERR", "unknown command {cmd!r}")

def handle_auth(self, args):
def handle_auth(self, args: List[Any]) -> Union[str, ErrorStr]:
self.auth = args[:]
self.on_auth(self.auth)
expect = 2 if self.server_ver >= 6 else 1
Expand All @@ -476,21 +476,21 @@ def handle_auth(self, args):

handlers["AUTH"] = handle_auth

def handle_client(self, args):
def handle_client(self, args: List[Any]) -> Union[str, ErrorStr]:
if args[0] == "SETNAME":
return self.handle_setname(args[1:])
return ErrorStr("ERR", "unknown subcommand or wrong number of arguments")

handlers["CLIENT"] = handle_client

def handle_setname(self, args):
def handle_setname(self, args: List[Any]) -> Union[str, ErrorStr]:
if len(args) != 1:
return ErrorStr("ERR", "wrong number of arguments")
self.client_name = args[0]
self.on_setname(self.client_name)
return "OK"

def handle_hello(self, args):
def handle_hello(self, args: List[Any]) -> Union[ErrorStr, Dict[str, Any]]:
if self.server_ver < 6:
return ErrorStr("ERR", "unknown command 'HELLO'")
proto = self.protocol
Expand All @@ -507,14 +507,14 @@ def handle_hello(self, args):
auth_args = args[:2]
args = args[2:]
res = self.handle_auth(auth_args)
if res != "OK":
if isinstance(res, ErrorStr):
return res
continue
if cmd == "SETNAME":
setname_args = args[:1]
args = args[1:]
res = self.handle_setname(setname_args)
if res != "OK":
if isinstance(res, ErrorStr):
return res
continue
return ErrorStr("ERR", "unknown subcommand or wrong number of arguments")
Expand Down

0 comments on commit 55f43c9

Please sign in to comment.