diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 579d34f..6119678 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,20 +1,15 @@ repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.292 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/psf/black rev: 23.9.1 hooks: - id: black args: - --quiet - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - args: - - --profile=black - - repo: https://github.com/pycqa/flake8 - rev: 6.1.0 - hooks: - - id: flake8 - repo: https://github.com/codespell-project/codespell rev: v2.2.6 hooks: diff --git a/README.md b/README.md index db6c769..0fe8342 100644 --- a/README.md +++ b/README.md @@ -90,23 +90,13 @@ async def user_authorized(request: web.Request) -> web.Response: ```python from aiohttp_msal import ENV, AsyncMSAL -from aiohttp_msal.redis_tools import clean_redis, get_redis, get_session - -async def get_async_msal(email: str) -> AsyncMSAL: - """Clean redis and get a session.""" - red = get_redis() - try: - return await get_session(red, email) - finally: - await red.close() - +from aiohttp_msal.redis_tools import get_session def main() # Uses the redis.asyncio driver to retrieve the current token # Will update the token_cache if a RefreshToken was used - ases = asyncio.run(get_async_msal(MYEMAIL)) + ases = asyncio.run(get_session(MYEMAIL)) client = GraphClient(ases.get_token) # ... # use the Graphclient - # ... ``` diff --git a/aiohttp_msal/__init__.py b/aiohttp_msal/__init__.py index da0c315..458cf77 100644 --- a/aiohttp_msal/__init__.py +++ b/aiohttp_msal/__init__.py @@ -13,7 +13,7 @@ _LOGGER = logging.getLogger(__name__) -VERSION = "0.6.3" +VERSION = "0.6.4" def msal_session(*args: Callable[[AsyncMSAL], Union[Any, Awaitable[Any]]]) -> Callable: diff --git a/aiohttp_msal/msal_async.py b/aiohttp_msal/msal_async.py index 0e8e2cc..1c669a9 100644 --- a/aiohttp_msal/msal_async.py +++ b/aiohttp_msal/msal_async.py @@ -71,11 +71,12 @@ class AsyncMSAL: https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.initiate_auth_code_flow The caller is expected to: - 1.somehow store this content, typically inside the current session of the server, - 2.guide the end user (i.e. resource owner) to visit that auth_uri, - typically with a redirect - 3.and then relay this dict and subsequent auth response to - acquire_token_by_auth_code_flow(). + 1. somehow store this content, typically inside the current session of the + server, + 2. guide the end user (i.e. resource owner) to visit that auth_uri, + typically with a redirect + 3. and then relay this dict and subsequent auth response to + acquire_token_by_auth_code_flow(). [1. and part of 3.] is stored by this class in the aiohttp_session diff --git a/aiohttp_msal/redis_tools.py b/aiohttp_msal/redis_tools.py index 338608c..b55e5cf 100644 --- a/aiohttp_msal/redis_tools.py +++ b/aiohttp_msal/redis_tools.py @@ -3,6 +3,7 @@ import json import logging import time +from contextlib import AsyncExitStack, asynccontextmanager from typing import Any, AsyncGenerator, Optional from redis.asyncio import Redis, from_url @@ -15,56 +16,81 @@ SES_KEYS = ("mail", "name", "m_mail", "m_name") -def get_redis() -> Redis: +@asynccontextmanager +async def get_redis() -> AsyncGenerator[Redis, None]: """Get a Redis connection.""" + if ENV.database: + _LOGGER.debug("Using redis from environment") + yield ENV.database + return _LOGGER.info("Connect to Redis %s", ENV.REDIS) - ENV.database = from_url(ENV.REDIS) # pylint: disable=no-member - return ENV.database + redis = from_url(ENV.REDIS) + try: + yield redis + finally: + await redis.close() + + +async def session_iter( + redis: Redis, + *, + match: Optional[dict[str, str]] = None, + key_match: Optional[str] = None, +) -> AsyncGenerator[tuple[str, int, dict[str, Any]], None]: + """Iterate over the Redis keys to find a specific session. + + match: Filter based on session content (i.e. mail/name) + key_match: Filter the Redis keys. Defaults to ENV.cookie_name + """ + async for key in redis.scan_iter( + count=100, match=key_match or f"{ENV.COOKIE_NAME}*" + ): + sval = await redis.get(key) + created, ses = 0, {} + try: + val = json.loads(sval) # type: ignore + created = int(val["created"]) + ses = val["session"] + except Exception: # pylint: disable=broad-except + pass + if match: + # Ensure we match all the supplied terms + if not all(k in ses and v in ses[k] for k, v in match.items()): + continue + yield key, created, ses -async def iter_redis( - redis: Redis, *, clean: bool = False, match: Optional[dict[str, str]] = None -) -> AsyncGenerator[tuple[str, str, dict], None]: - """Iterate over the Redis keys to find a specific session.""" - async for key in redis.scan_iter(count=100, match=f"{ENV.COOKIE_NAME}*"): - sval = await redis.get(key) - if not isinstance(sval, (str, bytes, bytearray)): - if clean: - await redis.delete(key) - continue - val = json.loads(sval) - ses = val.get("session") or {} - created = val.get("created") - if clean and not ses or not created: - await redis.delete(key) - continue - if match and not all(v in ses[k] for k, v in match.items()): - continue - yield key, created or "0", ses - - -async def clean_redis(redis: Redis, max_age: int = 90) -> None: +async def session_clean( + redis: Redis, *, max_age: int = 90, expected_keys: Optional[dict] = None +) -> None: """Clear session entries older than max_age days.""" + rem, keep = 0, 0 expire = int(time.time() - max_age * 24 * 60 * 60) - async for key, created, ses in iter_redis(redis, clean=True): - for key in SES_KEYS: - if not ses.get(key): + try: + async for key, created, ses in session_iter(redis): + all_keys = all(sk in ses for sk in (expected_keys or SES_KEYS)) + if created < expire or not all_keys: + rem += 1 await redis.delete(key) - continue - if int(created) < expire: - await redis.delete(key) + else: + keep += 1 + finally: + if rem: + _LOGGER.info("Sessions removed: %s (%s total)", rem, keep) + else: + _LOGGER.debug("No sessions removed (%s total)", keep) def _session_factory(key: str, created: str, session: dict) -> AsyncMSAL: - """Create a session with a save callback.""" + """Create a AsyncMSAL session. + + When get_token refreshes the token retrieved from Redis, the save_cache callback + will be responsible to update the cache in Redis.""" async def async_save_cache(_: dict) -> None: """Save the token cache to Redis.""" - rd2 = get_redis() - try: + async with get_redis() as rd2: await rd2.set(key, json.dumps({"created": created, "session": session})) - finally: - await rd2.close() def save_cache(*args: Any) -> None: """Save the token cache to Redis.""" @@ -76,8 +102,11 @@ def save_cache(*args: Any) -> None: return AsyncMSAL(session, save_cache=save_cache) -async def get_session(red: Redis, email: str) -> AsyncMSAL: +async def get_session(email: str, *, redis: Optional[Redis] = None) -> AsyncMSAL: """Get a session from Redis.""" - async for key, created, session in iter_redis(red, match={"mail": email}): - return _session_factory(key, created, session) + async with AsyncExitStack() as stack: + if redis is None: + redis = await stack.enter_async_context(get_redis()) + async for key, created, session in session_iter(redis, match={"mail": email}): + return _session_factory(key, str(created), session) raise ValueError(f"Session for {email} not found") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..58b1e73 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,4 @@ +[tool.ruff] +line-length = 88 +# pyflakes, pycodestyle, isort +select = ["F", "E", "W", "I001"]