Skip to content

Commit

Permalink
fix: misc typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jjaakola-aiven committed Dec 9, 2024
1 parent 0d6074c commit d7e9684
Show file tree
Hide file tree
Showing 17 changed files with 106 additions and 63 deletions.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ ignore_errors = True
# dependencies.
# - Write your own stubs. You don't need to write stubs for the whole library,
# only the parts that Karapace is interacting with.
[mypy-accept_types.*]
ignore_missing_imports = True

[mypy-aiokafka.*]
ignore_missing_imports = True

Expand Down
18 changes: 9 additions & 9 deletions src/karapace/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from karapace.config import Config, InvalidConfiguration
from karapace.statsd import StatsClient
from karapace.utils import json_decode, json_encode
from typing import Final, Protocol
from typing import Protocol
from typing_extensions import override, TypedDict
from watchfiles import awatch, Change

Expand Down Expand Up @@ -98,12 +98,12 @@ class AuthData(TypedDict):


class AuthenticateProtocol(Protocol):
def authenticate(self, *, username: str, password: str) -> User:
def authenticate(self, *, username: str, password: str) -> User | None:
...


class AuthorizeProtocol(Protocol):
def get_user(self, username: str) -> User:
def get_user(self, username: str) -> User | None:
...

def check_authorization(self, user: User | None, operation: Operation, resource: str) -> bool:
Expand All @@ -114,7 +114,7 @@ def check_authorization_any(self, user: User | None, operation: Operation, resou


class AuthenticatorAndAuthorizer(AuthenticateProtocol, AuthorizeProtocol):
MUST_AUTHENTICATE: Final[bool] = True
MUST_AUTHENTICATE: bool = True

async def close(self) -> None:
...
Expand All @@ -124,14 +124,14 @@ async def start(self, stats: StatsClient) -> None: # pylint: disable=unused-arg


class NoAuthAndAuthz(AuthenticatorAndAuthorizer):
MUST_AUTHENTICATE: Final[bool] = False
MUST_AUTHENTICATE: bool = False

@override
def authenticate(self, *, username: str, password: str) -> User:
def authenticate(self, *, username: str, password: str) -> User | None:
return None

@override
def get_user(self, username: str) -> User:
def get_user(self, username: str) -> User | None:
return None

@override
Expand All @@ -156,7 +156,7 @@ def __init__(self, *, user_db: dict[str, User] | None = None, permissions: list[
self.user_db = user_db or {}
self.permissions = permissions or []

def get_user(self, username: str) -> User:
def get_user(self, username: str) -> User | None:
user = self.user_db.get(username)
if not user:
raise ValueError("No user found")
Expand Down Expand Up @@ -289,7 +289,7 @@ def _load_authfile(self) -> None:
raise InvalidConfiguration("Failed to load auth file") from ex

@override
def authenticate(self, *, username: str, password: str) -> User:
def authenticate(self, *, username: str, password: str) -> User | None:
user = self.get_user(username)
if user is None or not user.compare_password(password):
raise AuthenticationError()
Expand Down
63 changes: 48 additions & 15 deletions src/karapace/forward_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
See LICENSE for details
"""

from fastapi import Request, Response
from fastapi.responses import JSONResponse, PlainTextResponse
from fastapi import HTTPException, Request, status
from karapace.utils import json_decode
from karapace.version import __version__
from pydantic import BaseModel
from typing import overload, TypeVar, Union

import aiohttp
import async_timeout
Expand All @@ -14,16 +16,30 @@
LOG = logging.getLogger(__name__)


T = TypeVar("T", bound=BaseModel)
P = TypeVar("P", bound=Union[int, list[int]])


class ForwardClient:
USER_AGENT = f"Karapace/{__version__}"

def __init__(self):
def __init__(self) -> None:
self._forward_client: aiohttp.ClientSession | None = None

def _get_forward_client(self) -> aiohttp.ClientSession:
return aiohttp.ClientSession(headers={"User-Agent": ForwardClient.USER_AGENT})

async def forward_request_remote(self, *, request: Request, primary_url: str) -> Response:
def _acceptable_response_content_type(self, *, content_type: str) -> bool:
return (
content_type.startswith("application/") and content_type.endswith("json")
) or content_type == "application/octet-stream"

async def _forward_request_remote(
self,
*,
request: Request,
primary_url: str,
) -> bytes:
LOG.info("Forwarding %s request to remote url: %r since we're not the master", request.method, request.url)
timeout = 60.0
headers = request.headers.mutablecopy()
Expand All @@ -37,17 +53,34 @@ async def forward_request_remote(self, *, request: Request, primary_url: str) ->
forward_url = f"{forward_url}?{request.url.query}"
logging.error(forward_url)

with async_timeout.timeout(timeout):
async with async_timeout.timeout(timeout):
body_data = await request.body()
async with func(forward_url, headers=headers, data=body_data) as response:
if response.headers.get("Content-Type", "").startswith("application/json"):
return JSONResponse(
content=await response.text(),
status_code=response.status,
headers=response.headers,
)
return PlainTextResponse(
content=await response.text(),
status_code=response.status,
headers=response.headers,
if self._acceptable_response_content_type(content_type=response.headers.get("Content-Type")):
return await response.text()
LOG.error("Unknown response for forwarded request: %s", response)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"error_code": status.HTTP_500_INTERNAL_SERVER_ERROR,
"message": "Unknown response for forwarded request.",
},
)

@overload
async def forward_request_remote(self, *, request: Request, primary_url: str, response_type: type[T]) -> T:
...

@overload
async def forward_request_remote(self, *, request: Request, primary_url: str, response_type: type[P]) -> P:
...

async def forward_request_remote(self, *, request: Request, primary_url: str, response_type: type[T] | type[P]) -> T | P:
body = await self._forward_request_remote(request=request, primary_url=primary_url)
if response_type == int:
return int(body) # type: ignore[return-value]
if response_type == list[int]:
return json_decode(body, assume_type=list[int]) # type: ignore[return-value]
if issubclass(response_type, BaseModel):
return response_type.model_validate_json(body) # type: ignore[return-value]
raise ValueError("Did not match any expected type")
4 changes: 2 additions & 2 deletions src/karapace/statsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from collections.abc import Iterator
from contextlib import contextmanager
from karapace.config import Config
from karapace.config import Config, KarapaceTags
from karapace.sentry import get_sentry_client
from typing import Any, Final

Expand All @@ -28,7 +28,7 @@ class StatsClient:
def __init__(self, config: Config) -> None:
self._dest_addr: Final = (config.statsd_host, config.statsd_port)
self._socket: Final = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._tags: Final = config.tags or {}
self._tags: Final[KarapaceTags] = config.tags
self.sentry_client: Final = get_sentry_client(sentry_config=(config.sentry or None))

@contextmanager
Expand Down
6 changes: 3 additions & 3 deletions src/karapace/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from collections.abc import Generator, Mapping, Sequence
from enum import Enum, unique
from karapace.errors import InvalidVersion
from pydantic import ValidationInfo
from typing import Any, ClassVar, NewType, Union
from typing import Any, Callable, ClassVar, NewType, Union
from typing_extensions import TypeAlias

import functools
Expand Down Expand Up @@ -38,7 +38,7 @@ class Subject(str):
@classmethod
# TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually.
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
def __get_validators__(cls):
def __get_validators__(cls) -> Generator[Callable[[str, ValidationInfo], str], None, None]:
yield cls.validate

@classmethod
Expand Down
11 changes: 9 additions & 2 deletions src/schema_registry/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from schema_registry.http_handlers import setup_exception_handlers
from schema_registry.middlewares import setup_middlewares
from schema_registry.routers.setup import setup_routers
from typing import AsyncContextManager, Callable

import logging

Expand Down Expand Up @@ -44,12 +45,18 @@ async def karapace_schema_registry_lifespan(
stastd.close()


def create_karapace_application(*, config: Config, lifespan: AsyncGenerator[None, None]) -> FastAPI:
def create_karapace_application(
*,
config: Config,
lifespan: Callable[
[FastAPI, StatsClient, KarapaceSchemaRegistry, AuthenticatorAndAuthorizer], AsyncContextManager[None]
],
) -> FastAPI:
configure_logging(config=config)
log_config_without_secrets(config=config)
logging.info("Starting Karapace Schema Registry (%s)", karapace_version.__version__)

app = FastAPI(lifespan=lifespan)
app = FastAPI(lifespan=lifespan) # type: ignore[arg-type]
setup_routers(app=app)
setup_exception_handlers(app=app)
setup_middlewares(app=app)
Expand Down
4 changes: 2 additions & 2 deletions src/schema_registry/http_handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

def setup_exception_handlers(app: FastAPI) -> None:
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(_: StarletteHTTPRequest, exc: StarletteHTTPException):
async def http_exception_handler(_: StarletteHTTPRequest, exc: StarletteHTTPException) -> JSONResponse:
return JSONResponse(status_code=exc.status_code, content=exc.detail)

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_: StarletteHTTPRequest, exc: RequestValidationError):
async def validation_exception_handler(_: StarletteHTTPRequest, exc: RequestValidationError) -> JSONResponse:
error_code = HTTPStatus.UNPROCESSABLE_ENTITY.value
if isinstance(exc, KarapaceValidationError):
error_code = exc.error_code
Expand Down
6 changes: 4 additions & 2 deletions src/schema_registry/middlewares/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
See LICENSE for details
"""

from fastapi import FastAPI, HTTPException, Request
from collections.abc import Awaitable
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse
from karapace.content_type import check_schema_headers
from typing import Callable


def setup_middlewares(app: FastAPI) -> None:
@app.middleware("http")
async def set_content_types(request: Request, call_next):
async def set_content_types(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
try:
response_content_type = check_schema_headers(request)
except HTTPException as exc:
Expand Down
12 changes: 9 additions & 3 deletions src/schema_registry/routers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ async def config_put(
return await controller.config_set(compatibility_level_request=compatibility_level_request)
if not primary_url:
raise no_primary_url_error()
return await forward_client.forward_request_remote(request=request, primary_url=primary_url)
return await forward_client.forward_request_remote(
request=request, primary_url=primary_url, response_type=CompatibilityResponse
)


@config_router.get("/{subject}")
Expand Down Expand Up @@ -93,7 +95,9 @@ async def config_set_subject(
return await controller.config_subject_set(subject=subject, compatibility_level_request=compatibility_level_request)
if not primary_url:
raise no_primary_url_error()
return await forward_client.forward_request_remote(request=request, primary_url=primary_url)
return await forward_client.forward_request_remote(
request=request, primary_url=primary_url, response_type=CompatibilityResponse
)


@config_router.delete("/{subject}")
Expand All @@ -115,4 +119,6 @@ async def config_delete_subject(
return await controller.config_subject_delete(subject=subject)
if not primary_url:
raise no_primary_url_error()
return await forward_client.forward_request_remote(request=request, primary_url=primary_url)
return await forward_client.forward_request_remote(
request=request, primary_url=primary_url, response_type=CompatibilityResponse
)
17 changes: 5 additions & 12 deletions src/schema_registry/routers/master_availability.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
"""

from dependency_injector.wiring import inject, Provide
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.responses import JSONResponse
from fastapi import APIRouter, Depends, Request, Response
from karapace.config import Config
from karapace.forward_client import ForwardClient
from karapace.schema_registry import KarapaceSchemaRegistry
from pydantic import BaseModel
from schema_registry.container import SchemaRegistryContainer
from typing import Final

import json
import logging

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -47,20 +45,15 @@ async def master_availability(
response.headers.update(NO_CACHE_HEADER)

if (
schema_registry.schema_reader.master_coordinator._sc is not None # pylint: disable=protected-access
schema_registry.schema_reader.master_coordinator is not None
and schema_registry.schema_reader.master_coordinator._sc is not None # pylint: disable=protected-access
and schema_registry.schema_reader.master_coordinator._sc.is_master_assigned_to_myself() # pylint: disable=protected-access
):
return MasterAvailabilityResponse(master_available=are_we_master)

if master_url is None or f"{config.advertised_hostname}:{config.advertised_port}" in master_url:
return NO_MASTER

forward_response = await forward_client.forward_request_remote(request=request, primary_url=master_url)
if isinstance(response, JSONResponse):
response_json = json.loads(forward_response.body)
return MasterAvailabilityResponse(master_available=response_json["master_availability"])

raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=forward_response.body,
return await forward_client.forward_request_remote(
request=request, primary_url=master_url, response_type=MasterAvailabilityResponse
)
3 changes: 1 addition & 2 deletions src/schema_registry/routers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from dependency_injector.wiring import inject, Provide
from fastapi import APIRouter, Depends, Response
from karapace.instrumentation.prometheus import PrometheusInstrumentation
from pydantic import BaseModel
from schema_registry.container import SchemaRegistryContainer

metrics_router = APIRouter(
Expand All @@ -20,5 +19,5 @@
@inject
async def metrics(
prometheus: PrometheusInstrumentation = Depends(Provide[SchemaRegistryContainer.karapace_container.prometheus]),
) -> BaseModel:
) -> Response:
return Response(content=await prometheus.serve_metrics(), media_type=prometheus.CONTENT_TYPE_LATEST)
5 changes: 3 additions & 2 deletions src/schema_registry/routers/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from karapace.typing import Subject
from schema_registry.container import SchemaRegistryContainer
from schema_registry.routers.errors import unauthorized
from schema_registry.routers.requests import ModeResponse
from schema_registry.schema_registry_apis import KarapaceSchemaRegistryController
from schema_registry.user import get_current_user
from typing import Annotated
Expand All @@ -26,7 +27,7 @@ async def mode_get(
user: Annotated[User, Depends(get_current_user)],
authorizer: AuthenticatorAndAuthorizer = Depends(Provide[SchemaRegistryContainer.karapace_container.authorizer]),
controller: KarapaceSchemaRegistryController = Depends(Provide[SchemaRegistryContainer.schema_registry_controller]),
):
) -> ModeResponse:
if authorizer and not authorizer.check_authorization(user, Operation.Read, "Config:"):
raise unauthorized()

Expand All @@ -40,7 +41,7 @@ async def mode_get_subject(
user: Annotated[User, Depends(get_current_user)],
authorizer: AuthenticatorAndAuthorizer = Depends(Provide[SchemaRegistryContainer.karapace_container.authorizer]),
controller: KarapaceSchemaRegistryController = Depends(Provide[SchemaRegistryContainer.schema_registry_controller]),
):
) -> ModeResponse:
if authorizer and not authorizer.check_authorization(user, Operation.Read, f"Subject:{subject}"):
raise unauthorized()

Expand Down
Loading

0 comments on commit d7e9684

Please sign in to comment.