Skip to content

Commit

Permalink
Simplify cache namespace and key encoding logic for v1 (#670)
Browse files Browse the repository at this point in the history
Co-authored-by: Sam Bull <[email protected]>
  • Loading branch information
padraic-shafer and Dreamsorcerer authored Feb 27, 2023
1 parent b2036b3 commit 2df2056
Show file tree
Hide file tree
Showing 17 changed files with 194 additions and 148 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ Migration instructions

There are a number of backwards-incompatible changes. These points should help with migrating from an older release:

* The ``key_builder`` parameter now expects a callback which accepts 2 strings and returns a string in all cache implementations, making the builders simpler and interchangeable.
* The ``key`` parameter has been removed from the ``cached`` decorator. The behaviour can be easily reimplemented with ``key_builder=lambda *a, **kw: "foo"``
* When using the ``key_builder`` parameter in ``@multicached``, the function will now return the original, unmodified keys, only using the transformed keys in the cache (this has always been the documented behaviour, but not the implemented behaviour).
* ``BaseSerializer`` is now an ``ABC``, so cannot be instantiated directly.
* If subclassing ``BaseCache`` to implement a custom backend:

* The cache key type used by the backend must now be specified when inheriting (e.g. ``BaseCache[str]`` typically).
* The ``build_key()`` method must now be defined (this should generally involve calling ``self._str_build_key()`` as a helper).


0.12.0 (2023-01-13)
Expand Down
4 changes: 2 additions & 2 deletions aiocache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Type
from typing import Any, Dict, Type

from .backends.memory import SimpleMemoryCache
from .base import BaseCache
Expand All @@ -8,7 +8,7 @@

logger = logging.getLogger(__name__)

AIOCACHE_CACHES: Dict[str, Type[BaseCache]] = {SimpleMemoryCache.NAME: SimpleMemoryCache}
AIOCACHE_CACHES: Dict[str, Type[BaseCache[Any]]] = {SimpleMemoryCache.NAME: SimpleMemoryCache}

try:
import redis
Expand Down
9 changes: 5 additions & 4 deletions aiocache/backends/memcached.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
from typing import Optional

import aiomcache

from aiocache.base import BaseCache
from aiocache.serializers import JsonSerializer


class MemcachedBackend(BaseCache):
class MemcachedBackend(BaseCache[bytes]):
def __init__(self, endpoint="127.0.0.1", port=11211, pool_size=2, **kwargs):
super().__init__(**kwargs)
self.endpoint = endpoint
Expand Down Expand Up @@ -130,7 +131,7 @@ class MemcachedCache(MemcachedBackend):
:param serializer: obj derived from :class:`aiocache.serializers.BaseSerializer`.
:param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes.
:param namespace: string to use as default prefix for the key used in all operations of
the backend. Default is None
the backend. Default is an empty string, "".
:param timeout: int or float in seconds specifying maximum timeout for the operations to last.
By default its 5.
:param endpoint: str with the endpoint to connect to. Default is 127.0.0.1.
Expand All @@ -147,8 +148,8 @@ def __init__(self, serializer=None, **kwargs):
def parse_uri_path(cls, path):
return {}

def _build_key(self, key, namespace=None):
ns_key = super()._build_key(key, namespace=namespace).replace(" ", "_")
def build_key(self, key: str, namespace: Optional[str] = None) -> bytes:
ns_key = self._str_build_key(key, namespace).replace(" ", "_")
return str.encode(ns_key)

def __repr__(self): # pragma: no cover
Expand Down
9 changes: 6 additions & 3 deletions aiocache/backends/memory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
from typing import Dict
from typing import Dict, Optional

from aiocache.base import BaseCache
from aiocache.serializers import NullSerializer


class SimpleMemoryBackend(BaseCache):
class SimpleMemoryBackend(BaseCache[str]):
"""
Wrapper around dict operations to use it as a cache backend
"""
Expand Down Expand Up @@ -118,7 +118,7 @@ class SimpleMemoryCache(SimpleMemoryBackend):
:param serializer: obj derived from :class:`aiocache.serializers.BaseSerializer`.
:param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes.
:param namespace: string to use as default prefix for the key used in all operations of
the backend. Default is None.
the backend. Default is an empty string, "".
:param timeout: int or float in seconds specifying maximum timeout for the operations to last.
By default its 5.
"""
Expand All @@ -131,3 +131,6 @@ def __init__(self, serializer=None, **kwargs):
@classmethod
def parse_uri_path(cls, path):
return {}

def build_key(self, key: str, namespace: Optional[str] = None) -> str:
return self._str_build_key(key, namespace)
39 changes: 25 additions & 14 deletions aiocache/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import itertools
import warnings
from typing import Any, Callable, Optional, TYPE_CHECKING

import redis.asyncio as redis
from redis.exceptions import ResponseError as IncrbyException

from aiocache.base import BaseCache, _ensure_key
from aiocache.base import BaseCache
from aiocache.serializers import JsonSerializer

if TYPE_CHECKING: # pragma: no cover
from aiocache.serializers import BaseSerializer


_NOT_SET = object()


class RedisBackend(BaseCache):
class RedisBackend(BaseCache[str]):
RELEASE_SCRIPT = (
"if redis.call('get',KEYS[1]) == ARGV[1] then"
" return redis.call('del',KEYS[1])"
Expand Down Expand Up @@ -186,7 +190,7 @@ class RedisCache(RedisBackend):
:param serializer: obj derived from :class:`aiocache.serializers.BaseSerializer`.
:param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes.
:param namespace: string to use as default prefix for the key used in all operations of
the backend. Default is None.
the backend. Default is an empty string, "".
:param timeout: int or float in seconds specifying maximum timeout for the operations to last.
By default its 5.
:param endpoint: str with the endpoint to connect to. Default is "127.0.0.1".
Expand All @@ -199,8 +203,21 @@ class RedisCache(RedisBackend):

NAME = "redis"

def __init__(self, serializer=None, **kwargs):
super().__init__(serializer=serializer or JsonSerializer(), **kwargs)
def __init__(
self,
serializer: Optional["BaseSerializer"] = None,
namespace: str = "",
key_builder: Optional[Callable[[str, str], str]] = None,
**kwargs: Any,
):
super().__init__(
serializer=serializer or JsonSerializer(),
namespace=namespace,
key_builder=key_builder or (
lambda key, namespace: f"{namespace}:{key}" if namespace else key
),
**kwargs,
)

@classmethod
def parse_uri_path(cls, path):
Expand All @@ -218,14 +235,8 @@ def parse_uri_path(cls, path):
options["db"] = db
return options

def _build_key(self, key, namespace=None):
if namespace is not None:
return "{}{}{}".format(
namespace, ":" if namespace else "", _ensure_key(key))
if self.namespace is not None:
return "{}{}{}".format(
self.namespace, ":" if self.namespace else "", _ensure_key(key))
return key

def __repr__(self): # pragma: no cover
return "RedisCache ({}:{})".format(self.endpoint, self.port)

def build_key(self, key: str, namespace: Optional[str] = None) -> str:
return self._str_build_key(key, namespace)
97 changes: 47 additions & 50 deletions aiocache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,22 @@
import logging
import os
import time
from abc import abstractmethod
from enum import Enum
from types import TracebackType
from typing import Callable, Optional, Set, Type
from typing import Callable, Generic, List, Optional, Set, TYPE_CHECKING, Type, TypeVar

from aiocache import serializers
from aiocache.serializers import StringSerializer

if TYPE_CHECKING: # pragma: no cover
from aiocache.plugins import BasePlugin
from aiocache.serializers import BaseSerializer


logger = logging.getLogger(__name__)

SENTINEL = object()
CacheKeyType = TypeVar("CacheKeyType")


class API:
Expand Down Expand Up @@ -87,7 +93,7 @@ async def _plugins(self, *args, **kwargs):
return _plugins


class BaseCache:
class BaseCache(Generic[CacheKeyType]):
"""
Base class that agregates the common logic for the different caches that may exist. Cache
related available options are:
Expand All @@ -97,9 +103,9 @@ class BaseCache:
:param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes. Default is empty
list.
:param namespace: string to use as default prefix for the key used in all operations of
the backend. Default is None
the backend. Default is an empty string, "".
:param key_builder: alternative callable to build the key. Receives the key and the namespace
as params and should return something that can be used as key by the underlying backend.
as params and should return a string that can be used as a key by the underlying backend.
:param timeout: int or float in seconds specifying maximum timeout for the operations to last.
By default its 5. Use 0 or None if you want to disable it.
:param ttl: int the expiration time in seconds to use as a default in all operations of
Expand All @@ -109,18 +115,22 @@ class BaseCache:
NAME: str

def __init__(
self, serializer=None, plugins=None, namespace=None, key_builder=None, timeout=5, ttl=None
self,
serializer: Optional["BaseSerializer"] = None,
plugins: Optional[List["BasePlugin"]] = None,
namespace: str = "",
key_builder: Callable[[str, str], str] = lambda key, namespace: f"{namespace}{key}",
timeout: Optional[float] = 5,
ttl: Optional[float] = None,
):
self.timeout = float(timeout) if timeout is not None else timeout
self.namespace = namespace
self.ttl = float(ttl) if ttl is not None else ttl
self.build_key = key_builder or self._build_key
self.timeout = float(timeout) if timeout is not None else None
self.ttl = float(ttl) if ttl is not None else None

self._serializer = None
self.serializer = serializer or serializers.StringSerializer()
self.namespace = namespace
self._build_key = key_builder

self._plugins = None
self.plugins = plugins or []
self._serializer = serializer or StringSerializer()
self._plugins = plugins or []

@property
def serializer(self):
Expand Down Expand Up @@ -162,9 +172,8 @@ async def add(self, key, value, ttl=SENTINEL, dumps_fn=None, namespace=None, _co
- :class:`asyncio.TimeoutError` if it lasts more than self.timeout
"""
start = time.monotonic()
dumps = dumps_fn or self._serializer.dumps
ns = namespace if namespace is not None else self.namespace
ns_key = self.build_key(key, namespace=ns)
dumps = dumps_fn or self.serializer.dumps
ns_key = self.build_key(key, namespace)

await self._add(ns_key, dumps(value), ttl=self._get_ttl(ttl), _conn=_conn)

Expand Down Expand Up @@ -192,9 +201,8 @@ async def get(self, key, default=None, loads_fn=None, namespace=None, _conn=None
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
"""
start = time.monotonic()
loads = loads_fn or self._serializer.loads
ns = namespace if namespace is not None else self.namespace
ns_key = self.build_key(key, namespace=ns)
loads = loads_fn or self.serializer.loads
ns_key = self.build_key(key, namespace)

value = loads(await self._get(ns_key, encoding=self.serializer.encoding, _conn=_conn))

Expand Down Expand Up @@ -224,10 +232,9 @@ async def multi_get(self, keys, loads_fn=None, namespace=None, _conn=None):
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
"""
start = time.monotonic()
loads = loads_fn or self._serializer.loads
ns = namespace if namespace is not None else self.namespace
loads = loads_fn or self.serializer.loads

ns_keys = [self.build_key(key, namespace=ns) for key in keys]
ns_keys = [self.build_key(key, namespace) for key in keys]
values = [
loads(value)
for value in await self._multi_get(
Expand Down Expand Up @@ -269,9 +276,8 @@ async def set(
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
"""
start = time.monotonic()
dumps = dumps_fn or self._serializer.dumps
ns = namespace if namespace is not None else self.namespace
ns_key = self.build_key(key, namespace=ns)
dumps = dumps_fn or self.serializer.dumps
ns_key = self.build_key(key, namespace)

res = await self._set(
ns_key, dumps(value), ttl=self._get_ttl(ttl), _cas_token=_cas_token, _conn=_conn
Expand Down Expand Up @@ -303,12 +309,11 @@ async def multi_set(self, pairs, ttl=SENTINEL, dumps_fn=None, namespace=None, _c
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
"""
start = time.monotonic()
dumps = dumps_fn or self._serializer.dumps
ns = namespace if namespace is not None else self.namespace
dumps = dumps_fn or self.serializer.dumps

tmp_pairs = []
for key, value in pairs:
tmp_pairs.append((self.build_key(key, namespace=ns), dumps(value)))
tmp_pairs.append((self.build_key(key, namespace), dumps(value)))

await self._multi_set(tmp_pairs, ttl=self._get_ttl(ttl), _conn=_conn)

Expand Down Expand Up @@ -339,8 +344,7 @@ async def delete(self, key, namespace=None, _conn=None):
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
"""
start = time.monotonic()
ns = namespace if namespace is not None else self.namespace
ns_key = self.build_key(key, namespace=ns)
ns_key = self.build_key(key, namespace)
ret = await self._delete(ns_key, _conn=_conn)
logger.debug("DELETE %s %d (%.4f)s", ns_key, ret, time.monotonic() - start)
return ret
Expand All @@ -364,8 +368,7 @@ async def exists(self, key, namespace=None, _conn=None):
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
"""
start = time.monotonic()
ns = namespace if namespace is not None else self.namespace
ns_key = self.build_key(key, namespace=ns)
ns_key = self.build_key(key, namespace)
ret = await self._exists(ns_key, _conn=_conn)
logger.debug("EXISTS %s %d (%.4f)s", ns_key, ret, time.monotonic() - start)
return ret
Expand All @@ -392,8 +395,7 @@ async def increment(self, key, delta=1, namespace=None, _conn=None):
:raises: :class:`TypeError` if value is not incrementable
"""
start = time.monotonic()
ns = namespace if namespace is not None else self.namespace
ns_key = self.build_key(key, namespace=ns)
ns_key = self.build_key(key, namespace)
ret = await self._increment(ns_key, delta, _conn=_conn)
logger.debug("INCREMENT %s %d (%.4f)s", ns_key, ret, time.monotonic() - start)
return ret
Expand All @@ -418,8 +420,7 @@ async def expire(self, key, ttl, namespace=None, _conn=None):
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
"""
start = time.monotonic()
ns = namespace if namespace is not None else self.namespace
ns_key = self.build_key(key, namespace=ns)
ns_key = self.build_key(key, namespace)
ret = await self._expire(ns_key, ttl, _conn=_conn)
logger.debug("EXPIRE %s %d (%.4f)s", ns_key, ret, time.monotonic() - start)
return ret
Expand Down Expand Up @@ -498,12 +499,15 @@ async def close(self, *args, _conn=None, **kwargs):
async def _close(self, *args, **kwargs):
pass

def _build_key(self, key, namespace=None):
if namespace is not None:
return "{}{}".format(namespace, _ensure_key(key))
if self.namespace is not None:
return "{}{}".format(self.namespace, _ensure_key(key))
return key
@abstractmethod
def build_key(self, key: str, namespace: Optional[str] = None) -> CacheKeyType:
raise NotImplementedError()

def _str_build_key(self, key: str, namespace: Optional[str] = None) -> str:
"""Simple key builder that can be used in subclasses for build_key()."""
key_name = key.value if isinstance(key, Enum) else key
ns = self.namespace if namespace is None else namespace
return self._build_key(key_name, ns)

def _get_ttl(self, ttl):
return ttl if ttl is not SENTINEL else self.ttl
Expand Down Expand Up @@ -550,12 +554,5 @@ async def _do_inject_conn(self, *args, **kwargs):
return _do_inject_conn


def _ensure_key(key):
if isinstance(key, Enum):
return key.value
else:
return key


for cmd in API.CMDS:
setattr(_Conn, cmd.__name__, _Conn._inject_conn(cmd.__name__))
Loading

0 comments on commit 2df2056

Please sign in to comment.