Skip to content

Commit

Permalink
Merge pull request #1004 from Aiven-Open/jjaakola-aiven-fastapi-lint-…
Browse files Browse the repository at this point in the history
…fixes

Linting fixes
  • Loading branch information
nosahama authored Dec 10, 2024
2 parents 552187c + 5f77f09 commit 8024753
Show file tree
Hide file tree
Showing 22 changed files with 380 additions and 174 deletions.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ ignore_errors = True
# dependencies.
# - Write your own stubs. You don't need to write stubs for the whole library,
# only the parts that Karapace is interacting with.
[mypy-accept_types.*]
ignore_missing_imports = True

[mypy-aiokafka.*]
ignore_missing_imports = True

Expand Down
20 changes: 10 additions & 10 deletions src/karapace/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from karapace.config import Config, InvalidConfiguration
from karapace.statsd import StatsClient
from karapace.utils import json_decode, json_encode
from typing import Final, Protocol
from typing import Protocol
from typing_extensions import override, TypedDict
from watchfiles import awatch, Change

Expand Down Expand Up @@ -98,12 +98,12 @@ class AuthData(TypedDict):


class AuthenticateProtocol(Protocol):
def authenticate(self, *, username: str, password: str) -> User:
def authenticate(self, *, username: str, password: str) -> User | None:
...


class AuthorizeProtocol(Protocol):
def get_user(self, username: str) -> User:
def get_user(self, username: str) -> User | None:
...

