Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
pcrespov committed May 2, 2023
1 parent e460380 commit 5b41f52
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,17 @@ async def can(
if operation in role_access.check.keys():
check = role_access.check[operation]
try:
is_valid: bool

if inspect.iscoroutinefunction(check):
return await check(context)
return check(context)
is_valid = await check(context)
return is_valid

is_valid = check(context)
return is_valid

except Exception: # pylint: disable=broad-except
_logger.exception(
_logger.debug(
"Check operation '%s', shall not raise [%s]", operation, check
)
return False
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from dataclasses import dataclass, field
from typing import TypedDict

import attr
import sqlalchemy as sa
from aiohttp import web
from aiohttp_security.abc import AbstractAuthorizationPolicy
Expand All @@ -15,7 +15,7 @@
from tenacity import retry

from ..db_models import UserStatus, users
from ._access_model import RoleBasedAccessModel, check_access
from ._access_model import ContextType, RoleBasedAccessModel, check_access

_logger = logging.getLogger(__name__)

Expand All @@ -25,11 +25,11 @@ class _UserIdentity(TypedDict, total=True):
role: UserRole


@attr.s(auto_attribs=True, frozen=True)
@dataclass(frozen=True)
class AuthorizationPolicy(AbstractAuthorizationPolicy):
app: web.Application
access_model: RoleBasedAccessModel
timed_cache: ExpiringDict = attr.ib(
timed_cache: ExpiringDict = field(
init=False, default=ExpiringDict(max_len=100, max_age_seconds=10)
)

Expand All @@ -43,10 +43,10 @@ def engine(self) -> Engine:
# return self.app.config_dict[APP_DB_ENGINE_KEY]
return self.app[APP_DB_ENGINE_KEY]

@retry(**PostgresRetryPolicyUponOperation(log).kwargs)
@retry(**PostgresRetryPolicyUponOperation(_logger).kwargs)
async def _get_active_user_with(self, identity: str) -> _UserIdentity | None:
# NOTE: Keeps a cache for a few seconds. Observed successive streams of this query
user: _UserIdentity | None = self.timed_cache.get(identity)
user: _UserIdentity | None = self.timed_cache.get(identity, None)
if user is None:
async with self.engine.acquire() as conn:
# NOTE: sometimes it raises psycopg2.DatabaseError in #880 and #1160
Expand All @@ -60,7 +60,9 @@ async def _get_active_user_with(self, identity: str) -> _UserIdentity | None:
if row is not None:
assert row["id"] # nosec
assert row["role"] # nosec
self.timed_cache[identity] = user = dict(row.items())
self.timed_cache[identity] = user = _UserIdentity(
id=row.id, role=row.role
)

return user

Expand All @@ -72,13 +74,18 @@ async def authorized_userid(self, identity: str) -> int | None:
"""
# TODO: why users.c.user_login_key!=users.c.email
user: _UserIdentity | None = await self._get_active_user_with(identity)
return user["id"] if user else None

if user is None:
return None

user_id: int = user["id"]
return user_id

async def permits(
self,
identity: str,
permission: str | tuple,
context: dict | None = None,
permission: str,
context: ContextType = None,
) -> bool:
"""Determines whether an identified user has permission
Expand All @@ -96,8 +103,8 @@ async def permits(
return False

user = await self._get_active_user_with(identity)
if user:
role = user.get("role")
return await check_access(self.access_model, role, permission, context)
if user is None:
return False

return False
role = user.get("role")
return await check_access(self.access_model, role, permission, context)

0 comments on commit 5b41f52

Please sign in to comment.