Skip to content

Commit

Permalink
0.6.4 redistools
Browse files Browse the repository at this point in the history
  • Loading branch information
kellerza committed Oct 9, 2023
1 parent 75ec3c5 commit 7f804f5
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 67 deletions.
15 changes: 5 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.292
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/psf/black
rev: 23.9.1
hooks:
- id: black
args:
- --quiet
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
args:
- --profile=black
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
hooks:
- id: flake8
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
hooks:
Expand Down
14 changes: 2 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,13 @@ async def user_authorized(request: web.Request) -> web.Response:

```python
from aiohttp_msal import ENV, AsyncMSAL
from aiohttp_msal.redis_tools import clean_redis, get_redis, get_session

async def get_async_msal(email: str) -> AsyncMSAL:
"""Clean redis and get a session."""
red = get_redis()
try:
return await get_session(red, email)
finally:
await red.close()

from aiohttp_msal.redis_tools import get_session

def main()
# Uses the redis.asyncio driver to retrieve the current token
# Will update the token_cache if a RefreshToken was used
ases = asyncio.run(get_async_msal(MYEMAIL))
ases = asyncio.run(get_session(MYEMAIL))
client = GraphClient(ases.get_token)
# ...
# use the Graphclient
# ...
```
2 changes: 1 addition & 1 deletion aiohttp_msal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

_LOGGER = logging.getLogger(__name__)

VERSION = "0.6.3"
VERSION = "0.6.4"


def msal_session(*args: Callable[[AsyncMSAL], Union[Any, Awaitable[Any]]]) -> Callable:
Expand Down
11 changes: 6 additions & 5 deletions aiohttp_msal/msal_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ class AsyncMSAL:
https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.initiate_auth_code_flow
The caller is expected to:
1.somehow store this content, typically inside the current session of the server,
2.guide the end user (i.e. resource owner) to visit that auth_uri,
typically with a redirect
3.and then relay this dict and subsequent auth response to
acquire_token_by_auth_code_flow().
1. somehow store this content, typically inside the current session of the
server,
2. guide the end user (i.e. resource owner) to visit that auth_uri,
typically with a redirect
3. and then relay this dict and subsequent auth response to
acquire_token_by_auth_code_flow().
[1. and part of 3.] is stored by this class in the aiohttp_session
Expand Down
107 changes: 68 additions & 39 deletions aiohttp_msal/redis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import time
from contextlib import AsyncExitStack, asynccontextmanager
from typing import Any, AsyncGenerator, Optional

from redis.asyncio import Redis, from_url
Expand All @@ -15,56 +16,81 @@
SES_KEYS = ("mail", "name", "m_mail", "m_name")


def get_redis() -> Redis:
@asynccontextmanager
async def get_redis() -> AsyncGenerator[Redis, None]:
"""Get a Redis connection."""
if ENV.database:
_LOGGER.debug("Using redis from environment")
yield ENV.database
return
_LOGGER.info("Connect to Redis %s", ENV.REDIS)
ENV.database = from_url(ENV.REDIS) # pylint: disable=no-member
return ENV.database
redis = from_url(ENV.REDIS)
try:
yield redis
finally:
await redis.close()


async def session_iter(
redis: Redis,
*,
match: Optional[dict[str, str]] = None,
key_match: Optional[str] = None,
) -> AsyncGenerator[tuple[str, int, dict[str, Any]], None]:
"""Iterate over the Redis keys to find a specific session.
match: Filter based on session content (i.e. mail/name)
key_match: Filter the Redis keys. Defaults to ENV.cookie_name
"""
async for key in redis.scan_iter(
count=100, match=key_match or f"{ENV.COOKIE_NAME}*"
):
sval = await redis.get(key)
created, ses = 0, {}
try:
val = json.loads(sval) # type: ignore
created = int(val["created"])
ses = val["session"]
except Exception: # pylint: disable=broad-except
pass
if match:
# Ensure we match all the supplied terms
if not all(k in ses and v in ses[k] for k, v in match.items()):
continue
yield key, created, ses


async def iter_redis(
redis: Redis, *, clean: bool = False, match: Optional[dict[str, str]] = None
) -> AsyncGenerator[tuple[str, str, dict], None]:
"""Iterate over the Redis keys to find a specific session."""
async for key in redis.scan_iter(count=100, match=f"{ENV.COOKIE_NAME}*"):
sval = await redis.get(key)
if not isinstance(sval, (str, bytes, bytearray)):
if clean:
await redis.delete(key)
continue
val = json.loads(sval)
ses = val.get("session") or {}
created = val.get("created")
if clean and not ses or not created:
await redis.delete(key)
continue
if match and not all(v in ses[k] for k, v in match.items()):
continue
yield key, created or "0", ses


async def clean_redis(redis: Redis, max_age: int = 90) -> None:
async def session_clean(
redis: Redis, *, max_age: int = 90, expected_keys: Optional[dict] = None
) -> None:
"""Clear session entries older than max_age days."""
rem, keep = 0, 0
expire = int(time.time() - max_age * 24 * 60 * 60)
async for key, created, ses in iter_redis(redis, clean=True):
for key in SES_KEYS:
if not ses.get(key):
try:
async for key, created, ses in session_iter(redis):
all_keys = all(sk in ses for sk in (expected_keys or SES_KEYS))
if created < expire or not all_keys:
rem += 1
await redis.delete(key)
continue
if int(created) < expire:
await redis.delete(key)
else:
keep += 1
finally:
if rem:
_LOGGER.info("Sessions removed: %s (%s total)", rem, keep)
else:
_LOGGER.debug("No sessions removed (%s total)", keep)


def _session_factory(key: str, created: str, session: dict) -> AsyncMSAL:
"""Create a session with a save callback."""
"""Create a AsyncMSAL session.
When get_token refreshes the token retrieved from Redis, the save_cache callback
will be responsible to update the cache in Redis."""

async def async_save_cache(_: dict) -> None:
"""Save the token cache to Redis."""
rd2 = get_redis()
try:
async with get_redis() as rd2:
await rd2.set(key, json.dumps({"created": created, "session": session}))
finally:
await rd2.close()

def save_cache(*args: Any) -> None:
"""Save the token cache to Redis."""
Expand All @@ -76,8 +102,11 @@ def save_cache(*args: Any) -> None:
return AsyncMSAL(session, save_cache=save_cache)


async def get_session(red: Redis, email: str) -> AsyncMSAL:
async def get_session(email: str, *, redis: Optional[Redis] = None) -> AsyncMSAL:
"""Get a session from Redis."""
async for key, created, session in iter_redis(red, match={"mail": email}):
return _session_factory(key, created, session)
async with AsyncExitStack() as stack:
if redis is None:
redis = await stack.enter_async_context(get_redis())
async for key, created, session in session_iter(redis, match={"mail": email}):
return _session_factory(key, str(created), session)
raise ValueError(f"Session for {email} not found")
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[tool.ruff]
line-length = 88
# pyflakes, pycodestyle, isort
select = ["F", "E", "W", "I001"]

0 comments on commit 7f804f5

Please sign in to comment.