diff --git a/mypy.ini b/mypy.ini index 981e4061c..4c8aacaca 100644 --- a/mypy.ini +++ b/mypy.ini @@ -77,6 +77,9 @@ ignore_errors = True # dependencies. # - Write your own stubs. You don't need to write stubs for the whole library, # only the parts that Karapace is interacting with. +[mypy-accept_types.*] +ignore_missing_imports = True + [mypy-aiokafka.*] ignore_missing_imports = True diff --git a/src/karapace/auth.py b/src/karapace/auth.py index 0e7e3096a..607ab0ea3 100644 --- a/src/karapace/auth.py +++ b/src/karapace/auth.py @@ -11,7 +11,7 @@ from karapace.config import Config, InvalidConfiguration from karapace.statsd import StatsClient from karapace.utils import json_decode, json_encode -from typing import Final, Protocol +from typing import Protocol from typing_extensions import override, TypedDict from watchfiles import awatch, Change @@ -98,12 +98,12 @@ class AuthData(TypedDict): class AuthenticateProtocol(Protocol): - def authenticate(self, *, username: str, password: str) -> User: + def authenticate(self, *, username: str, password: str) -> User | None: ... class AuthorizeProtocol(Protocol): - def get_user(self, username: str) -> User: + def get_user(self, username: str) -> User | None: ... def check_authorization(self, user: User | None, operation: Operation, resource: str) -> bool: @@ -114,7 +114,7 @@ def check_authorization_any(self, user: User | None, operation: Operation, resou class AuthenticatorAndAuthorizer(AuthenticateProtocol, AuthorizeProtocol): - MUST_AUTHENTICATE: Final[bool] = True + MUST_AUTHENTICATE: bool = True async def close(self) -> None: ... @@ -124,14 +124,14 @@ async def start(self, stats: StatsClient) -> None: # pylint: disable=unused-arg class NoAuthAndAuthz(AuthenticatorAndAuthorizer): - MUST_AUTHENTICATE: Final[bool] = False + MUST_AUTHENTICATE: bool = False @override - def authenticate(self, *, username: str, password: str) -> User: + def authenticate(self, *, username: str, password: str) -> User | None: return None @override - def get_user(self, username: str) -> User: + def get_user(self, username: str) -> User | None: return None @override @@ -156,7 +156,7 @@ def __init__(self, *, user_db: dict[str, User] | None = None, permissions: list[ self.user_db = user_db or {} self.permissions = permissions or [] - def get_user(self, username: str) -> User: + def get_user(self, username: str) -> User | None: user = self.user_db.get(username) if not user: raise ValueError("No user found") @@ -289,7 +289,7 @@ def _load_authfile(self) -> None: raise InvalidConfiguration("Failed to load auth file") from ex @override - def authenticate(self, *, username: str, password: str) -> User: + def authenticate(self, *, username: str, password: str) -> User | None: user = self.get_user(username) if user is None or not user.compare_password(password): raise AuthenticationError() diff --git a/src/karapace/forward_client.py b/src/karapace/forward_client.py index c24420b4f..0e33ddca8 100644 --- a/src/karapace/forward_client.py +++ b/src/karapace/forward_client.py @@ -3,9 +3,11 @@ See LICENSE for details """ -from fastapi import Request, Response -from fastapi.responses import JSONResponse, PlainTextResponse +from fastapi import HTTPException, Request, status +from karapace.utils import json_decode from karapace.version import __version__ +from pydantic import BaseModel +from typing import overload, TypeVar, Union import aiohttp import async_timeout @@ -14,16 +16,30 @@ LOG = logging.getLogger(__name__) +T = TypeVar("T", bound=BaseModel) +P = TypeVar("P", bound=Union[int, list[int]]) + + class ForwardClient: USER_AGENT = f"Karapace/{__version__}" - def __init__(self): + def __init__(self) -> None: self._forward_client: aiohttp.ClientSession | None = None def _get_forward_client(self) -> aiohttp.ClientSession: return aiohttp.ClientSession(headers={"User-Agent": ForwardClient.USER_AGENT}) - async def forward_request_remote(self, *, request: Request, primary_url: str) -> Response: + def _acceptable_response_content_type(self, *, content_type: str) -> bool: + return ( + content_type.startswith("application/") and content_type.endswith("json") + ) or content_type == "application/octet-stream" + + async def _forward_request_remote( + self, + *, + request: Request, + primary_url: str, + ) -> bytes: LOG.info("Forwarding %s request to remote url: %r since we're not the master", request.method, request.url) timeout = 60.0 headers = request.headers.mutablecopy() @@ -37,17 +53,34 @@ async def forward_request_remote(self, *, request: Request, primary_url: str) -> forward_url = f"{forward_url}?{request.url.query}" logging.error(forward_url) - with async_timeout.timeout(timeout): + async with async_timeout.timeout(timeout): body_data = await request.body() async with func(forward_url, headers=headers, data=body_data) as response: - if response.headers.get("Content-Type", "").startswith("application/json"): - return JSONResponse( - content=await response.text(), - status_code=response.status, - headers=response.headers, - ) - return PlainTextResponse( - content=await response.text(), - status_code=response.status, - headers=response.headers, + if self._acceptable_response_content_type(content_type=response.headers.get("Content-Type")): + return await response.text() + LOG.error("Unknown response for forwarded request: %s", response) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "error_code": status.HTTP_500_INTERNAL_SERVER_ERROR, + "message": "Unknown response for forwarded request.", + }, ) + + @overload + async def forward_request_remote(self, *, request: Request, primary_url: str, response_type: type[T]) -> T: + ... + + @overload + async def forward_request_remote(self, *, request: Request, primary_url: str, response_type: type[P]) -> P: + ... + + async def forward_request_remote(self, *, request: Request, primary_url: str, response_type: type[T] | type[P]) -> T | P: + body = await self._forward_request_remote(request=request, primary_url=primary_url) + if response_type == int: + return int(body) # type: ignore[return-value] + if response_type == list[int]: + return json_decode(body, assume_type=list[int]) # type: ignore[return-value] + if issubclass(response_type, BaseModel): + return response_type.model_validate_json(body) # type: ignore[return-value] + raise ValueError("Did not match any expected type") diff --git a/src/karapace/statsd.py b/src/karapace/statsd.py index 13b0db0a4..3ef4d001c 100644 --- a/src/karapace/statsd.py +++ b/src/karapace/statsd.py @@ -12,7 +12,7 @@ from collections.abc import Iterator from contextlib import contextmanager -from karapace.config import Config +from karapace.config import Config, KarapaceTags from karapace.sentry import get_sentry_client from typing import Any, Final @@ -28,7 +28,7 @@ class StatsClient: def __init__(self, config: Config) -> None: self._dest_addr: Final = (config.statsd_host, config.statsd_port) self._socket: Final = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self._tags: Final = config.tags or {} + self._tags: Final[KarapaceTags] = config.tags self.sentry_client: Final = get_sentry_client(sentry_config=(config.sentry or None)) @contextmanager diff --git a/src/karapace/typing.py b/src/karapace/typing.py index a205ae9de..7927657eb 100644 --- a/src/karapace/typing.py +++ b/src/karapace/typing.py @@ -5,11 +5,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence +from collections.abc import Generator, Mapping, Sequence from enum import Enum, unique from karapace.errors import InvalidVersion from pydantic import ValidationInfo -from typing import Any, ClassVar, NewType, Union +from typing import Any, Callable, ClassVar, NewType, Union from typing_extensions import TypeAlias import functools @@ -38,7 +38,7 @@ class Subject(str): @classmethod # TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually. # Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information. - def __get_validators__(cls): + def __get_validators__(cls) -> Generator[Callable[[str, ValidationInfo], str], None, None]: yield cls.validate @classmethod diff --git a/src/schema_registry/factory.py b/src/schema_registry/factory.py index 667bdfc7d..ae104b854 100644 --- a/src/schema_registry/factory.py +++ b/src/schema_registry/factory.py @@ -17,6 +17,7 @@ from schema_registry.http_handlers import setup_exception_handlers from schema_registry.middlewares import setup_middlewares from schema_registry.routers.setup import setup_routers +from typing import AsyncContextManager, Callable import logging @@ -44,12 +45,18 @@ async def karapace_schema_registry_lifespan( stastd.close() -def create_karapace_application(*, config: Config, lifespan: AsyncGenerator[None, None]) -> FastAPI: +def create_karapace_application( + *, + config: Config, + lifespan: Callable[ + [FastAPI, StatsClient, KarapaceSchemaRegistry, AuthenticatorAndAuthorizer], AsyncContextManager[None] + ], +) -> FastAPI: configure_logging(config=config) log_config_without_secrets(config=config) logging.info("Starting Karapace Schema Registry (%s)", karapace_version.__version__) - app = FastAPI(lifespan=lifespan) + app = FastAPI(lifespan=lifespan) # type: ignore[arg-type] setup_routers(app=app) setup_exception_handlers(app=app) setup_middlewares(app=app) diff --git a/src/schema_registry/http_handlers/__init__.py b/src/schema_registry/http_handlers/__init__.py index 93bc853cc..e15ec9565 100644 --- a/src/schema_registry/http_handlers/__init__.py +++ b/src/schema_registry/http_handlers/__init__.py @@ -14,11 +14,11 @@ def setup_exception_handlers(app: FastAPI) -> None: @app.exception_handler(StarletteHTTPException) - async def http_exception_handler(_: StarletteHTTPRequest, exc: StarletteHTTPException): + async def http_exception_handler(_: StarletteHTTPRequest, exc: StarletteHTTPException) -> JSONResponse: return JSONResponse(status_code=exc.status_code, content=exc.detail) @app.exception_handler(RequestValidationError) - async def validation_exception_handler(_: StarletteHTTPRequest, exc: RequestValidationError): + async def validation_exception_handler(_: StarletteHTTPRequest, exc: RequestValidationError) -> JSONResponse: error_code = HTTPStatus.UNPROCESSABLE_ENTITY.value if isinstance(exc, KarapaceValidationError): error_code = exc.error_code diff --git a/src/schema_registry/middlewares/__init__.py b/src/schema_registry/middlewares/__init__.py index 952f8abde..7c0559687 100644 --- a/src/schema_registry/middlewares/__init__.py +++ b/src/schema_registry/middlewares/__init__.py @@ -3,14 +3,16 @@ See LICENSE for details """ -from fastapi import FastAPI, HTTPException, Request +from collections.abc import Awaitable +from fastapi import FastAPI, HTTPException, Request, Response from fastapi.responses import JSONResponse from karapace.content_type import check_schema_headers +from typing import Callable def setup_middlewares(app: FastAPI) -> None: @app.middleware("http") - async def set_content_types(request: Request, call_next): + async def set_content_types(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: try: response_content_type = check_schema_headers(request) except HTTPException as exc: diff --git a/src/schema_registry/routers/config.py b/src/schema_registry/routers/config.py index 9e5543e71..1c95ac046 100644 --- a/src/schema_registry/routers/config.py +++ b/src/schema_registry/routers/config.py @@ -55,7 +55,9 @@ async def config_put( return await controller.config_set(compatibility_level_request=compatibility_level_request) if not primary_url: raise no_primary_url_error() - return await forward_client.forward_request_remote(request=request, primary_url=primary_url) + return await forward_client.forward_request_remote( + request=request, primary_url=primary_url, response_type=CompatibilityResponse + ) @config_router.get("/{subject}") @@ -93,7 +95,9 @@ async def config_set_subject( return await controller.config_subject_set(subject=subject, compatibility_level_request=compatibility_level_request) if not primary_url: raise no_primary_url_error() - return await forward_client.forward_request_remote(request=request, primary_url=primary_url) + return await forward_client.forward_request_remote( + request=request, primary_url=primary_url, response_type=CompatibilityResponse + ) @config_router.delete("/{subject}") @@ -115,4 +119,6 @@ async def config_delete_subject( return await controller.config_subject_delete(subject=subject) if not primary_url: raise no_primary_url_error() - return await forward_client.forward_request_remote(request=request, primary_url=primary_url) + return await forward_client.forward_request_remote( + request=request, primary_url=primary_url, response_type=CompatibilityResponse + ) diff --git a/src/schema_registry/routers/master_availability.py b/src/schema_registry/routers/master_availability.py index 55e792275..a3783575a 100644 --- a/src/schema_registry/routers/master_availability.py +++ b/src/schema_registry/routers/master_availability.py @@ -4,8 +4,7 @@ """ from dependency_injector.wiring import inject, Provide -from fastapi import APIRouter, Depends, HTTPException, Request, Response, status -from fastapi.responses import JSONResponse +from fastapi import APIRouter, Depends, Request, Response from karapace.config import Config from karapace.forward_client import ForwardClient from karapace.schema_registry import KarapaceSchemaRegistry @@ -13,7 +12,6 @@ from schema_registry.container import SchemaRegistryContainer from typing import Final -import json import logging LOG = logging.getLogger(__name__) @@ -47,7 +45,8 @@ async def master_availability( response.headers.update(NO_CACHE_HEADER) if ( - schema_registry.schema_reader.master_coordinator._sc is not None # pylint: disable=protected-access + schema_registry.schema_reader.master_coordinator is not None + and schema_registry.schema_reader.master_coordinator._sc is not None # pylint: disable=protected-access and schema_registry.schema_reader.master_coordinator._sc.is_master_assigned_to_myself() # pylint: disable=protected-access ): return MasterAvailabilityResponse(master_available=are_we_master) @@ -55,12 +54,6 @@ async def master_availability( if master_url is None or f"{config.advertised_hostname}:{config.advertised_port}" in master_url: return NO_MASTER - forward_response = await forward_client.forward_request_remote(request=request, primary_url=master_url) - if isinstance(response, JSONResponse): - response_json = json.loads(forward_response.body) - return MasterAvailabilityResponse(master_available=response_json["master_availability"]) - - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=forward_response.body, + return await forward_client.forward_request_remote( + request=request, primary_url=master_url, response_type=MasterAvailabilityResponse ) diff --git a/src/schema_registry/routers/metrics.py b/src/schema_registry/routers/metrics.py index 23b4b39f8..897a1d65d 100644 --- a/src/schema_registry/routers/metrics.py +++ b/src/schema_registry/routers/metrics.py @@ -6,7 +6,6 @@ from dependency_injector.wiring import inject, Provide from fastapi import APIRouter, Depends, Response from karapace.instrumentation.prometheus import PrometheusInstrumentation -from pydantic import BaseModel from schema_registry.container import SchemaRegistryContainer metrics_router = APIRouter( @@ -20,5 +19,5 @@ @inject async def metrics( prometheus: PrometheusInstrumentation = Depends(Provide[SchemaRegistryContainer.karapace_container.prometheus]), -) -> BaseModel: +) -> Response: return Response(content=await prometheus.serve_metrics(), media_type=prometheus.CONTENT_TYPE_LATEST) diff --git a/src/schema_registry/routers/mode.py b/src/schema_registry/routers/mode.py index 870a876d2..c139e8e7d 100644 --- a/src/schema_registry/routers/mode.py +++ b/src/schema_registry/routers/mode.py @@ -9,6 +9,7 @@ from karapace.typing import Subject from schema_registry.container import SchemaRegistryContainer from schema_registry.routers.errors import unauthorized +from schema_registry.routers.requests import ModeResponse from schema_registry.schema_registry_apis import KarapaceSchemaRegistryController from schema_registry.user import get_current_user from typing import Annotated @@ -26,7 +27,7 @@ async def mode_get( user: Annotated[User, Depends(get_current_user)], authorizer: AuthenticatorAndAuthorizer = Depends(Provide[SchemaRegistryContainer.karapace_container.authorizer]), controller: KarapaceSchemaRegistryController = Depends(Provide[SchemaRegistryContainer.schema_registry_controller]), -): +) -> ModeResponse: if authorizer and not authorizer.check_authorization(user, Operation.Read, "Config:"): raise unauthorized() @@ -40,7 +41,7 @@ async def mode_get_subject( user: Annotated[User, Depends(get_current_user)], authorizer: AuthenticatorAndAuthorizer = Depends(Provide[SchemaRegistryContainer.karapace_container.authorizer]), controller: KarapaceSchemaRegistryController = Depends(Provide[SchemaRegistryContainer.schema_registry_controller]), -): +) -> ModeResponse: if authorizer and not authorizer.check_authorization(user, Operation.Read, f"Subject:{subject}"): raise unauthorized() diff --git a/src/schema_registry/routers/subjects.py b/src/schema_registry/routers/subjects.py index a1964cc7a..4d0a9fe94 100644 --- a/src/schema_registry/routers/subjects.py +++ b/src/schema_registry/routers/subjects.py @@ -85,7 +85,7 @@ async def subjects_subject_delete( return await controller.subject_delete(subject=subject, permanent=permanent) if not primary_url: raise no_primary_url_error() - return await forward_client.forward_request_remote(request=request, primary_url=primary_url) + return await forward_client.forward_request_remote(request=request, primary_url=primary_url, response_type=list[int]) @subjects_router.post("/{subject}/versions") @@ -166,7 +166,7 @@ async def subjects_subject_version_delete( return await controller.subject_version_delete(subject=subject, version=version, permanent=permanent) if not primary_url: raise no_primary_url_error() - return await forward_client.forward_request_remote(request=request, primary_url=primary_url) + return await forward_client.forward_request_remote(request=request, primary_url=primary_url, response_type=int) @subjects_router.get("/{subject}/versions/{version}/schema") diff --git a/src/schema_registry/schema_registry_apis.py b/src/schema_registry/schema_registry_apis.py index bb2a12e73..56cd567b1 100644 --- a/src/schema_registry/schema_registry_apis.py +++ b/src/schema_registry/schema_registry_apis.py @@ -6,7 +6,7 @@ from avro.errors import SchemaParseException from dependency_injector.wiring import inject, Provide -from fastapi import Depends, HTTPException, Request, Response, status +from fastapi import Depends, HTTPException, Request, status from karapace.auth import AuthenticatorAndAuthorizer, Operation, User from karapace.compatibility import CompatibilityModes from karapace.compatibility.jsonschema.checks import is_incompatible @@ -794,7 +794,7 @@ async def subject_post( normalize: bool, forward_client: ForwardClient, request: Request, - ) -> SchemaIdResponse | Response: + ) -> SchemaIdResponse: LOG.debug("POST with subject: %r, request: %r", subject, schema_request) references = self._validate_references(schema_request=schema_request) @@ -865,7 +865,9 @@ async def subject_post( elif not primary_url: raise no_primary_url_error() else: - return await forward_client.forward_request_remote(request=request, primary_url=primary_url) + return await forward_client.forward_request_remote( + request=request, primary_url=primary_url, response_type=SchemaIdResponse + ) async def get_global_mode(self) -> ModeResponse: return ModeResponse(mode=str(self.schema_registry.get_global_mode())) diff --git a/src/schema_registry/user.py b/src/schema_registry/user.py index b3d6919a2..bab8044ac 100644 --- a/src/schema_registry/user.py +++ b/src/schema_registry/user.py @@ -15,7 +15,7 @@ async def get_current_user( credentials: Annotated[HTTPBasicCredentials, Depends(HTTPBasic(auto_error=False))], authorizer: AuthenticatorAndAuthorizer = Depends(Provide[SchemaRegistryContainer.karapace_container.authorizer]), -) -> User: +) -> User | None: if authorizer.MUST_AUTHENTICATE and not credentials: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/tests/integration/test_master_coordinator.py b/tests/integration/test_master_coordinator.py index b4fe8088f..876a91d66 100644 --- a/tests/integration/test_master_coordinator.py +++ b/tests/integration/test_master_coordinator.py @@ -190,7 +190,6 @@ async def test_no_eligible_master(kafka_servers: KafkaServers) -> None: await mc.close() -@pytest.mark.skip(reason="requires master forwarding to be implemented") async def test_schema_request_forwarding( registry_async_pair, registry_async_retry_client: RetryRestClient, diff --git a/tests/integration/test_schema_registry_auth.py b/tests/integration/test_schema_registry_auth.py index 7624fffb6..89832355f 100644 --- a/tests/integration/test_schema_registry_auth.py +++ b/tests/integration/test_schema_registry_auth.py @@ -19,7 +19,6 @@ import aiohttp import asyncio -import pytest NEW_TOPIC_TIMEOUT = 10 @@ -204,7 +203,6 @@ async def test_sr_ids(registry_async_retry_client_auth: RetryRestClient) -> None assert res.status_code == 200 -@pytest.mark.skip(reason="requires master forwarding to be implemented") async def test_sr_auth_forwarding( registry_async_auth_pair: list[str], registry_async_retry_client_auth: RetryRestClient ) -> None: