Skip to content

Commit

Permalink
Merge pull request #149 from the-mama-ai/github-test
Browse files Browse the repository at this point in the history
Add unit tests for github auth provider
  • Loading branch information
athornton authored Mar 11, 2024
2 parents 422fad8 + d622b91 commit a9a799d
Show file tree
Hide file tree
Showing 4 changed files with 440 additions and 42 deletions.
107 changes: 65 additions & 42 deletions giftless/auth/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import dataclasses
import functools
import logging
import math
import os
import threading
from collections.abc import Callable, Mapping, MutableMapping
from contextlib import AbstractContextManager
from contextlib import AbstractContextManager, suppress
from operator import attrgetter, itemgetter
from threading import Condition, Lock, RLock
from typing import Any, cast, overload
from threading import Lock, RLock
from typing import Any, Protocol, cast, overload

import cachetools.keys
import flask
Expand All @@ -23,22 +24,30 @@


# THREAD SAFE CACHING UTILS
class _LockType(AbstractContextManager, Protocol):
"""Generic type for threading.Lock and RLock."""

def acquire(self, blocking: bool = ..., timeout: float = ...) -> bool:
...

def release(self) -> None:
...


@dataclasses.dataclass(kw_only=True)
class SingleCallContext:
"""Thread-safety context for the single_call_method decorator."""

# condition variable blocking a call with particular arguments
cond: Condition = dataclasses.field(default_factory=Condition)
# None - call not started, False - call ongoing, True - call done
# the three states are needed to cover any spurious (pthread-like) wake-ups
call_status: bool | None = None
# reentrant lock guarding a call with particular arguments
rlock: _LockType = dataclasses.field(default_factory=RLock)
start_call: bool = True
result: Any = None
error: BaseException | None = None


def _ensure_lock(
existing_lock: Callable[[Any], AbstractContextManager] | None,
) -> Callable[[Any], AbstractContextManager]:
existing_lock: Callable[[Any], _LockType] | None = None,
) -> Callable[[Any], _LockType]:
if existing_lock is None:
default_lock = RLock()
return lambda _self: default_lock
Expand All @@ -54,7 +63,7 @@ def single_call_method(_method: Callable[..., Any]) -> Callable[..., Any]:
def single_call_method(
*,
key: Callable[..., Any] = cachetools.keys.methodkey,
lock: Callable[[Any], AbstractContextManager] | None = None,
lock: Callable[[Any], _LockType] | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
...

Expand All @@ -63,7 +72,7 @@ def single_call_method(
_method: Callable[..., Any] | None = None,
*,
key: Callable[..., Any] = cachetools.keys.methodkey,
lock: Callable[[Any], AbstractContextManager] | None = None,
lock: Callable[[Any], _LockType] | None = None,
) -> Callable[..., Any]:
"""Thread-safe decorator limiting concurrency of an idempotent method call.
When multiple threads concurrently call the decorated method with the same
Expand All @@ -78,7 +87,7 @@ def single_call_method(
It's possible to provide a "getter" callable for the lock guarding the main
call cache, called as 'lock(self)'. There's a built-in lock by default.
Each concurrent call is then guarded by its own lock/conditional variable.
Each concurrent call is then guarded by its own reentrant lock variable.
"""
lock = _ensure_lock(lock)

Expand All @@ -97,35 +106,29 @@ def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any:
concurrent_calls[k] = ctx = SingleCallContext()
# start locked for the current thread, so the following
# gap won't let other threads populate the result
ctx.cond.acquire()
ctx.rlock.acquire()

with ctx.cond:
if ctx.call_status is None:
# populating the result
ctx.call_status = False
with ctx.rlock:
if ctx.start_call:
ctx.start_call = False
ctx.rlock.release() # unlock the starting lock
try:
result = method(self, *args, **kwargs)
except BaseException as e:
ctx.error = e
raise
finally:
# call is done, cleanup its entry and notify threads
# call is done, cleanup its entry
with lck:
del concurrent_calls[k]
ctx.cond.release() # unlock the starting lock
ctx.cond.notify_all()
ctx.result = result
ctx.call_status = True
return result

else:
# waiting for the result to get populated
while True:
if ctx.error:
raise ctx.error
if ctx.call_status:
return ctx.result
ctx.cond.wait()
# call is done
if ctx.error:
raise ctx.error
return ctx.result

return wrapper

Expand All @@ -138,7 +141,7 @@ def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any:
def cachedmethod_threadsafe(
cache: Callable[[Any], MutableMapping],
key: Callable[..., Any] = cachetools.keys.methodkey,
lock: Callable[[Any], AbstractContextManager] | None = None,
lock: Callable[[Any], _LockType] | None = None,
) -> Callable[..., Any]:
"""Threadsafe variant of cachetools.cachedmethod."""
lock = _ensure_lock(lock)
Expand Down Expand Up @@ -178,15 +181,14 @@ class Schema(ma.Schema):
token_max_size = ma.fields.Int(
load_default=32, validate=ma.validate.Range(min=0)
)
# the auth cache must have at least one valid slot
auth_max_size = ma.fields.Int(
load_default=32, validate=ma.validate.Range(min=1)
load_default=32, validate=ma.validate.Range(min=0)
)
auth_write_ttl = ma.fields.Float(
load_default=15 * 60.0, validate=ma.validate.Range(min=1.0)
load_default=15 * 60.0, validate=ma.validate.Range(min=0)
)
auth_other_ttl = ma.fields.Float(
load_default=30.0, validate=ma.validate.Range(min=1.0)
load_default=30.0, validate=ma.validate.Range(min=0)
)

@ma.post_load
Expand Down Expand Up @@ -224,8 +226,9 @@ class Schema(ma.Schema):

@ma.post_load
def make_object(
self, data: Mapping[str, Any], **_kwargs: Mapping
self, data: MutableMapping[str, Any], **_kwargs: Mapping
) -> "Config":
data["api_url"] = data["api_url"].rstrip("/")
return Config(**data)

@classmethod
Expand Down Expand Up @@ -262,6 +265,10 @@ def expiration(_key: Any, value: set[Permission], now: float) -> float:
)
return now + ttl

