Skip to content

Commit

Permalink
Tcp server (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla authored Sep 27, 2024
1 parent af4170f commit a279dd0
Show file tree
Hide file tree
Showing 26 changed files with 377 additions and 99 deletions.
9 changes: 8 additions & 1 deletion docs/about/changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
---
title: Change log
description: Change log of all fakeredis releases
description: Changelog of all fakeredis releases
tags:
- changelog
- release-notes
toc_depth: 2
---

## v2.25.0
Expand All @@ -10,6 +14,9 @@ description: Change log of all fakeredis releases
- Implement support for hash expiration related commands @j00bar #328
- `HEXPIRE`, `HEXPIREAT`, `HEXPIRETIME`, `HPERSIST`, `HPEXPIRE`, `HPEXPIREAT`, `HPEXPIRETIME`, `HPTTL`, `HTTL`,
- Implement support for `SORT_RO` #325, `EXPIRETIME` #323, and `PEXPIRETIME` #324
- Support for creating a tcp server listening to multiple clients
- Testing against valkey 8.0 #333
- Improve documentation #332

## v2.24.1

Expand Down
36 changes: 35 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
## fakeredis: A python implementation of redis server
---
toc:
toc_depth: 3
---

fakeredis: A python implementation of redis server
=================================================


FakeRedis is a pure-Python implementation of the Redis key-value store.

Expand Down Expand Up @@ -30,6 +37,27 @@ pip install fakeredis[probabilistic,json] ## Support for RedisJSON and BloomFil

## How to Use

### Start a server on a thread

It is possible to start a server on a thread and use it as a connect to it as you would a real redis server.

```python
from threading import Thread
from fakeredis import TcpFakeServer

server_address = ("127.0.0.1", 6379)
server = TcpFakeServer(server_address)
t = Thread(target=server.serve_forever, daemon=True)
t.start()

import redis

r = redis.Redis(host=server_address[0], port=server_address[1])
r.set("foo", "bar")
assert r.get("foo") == b"bar"

```

### Use as a pytest fixture

```python
Expand Down Expand Up @@ -196,11 +224,13 @@ from fastapi import Depends, FastAPI

app = FastAPI()


async def get_redis() -> AsyncIterator[redis.Redis]:
# Code to handle creating a redis connection goes here, for example
async with redis.from_url("redis://localhost:6379") as client: # type: ignore[no-untyped-call]
yield client


@app.get("/")
async def root(redis_client: Annotated[redis.Redis, Depends(get_redis)]) -> Any:
# Code that does something with redis goes here, for example:
Expand All @@ -223,20 +253,24 @@ from redis import asyncio as redis

from main import app, get_redis


@pytest_asyncio.fixture
async def redis_client() -> AsyncIterator[redis.Redis]:
async with fakeredis.FakeAsyncRedis() as client:
yield client


@pytest_asyncio.fixture
async def app_client(redis_client: redis.Redis) -> AsyncIterator[httpx.AsyncClient]:
async def get_redis_override() -> redis.Redis:
return redis_client

transport = httpx.ASGITransport(app=app) # type: ignore[arg-type] # https://github.com/encode/httpx/issues/3111
async with httpx.AsyncClient(transport=transport, base_url="http://test") as app_client:
with mock.patch.dict(app.dependency_overrides, {get_redis: get_redis_override}):
yield app_client


@pytest.mark.asyncio
async def test_app(app_client: httpx.AsyncClient) -> None:
response = await app_client.get("/")
Expand Down
20 changes: 20 additions & 0 deletions docs/overrides/partials/toc-item.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<li class="md-nav__item">
<a href="{{ toc_item.url }}" class="md-nav__link">
<span class="md-ellipsis">
{{ toc_item.title }}
</span>
</a>

<!-- Table of contents list -->
{% if toc_item.children %}
<nav class="md-nav" aria-label="{{ toc_item.title | striptags }}">
<ul class="md-nav__list">
{% for toc_item in toc_item.children %}
{% if not page.meta.toc_depth or toc_item.level <= page.meta.toc_depth %}
{% include "partials/toc-item.html" %}
{% endif %}
{% endfor %}
</ul>
</nav>
{% endif %}
</li>
12 changes: 12 additions & 0 deletions fakeredis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

from ._connection import (
FakeRedis,
FakeStrictRedis,
Expand All @@ -9,6 +11,15 @@
FakeConnection as FakeAsyncConnection,
)

if sys.version_info >= (3, 11):
from ._tcp_server import TcpFakeServer
else:

class TcpFakeServer:
def __init__(self, *args, **kwargs):
raise NotImplementedError("TcpFakeServer is only available in Python 3.11+")


try:
from importlib import metadata
except ImportError: # for Python < 3.8
Expand All @@ -28,4 +39,5 @@
"FakeConnection",
"FakeAsyncRedis",
"FakeAsyncConnection",
"TcpFakeServer",
]
9 changes: 8 additions & 1 deletion fakeredis/_basefakesocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,20 @@ def __init__(self, server: "FakeServer", db: int, *args: Any, **kwargs: Any) ->
self._paused = False
self._parser = self._parse_commands()
self._parser.send(None)
self.version = server.version
# Assigned elsewhere
self._transaction: Optional[List[Any]]
self._in_transaction: bool
self._pubsub: int
self._transaction_failed: bool

@property
def version(self) -> Tuple[int, ...]:
return self._server.version

@property
def server_type(self) -> str:
return self._server.server_type

def put_response(self, msg: Any) -> None:
"""Put a response message into the queue of responses.
Expand Down
5 changes: 5 additions & 0 deletions fakeredis/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
*args: Any,
server: Optional[FakeServer] = None,
version: VersionType = (7,),
server_type: str = "redis",
lua_modules: Optional[Set[str]] = None,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -130,18 +131,22 @@ def __init__(
"connection_class": FakeConnection,
"server": server,
"version": version,
"server_type": server_type,
"lua_modules": lua_modules,
}
connection_kwargs.update({arg: kwds[arg] for arg in conn_pool_args if arg in kwds})
kwds["connection_pool"] = redis.connection.ConnectionPool(**connection_kwargs) # type: ignore
kwds.pop("server", None)
kwds.pop("connected", None)
kwds.pop("version", None)
kwds.pop("server_type", None)
kwds.pop("lua_modules", None)
super().__init__(**kwds)

@classmethod
def from_url(cls, *args: Any, **kwargs: Any) -> Self:
kwargs.setdefault("version", "7.4")
kwargs.setdefault("server_type", "redis")
pool = redis.ConnectionPool.from_url(*args, **kwargs)
# Now override how it creates connections
pool.connection_class = FakeConnection
Expand Down
3 changes: 3 additions & 0 deletions fakeredis/_msgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
LOG_INVALID_DEBUG_LEVEL_MSG = "ERR Invalid debug level."
LUA_COMMAND_ARG_MSG6 = "ERR Lua redis() command arguments must be strings or integers"
LUA_COMMAND_ARG_MSG = "ERR Lua redis lib command arguments must be strings or integers"
VALKEY_LUA_COMMAND_ARG_MSG = "Command arguments must be strings or integers script: {}"
LUA_WRONG_NUMBER_ARGS_MSG = "ERR wrong number or type of arguments"
SCRIPT_ERROR_MSG = "ERR Error running script (call to f_{}): @user_script:?: {}"
RESTORE_KEY_EXISTS = "BUSYKEY Target key name already exists."
Expand Down Expand Up @@ -100,9 +101,11 @@
TIMESERIES_KEY_EXISTS = "TSDB: key already exists"
TIMESERIES_INVALID_DUPLICATE_POLICY = "TSDB: Unknown DUPLICATE_POLICY"
TIMESERIES_KEY_DOES_NOT_EXIST = "TSDB: the key does not exist"
TIMESERIES_RULE_DOES_NOT_EXIST = "TSDB: compaction rule does not exist"
TIMESERIES_RULE_EXISTS = "TSDB: the destination key already has a src rule"
TIMESERIES_BAD_AGGREGATION_TYPE = "TSDB: Unknown aggregation type"
TIMESERIES_INVALID_TIMESTAMP = "TSDB: invalid timestamp"
TIMESERIES_BAD_TIMESTAMP = "TSDB: Couldn't parse alignTimestamp"
TIMESERIES_TIMESTAMP_OLDER_THAN_RETENTION = "TSDB: Timestamp is older than retention"
TIMESERIES_TIMESTAMP_LOWER_THAN_MAX_V7 = (
"TSDB: timestamp must be equal to or higher than the maximum existing timestamp"
Expand Down
17 changes: 10 additions & 7 deletions fakeredis/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _create_version(v: VersionType) -> Tuple[int, ...]:
class FakeServer:
_servers_map: Dict[str, "FakeServer"] = dict()

def __init__(self, version: VersionType = (7,)):
def __init__(self, version: VersionType = (7,), server_type: str = "redis") -> None:
self.lock = threading.Lock()
self.dbs: Dict[int, Database] = defaultdict(lambda: Database(self.lock))
# Maps channel/pattern to a weak set of sockets
Expand All @@ -43,15 +43,18 @@ def __init__(self, version: VersionType = (7,)):
self.connected = True
# List of weakrefs to sockets that are being closed lazily
self.closed_sockets: List[Any] = []
self.version = _create_version(version)
self.version: Tuple[int, ...] = _create_version(version)
if server_type not in ("redis", "dragonfly", "valkey"):
raise ValueError(f"Unsupported server type: {server_type}")
self.server_type: str = server_type

@staticmethod
def get_server(key: str, version: VersionType) -> "FakeServer":
return FakeServer._servers_map.setdefault(key, FakeServer(version=version))
def get_server(key: str, version: VersionType, server_type: str) -> "FakeServer":
return FakeServer._servers_map.setdefault(key, FakeServer(version=version, server_type=server_type))


class FakeBaseConnectionMixin(object):
def __init__(self, *args: Any, version: VersionType = (7, 0), **kwargs: Any) -> None:
def __init__(self, *args: Any, version: VersionType, server_type: str, **kwargs: Any) -> None:
self.client_name: Optional[str] = None
self.server_key: str
self._sock = None
Expand All @@ -66,7 +69,7 @@ def __init__(self, *args: Any, version: VersionType = (7, 0), **kwargs: Any) ->
else:
host, port = kwargs.get("host"), kwargs.get("port")
self.server_key = f"{host}:{port}"
self.server_key += f":v{version}"
self._server = FakeServer.get_server(self.server_key, version=version)
self.server_key += f":{server_type}:v{version}"
self._server = FakeServer.get_server(self.server_key, server_type=server_type, version=version)
self._server.connected = connected
super().__init__(*args, **kwargs)
127 changes: 127 additions & 0 deletions fakeredis/_tcp_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import logging
from dataclasses import dataclass
from itertools import count
from socketserver import ThreadingTCPServer, StreamRequestHandler
from typing import BinaryIO, Dict, Tuple

from fakeredis import FakeRedis
from fakeredis import FakeServer

LOGGER = logging.getLogger("fakeredis")


def to_bytes(value) -> bytes:
if isinstance(value, bytes):
return value
return str(value).encode()


@dataclass
class Client:
connection: FakeRedis
client_address: int


@dataclass
class Reader:
reader: BinaryIO

def load_array(self, length: int):
array = [None] * length
for i in range(length):
array[i] = self.load()
return array

def load(self):
line = self.reader.readline().strip()
match line[0:1], line[1:]:
case b"*", length:
return self.load_array(int(length))
case b"$", length:
bulk_string = self.reader.read(int(length) + 2).strip()
if len(bulk_string) != int(length):
raise ValueError()
return bulk_string
case b":", value:
return int(value)
case b"+", value:
return value
case b"-", value:
return Exception(value)
case _:
return None


@dataclass
class Writer:
writer: BinaryIO

def dump(self, value, dump_bulk=False):
if isinstance(value, int):
self.writer.write(f":{value}\r\n".encode())
elif isinstance(value, (str, bytes)):
value = to_bytes(value)
if dump_bulk or b"\r" in value or b"\n" in value:
self.writer.write(b"$" + str(len(value)).encode() + b"\r\n" + value + b"\r\n")
else:
self.writer.write(b"+" + value + b"\r\n")
elif isinstance(value, (list, set)):
self.writer.write(f"*{len(value)}\r\n".encode())
for item in value:
self.dump(item, dump_bulk=True)
elif value is None:
self.writer.write("$-1\r\n".encode())
elif isinstance(value, Exception):
self.writer.write(f"-{value.args[0]}\r\n".encode())


class TCPFakeRequestHandler(StreamRequestHandler):

def setup(self) -> None:
super().setup()
if self.client_address in self.server.clients:
self.current_client = self.server.clients[self.client_address]
else:
self.current_client = Client(
connection=FakeRedis(server=self.server.fake_server),
client_address=self.client_address,
)
self.reader = Reader(self.rfile)
self.writer = Writer(self.wfile)
self.server.clients[self.client_address] = self.current_client

def handle(self):
while True:
try:
self.data = self.reader.load()
LOGGER.debug(f">>> {self.client_address[0]}: {self.data}")
res = self.current_client.connection.execute_command(*self.data)
LOGGER.debug(f"<<< {self.client_address[0]}: {res}")
self.writer.dump(res)
except Exception as e:
LOGGER.debug(f"!!! {self.client_address[0]}: {e}")
self.writer.dump(e)
break

def finish(self) -> None:
del self.server.clients[self.current_client.client_address]
super().finish()


class TcpFakeServer(ThreadingTCPServer):
def __init__(
self,
server_address: Tuple[str | bytes | bytearray, int],
bind_and_activate: bool = True,
server_type: str = "redis",
server_version: Tuple[int, ...] = (7, 4),
):
super().__init__(server_address, TCPFakeRequestHandler, bind_and_activate)
self.fake_server = FakeServer(server_type=server_type, version=server_version)
self.client_ids = count(0)
self.clients: Dict[int, FakeRedis] = dict()


if __name__ == "__main__":
server = TcpFakeServer(("localhost", 19000))
server.serve_forever()
Loading

0 comments on commit a279dd0

Please sign in to comment.