Skip to content

Commit

Permalink
♻️ webserver: fixes mypy issues in login plugin (#4105)
Browse files Browse the repository at this point in the history
  • Loading branch information
matusdrobuliak66 authored Apr 20, 2023
1 parent 109c9c0 commit 82a734d
Show file tree
Hide file tree
Showing 19 changed files with 74 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@
"APP_JSONSCHEMA_SPECS_KEY",
"APP_OPENAPI_SPECS_KEY",
"APP_SETTINGS_KEY",
"RQT_USERID_KEY",
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from servicelib.rest_constants import RESPONSE_MODEL_POLICY

from . import catalog_client
from ._constants import RQ_PRODUCT_KEY
from ._constants import RQ_PRODUCT_KEY, RQT_USERID_KEY
from ._meta import api_version_prefix
from .catalog_models import (
ServiceInputGet,
Expand All @@ -29,7 +29,7 @@
replace_service_input_outputs,
)
from .catalog_units import can_connect
from .login.decorators import RQT_USERID_KEY, login_required
from .login.decorators import login_required
from .security_decorators import permission_required

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from servicelib.aiohttp.typing_extension import Handler
from servicelib.json_serialization import json_dumps
from servicelib.request_keys import RQT_USERID_KEY

from .. import director_v2_api
from .._meta import api_version_prefix
Expand All @@ -21,7 +22,7 @@
DirectorServiceError,
)
from ..director_v2_models import ClusterCreate, ClusterPatch, ClusterPing
from ..login.decorators import RQT_USERID_KEY, login_required
from ..login.decorators import login_required
from ..security_decorators import permission_required

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from servicelib.aiohttp.rest_responses import create_error_response, get_http_error
from servicelib.json_serialization import json_dumps
from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON
from servicelib.request_keys import RQT_USERID_KEY

from ._constants import RQ_PRODUCT_KEY
from ._meta import api_version_prefix as VTAG
from .director_v2_abc import get_project_run_policy
from .director_v2_core_computations import ComputationsApi
from .director_v2_exceptions import DirectorServiceError
from .login.decorators import RQT_USERID_KEY, login_required
from .login.decorators import login_required
from .security_decorators import permission_required
from .version_control.db import CommitID

Expand Down
17 changes: 9 additions & 8 deletions services/web/server/src/simcore_service_webserver/login/_2fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@

import asyncio
import logging
from typing import Optional

from aiohttp import web
from pydantic import BaseModel, Field
from servicelib.logging_utils import log_decorator
from servicelib.utils_secrets import generate_passcode
from settings_library.twilio import TwilioSettings
from simcore_postgres_database.models.users import UserNameConverter
from simcore_postgres_database.models.users import FullNameTuple, UserNameConverter
from twilio.rest import Client

from ..redis import get_redis_validation_code_client
Expand All @@ -26,8 +25,8 @@


def _get_human_readable_first_name(user_name: str) -> str:
full_name = UserNameConverter.get_full_name(user_name)
first_name = full_name.first_name.strip()[:20] # security strip
full_name: FullNameTuple = UserNameConverter.get_full_name(user_name)
first_name: str = full_name.first_name.strip()[:20] # security strip
return first_name.capitalize()


Expand All @@ -49,7 +48,8 @@ async def _do_create_2fa_code(
*,
expiration_seconds: int,
) -> str:
hash_key, code = user_email, generate_passcode()
hash_key: str = user_email
code: str = generate_passcode()
await redis_client.set(hash_key, value=code, ex=expiration_seconds)
return code

