diff --git a/mypy.ini b/mypy.ini index 981e4061c..4c8aacaca 100644 --- a/mypy.ini +++ b/mypy.ini @@ -77,6 +77,9 @@ ignore_errors = True # dependencies. # - Write your own stubs. You don't need to write stubs for the whole library, # only the parts that Karapace is interacting with. +[mypy-accept_types.*] +ignore_missing_imports = True + [mypy-aiokafka.*] ignore_missing_imports = True diff --git a/src/karapace/auth.py b/src/karapace/auth.py index e801bfae9..607ab0ea3 100644 --- a/src/karapace/auth.py +++ b/src/karapace/auth.py @@ -11,7 +11,7 @@ from karapace.config import Config, InvalidConfiguration from karapace.statsd import StatsClient from karapace.utils import json_decode, json_encode -from typing import Final, Protocol +from typing import Protocol from typing_extensions import override, TypedDict from watchfiles import awatch, Change @@ -98,12 +98,12 @@ class AuthData(TypedDict): class AuthenticateProtocol(Protocol): - def authenticate(self, *, username: str, password: str) -> User: + def authenticate(self, *, username: str, password: str) -> User | None: ... class AuthorizeProtocol(Protocol): - def get_user(self, username: str) -> User: + def get_user(self, username: str) -> User | None: ... def check_authorization(self, user: User | None, operation: Operation, resource: str) -> bool: @@ -114,24 +114,24 @@ def check_authorization_any(self, user: User | None, operation: Operation, resou class AuthenticatorAndAuthorizer(AuthenticateProtocol, AuthorizeProtocol): - MUST_AUTHENTICATE: Final[bool] = True + MUST_AUTHENTICATE: bool = True async def close(self) -> None: ... - async def start(self, stats: StatsClient) -> None: + async def start(self, stats: StatsClient) -> None: # pylint: disable=unused-argument ... class NoAuthAndAuthz(AuthenticatorAndAuthorizer): - MUST_AUTHENTICATE: Final[bool] = False + MUST_AUTHENTICATE: bool = False @override - def authenticate(self, *, username: str, password: str) -> User: + def authenticate(self, *, username: str, password: str) -> User | None: return None @override - def get_user(self, username: str) -> User: + def get_user(self, username: str) -> User | None: return None @override @@ -156,7 +156,7 @@ def __init__(self, *, user_db: dict[str, User] | None = None, permissions: list[ self.user_db = user_db or {} self.permissions = permissions or [] - def get_user(self, username: str) -> User: + def get_user(self, username: str) -> User | None: user = self.user_db.get(username) if not user: raise ValueError("No user found") @@ -289,7 +289,7 @@ def _load_authfile(self) -> None: raise InvalidConfiguration("Failed to load auth file") from ex @override - def authenticate(self, *, username: str, password: str) -> User: + def authenticate(self, *, username: str, password: str) -> User | None: user = self.get_user(username) if user is None or not user.compare_password(password): raise AuthenticationError() diff --git a/src/karapace/config.py b/src/karapace/config.py index b8a6b091a..332363d46 100644 --- a/src/karapace/config.py +++ b/src/karapace/config.py @@ -122,7 +122,7 @@ def get_rest_base_uri(self) -> str: def to_env_str(self) -> str: env_lines: list[str] = [] - for key, value in self.dict().items(): + for key, value in self.model_dump().items(): if value is not None: env_lines.append(f"{key.upper()}={value}") return "\n".join(env_lines) diff --git a/src/karapace/forward_client.py b/src/karapace/forward_client.py index 2f6decb70..9d791303c 100644 --- a/src/karapace/forward_client.py +++ b/src/karapace/forward_client.py @@ -3,9 +3,11 @@ See LICENSE for details """ -from fastapi import Request, Response -from fastapi.responses import JSONResponse, PlainTextResponse +from fastapi import HTTPException, Request, status +from karapace.utils import json_decode from karapace.version import __version__ +from pydantic import BaseModel +from typing import overload, TypeVar, Union import aiohttp import async_timeout @@ -14,16 +16,30 @@ LOG = logging.getLogger(__name__) +BaseModelResponse = TypeVar("BaseModelResponse", bound=BaseModel) +SimpleTypeResponse = TypeVar("SimpleTypeResponse", bound=Union[int, list[int]]) + + class ForwardClient: USER_AGENT = f"Karapace/{__version__}" - def __init__(self): + def __init__(self) -> None: self._forward_client: aiohttp.ClientSession | None = None def _get_forward_client(self) -> aiohttp.ClientSession: return aiohttp.ClientSession(headers={"User-Agent": ForwardClient.USER_AGENT}) - async def forward_request_remote(self, *, request: Request, primary_url: str) -> Response: + def _acceptable_response_content_type(self, *, content_type: str) -> bool: + return ( + content_type.startswith("application/") and content_type.endswith("json") + ) or content_type == "application/octet-stream" + + async def _forward_request_remote( + self, + *, + request: Request, + primary_url: str, + ) -> bytes: LOG.info("Forwarding %s request to remote url: %r since we're not the master", request.method, request.url) timeout = 60.0 headers = request.headers.mutablecopy() @@ -37,18 +53,52 @@ async def forward_request_remote(self, *, request: Request, primary_url: str) -> forward_url = f"{forward_url}?{request.url.query}" logging.error(forward_url) - with async_timeout.timeout(timeout): + async with async_timeout.timeout(timeout): body_data = await request.body() async with func(forward_url, headers=headers, data=body_data) as response: - if response.headers.get("Content-Type", "").startswith("application/json"): - return JSONResponse( - content=await response.text(), - status_code=response.status, - headers=response.headers, - ) - else: - return PlainTextResponse( - content=await response.text(), - status_code=response.status, - headers=response.headers, - ) + if self._acceptable_response_content_type(content_type=response.headers.get("Content-Type")): + return await response.text() + LOG.error("Unknown response for forwarded request: %s", response) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "error_code": status.HTTP_500_INTERNAL_SERVER_ERROR, + "message": "Unknown response for forwarded request.", + }, + ) + + @overload + async def forward_request_remote( + self, + *, + request: Request, + primary_url: str, + response_type: type[BaseModelResponse], + ) -> BaseModelResponse: + ... + + @overload + async def forward_request_remote( + self, + *, + request: Request, + primary_url: str, + response_type: type[SimpleTypeResponse], + ) -> SimpleTypeResponse: + ... + + async def forward_request_remote( + self, + *, + request: Request, + primary_url: str, + response_type: type[BaseModelResponse] | type[SimpleTypeResponse], + ) -> BaseModelResponse | SimpleTypeResponse: + body = await self._forward_request_remote(request=request, primary_url=primary_url) + if response_type == int: + return int(body) # type: ignore[return-value] + if response_type == list[int]: + return json_decode(body, assume_type=list[int]) # type: ignore[return-value] + if issubclass(response_type, BaseModel): + return response_type.parse_raw(body) # type: ignore[return-value] + raise ValueError("Did not match any expected type") diff --git a/src/karapace/rapu.py b/src/karapace/rapu.py index 6a45b8ed0..d236671fa 100644 --- a/src/karapace/rapu.py +++ b/src/karapace/rapu.py @@ -174,8 +174,7 @@ def __init__( def _create_aiohttp_application(self, *, config: Config) -> aiohttp.web.Application: if config.http_request_max_size: return aiohttp.web.Application(client_max_size=config.http_request_max_size) - else: - return aiohttp.web.Application() + return aiohttp.web.Application() async def close_by_app(self, app: aiohttp.web.Application) -> None: # pylint: disable=unused-argument await self.close() diff --git a/src/karapace/statsd.py b/src/karapace/statsd.py index 13b0db0a4..3ef4d001c 100644 --- a/src/karapace/statsd.py +++ b/src/karapace/statsd.py @@ -12,7 +12,7 @@ from collections.abc import Iterator from contextlib import contextmanager -from karapace.config import Config +from karapace.config import Config, KarapaceTags from karapace.sentry import get_sentry_client from typing import Any, Final @@ -28,7 +28,7 @@ class StatsClient: def __init__(self, config: Config) -> None: self._dest_addr: Final = (config.statsd_host, config.statsd_port) self._socket: Final = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self._tags: Final = config.tags or {} + self._tags: Final[KarapaceTags] = config.tags self.sentry_client: Final = get_sentry_client(sentry_config=(config.sentry or None)) @contextmanager diff --git a/src/karapace/typing.py b/src/karapace/typing.py index a205ae9de..7927657eb 100644 --- a/src/karapace/typing.py +++ b/src/karapace/typing.py @@ -5,11 +5,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence +from collections.abc import Generator, Mapping, Sequence from enum import Enum, unique from karapace.errors import InvalidVersion from pydantic import ValidationInfo -from typing import Any, ClassVar, NewType, Union +from typing import Any, Callable, ClassVar, NewType, Union from typing_extensions import TypeAlias import functools @@ -38,7 +38,7 @@ class Subject(str): @classmethod # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. - def __get_validators__(cls): + def __get_validators__(cls) -> Generator[Callable[[str, ValidationInfo], str], None, None]: yield cls.validate @classmethod diff --git a/src/schema_registry/factory.py b/src/schema_registry/factory.py index 667bdfc7d..ae104b854 100644 --- a/src/schema_registry/factory.py +++ b/src/schema_registry/factory.py @@ -17,6 +17,7 @@ from schema_registry.http_handlers import setup_exception_handlers from schema_registry.middlewares import setup_middlewares from schema_registry.routers.setup import setup_routers +from typing import AsyncContextManager, Callable import logging @@ -44,12 +45,18 @@ async def karapace_schema_registry_lifespan( stastd.close() -def create_karapace_application(*, config: Config, lifespan: AsyncGenerator[None, None]) -> FastAPI: +def create_karapace_application( + *, + config: Config, + lifespan: Callable[ + [FastAPI, StatsClient, KarapaceSchemaRegistry, AuthenticatorAndAuthorizer], AsyncContextManager[None] + ], +) -> FastAPI: configure_logging(config=config) log_config_without_secrets(config=config) logging.info("Starting Karapace Schema Registry (%s)", karapace_version.__version__) - app = FastAPI(lifespan=lifespan) + app = FastAPI(lifespan=lifespan) # type: ignore[arg-type] setup_routers(app=app) setup_exception_handlers(app=app) setup_middlewares(app=app) diff --git a/src/schema_registry/http_handlers/__init__.py b/src/schema_registry/http_handlers/__init__.py index 93bc853cc..e15ec9565 100644 --- a/src/schema_registry/http_handlers/__init__.py +++ b/src/schema_registry/http_handlers/__init__.py @@ -14,11 +14,11 @@ def setup_exception_handlers(app: FastAPI) -> None: @app.exception_handler(StarletteHTTPException) - async def http_exception_handler(_: StarletteHTTPRequest, exc: StarletteHTTPException): + async def http_exception_handler(_: StarletteHTTPRequest, exc: StarletteHTTPException) -> JSONResponse: return JSONResponse(status_code=exc.status_code, content=exc.detail) @app.exception_handler(RequestValidationError) - async def validation_exception_handler(_: StarletteHTTPRequest, exc: RequestValidationError): + async def validation_exception_handler(_: StarletteHTTPRequest, exc: RequestValidationError) -> JSONResponse: error_code = HTTPStatus.UNPROCESSABLE_ENTITY.value if isinstance(exc, KarapaceValidationError): error_code = exc.error_code diff --git a/src/schema_registry/middlewares/__init__.py b/src/schema_registry/middlewares/__init__.py index b5fb2e125..7c0559687 100644 --- a/src/schema_registry/middlewares/__init__.py +++ b/src/schema_registry/middlewares/__init__.py @@ -3,14 +3,16 @@ See LICENSE for details """ -from fastapi import FastAPI, HTTPException, Request +from collections.abc import Awaitable +from fastapi import FastAPI, HTTPException, Request, Response from fastapi.responses import JSONResponse from karapace.content_type import check_schema_headers +from typing import Callable def setup_middlewares(app: FastAPI) -> None: @app.middleware("http") - async def set_content_types(request: Request, call_next): + async def set_content_types(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: try: response_content_type = check_schema_headers(request) except HTTPException as exc: @@ -25,7 +27,7 @@ async def set_content_types(request: Request, call_next): if request.headers.get("Content-Type") == "application/octet-stream": new_headers = request.headers.mutablecopy() new_headers["Content-Type"] = "application/json" - request._headers = new_headers + request._headers = new_headers # pylint: disable=protected-access request.scope.update(headers=request.headers.raw) response = await call_next(request) diff --git a/src/schema_registry/routers/config.py b/src/schema_registry/routers/config.py index 04bd63545..1c95ac046 100644 --- a/src/schema_registry/routers/config.py +++ b/src/schema_registry/routers/config.py @@ -53,10 +53,11 @@ async def config_put( i_am_primary, primary_url = await schema_registry.get_master() if i_am_primary: return await controller.config_set(compatibility_level_request=compatibility_level_request) - elif not primary_url: + if not primary_url: raise no_primary_url_error() - else: - return await forward_client.forward_request_remote(request=request, primary_url=primary_url) + return await forward_client.forward_request_remote( + request=request, primary_url=primary_url, response_type=CompatibilityResponse + ) @config_router.get("/{subject}") @@ -92,10 +93,11 @@ async def config_set_subject( i_am_primary, primary_url = await schema_registry.get_master() if i_am_primary: return await controller.config_subject_set(subject=subject, compatibility_level_request=compatibility_level_request) - elif not primary_url: + if not primary_url: raise no_primary_url_error() - else: - return await forward_client.forward_request_remote(request=request, primary_url=primary_url) + return await forward_client.forward_request_remote( + request=request, primary_url=primary_url, response_type=CompatibilityResponse + ) @config_router.delete("/{subject}") @@ -115,7 +117,8 @@ async def config_delete_subject( i_am_primary, primary_url = await schema_registry.get_master() if i_am_primary: return await controller.config_subject_delete(subject=subject) - elif not primary_url: + if not primary_url: raise no_primary_url_error() - else: - return await forward_client.forward_request_remote(request=request, primary_url=primary_url) + return await forward_client.forward_request_remote( + request=request, primary_url=primary_url, response_type=CompatibilityResponse + ) diff --git a/src/schema_registry/routers/master_availability.py b/src/schema_registry/routers/master_availability.py index 55e792275..a3783575a 100644 --- a/src/schema_registry/routers/master_availability.py +++ b/src/schema_registry/routers/master_availability.py @@ -4,8 +4,7 @@ """ from dependency_injector.wiring import inject, Provide -from fastapi import APIRouter, Depends, HTTPException, Request, Response, status -from fastapi.responses import JSONResponse +from fastapi import APIRouter, Depends, Request, Response from karapace.config import Config from karapace.forward_client import ForwardClient from karapace.schema_registry import KarapaceSchemaRegistry @@ -13,7 +12,6 @@ from schema_registry.container import SchemaRegistryContainer from typing import Final -import json import logging LOG = logging.getLogger(__name__) @@ -47,7 +45,8 @@ async def master_availability( response.headers.update(NO_CACHE_HEADER) if ( - schema_registry.schema_reader.master_coordinator._sc is not None # pylint: disable=protected-access + schema_registry.schema_reader.master_coordinator is not None + and schema_registry.schema_reader.master_coordinator._sc is not None # pylint: disable=protected-access and schema_registry.schema_reader.master_coordinator._sc.is_master_assigned_to_myself() # pylint: disable=protected-access ): return MasterAvailabilityResponse(master_available=are_we_master) @@ -55,12 +54,6 @@ async def master_availability( if master_url is None or f"{config.advertised_hostname}:{config.advertised_port}" in master_url: return NO_MASTER - forward_response = await forward_client.forward_request_remote(request=request, primary_url=master_url) - if isinstance(response, JSONResponse): - response_json = json.loads(forward_response.body) - return MasterAvailabilityResponse(master_available=response_json["master_availability"]) - - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=forward_response.body, + return await forward_client.forward_request_remote( + request=request, primary_url=master_url, response_type=MasterAvailabilityResponse ) diff --git a/src/schema_registry/routers/metrics.py b/src/schema_registry/routers/metrics.py index 23b4b39f8..897a1d65d 100644 --- a/src/schema_registry/routers/metrics.py +++ b/src/schema_registry/routers/metrics.py @@ -6,7 +6,6 @@ from dependency_injector.wiring import inject, Provide from fastapi import APIRouter, Depends, Response from karapace.instrumentation.prometheus import PrometheusInstrumentation -from pydantic import BaseModel from schema_registry.container import SchemaRegistryContainer metrics_router = APIRouter( @@ -20,5 +19,5 @@ @inject async def metrics( prometheus: PrometheusInstrumentation = Depends(Provide[SchemaRegistryContainer.karapace_container.prometheus]), -) -> BaseModel: +) -> Response: return Response(content=await prometheus.serve_metrics(), media_type=prometheus.CONTENT_TYPE_LATEST) diff --git a/src/schema_registry/routers/mode.py b/src/schema_registry/routers/mode.py index 870a876d2..c139e8e7d 100644 --- a/src/schema_registry/routers/mode.py +++ b/src/schema_registry/routers/mode.py @@ -9,6 +9,7 @@ from karapace.typing import Subject from schema_registry.container import SchemaRegistryContainer from schema_registry.routers.errors import unauthorized +from schema_registry.routers.requests import ModeResponse from schema_registry.schema_registry_apis import KarapaceSchemaRegistryController from schema_registry.user import get_current_user from typing import Annotated @@ -26,7 +27,7 @@ async def mode_get( user: Annotated[User, Depends(get_current_user)], authorizer: AuthenticatorAndAuthorizer = Depends(Provide[SchemaRegistryContainer.karapace_container.authorizer]), controller: KarapaceSchemaRegistryController = Depends(Provide[SchemaRegistryContainer.schema_registry_controller]), -): +) -> ModeResponse: if authorizer and not authorizer.check_authorization(user, Operation.Read, "Config:"): raise unauthorized() @@ -40,7 +41,7 @@ async def mode_get_subject( user: Annotated[User, Depends(get_current_user)], authorizer: AuthenticatorAndAuthorizer = Depends(Provide[SchemaRegistryContainer.karapace_container.authorizer]), controller: KarapaceSchemaRegistryController = Depends(Provide[SchemaRegistryContainer.schema_registry_controller]), -): +) -> ModeResponse: if authorizer and not authorizer.check_authorization(user, Operation.Read, f"Subject:{subject}"): raise unauthorized() diff --git a/src/schema_registry/routers/schemas.py b/src/schema_registry/routers/schemas.py index d7af4cd2b..984c50085 100644 --- a/src/schema_registry/routers/schemas.py +++ b/src/schema_registry/routers/schemas.py @@ -4,7 +4,7 @@ """ from dependency_injector.wiring import inject, Provide -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from karapace.auth import AuthenticatorAndAuthorizer, User from schema_registry.container import SchemaRegistryContainer from schema_registry.routers.requests import SchemaListingItem, SchemasResponse, SubjectVersion @@ -44,7 +44,7 @@ async def schemas_get( schema_id: str, # TODO: type to actual type includeSubjects: bool = False, # TODO: include subjects? fetchMaxId: bool = False, # TODO: fetch max id? - format: str = "", + format_serialized: str = Query("", alias="format"), authorizer: AuthenticatorAndAuthorizer = Depends(Provide[SchemaRegistryContainer.karapace_container.authorizer]), controller: KarapaceSchemaRegistryController = Depends(Provide[SchemaRegistryContainer.schema_registry_controller]), ) -> SchemasResponse: @@ -52,7 +52,7 @@ async def schemas_get( schema_id=schema_id, include_subjects=includeSubjects, fetch_max_id=fetchMaxId, - format_serialized=format, + format_serialized=format_serialized, user=user, authorizer=authorizer, ) diff --git a/src/schema_registry/routers/subjects.py b/src/schema_registry/routers/subjects.py index 766329795..4d0a9fe94 100644 --- a/src/schema_registry/routers/subjects.py +++ b/src/schema_registry/routers/subjects.py @@ -83,10 +83,9 @@ async def subjects_subject_delete( i_am_primary, primary_url = await schema_registry.get_master() if i_am_primary: return await controller.subject_delete(subject=subject, permanent=permanent) - elif not primary_url: + if not primary_url: raise no_primary_url_error() - else: - return await forward_client.forward_request_remote(request=request, primary_url=primary_url) + return await forward_client.forward_request_remote(request=request, primary_url=primary_url, response_type=list[int]) @subjects_router.post("/{subject}/versions") @@ -165,10 +164,9 @@ async def subjects_subject_version_delete( i_am_primary, primary_url = await schema_registry.get_master() if i_am_primary: return await controller.subject_version_delete(subject=subject, version=version, permanent=permanent) - elif not primary_url: + if not primary_url: raise no_primary_url_error() - else: - return await forward_client.forward_request_remote(request=request, primary_url=primary_url) + return await forward_client.forward_request_remote(request=request, primary_url=primary_url, response_type=int) @subjects_router.get("/{subject}/versions/{version}/schema") diff --git a/src/schema_registry/schema_registry_apis.py b/src/schema_registry/schema_registry_apis.py index 13f6bb8f2..56cd567b1 100644 --- a/src/schema_registry/schema_registry_apis.py +++ b/src/schema_registry/schema_registry_apis.py @@ -6,7 +6,7 @@ from avro.errors import SchemaParseException from dependency_injector.wiring import inject, Provide -from fastapi import Depends, HTTPException, Request, Response, status +from fastapi import Depends, HTTPException, Request, status from karapace.auth import AuthenticatorAndAuthorizer, Operation, User from karapace.compatibility import CompatibilityModes from karapace.compatibility.jsonschema.checks import is_incompatible @@ -76,22 +76,22 @@ def _add_schema_registry_routes(self) -> None: def _subject_get(self, subject: Subject, include_deleted: bool = False) -> dict[Version, SchemaVersion]: try: schema_versions = self.schema_registry.subject_get(subject, include_deleted) - except SubjectNotFoundException: + except SubjectNotFoundException as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.SUBJECT_NOT_FOUND.value, "message": SchemaErrorMessages.SUBJECT_NOT_FOUND_FMT.value.format(subject=subject), }, - ) - except SchemasNotFoundException: + ) from exc + except SchemasNotFoundException as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.SUBJECT_NOT_FOUND.value, "message": SchemaErrorMessages.SUBJECT_NOT_FOUND_FMT.value.format(subject=subject), }, - ) + ) from exc return schema_versions def _invalid_version(self, version: str | int) -> HTTPException: @@ -117,16 +117,16 @@ async def compatibility_check( """Check for schema compatibility""" try: compatibility_mode = self.schema_registry.get_compatibility_mode(subject=subject) - except ValueError as ex: + except ValueError as exc: # Using INTERNAL_SERVER_ERROR because the subject and configuration # should have been validated before. raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={ "error_code": SchemaErrorCodes.HTTP_INTERNAL_SERVER_ERROR.value, - "message": str(ex), + "message": str(exc), }, - ) + ) from exc new_schema = self.get_new_schema(schema_request=schema_request) old_schema = self.get_old_schema(subject, Versioner.V(version)) # , content_type) @@ -186,14 +186,14 @@ async def schemas_get( ) -> SchemasResponse: try: parsed_schema_id = SchemaId(int(schema_id)) - except ValueError: + except ValueError as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.HTTP_NOT_FOUND.value, "message": "HTTP 404 Not Found", }, - ) + ) from exc def _has_subject_with_id() -> bool: # Fast path @@ -263,14 +263,14 @@ async def schemas_get_versions( ) -> list[SubjectVersion]: try: schema_id_int = SchemaId(int(schema_id)) - except ValueError: + except ValueError as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.HTTP_NOT_FOUND.value, "message": "HTTP 404 Not Found", }, - ) + ) from exc subject_versions = [] for subject_version in self.schema_registry.get_subject_versions_for_schema(schema_id_int, include_deleted=deleted): @@ -301,14 +301,14 @@ async def config_set( ) -> CompatibilityResponse: try: compatibility_level = CompatibilityModes(compatibility_level_request.compatibility) - except (ValueError, KeyError): + except (ValueError, KeyError) as exc: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail={ "error_code": SchemaErrorCodes.INVALID_COMPATIBILITY_LEVEL.value, "message": SchemaErrorMessages.INVALID_COMPATIBILITY_LEVEL.value, }, - ) + ) from exc self.schema_registry.send_config_message(compatibility_level=compatibility_level, subject=None) return CompatibilityResponse(compatibility=self.schema_registry.schema_reader.config.compatibility) @@ -354,14 +354,14 @@ async def config_subject_set( ) -> CompatibilityResponse: try: compatibility_level = CompatibilityModes(compatibility_level_request.compatibility) - except (ValueError, KeyError): + except (ValueError, KeyError) as exc: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail={ "error_code": SchemaErrorCodes.INVALID_COMPATIBILITY_LEVEL.value, "message": "Invalid compatibility level", }, - ) + ) from exc self.schema_registry.send_config_message(compatibility_level=compatibility_level, subject=Subject(subject)) return CompatibilityResponse(compatibility=compatibility_level.value) @@ -400,42 +400,42 @@ async def subject_delete( try: version_list = await self.schema_registry.subject_delete_local(subject=Subject(subject), permanent=permanent) return [version.value for version in version_list] - except (SubjectNotFoundException, SchemasNotFoundException): + except (SubjectNotFoundException, SchemasNotFoundException) as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.SUBJECT_NOT_FOUND.value, "message": SchemaErrorMessages.SUBJECT_NOT_FOUND_FMT.value.format(subject=subject), }, - ) - except SubjectNotSoftDeletedException: + ) from exc + except SubjectNotSoftDeletedException as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.SUBJECT_NOT_SOFT_DELETED.value, "message": f"Subject '{subject}' was not deleted first before being permanently deleted", }, - ) - except SubjectSoftDeletedException: + ) from exc + except SubjectSoftDeletedException as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.SUBJECT_SOFT_DELETED.value, "message": f"Subject '{subject}' was soft deleted.Set permanent=true to delete permanently", }, - ) + ) from exc - except ReferenceExistsException as arg: + except ReferenceExistsException as exc: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail={ "error_code": SchemaErrorCodes.REFERENCE_EXISTS.value, "message": ( f"One or more references exist to the schema " - f"{{magic=1,keytype=SCHEMA,subject={subject},version={arg.version}}}." + f"{{magic=1,keytype=SCHEMA,subject={subject},version={exc.version}}}." ), }, - ) + ) from exc async def subject_version_get( self, @@ -456,24 +456,24 @@ async def subject_version_get( schemaType=subject_data.get("schemaType", None), compatibility=None, # Do not return compatibility from this endpoint. ) - except (SubjectNotFoundException, SchemasNotFoundException): + except (SubjectNotFoundException, SchemasNotFoundException) as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.SUBJECT_NOT_FOUND.value, "message": SchemaErrorMessages.SUBJECT_NOT_FOUND_FMT.value.format(subject=subject), }, - ) - except VersionNotFoundException: + ) from exc + except VersionNotFoundException as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.VERSION_NOT_FOUND.value, "message": f"Version {version} not found.", }, - ) - except InvalidVersion: - raise self._invalid_version(version) + ) from exc + except InvalidVersion as exc: + raise self._invalid_version(version) from exc async def subject_version_delete( self, @@ -487,23 +487,23 @@ async def subject_version_delete( Subject(subject), Versioner.V(version), permanent ) return resolved_version.value - except (SubjectNotFoundException, SchemasNotFoundException): + except (SubjectNotFoundException, SchemasNotFoundException) as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.SUBJECT_NOT_FOUND.value, "message": SchemaErrorMessages.SUBJECT_NOT_FOUND_FMT.value.format(subject=subject), }, - ) - except VersionNotFoundException: + ) from exc + except VersionNotFoundException as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.VERSION_NOT_FOUND.value, "message": f"Version {version} not found.", }, - ) - except SchemaVersionSoftDeletedException: + ) from exc + except SchemaVersionSoftDeletedException as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ @@ -513,8 +513,8 @@ async def subject_version_delete( "Set permanent=true to delete permanently" ), }, - ) - except SchemaVersionNotSoftDeletedException: + ) from exc + except SchemaVersionNotSoftDeletedException as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ @@ -523,20 +523,20 @@ async def subject_version_delete( f"Subject '{subject}' Version {version} was not deleted " "first before being permanently deleted" ), }, - ) - except ReferenceExistsException as arg: + ) from exc + except ReferenceExistsException as exc: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail={ "error_code": SchemaErrorCodes.REFERENCE_EXISTS.value, "message": ( f"One or more references exist to the schema " - f"{{magic=1,keytype=SCHEMA,subject={subject},version={arg.version}}}." + f"{{magic=1,keytype=SCHEMA,subject={subject},version={exc.version}}}." ), }, - ) - except InvalidVersion: - self._invalid_version(version) + ) from exc + except InvalidVersion as exc: + raise self._invalid_version(version) from exc async def subject_version_schema_get( self, @@ -547,24 +547,24 @@ async def subject_version_schema_get( try: subject_data = self.schema_registry.subject_version_get(Subject(subject), Versioner.V(version)) return json.loads(cast(str, subject_data["schema"])) # TODO typing - except InvalidVersion: - raise self._invalid_version(version) - except VersionNotFoundException: + except InvalidVersion as exc: + raise self._invalid_version(version) from exc + except VersionNotFoundException as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.VERSION_NOT_FOUND.value, "message": f"Version {version} not found.", }, - ) - except (SchemasNotFoundException, SubjectNotFoundException): + ) from exc + except (SchemasNotFoundException, SubjectNotFoundException) as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.SUBJECT_NOT_FOUND.value, "message": SchemaErrorMessages.SUBJECT_NOT_FOUND_FMT.value.format(subject=subject), }, - ) + ) from exc async def subject_version_referencedby_get( self, @@ -577,24 +577,24 @@ async def subject_version_referencedby_get( referenced_by = await self.schema_registry.subject_version_referencedby_get( Subject(subject), Versioner.V(version) ) - except (SubjectNotFoundException, SchemasNotFoundException): + except (SubjectNotFoundException, SchemasNotFoundException) as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.SUBJECT_NOT_FOUND.value, "message": SchemaErrorMessages.SUBJECT_NOT_FOUND_FMT.value.format(subject=subject), }, - ) - except VersionNotFoundException: + ) from exc + except VersionNotFoundException as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.VERSION_NOT_FOUND.value, "message": f"Version {version} not found.", }, - ) - except InvalidVersion: - raise self._invalid_version(version) + ) from exc + except InvalidVersion as exc: + raise self._invalid_version(version) from exc return referenced_by @@ -608,14 +608,14 @@ async def subject_versions_list( schema_versions = self.schema_registry.subject_get(Subject(subject), include_deleted=deleted) version_list = [version.value for version in schema_versions] return version_list - except (SubjectNotFoundException, SchemasNotFoundException): + except (SubjectNotFoundException, SchemasNotFoundException) as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.SUBJECT_NOT_FOUND.value, "message": SchemaErrorMessages.SUBJECT_NOT_FOUND_FMT.value.format(subject=subject), }, - ) + ) from exc def _validate_schema_type(self, data: JsonData) -> SchemaType: # TODO: simplify the calling code, this functionality should not be required @@ -631,14 +631,14 @@ def _validate_schema_type(self, data: JsonData) -> SchemaType: schema_type_unparsed = data.get("schemaType", SchemaType.AVRO.value) try: schema_type = SchemaType(schema_type_unparsed) - except ValueError: + except ValueError as exc: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail={ "error_code": SchemaErrorCodes.HTTP_UNPROCESSABLE_ENTITY.value, "message": f"Invalid schemaType {schema_type_unparsed}", }, - ) + ) from exc return schema_type def _validate_references( @@ -692,14 +692,14 @@ async def subjects_schema_post( ) -> SchemaResponse: try: subject_data = self._subject_get(subject, include_deleted=deleted) - except (SchemasNotFoundException, SubjectNotFoundException): + except (SchemasNotFoundException, SubjectNotFoundException) as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.SUBJECT_NOT_FOUND.value, "message": SchemaErrorMessages.SUBJECT_NOT_FOUND_FMT.value.format(subject=subject), }, - ) + ) from exc references = None new_schema_dependencies = None references = self._validate_references(schema_request) @@ -717,7 +717,7 @@ async def subjects_schema_post( normalize=normalize, use_protobuf_formatter=self.config.use_protobuf_formatter, ) - except InvalidSchema: + except InvalidSchema as exc: LOG.warning("Invalid schema: %r", schema_request.schema_str) raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, @@ -725,8 +725,8 @@ async def subjects_schema_post( "error_code": SchemaErrorCodes.INVALID_SCHEMA.value, "message": f"Error while looking up schema under subject {subject}", }, - ) - except InvalidReferences: + ) from exc + except InvalidReferences as exc: human_error = "Provided references is not valid" raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, @@ -734,7 +734,7 @@ async def subjects_schema_post( "error_code": SchemaErrorCodes.INVALID_SCHEMA.value, "message": f"Invalid {schema_request.schema_type} references. Error: {human_error}", }, - ) + ) from exc # Match schemas based on version from latest to oldest for schema_version in sorted(subject_data.values(), key=lambda item: item.version, reverse=True): @@ -747,11 +747,11 @@ async def subjects_schema_post( dependencies=other_dependencies, normalize=normalize, ) - except InvalidSchema as e: + except InvalidSchema as exc: failed_schema_id = schema_version.schema_id LOG.exception("Existing schema failed to parse. Id: %s", failed_schema_id) self.stats.unexpected_exception( - ex=e, where="Matching existing schemas to posted. Failed schema id: {failed_schema_id}" + ex=exc, where="Matching existing schemas to posted. Failed schema id: {failed_schema_id}" ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -759,7 +759,7 @@ async def subjects_schema_post( "error_code": SchemaErrorCodes.HTTP_INTERNAL_SERVER_ERROR.value, "message": f"Error while looking up schema under subject {subject}", }, - ) + ) from exc if schema_request.schema_type is SchemaType.JSONSCHEMA: schema_valid = parsed_typed_schema.to_dict() == new_schema.to_dict() @@ -776,8 +776,7 @@ async def subjects_schema_post( schema=parsed_typed_schema.schema_str, schemaType=schema_type, ) - else: - LOG.debug("Schema %r did not match %r", schema_version, parsed_typed_schema) + LOG.debug("Schema %r did not match %r", schema_version, parsed_typed_schema) raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -795,7 +794,7 @@ async def subject_post( normalize: bool, forward_client: ForwardClient, request: Request, - ) -> SchemaIdResponse | Response: + ) -> SchemaIdResponse: LOG.debug("POST with subject: %r, request: %r", subject, schema_request) references = self._validate_references(schema_request=schema_request) @@ -810,10 +809,10 @@ async def subject_post( normalize=normalize, use_protobuf_formatter=self.config.use_protobuf_formatter, ) - except (InvalidReferences, InvalidSchema, InvalidSchemaType) as e: + except (InvalidReferences, InvalidSchema, InvalidSchemaType) as exc: LOG.warning("Invalid schema: %r", schema_request.schema_str, exc_info=True) - if isinstance(e.__cause__, (SchemaParseException, JSONDecodeError, ProtobufUnresolvedDependencyException)): - human_error = f"{e.__cause__.args[0]}" # pylint: disable=no-member + if isinstance(exc.__cause__, (SchemaParseException, JSONDecodeError, ProtobufUnresolvedDependencyException)): + human_error = f"{exc.__cause__.args[0]}" # pylint: disable=no-member else: from_body_schema_str = schema_request.schema_str human_error = ( @@ -825,7 +824,7 @@ async def subject_post( "error_code": SchemaErrorCodes.INVALID_SCHEMA.value, "message": f"Invalid {schema_request.schema_type.value} schema. Error: {human_error}", }, - ) + ) from exc schema_id = self.get_schema_id_if_exists(subject=Subject(subject), schema=new_schema, include_deleted=False) if schema_id is not None: @@ -836,37 +835,39 @@ async def subject_post( try: schema_id = await self.schema_registry.write_new_schema_local(Subject(subject), new_schema, references) return SchemaIdResponse(id=schema_id) - except InvalidSchema as ex: + except InvalidSchema as exc: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail={ "error_code": SchemaErrorCodes.INVALID_SCHEMA.value, - "message": f"Invalid {schema_request.schema_type.value} schema. Error: {str(ex)}", + "message": f"Invalid {schema_request.schema_type.value} schema. Error: {str(exc)}", }, - ) - except IncompatibleSchema as ex: + ) from exc + except IncompatibleSchema as exc: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail={ "error_code": SchemaErrorCodes.HTTP_CONFLICT.value, - "message": str(ex), + "message": str(exc), }, - ) - except SchemaTooLargeException: + ) from exc + except SchemaTooLargeException as exc: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail={ "error_code": SchemaErrorCodes.SCHEMA_TOO_LARGE_ERROR_CODE.value, "message": "Schema is too large", }, - ) + ) from exc except Exception as xx: raise xx elif not primary_url: raise no_primary_url_error() else: - return await forward_client.forward_request_remote(request=request, primary_url=primary_url) + return await forward_client.forward_request_remote( + request=request, primary_url=primary_url, response_type=SchemaIdResponse + ) async def get_global_mode(self) -> ModeResponse: return ModeResponse(mode=str(self.schema_registry.get_global_mode())) @@ -903,14 +904,14 @@ def get_new_schema(self, schema_request: SchemaRequest) -> ValidatedTypedSchema: dependencies=new_schema_dependencies, use_protobuf_formatter=self.config.use_protobuf_formatter, ) - except InvalidSchema: + except InvalidSchema as exc: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail={ "error_code": SchemaErrorCodes.INVALID_SCHEMA.value, "message": f"Invalid {schema_request.schema_type} schema", }, - ) + ) from exc def get_old_schema(self, subject: Subject, version: Version) -> ParsedTypedSchema: old: JsonObject | None = None @@ -918,14 +919,14 @@ def get_old_schema(self, subject: Subject, version: Version) -> ParsedTypedSchem old = self.schema_registry.subject_version_get(subject=subject, version=version) except InvalidVersion: self._invalid_version(version.value) - except (VersionNotFoundException, SchemasNotFoundException, SubjectNotFoundException): + except (VersionNotFoundException, SchemasNotFoundException, SubjectNotFoundException) as exc: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "error_code": SchemaErrorCodes.VERSION_NOT_FOUND.value, "message": f"Version {version} not found.", }, - ) + ) from exc assert old is not None old_schema_type = self._validate_schema_type(data=old) try: @@ -935,11 +936,11 @@ def get_old_schema(self, subject: Subject, version: Version) -> ParsedTypedSchem old_references, old_dependencies = self.schema_registry.resolve_references(old_references) old_schema = ParsedTypedSchema.parse(old_schema_type, old["schema"], old_references, old_dependencies) return old_schema - except InvalidSchema: + except InvalidSchema as exc: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail={ "error_code": SchemaErrorCodes.INVALID_SCHEMA.value, "message": f"Found an invalid {old_schema_type} schema registered", }, - ) + ) from exc diff --git a/src/schema_registry/user.py b/src/schema_registry/user.py index b3d6919a2..bab8044ac 100644 --- a/src/schema_registry/user.py +++ b/src/schema_registry/user.py @@ -15,7 +15,7 @@ async def get_current_user( credentials: Annotated[HTTPBasicCredentials, Depends(HTTPBasic(auto_error=False))], authorizer: AuthenticatorAndAuthorizer = Depends(Provide[SchemaRegistryContainer.karapace_container.authorizer]), -) -> User: +) -> User | None: if authorizer.MUST_AUTHENTICATE and not credentials: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/tests/integration/test_master_coordinator.py b/tests/integration/test_master_coordinator.py index b4fe8088f..876a91d66 100644 --- a/tests/integration/test_master_coordinator.py +++ b/tests/integration/test_master_coordinator.py @@ -190,7 +190,6 @@ async def test_no_eligible_master(kafka_servers: KafkaServers) -> None: await mc.close() -@pytest.mark.skip(reason="requires master forwarding to be implemented") async def test_schema_request_forwarding( registry_async_pair, registry_async_retry_client: RetryRestClient, diff --git a/tests/integration/test_schema_registry_auth.py b/tests/integration/test_schema_registry_auth.py index 7624fffb6..89832355f 100644 --- a/tests/integration/test_schema_registry_auth.py +++ b/tests/integration/test_schema_registry_auth.py @@ -19,7 +19,6 @@ import aiohttp import asyncio -import pytest NEW_TOPIC_TIMEOUT = 10 @@ -204,7 +203,6 @@ async def test_sr_ids(registry_async_retry_client_auth: RetryRestClient) -> None assert res.status_code == 200 -@pytest.mark.skip(reason="requires master forwarding to be implemented") async def test_sr_auth_forwarding( registry_async_auth_pair: list[str], registry_async_retry_client_auth: RetryRestClient ) -> None: diff --git a/tests/integration/utils/cluster.py b/tests/integration/utils/cluster.py index 67b2c97ea..e574cefa2 100644 --- a/tests/integration/utils/cluster.py +++ b/tests/integration/utils/cluster.py @@ -81,15 +81,15 @@ async def start_schema_registry_cluster( } process = popen_karapace_all(module="schema_registry", env=env, stdout=logfile, stderr=errfile) stack.callback(stop_process, process) - all_processes.append((process, port)) + all_processes.append((process, port, config.host)) protocol = "http" if config.server_tls_keyfile is None else "https" endpoint = RegistryEndpoint(protocol, config.host, port) description = RegistryDescription(endpoint, schemas_topic) all_registries.append(description) - for process, port in all_processes: - wait_for_port_subprocess(port, process, hostname=config.host, wait_time=120) + for process, port, host in all_processes: + wait_for_port_subprocess(port, process, hostname=host, wait_time=120) yield all_registries diff --git a/tests/unit/test_forwarding_client.py b/tests/unit/test_forwarding_client.py new file mode 100644 index 000000000..744b1e9da --- /dev/null +++ b/tests/unit/test_forwarding_client.py @@ -0,0 +1,153 @@ +""" +karapace - schema registry authentication and authorization tests + +Copyright (c) 2023 Aiven Ltd +See LICENSE for details +""" +from __future__ import annotations + +from dataclasses import dataclass +from fastapi import Request +from fastapi.datastructures import Headers +from karapace.forward_client import ForwardClient +from pydantic import BaseModel +from starlette.datastructures import MutableHeaders +from tests.base_testcase import BaseTestCase +from unittest.mock import AsyncMock, Mock, patch + +import pytest + + +class TestResponse(BaseModel): + number: int + string: str + + +@dataclass +class ContentTypeTestCase(BaseTestCase): + content_type: str + + +@pytest.mark.parametrize( + "testcase", + [ + ContentTypeTestCase(test_name="application/json", content_type="application/json"), + ContentTypeTestCase( + test_name="application/vnd.schemaregistry.v1+json", content_type="application/vnd.schemaregistry.v1+json" + ), + ContentTypeTestCase( + test_name="application/vnd.schemaregistry+json", content_type="application/vnd.schemaregistry+json" + ), + ContentTypeTestCase(test_name="application/octet-stream", content_type="application/octet-stream"), + ], + ids=str, +) +async def test_forward_request_with_basemodel_response(testcase: ContentTypeTestCase) -> None: + forward_client = ForwardClient() + with patch.object(forward_client, "_get_forward_client", autospec=True) as mock_get_forward_client: + mock_request = Mock(spec=Request) + mock_request.method = "GET" + mock_request.headers = Headers() + + mock_aiohttp_session = Mock() + mock_get_forward_client.return_value = mock_aiohttp_session + mock_get_func = Mock() + mock_response_context = AsyncMock + mock_response = AsyncMock() + mock_response_context.call_function = lambda _: mock_response + mock_response.text.return_value = '{"number":10,"string":"important"}' + headers = MutableHeaders() + headers["Content-Type"] = testcase.content_type + mock_response.headers = headers + + async def mock_aenter(_) -> Mock: + return mock_response + + async def mock_aexit(_, __, ___, ____) -> None: + return + + mock_get_func.__aenter__ = mock_aenter + mock_get_func.__aexit__ = mock_aexit + mock_aiohttp_session.get.return_value = mock_get_func + + response = await forward_client.forward_request_remote( + request=mock_request, + primary_url="test-url", + response_type=TestResponse, + ) + + assert response == TestResponse(number=10, string="important") + + +async def test_forward_request_with_integer_list_response() -> None: + forward_client = ForwardClient() + with patch.object(forward_client, "_get_forward_client", autospec=True) as mock_get_forward_client: + mock_request = Mock(spec=Request) + mock_request.method = "GET" + mock_request.headers = Headers() + + mock_aiohttp_session = Mock() + mock_get_forward_client.return_value = mock_aiohttp_session + mock_get_func = Mock() + mock_response_context = AsyncMock + mock_response = AsyncMock() + mock_response_context.call_function = lambda _: mock_response + mock_response.text.return_value = "[1, 2, 3, 10]" + headers = MutableHeaders() + headers["Content-Type"] = "application/json" + mock_response.headers = headers + + async def mock_aenter(_) -> Mock: + return mock_response + + async def mock_aexit(_, __, ___, ____) -> None: + return + + mock_get_func.__aenter__ = mock_aenter + mock_get_func.__aexit__ = mock_aexit + mock_aiohttp_session.get.return_value = mock_get_func + + response = await forward_client.forward_request_remote( + request=mock_request, + primary_url="test-url", + response_type=list[int], + ) + + assert response == [1, 2, 3, 10] + + +async def test_forward_request_with_integer_response() -> None: + forward_client = ForwardClient() + with patch.object(forward_client, "_get_forward_client", autospec=True) as mock_get_forward_client: + mock_request = Mock(spec=Request) + mock_request.method = "GET" + mock_request.headers = Headers() + + mock_aiohttp_session = Mock() + mock_get_forward_client.return_value = mock_aiohttp_session + mock_get_func = Mock() + mock_response_context = AsyncMock + mock_response = AsyncMock() + mock_response_context.call_function = lambda _: mock_response + mock_response.text.return_value = "12" + headers = MutableHeaders() + headers["Content-Type"] = "application/json" + mock_response.headers = headers + + async def mock_aenter(_) -> Mock: + return mock_response + + async def mock_aexit(_, __, ___, ____) -> None: + return + + mock_get_func.__aenter__ = mock_aenter + mock_get_func.__aexit__ = mock_aexit + mock_aiohttp_session.get.return_value = mock_get_func + + response = await forward_client.forward_request_remote( + request=mock_request, + primary_url="test-url", + response_type=int, + ) + + assert response == 12