Skip to content

Commit

Permalink
0.6
Browse files Browse the repository at this point in the history
  • Loading branch information
kellerza committed Oct 3, 2023
1 parent 7761e5c commit cef9230
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 48 deletions.
29 changes: 9 additions & 20 deletions .github/workflows/main.yml → .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- uses: actions/setup-python@v3
with:
python-version: ${{ env.DEFAULT_PYTHON }}
- uses: psf/black@23.1.0
- uses: psf/black@23.9.1

isort:
name: Check isort
Expand Down Expand Up @@ -46,8 +46,7 @@ jobs:
python-version: ${{ env.DEFAULT_PYTHON }}
- name: Install Requirements
run: |
python -m pip install --upgrade pip
pip install ".[redis,tests]"
pip install -e ".[redis,tests]"
- name: Run Pylint
run: |
pylint aiohttp_msal
Expand Down Expand Up @@ -88,7 +87,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]
needs:
- black
- isort
Expand All @@ -104,7 +103,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Run tests and generate coverage report
run: |
python -m pip install --upgrade pip
pip install ".[redis,tests]"
pytest --cov=./aiohttp_msal --cov-report=xml
- name: Upload coverage to Codecov
Expand All @@ -120,27 +118,18 @@ jobs:
if: startsWith(github.ref, 'refs/tags')
needs:
- pytest
permissions:
# IMPORTANT: this permission is mandatory for trusted publishing
id-token: write
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ env.DEFAULT_PYTHON }}
uses: actions/setup-python@v4
with:
python-version: ${{ env.DEFAULT_PYTHON }}
- name: Install pypa/build
run: >-
python -m
pip install
build
--user
run: python -m pip install build --user
- name: Build a binary wheel and a source tarball
run: >-
python -m
build
--sdist
--wheel
--outdir dist/
.
run: python -m build --sdist --wheel --outdir dist/ .
- name: Publish distribution 📦 to PyPI
uses: pypa/gh-action-pypi-publish@master
with:
password: ${{ secrets.PYPI_API_TOKEN }}
uses: pypa/gh-action-pypi-publish@release/v1
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.9.1
hooks:
- id: black
args:
Expand All @@ -12,17 +12,17 @@ repos:
args:
- --profile=black
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
rev: 6.1.0
hooks:
- id: flake8
- repo: https://github.com/codespell-project/codespell
rev: v2.2.2
rev: v2.2.6
hooks:
- id: codespell
# exclude: >-
# ^(.*comments_backup\.csv|.*poetry.lock)$
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.991
rev: v1.5.1
hooks:
- id: mypy
files: "aiohttp_msal/.*.py"
Expand Down
16 changes: 7 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@ Blocking MSAL functions are executed in the executor thread. Should be useful un

