Skip to content

Commit

Permalink
feat:ACL commands (#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla authored Dec 25, 2024
1 parent 831f47b commit 0a9ff56
Show file tree
Hide file tree
Showing 26 changed files with 1,316 additions and 166 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ jobs:
max-parallel: 8
fail-fast: false
matrix:
redis-image: [ "redis:6.2.14", "redis:7.0.15", "redis:7.4.0" ]
python-version: [ "3.9", "3.11", "3.12", "3.13" ]
redis-image: [ "redis:6.2.16", "redis:7.4.1" ]
python-version: [ "3.9", "3.12", "3.13" ]
redis-py: [ "4.3.6", "4.6.0", "5.0.8", "5.2.1", "5.3.0b3" ]
include:
- python-version: "3.12"
Expand All @@ -66,12 +66,12 @@ jobs:
extra: "lua"
hypothesis: true
- python-version: "3.12"
redis-image: "redis/redis-stack-server:6.2.6-v15"
redis-image: "redis/redis-stack-server:6.2.6-v17"
redis-py: "5.2.1"
extra: "json, bf, lua, cf"
hypothesis: true
- python-version: "3.12"
redis-image: "redis/redis-stack-server:7.4.0-v0"
redis-image: "redis/redis-stack-server:7.4.0-v1"
redis-py: "5.2.1"
extra: "json, bf, lua, cf"
coverage: true
Expand Down
7 changes: 7 additions & 0 deletions docs/about/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@ toc_depth: 2

### 🚀 Features

- ACL commands support #338
- Add support disable_decoding in async read_response #349
- Implement support for `SADDEX`, using a new set implementation with support for expiring members #350

### 🧰 Maintenance

- Remove end of life python 3.8 from test matrix
- Add python 3.13 to test matrix
- Improve documentation for Dragonfly/Valkey support

## v2.26.2

### 🐛 Bug Fixes
Expand Down
87 changes: 45 additions & 42 deletions fakeredis/_basefakesocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(self, server: "FakeServer", db: int, *args: Any, **kwargs: Any) ->
self._in_transaction: bool
self._pubsub: int
self._transaction_failed: bool
self._current_user: bytes = b"default"
self._client_info: bytes = kwargs.pop("client_info", b"")

@property
def version(self) -> Tuple[int, ...]:
Expand Down Expand Up @@ -181,6 +183,49 @@ def _parse_commands(self) -> Generator[None, Any, None]:
buf = buf[length + 2 :] # +2 to skip the CRLF
self._process_command(fields)

def _process_command(self, fields: List[bytes]) -> None:
if not fields:
return
result: Any
cmd, cmd_arguments = _extract_command(fields)
try:
func, sig = self._name_to_func(cmd)
self._server.acl.validate_command(self._current_user, self._client_info, fields) # ACL check
with self._server.lock:
# Clean out old connections
while True:
try:
weak_sock = self._server.closed_sockets.pop()
except IndexError:
break
else:
sock = weak_sock()
if sock:
sock._cleanup(self._server)
now = time.time()
for db in self._server.dbs.values():
db.time = now
sig.check_arity(cmd_arguments, self.version)
if self._transaction is not None and msgs.FLAG_TRANSACTION not in sig.flags:
self._transaction.append((func, sig, cmd_arguments))
result = QUEUED
else:
result = self._run_command(func, sig, cmd_arguments, False)
except SimpleError as exc:
if self._transaction is not None:
# TODO: should not apply if the exception is from _run_command
# e.g. watch inside multi
self._transaction_failed = True
if cmd == "exec" and exc.value.startswith("ERR "):
exc.value = "EXECABORT Transaction discarded because of: " + exc.value[4:]
self._transaction = None
self._transaction_failed = False
self._clear_watches()
result = exc
result = self._decode_result(result)
if not isinstance(result, NoResponse):
self.put_response(result)

def _run_command(
self, func: Optional[Callable[[Any], Any]], sig: Signature, args: List[Any], from_script: bool
) -> Any:
Expand Down Expand Up @@ -263,48 +308,6 @@ def sendall(self, data: AnyStr) -> None:
data = data.encode("ascii") # type: ignore
self._parser.send(data)

def _process_command(self, fields: List[bytes]) -> None:
if not fields:
return
result: Any
cmd, cmd_arguments = _extract_command(fields)
try:
func, sig = self._name_to_func(cmd)
with self._server.lock:
# Clean out old connections
while True:
try:
weak_sock = self._server.closed_sockets.pop()
except IndexError:
break
else:
sock = weak_sock()
if sock:
sock._cleanup(self._server)
now = time.time()
for db in self._server.dbs.values():
db.time = now
sig.check_arity(cmd_arguments, self.version)
if self._transaction is not None and msgs.FLAG_TRANSACTION not in sig.flags:
self._transaction.append((func, sig, cmd_arguments))
result = QUEUED
else:
result = self._run_command(func, sig, cmd_arguments, False)
except SimpleError as exc:
if self._transaction is not None:
# TODO: should not apply if the exception is from _run_command
# e.g. watch inside multi
self._transaction_failed = True
if cmd == "exec" and exc.value.startswith("ERR "):
exc.value = "EXECABORT Transaction discarded because of: " + exc.value[4:]
self._transaction = None
self._transaction_failed = False
self._clear_watches()
result = exc
result = self._decode_result(result)
if not isinstance(result, NoResponse):
self.put_response(result)

def _scan(self, keys, cursor, *args):
"""This is the basis of most of the ``scan`` methods.
Expand Down
23 changes: 12 additions & 11 deletions fakeredis/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ def connect(self) -> None:
def _connect(self) -> FakeSocket:
if not self._server.connected:
raise redis.ConnectionError(msgs.CONNECTION_ERROR_MSG)
return FakeSocket(self._server, db=self.db, lua_modules=self._lua_modules)
return FakeSocket(
self._server,
db=self.db,
lua_modules=self._lua_modules,
client_info=b"id=3 addr=127.0.0.1:57275 laddr=127.0.0.1:6379 fd=8 name= age=16 idle=0 flags=N db=0 sub=0 psub=0 ssub=0 multi=-1 qbuf=48 qbuf-free=16842 argv-mem=25 multi-mem=0 rbs=1024 rbp=0 obl=0 oll=0 omem=0 tot-mem=18737 events=r cmd=auth user=default redir=-1 resp=2",
)

def can_read(self, timeout: Optional[float] = 0) -> bool:
if not self._server.connected:
Expand Down Expand Up @@ -62,7 +67,7 @@ def read_response(self, **kwargs: Any) -> Any: # type: ignore
raise redis.ConnectionError(msgs.CONNECTION_ERROR_MSG)
else:
response = self._sock.responses.get()
if isinstance(response, redis.ResponseError):
if isinstance(response, (redis.ResponseError, redis.AuthenticationError)):
raise response
if kwargs.get("disable_decoding", False):
return response
Expand Down Expand Up @@ -100,6 +105,7 @@ def __init__(
for ind, p in enumerate(parameters)
if p.default != inspect.Parameter.empty
}
kwds["server"] = server
if not kwds.get("connection_pool", None):
charset = kwds.get("charset", None)
errors = kwds.get("errors", None)
Expand All @@ -114,9 +120,8 @@ def __init__(
"host",
"port",
"db",
# Ignoring because AUTH is not implemented
# 'username',
# 'password',
"username",
"password",
"socket_timeout",
"encoding",
"encoding_errors",
Expand All @@ -126,10 +131,10 @@ def __init__(
"health_check_interval",
"client_name",
"connected",
"server",
}
connection_kwargs = {
"connection_class": FakeConnection,
"server": server,
"version": version,
"server_type": server_type,
"lua_modules": lua_modules,
Expand All @@ -150,11 +155,7 @@ def from_url(cls, *args: Any, **kwargs: Any) -> Self:
pool = redis.ConnectionPool.from_url(*args, **kwargs)
# Now override how it creates connections
pool.connection_class = FakeConnection
# Using username and password fails since AUTH is not implemented.
# https://github.com/cunla/fakeredis-py/issues/9
pool.connection_kwargs.pop("username", None)
pool.connection_kwargs.pop("password", None)
return cls(connection_pool=pool)
return cls(connection_pool=pool, *args, **kwargs)


class FakeStrictRedis(FakeRedisMixin, redis.StrictRedis): # type: ignore
Expand Down
8 changes: 6 additions & 2 deletions fakeredis/_fakesocket.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Set
from typing import Optional, Set, Any

from fakeredis.commands_mixins import (
BitmapCommandsMixin,
Expand All @@ -14,6 +14,7 @@
TransactionsCommandsMixin,
SetCommandsMixin,
StreamsCommandsMixin,
AclCommandsMixin,
)
from fakeredis.stack import (
JSONCommandsMixin,
Expand Down Expand Up @@ -54,11 +55,14 @@ class FakeSocket(
TDigestCommandsMixin,
TimeSeriesCommandsMixin,
DragonflyCommandsMixin,
AclCommandsMixin,
):
def __init__(
self,
server: "FakeServer",
db: int,
lua_modules: Optional[Set[str]] = None, # noqa: F821
*args: Any,
**kwargs,
) -> None:
super(FakeSocket, self).__init__(server, db, lua_modules=lua_modules)
super(FakeSocket, self).__init__(server, db, *args, lua_modules=lua_modules, **kwargs)
9 changes: 9 additions & 0 deletions fakeredis/_msgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@
)
INVALID_OVERFLOW_TYPE = "ERR Invalid OVERFLOW type specified"

# ACL specific errors
AUTH_FAILURE = "WRONGPASS invalid username-password pair or user is disabled."

# TDigest error messages
TDIGEST_KEY_EXISTS = "T-Digest: key already exists"
TDIGEST_KEY_NOT_EXISTS = "T-Digest: key does not exist"
Expand Down Expand Up @@ -118,6 +121,12 @@
TIMESERIES_BAD_FILTER_EXPRESSION = "TSDB: failed parsing labels"
HEXPIRE_NUMFIELDS_DIFFERENT = "The `numfields` parameter must match the number of arguments"

MISSING_ACLFILE_CONFIG = "ERR This Redis instance is not configured to use an ACL file. You may want to specify users via the ACL SETUSER command and then issue a CONFIG REWRITE (assuming you have a Redis configuration file set) in order to store users in the Redis configuration."

NO_PERMISSION_ERROR = "NOPERM User {} has no permissions to run the '{}' command"
NO_PERMISSION_KEY_ERROR = "NOPERM No permissions to access a key"
NO_PERMISSION_CHANNEL_ERROR = "NOPERM No permissions to access a channel"

# Command flags
FLAG_NO_SCRIPT = "s" # Command not allowed in scripts
FLAG_LEAVE_EMPTY_VAL = "v"
Expand Down
37 changes: 32 additions & 5 deletions fakeredis/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing_extensions import Literal


from fakeredis.model import AccessControlList
from fakeredis._helpers import Database, FakeSelector

LOGGER = logging.getLogger("fakeredis")
Expand All @@ -31,10 +32,30 @@ def _create_version(v: VersionType) -> Tuple[int, ...]:
return v


def _version_to_str(v: VersionType) -> str:
if isinstance(v, tuple):
return ".".join(str(x) for x in v)
return str(v)


class FakeServer:
_servers_map: Dict[str, "FakeServer"] = dict()

def __init__(self, version: VersionType = (7,), server_type: ServerType = "redis") -> None:
def __init__(
self,
version: VersionType = (7,),
server_type: ServerType = "redis",
config: Dict[bytes, bytes] = None,
) -> None:
"""Initialize a new FakeServer instance.
:param version: The version of the server (e.g. 6, 7.4, "7.4.1", can also be a tuple)
:param server_type: The type of server (redis, dragonfly, valkey)
:param config: A dictionary of configuration options.
Configuration options:
- `requirepass`: The password required to authenticate to the server.
- `aclfile`: The path to the ACL file.
"""
self.lock = threading.Lock()
self.dbs: Dict[int, Database] = defaultdict(lambda: Database(self.lock))
# Maps channel/pattern to a weak set of sockets
Expand All @@ -49,14 +70,20 @@ def __init__(self, version: VersionType = (7,), server_type: ServerType = "redis
if server_type not in ("redis", "dragonfly", "valkey"):
raise ValueError(f"Unsupported server type: {server_type}")
self.server_type: str = server_type
self.config: Dict[bytes, bytes] = config or dict()
self.acl: AccessControlList = AccessControlList()

@staticmethod
def get_server(key: str, version: VersionType, server_type: str) -> "FakeServer":
return FakeServer._servers_map.setdefault(key, FakeServer(version=version, server_type=server_type))
def get_server(key: str, version: VersionType, server_type: ServerType) -> "FakeServer":
if key not in FakeServer._servers_map:
FakeServer._servers_map[key] = FakeServer(version=version, server_type=server_type)
return FakeServer._servers_map[key]


class FakeBaseConnectionMixin(object):
def __init__(self, *args: Any, version: VersionType = (7, 0), server_type: str = "redis", **kwargs: Any) -> None:
def __init__(
self, *args: Any, version: VersionType = (7, 0), server_type: ServerType = "redis", **kwargs: Any
) -> None:
self.client_name: Optional[str] = None
self.server_key: str
self._sock = None
Expand All @@ -71,7 +98,7 @@ def __init__(self, *args: Any, version: VersionType = (7, 0), server_type: str =
else:
host, port = kwargs.get("host"), kwargs.get("port")
self.server_key = f"{host}:{port}"
self.server_key += f":{server_type}:v{version}"
self.server_key += f":{server_type}:v{_version_to_str(version)[0]}"
self._server = FakeServer.get_server(self.server_key, server_type=server_type, version=version)
self._server.connected = connected
super().__init__(*args, **kwargs)
1 change: 1 addition & 0 deletions fakeredis/_tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fakeredis._server import ServerType

LOGGER = logging.getLogger("fakeredis")
LOGGER.setLevel(logging.DEBUG)


def to_bytes(value) -> bytes:
Expand Down
2 changes: 1 addition & 1 deletion fakeredis/commands.json

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions fakeredis/commands_mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any

from .acl_mixin import AclCommandsMixin
from .bitmap_mixin import BitmapCommandsMixin
from .connection_mixin import ConnectionCommandsMixin
from .generic_mixin import GenericCommandsMixin
Expand All @@ -11,6 +12,7 @@
from .set_mixin import SetCommandsMixin
from .streams_mixin import StreamsCommandsMixin
from .string_mixin import StringCommandsMixin
from .transactions_mixin import TransactionsCommandsMixin

try:
from .scripting_mixin import ScriptingCommandsMixin
Expand All @@ -22,8 +24,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super(ScriptingCommandsMixin, self).__init__(*args, **kwargs) # type: ignore


from .transactions_mixin import TransactionsCommandsMixin

__all__ = [
"BitmapCommandsMixin",
"ConnectionCommandsMixin",
Expand All @@ -38,4 +38,5 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
"SetCommandsMixin",
"StreamsCommandsMixin",
"StringCommandsMixin",
"AclCommandsMixin",
]
Loading

0 comments on commit 0a9ff56

Please sign in to comment.