Expand All @@ -59,7 +59,7 @@ async def create_2fa_code(
) -> str:
"""Saves 2FA code with an expiration time, i.e. a finite Time-To-Live (TTL)"""
redis_client = get_redis_validation_code_client(app)
code = await _do_create_2fa_code(
code: str = await _do_create_2fa_code(
redis_client=redis_client,
user_email=user_email,
expiration_seconds=expiration_in_seconds,
Expand All @@ -68,11 +68,12 @@ async def create_2fa_code(


@log_decorator(log, level=logging.DEBUG)
async def get_2fa_code(app: web.Application, user_email: str) -> Optional[str]:
async def get_2fa_code(app: web.Application, user_email: str) -> str | None:
"""Returns 2FA code for user or None if it does not exist (e.g. expired or never set)"""
redis_client = get_redis_validation_code_client(app)
hash_key = user_email
return await redis_client.get(hash_key)
hash_value: str | None = await redis_client.get(hash_key)
return hash_value


@log_decorator(log, level=logging.DEBUG)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
from contextlib import contextmanager
from datetime import datetime
from typing import Iterator, Literal, Optional
from typing import Iterator, Literal, cast

from aiohttp import web
from models_library.basic_types import IdInt
Expand All @@ -22,6 +22,7 @@
validator,
)
from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON
from simcore_postgres_database.models.confirmations import ConfirmationAction
from yarl import URL

from ..invitations import (
Expand All @@ -32,14 +33,13 @@
validate_invitation_url,
)
from ._confirmation import (
ConfirmationAction,
get_expiration_date,
is_confirmation_expired,
validate_confirmation_code,
)
from ._constants import MSG_EMAIL_EXISTS, MSG_INVITATIONS_CONTACT_SUFFIX
from .settings import LoginOptions
from .storage import AsyncpgStorage, ConfirmationTokenDict
from .storage import AsyncpgStorage, BaseConfirmationTokenDict, ConfirmationTokenDict
from .utils import CONFIRMATION_PENDING

log = logging.getLogger(__name__)
Expand All @@ -51,14 +51,14 @@ class ConfirmationTokenInfoDict(ConfirmationTokenDict):


class InvitationData(BaseModel):
issuer: Optional[str] = Field(
issuer: str | None = Field(
None,
description="Who has issued this invitation? (e.g. an email or a uid)",
)
guest: Optional[str] = Field(
guest: str | None = Field(
None, description="Reference tag for this invitation", deprecated=True
)
trial_account_days: Optional[PositiveInt] = Field(
trial_account_days: PositiveInt | None = Field(
None,
description="If set, this invitation will activate a trial account."
"Sets the number of days from creation until the account expires",
Expand All @@ -77,7 +77,7 @@ def ensure_enum(cls, v):
return ConfirmationAction(v)


ACTION_TO_DATA_TYPE: dict[ConfirmationAction, Optional[type]] = {
ACTION_TO_DATA_TYPE: dict[ConfirmationAction, type | None] = {
ConfirmationAction.INVITATION: InvitationData,
ConfirmationAction.REGISTRATION: None,
}
Expand Down Expand Up @@ -131,9 +131,9 @@ async def create_invitation_token(
db: AsyncpgStorage,
*,
user_id: IdInt,
user_email: Optional[LowerCaseEmailStr] = None,
tag: Optional[str] = None,
trial_days: Optional[PositiveInt] = None,
user_email: LowerCaseEmailStr | None = None,
tag: str | None = None,
trial_days: PositiveInt | None = None,
) -> ConfirmationTokenDict:
"""Creates an invitation token for a guest to register in the platform and returns
Expand Down Expand Up @@ -167,7 +167,7 @@ def _invitations_request_context(invitation_code: str) -> Iterator[URL]:
"""
try:
url = get_invitation_url(
confirmation=ConfirmationTokenDict(
confirmation=BaseConfirmationTokenDict(
code=invitation_code, action=ConfirmationAction.INVITATION.name
),
origin=URL("https://dummyhost.com:8000"),
Expand Down Expand Up @@ -237,7 +237,7 @@ async def check_and_consume_invitation(
if confirmation_token := await validate_confirmation_code(invitation_code, db, cfg):
try:
invitation = _InvitationValidator.parse_obj(confirmation_token)
return invitation.data
return cast(InvitationData, invitation.data)

except ValidationError as err:
log.warning(
Expand All @@ -262,7 +262,7 @@ async def check_and_consume_invitation(


def get_invitation_url(
confirmation: ConfirmationTokenDict, origin: Optional[URL] = None
confirmation: BaseConfirmationTokenDict, origin: URL | None = None
) -> URL:
"""Creates a URL to invite a user for registration
Expand Down Expand Up @@ -291,7 +291,7 @@ def get_confirmation_info(
Extends ConfirmationTokenDict by adding extra info and
deserializing action's data entry
"""
info = ConfirmationTokenInfoDict(**confirmation)
info: ConfirmationTokenInfoDict = ConfirmationTokenInfoDict(**confirmation)

action = ConfirmationAction(confirmation["action"])
if (data_type := ACTION_TO_DATA_TYPE[action]) and (data := confirmation["data"]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import uuid as uuidlib
from copy import deepcopy
from datetime import timedelta
from typing import Optional

import simcore_postgres_database.webserver_models as orm
import sqlalchemy as sa
Expand All @@ -14,12 +13,13 @@
from servicelib.aiohttp.application_keys import APP_DB_ENGINE_KEY
from servicelib.aiohttp.requests_validation import parse_request_body_as
from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON
from servicelib.request_keys import RQT_USERID_KEY
from servicelib.rest_constants import RESPONSE_MODEL_POLICY
from simcore_postgres_database.errors import DatabaseError
from sqlalchemy.sql import func

from ..security_api import check_permission
from .decorators import RQT_USERID_KEY, login_required
from .decorators import login_required
from .utils import get_random_string

log = logging.getLogger(__name__)
Expand All @@ -32,7 +32,7 @@

class ApiKeyCreate(BaseModel):
display_name: str = Field(..., min_length=3)
expiration: Optional[timedelta] = Field(
expiration: timedelta | None = Field(
None,
description="Time delta from creation time to expiration. If None, then it does not expire.",
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sys
from datetime import datetime
from typing import Optional

import typer
from simcore_postgres_database.models.confirmations import ConfirmationAction
Expand All @@ -13,7 +12,7 @@
def invitations(
base_url: str,
issuer_email: str,
trial_days: Optional[int] = None,
trial_days: int | None = None,
user_id: int = 1,
num_codes: int = 15,
code_length: int = 30,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Final, Optional
from typing import Final

from aiohttp import web
from aiohttp.web import RouteTableDef
Expand All @@ -9,6 +9,7 @@
from servicelib.error_codes import create_error_code
from servicelib.logging_utils import log_context
from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON
from servicelib.request_keys import RQT_USERID_KEY
from simcore_postgres_database.models.users import UserRole

from .._meta import API_VTAG
Expand Down Expand Up @@ -38,7 +39,7 @@
)
from ._models import InputSchema
from ._security import login_granted_response
from .decorators import RQT_USERID_KEY, login_required
from .decorators import login_required
from .settings import LoginSettingsForProduct, get_plugin_settings
from .storage import AsyncpgStorage, get_plugin_storage
from .utils import (
Expand All @@ -62,8 +63,8 @@ class LoginBody(InputSchema):

class CodePageParams(BaseModel):
message: str
retry_2fa_after: Optional[PositiveInt] = None
next_url: Optional[str] = None
retry_2fa_after: PositiveInt | None = None
next_url: str | None = None


class LoginNextPage(NextPage[CodePageParams]):
Expand Down Expand Up @@ -240,7 +241,7 @@ async def login_2fa(request: web.Request):


class LogoutBody(InputSchema):
client_session_id: Optional[str] = Field(
client_session_id: str | None = Field(
None, example="5ac57685-c40f-448f-8711-70be1936fd63"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import SecretStr, validator
from servicelib.aiohttp.requests_validation import parse_request_body_as
from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON
from servicelib.request_keys import RQT_USERID_KEY

from .._meta import API_VTAG
from ..products import Product, get_current_product
Expand All @@ -23,7 +24,7 @@
MSG_WRONG_PASSWORD,
)
from ._models import InputSchema, create_password_match_validator
from .decorators import RQT_USERID_KEY, login_required
from .decorators import login_required
from .settings import LoginOptions, get_plugin_options
from .storage import AsyncpgStorage, get_plugin_storage
from .utils import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ async def phone_confirmation(request: web.Request):
)

db: AsyncpgStorage = get_plugin_storage(request.app)
product: Product = get_current_product(request)

if not settings.LOGIN_2FA_REQUIRED:
raise web.HTTPServiceUnavailable(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
import json
import logging
from typing import Optional

import asyncpg
from aiohttp import web
from pydantic import ValidationError
from servicelib.aiohttp.application_setup import ModuleCategory, app_module_setup
from settings_library.email import SMTPSettings
from settings_library.postgres import PostgresSettings

from .._constants import (
APP_OPENAPI_SPECS_KEY,
Expand All @@ -15,10 +16,8 @@
INDEX_RESOURCE_NAME,
)
from ..db import setup_db
from ..db_settings import PostgresSettings
from ..db_settings import get_plugin_settings as get_db_plugin_settings
from ..email import setup_email
from ..email_settings import SMTPSettings
from ..email_settings import get_plugin_settings as get_email_plugin_settings
from ..invitations import setup_invitations
from ..products import ProductName, list_products, setup_products
Expand Down Expand Up @@ -83,7 +82,7 @@ async def _resolve_login_settings_per_product(app: web.Application):
for the login plugin. Note that product settings override app settings.
"""
# app plugin settings
app_login_settings: Optional[LoginSettings]
app_login_settings: LoginSettings | None
login_settings_per_product: dict[ProductName, LoginSettingsForProduct] = {}

if app_login_settings := app[APP_SETTINGS_KEY].WEBSERVER_LOGIN:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def include_path(tuple_object):
"list_api_keys": api_keys_handlers.list_api_keys,
}

routes = map_handlers_with_operations(
routes: list[web.RouteDef] = map_handlers_with_operations(
handlers_map,
filter(include_path, iter_path_operations(validated_specs)),
strict=True,
Expand Down
Loading

0 comments on commit 82a734d

Please sign in to comment.