diff --git a/fakeredis/aioredis.py b/fakeredis/aioredis.py index cdebce0e..36765eed 100644 --- a/fakeredis/aioredis.py +++ b/fakeredis/aioredis.py @@ -8,11 +8,6 @@ from ._server import FakeBaseConnectionMixin -if sys.version_info >= (3, 8): - from typing import Type, TypedDict -else: - from typing_extensions import Type, TypedDict - if sys.version_info >= (3, 11): from asyncio import timeout as async_timeout else: @@ -171,22 +166,6 @@ def repr_pieces(self): return pieces -class ConnectionKwargs(TypedDict, total=False): - db: Union[str, int] - username: Optional[str] - password: Optional[str] - socket_timeout: Optional[float] - encoding: str - encoding_errors: str - decode_responses: bool - retry_on_timeout: bool - health_check_interval: int - client_name: Optional[str] - server: Optional[_server.FakeServer] - connection_class: Type[redis_async.Connection] - max_connections: Optional[int] - - class FakeRedis(redis_async.Redis): def __init__( self, @@ -205,27 +184,29 @@ def __init__( username: Optional[str] = None, server: Optional[_server.FakeServer] = None, connected: bool = True, + version=(7,), **kwargs, ): if not connection_pool: # Adapted from aioredis - connection_kwargs: ConnectionKwargs = { - "db": db, + connection_kwargs = dict( + db=db, # Ignoring because AUTH is not implemented # 'username', # 'password', - "socket_timeout": socket_timeout, - "encoding": encoding, - "encoding_errors": encoding_errors, - "decode_responses": decode_responses, - "retry_on_timeout": retry_on_timeout, - "health_check_interval": health_check_interval, - "client_name": client_name, - "server": server, - "connected": connected, - "connection_class": FakeConnection, - "max_connections": max_connections, - } + socket_timeout=socket_timeout, + encoding=encoding, + encoding_errors=encoding_errors, + decode_responses=decode_responses, + retry_on_timeout=retry_on_timeout, + health_check_interval=health_check_interval, + client_name=client_name, + server=server, + connected=connected, + connection_class=FakeConnection, + max_connections=max_connections, + version=version, + ) connection_pool = redis_async.ConnectionPool(**connection_kwargs) super().__init__( db=db, diff --git a/test/test_redis_asyncio.py b/test/test_redis_asyncio.py index d4ec22d6..6a8aa6f9 100644 --- a/test/test_redis_asyncio.py +++ b/test/test_redis_asyncio.py @@ -274,6 +274,19 @@ async def test_from_url(): await r1.connection_pool.disconnect() +@pytest.mark.fake +async def test_from_url_with_version(): + r0 = aioredis.FakeRedis.from_url('redis://localhost?db=0', version=(6,)) + r1 = aioredis.FakeRedis.from_url('redis://localhost?db=1', version=(6,)) + # Check that they are indeed different databases + await r0.set('foo', 'a') + await r1.set('foo', 'b') + assert await r0.get('foo') == b'a' + assert await r1.get('foo') == b'b' + await r0.connection_pool.disconnect() + await r1.connection_pool.disconnect() + + @fake_only async def test_from_url_with_server(req_aioredis2, fake_server): r2 = aioredis.FakeRedis.from_url('redis://localhost', server=fake_server)