From 4711ba83b49b4f73576c16daa5d9d62634a8d36b Mon Sep 17 00:00:00 2001 From: Johann Kellerman Date: Mon, 5 Aug 2024 22:26:35 +0200 Subject: [PATCH] ci --- aiohttp_msal/redis_tools.py | 77 +++++++++++++++++++++++++++++++---- aiohttp_msal/settings_base.py | 1 - pyproject.toml | 52 +++++++++++++++++++++++ setup.cfg | 2 +- 4 files changed, 121 insertions(+), 11 deletions(-) diff --git a/aiohttp_msal/redis_tools.py b/aiohttp_msal/redis_tools.py index 4ac28bc..cf66014 100644 --- a/aiohttp_msal/redis_tools.py +++ b/aiohttp_msal/redis_tools.py @@ -10,7 +10,7 @@ from redis.asyncio import Redis, from_url from aiohttp_msal.msal_async import AsyncMSAL -from aiohttp_msal.settings import ENV +from aiohttp_msal.settings import ENV as MENV _LOGGER = logging.getLogger(__name__) @@ -20,15 +20,17 @@ @asynccontextmanager async def get_redis() -> AsyncGenerator[Redis, None]: """Get a Redis connection.""" - if ENV.database: + if MENV.database: _LOGGER.debug("Using redis from environment") - yield ENV.database + yield MENV.database return - _LOGGER.info("Connect to Redis %s", ENV.REDIS) - redis = from_url(ENV.REDIS, decode_responses=True) + _LOGGER.info("Connect to Redis %s", MENV.REDIS) + redis = from_url(MENV.REDIS) # decode_responses=True not allowed aiohttp_session + MENV.database = redis try: yield redis finally: + MENV.database = None # type:ignore await redis.close() @@ -46,10 +48,11 @@ async def session_iter( if match and not all(isinstance(v, str) for v in match.values()): raise ValueError("match values must be strings") async for key in redis.scan_iter( - count=100, match=key_match or f"{ENV.COOKIE_NAME}*" + count=100, match=key_match or f"{MENV.COOKIE_NAME}*" ): + if not isinstance(key, str): + key = key.decode() sval = await redis.get(key) - _LOGGER.debug("Session: %s = %s", key, sval) created, ses = 0, {} try: val = json.loads(sval) # type: ignore @@ -111,11 +114,67 @@ def save_cache(*args: Any) -> None: return AsyncMSAL(session, save_cache=save_cache) -async def get_session(email: str, *, redis: Optional[Redis] = None) -> AsyncMSAL: +async def get_session( + email: str, *, redis: Optional[Redis] = None, scope: str = "" +) -> AsyncMSAL: """Get a session from Redis.""" + cnt = 0 async with AsyncExitStack() as stack: if redis is None: redis = await stack.enter_async_context(get_redis()) async for key, created, session in session_iter(redis, match={"mail": email}): + cnt += 1 + if scope and scope not in str(session.get("token_cache")).lower(): + continue return _session_factory(key, str(created), session) - raise ValueError(f"Session for {email} not found") + msg = f"Session for {email}" + if not scope: + raise ValueError(f"{msg} not found") + raise ValueError(f"{msg} with scope {scope} not found ({cnt} checked)") + + +async def redis_get_json(key: str) -> list | dict | None: + """Get a key from redis.""" + res = await MENV.database.get(key) + if isinstance(res, (str, bytes, bytearray)): + return json.loads(res) + if res is not None: + _LOGGER.warning("Unexpected type for %s: %s", key, type(res)) + return None + + +async def redis_get(key: str) -> str | None: + """Get a key from redis.""" + res = await MENV.database.get(key) + if isinstance(res, str): + return res + if isinstance(res, (bytes, bytearray)): + return res.decode() + if res is not None: + _LOGGER.warning("Unexpected type for %s: %s", key, type(res)) + return None + + +async def redis_set_set(key: str, new_set: set[str]) -> None: + """Set the value of a set in redis.""" + cur_set = set( + s if isinstance(s, str) else s.decode() + for s in await MENV.database.smembers(key) + ) + dif = list(cur_set - new_set) + if dif: + _LOGGER.warning("%s: removing %s", key, dif) + await MENV.database.srem(key, *dif) + + dif = list(new_set - cur_set) + if dif: + _LOGGER.info("%s: adding %s", key, dif) + await MENV.database.sadd(key, *dif) + + +async def redis_scan(match_str: str) -> list[str]: + """Return a list of matching keys.""" + return [ + s if isinstance(s, str) else s.decode() + async for s in MENV.database.scan_iter(match=match_str) + ] diff --git a/aiohttp_msal/settings_base.py b/aiohttp_msal/settings_base.py index 0686d2a..4ec3153 100644 --- a/aiohttp_msal/settings_base.py +++ b/aiohttp_msal/settings_base.py @@ -1,7 +1,6 @@ """Settings Base.""" from __future__ import annotations - import logging import os from pathlib import Path diff --git a/pyproject.toml b/pyproject.toml index 77e9602..aa3b168 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,56 @@ + +[tool.black] +line-length = 88 +target-version = ['py36', 'py37', 'py38'] +include = '\.pyi?$' + [tool.ruff] line-length = 121 # pyflakes, pycodestyle, isort +include = ["tests/*.py", "aiohttp_msal/**/*.py"] + +[tool.ruff.lint] select = ["F", "E", "W", "I001"] + +[tool.ruff.lint.flake8-import-conventions] +# Declare the banned `from` imports. +banned-from = ["typing"] + +[tool.ruff.lint.isort] +no-lines-before = ["future", "standard-library"] + +[tool.mypy] +disallow_untyped_defs = true +ignore_missing_imports = true + +# https://stackoverflow.com/questions/64162504/settings-for-pylint-in-setup-cfg-are-not-getting-used +[tool.pylint.'MESSAGES CONTROL'] +max-line-length = 120 +good-names = ["db", "fr", "cr", "k", "i"] +disable = [ + "line-too-long", + "unsubscriptable-object", + "unused-argument", + "too-many-branches", + "too-many-locals", + "too-many-statements", + "too-many-instance-attributes", + "too-few-public-methods", + "R0401", + "R0801", + "wrong-import-order", +] + +[tool.pylint.design] +# limiting the number of returns might discourage +# the use of guard clauses. So we increase the +# allowed number of returns from 6 to 8 +max-returns = 8 +[tool.pytest.ini_options] +pythonpath = [".", "src"] +filterwarnings = "ignore:.+@coroutine.+deprecated.+" +testpaths = "tests" +norecursedirs = [".git", "modules"] +log_cli = true +log_cli_level = "DEBUG" +asyncio_mode = "auto" diff --git a/setup.cfg b/setup.cfg index bcb7977..e335c6e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,7 +36,7 @@ redis = aiohttp_session[aioredis]>=2.12 tests = black==24.8.0 - pylint + pylint==3.2.6 flake8 pytest-aiohttp pytest