Skip to content

Commit

Permalink
forwarding_client: use single session and close on shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
nosahama committed Dec 10, 2024
1 parent d9b5130 commit cadc599
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 128 deletions.
2 changes: 1 addition & 1 deletion GNUmakefile
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ unit-tests: venv/.deps-dev
rm -fr runtime/*

.PHONY: integration-tests
unit-tests: export PYTEST_ARGS ?=
integration-tests: export PYTEST_ARGS ?=
integration-tests: venv/.deps-dev
rm -fr runtime/*
$(PYTHON) -m pytest -s -vvv $(PYTEST_ARGS) tests/integration/
Expand Down
14 changes: 5 additions & 9 deletions src/karapace/forward_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ class ForwardClient:
USER_AGENT = f"Karapace/{__version__}"

def __init__(self) -> None:
self._forward_client: aiohttp.ClientSession | None = None
self._forward_client: aiohttp.ClientSession = aiohttp.ClientSession(headers={"User-Agent": self.USER_AGENT})

def _get_forward_client(self) -> aiohttp.ClientSession:
return aiohttp.ClientSession(headers={"User-Agent": ForwardClient.USER_AGENT})
async def close(self) -> None:
await self._forward_client.close()

def _acceptable_response_content_type(self, *, content_type: str) -> bool:
return (
Expand All @@ -42,11 +42,7 @@ async def _forward_request_remote(
) -> bytes:
LOG.info("Forwarding %s request to remote url: %r since we're not the master", request.method, request.url)
timeout = 60.0
headers = request.headers.mutablecopy()
func = getattr(self._get_forward_client(), request.method.lower())
# auth_header = request.headers.get("Authorization")
# if auth_header is not None:
# headers["Authorization"] = auth_header
func = getattr(self._forward_client, request.method.lower())

forward_url = f"{primary_url}{request.url.path}"
if request.url.query:
Expand All @@ -55,7 +51,7 @@ async def _forward_request_remote(

async with async_timeout.timeout(timeout):
body_data = await request.body()
async with func(forward_url, headers=headers, data=body_data) as response:
async with func(forward_url, headers=request.headers.mutablecopy(), data=body_data) as response:
if self._acceptable_response_content_type(content_type=response.headers.get("Content-Type")):
return await response.text()
LOG.error("Unknown response for forwarded request: %s", response)
Expand Down
14 changes: 7 additions & 7 deletions src/schema_registry/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from karapace import version as karapace_version
from karapace.auth import AuthenticatorAndAuthorizer
from karapace.config import Config
from karapace.forward_client import ForwardClient
from karapace.logging_setup import configure_logging, log_config_without_secrets
from karapace.schema_registry import KarapaceSchemaRegistry
from karapace.statsd import StatsClient
Expand All @@ -26,6 +27,7 @@
@inject
async def karapace_schema_registry_lifespan(
_: FastAPI,
forward_client: ForwardClient = Depends(Provide[SchemaRegistryContainer.karapace_container.forward_client]),
stastd: StatsClient = Depends(Provide[SchemaRegistryContainer.karapace_container.statsd]),
schema_registry: KarapaceSchemaRegistry = Depends(Provide[SchemaRegistryContainer.karapace_container.schema_registry]),
authorizer: AuthenticatorAndAuthorizer = Depends(Provide[SchemaRegistryContainer.karapace_container.authorizer]),
Expand All @@ -37,19 +39,17 @@ async def karapace_schema_registry_lifespan(

yield
finally:
if schema_registry:
await schema_registry.close()
if authorizer:
await authorizer.close()
if stastd:
stastd.close()
await schema_registry.close()
await authorizer.close()
await forward_client.close()
stastd.close()


def create_karapace_application(
*,
config: Config,
lifespan: Callable[
[FastAPI, StatsClient, KarapaceSchemaRegistry, AuthenticatorAndAuthorizer], AsyncContextManager[None]
[FastAPI, ForwardClient, StatsClient, KarapaceSchemaRegistry, AuthenticatorAndAuthorizer], AsyncContextManager[None]
],
) -> FastAPI:
configure_logging(config=config)
Expand Down
225 changes: 114 additions & 111 deletions tests/unit/test_forwarding_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
karapace - schema registry authentication and authorization tests
karapace - master forwarding tests
Copyright (c) 2023 Aiven Ltd
Copyright (c) 2024 Aiven Ltd
See LICENSE for details
"""
from __future__ import annotations
Expand All @@ -15,6 +15,7 @@
from tests.base_testcase import BaseTestCase
from unittest.mock import AsyncMock, Mock, patch

import aiohttp
import pytest


Expand All @@ -28,6 +29,20 @@ class ContentTypeTestCase(BaseTestCase):
content_type: str


@pytest.fixture(name="forward_client")
def fixture_forward_client() -> ForwardClient:
with patch("karapace.forward_client.aiohttp") as mocked_aiohttp:
mocked_aiohttp.ClientSession.return_value = Mock(
spec=aiohttp.ClientSession, headers={"User-Agent": ForwardClient.USER_AGENT}
)
return ForwardClient()


async def test_forward_client_close(forward_client: ForwardClient) -> None:
await forward_client.close()
forward_client._forward_client.close.assert_awaited_once() # pylint: disable=protected-access


@pytest.mark.parametrize(
"testcase",
[
Expand All @@ -42,112 +57,100 @@ class ContentTypeTestCase(BaseTestCase):
],
ids=str,
)
async def test_forward_request_with_basemodel_response(testcase: ContentTypeTestCase) -> None:
forward_client = ForwardClient()
with patch.object(forward_client, "_get_forward_client", autospec=True) as mock_get_forward_client:
mock_request = Mock(spec=Request)
mock_request.method = "GET"
mock_request.headers = Headers()

mock_aiohttp_session = Mock()
mock_get_forward_client.return_value = mock_aiohttp_session
mock_get_func = Mock()
mock_response_context = AsyncMock
mock_response = AsyncMock()
mock_response_context.call_function = lambda _: mock_response
mock_response.text.return_value = '{"number":10,"string":"important"}'
headers = MutableHeaders()
headers["Content-Type"] = testcase.content_type
mock_response.headers = headers

async def mock_aenter(_) -> Mock:
return mock_response

async def mock_aexit(_, __, ___, ____) -> None:
return

mock_get_func.__aenter__ = mock_aenter
mock_get_func.__aexit__ = mock_aexit
mock_aiohttp_session.get.return_value = mock_get_func

response = await forward_client.forward_request_remote(
request=mock_request,
primary_url="test-url",
response_type=TestResponse,
)

assert response == TestResponse(number=10, string="important")


async def test_forward_request_with_integer_list_response() -> None:
forward_client = ForwardClient()
with patch.object(forward_client, "_get_forward_client", autospec=True) as mock_get_forward_client:
mock_request = Mock(spec=Request)
mock_request.method = "GET"
mock_request.headers = Headers()

mock_aiohttp_session = Mock()
mock_get_forward_client.return_value = mock_aiohttp_session
mock_get_func = Mock()
mock_response_context = AsyncMock
mock_response = AsyncMock()
mock_response_context.call_function = lambda _: mock_response
mock_response.text.return_value = "[1, 2, 3, 10]"
headers = MutableHeaders()
headers["Content-Type"] = "application/json"
mock_response.headers = headers

async def mock_aenter(_) -> Mock:
return mock_response

async def mock_aexit(_, __, ___, ____) -> None:
return

mock_get_func.__aenter__ = mock_aenter
mock_get_func.__aexit__ = mock_aexit
mock_aiohttp_session.get.return_value = mock_get_func

response = await forward_client.forward_request_remote(
request=mock_request,
primary_url="test-url",
response_type=list[int],
)

assert response == [1, 2, 3, 10]


async def test_forward_request_with_integer_response() -> None:
forward_client = ForwardClient()
with patch.object(forward_client, "_get_forward_client", autospec=True) as mock_get_forward_client:
mock_request = Mock(spec=Request)
mock_request.method = "GET"
mock_request.headers = Headers()

mock_aiohttp_session = Mock()
mock_get_forward_client.return_value = mock_aiohttp_session
mock_get_func = Mock()
mock_response_context = AsyncMock
mock_response = AsyncMock()
mock_response_context.call_function = lambda _: mock_response
mock_response.text.return_value = "12"
headers = MutableHeaders()
headers["Content-Type"] = "application/json"
mock_response.headers = headers

async def mock_aenter(_) -> Mock:
return mock_response

async def mock_aexit(_, __, ___, ____) -> None:
return

mock_get_func.__aenter__ = mock_aenter
mock_get_func.__aexit__ = mock_aexit
mock_aiohttp_session.get.return_value = mock_get_func

response = await forward_client.forward_request_remote(
request=mock_request,
primary_url="test-url",
response_type=int,
)

assert response == 12
async def test_forward_request_with_basemodel_response(forward_client: ForwardClient, testcase: ContentTypeTestCase) -> None:
mock_request = Mock(spec=Request)
mock_request.method = "GET"
mock_request.headers = Headers()

mock_get_func = Mock()
mock_response_context = AsyncMock
mock_response = AsyncMock()
mock_response_context.call_function = lambda _: mock_response
mock_response.text.return_value = '{"number":10,"string":"important"}'
headers = MutableHeaders()
headers["Content-Type"] = testcase.content_type
mock_response.headers = headers

async def mock_aenter(_) -> Mock:
return mock_response

async def mock_aexit(_, __, ___, ____) -> None:
return

mock_get_func.__aenter__ = mock_aenter
mock_get_func.__aexit__ = mock_aexit
forward_client._forward_client.get.return_value = mock_get_func # pylint: disable=protected-access

response = await forward_client.forward_request_remote(
request=mock_request,
primary_url="test-url",
response_type=TestResponse,
)

assert response == TestResponse(number=10, string="important")


async def test_forward_request_with_integer_list_response(forward_client: ForwardClient) -> None:
mock_request = Mock(spec=Request)
mock_request.method = "GET"
mock_request.headers = Headers()

mock_get_func = Mock()
mock_response_context = AsyncMock
mock_response = AsyncMock()
mock_response_context.call_function = lambda _: mock_response
mock_response.text.return_value = "[1, 2, 3, 10]"
headers = MutableHeaders()
headers["Content-Type"] = "application/json"
mock_response.headers = headers

async def mock_aenter(_) -> Mock:
return mock_response

async def mock_aexit(_, __, ___, ____) -> None:
return

mock_get_func.__aenter__ = mock_aenter
mock_get_func.__aexit__ = mock_aexit
forward_client._forward_client.get.return_value = mock_get_func # pylint: disable=protected-access

response = await forward_client.forward_request_remote(
request=mock_request,
primary_url="test-url",
response_type=list[int],
)

assert response == [1, 2, 3, 10]


async def test_forward_request_with_integer_response(forward_client: ForwardClient) -> None:
mock_request = Mock(spec=Request)
mock_request.method = "GET"
mock_request.headers = Headers()

mock_get_func = Mock()
mock_response_context = AsyncMock
mock_response = AsyncMock()
mock_response_context.call_function = lambda _: mock_response
mock_response.text.return_value = "12"
headers = MutableHeaders()
headers["Content-Type"] = "application/json"
mock_response.headers = headers

async def mock_aenter(_) -> Mock:
return mock_response

async def mock_aexit(_, __, ___, ____) -> None:
return

mock_get_func.__aenter__ = mock_aenter
mock_get_func.__aexit__ = mock_aexit
forward_client._forward_client.get.return_value = mock_get_func # pylint: disable=protected-access

response = await forward_client.forward_request_remote(
request=mock_request,
primary_url="test-url",
response_type=int,
)

assert response == 12

0 comments on commit cadc599

Please sign in to comment.