diff --git a/CHANGES.rst b/CHANGES.rst index 0f539e4a..bc52e18a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -12,9 +12,14 @@ Migration instructions There are a number of backwards-incompatible changes. These points should help with migrating from an older release: +* The ``key_builder`` parameter now expects a callback which accepts 2 strings and returns a string in all cache implementations, making the builders simpler and interchangeable. * The ``key`` parameter has been removed from the ``cached`` decorator. The behaviour can be easily reimplemented with ``key_builder=lambda *a, **kw: "foo"`` * When using the ``key_builder`` parameter in ``@multicached``, the function will now return the original, unmodified keys, only using the transformed keys in the cache (this has always been the documented behaviour, but not the implemented behaviour). * ``BaseSerializer`` is now an ``ABC``, so cannot be instantiated directly. +* If subclassing ``BaseCache`` to implement a custom backend: + + * The cache key type used by the backend must now be specified when inheriting (e.g. ``BaseCache[str]`` typically). + * The ``build_key()`` method must now be defined (this should generally involve calling ``self._str_build_key()`` as a helper). 0.12.0 (2023-01-13) diff --git a/aiocache/__init__.py b/aiocache/__init__.py index 12503ed7..c2b5b765 100644 --- a/aiocache/__init__.py +++ b/aiocache/__init__.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Type +from typing import Any, Dict, Type from .backends.memory import SimpleMemoryCache from .base import BaseCache @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) -AIOCACHE_CACHES: Dict[str, Type[BaseCache]] = {SimpleMemoryCache.NAME: SimpleMemoryCache} +AIOCACHE_CACHES: Dict[str, Type[BaseCache[Any]]] = {SimpleMemoryCache.NAME: SimpleMemoryCache} try: import redis diff --git a/aiocache/backends/memcached.py b/aiocache/backends/memcached.py index 8abf56f7..45e1971a 100644 --- a/aiocache/backends/memcached.py +++ b/aiocache/backends/memcached.py @@ -1,4 +1,5 @@ import asyncio +from typing import Optional import aiomcache @@ -6,7 +7,7 @@ from aiocache.serializers import JsonSerializer -class MemcachedBackend(BaseCache): +class MemcachedBackend(BaseCache[bytes]): def __init__(self, endpoint="127.0.0.1", port=11211, pool_size=2, **kwargs): super().__init__(**kwargs) self.endpoint = endpoint @@ -130,7 +131,7 @@ class MemcachedCache(MemcachedBackend): :param serializer: obj derived from :class:`aiocache.serializers.BaseSerializer`. :param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes. :param namespace: string to use as default prefix for the key used in all operations of - the backend. Default is None + the backend. Default is an empty string, "". :param timeout: int or float in seconds specifying maximum timeout for the operations to last. By default its 5. :param endpoint: str with the endpoint to connect to. Default is 127.0.0.1. @@ -147,8 +148,8 @@ def __init__(self, serializer=None, **kwargs): def parse_uri_path(cls, path): return {} - def _build_key(self, key, namespace=None): - ns_key = super()._build_key(key, namespace=namespace).replace(" ", "_") + def build_key(self, key: str, namespace: Optional[str] = None) -> bytes: + ns_key = self._str_build_key(key, namespace).replace(" ", "_") return str.encode(ns_key) def __repr__(self): # pragma: no cover diff --git a/aiocache/backends/memory.py b/aiocache/backends/memory.py index 7d75d7ee..eafef29e 100644 --- a/aiocache/backends/memory.py +++ b/aiocache/backends/memory.py @@ -1,11 +1,11 @@ import asyncio -from typing import Dict +from typing import Dict, Optional from aiocache.base import BaseCache from aiocache.serializers import NullSerializer -class SimpleMemoryBackend(BaseCache): +class SimpleMemoryBackend(BaseCache[str]): """ Wrapper around dict operations to use it as a cache backend """ @@ -118,7 +118,7 @@ class SimpleMemoryCache(SimpleMemoryBackend): :param serializer: obj derived from :class:`aiocache.serializers.BaseSerializer`. :param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes. :param namespace: string to use as default prefix for the key used in all operations of - the backend. Default is None. + the backend. Default is an empty string, "". :param timeout: int or float in seconds specifying maximum timeout for the operations to last. By default its 5. """ @@ -131,3 +131,6 @@ def __init__(self, serializer=None, **kwargs): @classmethod def parse_uri_path(cls, path): return {} + + def build_key(self, key: str, namespace: Optional[str] = None) -> str: + return self._str_build_key(key, namespace) diff --git a/aiocache/backends/redis.py b/aiocache/backends/redis.py index b0e181f9..408989a1 100644 --- a/aiocache/backends/redis.py +++ b/aiocache/backends/redis.py @@ -1,17 +1,21 @@ import itertools import warnings +from typing import Any, Callable, Optional, TYPE_CHECKING import redis.asyncio as redis from redis.exceptions import ResponseError as IncrbyException -from aiocache.base import BaseCache, _ensure_key +from aiocache.base import BaseCache from aiocache.serializers import JsonSerializer +if TYPE_CHECKING: # pragma: no cover + from aiocache.serializers import BaseSerializer + _NOT_SET = object() -class RedisBackend(BaseCache): +class RedisBackend(BaseCache[str]): RELEASE_SCRIPT = ( "if redis.call('get',KEYS[1]) == ARGV[1] then" " return redis.call('del',KEYS[1])" @@ -186,7 +190,7 @@ class RedisCache(RedisBackend): :param serializer: obj derived from :class:`aiocache.serializers.BaseSerializer`. :param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes. :param namespace: string to use as default prefix for the key used in all operations of - the backend. Default is None. + the backend. Default is an empty string, "". :param timeout: int or float in seconds specifying maximum timeout for the operations to last. By default its 5. :param endpoint: str with the endpoint to connect to. Default is "127.0.0.1". @@ -199,8 +203,21 @@ class RedisCache(RedisBackend): NAME = "redis" - def __init__(self, serializer=None, **kwargs): - super().__init__(serializer=serializer or JsonSerializer(), **kwargs) + def __init__( + self, + serializer: Optional["BaseSerializer"] = None, + namespace: str = "", + key_builder: Optional[Callable[[str, str], str]] = None, + **kwargs: Any, + ): + super().__init__( + serializer=serializer or JsonSerializer(), + namespace=namespace, + key_builder=key_builder or ( + lambda key, namespace: f"{namespace}:{key}" if namespace else key + ), + **kwargs, + ) @classmethod def parse_uri_path(cls, path): @@ -218,14 +235,8 @@ def parse_uri_path(cls, path): options["db"] = db return options - def _build_key(self, key, namespace=None): - if namespace is not None: - return "{}{}{}".format( - namespace, ":" if namespace else "", _ensure_key(key)) - if self.namespace is not None: - return "{}{}{}".format( - self.namespace, ":" if self.namespace else "", _ensure_key(key)) - return key - def __repr__(self): # pragma: no cover return "RedisCache ({}:{})".format(self.endpoint, self.port) + + def build_key(self, key: str, namespace: Optional[str] = None) -> str: + return self._str_build_key(key, namespace) diff --git a/aiocache/base.py b/aiocache/base.py index 54e9d18c..ce44dfbe 100644 --- a/aiocache/base.py +++ b/aiocache/base.py @@ -3,16 +3,22 @@ import logging import os import time +from abc import abstractmethod from enum import Enum from types import TracebackType -from typing import Callable, Optional, Set, Type +from typing import Callable, Generic, List, Optional, Set, TYPE_CHECKING, Type, TypeVar -from aiocache import serializers +from aiocache.serializers import StringSerializer + +if TYPE_CHECKING: # pragma: no cover + from aiocache.plugins import BasePlugin + from aiocache.serializers import BaseSerializer logger = logging.getLogger(__name__) SENTINEL = object() +CacheKeyType = TypeVar("CacheKeyType") class API: @@ -87,7 +93,7 @@ async def _plugins(self, *args, **kwargs): return _plugins -class BaseCache: +class BaseCache(Generic[CacheKeyType]): """ Base class that agregates the common logic for the different caches that may exist. Cache related available options are: @@ -97,9 +103,9 @@ class BaseCache: :param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes. Default is empty list. :param namespace: string to use as default prefix for the key used in all operations of - the backend. Default is None + the backend. Default is an empty string, "". :param key_builder: alternative callable to build the key. Receives the key and the namespace - as params and should return something that can be used as key by the underlying backend. + as params and should return a string that can be used as a key by the underlying backend. :param timeout: int or float in seconds specifying maximum timeout for the operations to last. By default its 5. Use 0 or None if you want to disable it. :param ttl: int the expiration time in seconds to use as a default in all operations of @@ -109,18 +115,22 @@ class BaseCache: NAME: str def __init__( - self, serializer=None, plugins=None, namespace=None, key_builder=None, timeout=5, ttl=None + self, + serializer: Optional["BaseSerializer"] = None, + plugins: Optional[List["BasePlugin"]] = None, + namespace: str = "", + key_builder: Callable[[str, str], str] = lambda key, namespace: f"{namespace}{key}", + timeout: Optional[float] = 5, + ttl: Optional[float] = None, ): - self.timeout = float(timeout) if timeout is not None else timeout - self.namespace = namespace - self.ttl = float(ttl) if ttl is not None else ttl - self.build_key = key_builder or self._build_key + self.timeout = float(timeout) if timeout is not None else None + self.ttl = float(ttl) if ttl is not None else None - self._serializer = None - self.serializer = serializer or serializers.StringSerializer() + self.namespace = namespace + self._build_key = key_builder - self._plugins = None - self.plugins = plugins or [] + self._serializer = serializer or StringSerializer() + self._plugins = plugins or [] @property def serializer(self): @@ -162,9 +172,8 @@ async def add(self, key, value, ttl=SENTINEL, dumps_fn=None, namespace=None, _co - :class:`asyncio.TimeoutError` if it lasts more than self.timeout """ start = time.monotonic() - dumps = dumps_fn or self._serializer.dumps - ns = namespace if namespace is not None else self.namespace - ns_key = self.build_key(key, namespace=ns) + dumps = dumps_fn or self.serializer.dumps + ns_key = self.build_key(key, namespace) await self._add(ns_key, dumps(value), ttl=self._get_ttl(ttl), _conn=_conn) @@ -192,9 +201,8 @@ async def get(self, key, default=None, loads_fn=None, namespace=None, _conn=None :raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout """ start = time.monotonic() - loads = loads_fn or self._serializer.loads - ns = namespace if namespace is not None else self.namespace - ns_key = self.build_key(key, namespace=ns) + loads = loads_fn or self.serializer.loads + ns_key = self.build_key(key, namespace) value = loads(await self._get(ns_key, encoding=self.serializer.encoding, _conn=_conn)) @@ -224,10 +232,9 @@ async def multi_get(self, keys, loads_fn=None, namespace=None, _conn=None): :raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout """ start = time.monotonic() - loads = loads_fn or self._serializer.loads - ns = namespace if namespace is not None else self.namespace + loads = loads_fn or self.serializer.loads - ns_keys = [self.build_key(key, namespace=ns) for key in keys] + ns_keys = [self.build_key(key, namespace) for key in keys] values = [ loads(value) for value in await self._multi_get( @@ -269,9 +276,8 @@ async def set( :raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout """ start = time.monotonic() - dumps = dumps_fn or self._serializer.dumps - ns = namespace if namespace is not None else self.namespace - ns_key = self.build_key(key, namespace=ns) + dumps = dumps_fn or self.serializer.dumps + ns_key = self.build_key(key, namespace) res = await self._set( ns_key, dumps(value), ttl=self._get_ttl(ttl), _cas_token=_cas_token, _conn=_conn @@ -303,12 +309,11 @@ async def multi_set(self, pairs, ttl=SENTINEL, dumps_fn=None, namespace=None, _c :raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout """ start = time.monotonic() - dumps = dumps_fn or self._serializer.dumps - ns = namespace if namespace is not None else self.namespace + dumps = dumps_fn or self.serializer.dumps tmp_pairs = [] for key, value in pairs: - tmp_pairs.append((self.build_key(key, namespace=ns), dumps(value))) + tmp_pairs.append((self.build_key(key, namespace), dumps(value))) await self._multi_set(tmp_pairs, ttl=self._get_ttl(ttl), _conn=_conn) @@ -339,8 +344,7 @@ async def delete(self, key, namespace=None, _conn=None): :raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout """ start = time.monotonic() - ns = namespace if namespace is not None else self.namespace - ns_key = self.build_key(key, namespace=ns) + ns_key = self.build_key(key, namespace) ret = await self._delete(ns_key, _conn=_conn) logger.debug("DELETE %s %d (%.4f)s", ns_key, ret, time.monotonic() - start) return ret @@ -364,8 +368,7 @@ async def exists(self, key, namespace=None, _conn=None): :raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout """ start = time.monotonic() - ns = namespace if namespace is not None else self.namespace - ns_key = self.build_key(key, namespace=ns) + ns_key = self.build_key(key, namespace) ret = await self._exists(ns_key, _conn=_conn) logger.debug("EXISTS %s %d (%.4f)s", ns_key, ret, time.monotonic() - start) return ret @@ -392,8 +395,7 @@ async def increment(self, key, delta=1, namespace=None, _conn=None): :raises: :class:`TypeError` if value is not incrementable """ start = time.monotonic() - ns = namespace if namespace is not None else self.namespace - ns_key = self.build_key(key, namespace=ns) + ns_key = self.build_key(key, namespace) ret = await self._increment(ns_key, delta, _conn=_conn) logger.debug("INCREMENT %s %d (%.4f)s", ns_key, ret, time.monotonic() - start) return ret @@ -418,8 +420,7 @@ async def expire(self, key, ttl, namespace=None, _conn=None): :raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout """ start = time.monotonic() - ns = namespace if namespace is not None else self.namespace - ns_key = self.build_key(key, namespace=ns) + ns_key = self.build_key(key, namespace) ret = await self._expire(ns_key, ttl, _conn=_conn) logger.debug("EXPIRE %s %d (%.4f)s", ns_key, ret, time.monotonic() - start) return ret @@ -498,12 +499,15 @@ async def close(self, *args, _conn=None, **kwargs): async def _close(self, *args, **kwargs): pass - def _build_key(self, key, namespace=None): - if namespace is not None: - return "{}{}".format(namespace, _ensure_key(key)) - if self.namespace is not None: - return "{}{}".format(self.namespace, _ensure_key(key)) - return key + @abstractmethod + def build_key(self, key: str, namespace: Optional[str] = None) -> CacheKeyType: + raise NotImplementedError() + + def _str_build_key(self, key: str, namespace: Optional[str] = None) -> str: + """Simple key builder that can be used in subclasses for build_key().""" + key_name = key.value if isinstance(key, Enum) else key + ns = self.namespace if namespace is None else namespace + return self._build_key(key_name, ns) def _get_ttl(self, ttl): return ttl if ttl is not SENTINEL else self.ttl @@ -550,12 +554,5 @@ async def _do_inject_conn(self, *args, **kwargs): return _do_inject_conn -def _ensure_key(key): - if isinstance(key, Enum): - return key.value - else: - return key - - for cmd in API.CMDS: setattr(_Conn, cmd.__name__, _Conn._inject_conn(cmd.__name__)) diff --git a/aiocache/decorators.py b/aiocache/decorators.py index 11122cda..803a8796 100644 --- a/aiocache/decorators.py +++ b/aiocache/decorators.py @@ -35,7 +35,7 @@ class cached: :param ttl: int seconds to store the function call. Default is None which means no expiration. :param namespace: string to use as default prefix for the key used in all operations of - the backend. Default is None + the backend. Default is an empty string, "". :param key_builder: Callable that allows to build the function dynamically. It receives the function plus same args and kwargs passed to the function. This behavior is necessarily different than ``BaseCache.build_key()`` @@ -60,7 +60,7 @@ class cached: def __init__( self, ttl=SENTINEL, - namespace=None, + namespace="", key_builder=None, skip_cache_func=lambda x: False, cache=Cache.MEMORY, @@ -144,7 +144,7 @@ def _key_from_args(self, func, args, kwargs): + str(ordered_kwargs) ) - async def get_from_cache(self, key: str): + async def get_from_cache(self, key): try: return await self.cache.get(key) except Exception: @@ -176,7 +176,7 @@ class cached_stampede(cached): :param ttl: int seconds to store the function call. Default is None which means no expiration. :param key_from_attr: str arg or kwarg name from the function to use as a key. :param namespace: string to use as default prefix for the key used in all operations of - the backend. Default is None + the backend. Default is an empty string, "". :param key_builder: Callable that allows to build the function dynamically. It receives the function plus same args and kwargs passed to the function. This behavior is necessarily different than ``BaseCache.build_key()`` @@ -278,7 +278,7 @@ class multi_cached: :param keys_from_attr: name of the arg or kwarg in the decorated callable that contains an iterable that yields the keys returned by the decorated callable. :param namespace: string to use as default prefix for the key used in all operations of - the backend. Default is None + the backend. Default is an empty string, "". :param key_builder: Callable that enables mapping the decorated function's keys to the keys used by the cache. Receives a key from the iterable corresponding to ``keys_from_attr``, the decorated callable, and the positional and keyword arguments @@ -303,7 +303,7 @@ class multi_cached: def __init__( self, keys_from_attr, - namespace=None, + namespace="", key_builder=None, skip_cache_func=lambda k, v: False, ttl=SENTINEL, diff --git a/aiocache/lock.py b/aiocache/lock.py index d7e7e1cb..34e2299c 100644 --- a/aiocache/lock.py +++ b/aiocache/lock.py @@ -1,11 +1,11 @@ import asyncio import uuid -from typing import Any, Dict, Union +from typing import Any, Dict, Generic, Union -from aiocache.base import BaseCache +from aiocache.base import BaseCache, CacheKeyType -class RedLock: +class RedLock(Generic[CacheKeyType]): """ Implementation of `Redlock `_ with a single instance because aiocache is focused on single @@ -62,7 +62,7 @@ class RedLock: _EVENTS: Dict[str, asyncio.Event] = {} - def __init__(self, client: BaseCache, key: str, lease: Union[int, float]): + def __init__(self, client: BaseCache[CacheKeyType], key: str, lease: Union[int, float]): self.client = client self.key = self.client.build_key(key + "-lock") self.lease = lease @@ -96,7 +96,7 @@ async def _release(self): RedLock._EVENTS.pop(self.key).set() -class OptimisticLock: +class OptimisticLock(Generic[CacheKeyType]): """ Implementation of `optimistic lock `_ @@ -133,7 +133,7 @@ class OptimisticLock: If the lock is created with an unexisting key, there will never be conflicts. """ - def __init__(self, client: BaseCache, key: str): + def __init__(self, client: BaseCache[CacheKeyType], key: str): self.client = client self.key = key self.ns_key = self.client.build_key(key) diff --git a/examples/alt_key_builder.py b/examples/alt_key_builder.py index 81dc09ad..de13a9ea 100644 --- a/examples/alt_key_builder.py +++ b/examples/alt_key_builder.py @@ -21,7 +21,7 @@ Args: key (str): undecorated key name - namespace (str, optional): Prefix to add to the key. Defaults to None. + namespace (str, optional): Prefix to add to the key. Defaults to "". Returns: By default, ``cache.build_key()`` returns ``f'{namespace}{sep}{key}'``, @@ -59,21 +59,21 @@ async def demo_key_builders(): # 1. Custom ``key_builder`` for a cache # ------------------------------------- -def ensure_no_spaces(key, namespace=None, replace='_'): +def ensure_no_spaces(key, namespace, replace="_"): """Prefix key with namespace; replace each space with ``replace``""" - aggregate_key = f"{namespace or ''}{key}" + aggregate_key = f"{namespace}{key}" custom_key = aggregate_key.replace(' ', replace) return custom_key -def bytes_key(key, namespace=None): +def bytes_key(key, namespace): """Prefix key with namespace; convert output to bytes""" - aggregate_key = f"{namespace or ''}{key}" + aggregate_key = f"{namespace}{key}" custom_key = aggregate_key.encode() return custom_key -def fixed_key(key, namespace=None): +def fixed_key(key, namespace): """Ignore input, generate a fixed key""" unchanging_key = "universal key" return unchanging_key diff --git a/tests/acceptance/test_decorators.py b/tests/acceptance/test_decorators.py index f9608e56..ad99aca7 100644 --- a/tests/acceptance/test_decorators.py +++ b/tests/acceptance/test_decorators.py @@ -5,8 +5,7 @@ import pytest from aiocache import cached, cached_stampede, multi_cached -from aiocache.base import _ensure_key -from ..utils import Keys +from ..utils import Keys, ensure_key async def return_dict(keys=None): @@ -164,15 +163,15 @@ async def fn(keys): async def test_multi_cached_key_builder(self, cache): def build_key(key, f, self, keys, market="ES"): - return "{}_{}_{}".format(f.__name__, _ensure_key(key), market) + return "{}_{}_{}".format(f.__name__, ensure_key(key), market) @multi_cached(keys_from_attr="keys", key_builder=build_key) async def fn(self, keys, market="ES"): return {Keys.KEY: 1, Keys.KEY_1: 2} await fn("self", keys=[Keys.KEY, Keys.KEY_1]) - assert await cache.exists("fn_" + _ensure_key(Keys.KEY) + "_ES") is True - assert await cache.exists("fn_" + _ensure_key(Keys.KEY_1) + "_ES") is True + assert await cache.exists("fn_" + ensure_key(Keys.KEY) + "_ES") is True + assert await cache.exists("fn_" + ensure_key(Keys.KEY_1) + "_ES") is True async def test_multi_cached_skip_keys(self, cache): @multi_cached(keys_from_attr="keys", skip_cache_func=lambda _, v: v is None) diff --git a/tests/acceptance/test_lock.py b/tests/acceptance/test_lock.py index 665fea90..f151d618 100644 --- a/tests/acceptance/test_lock.py +++ b/tests/acceptance/test_lock.py @@ -200,7 +200,7 @@ def lock(self, cache): async def test_acquire(self, cache, lock): await cache.set(Keys.KEY, "value") async with lock: - assert lock._token == await cache._gets(cache._build_key(Keys.KEY)) + assert lock._token == await cache._gets(cache.build_key(Keys.KEY)) async def test_release_does_nothing(self, lock): assert await lock.__aexit__("exc_type", "exc_value", "traceback") is None diff --git a/tests/ut/backends/test_memcached.py b/tests/ut/backends/test_memcached.py index 7e011618..81cd744c 100644 --- a/tests/ut/backends/test_memcached.py +++ b/tests/ut/backends/test_memcached.py @@ -4,9 +4,9 @@ import pytest from aiocache.backends.memcached import MemcachedBackend, MemcachedCache -from aiocache.base import BaseCache, _ensure_key +from aiocache.base import BaseCache from aiocache.serializers import JsonSerializer -from ...utils import Keys +from ...utils import Keys, ensure_key @pytest.fixture @@ -249,10 +249,10 @@ def test_parse_uri_path(self): @pytest.mark.parametrize( "namespace, expected", - ([None, "test" + _ensure_key(Keys.KEY)], ["", _ensure_key(Keys.KEY)], ["my_ns", "my_ns" + _ensure_key(Keys.KEY)]), # type: ignore[attr-defined] # noqa: B950 + ([None, "test" + ensure_key(Keys.KEY)], ["", ensure_key(Keys.KEY)], ["my_ns", "my_ns" + ensure_key(Keys.KEY)]), # noqa: B950 ) def test_build_key_bytes(self, set_test_namespace, memcached_cache, namespace, expected): - assert memcached_cache.build_key(Keys.KEY, namespace=namespace) == expected.encode() + assert memcached_cache.build_key(Keys.KEY, namespace) == expected.encode() def test_build_key_no_namespace(self, memcached_cache): assert memcached_cache.build_key(Keys.KEY, namespace=None) == Keys.KEY.encode() diff --git a/tests/ut/backends/test_redis.py b/tests/ut/backends/test_redis.py index 2470584a..c6ad755a 100644 --- a/tests/ut/backends/test_redis.py +++ b/tests/ut/backends/test_redis.py @@ -5,9 +5,9 @@ from redis.exceptions import ResponseError from aiocache.backends.redis import RedisBackend, RedisCache -from aiocache.base import BaseCache, _ensure_key +from aiocache.base import BaseCache from aiocache.serializers import JsonSerializer -from ...utils import Keys +from ...utils import Keys, ensure_key @pytest.fixture @@ -253,10 +253,10 @@ def test_parse_uri_path(self, path, expected): @pytest.mark.parametrize( "namespace, expected", - ([None, "test:" + _ensure_key(Keys.KEY)], ["", _ensure_key(Keys.KEY)], ["my_ns", "my_ns:" + _ensure_key(Keys.KEY)]), # noqa: B950 + ([None, "test:" + ensure_key(Keys.KEY)], ["", ensure_key(Keys.KEY)], ["my_ns", "my_ns:" + ensure_key(Keys.KEY)]), # noqa: B950 ) def test_build_key_double_dot(self, set_test_namespace, redis_cache, namespace, expected): - assert redis_cache.build_key(Keys.KEY, namespace=namespace) == expected + assert redis_cache.build_key(Keys.KEY, namespace) == expected def test_build_key_no_namespace(self, redis_cache): assert redis_cache.build_key(Keys.KEY, namespace=None) == Keys.KEY diff --git a/tests/ut/conftest.py b/tests/ut/conftest.py index 0adcd1f4..385294c9 100644 --- a/tests/ut/conftest.py +++ b/tests/ut/conftest.py @@ -29,14 +29,14 @@ def reset_caches(): @pytest.fixture def mock_cache(mocker): - return create_autospec(BaseCache()) + return create_autospec(BaseCache[str]()) @pytest.fixture def mock_base_cache(): """Return BaseCache instance with unimplemented methods mocked out.""" plugin = create_autospec(BasePlugin, instance=True) - cache = BaseCache(timeout=0.002, plugins=(plugin,)) + cache = BaseCache[str](timeout=0.002, plugins=(plugin,)) methods = ("_add", "_get", "_gets", "_set", "_multi_get", "_multi_set", "_delete", "_exists", "_increment", "_expire", "_clear", "_raw", "_close", "_redlock_release", "acquire_conn", "release_conn") @@ -44,12 +44,22 @@ def mock_base_cache(): for f in methods: stack.enter_context(patch.object(cache, f, autospec=True)) stack.enter_context(patch.object(cache, "_serializer", autospec=True)) + stack.enter_context(patch.object(cache, "build_key", cache._str_build_key)) yield cache +@pytest.fixture +def abstract_base_cache(): + # TODO: Is there need for a separate BaseCache[bytes] fixture? + return BaseCache[str]() + + @pytest.fixture def base_cache(): - return BaseCache() + # TODO: Is there need for a separate BaseCache[bytes] fixture? + cache = BaseCache[str]() + cache.build_key = cache._str_build_key + return cache @pytest.fixture diff --git a/tests/ut/test_base.py b/tests/ut/test_base.py index 569b5245..3449dabf 100644 --- a/tests/ut/test_base.py +++ b/tests/ut/test_base.py @@ -4,8 +4,8 @@ import pytest -from aiocache.base import API, BaseCache, _Conn, _ensure_key -from ..utils import Keys +from aiocache.base import API, BaseCache, _Conn +from ..utils import Keys, ensure_key class TestAPI: @@ -137,11 +137,11 @@ async def dummy(self, *args, **kwargs): class TestBaseCache: def test_str_ttl(self): - cache = BaseCache(ttl="1.5") + cache = BaseCache[str](ttl="1.5") assert cache.ttl == 1.5 def test_str_timeout(self): - cache = BaseCache(timeout="1.5") + cache = BaseCache[str](timeout="1.5") assert cache.timeout == 1.5 async def test_add(self, base_cache): @@ -197,49 +197,62 @@ async def test_acquire_conn(self, base_cache): async def test_release_conn(self, base_cache): assert await base_cache.release_conn("mock") is None + def test_abstract_build_key(self, abstract_base_cache): + with pytest.raises(NotImplementedError): + abstract_base_cache.build_key(Keys.KEY) + @pytest.fixture def set_test_namespace(self, base_cache): base_cache.namespace = "test" yield - base_cache.namespace = None + base_cache.namespace = "" + + @pytest.mark.parametrize( + "namespace, expected", + ([None, "None" + ensure_key(Keys.KEY)], ["", ensure_key(Keys.KEY)], ["my_ns", "my_ns" + ensure_key(Keys.KEY)]), # noqa: B950 + ) + def test_str_build_key(self, set_test_namespace, namespace, expected): + # TODO: Runtime check for namespace=None: Raise ValueError or replace with ""? + cache = BaseCache[str](namespace=namespace) + assert cache._str_build_key(Keys.KEY) == expected @pytest.mark.parametrize( "namespace, expected", - ([None, "test" + _ensure_key(Keys.KEY)], ["", _ensure_key(Keys.KEY)], ["my_ns", "my_ns" + _ensure_key(Keys.KEY)]), # type: ignore[attr-defined] # noqa: B950 + ([None, "test" + ensure_key(Keys.KEY)], ["", ensure_key(Keys.KEY)], ["my_ns", "my_ns" + ensure_key(Keys.KEY)]), # noqa: B950 ) def test_build_key(self, set_test_namespace, base_cache, namespace, expected): - assert base_cache.build_key(Keys.KEY, namespace=namespace) == expected + assert base_cache.build_key(Keys.KEY, namespace) == expected + + def patch_str_build_key(self, cache: BaseCache[str]) -> None: + """Implement build_key() on BaseCache[str] as if it were subclassed""" + cache.build_key = cache._str_build_key # type: ignore[assignment] + return def test_alt_build_key(self): - cache = BaseCache(key_builder=lambda key, namespace: "x") + cache = BaseCache[str](key_builder=lambda key, namespace: "x") + self.patch_str_build_key(cache) assert cache.build_key(Keys.KEY, "namespace") == "x" - @pytest.fixture - def alt_base_cache(self, init_namespace="test"): + def alt_build_key(self, key, namespace): """Custom key_builder for cache""" - def build_key(key, namespace=None): - ns = namespace if namespace is not None else "" - sep = ":" if namespace else "" - return f"{ns}{sep}{_ensure_key(key)}" - - cache = BaseCache(key_builder=build_key, namespace=init_namespace) - return cache + sep = ":" if namespace else "" + return f"{namespace}{sep}{ensure_key(key)}" @pytest.mark.parametrize( "namespace, expected", - ([None, _ensure_key(Keys.KEY)], ["", _ensure_key(Keys.KEY)], ["my_ns", "my_ns:" + _ensure_key(Keys.KEY)]), # type: ignore[attr-defined] # noqa: B950 + ([None, "test:" + ensure_key(Keys.KEY)], ["", ensure_key(Keys.KEY)], ["my_ns", "my_ns:" + ensure_key(Keys.KEY)]), # noqa: B950 ) - def test_alt_build_key_override_namespace(self, alt_base_cache, namespace, expected): + def test_alt_build_key_override_namespace(self, namespace, expected): """Custom key_builder overrides namespace of cache""" - cache = alt_base_cache - assert cache.build_key(Keys.KEY, namespace=namespace) == expected + cache = BaseCache[str](key_builder=self.alt_build_key, namespace="test") + self.patch_str_build_key(cache) + assert cache.build_key(Keys.KEY, namespace) == expected @pytest.mark.parametrize( - "init_namespace, expected", - ([None, _ensure_key(Keys.KEY)], ["", _ensure_key(Keys.KEY)], ["test", "test:" + _ensure_key(Keys.KEY)]), # type: ignore[attr-defined] # noqa: B950 + "namespace, expected", + ([None, "None" + ensure_key(Keys.KEY)], ["", ensure_key(Keys.KEY)], ["test", "test:" + ensure_key(Keys.KEY)]), # noqa: B950 ) - async def test_alt_build_key_default_namespace( - self, init_namespace, alt_base_cache, expected): + async def test_alt_build_key_default_namespace(self, namespace, expected): """Custom key_builder for cache with or without namespace specified. Cache member functions that accept a ``namespace`` parameter @@ -251,8 +264,8 @@ async def test_alt_build_key_default_namespace( even when that cache is supplied to a lock or to a decorator using the ``alias`` argument. """ - cache = alt_base_cache - cache.namespace = init_namespace + cache = BaseCache[str](key_builder=self.alt_build_key, namespace=namespace) + self.patch_str_build_key(cache) # Verify that private members are called with the correct ns_key await self._assert_add__alt_build_key_default_namespace(cache, expected) @@ -429,7 +442,7 @@ async def test_get(self, mock_base_cache): await mock_base_cache.get(Keys.KEY) mock_base_cache._get.assert_called_with( - mock_base_cache._build_key(Keys.KEY), encoding=ANY, _conn=ANY + mock_base_cache.build_key(Keys.KEY), encoding=ANY, _conn=ANY ) assert mock_base_cache.plugins[0].pre_get.call_count == 1 assert mock_base_cache.plugins[0].post_get.call_count == 1 @@ -454,7 +467,7 @@ async def test_set(self, mock_base_cache): await mock_base_cache.set(Keys.KEY, "value", ttl=2) mock_base_cache._set.assert_called_with( - mock_base_cache._build_key(Keys.KEY), ANY, ttl=2, _cas_token=None, _conn=ANY + mock_base_cache.build_key(Keys.KEY), ANY, ttl=2, _cas_token=None, _conn=ANY ) assert mock_base_cache.plugins[0].pre_set.call_count == 1 assert mock_base_cache.plugins[0].post_set.call_count == 1 @@ -469,7 +482,7 @@ async def test_add(self, mock_base_cache): mock_base_cache._exists = AsyncMock(return_value=False) await mock_base_cache.add(Keys.KEY, "value", ttl=2) - key = mock_base_cache._build_key(Keys.KEY) + key = mock_base_cache.build_key(Keys.KEY) mock_base_cache._add.assert_called_with(key, ANY, ttl=2, _conn=ANY) assert mock_base_cache.plugins[0].pre_add.call_count == 1 assert mock_base_cache.plugins[0].post_add.call_count == 1 @@ -484,7 +497,7 @@ async def test_mget(self, mock_base_cache): await mock_base_cache.multi_get([Keys.KEY, Keys.KEY_1]) mock_base_cache._multi_get.assert_called_with( - [mock_base_cache._build_key(Keys.KEY), mock_base_cache._build_key(Keys.KEY_1)], + [mock_base_cache.build_key(Keys.KEY), mock_base_cache.build_key(Keys.KEY_1)], encoding=ANY, _conn=ANY, ) @@ -500,8 +513,8 @@ async def test_mget_timeouts(self, mock_base_cache): async def test_mset(self, mock_base_cache): await mock_base_cache.multi_set([[Keys.KEY, "value"], [Keys.KEY_1, "value1"]], ttl=2) - key = mock_base_cache._build_key(Keys.KEY) - key1 = mock_base_cache._build_key(Keys.KEY_1) + key = mock_base_cache.build_key(Keys.KEY) + key1 = mock_base_cache.build_key(Keys.KEY_1) mock_base_cache._multi_set.assert_called_with( [(key, ANY), (key1, ANY)], ttl=2, _conn=ANY) assert mock_base_cache.plugins[0].pre_multi_set.call_count == 1 @@ -516,7 +529,7 @@ async def test_mset_timeouts(self, mock_base_cache): async def test_exists(self, mock_base_cache): await mock_base_cache.exists(Keys.KEY) - mock_base_cache._exists.assert_called_with(mock_base_cache._build_key(Keys.KEY), _conn=ANY) + mock_base_cache._exists.assert_called_with(mock_base_cache.build_key(Keys.KEY), _conn=ANY) assert mock_base_cache.plugins[0].pre_exists.call_count == 1 assert mock_base_cache.plugins[0].post_exists.call_count == 1 @@ -529,7 +542,7 @@ async def test_exists_timeouts(self, mock_base_cache): async def test_increment(self, mock_base_cache): await mock_base_cache.increment(Keys.KEY, 2) - key = mock_base_cache._build_key(Keys.KEY) + key = mock_base_cache.build_key(Keys.KEY) mock_base_cache._increment.assert_called_with(key, 2, _conn=ANY) assert mock_base_cache.plugins[0].pre_increment.call_count == 1 assert mock_base_cache.plugins[0].post_increment.call_count == 1 @@ -543,7 +556,7 @@ async def test_increment_timeouts(self, mock_base_cache): async def test_delete(self, mock_base_cache): await mock_base_cache.delete(Keys.KEY) - mock_base_cache._delete.assert_called_with(mock_base_cache._build_key(Keys.KEY), _conn=ANY) + mock_base_cache._delete.assert_called_with(mock_base_cache.build_key(Keys.KEY), _conn=ANY) assert mock_base_cache.plugins[0].pre_delete.call_count == 1 assert mock_base_cache.plugins[0].post_delete.call_count == 1 @@ -555,7 +568,7 @@ async def test_delete_timeouts(self, mock_base_cache): async def test_expire(self, mock_base_cache): await mock_base_cache.expire(Keys.KEY, 1) - key = mock_base_cache._build_key(Keys.KEY) + key = mock_base_cache.build_key(Keys.KEY) mock_base_cache._expire.assert_called_with(key, 1, _conn=ANY) assert mock_base_cache.plugins[0].pre_expire.call_count == 1 assert mock_base_cache.plugins[0].post_expire.call_count == 1 @@ -568,7 +581,7 @@ async def test_expire_timeouts(self, mock_base_cache): async def test_clear(self, mock_base_cache): await mock_base_cache.clear(Keys.KEY) - mock_base_cache._clear.assert_called_with(mock_base_cache._build_key(Keys.KEY), _conn=ANY) + mock_base_cache._clear.assert_called_with(mock_base_cache.build_key(Keys.KEY), _conn=ANY) assert mock_base_cache.plugins[0].pre_clear.call_count == 1 assert mock_base_cache.plugins[0].post_clear.call_count == 1 @@ -581,7 +594,7 @@ async def test_clear_timeouts(self, mock_base_cache): async def test_raw(self, mock_base_cache): await mock_base_cache.raw("get", Keys.KEY) mock_base_cache._raw.assert_called_with( - "get", mock_base_cache._build_key(Keys.KEY), encoding=ANY, _conn=ANY + "get", mock_base_cache.build_key(Keys.KEY), encoding=ANY, _conn=ANY ) assert mock_base_cache.plugins[0].pre_raw.call_count == 1 assert mock_base_cache.plugins[0].post_raw.call_count == 1 diff --git a/tests/ut/test_lock.py b/tests/ut/test_lock.py index 1bac2bff..a06ea567 100644 --- a/tests/ut/test_lock.py +++ b/tests/ut/test_lock.py @@ -86,7 +86,7 @@ def test_init(self, mock_base_cache, lock): assert lock.client == mock_base_cache assert lock._token is None assert lock.key == Keys.KEY - assert lock.ns_key == mock_base_cache._build_key(Keys.KEY) + assert lock.ns_key == mock_base_cache.build_key(Keys.KEY) async def test_aenter_returns_lock(self, lock): assert await lock.__aenter__() is lock diff --git a/tests/utils.py b/tests/utils.py index ab884c7f..a6f622fe 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,3 +7,10 @@ class Keys(str, Enum): KEY_LOCK = Keys.KEY + "-lock" + + +def ensure_key(key): + if isinstance(key, Enum): + return key.value + else: + return key