Tested with MSAL Python 1.21.0 onward - [MSAL Python docs](https://github.com/AzureAD/microsoft-authentication-library-for-python)


## AsycMSAL class

The AsyncMSAL class wraps the behavior in the following example app
https://github.com/Azure-Samples/ms-identity-python-webapp/blob/master/app.py#L76
<https://github.com/Azure-Samples/ms-identity-python-webapp/blob/master/app.py#L76>

It is responsible to manage tokens & token refreshes and as a client to retrieve data using these tokens.

### Acquire the token

Firstly you should get the tokens via OAuth

1. `initiate_auth_code_flow` [referernce](https://msal-python.readthedocs.io/en/latest/#msal.PublicClientApplication.initiate_auth_code_flow)
1. `initiate_auth_code_flow` [referernce](https://msal-python.readthedocs.io/en/latest/#msal.PublicClientApplication.initiate_auth_code_flow)

The caller is expected to:
1. somehow store this content, typically inside the current session of the server,
Expand All @@ -28,8 +27,7 @@ Firstly you should get the tokens via OAuth

**Step 1** and part of **Step 3** is stored by this class in the aiohttp_session

2. `acquire_token_by_auth_code_flow` [referernce](https://msal-python.readthedocs.io/en/latest/#msal.PublicClientApplication.initiate_auth_code_flow)

2. `acquire_token_by_auth_code_flow` [referernce](https://msal-python.readthedocs.io/en/latest/#msal.PublicClientApplication.initiate_auth_code_flow)

### Use the token

Expand All @@ -42,11 +40,11 @@ async with aiomsal.get("https://graph.microsoft.com/v1.0/me") as res:
res = await res.json()
```

# Example web server
## Example web server

Complete routes can be found in [routes.py](./aiohttp_msal/routes.py)

## Start the login process
### Start the login process

```python
@ROUTES.get("/user/login")
Expand All @@ -61,7 +59,7 @@ async def user_login(request: web.Request) -> web.Response:
return web.HTTPFound(redir)
```

## Acquire the token after being redirected back to the server
### Acquire the token after being redirected back to the server

```python
@ROUTES.post(URI_USER_AUTHORIZED)
Expand All @@ -78,7 +76,7 @@ async def user_authorized(request: web.Request) -> web.Response:

- `@ROUTES.get("/user/photo")`

Serve the user's photo from his Microsoft profile
Serve the user's photo from their Microsoft profile

- `get_user_info`

Expand Down
2 changes: 1 addition & 1 deletion aiohttp_msal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

_LOGGER = logging.getLogger(__name__)

VERSION = "0.5.5"
VERSION = "0.6"


def msal_session(*args: Callable[[AsyncMSAL], Union[Any, Awaitable[Any]]]) -> Callable:
Expand Down
41 changes: 30 additions & 11 deletions aiohttp_msal/msal_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import asyncio
import json
from functools import partial, wraps
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union

from aiohttp import web
from aiohttp.client import ClientResponse, ClientSession, _RequestContextManager
Expand Down Expand Up @@ -96,11 +96,20 @@ class AsyncMSAL:
_app: ConfidentialClientApplication = None
_clientsession: ClientSession = None # type: ignore

def __init__(self, session: Session):
"""Init the class."""
def __init__(
self,
session: Union[Session, dict[str, str]],
save_cache: Optional[Callable[[Union[Session, dict[str, str]]], None]] = None,
):
"""Init the class.
**save_token_cache** will be called if the token cache changes. Optional.
Not required when the session parameter is an aiohttp_session.Session."""
self.session = session
if not isinstance(session, Session):
raise ValueError(f"session required {session}")
if save_cache:
self.save_token_cache = save_cache
if not isinstance(session, (Session, dict)):
raise ValueError(f"session or dict-like object required {session}")

@property
def token_cache(self) -> SerializableTokenCache:
Expand Down Expand Up @@ -134,11 +143,13 @@ def _save_token_cache(self) -> None:
"""Save the token cache if it changed."""
if self.token_cache.has_state_changed:
self.session[TOKEN_CACHE] = self.token_cache.serialize()
if hasattr(self, "save_token_cache"):
self.save_token_cache(self.token_cache)

def build_auth_code_flow(self, redirect_uri: str) -> str:
"""First step - Start the flow."""
self.session[TOKEN_CACHE] = None
self.session[USER_EMAIL] = None
self.session[TOKEN_CACHE] = None # type: ignore
self.session[USER_EMAIL] = None # type: ignore
self.session[FLOW_CACHE] = res = self.app.initiate_auth_code_flow(
MY_SCOPE,
redirect_uri=redirect_uri,
Expand All @@ -149,8 +160,7 @@ def build_auth_code_flow(self, redirect_uri: str) -> str:
# https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.initiate_auth_code_flow
return str(res["auth_uri"])

@async_wrap
def async_acquire_token_by_auth_code_flow(self, auth_response: Any) -> None:
def acquire_token_by_auth_code_flow(self, auth_response: Any) -> None:
"""Second step - Acquire token."""
# Assume we have it in the cache (added by /login)
# will raise keryerror if no cache
Expand All @@ -165,8 +175,13 @@ def async_acquire_token_by_auth_code_flow(self, auth_response: Any) -> None:
"preferred_username"
)

@async_wrap
def async_get_token(self) -> Optional[dict[str, Any]]:
async def async_acquire_token_by_auth_code_flow(self, auth_response: Any) -> None:
"""Second step - Acquire token, async version."""
await asyncio.get_event_loop().run_in_executor(
None, self.acquire_token_by_auth_code_flow, auth_response
)

def get_token(self) -> Optional[dict[str, Any]]:
"""Acquire a token based on username."""
accounts = self.app.get_accounts()
if accounts:
Expand All @@ -175,6 +190,10 @@ def async_get_token(self) -> Optional[dict[str, Any]]:
return result
return None

async def async_get_token(self) -> Optional[dict[str, Any]]:
"""Acquire a token based on username."""
return await asyncio.get_event_loop().run_in_executor(None, self.get_token)

async def request(self, method: str, url: str, **kwargs: Any) -> ClientResponse:
"""Make a request to url using an oauth session.
Expand Down
82 changes: 82 additions & 0 deletions aiohttp_msal/redis_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Redis tools for sessions."""
import asyncio
import json
import logging
import time
from typing import AsyncGenerator, Optional

from redis.asyncio import Redis, from_url

from aiohttp_msal.msal_async import AsyncMSAL
from aiohttp_msal.settings import ENV

_LOGGER = logging.getLogger(__name__)

SES_KEYS = ("mail", "name", "m_mail", "m_name")


def get_redis() -> Redis:
"""Get a Redis connection."""
_LOGGER.info("Connect to Redis %s", ENV.REDIS)
ENV.database = from_url(ENV.REDIS) # pylint: disable=no-member
return ENV.database


async def iter_redis(
redis: Redis, *, clean: bool = False, match: Optional[dict[str, str]] = None
) -> AsyncGenerator[tuple[str, str, dict], None]:
"""Iterate over the Redis keys to find a specific session."""
async for key in redis.scan_iter(count=100, match=f"{ENV.COOKIE_NAME}*"):
sval = await redis.get(key)
if not isinstance(sval, str):
if clean:
await redis.delete(key)
continue
val = json.loads(sval)
ses = val.get("session")
created = val.get("created")
if clean and not ses or not created:
await redis.delete(key)
continue
if match:
for mkey, mval in match.items():
if mval not in ses[mkey]:
continue
created = val.get("created") or "0"
session = val.get("session") or {}
yield key, created, session


async def clean_redis(redis: Redis, max_age: int = 90) -> None:
"""Clear session entries older than max_age days."""
expire = int(time.time() - max_age * 24 * 60 * 60)
async for key, created, ses in iter_redis(redis, clean=True):
for key in SES_KEYS:
if not ses.get(key):
await redis.delete(key)
continue
if int(created) < expire:
await redis.delete(key)


async def get_session(red: Redis, email: str) -> AsyncMSAL:
"""Get a session from Redis."""
async for key, created, session in iter_redis(red, match={"mail": email}):

async def _save_cache(_: dict) -> None:
"""Save the token cache to Redis."""
rd2 = get_redis()
try:
await rd2.set(key, json.dumps({"created": created, "session": session}))
finally:
await rd2.close()

def save_cache(ses: dict) -> None:
"""Save the token cache to Redis."""
try:
asyncio.get_event_loop().create_task(_save_cache(ses))
except RuntimeError:
asyncio.run(_save_cache(ses))

return AsyncMSAL(session, save_cache=save_cache)
raise ValueError(f"Session for {email} not found")
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[metadata]
name = aiohttp_msal
version = attr: aiohttp_msal.VERSION
description = Helper Library to use MSAL with aiohttp
description = Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/kellerza/aiohttp_msal
Expand All @@ -27,7 +27,7 @@ python_requires = >=3.9
include_package_data = True
tests_requires = file: requirements_test.txt
install_requires =
msal>=1.21.0
msal>=1.24.1
aiohttp_session>=2.12
aiohttp>=3.8
zip_safe = true
Expand All @@ -36,7 +36,7 @@ zip_safe = true
redis =
aiohttp_session[aioredis]>=2.12
tests =
black==23.3.0
black==23.9.1
pylint
flake8
pytest-aiohttp
Expand Down
1 change: 1 addition & 0 deletions tests/test_init.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Init."""
import aiohttp_msal # noqa
import aiohttp_msal.routes # noqa

0 comments on commit cef9230

Please sign in to comment.