Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linting fixes #1004

Merged
merged 6 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ ignore_errors = True
# dependencies.
# - Write your own stubs. You don't need to write stubs for the whole library,
# only the parts that Karapace is interacting with.
[mypy-accept_types.*]
ignore_missing_imports = True

[mypy-aiokafka.*]
ignore_missing_imports = True

Expand Down
20 changes: 10 additions & 10 deletions src/karapace/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from karapace.config import Config, InvalidConfiguration
from karapace.statsd import StatsClient
from karapace.utils import json_decode, json_encode
from typing import Final, Protocol
from typing import Protocol
from typing_extensions import override, TypedDict
from watchfiles import awatch, Change

Expand Down Expand Up @@ -98,12 +98,12 @@ class AuthData(TypedDict):


class AuthenticateProtocol(Protocol):
def authenticate(self, *, username: str, password: str) -> User:
def authenticate(self, *, username: str, password: str) -> User | None:
...


class AuthorizeProtocol(Protocol):
def get_user(self, username: str) -> User:
def get_user(self, username: str) -> User | None:
...

def check_authorization(self, user: User | None, operation: Operation, resource: str) -> bool:
Expand All @@ -114,24 +114,24 @@ def check_authorization_any(self, user: User | None, operation: Operation, resou


class AuthenticatorAndAuthorizer(AuthenticateProtocol, AuthorizeProtocol):
MUST_AUTHENTICATE: Final[bool] = True
MUST_AUTHENTICATE: bool = True

async def close(self) -> None:
...

async def start(self, stats: StatsClient) -> None:
async def start(self, stats: StatsClient) -> None: # pylint: disable=unused-argument
...


class NoAuthAndAuthz(AuthenticatorAndAuthorizer):
MUST_AUTHENTICATE: Final[bool] = False
MUST_AUTHENTICATE: bool = False

@override
def authenticate(self, *, username: str, password: str) -> User:
def authenticate(self, *, username: str, password: str) -> User | None:
return None

@override
def get_user(self, username: str) -> User:
def get_user(self, username: str) -> User | None:
return None

@override
Expand All @@ -156,7 +156,7 @@ def __init__(self, *, user_db: dict[str, User] | None = None, permissions: list[
self.user_db = user_db or {}
self.permissions = permissions or []

def get_user(self, username: str) -> User:
def get_user(self, username: str) -> User | None:
user = self.user_db.get(username)
if not user:
raise ValueError("No user found")
Expand Down Expand Up @@ -289,7 +289,7 @@ def _load_authfile(self) -> None:
raise InvalidConfiguration("Failed to load auth file") from ex

@override
def authenticate(self, *, username: str, password: str) -> User:
def authenticate(self, *, username: str, password: str) -> User | None:
user = self.get_user(username)
if user is None or not user.compare_password(password):
raise AuthenticationError()
Expand Down
2 changes: 1 addition & 1 deletion src/karapace/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_rest_base_uri(self) -> str:

def to_env_str(self) -> str:
env_lines: list[str] = []
for key, value in self.dict().items():
for key, value in self.model_dump().items():
if value is not None:
env_lines.append(f"{key.upper()}={value}")
return "\n".join(env_lines)
Expand Down
84 changes: 67 additions & 17 deletions src/karapace/forward_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
See LICENSE for details
"""

from fastapi import Request, Response
from fastapi.responses import JSONResponse, PlainTextResponse
from fastapi import HTTPException, Request, status
from karapace.utils import json_decode
from karapace.version import __version__
from pydantic import BaseModel
from typing import overload, TypeVar, Union

import aiohttp
import async_timeout
Expand All @@ -14,16 +16,30 @@
LOG = logging.getLogger(__name__)


BaseModelResponse = TypeVar("BaseModelResponse", bound=BaseModel)
SimpleTypeResponse = TypeVar("SimpleTypeResponse", bound=Union[int, list[int]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the Union, maybe | works, will test it later.



class ForwardClient:
USER_AGENT = f"Karapace/{__version__}"

def __init__(self):
def __init__(self) -> None:
self._forward_client: aiohttp.ClientSession | None = None

def _get_forward_client(self) -> aiohttp.ClientSession:
return aiohttp.ClientSession(headers={"User-Agent": ForwardClient.USER_AGENT})

async def forward_request_remote(self, *, request: Request, primary_url: str) -> Response:
def _acceptable_response_content_type(self, *, content_type: str) -> bool:
return (
content_type.startswith("application/") and content_type.endswith("json")
) or content_type == "application/octet-stream"

async def _forward_request_remote(
self,
*,
request: Request,
primary_url: str,
) -> bytes:
LOG.info("Forwarding %s request to remote url: %r since we're not the master", request.method, request.url)
timeout = 60.0
headers = request.headers.mutablecopy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a piece of code for the authorization header, do we want to bring this back?
i.e.

# auth_header = request.headers.get("Authorization")
# if auth_header is not None:
#    headers["Authorization"] = auth_header

Copy link
Contributor

@nosahama nosahama Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's also this piece of code: func = getattr(self._get_forward_client(), request.method.lower()), this would mean that we generate a new client session every time we want to forward a request. Maybe we move the session initialization to the __init__() function, assign directly to self._forward_client and that will get reused for the whole app runtime and we can even add a close() function that cleans up the client session in the application lifespan.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, probably not in this PR though.

Expand All @@ -37,18 +53,52 @@ async def forward_request_remote(self, *, request: Request, primary_url: str) ->
forward_url = f"{forward_url}?{request.url.query}"
logging.error(forward_url)

with async_timeout.timeout(timeout):
async with async_timeout.timeout(timeout):
body_data = await request.body()
async with func(forward_url, headers=headers, data=body_data) as response:
if response.headers.get("Content-Type", "").startswith("application/json"):
return JSONResponse(
content=await response.text(),
status_code=response.status,
headers=response.headers,
)
else:
return PlainTextResponse(
content=await response.text(),
status_code=response.status,
headers=response.headers,
)
if self._acceptable_response_content_type(content_type=response.headers.get("Content-Type")):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the server cannot be reached, i.e. there's even no response, etc, or 5xx status.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling improvement to do later.

return await response.text()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would there be anytime when we expect text/plain?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error cases, I can't think of anything else.

LOG.error("Unknown response for forwarded request: %s", response)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"error_code": status.HTTP_500_INTERNAL_SERVER_ERROR,
"message": "Unknown response for forwarded request.",
},
)

@overload
async def forward_request_remote(
self,
*,
request: Request,
primary_url: str,
response_type: type[BaseModelResponse],
) -> BaseModelResponse:
...

@overload
async def forward_request_remote(
self,
*,
request: Request,
primary_url: str,
response_type: type[SimpleTypeResponse],
) -> SimpleTypeResponse:
...

async def forward_request_remote(
self,
*,
request: Request,
primary_url: str,
response_type: type[BaseModelResponse] | type[SimpleTypeResponse],
) -> BaseModelResponse | SimpleTypeResponse:
body = await self._forward_request_remote(request=request, primary_url=primary_url)
if response_type == int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using isinstance()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The response_type is a type not an int object.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yh makes sense, hence the type vars 👍

return int(body) # type: ignore[return-value]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a cast will work to avoid the # type: ignore[return-value], we can cast to the SimpleTypeResponse, but not sure if this works, will test in another PR.

if response_type == list[int]:
return json_decode(body, assume_type=list[int]) # type: ignore[return-value]
if issubclass(response_type, BaseModel):
return response_type.parse_raw(body) # type: ignore[return-value]
raise ValueError("Did not match any expected type")
3 changes: 1 addition & 2 deletions src/karapace/rapu.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def __init__(
def _create_aiohttp_application(self, *, config: Config) -> aiohttp.web.Application:
if config.http_request_max_size:
return aiohttp.web.Application(client_max_size=config.http_request_max_size)
else:
return aiohttp.web.Application()
return aiohttp.web.Application()

async def close_by_app(self, app: aiohttp.web.Application) -> None: # pylint: disable=unused-argument
await self.close()
Expand Down
4 changes: 2 additions & 2 deletions src/karapace/statsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from collections.abc import Iterator
from contextlib import contextmanager
from karapace.config import Config
from karapace.config import Config, KarapaceTags
from karapace.sentry import get_sentry_client
from typing import Any, Final

Expand All @@ -28,7 +28,7 @@ class StatsClient:
def __init__(self, config: Config) -> None:
self._dest_addr: Final = (config.statsd_host, config.statsd_port)
self._socket: Final = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._tags: Final = config.tags or {}
self._tags: Final[KarapaceTags] = config.tags
self.sentry_client: Final = get_sentry_client(sentry_config=(config.sentry or None))

@contextmanager
Expand Down
6 changes: 3 additions & 3 deletions src/karapace/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from collections.abc import Generator, Mapping, Sequence
from enum import Enum, unique
from karapace.errors import InvalidVersion
from pydantic import ValidationInfo
from typing import Any, ClassVar, NewType, Union
from typing import Any, Callable, ClassVar, NewType, Union
from typing_extensions import TypeAlias

import functools
Expand Down Expand Up @@ -38,7 +38,7 @@ class Subject(str):
@classmethod
# TODO[pydantic]: We couldn't refactor `__get_validators__`, please create the `__get_pydantic_core_schema__` manually.
# Check https://docs.pydantic.dev/latest/migration/#defining-custom-types for more information.
def __get_validators__(cls):
def __get_validators__(cls) -> Generator[Callable[[str, ValidationInfo], str], None, None]:
yield cls.validate

@classmethod
Expand Down
11 changes: 9 additions & 2 deletions src/schema_registry/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from schema_registry.http_handlers import setup_exception_handlers
from schema_registry.middlewares import setup_middlewares
from schema_registry.routers.setup import setup_routers
from typing import AsyncContextManager, Callable

import logging

Expand Down Expand Up @@ -44,12 +45,18 @@ async def karapace_schema_registry_lifespan(
stastd.close()


def create_karapace_application(*, config: Config, lifespan: AsyncGenerator[None, None]) -> FastAPI:
def create_karapace_application(
*,
config: Config,
lifespan: Callable[
[FastAPI, StatsClient, KarapaceSchemaRegistry, AuthenticatorAndAuthorizer], AsyncContextManager[None]
],
) -> FastAPI:
configure_logging(config=config)
log_config_without_secrets(config=config)
logging.info("Starting Karapace Schema Registry (%s)", karapace_version.__version__)

app = FastAPI(lifespan=lifespan)
app = FastAPI(lifespan=lifespan) # type: ignore[arg-type]
setup_routers(app=app)
setup_exception_handlers(app=app)
setup_middlewares(app=app)
Expand Down
4 changes: 2 additions & 2 deletions src/schema_registry/http_handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

def setup_exception_handlers(app: FastAPI) -> None:
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(_: StarletteHTTPRequest, exc: StarletteHTTPException):
async def http_exception_handler(_: StarletteHTTPRequest, exc: StarletteHTTPException) -> JSONResponse:
return JSONResponse(status_code=exc.status_code, content=exc.detail)

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_: StarletteHTTPRequest, exc: RequestValidationError):
async def validation_exception_handler(_: StarletteHTTPRequest, exc: RequestValidationError) -> JSONResponse:
error_code = HTTPStatus.UNPROCESSABLE_ENTITY.value
if isinstance(exc, KarapaceValidationError):
error_code = exc.error_code
Expand Down
8 changes: 5 additions & 3 deletions src/schema_registry/middlewares/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
See LICENSE for details
"""

from fastapi import FastAPI, HTTPException, Request
from collections.abc import Awaitable
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse
from karapace.content_type import check_schema_headers
from typing import Callable


def setup_middlewares(app: FastAPI) -> None:
@app.middleware("http")
async def set_content_types(request: Request, call_next):
async def set_content_types(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
try:
response_content_type = check_schema_headers(request)
except HTTPException as exc:
Expand All @@ -25,7 +27,7 @@ async def set_content_types(request: Request, call_next):
if request.headers.get("Content-Type") == "application/octet-stream":
new_headers = request.headers.mutablecopy()
new_headers["Content-Type"] = "application/json"
request._headers = new_headers
request._headers = new_headers # pylint: disable=protected-access
request.scope.update(headers=request.headers.raw)

response = await call_next(request)
Expand Down
21 changes: 12 additions & 9 deletions src/schema_registry/routers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ async def config_put(
i_am_primary, primary_url = await schema_registry.get_master()
if i_am_primary:
return await controller.config_set(compatibility_level_request=compatibility_level_request)
elif not primary_url:
if not primary_url:
raise no_primary_url_error()
else:
return await forward_client.forward_request_remote(request=request, primary_url=primary_url)
return await forward_client.forward_request_remote(
request=request, primary_url=primary_url, response_type=CompatibilityResponse
)


@config_router.get("/{subject}")
Expand Down Expand Up @@ -92,10 +93,11 @@ async def config_set_subject(
i_am_primary, primary_url = await schema_registry.get_master()
if i_am_primary:
return await controller.config_subject_set(subject=subject, compatibility_level_request=compatibility_level_request)
elif not primary_url:
if not primary_url:
raise no_primary_url_error()
else:
return await forward_client.forward_request_remote(request=request, primary_url=primary_url)
return await forward_client.forward_request_remote(
request=request, primary_url=primary_url, response_type=CompatibilityResponse
)


@config_router.delete("/{subject}")
Expand All @@ -115,7 +117,8 @@ async def config_delete_subject(
i_am_primary, primary_url = await schema_registry.get_master()
if i_am_primary:
return await controller.config_subject_delete(subject=subject)
elif not primary_url:
if not primary_url:
raise no_primary_url_error()
else:
return await forward_client.forward_request_remote(request=request, primary_url=primary_url)
return await forward_client.forward_request_remote(
request=request, primary_url=primary_url, response_type=CompatibilityResponse
)
Loading
Loading