diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a311d54..5647fda 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -15,8 +15,9 @@ jobs: runs-on: "ubuntu-latest" strategy: + fail-fast: false matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: "actions/checkout@v3" @@ -37,7 +38,7 @@ jobs: key: python-${{ matrix.python-version }}-pydeps-${{ hashFiles('**/poetry.lock') }} - name: "Install Dependencies" if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: "poetry install --no-interaction --no-root" + run: "poetry install --no-interaction --no-root --all-extras" - name: "Run Lint" run: "make lint" - name: "Run Tests" diff --git a/Makefile b/Makefile index d9a0be0..3b59ab9 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,6 @@ benchmark: .PHONY: lint lint: - poetry run mypy --check-untyped-defs --ignore-missing-imports . + poetry run mypy . trace_bench: poetry run python -m benchmarks.trace_bench diff --git a/pyproject.toml b/pyproject.toml index 0d3266f..0c33318 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ classifiers = [ license = "BSD-3-Clause" [tool.poetry.dependencies] -python = "^3.7" +python = ">=3.8,<4.0" typing-extensions = "^4.4.0" theine-core = "^0.4.3" @@ -20,7 +20,7 @@ theine-core = "^0.4.3" pytest = "^7.2.1" pytest-benchmark = "^4.0.0" typing-extensions = "^4.4.0" -mypy = "^1.0.0" +mypy = "1.11.1" django = "^3.2" pytest-django = "^4.5.2" pytest-asyncio = "^0.20.3" @@ -30,6 +30,19 @@ isort = "^5.5.0" py-spy = "^0.3.14" cacheout = "^0.14.1" bounded-zipf = "^1.0.0" +django-stubs = "^5.0.2" + +[tool.mypy] +strict = true +plugins = ["mypy_django_plugin.main", ] +exclude = [ + "benchmarks", + "tests" +] + +[tool.django-stubs] +django_settings_module = 'theine' +strict_settings = false [build-system] requires = ["poetry-core"] diff --git a/tests/adapters/test_django.py b/tests/adapters/test_django.py index 2b1b718..f030fc0 100644 --- a/tests/adapters/test_django.py +++ b/tests/adapters/test_django.py @@ -16,23 +16,23 @@ def cache() -> Iterable[BaseCache]: class TestTheineCache: - def test_settings(self, cache: BaseCache): + def test_settings(self, cache: BaseCache) -> None: assert cache._max_entries == 1000 assert cache.default_timeout == 60 - def test_unicode_keys(self, cache: BaseCache): + def test_unicode_keys(self, cache: BaseCache) -> None: cache.set("ключ", "value") res = cache.get("ключ") assert res == "value" - def test_save_and_integer(self, cache: BaseCache): + def test_save_and_integer(self, cache: BaseCache) -> None: cache.set("test_key", 2) res = cache.get("test_key", "Foo") assert isinstance(res, int) assert res == 2 - def test_save_string(self, cache: BaseCache): + def test_save_string(self, cache: BaseCache) -> None: cache.set("test_key", "hello" * 1000) res = cache.get("test_key") @@ -45,14 +45,14 @@ def test_save_string(self, cache: BaseCache): assert isinstance(res, str) assert res == "2" - def test_save_unicode(self, cache: BaseCache): + def test_save_unicode(self, cache: BaseCache) -> None: cache.set("test_key", "heló") res = cache.get("test_key") assert isinstance(res, str) assert res == "heló" - def test_save_dict(self, cache: BaseCache): + def test_save_dict(self, cache: BaseCache) -> None: now_dt = datetime.datetime.now() test_dict = {"id": 1, "date": now_dt, "name": "Foo"} @@ -64,7 +64,7 @@ def test_save_dict(self, cache: BaseCache): assert res["name"] == "Foo" assert res["date"] == now_dt - def test_save_float(self, cache: BaseCache): + def test_save_float(self, cache: BaseCache) -> None: float_val = 1.345620002 cache.set("test_key", float_val) @@ -73,19 +73,19 @@ def test_save_float(self, cache: BaseCache): assert isinstance(res, float) assert res == float_val - def test_timeout(self, cache: BaseCache): + def test_timeout(self, cache: BaseCache) -> None: cache.set("test_key", 222, timeout=3) time.sleep(4) res = cache.get("test_key") assert res is None - def test_timeout_0(self, cache: BaseCache): + def test_timeout_0(self, cache: BaseCache) -> None: cache.set("test_key", 222, timeout=0) res = cache.get("test_key") assert res is None - def test_timeout_parameter_as_positional_argument(self, cache: BaseCache): + def test_timeout_parameter_as_positional_argument(self, cache: BaseCache) -> None: cache.set("test_key", 222, -1) res = cache.get("test_key") assert res is None @@ -97,7 +97,7 @@ def test_timeout_parameter_as_positional_argument(self, cache: BaseCache): assert res1 == 222 assert res2 is None - def test_timeout_negative(self, cache: BaseCache): + def test_timeout_negative(self, cache: BaseCache) -> None: cache.set("test_key", 222, timeout=-1) res = cache.get("test_key") assert res is None @@ -107,22 +107,19 @@ def test_timeout_negative(self, cache: BaseCache): res = cache.get("test_key") assert res is None - def test_timeout_tiny(self, cache: BaseCache): + def test_timeout_tiny(self, cache: BaseCache) -> None: cache.set("test_key", 222, timeout=0.00001) res = cache.get("test_key") assert res in (None, 222) - def test_set_add(self, cache: BaseCache): + def test_set_add(self, cache: BaseCache) -> None: cache.set("add_key", "Initial value") - res = cache.add("add_key", "New value") - assert res is False + assert cache.add("add_key", "New value") is False - res = cache.get("add_key") - assert res == "Initial value" - res = cache.add("other_key", "New value") - assert res is True + assert cache.get("add_key") == "Initial value" + assert cache.add("other_key", "New value") is True - def test_get_many(self, cache: BaseCache): + def test_get_many(self, cache: BaseCache) -> None: cache.set("a", 1) cache.set("b", 2) cache.set("c", 3) @@ -130,7 +127,7 @@ def test_get_many(self, cache: BaseCache): res = cache.get_many(["a", "b", "c"]) assert res == {"a": 1, "b": 2, "c": 3} - def test_get_many_unicode(self, cache: BaseCache): + def test_get_many_unicode(self, cache: BaseCache) -> None: cache.set("a", "1") cache.set("b", "2") cache.set("c", "3") @@ -138,40 +135,32 @@ def test_get_many_unicode(self, cache: BaseCache): res = cache.get_many(["a", "b", "c"]) assert res == {"a": "1", "b": "2", "c": "3"} - def test_set_many(self, cache: BaseCache): + def test_set_many(self, cache: BaseCache) -> None: cache.set_many({"a": 1, "b": 2, "c": 3}) res = cache.get_many(["a", "b", "c"]) assert res == {"a": 1, "b": 2, "c": 3} - def test_delete(self, cache: BaseCache): + def test_delete(self, cache: BaseCache) -> None: cache.set_many({"a": 1, "b": 2, "c": 3}) - res = cache.delete("a") - assert bool(res) is True - - res = cache.get_many(["a", "b", "c"]) - assert res == {"b": 2, "c": 3} + assert cache.delete("a") is True + assert cache.get_many(["a", "b", "c"]) == {"b": 2, "c": 3} + assert cache.delete("a") is False - res = cache.delete("a") - assert bool(res) is False - - def test_delete_many(self, cache: BaseCache): + def test_delete_many(self, cache: BaseCache) -> None: cache.set_many({"a": 1, "b": 2, "c": 3}) - res = cache.delete_many(["a", "b"]) - res = cache.get_many(["a", "b", "c"]) - assert res == {"c": 3} + cache.delete_many(["a", "b"]) + assert cache.get_many(["a", "b", "c"]) == {"c": 3} - def test_delete_many_generator(self, cache: BaseCache): + def test_delete_many_generator(self, cache: BaseCache) -> None: cache.set_many({"a": 1, "b": 2, "c": 3}) - res = cache.delete_many(key for key in ["a", "b"]) + cache.delete_many(key for key in ["a", "b"]) res = cache.get_many(["a", "b", "c"]) assert res == {"c": 3} - def test_delete_many_empty_generator(self, cache: BaseCache): - res = cache.delete_many(key for key in cast(List[str], [])) - assert bool(res) is False - - def test_incr(self, cache: BaseCache): + def test_delete_many_empty_generator(self, cache: BaseCache) -> None: + cache.delete_many(key for key in cast(List[str], [])) + def test_incr(self, cache: BaseCache) -> None: cache.set("num", 1) cache.incr("num") res = cache.get("num") @@ -198,7 +187,7 @@ def test_incr(self, cache: BaseCache): res = cache.get("num") assert res == 5 - def test_incr_no_timeout(self, cache: BaseCache): + def test_incr_no_timeout(self, cache: BaseCache) -> None: cache.set("num", 1, timeout=None) cache.incr("num") @@ -226,7 +215,7 @@ def test_incr_no_timeout(self, cache: BaseCache): res = cache.get("num") assert res == 5 - def test_get_set_bool(self, cache: BaseCache): + def test_get_set_bool(self, cache: BaseCache) -> None: cache.set("bool", True) res = cache.get("bool") @@ -239,7 +228,7 @@ def test_get_set_bool(self, cache: BaseCache): assert isinstance(res, bool) assert res is False - def test_version(self, cache: BaseCache): + def test_version(self, cache: BaseCache) -> None: cache.set("keytest", 2, version=2) res = cache.get("keytest") assert res is None @@ -247,7 +236,7 @@ def test_version(self, cache: BaseCache): res = cache.get("keytest", version=2) assert res == 2 - def test_incr_version(self, cache: BaseCache): + def test_incr_version(self, cache: BaseCache) -> None: cache.set("keytest", 2) cache.incr_version("keytest") @@ -257,7 +246,7 @@ def test_incr_version(self, cache: BaseCache): res = cache.get("keytest", version=2) assert res == 2 - def test_ttl_incr_version_no_timeout(self, cache: BaseCache): + def test_ttl_incr_version_no_timeout(self, cache: BaseCache) -> None: cache.set("my_key", "hello world!", timeout=None) cache.incr_version("my_key") @@ -266,14 +255,14 @@ def test_ttl_incr_version_no_timeout(self, cache: BaseCache): assert my_value == "hello world!" - def test_touch_zero_timeout(self, cache: BaseCache): + def test_touch_zero_timeout(self, cache: BaseCache) -> None: cache.set("test_key", 222, timeout=10) assert cache.touch("test_key", 0) is True res = cache.get("test_key") assert res is None - def test_touch_positive_timeout(self, cache: BaseCache): + def test_touch_positive_timeout(self, cache: BaseCache) -> None: cache.set("test_key", 222, timeout=10) assert cache.touch("test_key", 2) is True @@ -281,35 +270,35 @@ def test_touch_positive_timeout(self, cache: BaseCache): time.sleep(3) assert cache.get("test_key") is None - def test_touch_negative_timeout(self, cache: BaseCache): + def test_touch_negative_timeout(self, cache: BaseCache) -> None: cache.set("test_key", 222, timeout=10) assert cache.touch("test_key", -1) is True res = cache.get("test_key") assert res is None - def test_touch_missed_key(self, cache: BaseCache): + def test_touch_missed_key(self, cache: BaseCache) -> None: assert cache.touch("test_key_does_not_exist", 1) is False - def test_touch_forever(self, cache: Theine): + def test_touch_forever(self, cache: Theine) -> None: cache.set("test_key", "foo", timeout=1) result = cache.touch("test_key", None) assert result is True time.sleep(2) assert cache.get("test_key") == "foo" - def test_touch_forever_nonexistent(self, cache: BaseCache): + def test_touch_forever_nonexistent(self, cache: BaseCache) -> None: result = cache.touch("test_key_does_not_exist", None) assert result is False - def test_touch_default_timeout(self, cache: BaseCache): + def test_touch_default_timeout(self, cache: BaseCache) -> None: cache.set("test_key", "foo", timeout=1) result = cache.touch("test_key") assert result is True time.sleep(2) assert cache.get("test_key") == "foo" - def test_clear(self, cache: BaseCache): + def test_clear(self, cache: BaseCache) -> None: cache.set("foo", "bar") value_from_cache = cache.get("foo") assert value_from_cache == "bar" diff --git a/tests/test_cache.py b/tests/test_cache.py index 058af2e..0474323 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,19 +1,21 @@ from datetime import timedelta from random import randint from time import sleep -from bounded_zipf import Zipf +from typing import cast import pytest +from bounded_zipf import Zipf # type: ignore[import] +from pytest_asyncio.plugin import SubRequest from theine.theine import Cache, sentinel @pytest.fixture(params=["lru", "tlfu", "clockpro"]) -def policy(request): - return request.param +def policy(request: SubRequest) -> str: + return cast(str, request.param) -def test_set(policy): +def test_set(policy: str) -> None: cache = Cache(policy, 100) for i in range(20): key = f"key:{i}" @@ -36,7 +38,7 @@ def test_set(policy): assert len(cache) == 100 -def test_set_cache_size(policy): +def test_set_cache_size(policy: str) -> None: cache = Cache(policy, 500) for _ in range(100000): i = randint(0, 100000) @@ -44,7 +46,7 @@ def test_set_cache_size(policy): assert len([i for i in cache._cache if i is not sentinel]) == 500 -def test_set_with_ttl(policy): +def test_set_with_ttl(policy: str) -> None: cache = Cache(policy, 500) for i in range(30): key = f"key:{i}" @@ -67,7 +69,7 @@ def test_set_with_ttl(policy): assert f"key:{i}:2" in data -def test_delete(policy): +def test_delete(policy: str) -> None: cache = Cache(policy, 100) for i in range(20): key = f"key:{i}" @@ -79,11 +81,11 @@ def test_delete(policy): class Foo: - def __init__(self, id): + def __init__(self, id: int): self.id = id -def test_hashable_key(policy): +def test_hashable_key(policy: str) -> None: cache = Cache(policy, 100) foos = [Foo(i) for i in range(20)] for foo in foos: @@ -98,7 +100,7 @@ def test_hashable_key(policy): assert cache.key_gen.len() == 19 -def test_set_with_ttl_hashable(policy): +def test_set_with_ttl_hashable(policy: str) -> None: cache = Cache(policy, 500) foos = [Foo(i) for i in range(30)] for i in range(30): @@ -117,7 +119,7 @@ def test_set_with_ttl_hashable(policy): assert cache.key_gen.len() == 0 -def test_ttl_high_workload(policy): +def test_ttl_high_workload(policy: str) -> None: cache = Cache(policy, 500000) for i in range(500000): cache.set((i, 2), i, timedelta(seconds=randint(5, 10))) @@ -132,7 +134,7 @@ def test_ttl_high_workload(policy): assert len(cache.key_gen.hk) == 0 -def test_close_cache(policy): +def test_close_cache(policy: str) -> None: for _ in range(10): cache = Cache(policy, 500) cache.set("foo", "bar", timedelta(seconds=60)) @@ -140,7 +142,7 @@ def test_close_cache(policy): assert cache._maintainer.is_alive() is False -def test_cache_stats(policy): +def test_cache_stats(policy: str) -> None: cache = Cache(policy, 5000) assert cache.max_size == 5000 assert len(cache) == 0 diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 72abf66..f2c965a 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -8,13 +8,13 @@ from unittest.mock import Mock import pytest -from bounded_zipf import Zipf +from bounded_zipf import Zipf # type: ignore[import] from theine import Cache, Memoize @Memoize(Cache("tlfu", 1000), None) -def foo(id: int, m: Mock) -> Dict: +def foo(id: int, m: Mock) -> Dict[str, int]: m(id) return {"id": id} @@ -97,7 +97,7 @@ async def async_foo_auto(self, id: int, m: Mock) -> Dict: return {"id": id} -def test_sync_decorator(): +def test_sync_decorator() -> None: mock = Mock() threads: List[Thread] = [] assert foo.__name__ == "foo" # type: ignore @@ -118,7 +118,7 @@ def assert_id(id: int, m: Mock): assert set(ints) == {0, 1, 2, 3, 4, 5} -def test_sync_decorator_empty(): +def test_sync_decorator_empty() -> None: threads: List[Thread] = [] def assert_id(): @@ -134,7 +134,7 @@ def assert_id(): @pytest.mark.asyncio -async def test_async_decorator(): +async def test_async_decorator() -> None: mock = Mock() assert async_foo.__name__ == "async_foo" # type: ignore @@ -149,7 +149,7 @@ async def assert_id(id: int, m: Mock): assert set(ints) == {0, 1, 2, 3, 4, 5} -def test_instance_method_sync(): +def test_instance_method_sync() -> None: mock = Mock() threads: List[Thread] = [] bar = Bar() @@ -172,7 +172,7 @@ def assert_id(id: int, m: Mock): @pytest.mark.asyncio -async def test_instance_method_async(): +async def test_instance_method_async() -> None: mock = Mock() bar = Bar() assert bar.async_foo.__name__ == "async_foo" # type: ignore @@ -193,8 +193,7 @@ def foo_auto_key(a: int, b: int, c: int = 5) -> Dict: return {"a": a, "b": b, "c": c} -def test_auto_key(): - +def test_auto_key() -> None: tests = [ ([1, 2, 3], {}, (1, 2, 3)), ([1, 2], {}, (1, 2, 5)), @@ -219,8 +218,7 @@ async def async_foo_auto_key(a: int, b: int, c: int = 5) -> Dict: @pytest.mark.asyncio -async def test_auto_key_async(): - +async def test_auto_key_async() -> None: tests = [ ([1, 2, 3], {}, (1, 2, 3)), ([1, 2], {}, (1, 2, 5)), @@ -239,7 +237,7 @@ async def assert_data(args, kwargs, expected): await assert_data(*case) -def test_instance_method_auto_key_sync(): +def test_instance_method_auto_key_sync() -> None: mock = Mock() threads: List[Thread] = [] bar = Bar() @@ -261,7 +259,7 @@ def assert_id(id: int, m: Mock): @pytest.mark.asyncio -async def test_instance_method_auto_key_async(): +async def test_instance_method_auto_key_async() -> None: mock = Mock() bar = Bar() @@ -286,7 +284,7 @@ def _(id: int) -> str: return f"id-{id}" -def test_timeout(): +def test_timeout() -> None: for i in range(30): result = foo_to(i) assert result["id"] == i @@ -308,7 +306,7 @@ def foo_to_auto(id: int, m: Mock) -> Dict: return {"id": id} -def test_timeout_auto_key(): +def test_timeout_auto_key() -> None: mock = Mock() for i in range(30): result = foo_to_auto(i, mock) @@ -327,7 +325,7 @@ def test_timeout_auto_key(): assert foo_to_auto._cache.key_gen.len() == 0 -def test_cache_full_evict(): +def test_cache_full_evict() -> None: mock = Mock() for i in range(30, 1500): result = foo_to_auto(i, mock) @@ -336,7 +334,7 @@ def test_cache_full_evict(): assert foo_to_auto._cache.key_gen.len() == 1000 -def test_cache_full_auto_key_sync_multi(): +def test_cache_full_auto_key_sync_multi() -> None: mock = Mock() threads: List[Thread] = [] @@ -356,11 +354,11 @@ def assert_id(id: int, m: Mock): @Memoize(Cache("tlfu", 1000), timeout=None, lock=True) -def read_auto_key(key: str): +def read_auto_key(key: str) -> str: return key -def assert_read_key(n: int): +def assert_read_key(n: int) -> None: key = f"key:{n}" v = read_auto_key(key) assert v == key @@ -369,7 +367,7 @@ def assert_read_key(n: int): print(".", end="") -def test_cocurrency_load(): +def test_cocurrency_load() -> None: z = Zipf(1.0001, 10, 5000000) with concurrent.futures.ThreadPoolExecutor(max_workers=1000) as executor: for _ in range(200000): diff --git a/theine/adapters/django.py b/theine/adapters/django.py index 902b4ed..e5daee7 100644 --- a/theine/adapters/django.py +++ b/theine/adapters/django.py @@ -1,26 +1,31 @@ from datetime import timedelta from threading import Lock -from typing import Optional, cast +from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar, Union, cast -from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache +from django.core.cache.backends.base import BaseCache, DEFAULT_TIMEOUT from theine import Cache as Theine from theine.theine import sentinel +KEY_TYPE = Union[str, Callable[..., str]] +VALUE_TYPE = Any +VERSION_TYPE = Optional[int] + class Cache(BaseCache): - def __init__(self, name, params): + def __init__(self, name: str, params: Dict[str, Any]): super().__init__(params) options = params.get("OPTIONS", {}) policy = options.get("POLICY", "tlfu") self.cache = Theine(policy, self._max_entries) - def _timeout_seconds(self, timeout) -> Optional[float]: + def _timeout_seconds(self, timeout: 'Optional[Union[float, DEFAULT_TIMEOUT]]') -> float: if timeout == DEFAULT_TIMEOUT: - return self.default_timeout - return timeout + return cast(float, self.default_timeout) + return cast(float, timeout) - def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): + def add(self, key: KEY_TYPE, value: VALUE_TYPE, timeout: Optional[float] = DEFAULT_TIMEOUT, + version: VERSION_TYPE = None) -> bool: data = self.get(key, sentinel, version) if data is not sentinel: return False @@ -28,17 +33,17 @@ def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): self.cache.set( key, value, - timedelta(seconds=cast(float, timeout)) - if timeout is not DEFAULT_TIMEOUT - else None, + timedelta(seconds=cast(float, timeout)) if timeout is not DEFAULT_TIMEOUT else None, ) return True - def get(self, key, default=None, version=None): + def get(self, key: KEY_TYPE, default: Optional[VALUE_TYPE] = None, version: VERSION_TYPE = None) -> Optional[ + VALUE_TYPE]: key = self.make_key(key, version) return self.cache.get(key, default) - def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): + def set(self, key: KEY_TYPE, value: VALUE_TYPE, timeout: Optional[float] = DEFAULT_TIMEOUT, + version: VERSION_TYPE = None) -> None: to = self._timeout_seconds(timeout) if to is not None and to <= 0: self.delete(key) @@ -50,7 +55,7 @@ def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): timedelta(seconds=to) if to is not None else None, ) - def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None): + def touch(self, key: KEY_TYPE, timeout: Optional[float] = DEFAULT_TIMEOUT, version: VERSION_TYPE = None) -> bool: data = self.get(key, sentinel, version) if data is sentinel: return False @@ -58,7 +63,7 @@ def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None): if ( timeout is not DEFAULT_TIMEOUT and timeout is not None - and cast(float, timeout) <= 0 + and timeout <= 0 ): self.cache.delete(nkey) return True @@ -67,9 +72,9 @@ def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None): self.cache._access(nkey, timedelta(seconds=to) if to is not None else None) return True - def delete(self, key, version=None): + def delete(self, key: KEY_TYPE, version: VERSION_TYPE = None) -> bool: key = self.make_key(key, version) return self.cache.delete(key) - def clear(self): + def clear(self) -> None: self.cache.clear() diff --git a/theine/theine.py b/theine/theine.py index 9d69369..bb20923 100644 --- a/theine/theine.py +++ b/theine/theine.py @@ -8,53 +8,54 @@ from threading import Event, Thread from typing import ( Any, - Callable, + Awaitable, Callable, Dict, Hashable, List, Optional, - Tuple, + TYPE_CHECKING, Tuple, Type, TypeVar, Union, - cast, -) + cast, ) +from mypy_extensions import KwArg, VarArg from theine_core import ClockProCore, LruCore, TlfuCore from typing_extensions import ParamSpec, Protocol from theine.exceptions import InvalidTTL from theine.models import CacheStats +P = ParamSpec("P") +R = TypeVar("R", covariant=True, bound=Any) +R_A = TypeVar("R_A", covariant=True, bound=Union[Awaitable[Any], Callable[..., Any]]) +if TYPE_CHECKING: + from functools import _Wrapped + sentinel = object() -def KeyGen(): - counter = itertools.count() - hk_map: Dict[Hashable, int] = {} - kh_map: Dict[int, Hashable] = {} +class KeyGen: + def __init__(self) -> None: + self.counter = itertools.count() + self.hk: Dict[Hashable, int] = {} + self.kh: Dict[int, Hashable] = {} - def gen(input: Hashable) -> str: - id = hk_map.get(input, None) + def gen(self, input: Hashable) -> str: + id = self.hk.get(input, None) if id is None: - id = next(counter) - hk_map[input] = id - kh_map[id] = input + id = next(self.counter) + self.hk[input] = id + self.kh[id] = input return f"_auto:{id}" - def _remove(key: str): - h = kh_map.pop(int(key.replace("_auto:", "")), None) + def remove(self, key: str) -> None: + h = self.kh.pop(int(key.replace("_auto:", "")), None) if h is not None: - hk_map.pop(h, None) + self.hk.pop(h, None) - def _len() -> int: - return len(hk_map) - - gen.remove = _remove # type: ignore - gen.len = _len # type: ignore - gen.kh = kh_map # type: ignore - gen.hk = hk_map # type: ignore - return gen + def len(self) -> int: + return len(self.hk) class Core(Protocol): @@ -70,10 +71,10 @@ def remove(self, key: str) -> Optional[int]: def access(self, key: str) -> Optional[int]: ... - def advance(self, cache: List, sentinel: Any, kh: Dict, hk: Dict): + def advance(self, cache: List[Any], sentinel: Any, kh: Dict[int, Hashable], hk: Dict[Hashable, int]) -> None: ... - def clear(self): + def clear(self) -> None: ... def len(self) -> int: @@ -95,10 +96,10 @@ def remove(self, key: str) -> Optional[int]: def access(self, key: str) -> Optional[int]: ... - def advance(self, cache: List, sentinel: Any, kh: Dict, hk: Dict): + def advance(self, cache: List[Any], sentinel: Any, kh: Dict[int, Hashable], hk: Dict[Hashable, int]) -> None: ... - def clear(self): + def clear(self) -> None: ... def len(self) -> int: @@ -111,9 +112,6 @@ def len(self) -> int: "clockpro": ClockProCore, } -P = ParamSpec("P") -R = TypeVar("R", covariant=True) - @dataclass class EventData: @@ -124,12 +122,12 @@ class EventData: # https://github.com/python/cpython/issues/90780 # use event to protect from thundering herd class CachedAwaitable: - def __init__(self, awaitable): + def __init__(self, awaitable: Awaitable[Any]) -> None: self.awaitable = awaitable - self.future = None + self.future: Optional[Awaitable[Any]] = None self.result = sentinel - def __await__(self): + def __await__(self) -> Any: if self.result is not sentinel: return self.result @@ -146,62 +144,59 @@ def __await__(self): class Key: - def __init__(self): + def __init__(self) -> None: self.key: Optional[str] = None self.event = Event() -class Cached(Protocol[P, R]): +class Cached(Protocol[P, R_A]): _cache: "Cache" - def key(self, fn: Callable[P, Hashable]): + def key(self, fn: Callable[P, Hashable]) -> None: ... - def __call__(self, *args, **kwargs) -> R: + def __call__(self, *args: Any, **kwargs: Any) -> R_A: ... def Wrapper( - fn: Callable[P, R], + fn: Callable[P, R_A], timeout: Optional[timedelta], cache: "Cache", - coro: bool, typed: bool, lock: bool, -) -> Cached[P, R]: - +) -> Cached[P, R_A]: _key_func: Optional[Callable[..., Hashable]] = None _events: Dict[Hashable, EventData] = {} - _func: Callable = fn + _func: Callable[P, R_A] = fn _cache: "Cache" = cache - _coro: bool = coro _timeout: Optional[timedelta] = timeout _typed: bool = typed _auto_key: bool = True _lock = lock - def key(fn: Callable[P, Hashable]): + def key(fn: Callable[P, Hashable]) -> None: nonlocal _key_func nonlocal _auto_key _key_func = fn _auto_key = False - def fetch(*args: P.args, **kwargs: P.kwargs) -> R: + def fetch(*args: P.args, **kwargs: P.kwargs) -> R_A: if _auto_key: key = _make_key(args, kwargs, _typed) else: key = _key_func(*args, **kwargs) # type: ignore - if _coro: + if inspect.iscoroutinefunction(fn): result = _cache.get(key, sentinel) if result is sentinel: - result = CachedAwaitable(_func(*args, **kwargs)) + result = CachedAwaitable(cast(Awaitable[Any], _func(*args, **kwargs))) _cache.set(key, result, _timeout) - return cast(R, result) + return cast(R_A, result) data = _cache.get(key, sentinel) if data is not sentinel: - return cast(R, data) + return cast(R_A, data) if _lock: event = EventData(Event(), None) ve = _events.setdefault(key, event) @@ -217,7 +212,7 @@ def fetch(*args: P.args, **kwargs: P.kwargs) -> R: else: result = _func(*args, **kwargs) _cache.set(key, result, _timeout) - return cast(R, result) + return cast(R_A, result) fetch._cache = _cache # type: ignore fetch.key = key # type: ignore @@ -252,9 +247,8 @@ def __init__( self.typed = typed self.lock = lock - def __call__(self, fn: Callable[P, R]) -> Cached[P, R]: - coro = inspect.iscoroutinefunction(fn) - wrapper = Wrapper(fn, self.timeout, self.cache, coro, self.typed, self.lock) + def __call__(self, fn: Callable[P, R_A]) -> '_Wrapped[P, R_A, [VarArg(Any), KwArg(Any)], R_A]': + wrapper = Wrapper(fn, self.timeout, self.cache, self.typed, self.lock) return update_wrapper(wrapper, fn) @@ -300,7 +294,7 @@ def get(self, key: Hashable, default: Any = None) -> Any: elif isinstance(key, int): key_str = f"{key}" else: - key_str = self.key_gen(key) + key_str = self.key_gen.gen(key) auto_key = True index = self.core.access(key_str) @@ -312,14 +306,14 @@ def get(self, key: Hashable, default: Any = None) -> Any: self._hit += 1 return self._cache[index] - def _access(self, key: Hashable, ttl: Optional[timedelta] = None): + def _access(self, key: Hashable, ttl: Optional[timedelta] = None) -> None: key_str = "" if isinstance(key, str): key_str = key elif isinstance(key, int): key_str = f"{key}" else: - key_str = self.key_gen(key) + key_str = self.key_gen.gen(key) ttl_ns = None if ttl is not None: @@ -346,7 +340,7 @@ def set( elif isinstance(key, int): key_str = f"{key}" else: - key_str = self.key_gen(key) + key_str = self.key_gen.gen(key) ttl_ns = None if ttl is not None: @@ -382,7 +376,7 @@ def _set_clockpro( elif isinstance(key, int): key_str = f"{key}" else: - key_str = self.key_gen(key) + key_str = self.key_gen.gen(key) ttl_ns = None if ttl is not None: @@ -416,7 +410,7 @@ def delete(self, key: Hashable) -> bool: elif isinstance(key, int): key_str = f"{key}" else: - key_str = self.key_gen(key) + key_str = self.key_gen.gen(key) self.key_gen.remove(key_str) index = self.core.remove(key_str) @@ -425,7 +419,7 @@ def delete(self, key: Hashable) -> bool: return True return False - def maintenance(self): + def maintenance(self) -> None: """ Remove expired keys. """ @@ -433,15 +427,15 @@ def maintenance(self): self.core.advance(self._cache, sentinel, self.key_gen.kh, self.key_gen.hk) time.sleep(0.5) - def clear(self): + def clear(self) -> None: self.core.clear() self._cache = [sentinel] * len(self._cache) - def close(self): + def close(self) -> None: self._closed = True self._maintainer.join() - def __del__(self): + def __del__(self) -> None: self.clear() self.close()