diff --git a/questionpy_server/api/routes/_files.py b/questionpy_server/api/routes/_files.py index a3c2bf82..8664e343 100644 --- a/questionpy_server/api/routes/_files.py +++ b/questionpy_server/api/routes/_files.py @@ -6,7 +6,7 @@ from aiohttp import web from aiohttp.web_exceptions import HTTPNotImplemented -from questionpy_server.decorators import ensure_package_and_question_state_exist +from questionpy_server.decorators import ensure_package from questionpy_server.package import Package from questionpy_server.worker.runtime.package_location import ZipPackageLocation @@ -18,7 +18,7 @@ @file_routes.post(r"/packages/{package_hash}/file/{namespace}/{short_name}/{path:static/.*}") # type: ignore[arg-type] -@ensure_package_and_question_state_exist +@ensure_package async def post_attempt_start(request: web.Request, package: Package) -> web.Response: qpy_server: QPyServer = request.app["qpy_server_app"] namespace = request.match_info["namespace"] diff --git a/questionpy_server/api/routes/_packages.py b/questionpy_server/api/routes/_packages.py index 4232cc89..c6a52844 100644 --- a/questionpy_server/api/routes/_packages.py +++ b/questionpy_server/api/routes/_packages.py @@ -9,7 +9,7 @@ from questionpy_common.environment import RequestUser from questionpy_server.api.models import QuestionCreateArguments, QuestionEditFormResponse, RequestBaseData -from questionpy_server.decorators import ensure_package_and_question_state_exist +from questionpy_server.decorators import ensure_package, ensure_package_and_question_state_exist from questionpy_server.package import Package from questionpy_server.web import json_response from questionpy_server.worker.runtime.package_location import ZipPackageLocation @@ -84,7 +84,7 @@ async def post_question_migrate(_request: web.Request) -> web.Response: @package_routes.post(r"/package-extract-info") # type: ignore[arg-type] -@ensure_package_and_question_state_exist +@ensure_package async def package_extract_info(_request: web.Request, package: Package) -> web.Response: """Get package information.""" return json_response(data=package.get_info(), status=201) diff --git a/questionpy_server/cache.py b/questionpy_server/cache.py index ffd4d407..2f26e472 100644 --- a/questionpy_server/cache.py +++ b/questionpy_server/cache.py @@ -19,9 +19,14 @@ class File(NamedTuple): size: int -class SizeError(Exception): - def __init__(self, message: str = "", max_size: int = 0, actual_size: int = 0): - super().__init__(message) +class CacheItemTooLargeError(Exception): + def __init__(self, key: str, actual_size: int, max_size: int): + readable_actual = ByteSize(actual_size).human_readable() + readable_max = ByteSize(max_size).human_readable() + super().__init__( + f"Unable to cache item '{key}' with size '{readable_actual}' because it exceeds the maximum " + f"allowed size of '{readable_max}'" + ) self.max_size = max_size self.actual_size = actual_size @@ -146,12 +151,7 @@ async def put(self, key: str, value: bytes) -> Path: if size > self.max_size: # If we allowed this, the loop at the end would remove all items from the dictionary, # so we raise an error to allow exceptions for this case. - msg = f"Item itself exceeds maximum allowed size of {ByteSize(self.max_size).human_readable()}" - raise SizeError( - msg, - max_size=self.max_size, - actual_size=size, - ) + raise CacheItemTooLargeError(key, size, self.max_size) async with self._lock: # Save the bytes on filesystem. diff --git a/questionpy_server/collector/package_collection.py b/questionpy_server/collector/package_collection.py index 3e2fa39e..79c9e865 100644 --- a/questionpy_server/collector/package_collection.py +++ b/questionpy_server/collector/package_collection.py @@ -88,6 +88,9 @@ def get(self, package_hash: str) -> "Package": Returns: path to the package + + Raises: + FileNotFoundError: if no package with the given hash exists """ # Check if package was indexed if package := self._indexer.get_by_hash(package_hash): diff --git a/questionpy_server/decorators.py b/questionpy_server/decorators.py index ea4e73d9..8d668341 100644 --- a/questionpy_server/decorators.py +++ b/questionpy_server/decorators.py @@ -1,23 +1,187 @@ import functools -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, cast, get_type_hints +import inspect +from collections.abc import Awaitable, Callable +from functools import wraps +from inspect import Parameter +from json import JSONDecodeError +from typing import TYPE_CHECKING, Concatenate, NamedTuple, ParamSpec, TypeAlias -from aiohttp.abc import Request +import aiohttp.typedefs +from aiohttp import BodyPartReader, web from aiohttp.log import web_logger -from aiohttp.web_exceptions import HTTPBadRequest, HTTPNotFound, HTTPUnsupportedMediaType +from aiohttp.web_exceptions import HTTPBadRequest +from pydantic import ValidationError +from questionpy_common import constants from questionpy_server.api.models import MainBaseModel, NotFoundStatus, NotFoundStatusWhat -from questionpy_server.types import RouteHandler -from questionpy_server.web import create_model_from_json, get_or_save_package, parse_form_data +from questionpy_server.cache import CacheItemTooLargeError +from questionpy_server.package import Package +from questionpy_server.types import M +from questionpy_server.web import ( + HashContainer, + read_part, +) if TYPE_CHECKING: from questionpy_server.app import QPyServer +_P = ParamSpec("_P") +_PrettyHandlerFunc: TypeAlias = Callable[Concatenate[web.Request, _P], Awaitable[web.StreamResponse]] -# TODO: refactor to reduce complexity -def ensure_package_and_question_state_exist( # noqa: C901 - _func: RouteHandler | None = None, -) -> RouteHandler | Callable[[RouteHandler], RouteHandler]: + +class _RequestBodyParts(NamedTuple): + main: bytes | None + package: HashContainer | None + question_state: bytes | None + + +def _get_main_body_param(handler: _PrettyHandlerFunc, signature: inspect.Signature) -> inspect.Parameter | None: + candidates = [ + param + for param in signature.parameters.values() + if isinstance(param.annotation, type) and issubclass(param.annotation, MainBaseModel) + ] + + if not candidates: + # Handler doesn't use the main body. + return None + + if len(candidates) > 1: + msg = f"Handler function '{handler.__name__}' ambiguously takes multiple MainBaseModel parameters" + raise TypeError(msg) + + return candidates[0] + + +def _get_question_state_param(handler: _PrettyHandlerFunc, signature: inspect.Signature) -> inspect.Parameter | None: + if "question_state" not in signature.parameters: + return None + + candidate = signature.parameters["question_state"] + + if candidate.annotation is not Parameter.empty and not isinstance(bytes, candidate.annotation): + msg = f"Parameter 'question_state' of handler function '{handler.__name__}' must have type 'bytes'" + raise TypeError(msg) + + return candidate + + +def _get_package_param(handler: _PrettyHandlerFunc, signature: inspect.Signature) -> inspect.Parameter | None: + candidates = [param for param in signature.parameters.values() if param.annotation is Package] + + if not candidates: + # Handler doesn't use the package. + return None + + if len(candidates) > 1: + msg = f"Handler function '{handler.__name__}' ambiguously takes multiple Package parameters" + raise TypeError(msg) + + return candidates[0] + + +_PARTS_REQUEST_KEY = "qpy-request-parts" + + +async def _read_body_parts(request: web.Request) -> _RequestBodyParts: + parts: _RequestBodyParts = request.get(_PARTS_REQUEST_KEY, None) + if parts: + return parts + + if not request.body_exists: + # No body sent at all. + parts = _RequestBodyParts(None, None, None) + elif request.content_type == "multipart/form-data": + # Multiple parts. + parts = await parse_form_data(request) + elif request.content_type == "application/json": + # Just the main body part. + parts = _RequestBodyParts(await request.read(), None, None) + else: + msg = ( + f"Wrong content type, expected multipart/form-data, application/json or no body, got " + f"'{request.content_type}'" + ) + web_logger.info(msg) + raise web.HTTPUnsupportedMediaType(reason=msg) + + request[_PARTS_REQUEST_KEY] = parts + return parts + + +class MainBodyMissingError(web.HTTPBadRequest): + def __init__(self) -> None: + super().__init__(text="The main body is required but was not provided.") + # TODO: Log necessary? + web_logger.warning(self.text) + + +class PackageMissingError(web.HTTPNotFound): + def __init__(self) -> None: + super().__init__( + text=NotFoundStatus(what=NotFoundStatusWhat.PACKAGE).model_dump_json(), content_type="application/json" + ) + + +class PackageHashMismatchError(web.HTTPBadRequest): + def __init__(self, from_uri: str, from_body: str) -> None: + super().__init__( + text=f"The request URI specifies a package with hash '{from_uri}', but the sent package has " + f"a hash of '{from_body}'." + ) + web_logger.warning(self.text) + + +class QuestionStateMissingError(web.HTTPBadRequest): + def __init__(self) -> None: + super().__init__(text="A question state part is required but was not provided.") + # TODO: Log necessary? + web_logger.warning(self.text) + + +def ensure_package(handler: _PrettyHandlerFunc) -> aiohttp.typedefs.Handler: + """Decorator ensuring that the package needed by the handler is present.""" + signature = inspect.signature(handler) + package_param = _get_package_param(handler, signature) + + if not package_param: + msg = f"Handler '{handler.__name__}' doesn't have a package param but is decorated with ensure_package." + raise TypeError(msg) + + @wraps(handler) + async def wrapper(request: web.Request, *args: _P.args, **kwargs: _P.kwargs) -> web.StreamResponse: + server: QPyServer = request.app["qpy_server_app"] + + uri_package_hash: str | None = request.match_info.get("package_hash", None) + parts = await _read_body_parts(request) + + if parts.package and uri_package_hash and uri_package_hash != parts.package.hash: + raise PackageHashMismatchError(uri_package_hash, parts.package.hash) + + package = None + if uri_package_hash: + try: + package = server.package_collection.get(uri_package_hash) + except FileNotFoundError: + package = None + + if not package and parts.package: + try: + package = await server.package_collection.put(parts.package) + except CacheItemTooLargeError as e: + raise web.HTTPRequestEntityTooLarge(max_size=e.max_size, actual_size=e.actual_size, text=str(e)) from e + + if not package: + raise PackageMissingError + + kwargs[package_param.name] = package # type: ignore[union-attr] # we narrowed package_param earlier + + return await handler(request, *args, **kwargs) + + return wrapper + + +def ensure_package_and_question_state_exist(handler: _PrettyHandlerFunc) -> aiohttp.typedefs.Handler: """Decorator that ensures package and question state exist. Ensures that the package and question state exist (if needed by func) and that the json corresponds to the model @@ -27,81 +191,76 @@ def ensure_package_and_question_state_exist( # noqa: C901 * func may want an argument named 'data' (with a subclass of MainBaseModel) * func may want an argument named 'question_state' (bytes or bytes | None) * every func wants a package with an argument named 'package' + """ + signature = inspect.signature(handler) + + main_body_param = _get_main_body_param(handler, signature) + _get_package_param(handler, signature) + _get_question_state_param(handler, signature) + + @functools.wraps(handler) + async def wrapper(request: web.Request, *args: _P.args, **kwargs: _P.kwargs) -> web.StreamResponse: + request.app["qpy_server_app"] + + parts = await _read_body_parts(request) + + # TODO: unfuck + + if main_body_param: + if parts.main is None: + raise MainBodyMissingError + + kwargs[main_body_param.name] = validate_from_http(parts.main, main_body_param.annotation) + + # Check if func wants a question state and if it is provided. + if require_question_state and sent_question_state is None: + raise QuestionStateMissingError + kwargs["question_state"] = sent_question_state + + kwargs["package"] = package + return await handler(request, *args, **kwargs) + + return wrapper + + +async def parse_form_data(request: web.Request) -> _RequestBodyParts: + """Parses a multipart/form-data request. Args: - _func (Optional[RouteHandler]): Control parameter; allows using the decorator with or without arguments. - If this decorator is used with any arguments, this will always be the decorated function itself. (Default - value = None) + request (Request): The request to be parsed. + + Returns: tuple of main field, package, and question state """ + server: QPyServer = request.app["qpy_server_app"] + main = package = question_state = None + + reader = await request.multipart() + while part := await reader.next(): + if not isinstance(part, BodyPartReader): + continue + + if part.name == "main": + main = await read_part(part, server.settings.webservice.max_main_size, calculate_hash=False) + elif part.name == "package": + package = await read_part(part, server.settings.webservice.max_package_size, calculate_hash=True) + elif part.name == "question_state": + question_state = await read_part(part, constants.MAX_QUESTION_STATE_SIZE, calculate_hash=False) + + return _RequestBodyParts(main, package, question_state) - def decorator(function: RouteHandler) -> RouteHandler: # noqa: C901 - """Internal decorator function.""" - type_hints = get_type_hints(function) - question_state_type = type_hints.get("question_state") - takes_question_state = question_state_type is not None - require_question_state = question_state_type is bytes - main_part_json_model: type[MainBaseModel] | None = type_hints.get("data") - - if main_part_json_model and not issubclass(main_part_json_model, MainBaseModel): - msg = f"Parameter 'data' of function {function.__name__} has unexpected type." - raise TypeError(msg) - - @functools.wraps(function) - async def wrapper(request: Request, *args: Any, **kwargs: Any) -> Any: # noqa: C901 - """Wrapper around the actual function call.""" - server: QPyServer = request.app["qpy_server_app"] - package_hash: str = request.match_info.get("package_hash", "") - - if not request.body_exists: - main, sent_package, sent_question_state = None, None, None - elif request.content_type == "multipart/form-data": - main, sent_package, sent_question_state = await parse_form_data(request) - elif request.content_type == "application/json": - main, sent_package, sent_question_state = await request.read(), None, None - else: - web_logger.info("Wrong content type, multipart/form-data expected, got %s", request.content_type) - raise HTTPUnsupportedMediaType - - if main_part_json_model: - if main is None: - msg = "Multipart/form-data field 'main' is not set" - web_logger.warning(msg) - raise HTTPBadRequest(text=msg) - - model = create_model_from_json(main.decode(), main_part_json_model) - kwargs["data"] = model - - # Check if func wants a question state and if it is provided. - if takes_question_state: - if require_question_state and sent_question_state is None: - msg = "Multipart/form-data field 'question_state' is not set" - web_logger.warning(msg) - raise HTTPBadRequest(text=msg) - kwargs["question_state"] = sent_question_state - - # Check if a package is provided and if it matches the optional hash given in the URL. - if sent_package and package_hash and package_hash != sent_package.hash: - msg = f"Package hash does not match: {package_hash} != {sent_package.hash}" - web_logger.warning(msg) - raise HTTPBadRequest(text=msg) - - package = await get_or_save_package(server.package_collection, package_hash, sent_package) - if package is None: - if package_hash: - raise HTTPNotFound( - text=NotFoundStatus(what=NotFoundStatusWhat.PACKAGE).model_dump_json(), - content_type="application/json", - ) - - msg = "No package found in multipart/form-data" - web_logger.warning(msg) - raise HTTPBadRequest(text=msg) - - kwargs["package"] = package - return await function(request, *args, **kwargs) - - return cast(RouteHandler, wrapper) - - if _func is None: - return decorator - return decorator(_func) + +def validate_from_http(raw_body: str | bytes, param_class: type[M]) -> M: + """Validates the given json which was presumably an HTTP body to the given Pydantic model. + + Args: + raw_body: raw json body + param_class: the [pydantic.BaseModel][] subclass to valdiate to + """ + try: + return param_class.model_validate_json(raw_body) + except ValidationError as error: + web_logger.warning("JSON does not match model: %s", error) + raise HTTPBadRequest from error + except JSONDecodeError as error: + web_logger.warning("Invalid JSON in request") + raise HTTPBadRequest from error diff --git a/questionpy_server/repository/__init__.py b/questionpy_server/repository/__init__.py index bb294130..d64ffb6a 100644 --- a/questionpy_server/repository/__init__.py +++ b/questionpy_server/repository/__init__.py @@ -7,7 +7,7 @@ from gzip import decompress from urllib.parse import urljoin -from questionpy_server.cache import FileLimitLRU, SizeError +from questionpy_server.cache import CacheItemTooLargeError, FileLimitLRU from questionpy_server.repository.helper import download from questionpy_server.repository.models import RepoMeta, RepoPackage, RepoPackageIndex from questionpy_server.utils.logger import URLAdapter @@ -52,7 +52,7 @@ async def get_packages(self, meta: RepoMeta) -> dict[str, RepoPackage]: raw_index_zip = await download(self._url_index, size=meta.size, expected_hash=meta.sha256) try: await self._cache.put(meta.sha256, raw_index_zip) - except SizeError: + except CacheItemTooLargeError: self._log.warning("Package index is too big to be cached.") raw_index = decompress(raw_index_zip) diff --git a/questionpy_server/web.py b/questionpy_server/web.py index 49faa8be..ec54dd9a 100644 --- a/questionpy_server/web.py +++ b/questionpy_server/web.py @@ -5,33 +5,23 @@ from collections.abc import Sequence from hashlib import sha256 from io import BytesIO -from json import JSONDecodeError, loads -from typing import TYPE_CHECKING, Literal, NamedTuple, overload +from typing import Literal, NamedTuple, overload from aiohttp import BodyPartReader -from aiohttp.abc import Request from aiohttp.log import web_logger -from aiohttp.web_exceptions import HTTPBadRequest, HTTPRequestEntityTooLarge +from aiohttp.web_exceptions import HTTPRequestEntityTooLarge from aiohttp.web_response import Response -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel -from questionpy_common import constants from questionpy_common.constants import KiB -from questionpy_server.cache import SizeError -from questionpy_server.collector import PackageCollection -from questionpy_server.package import Package -from questionpy_server.types import M - -if TYPE_CHECKING: - from questionpy_server.app import QPyServer def json_response(data: Sequence[BaseModel] | BaseModel, status: int = 200) -> Response: """Creates a json response from a single BaseModel or a list of BaseModels. Args: - data (Union[Sequence[BaseModel]): A BaseModel or a list of BaseModels. - status (int): The HTTP status code. + data: A BaseModel or a list of BaseModels. + status: The HTTP status code. Returns: Response: A response object. @@ -42,29 +32,6 @@ def json_response(data: Sequence[BaseModel] | BaseModel, status: int = 200) -> R return Response(text=data.model_dump_json(), status=status, content_type="application/json") -def create_model_from_json(json: object | str, param_class: type[M]) -> M: - """Creates a BaseModel from an object. - - Args: - json (Union[object, str]): object containing the parsed json - param_class (type[M]): BaseModel class - - Returns: - M: BaseModel - """ - try: - if isinstance(json, str): - json = loads(json) - model = param_class.model_validate(json) - except ValidationError as error: - web_logger.warning("JSON does not match model: %s", error) - raise HTTPBadRequest from error - except JSONDecodeError as error: - web_logger.warning("Invalid JSON in request") - raise HTTPBadRequest from error - return model - - class HashContainer(NamedTuple): data: bytes hash: str @@ -112,57 +79,3 @@ async def read_part(part: BodyPartReader, max_size: int, *, calculate_hash: bool if calculate_hash: return HashContainer(data=buffer.getvalue(), hash=hash_object.hexdigest()) return buffer.getvalue() - - -async def parse_form_data(request: Request) -> tuple[bytes | None, HashContainer | None, bytes | None]: - """Parses a multipart/form-data request. - - Args: - request (Request): The request to be parsed. - - Returns: - tuple of main field, package, and question state - """ - server: QPyServer = request.app["qpy_server_app"] - main = package = question_state = None - - reader = await request.multipart() - while part := await reader.next(): - if not isinstance(part, BodyPartReader): - continue - - if part.name == "main": - main = await read_part(part, server.settings.webservice.max_main_size, calculate_hash=False) - elif part.name == "package": - package = await read_part(part, server.settings.webservice.max_package_size, calculate_hash=True) - elif part.name == "question_state": - question_state = await read_part(part, constants.MAX_QUESTION_STATE_SIZE, calculate_hash=False) - - return main, package, question_state - - -async def get_or_save_package( - collection: PackageCollection, hash_value: str, container: HashContainer | None -) -> Package | None: - """Gets a package from or saves it in the package collection. - - Args: - collection (PackageCollection): package collection - hash_value (str): The hash of the package. - container (Optional[HashContainer]): container with the package data and its hash - - Returns: - package - """ - try: - if not container: - package = collection.get(hash_value) - else: - package = await collection.put(container) - except SizeError as error: - raise HTTPRequestEntityTooLarge( - max_size=error.max_size, actual_size=error.actual_size, text=str(error) - ) from error - except FileNotFoundError: - return None - return package diff --git a/tests/questionpy_server/repository/test_repository.py b/tests/questionpy_server/repository/test_repository.py index b5813271..d0a3846c 100644 --- a/tests/questionpy_server/repository/test_repository.py +++ b/tests/questionpy_server/repository/test_repository.py @@ -11,7 +11,7 @@ from _pytest.tmpdir import TempPathFactory from questionpy_common.constants import KiB -from questionpy_server.cache import FileLimitLRU, SizeError +from questionpy_server.cache import CacheItemTooLargeError, FileLimitLRU from questionpy_server.repository import RepoMeta, RepoPackage, RepoPackageIndex, Repository from questionpy_server.utils.manifest import ComparableManifest from tests.test_data.factories import ManifestFactory, RepoMetaFactory, RepoPackageVersionsFactory @@ -119,7 +119,7 @@ async def test_log_warning_when_package_index_is_too_big_for_cache( with ( patch("questionpy_server.repository.download") as mock_download, - patch.object(cache, "put", side_effect=SizeError), + patch.object(cache, "put", side_effect=CacheItemTooLargeError), ): parsed = package_index.model_dump_json() mock_download.return_value = compress(parsed.encode()) diff --git a/tests/questionpy_server/test_cache.py b/tests/questionpy_server/test_cache.py index 29952d30..9e4f9742 100644 --- a/tests/questionpy_server/test_cache.py +++ b/tests/questionpy_server/test_cache.py @@ -11,7 +11,7 @@ import pytest from _pytest.tmpdir import TempPathFactory -from questionpy_server.cache import FileLimitLRU, SizeError +from questionpy_server.cache import CacheItemTooLargeError, FileLimitLRU @dataclass @@ -186,7 +186,7 @@ async def test_put(cache: FileLimitLRU, settings: Settings) -> None: assert get_file_count(settings.cache.directory) == settings.items.num_of_items # Content size is bigger than cache size. - with pytest.raises(SizeError): + with pytest.raises(CacheItemTooLargeError): await cache.put("new", b"." * (settings.cache.size + 1)) # Replace existing file.