Skip to content

Commit

Permalink
ci
Browse files Browse the repository at this point in the history
  • Loading branch information
kellerza committed Aug 5, 2024
1 parent 54a7f70 commit 4711ba8
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 11 deletions.
77 changes: 68 additions & 9 deletions aiohttp_msal/redis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from redis.asyncio import Redis, from_url

from aiohttp_msal.msal_async import AsyncMSAL
from aiohttp_msal.settings import ENV
from aiohttp_msal.settings import ENV as MENV

_LOGGER = logging.getLogger(__name__)

Expand All @@ -20,15 +20,17 @@
@asynccontextmanager
async def get_redis() -> AsyncGenerator[Redis, None]:
"""Get a Redis connection."""
if ENV.database:
if MENV.database:
_LOGGER.debug("Using redis from environment")
yield ENV.database
yield MENV.database
return
_LOGGER.info("Connect to Redis %s", ENV.REDIS)
redis = from_url(ENV.REDIS, decode_responses=True)
_LOGGER.info("Connect to Redis %s", MENV.REDIS)
redis = from_url(MENV.REDIS) # decode_responses=True not allowed aiohttp_session
MENV.database = redis
try:
yield redis
finally:
MENV.database = None # type:ignore
await redis.close()


Expand All @@ -46,10 +48,11 @@ async def session_iter(
if match and not all(isinstance(v, str) for v in match.values()):
raise ValueError("match values must be strings")
async for key in redis.scan_iter(
count=100, match=key_match or f"{ENV.COOKIE_NAME}*"
count=100, match=key_match or f"{MENV.COOKIE_NAME}*"
):
if not isinstance(key, str):
key = key.decode()
sval = await redis.get(key)
_LOGGER.debug("Session: %s = %s", key, sval)
created, ses = 0, {}
try:
val = json.loads(sval) # type: ignore
Expand Down Expand Up @@ -111,11 +114,67 @@ def save_cache(*args: Any) -> None:
return AsyncMSAL(session, save_cache=save_cache)


async def get_session(email: str, *, redis: Optional[Redis] = None) -> AsyncMSAL:
async def get_session(
email: str, *, redis: Optional[Redis] = None, scope: str = ""
) -> AsyncMSAL:
"""Get a session from Redis."""
cnt = 0
async with AsyncExitStack() as stack:
if redis is None:
redis = await stack.enter_async_context(get_redis())
async for key, created, session in session_iter(redis, match={"mail": email}):
cnt += 1
if scope and scope not in str(session.get("token_cache")).lower():
continue
return _session_factory(key, str(created), session)
raise ValueError(f"Session for {email} not found")
msg = f"Session for {email}"
if not scope:
raise ValueError(f"{msg} not found")
raise ValueError(f"{msg} with scope {scope} not found ({cnt} checked)")


async def redis_get_json(key: str) -> list | dict | None:
"""Get a key from redis."""
res = await MENV.database.get(key)
if isinstance(res, (str, bytes, bytearray)):
return json.loads(res)
if res is not None:
_LOGGER.warning("Unexpected type for %s: %s", key, type(res))
return None


async def redis_get(key: str) -> str | None:
"""Get a key from redis."""
res = await MENV.database.get(key)
if isinstance(res, str):
return res
if isinstance(res, (bytes, bytearray)):
return res.decode()
if res is not None:
_LOGGER.warning("Unexpected type for %s: %s", key, type(res))
return None


async def redis_set_set(key: str, new_set: set[str]) -> None:
"""Set the value of a set in redis."""
cur_set = set(
s if isinstance(s, str) else s.decode()
for s in await MENV.database.smembers(key)
)
dif = list(cur_set - new_set)
if dif:
_LOGGER.warning("%s: removing %s", key, dif)
await MENV.database.srem(key, *dif)

dif = list(new_set - cur_set)
if dif:
_LOGGER.info("%s: adding %s", key, dif)
await MENV.database.sadd(key, *dif)


async def redis_scan(match_str: str) -> list[str]:
"""Return a list of matching keys."""
return [
s if isinstance(s, str) else s.decode()
async for s in MENV.database.scan_iter(match=match_str)
]
1 change: 0 additions & 1 deletion aiohttp_msal/settings_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Settings Base."""

from __future__ import annotations

import logging
import os
from pathlib import Path
Expand Down
52 changes: 52 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,56 @@

[tool.black]
line-length = 88
target-version = ['py36', 'py37', 'py38']
include = '\.pyi?$'

[tool.ruff]
line-length = 121
# pyflakes, pycodestyle, isort
include = ["tests/*.py", "aiohttp_msal/**/*.py"]

[tool.ruff.lint]
select = ["F", "E", "W", "I001"]

[tool.ruff.lint.flake8-import-conventions]
# Declare the banned `from` imports.
banned-from = ["typing"]

[tool.ruff.lint.isort]
no-lines-before = ["future", "standard-library"]

[tool.mypy]
disallow_untyped_defs = true
ignore_missing_imports = true

# https://stackoverflow.com/questions/64162504/settings-for-pylint-in-setup-cfg-are-not-getting-used
[tool.pylint.'MESSAGES CONTROL']
max-line-length = 120
good-names = ["db", "fr", "cr", "k", "i"]
disable = [
"line-too-long",
"unsubscriptable-object",
"unused-argument",
"too-many-branches",
"too-many-locals",
"too-many-statements",
"too-many-instance-attributes",
"too-few-public-methods",
"R0401",
"R0801",
"wrong-import-order",
]

[tool.pylint.design]
# limiting the number of returns might discourage
# the use of guard clauses. So we increase the
# allowed number of returns from 6 to 8
max-returns = 8
[tool.pytest.ini_options]
pythonpath = [".", "src"]
filterwarnings = "ignore:.+@coroutine.+deprecated.+"
testpaths = "tests"
norecursedirs = [".git", "modules"]
log_cli = true
log_cli_level = "DEBUG"
asyncio_mode = "auto"
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ redis =
aiohttp_session[aioredis]>=2.12
tests =
black==24.8.0
pylint
pylint==3.2.6
flake8
pytest-aiohttp
pytest
Expand Down

0 comments on commit 4711ba8

Please sign in to comment.