Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ Refactor products plugin before extension #4741

Merged
merged 13 commits into from
Sep 13, 2023
14 changes: 7 additions & 7 deletions packages/postgres-database/tests/products/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# pylint: disable=unused-argument


from typing import Callable
from collections.abc import Callable

import pytest
from aiopg.sa.exc import ResourceClosedError
Expand Down Expand Up @@ -38,12 +38,12 @@ async def _make(conn) -> None:
)
.on_conflict_do_update(
index_elements=[products.c.name],
set_=dict(
display_name=f"Product {name.capitalize()}",
short_name=name[:3].lower(),
host_regex=regex,
priority=n,
),
set_={
"display_name": f"Product {name.capitalize()}",
"short_name": name[:3].lower(),
"host_regex": regex,
"priority": n,
},
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@


import json
from collections.abc import Callable
from pathlib import Path
from pprint import pprint
from typing import Callable

import sqlalchemy as sa
from aiopg.sa.engine import Engine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


import asyncio
from typing import Callable
from collections.abc import Callable

pcrespov marked this conversation as resolved.
Show resolved Hide resolved
import pytest
import sqlalchemy as sa
Expand Down Expand Up @@ -118,10 +118,10 @@ async def _auto_create_products_groups():
)
):
# get or create
product_group_id = await get_or_create_product_group(
return await get_or_create_product_group(
conn, product_name=product_row.name
)
return product_group_id
return None

tasks = [asyncio.create_task(_auto_create_products_groups()) for _ in range(5)]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from aiohttp import web

from .._meta import api_version_prefix
from ..products.plugin import get_product_name
from ..products.api import get_product_name
from ..utils_aiohttp import envelope_json_response
from . import _api
from ._models import Announcement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .._meta import API_VTAG
from ..login.decorators import login_required
from ..products.plugin import Product, get_current_product, get_product_template_path
from ..products.api import Product, get_current_product, get_product_template_path
from ..security.decorators import permission_required
from ..utils import get_traceback_string
from ..utils_aiohttp import envelope_json_response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .._constants import RQT_USERID_KEY
from .._meta import API_VTAG
from ..login.decorators import login_required
from ..products.plugin import Product, get_current_product
from ..products.api import Product, get_current_product
from ..scicrunch.db import ResearchResourceRepository
from ..scicrunch.errors import InvalidRRID, ScicrunchError
from ..scicrunch.models import ResearchResource, ResourceHit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from servicelib.error_codes import create_error_code
from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON

from ..products.plugin import Product, get_current_product
from ..products.api import Product, get_current_product
from ..session.access_policies import session_access_required
from ._2fa import (
create_2fa_code,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from simcore_postgres_database.models.users import UserRole

from .._meta import API_VTAG
from ..products.plugin import Product, get_current_product
from ..products.api import Product, get_current_product
from ..security.api import check_password, forget
from ..session.access_policies import (
on_success_grant_session_access_to,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from servicelib.request_keys import RQT_USERID_KEY

from .._meta import API_VTAG
from ..products.plugin import Product, get_current_product
from ..products.api import Product, get_current_product
from ..security.api import check_password, encrypt_password
from ..utils import HOUR
from ..utils_rate_limiting import global_rate_limit_route
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from simcore_postgres_database.errors import UniqueViolation
from yarl import URL

from ..products.plugin import Product, get_current_product
from ..products.api import Product, get_current_product
from ..security.api import encrypt_password
from ..session.access_policies import session_access_required
from ..utils import MINUTE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .._meta import API_VTAG
from ..groups.api import auto_add_user_to_groups, auto_add_user_to_product_group
from ..invitations.plugin import is_service_invitation_code
from ..products.plugin import Product, get_current_product
from ..products.api import Product, get_current_product
from ..security.api import encrypt_password
from ..session.access_policies import (
on_success_grant_session_access_to,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from ..email.plugin import setup_email
from ..email.settings import get_plugin_settings as get_email_plugin_settings
from ..invitations.plugin import setup_invitations
from ..products.plugin import ProductName, list_products, setup_products
from ..products.api import ProductName, list_products
from ..products.plugin import setup_products
from ..redis import setup_redis
from ..rest.plugin import setup_rest
from . import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .._resources import webserver_resources
from ..email.utils import AttachmentTuple, send_email_from_template
from ..products.plugin import get_product_template_path
from ..products.api import get_product_template_path

log = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from pathlib import Path

import aiofiles
from aiohttp import web
from models_library.products import ProductName

from .._constants import APP_PRODUCTS_KEY, RQ_PRODUCT_KEY
from .._resources import webserver_resources
from ._db import ProductRepository
from ._events import APP_PRODUCTS_TEMPLATES_DIR_KEY
from ._model import Product


def get_product_name(request: web.Request) -> str:
product_name: str = request[RQ_PRODUCT_KEY]
return product_name


def get_current_product(request: web.Request) -> Product:
"""Returns product associated to current request"""
product_name: ProductName = get_product_name(request)
current_product: Product = request.app[APP_PRODUCTS_KEY][product_name]
return current_product


def list_products(app: web.Application) -> list[Product]:
products: list[Product] = app[APP_PRODUCTS_KEY].values()
return products


#
# helpers for get_product_template_path
#


def _themed(dirname: str, template: str) -> Path:
path: Path = webserver_resources.get_path(f"{Path(dirname) / template}")
return path


async def _get_content(request: web.Request, template_name: str):
repo = ProductRepository(request)
content = await repo.get_template_content(template_name)
if not content:
msg = f"Missing template {template_name} for product"
raise ValueError(msg)
return content


def _safe_get_current_product(request: web.Request) -> Product | None:
try:
product: Product = get_current_product(request)
return product
except KeyError:
return None


async def get_product_template_path(request: web.Request, filename: str) -> Path:
if product := _safe_get_current_product(request):
if template_name := product.get_template_name_for(filename):
template_dir: Path = request.app[APP_PRODUCTS_TEMPLATES_DIR_KEY]
template_path = template_dir / template_name
if not template_path.exists():
# cache
content = await _get_content(request, template_name)
try:
async with aiofiles.open(template_path, "wt") as fh:
await fh.write(content)
except Exception:
# fails to write
if template_path.exists():
template_path.unlink()
raise

return template_path

# check static resources under templates/
if (
template_path := _themed(f"templates/{product.name}", filename)
) and template_path.exists():
return template_path

# If no product or template for product defined, we fall back to common templates
common_template = _themed("templates/common", filename)
if not common_template.exists():
msg = f"{filename} is not part of the templates/common"
raise ValueError(msg)

return common_template
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import AsyncIterator
from collections.abc import AsyncIterator

import sqlalchemy as sa
from aiopg.sa.connection import SAConnection
Expand All @@ -10,7 +10,7 @@
from ..db.models import products
from ._model import Product

log = logging.getLogger(__name__)
_logger = logging.getLogger(__name__)


#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ._db import iter_products
from ._model import Product

log = logging.getLogger(__name__)
_logger = logging.getLogger(__name__)

APP_PRODUCTS_TEMPLATES_DIR_KEY = f"{__name__}.template_dir"

Expand Down Expand Up @@ -52,7 +52,7 @@ async def auto_create_products_groups(app: web.Application) -> None:
product_group_id = await get_or_create_product_group(
connection, product_name
)
log.debug(
_logger.debug(
"Product with %s has an associated group with %s",
f"{product_name=}",
f"{product_group_id=}",
Expand All @@ -63,7 +63,7 @@ def _set_app_state(
app: web.Application, app_products: dict[str, Product], default_product_name: str
):
app[APP_PRODUCTS_KEY] = app_products
assert default_product_name in app_products.keys() # nosec
assert default_product_name in app_products # nosec
app[f"{APP_PRODUCTS_KEY}_default"] = default_product_name


Expand All @@ -73,23 +73,22 @@ async def load_products_on_startup(app: web.Application):
"""
app_products: dict[str, Product] = {}
engine: Engine = app[APP_DB_ENGINE_KEY]
async with engine.acquire() as conn:
async for row in iter_products(conn):
async with engine.acquire() as connection:
async for row in iter_products(connection):
try:
name = row.name
app_products[name] = Product.from_orm(row)

assert name in FRONTEND_APPS_AVAILABLE # nosec

except ValidationError as err:
raise InvalidConfig(
f"Invalid product configuration in db '{row}':\n {err}"
) from err
except ValidationError as err: # noqa: PERF203
msg = f"Invalid product configuration in db '{row}':\n {err}"
raise InvalidConfig(msg) from err

assert FRONTEND_APP_DEFAULT in app_products.keys() # nosec
assert FRONTEND_APP_DEFAULT in app_products # nosec

default_product_name = await get_default_product_name(conn)
default_product_name = await get_default_product_name(connection)

_set_app_state(app, app_products, default_product_name)

log.debug("Product loaded: %s", [p.name for p in app_products.values()])
_logger.debug("Product loaded: %s", [p.name for p in app_products.values()])
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from .._meta import API_VTAG
from ._model import Product

log = logging.getLogger(__name__)
_logger = logging.getLogger(__name__)


def discover_product_by_hostname(request: web.Request) -> str | None:
def _discover_product_by_hostname(request: web.Request) -> str | None:
products: dict[str, Product] = request.app[APP_PRODUCTS_KEY]
for product in products.values():
if product.host_regex.search(request.host):
Expand All @@ -19,7 +19,7 @@ def discover_product_by_hostname(request: web.Request) -> str | None:
return None


def discover_product_by_request_header(request: web.Request) -> str | None:
def _discover_product_by_request_header(request: web.Request) -> str | None:
requested_product: str | None = request.headers.get(X_PRODUCT_NAME_HEADER)
if requested_product:
for product_name in request.app[APP_PRODUCTS_KEY]:
Expand Down Expand Up @@ -48,8 +48,8 @@ async def discover_product_middleware(request: web.Request, handler: Handler):
or request.path == "/static-frontend-data.json"
):
product_name = (
discover_product_by_request_header(request)
or discover_product_by_hostname(request)
_discover_product_by_request_header(request)
or _discover_product_by_hostname(request)
or _get_app_default_product_name(request)
)
request[RQ_PRODUCT_KEY] = product_name
Expand All @@ -61,7 +61,7 @@ async def discover_product_middleware(request: web.Request, handler: Handler):
or request.path.startswith("/view")
or request.path == "/"
):
product_name = discover_product_by_hostname(
product_name = _discover_product_by_hostname(
request
) or _get_app_default_product_name(request)

Expand Down
Loading