def check_authorization(self, user: User | None, operation: Operation, resource: str) -> bool:
Expand All @@ -114,24 +114,24 @@ def check_authorization_any(self, user: User | None, operation: Operation, resou


class AuthenticatorAndAuthorizer(AuthenticateProtocol, AuthorizeProtocol):
MUST_AUTHENTICATE: Final[bool] = True
MUST_AUTHENTICATE: bool = True

async def close(self) -> None:
...

async def start(self, stats: StatsClient) -> None:
async def start(self, stats: StatsClient) -> None: # pylint: disable=unused-argument
...


class NoAuthAndAuthz(AuthenticatorAndAuthorizer):
MUST_AUTHENTICATE: Final[bool] = False
MUST_AUTHENTICATE: bool = False

@override
def authenticate(self, *, username: str, password: str) -> User:
def authenticate(self, *, username: str, password: str) -> User | None:
return None

@override
def get_user(self, username: str) -> User:
def get_user(self, username: str) -> User | None:
return None

@override
Expand All @@ -156,7 +156,7 @@ def __init__(self, *, user_db: dict[str, User] | None = None, permissions: list[
self.user_db = user_db or {}
self.permissions = permissions or []

def get_user(self, username: str) -> User:
def get_user(self, username: str) -> User | None:
user = self.user_db.get(username)
if not user:
raise ValueError("No user found")
Expand Down Expand Up @@ -289,7 +289,7 @@ def _load_authfile(self) -> None:
raise InvalidConfiguration("Failed to load auth file") from ex

@override
def authenticate(self, *, username: str, password: str) -> User:
def authenticate(self, *, username: str, password: str) -> User | None:
user = self.get_user(username)
if user is None or not user.compare_password(password):
raise AuthenticationError()
Expand Down
2 changes: 1 addition & 1 deletion src/karapace/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_rest_base_uri(self) -> str:

def to_env_str(self) -> str:
env_lines: list[str] = []
for key, value in self.dict().items():
for key, value in self.model_dump().items():
if value is not None:
env_lines.append(f"{key.upper()}={value}")
return "\n".join(env_lines)
Expand Down
84 changes: 67 additions & 17 deletions src/karapace/forward_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
See LICENSE for details
"""

from fastapi import Request, Response
from fastapi.responses import JSONResponse, PlainTextResponse
from fastapi import HTTPException, Request, status
from karapace.utils import json_decode
from karapace.version import __version__
from pydantic import BaseModel
from typing import overload, TypeVar, Union

import aiohttp
import async_timeout
Expand All @@ -14,16 +16,30 @@
LOG = logging.getLogger(__name__)


BaseModelResponse = TypeVar("BaseModelResponse", bound=BaseModel)
SimpleTypeResponse = TypeVar("SimpleTypeResponse", bound=Union[int, list[int]])


class ForwardClient:
USER_AGENT = f"Karapace/{__version__}"

def __init__(self):
def __init__(self) -> None:
self._forward_client: aiohttp.ClientSession | None = None

def _get_forward_client(self) -> aiohttp.ClientSession:
return aiohttp.ClientSession(headers={"User-Agent": ForwardClient.USER_AGENT})

async def forward_request_remote(self, *, request: Request, primary_url: str) -> Response:
def _acceptable_response_content_type(self, *, content_type: str) -> bool:
return (
content_type.startswith("application/") and content_type.endswith("json")
) or content_type == "application/octet-stream"

async def _forward_request_remote(
self,
*,
request: Request,
primary_url: str,
) -> bytes:
LOG.info("Forwarding %s request to remote url: %r since we're not the master", request.method, request.url)
timeout = 60.0
headers = request.headers.mutablecopy()
Expand All @@ -37,18 +53,52 @@ async def forward_request_remote(self, *, request: Request, primary_url: str) ->
forward_url = f"{forward_url}?{request.url.query}"
logging.error(forward_url)

with async_timeout.timeout(timeout):
async with async_timeout.timeout(timeout):
body_data = await request.body()
async with func(forward_url, headers=headers, data=body_data) as response:
if response.headers.get("Content-Type", "").startswith("application/json"):
return JSONResponse(
content=await response.text(),
status_code=response.status,
headers=response.headers,
)
else:
return PlainTextResponse(
content=await response.text(),
status_code=response.status,
headers=response.headers,
)
if self._acceptable_response_content_type(content_type=response.headers.get("Content-Type")):
return await response.text()
LOG.error("Unknown response for forwarded request: %s", response)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"error_code": status.HTTP_500_INTERNAL_SERVER_ERROR,
"message": "Unknown response for forwarded request.",
},
)

@overload
async def forward_request_remote(
self,
*,
request: Request,
primary_url: str,
response_type: type[BaseModelResponse],
) -> BaseModelResponse:
...

@overload
async def forward_request_remote(
self,
*,
request: Request,
primary_url: str,
response_type: type[SimpleTypeResponse],
) -> SimpleTypeResponse:
...

async def forward_request_remote(
self,
*,
request: Request,
primary_url: str,
response_type: type[BaseModelResponse] | type[SimpleTypeResponse],
) -> BaseModelResponse | SimpleTypeResponse:
body = await self._forward_request_remote(request=request, primary_url=primary_url)
if response_type == int:
return int(body) # type: ignore[return-value]
if response_type == list[int]:
return json_decode(body, assume_type=list[int]) # type: ignore[return-value]
if issubclass(response_type, BaseModel):
return response_type.parse_raw(body) # type: ignore[return-value]
raise ValueError("Did not match any expected type")
3 changes: 1 addition & 2 deletions src/karapace/rapu.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def __init__(
def _create_aiohttp_application(self, *, config: Config) -> aiohttp.web.Application:
if config.http_request_max_size:
return aiohttp.web.Application(client_max_size=config.http_request_max_size)
else:
return aiohttp.web.Application()
return aiohttp.web.Application()

async def close_by_app(self, app: aiohttp.web.Application) -> None: # pylint: disable=unused-argument
await self.close()
Expand Down
4 changes: 2 additions & 2 deletions src/karapace/statsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from collections.abc import Iterator
from contextlib import contextmanager
from karapace.config import Config
from karapace.config import Config, KarapaceTags
from karapace.sentry import get_sentry_client
from typing import Any, Final

Expand All @@ -28,7 +28,7 @@ class StatsClient:
def __init__(self, config: Config) -> None:
self._dest_addr: Final = (config.statsd_host, config.statsd_port)
self._socket: Final = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._tags: Final = config.tags or {}
self._tags: Final[KarapaceTags] = config.tags
self.sentry_client: Final = get_sentry_client(sentry_config=(config.sentry or None))

@contextmanager
Expand Down
6 changes: 3 additions & 3 deletions src/karapace/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from collections.abc import Generator, Mapping, Sequence
from enum import Enum, unique
from karapace.errors import InvalidVersion
from pydantic import ValidationInfo
from typing import Any, ClassVar, NewType, Union
from typing import Any, Callable, ClassVar, NewType, Union
from typing_extensions import TypeAlias

import functools
Expand Down Expand Up @@ -38,7 +38,7 @@ class Subject(str):
@classmethod
# TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually.
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
def __get_validators__(cls):
def __get_validators__(cls) -> Generator[Callable[[str, ValidationInfo], str], None, None]:
yield cls.validate

@classmethod
Expand Down
11 changes: 9 additions & 2 deletions src/schema_registry/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from schema_registry.http_handlers import setup_exception_handlers
from schema_registry.middlewares import setup_middlewares
from schema_registry.routers.setup import setup_routers
from typing import AsyncContextManager, Callable

import logging

Expand Down Expand Up @@ -44,12 +45,18 @@ async def karapace_schema_registry_lifespan(
stastd.close()


def create_karapace_application(*, config: Config, lifespan: AsyncGenerator[None, None]) -> FastAPI:
def create_karapace_application(
*,
config: Config,
lifespan: Callable[
[FastAPI, StatsClient, KarapaceSchemaRegistry, AuthenticatorAndAuthorizer], AsyncContextManager[None]
],
) -> FastAPI:
configure_logging(config=config)
log_config_without_secrets(config=config)
logging.info("Starting Karapace Schema Registry (%s)", karapace_version.__version__)

app = FastAPI(lifespan=lifespan)
app = FastAPI(lifespan=lifespan) # type: ignore[arg-type]
setup_routers(app=app)
setup_exception_handlers(app=app)
setup_middlewares(app=app)
Expand Down
4 changes: 2 additions & 2 deletions src/schema_registry/http_handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

def setup_exception_handlers(app: FastAPI) -> None:
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(_: StarletteHTTPRequest, exc: StarletteHTTPException):
async def http_exception_handler(_: StarletteHTTPRequest, exc: StarletteHTTPException) -> JSONResponse:
return JSONResponse(status_code=exc.status_code, content=exc.detail)

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_: StarletteHTTPRequest, exc: RequestValidationError):
async def validation_exception_handler(_: StarletteHTTPRequest, exc: RequestValidationError) -> JSONResponse:
error_code = HTTPStatus.UNPROCESSABLE_ENTITY.value
if isinstance(exc, KarapaceValidationError):
error_code = exc.error_code
Expand Down
8 changes: 5 additions & 3 deletions src/schema_registry/middlewares/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
See LICENSE for details
"""

from fastapi import FastAPI, HTTPException, Request
from collections.abc import Awaitable
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse
from karapace.content_type import check_schema_headers
from typing import Callable


def setup_middlewares(app: FastAPI) -> None:
@app.middleware("http")
async def set_content_types(request: Request, call_next):
async def set_content_types(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
try:
response_content_type = check_schema_headers(request)
except HTTPException as exc:
Expand All @@ -25,7 +27,7 @@ async def set_content_types(request: Request, call_next):
if request.headers.get("Content-Type") == "application/octet-stream":
new_headers = request.headers.mutablecopy()
new_headers["Content-Type"] = "application/json"
request._headers = new_headers
request._headers = new_headers # pylint: disable=protected-access
request.scope.update(headers=request.headers.raw)

response = await call_next(request)
Expand Down
21 changes: 12 additions & 9 deletions src/schema_registry/routers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ async def config_put(
i_am_primary, primary_url = await schema_registry.get_master()
if i_am_primary:
return await controller.config_set(compatibility_level_request=compatibility_level_request)
elif not primary_url:
if not primary_url:
raise no_primary_url_error()
else:
return await forward_client.forward_request_remote(request=request, primary_url=primary_url)
return await forward_client.forward_request_remote(
request=request, primary_url=primary_url, response_type=CompatibilityResponse
)


@config_router.get("/{subject}")
Expand Down Expand Up @@ -92,10 +93,11 @@ async def config_set_subject(
i_am_primary, primary_url = await schema_registry.get_master()
if i_am_primary:
return await controller.config_subject_set(subject=subject, compatibility_level_request=compatibility_level_request)
elif not primary_url:
if not primary_url:
raise no_primary_url_error()
else:
return await forward_client.forward_request_remote(request=request, primary_url=primary_url)
return await forward_client.forward_request_remote(
request=request, primary_url=primary_url, response_type=CompatibilityResponse
)


@config_router.delete("/{subject}")
Expand All @@ -115,7 +117,8 @@ async def config_delete_subject(
i_am_primary, primary_url = await schema_registry.get_master()
if i_am_primary:
return await controller.config_subject_delete(subject=subject)
elif not primary_url:
if not primary_url:
raise no_primary_url_error()
else:
return await forward_client.forward_request_remote(request=request, primary_url=primary_url)
return await forward_client.forward_request_remote(
request=request, primary_url=primary_url, response_type=CompatibilityResponse
)
Loading

0 comments on commit 8024753

Please sign in to comment.