# size-unlimited proxy cache to ensure at least one successful hit
self._auth_cache_read_proxy: MutableMapping[
Any, set[Permission]
] = cachetools.TTLCache(math.inf, 60.0)
self._auth_cache = cachetools.TLRUCache(cc.auth_max_size, expiration)
self._auth_cache_lock = Lock()

Expand All @@ -280,17 +287,28 @@ def __eq__(self, other: object) -> bool:
def __hash__(self) -> int:
return hash((self.login, self.id))

def permissions(self, org: str, repo: str) -> set[Permission] | None:
def permissions(
self, org: str, repo: str, *, authoritative: bool = False
) -> set[Permission] | None:
key = cachetools.keys.hashkey(org, repo)
with self._auth_cache_lock:
return self._auth_cache.get(key)
if authoritative:
permission = self._auth_cache_read_proxy.pop(key, None)
else:
permission = self._auth_cache_read_proxy.get(key)
if permission is None:
return self._auth_cache.get(key)
if authoritative:
with suppress(ValueError):
self._auth_cache[key] = permission
return permission

def authorize(
self, org: str, repo: str, permissions: set[Permission] | None
) -> None:
key = cachetools.keys.hashkey(org, repo)
with self._auth_cache_lock:
self._auth_cache[key] = (
self._auth_cache_read_proxy[key] = (
permissions if permissions is not None else set()
)

Expand All @@ -301,7 +319,7 @@ def is_authorized(
permission: Permission,
oid: str | None = None,
) -> bool:
permissions = self.permissions(organization, repo)
permissions = self.permissions(organization, repo, authoritative=True)
return permission in permissions if permissions else False

def cache_ttl(self, permissions: set[Permission]) -> float:
Expand Down Expand Up @@ -351,11 +369,12 @@ def _extract_token(self, request: flask.Request) -> str:
return token

def __post_init__(self, request: flask.Request) -> None:
self.org, self.repo = request.path.split("/", maxsplit=3)[1:3]
org_repo_getter = itemgetter("organization", "repo")
self.org, self.repo = org_repo_getter(request.view_args or {})
self.token = self._extract_token(request)

def __init__(self, cfg: Config) -> None:
self._api_url = cfg.api_url.rstrip("/")
self._api_url = cfg.api_url
self._api_headers = {"Accept": "application/vnd.github+json"}
if cfg.api_version:
self._api_headers["X-GitHub-Api-Version"] = cfg.api_version
Expand Down Expand Up @@ -471,8 +490,12 @@ def __call__(self, request: flask.Request) -> Identity | None:
self._authorize(ctx, user)
return user

@property
def api_url(self) -> str:
return self._api_url


def factory(**options: Mapping[str, Any]) -> GithubAuthenticator:
def factory(**options: Any) -> GithubAuthenticator:
"""Build GitHub Authenticator from supplied options."""
config = Config.from_dict(options)
return GithubAuthenticator(config)
1 change: 1 addition & 0 deletions requirements/dev.in
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pytest-mypy
pytest-env
pytest-cov
pytest-vcr
responses

pytz
types-pytz
Expand Down
5 changes: 5 additions & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,17 @@ pytz==2023.3.post1
pyyaml==6.0.1
# via
# -c requirements/main.txt
# responses
# vcrpy
recommonmark==0.7.1
# via -r requirements/dev.in
requests==2.31.0
# via
# -c requirements/main.txt
# responses
# sphinx
responses==0.25.0
# via -r requirements/dev.in
rsa==4.9
# via
# -c requirements/main.txt
Expand Down Expand Up @@ -241,6 +245,7 @@ urllib3==2.0.7
# via
# -c requirements/main.txt
# requests
# responses
# types-requests
vcrpy==5.1.0
# via pytest-vcr
Expand Down
Loading

0 comments on commit a9a799d

Please sign in to comment.