diff --git a/karapace/kafka_rest_apis/__init__.py b/karapace/kafka_rest_apis/__init__.py index 329b5361d..5daf09885 100644 --- a/karapace/kafka_rest_apis/__init__.py +++ b/karapace/kafka_rest_apis/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from binascii import Error as B64DecodeError from collections import namedtuple from confluent_kafka.error import KafkaException @@ -36,7 +38,7 @@ ) from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType from karapace.utils import convert_to_int, json_encode -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, TypedDict import asyncio import base64 @@ -66,10 +68,10 @@ def __init__(self, config: Config) -> None: super().__init__(config=config) self._add_kafka_rest_routes() self.serializer = SchemaRegistrySerializer(config=config) - self.proxies: Dict[str, "UserRestProxy"] = {} + self.proxies: dict[str, UserRestProxy] = {} self._proxy_lock = asyncio.Lock() log.info("REST proxy starting with (delegated authorization=%s)", self.config.get("rest_authorization", False)) - self._idle_proxy_janitor_task: Optional[asyncio.Task] = None + self._idle_proxy_janitor_task: asyncio.Task | None = None async def close(self) -> None: log.info("Closing REST proxy application") @@ -416,32 +418,56 @@ async def topic_publish(self, topic: str, content_type: str, *, request: HTTPReq await proxy.topic_publish(topic, content_type, request=request) +class _ReplicaMetadata(TypedDict): + broker: int + leader: bool + in_sync: bool + + +class _PartitionMetadata(TypedDict): + partition: int + leader: int + replicas: list[_ReplicaMetadata] + + +class _TopicMetadata(TypedDict): + partitions: list[_PartitionMetadata] + + +class _ClusterMetadata(TypedDict): + topics: dict[str, _TopicMetadata] + brokers: list[int] + + class UserRestProxy: def __init__( self, config: Config, kafka_timeout: int, serializer: SchemaRegistrySerializer, - auth_expiry: Optional[datetime.datetime] = None, + auth_expiry: datetime.datetime | None = None, + verify_connection: bool = True, ): self.config = config self.kafka_timeout = kafka_timeout self.serializer = serializer self._cluster_metadata = None self._cluster_metadata_complete = False - self._metadata_birth = None + # birth of all the metadata (when the request was requiring all the metadata available in the cluster) + self._global_metadata_birth: float | None = None + self._cluster_metadata_topic_birth: dict[str, float] = {} self.metadata_max_age = self.config["admin_metadata_max_age"] self.admin_client = None self.admin_lock = asyncio.Lock() self.metadata_cache = None self.topic_schema_cache = TopicSchemaCache() self.consumer_manager = ConsumerManager(config=config, deserializer=self.serializer) - self.init_admin_client() + self.init_admin_client(verify_connection) self._last_used = time.monotonic() self._auth_expiry = auth_expiry self._async_producer_lock = asyncio.Lock() - self._async_producer: Optional[AsyncKafkaProducer] = None + self._async_producer: AsyncKafkaProducer | None = None self.naming_strategy = NameStrategy(self.config["name_strategy"]) def __str__(self) -> str: @@ -607,28 +633,72 @@ async def get_topic_config(self, topic: str) -> dict: async with self.admin_lock: return self.admin_client.get_topic_config(topic) - async def cluster_metadata(self, topics: Optional[List[str]] = None) -> dict: - async with self.admin_lock: - if self._metadata_birth is None or time.monotonic() - self._metadata_birth > self.metadata_max_age: - self._cluster_metadata = None + def is_global_metadata_old(self) -> bool: + return ( + self._global_metadata_birth is None or (time.monotonic() - self._global_metadata_birth) > self.metadata_max_age + ) - if self._cluster_metadata: - # Return from metadata only if all queried topics have cached metadata - if topics is None: - if self._cluster_metadata_complete: - return self._cluster_metadata - elif all(topic in self._cluster_metadata["topics"] for topic in topics): - return { - **self._cluster_metadata, - "topics": {topic: self._cluster_metadata["topics"][topic] for topic in topics}, - } + def is_metadata_of_topics_old(self, topics: list[str]) -> bool: + # Return from metadata only if all queried topics have cached metadata + + if self._cluster_metadata_topic_birth is None: + return True + + are_all_topic_queried_at_least_once = all(topic in self._cluster_metadata_topic_birth for topic in topics) + + if not are_all_topic_queried_at_least_once: + return True + oldest_requested_topic_udpate_timestamp = min(self._cluster_metadata_topic_birth[topic] for topic in topics) + return ( + are_all_topic_queried_at_least_once + and (time.monotonic() - oldest_requested_topic_udpate_timestamp) > self.metadata_max_age + ) + + def _update_all_metadata(self) -> _ClusterMetadata: + if not self.is_global_metadata_old() and self._cluster_metadata_complete: + return self._cluster_metadata + + metadata_birth = time.monotonic() + metadata = self.admin_client.cluster_metadata(None) + for topic in metadata["topics"]: + self._cluster_metadata_topic_birth[topic] = metadata_birth + + self._global_metadata_birth = metadata_birth + self._cluster_metadata = metadata + self._cluster_metadata_complete = True + return metadata + + def _empty_cluster_metadata_cache(self) -> _ClusterMetadata: + return {"topics": {}, "brokers": []} + + def _update_metadata_for_topics(self, topics: list[str]) -> _ClusterMetadata: + if not self.is_metadata_of_topics_old(topics): + return { + **self._cluster_metadata, + "topics": {topic: self._cluster_metadata["topics"][topic] for topic in topics}, + } + + metadata_birth = time.monotonic() + metadata = self.admin_client.cluster_metadata(topics) + + if self._cluster_metadata is None: + self._cluster_metadata = self._empty_cluster_metadata_cache() + + for topic in metadata["topics"]: + self._cluster_metadata_topic_birth[topic] = metadata_birth + self._cluster_metadata["topics"][topic] = metadata["topics"][topic] + + self._cluster_metadata_complete = False + return metadata + + async def cluster_metadata(self, topics: list[str] | None = None) -> _ClusterMetadata: + async with self.admin_lock: try: - metadata_birth = time.monotonic() - metadata = self.admin_client.cluster_metadata(topics) - self._metadata_birth = metadata_birth - self._cluster_metadata = metadata - self._cluster_metadata_complete = topics is None + if topics is None: + metadata = self._update_all_metadata() + else: + metadata = self._update_metadata_for_topics(topics) except KafkaException: log.warning("Could not refresh cluster metadata") KafkaRest.r( @@ -641,7 +711,7 @@ async def cluster_metadata(self, topics: Optional[List[str]] = None) -> dict: ) return metadata - def init_admin_client(self): + def init_admin_client(self, verify_connection: bool = True) -> KafkaAdminClient: for retry in [True, True, False]: try: self.admin_client = KafkaAdminClient( @@ -652,6 +722,7 @@ def init_admin_client(self): ssl_keyfile=self.config["ssl_keyfile"], metadata_max_age_ms=self.config["metadata_max_age_ms"], connections_max_idle_ms=self.config["connections_max_idle_ms"], + verify_connection=verify_connection, **get_kafka_client_auth_parameters_from_config(self.config), ) break @@ -675,7 +746,7 @@ async def aclose(self) -> None: self.admin_client = None self.consumer_manager = None - async def publish(self, topic: str, partition_id: Optional[str], content_type: str, request: HTTPRequest) -> None: + async def publish(self, topic: str, partition_id: str | None, content_type: str, request: HTTPRequest) -> None: """ :raises NoBrokersAvailable: :raises AuthenticationFailedError: @@ -797,7 +868,7 @@ async def get_schema_id( :raises InvalidSchema: """ log.debug("[resolve schema id] Retrieving schema id for %r", data) - schema_id: Union[SchemaId, None] = ( + schema_id: SchemaId | None = ( SchemaId(int(data[f"{subject_type}_schema_id"])) if f"{subject_type}_schema_id" in data else None ) schema_str = data.get(f"{subject_type}_schema") @@ -817,7 +888,7 @@ async def get_schema_id( schema_id = await self._query_schema_id_from_cache_or_registry(parsed_schema, schema_str, subject_name) else: - def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool: + def subject_not_included(schema: TypedSchema, subjects: list[Subject]) -> bool: subject = get_subject_name(topic, schema, subject_type, self.naming_strategy) return subject not in subjects @@ -832,8 +903,8 @@ def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool: return schema_id async def _query_schema_and_subjects( - self, schema_id: SchemaId, *, need_new_call: Optional[Callable[[TypedSchema, List[Subject]], bool]] - ) -> Tuple[TypedSchema, List[Subject]]: + self, schema_id: SchemaId, *, need_new_call: Callable[[TypedSchema, list[Subject]], bool] | None + ) -> tuple[TypedSchema, list[Subject]]: try: return await self.serializer.get_schema_for_id(schema_id, need_new_call=need_new_call) except SchemaRetrievalError as schema_error: @@ -924,10 +995,10 @@ async def _prepare_records( content_type: str, data: dict, ser_format: str, - key_schema_id: Optional[int], - value_schema_id: Optional[int], - default_partition: Optional[int] = None, - ) -> List[Tuple]: + key_schema_id: int | None, + value_schema_id: int | None, + default_partition: int | None = None, + ) -> list[tuple]: prepared_records = [] for record in data["records"]: key = record.get("key") @@ -978,8 +1049,8 @@ async def serialize( self, content_type: str, obj=None, - ser_format: Optional[str] = None, - schema_id: Optional[int] = None, + ser_format: str | None = None, + schema_id: int | None = None, ) -> bytes: if not obj: return b"" @@ -1003,7 +1074,7 @@ async def serialize( return await self.schema_serialize(obj, schema_id) raise FormatError(f"Unknown format: {ser_format}") - async def schema_serialize(self, obj: dict, schema_id: Optional[int]) -> bytes: + async def schema_serialize(self, obj: dict, schema_id: int | None) -> bytes: schema, _ = await self.serializer.get_schema_for_id(schema_id) bytes_ = await self.serializer.serialize(schema, obj) return bytes_ @@ -1066,7 +1137,7 @@ async def validate_publish_request_format(self, data: dict, formats: dict, conte sub_code=RESTErrorCodes.INVALID_DATA.value, ) - async def produce_messages(self, *, topic: str, prepared_records: List) -> List: + async def produce_messages(self, *, topic: str, prepared_records: list) -> list: """ :raises NoBrokersAvailable: :raises AuthenticationFailedError: diff --git a/tests/unit/kafka_rest_apis/__init__.py b/tests/unit/kafka_rest_apis/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/kafka_rest_apis/test_rest_proxy_cache.py b/tests/unit/kafka_rest_apis/test_rest_proxy_cache.py new file mode 100644 index 000000000..f49cfbf63 --- /dev/null +++ b/tests/unit/kafka_rest_apis/test_rest_proxy_cache.py @@ -0,0 +1,253 @@ +# pylint: disable=protected-access +""" +Copyright (c) 2024 Aiven Ltd +See LICENSE for details +""" +from karapace.config import DEFAULTS +from karapace.kafka_rest_apis import UserRestProxy +from karapace.serialization import SchemaRegistrySerializer +from unittest.mock import patch + +import copy + + +def user_rest_proxy(max_age_metadata: int = 5) -> UserRestProxy: + configs = {**DEFAULTS, **{"admin_metadata_max_age": max_age_metadata}} + serializer = SchemaRegistrySerializer(configs) + return UserRestProxy(configs, 1, serializer, auth_expiry=None, verify_connection=False) + + +EMPTY_REPLY = { + "topics": {}, + "brokers": [], +} + +TOPIC_REQUEST = { + "topics": { + "topic_a": { + "partitions": [ + { + "partition": 0, + "leader": 10, + "replicas": [ + {"broker": 10, "leader": True, "in_sync": True}, + ], + } + ] + } + }, + "brokers": [10], +} + +ALL_TOPIC_REQUEST = { + "topics": { + "topic_a": { + "partitions": [ + { + "partition": 0, + "leader": 69, + "replicas": [ + {"broker": 69, "leader": True, "in_sync": True}, + {"broker": 67, "leader": False, "in_sync": True}, + ], + } + ] + }, + "topic_b": { + "partitions": [ + { + "partition": 0, + "leader": 66, + "replicas": [ + {"broker": 69, "leader": False, "in_sync": True}, + {"broker": 67, "leader": False, "in_sync": False}, + {"broker": 66, "leader": True, "in_sync": True}, + {"broker": 65, "leader": False, "in_sync": True}, + ], + } + ] + }, + "__consumer_offsets": { + "partitions": [ + { + "partition": 0, + "leader": 69, + "replicas": [ + {"broker": 69, "leader": True, "in_sync": True}, + {"broker": 68, "leader": False, "in_sync": True}, + {"broker": 67, "leader": False, "in_sync": True}, + ], + }, + { + "partition": 1, + "leader": 67, + "replicas": [ + {"broker": 67, "leader": True, "in_sync": True}, + {"broker": 68, "leader": False, "in_sync": True}, + {"broker": 69, "leader": False, "in_sync": True}, + ], + }, + { + "partition": 2, + "leader": 67, + "replicas": [ + {"broker": 67, "leader": True, "in_sync": True}, + {"broker": 69, "leader": False, "in_sync": True}, + {"broker": 68, "leader": False, "in_sync": True}, + ], + }, + ] + }, + }, + "brokers": [68, 64, 66, 65, 69, 67], +} + + +async def test_cache_is_evicted_after_expiration_global_initially() -> None: + proxy = user_rest_proxy() + with patch( + "karapace.kafka.admin.KafkaAdminClient.cluster_metadata", return_value=EMPTY_REPLY + ) as mocked_cluster_metadata: + await proxy.cluster_metadata(None) + mocked_cluster_metadata.assert_called_once_with(None) # "initially the metadata are always old" + + +async def test_cache_is_evicted_after_expiration_global() -> None: + proxy = user_rest_proxy(max_age_metadata=10) + proxy._global_metadata_birth = 0 + with patch( + "karapace.kafka.admin.KafkaAdminClient.cluster_metadata", return_value=EMPTY_REPLY + ) as mocked_cluster_metadata: + with patch("time.monotonic", return_value=11): + await proxy.cluster_metadata(None) + mocked_cluster_metadata.assert_called_once_with(None) # "metadata old require a refresh" + + +async def test_global_cache_is_used_for_single_topic() -> None: + proxy = user_rest_proxy(max_age_metadata=10) + proxy._global_metadata_birth = 0 + with patch( + "karapace.kafka.admin.KafkaAdminClient.cluster_metadata", return_value=ALL_TOPIC_REQUEST + ) as mocked_cluster_metadata: + with patch("time.monotonic", return_value=11): + await proxy.cluster_metadata(None) + await proxy.cluster_metadata(None) + await proxy.cluster_metadata(None) + + mocked_cluster_metadata.assert_called_once_with(None) # "calling multiple times should be cached" + + assert proxy._global_metadata_birth == 11 + assert proxy._cluster_metadata_topic_birth == {"topic_a": 11, "topic_b": 11, "__consumer_offsets": 11} + + with patch( + "karapace.kafka.admin.KafkaAdminClient.cluster_metadata", return_value=ALL_TOPIC_REQUEST + ) as mocked_cluster_metadata: + with patch("time.monotonic", return_value=14): + await proxy.cluster_metadata(["topic_a", "topic_b"]) + + assert ( + mocked_cluster_metadata.call_count == 0 + ), "the result should still be cached since we marked it as ready at time 11 and we are at 14" + + +async def test_cache_is_evicted_if_one_topic_is_expired() -> None: + proxy = user_rest_proxy(max_age_metadata=10) + proxy._global_metadata_birth = 0 + with patch( + "karapace.kafka.admin.KafkaAdminClient.cluster_metadata", return_value=ALL_TOPIC_REQUEST + ) as mocked_cluster_metadata: + with patch("time.monotonic", return_value=11): + await proxy.cluster_metadata(None) + + proxy._cluster_metadata_topic_birth = {"topic_a": 11, "topic_b": 1, "__consumer_offsets": 11} + + with patch( + "karapace.kafka.admin.KafkaAdminClient.cluster_metadata", return_value=ALL_TOPIC_REQUEST + ) as mocked_cluster_metadata: + with patch("time.monotonic", return_value=14): + await proxy.cluster_metadata(["topic_a", "topic_b"]) + + assert mocked_cluster_metadata.call_count == 1, "topic_b should be evicted" + + +async def test_cache_is_evicted_if_a_topic_was_never_queries() -> None: + proxy = user_rest_proxy(max_age_metadata=10) + proxy._global_metadata_birth = 0 + with patch( + "karapace.kafka.admin.KafkaAdminClient.cluster_metadata", return_value=ALL_TOPIC_REQUEST + ) as mocked_cluster_metadata: + with patch("time.monotonic", return_value=11): + await proxy.cluster_metadata(None) + + proxy._cluster_metadata_topic_birth = {"topic_a": 11, "__consumer_offsets": 11} + + with patch( + "karapace.kafka.admin.KafkaAdminClient.cluster_metadata", return_value=ALL_TOPIC_REQUEST + ) as mocked_cluster_metadata: + with patch("time.monotonic", return_value=14): + await proxy.cluster_metadata(["topic_a", "topic_b"]) + + assert mocked_cluster_metadata.call_count == 1, "topic_b is not present in the cache, should call the refresh" + + +async def test_cache_is_used_if_topic_requested_is_updated() -> None: + proxy = user_rest_proxy(max_age_metadata=10) + proxy._global_metadata_birth = 0 + with patch( + "karapace.kafka.admin.KafkaAdminClient.cluster_metadata", return_value=TOPIC_REQUEST + ) as mocked_cluster_metadata: + with patch("time.monotonic", return_value=11): + await proxy.cluster_metadata(None) + + with patch( + "karapace.kafka.admin.KafkaAdminClient.cluster_metadata", return_value=ALL_TOPIC_REQUEST + ) as mocked_cluster_metadata: + with patch("time.monotonic", return_value=14): + await proxy.cluster_metadata(["topic_a"]) + + assert mocked_cluster_metadata.call_count == 0, "topic_a cache its present, should be used" + + +async def test_update_global_cache() -> None: + proxy = user_rest_proxy(max_age_metadata=10) + proxy._global_metadata_birth = 0 + with patch( + "karapace.kafka.admin.KafkaAdminClient.cluster_metadata", return_value=TOPIC_REQUEST + ) as mocked_cluster_metadata: + with patch("time.monotonic", return_value=11): + await proxy.cluster_metadata(None) + + assert mocked_cluster_metadata.call_count == 1, "should call the server for the first time" + + with patch( + "karapace.kafka.admin.KafkaAdminClient.cluster_metadata", return_value=TOPIC_REQUEST + ) as mocked_cluster_metadata: + with patch("time.monotonic", return_value=21): + await proxy.cluster_metadata(None) + + assert mocked_cluster_metadata.call_count == 0, "should call the server since the cache its expired" + + +async def test_update_topic_cache_do_not_evict_all_the_global_cache() -> None: + proxy = user_rest_proxy(max_age_metadata=10) + proxy._global_metadata_birth = 0 + proxy._cluster_metadata = ALL_TOPIC_REQUEST + proxy._cluster_metadata_topic_birth = {"topic_a": 0, "topic_b": 200, "__consumer_offsets": 200} + + with patch( + "karapace.kafka.admin.KafkaAdminClient.cluster_metadata", return_value=TOPIC_REQUEST + ) as mocked_cluster_metadata: + with patch("time.monotonic", return_value=208): + res = await proxy.cluster_metadata(["topic_a"]) + + assert res == TOPIC_REQUEST + + assert proxy._cluster_metadata_topic_birth == {"topic_a": 208, "topic_b": 200, "__consumer_offsets": 200} + + expected_metadata = copy.deepcopy(ALL_TOPIC_REQUEST) + expected_metadata["topics"]["topic_a"] = TOPIC_REQUEST["topics"]["topic_a"] + assert proxy._cluster_metadata == expected_metadata + + assert ( + mocked_cluster_metadata.call_count == 1 + ), "we should call the server since the previous time of caching for the topic_a was 0"