Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

forwarding_client: use single session and close on shutdown #1006

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion GNUmakefile
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ unit-tests: venv/.deps-dev
rm -fr runtime/*

.PHONY: integration-tests
unit-tests: export PYTEST_ARGS ?=
integration-tests: export PYTEST_ARGS ?=
integration-tests: venv/.deps-dev
rm -fr runtime/*
$(PYTHON) -m pytest -s -vvv $(PYTEST_ARGS) tests/integration/
Expand Down
14 changes: 5 additions & 9 deletions src/karapace/forward_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ class ForwardClient:
USER_AGENT = f"Karapace/{__version__}"

def __init__(self) -> None:
self._forward_client: aiohttp.ClientSession | None = None
self._forward_client: aiohttp.ClientSession = aiohttp.ClientSession(headers={"User-Agent": self.USER_AGENT})

def _get_forward_client(self) -> aiohttp.ClientSession:
return aiohttp.ClientSession(headers={"User-Agent": ForwardClient.USER_AGENT})
async def close(self) -> None:
await self._forward_client.close()

def _acceptable_response_content_type(self, *, content_type: str) -> bool:
return (
Expand All @@ -42,11 +42,7 @@ async def _forward_request_remote(
) -> bytes:
LOG.info("Forwarding %s request to remote url: %r since we're not the master", request.method, request.url)
timeout = 60.0
headers = request.headers.mutablecopy()
func = getattr(self._get_forward_client(), request.method.lower())
# auth_header = request.headers.get("Authorization")
# if auth_header is not None:
# headers["Authorization"] = auth_header
func = getattr(self._forward_client, request.method.lower())

forward_url = f"{primary_url}{request.url.path}"
if request.url.query:
Expand All @@ -55,7 +51,7 @@ async def _forward_request_remote(

async with async_timeout.timeout(timeout):
body_data = await request.body()
async with func(forward_url, headers=headers, data=body_data) as response:
async with func(forward_url, headers=request.headers.mutablecopy(), data=body_data) as response:
if self._acceptable_response_content_type(content_type=response.headers.get("Content-Type")):
return await response.text()
LOG.error("Unknown response for forwarded request: %s", response)
Expand Down
14 changes: 7 additions & 7 deletions src/schema_registry/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from karapace import version as karapace_version
from karapace.auth import AuthenticatorAndAuthorizer
from karapace.config import Config
from karapace.forward_client import ForwardClient
from karapace.logging_setup import configure_logging, log_config_without_secrets
from karapace.schema_registry import KarapaceSchemaRegistry
from karapace.statsd import StatsClient
Expand All @@ -26,6 +27,7 @@
@inject
async def karapace_schema_registry_lifespan(
_: FastAPI,
forward_client: ForwardClient = Depends(Provide[SchemaRegistryContainer.karapace_container.forward_client]),
stastd: StatsClient = Depends(Provide[SchemaRegistryContainer.karapace_container.statsd]),
schema_registry: KarapaceSchemaRegistry = Depends(Provide[SchemaRegistryContainer.karapace_container.schema_registry]),
authorizer: AuthenticatorAndAuthorizer = Depends(Provide[SchemaRegistryContainer.karapace_container.authorizer]),
Expand All @@ -37,19 +39,17 @@ async def karapace_schema_registry_lifespan(

yield
finally:
if schema_registry:
await schema_registry.close()
if authorizer:
await authorizer.close()
if stastd:
stastd.close()
await schema_registry.close()
await authorizer.close()
await forward_client.close()
stastd.close()


def create_karapace_application(
*,
config: Config,
lifespan: Callable[
[FastAPI, StatsClient, KarapaceSchemaRegistry, AuthenticatorAndAuthorizer], AsyncContextManager[None]
[FastAPI, ForwardClient, StatsClient, KarapaceSchemaRegistry, AuthenticatorAndAuthorizer], AsyncContextManager[None]
],
) -> FastAPI:
configure_logging(config=config)
Expand Down
225 changes: 114 additions & 111 deletions tests/unit/test_forwarding_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
karapace - schema registry authentication and authorization tests
karapace - master forwarding tests

Copyright (c) 2023 Aiven Ltd
Copyright (c) 2024 Aiven Ltd
See LICENSE for details
"""
from __future__ import annotations
Expand All @@ -15,6 +15,7 @@
from tests.base_testcase import BaseTestCase
from unittest.mock import AsyncMock, Mock, patch

import aiohttp
import pytest


Expand All @@ -28,6 +29,20 @@ class ContentTypeTestCase(BaseTestCase):
content_type: str


@pytest.fixture(name="forward_client")
def fixture_forward_client() -> ForwardClient:
with patch("karapace.forward_client.aiohttp") as mocked_aiohttp:
mocked_aiohttp.ClientSession.return_value = Mock(
spec=aiohttp.ClientSession, headers={"User-Agent": ForwardClient.USER_AGENT}
)
return ForwardClient()


async def test_forward_client_close(forward_client: ForwardClient) -> None:
await forward_client.close()
forward_client._forward_client.close.assert_awaited_once() # pylint: disable=protected-access


@pytest.mark.parametrize(
"testcase",
[
Expand All @@ -42,112 +57,100 @@ class ContentTypeTestCase(BaseTestCase):
],
ids=str,
)
async def test_forward_request_with_basemodel_response(testcase: ContentTypeTestCase) -> None:
nosahama marked this conversation as resolved.
Show resolved Hide resolved
forward_client = ForwardClient()
with patch.object(forward_client, "_get_forward_client", autospec=True) as mock_get_forward_client:
mock_request = Mock(spec=Request)
mock_request.method = "GET"
mock_request.headers = Headers()

mock_aiohttp_session = Mock()
mock_get_forward_client.return_value = mock_aiohttp_session
mock_get_func = Mock()
mock_response_context = AsyncMock
mock_response = AsyncMock()
mock_response_context.call_function = lambda _: mock_response
mock_response.text.return_value = '{"number":10,"string":"important"}'
headers = MutableHeaders()
headers["Content-Type"] = testcase.content_type
mock_response.headers = headers

async def mock_aenter(_) -> Mock:
return mock_response

async def mock_aexit(_, __, ___, ____) -> None:
return

mock_get_func.__aenter__ = mock_aenter
mock_get_func.__aexit__ = mock_aexit
mock_aiohttp_session.get.return_value = mock_get_func

response = await forward_client.forward_request_remote(
request=mock_request,
primary_url="test-url",
response_type=TestResponse,
)

assert response == TestResponse(number=10, string="important")


async def test_forward_request_with_integer_list_response() -> None:
forward_client = ForwardClient()
with patch.object(forward_client, "_get_forward_client", autospec=True) as mock_get_forward_client:
mock_request = Mock(spec=Request)
mock_request.method = "GET"
mock_request.headers = Headers()

mock_aiohttp_session = Mock()
mock_get_forward_client.return_value = mock_aiohttp_session
mock_get_func = Mock()
mock_response_context = AsyncMock
mock_response = AsyncMock()
mock_response_context.call_function = lambda _: mock_response
mock_response.text.return_value = "[1, 2, 3, 10]"
headers = MutableHeaders()
headers["Content-Type"] = "application/json"
mock_response.headers = headers

async def mock_aenter(_) -> Mock:
return mock_response

async def mock_aexit(_, __, ___, ____) -> None:
return

mock_get_func.__aenter__ = mock_aenter
mock_get_func.__aexit__ = mock_aexit
mock_aiohttp_session.get.return_value = mock_get_func

response = await forward_client.forward_request_remote(
request=mock_request,
primary_url="test-url",
response_type=list[int],
)

assert response == [1, 2, 3, 10]


async def test_forward_request_with_integer_response() -> None:
forward_client = ForwardClient()
with patch.object(forward_client, "_get_forward_client", autospec=True) as mock_get_forward_client:
mock_request = Mock(spec=Request)
mock_request.method = "GET"
mock_request.headers = Headers()

mock_aiohttp_session = Mock()
mock_get_forward_client.return_value = mock_aiohttp_session
mock_get_func = Mock()
mock_response_context = AsyncMock
mock_response = AsyncMock()
mock_response_context.call_function = lambda _: mock_response
mock_response.text.return_value = "12"
headers = MutableHeaders()
headers["Content-Type"] = "application/json"
mock_response.headers = headers

async def mock_aenter(_) -> Mock:
return mock_response

async def mock_aexit(_, __, ___, ____) -> None:
return

mock_get_func.__aenter__ = mock_aenter
mock_get_func.__aexit__ = mock_aexit
mock_aiohttp_session.get.return_value = mock_get_func

response = await forward_client.forward_request_remote(
request=mock_request,
primary_url="test-url",
response_type=int,
)

assert response == 12
async def test_forward_request_with_basemodel_response(forward_client: ForwardClient, testcase: ContentTypeTestCase) -> None:
mock_request = Mock(spec=Request)
mock_request.method = "GET"
mock_request.headers = Headers()

mock_get_func = Mock()
mock_response_context = AsyncMock
mock_response = AsyncMock()
mock_response_context.call_function = lambda _: mock_response
mock_response.text.return_value = '{"number":10,"string":"important"}'
headers = MutableHeaders()
headers["Content-Type"] = testcase.content_type
mock_response.headers = headers

async def mock_aenter(_) -> Mock:
return mock_response

async def mock_aexit(_, __, ___, ____) -> None:
return

mock_get_func.__aenter__ = mock_aenter
mock_get_func.__aexit__ = mock_aexit
forward_client._forward_client.get.return_value = mock_get_func # pylint: disable=protected-access

response = await forward_client.forward_request_remote(
request=mock_request,
primary_url="test-url",
response_type=TestResponse,
)

assert response == TestResponse(number=10, string="important")


async def test_forward_request_with_integer_list_response(forward_client: ForwardClient) -> None:
mock_request = Mock(spec=Request)
mock_request.method = "GET"
mock_request.headers = Headers()

mock_get_func = Mock()
mock_response_context = AsyncMock
mock_response = AsyncMock()
mock_response_context.call_function = lambda _: mock_response
mock_response.text.return_value = "[1, 2, 3, 10]"
headers = MutableHeaders()
headers["Content-Type"] = "application/json"
mock_response.headers = headers

async def mock_aenter(_) -> Mock:
return mock_response

async def mock_aexit(_, __, ___, ____) -> None:
return

mock_get_func.__aenter__ = mock_aenter
mock_get_func.__aexit__ = mock_aexit
forward_client._forward_client.get.return_value = mock_get_func # pylint: disable=protected-access

response = await forward_client.forward_request_remote(
request=mock_request,
primary_url="test-url",
response_type=list[int],
)

assert response == [1, 2, 3, 10]


async def test_forward_request_with_integer_response(forward_client: ForwardClient) -> None:
mock_request = Mock(spec=Request)
mock_request.method = "GET"
mock_request.headers = Headers()

mock_get_func = Mock()
mock_response_context = AsyncMock
mock_response = AsyncMock()
mock_response_context.call_function = lambda _: mock_response
mock_response.text.return_value = "12"
headers = MutableHeaders()
headers["Content-Type"] = "application/json"
mock_response.headers = headers

async def mock_aenter(_) -> Mock:
return mock_response

async def mock_aexit(_, __, ___, ____) -> None:
return

mock_get_func.__aenter__ = mock_aenter
mock_get_func.__aexit__ = mock_aexit
forward_client._forward_client.get.return_value = mock_get_func # pylint: disable=protected-access

response = await forward_client.forward_request_remote(
request=mock_request,
primary_url="test-url",
response_type=int,
)

assert response == 12
Loading