diff --git a/examples/kafka/ack_after_process.py b/examples/kafka/ack_after_process.py new file mode 100644 index 0000000000..7a00b7fac7 --- /dev/null +++ b/examples/kafka/ack_after_process.py @@ -0,0 +1,19 @@ +from faststream import FastStream, Logger +from faststream.kafka import KafkaBroker + +broker = KafkaBroker() +app = FastStream(broker) + + +@broker.subscriber( + "test", + group_id="group", + auto_commit=False, +) +async def handler(msg: str, logger: Logger): + logger.info(msg) + + +@app.after_startup +async def test(): + await broker.publish("Hi!", "test") diff --git a/faststream/_compat.py b/faststream/_compat.py index 64da9690a8..63cb278bc5 100644 --- a/faststream/_compat.py +++ b/faststream/_compat.py @@ -10,17 +10,12 @@ ) from fast_depends._compat import FieldInfo from pydantic import BaseModel -from typing_extensions import TypedDict as TypedDict -from typing_extensions import override as override - -# TODO: uncomment with py3.12 release 2023-10-02 -# if sys.version_info < (3, 12): -# from typing_extensions import override as override -# from typing_extensions import TypedDict as TypedDict -# else: -# from typing import override -# from typing import TypedDict as TypedDict +if sys.version_info < (3, 12): + from typing_extensions import TypedDict as TypedDict + from typing_extensions import override as override +else: + from typing import TypedDict as TypedDict if sys.version_info < (3, 11): from typing_extensions import Never as Never diff --git a/faststream/broker/handler.py b/faststream/broker/handler.py index 7ae1cbcc7f..be2bccc61b 100644 --- a/faststream/broker/handler.py +++ b/faststream/broker/handler.py @@ -252,6 +252,8 @@ async def consume(self, msg: MsgType) -> SendableMessage: # type: ignore[overri async with AsyncExitStack() as stack: gl_middlewares: List[BaseMiddleware] = [] + stack.enter_context(context.scope("handler_", self)) + for m in self.global_middlewares: gl_middlewares.append(await stack.enter_async_context(m(msg))) @@ -308,7 +310,7 @@ async def consume(self, msg: MsgType) -> SendableMessage: # type: ignore[overri for m_pub in all_middlewares: result_to_send = ( await pub_stack.enter_async_context( - m_pub.publish_scope(result_msg) + m_pub.publish_scope(result_to_send) ) ) diff --git a/faststream/broker/test.py b/faststream/broker/test.py index c57f1fd5d0..cbfe52074d 100644 --- a/faststream/broker/test.py +++ b/faststream/broker/test.py @@ -1,16 +1,24 @@ -from types import TracebackType -from typing import Any, Dict, Optional, Type +from abc import abstractmethod +from contextlib import asynccontextmanager +from functools import partial +from types import MethodType, TracebackType +from typing import Any, AsyncGenerator, Dict, Generic, Optional, Type, TypeVar +from unittest.mock import AsyncMock import anyio from anyio.abc._tasks import TaskGroup from faststream.app import FastStream from faststream.broker.core.abc import BrokerUsecase +from faststream.broker.core.asyncronous import BrokerAsyncUsecase from faststream.broker.handler import AsyncHandler from faststream.broker.middlewares import CriticalLogMiddleware +from faststream.broker.wrapper import HandlerCallWrapper from faststream.types import SendableMessage, SettingField from faststream.utils.functions import timeout_scope +Broker = TypeVar("Broker", bound=BrokerAsyncUsecase[Any, Any]) + class TestApp: # make sure pytest doesn't try to collect this class as a test class @@ -87,6 +95,114 @@ async def __aexit__( await self._task.__aexit__(None, None, None) +class TestBroker(Generic[Broker]): + # This is set so pytest ignores this class + __test__ = False + + def __init__( + self, + broker: Broker, + with_real: bool = False, + connect_only: bool = False, + ): + self.with_real = with_real + self.broker = broker + self.connect_only = connect_only + + async def __aenter__(self) -> Broker: + self._ctx = self._create_ctx() + return await self._ctx.__aenter__() + + async def __aexit__(self, *args: Any) -> None: + await self._ctx.__aexit__(*args) + + @asynccontextmanager + async def _create_ctx(self) -> AsyncGenerator[Broker, None]: + if not self.with_real: + self._patch_test_broker(self.broker) + else: + self._fake_start(self.broker) + + async with self.broker: + try: + if not self.connect_only: + await self.broker.start() + yield self.broker + finally: + self._fake_close(self.broker) + + @classmethod + def _patch_test_broker(cls, broker: Broker) -> None: + broker.start = AsyncMock(wraps=partial(cls._fake_start, broker)) # type: ignore[method-assign] + broker._connect = MethodType(cls._fake_connect, broker) # type: ignore[method-assign] + broker.close = AsyncMock() # type: ignore[method-assign] + + @classmethod + def _fake_start(cls, broker: Broker, *args: Any, **kwargs: Any) -> None: + for key, p in broker._publishers.items(): + if getattr(p, "_fake_handler", False): + continue + + handler = broker.handlers.get(key) + if handler is not None: + for f, _, _, _, _, _ in handler.calls: + f.mock.side_effect = p.mock + else: + p._fake_handler = True + f = cls.create_publisher_fake_subscriber(broker, p) + p.mock = f.mock + + cls.patch_publisher(broker, p) + + patch_broker_calls(broker) + + @classmethod + def _fake_close( + cls, + broker: Broker, + exc_type: Optional[Type[BaseException]] = None, + exc_val: Optional[BaseException] = None, + exec_tb: Optional[TracebackType] = None, + ) -> None: + broker.middlewares = [ + CriticalLogMiddleware(broker.logger, broker.log_level), + *broker.middlewares, + ] + + for p in broker._publishers.values(): + p.mock.reset_mock() + if getattr(p, "_fake_handler", False): + cls.remove_publisher_fake_subscriber(broker, p) + p._fake_handler = False + p.mock.reset_mock() + + for h in broker.handlers.values(): + for f, _, _, _, _, _ in h.calls: + f.refresh(with_mock=True) + + @staticmethod + @abstractmethod + def create_publisher_fake_subscriber( + broker: Broker, publisher: Any + ) -> HandlerCallWrapper[Any, Any, Any]: + raise NotImplementedError() + + @staticmethod + @abstractmethod + def remove_publisher_fake_subscriber(broker: Broker, publisher: Any) -> None: + raise NotImplementedError() + + @staticmethod + @abstractmethod + async def _fake_connect(broker: Broker, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError() + + @staticmethod + @abstractmethod + def patch_publisher(broker: Broker, publisher: Any) -> None: + raise NotImplementedError() + + def patch_broker_calls(broker: BrokerUsecase[Any, Any]) -> None: """Patch broker calls. diff --git a/faststream/kafka/broker.py b/faststream/kafka/broker.py index d849223720..e7791ae8d3 100644 --- a/faststream/kafka/broker.py +++ b/faststream/kafka/broker.py @@ -267,7 +267,7 @@ def subscriber( # type: ignore[override] "earliest", "none", ] = "latest", - enable_auto_commit: bool = True, + auto_commit: bool = True, auto_commit_interval_ms: int = 5000, check_crcs: bool = True, partition_assignment_strategy: Sequence[AbstractPartitionAssignor] = ( @@ -336,7 +336,7 @@ def subscriber( # type: ignore[override] fetch_min_bytes (int): The minimum number of bytes to fetch. max_partition_fetch_bytes (int): The maximum bytes to fetch for a partition. auto_offset_reset (Literal["latest", "earliest", "none"]): Auto offset reset policy. - enable_auto_commit (bool): Whether to enable auto-commit. + auto_commit (bool): Whether to enable auto-commit. auto_commit_interval_ms (int): Auto-commit interval in milliseconds. check_crcs (bool): Whether to check CRCs. partition_assignment_strategy (Sequence[AbstractPartitionAssignor]): Partition assignment strategy. @@ -367,6 +367,9 @@ def subscriber( # type: ignore[override] self._setup_log_context(topics, group_id) + if not auto_commit and not group_id: + raise ValueError("You should install `group_id` with manual commit mode") + key = Handler.get_routing_hash(topics, group_id) builder = partial( aiokafka.AIOKafkaConsumer, @@ -377,7 +380,7 @@ def subscriber( # type: ignore[override] fetch_min_bytes=fetch_min_bytes, max_partition_fetch_bytes=max_partition_fetch_bytes, auto_offset_reset=auto_offset_reset, - enable_auto_commit=enable_auto_commit, + enable_auto_commit=auto_commit, auto_commit_interval_ms=auto_commit_interval_ms, check_crcs=check_crcs, partition_assignment_strategy=partition_assignment_strategy, @@ -399,6 +402,7 @@ def subscriber( # type: ignore[override] topics=topics, group_id=group_id, ), + is_manual=not auto_commit, group_id=group_id, client_id=self.client_id, builder=builder, diff --git a/faststream/kafka/broker.pyi b/faststream/kafka/broker.pyi index 43f965d9fd..92d003a416 100644 --- a/faststream/kafka/broker.pyi +++ b/faststream/kafka/broker.pyi @@ -208,7 +208,7 @@ class KafkaBroker( "earliest", "none", ] = "latest", - enable_auto_commit: bool = True, + auto_commit: bool = True, auto_commit_interval_ms: int = 5000, check_crcs: bool = True, partition_assignment_strategy: Sequence[AbstractPartitionAssignor] = ( @@ -266,7 +266,7 @@ class KafkaBroker( "earliest", "none", ] = "latest", - enable_auto_commit: bool = True, + auto_commit: bool = True, auto_commit_interval_ms: int = 5000, check_crcs: bool = True, partition_assignment_strategy: Sequence[AbstractPartitionAssignor] = ( diff --git a/faststream/kafka/fastapi.pyi b/faststream/kafka/fastapi.pyi index 1f1ef69edf..8ed036a6a4 100644 --- a/faststream/kafka/fastapi.pyi +++ b/faststream/kafka/fastapi.pyi @@ -151,7 +151,7 @@ class KafkaRouter(StreamRouter[ConsumerRecord]): "earliest", "none", ] = "latest", - enable_auto_commit: bool = True, + auto_commit: bool = True, auto_commit_interval_ms: int = 5000, check_crcs: bool = True, partition_assignment_strategy: Sequence[AbstractPartitionAssignor] = ( @@ -215,7 +215,7 @@ class KafkaRouter(StreamRouter[ConsumerRecord]): "earliest", "none", ] = "latest", - enable_auto_commit: bool = True, + auto_commit: bool = True, auto_commit_interval_ms: int = 5000, check_crcs: bool = True, partition_assignment_strategy: Sequence[AbstractPartitionAssignor] = ( @@ -275,7 +275,7 @@ class KafkaRouter(StreamRouter[ConsumerRecord]): "earliest", "none", ] = "latest", - enable_auto_commit: bool = True, + auto_commit: bool = True, auto_commit_interval_ms: int = 5000, check_crcs: bool = True, partition_assignment_strategy: Sequence[AbstractPartitionAssignor] = ( @@ -337,7 +337,7 @@ class KafkaRouter(StreamRouter[ConsumerRecord]): "earliest", "none", ] = "latest", - enable_auto_commit: bool = True, + auto_commit: bool = True, auto_commit_interval_ms: int = 5000, check_crcs: bool = True, partition_assignment_strategy: Sequence[AbstractPartitionAssignor] = ( diff --git a/faststream/kafka/handler.py b/faststream/kafka/handler.py index ae60c61908..531640125d 100644 --- a/faststream/kafka/handler.py +++ b/faststream/kafka/handler.py @@ -62,6 +62,7 @@ def __init__( group_id: Optional[str] = None, client_id: str = "faststream-" + __version__, builder: Callable[..., AIOKafkaConsumer], + is_manual: bool = False, batch: bool = False, batch_timeout_ms: int = 200, max_records: Optional[int] = None, @@ -101,6 +102,7 @@ def __init__( self.batch = batch self.batch_timeout_ms = batch_timeout_ms self.max_records = max_records + self.is_manual = is_manual self.builder = builder self.task = None diff --git a/faststream/kafka/message.py b/faststream/kafka/message.py index eba1e92dd1..599d7595f9 100644 --- a/faststream/kafka/message.py +++ b/faststream/kafka/message.py @@ -22,6 +22,19 @@ class KafkaMessage(StreamMessage[aiokafka.ConsumerRecord]): Reject the Kafka message. """ + def __init__( + self, + *args: Any, + consumer: aiokafka.AIOKafkaConsumer, + is_manual: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + self.is_manual = is_manual + self.consumer = consumer + self.commited = False + async def ack(self, **kwargs: Any) -> None: """ Acknowledge the Kafka message. @@ -32,7 +45,9 @@ async def ack(self, **kwargs: Any) -> None: Returns: None: This method does not return a value. """ - return None + if self.is_manual and not self.commited: + await self.consumer.commit() + self.commited = True async def nack(self, **kwargs: Any) -> None: """ @@ -44,7 +59,7 @@ async def nack(self, **kwargs: Any) -> None: Returns: None: This method does not return a value. """ - return None + self.commited = True async def reject(self, **kwargs: Any) -> None: """ @@ -56,4 +71,4 @@ async def reject(self, **kwargs: Any) -> None: Returns: None: This method does not return a value. """ - return None + self.commited = True diff --git a/faststream/kafka/parser.py b/faststream/kafka/parser.py index b690c3cf1e..6f33ec5cf6 100644 --- a/faststream/kafka/parser.py +++ b/faststream/kafka/parser.py @@ -7,6 +7,7 @@ from faststream.broker.parsers import decode_message from faststream.kafka.message import KafkaMessage from faststream.types import DecodedMessage +from faststream.utils.context.main import context class AioKafkaParser: @@ -26,6 +27,7 @@ async def parse_message( The above docstring is autogenerated by docstring-gen library (https://docstring-gen.airt.ai) """ headers = {i: j.decode() for i, j in message.headers} + handler = context.get("handler_") return KafkaMessage( body=message.value, headers=headers, @@ -34,6 +36,8 @@ async def parse_message( message_id=f"{message.offset}-{message.timestamp}", correlation_id=headers.get("correlation_id", str(uuid4())), raw_message=message, + consumer=handler.consumer, + is_manual=handler.is_manual, ) @staticmethod @@ -60,6 +64,7 @@ async def parse_message_batch( first = message[0] last = message[-1] headers = {i: j.decode() for i, j in first.headers} + handler = context.get("handler_") return KafkaMessage( body=[m.value for m in message], headers=headers, @@ -68,6 +73,8 @@ async def parse_message_batch( message_id=f"{first.offset}-{last.offset}-{first.timestamp}", correlation_id=headers.get("correlation_id", str(uuid4())), raw_message=message, + consumer=handler.consumer, + is_manual=handler.is_manual, ) @staticmethod diff --git a/faststream/kafka/router.pyi b/faststream/kafka/router.pyi index 267f7767a2..1be076556f 100644 --- a/faststream/kafka/router.pyi +++ b/faststream/kafka/router.pyi @@ -77,7 +77,7 @@ class KafkaRouter(BrokerRouter[str, aiokafka.ConsumerRecord]): "earliest", "none", ] = "latest", - enable_auto_commit: bool = True, + auto_commit: bool = True, auto_commit_interval_ms: int = 5000, check_crcs: bool = True, partition_assignment_strategy: Sequence[AbstractPartitionAssignor] = ( diff --git a/faststream/kafka/shared/router.pyi b/faststream/kafka/shared/router.pyi index 03e41d4a13..11e4dde8a8 100644 --- a/faststream/kafka/shared/router.pyi +++ b/faststream/kafka/shared/router.pyi @@ -31,7 +31,7 @@ class KafkaRoute: "earliest", "none", ] = "latest", - enable_auto_commit: bool = True, + auto_commit: bool = True, auto_commit_interval_ms: int = 5000, check_crcs: bool = True, partition_assignment_strategy: Sequence[AbstractPartitionAssignor] = ( @@ -91,7 +91,7 @@ class KafkaRoute: "earliest", "none", ] = "latest", - enable_auto_commit: bool = True, + auto_commit: bool = True, auto_commit_interval_ms: int = 5000, check_crcs: bool = True, partition_assignment_strategy: Sequence[AbstractPartitionAssignor] = ( diff --git a/faststream/kafka/test.py b/faststream/kafka/test.py index 61f813d6e1..153a9432b7 100644 --- a/faststream/kafka/test.py +++ b/faststream/kafka/test.py @@ -1,173 +1,51 @@ -from contextlib import asynccontextmanager from datetime import datetime -from functools import partial -from types import MethodType, TracebackType -from typing import Any, AsyncGenerator, Dict, Optional, Type -from unittest.mock import AsyncMock +from typing import Any, Dict, Optional from uuid import uuid4 from aiokafka import ConsumerRecord from faststream._compat import override -from faststream.broker.middlewares import CriticalLogMiddleware from faststream.broker.parsers import encode_message -from faststream.broker.test import call_handler, patch_broker_calls +from faststream.broker.test import TestBroker, call_handler +from faststream.broker.wrapper import HandlerCallWrapper +from faststream.kafka.asyncapi import Publisher from faststream.kafka.broker import KafkaBroker +from faststream.kafka.message import KafkaMessage from faststream.kafka.producer import AioKafkaFastProducer from faststream.types import SendableMessage __all__ = ("TestKafkaBroker",) -class TestKafkaBroker: - """ - A context manager for creating a test KafkaBroker instance with optional mocking. - - This class serves as a context manager for creating a KafkaBroker instance for testing purposes. It can either use the - original KafkaBroker instance (if `with_real` is True) or replace certain components with mocks (if `with_real` is - False) to isolate the broker during testing. - - Args: - broker (KafkaBroker): The KafkaBroker instance to be used in testing. - with_real (bool, optional): If True, the original broker is returned; if False, components are replaced with - mock objects. Defaults to False. - - Attributes: - broker (KafkaBroker): The KafkaBroker instance provided for testing. - with_real (bool): A boolean flag indicating whether to use the original broker (True) or replace components with - mocks (False). +class TestKafkaBroker(TestBroker[KafkaBroker]): + @staticmethod + async def _fake_connect(broker: KafkaBroker, *args: Any, **kwargs: Any) -> None: + broker._producer = FakeProducer(broker) - Methods: - __aenter__(self) -> KafkaBroker: - Enter the context and return the KafkaBroker instance. - - __aexit__(self, *args: Any) -> None: - Exit the context. - - Example usage: - - ```python - real_broker = KafkaBroker() - with TestKafkaBroker(real_broker, with_real=True) as broker: - # Use the real KafkaBroker instance for testing. - - with TestKafkaBroker(real_broker, with_real=False) as broker: - # Use a mocked KafkaBroker instance for testing. - """ + @staticmethod + def patch_publisher(broker: KafkaBroker, publisher: Any) -> None: + publisher._producer = broker._producer - # This is set so pytest ignores this class - __test__ = False - - def __init__( - self, + @staticmethod + def create_publisher_fake_subscriber( broker: KafkaBroker, - with_real: bool = False, - connect_only: bool = False, - ): - """ - Initialize a TestKafkaBroker instance. - - Args: - broker (KafkaBroker): The KafkaBroker instance to be used in testing. - with_real (bool, optional): If True, the original broker is returned; if False, components are replaced with - mock objects. Defaults to False. - """ - self.with_real = with_real - self.broker = broker - self.connect_only = connect_only - - @asynccontextmanager - async def _create_ctx(self) -> AsyncGenerator[KafkaBroker, None]: - """ - Create the context for the context manager. - - Yields: - KafkaBroker: The KafkaBroker instance for testing, either with or without mocks. - """ - if not self.with_real: - self.broker.start = AsyncMock(wraps=partial(_fake_start, self.broker)) # type: ignore[method-assign] - self.broker._connect = MethodType(_fake_connect, self.broker) # type: ignore[method-assign] - self.broker.close = AsyncMock() # type: ignore[method-assign] - else: - _fake_start(self.broker) - - async with self.broker: - try: - if not self.connect_only: - await self.broker.start() - yield self.broker - finally: - _fake_close(self.broker) - - async def __aenter__(self) -> KafkaBroker: - """ - Enter the context and return the KafkaBroker instance. - - Returns: - KafkaBroker: The KafkaBroker instance for testing, either with or without mocks. - """ - self._ctx = self._create_ctx() - return await self._ctx.__aenter__() - - async def __aexit__(self, *args: Any) -> None: - """ - Exit the context. - - Args: - *args: Variable-length argument list. - """ - await self._ctx.__aexit__(*args) - - -def build_message( - message: SendableMessage, - topic: str, - partition: Optional[int] = None, - timestamp_ms: Optional[int] = None, - key: Optional[bytes] = None, - headers: Optional[Dict[str, str]] = None, - correlation_id: Optional[str] = None, - *, - reply_to: str = "", -) -> ConsumerRecord: - """ - Build a Kafka ConsumerRecord for a sendable message. - - Args: - message (SendableMessage): The sendable message to be encoded. - topic (str): The Kafka topic for the message. - partition (Optional[int], optional): The Kafka partition for the message. Defaults to None. - timestamp_ms (Optional[int], optional): The message timestamp in milliseconds. Defaults to None. - key (Optional[bytes], optional): The message key. Defaults to None. - headers (Optional[Dict[str, str]], optional): Additional headers for the message. Defaults to None. - correlation_id (Optional[str], optional): The correlation ID for the message. Defaults to None. - reply_to (str, optional): The topic to which responses should be sent. Defaults to "". + publisher: Publisher, + ) -> HandlerCallWrapper[Any, Any, Any]: + @broker.subscriber( # type: ignore[call-overload,misc] + publisher.topic, + batch=publisher.batch, + _raw=True, + ) + def f(msg: KafkaMessage) -> str: + return "" - Returns: - ConsumerRecord: A Kafka ConsumerRecord object. - """ - msg, content_type = encode_message(message) - k = key or b"" - headers = { - "content-type": content_type or "", - "correlation_id": correlation_id or str(uuid4()), - "reply_to": reply_to, - **(headers or {}), - } + return f # type: ignore[no-any-return] - return ConsumerRecord( - value=msg, - topic=topic, - partition=partition or 0, - timestamp=timestamp_ms or int(datetime.now().timestamp()), - timestamp_type=0, - key=k, - serialized_key_size=len(k), - serialized_value_size=len(msg), - checksum=sum(msg), - offset=0, - headers=[(i, j.encode()) for i, j in headers.items()], - ) + @staticmethod + def remove_publisher_fake_subscriber( + broker: KafkaBroker, publisher: Publisher + ) -> None: + broker.handlers.pop(publisher.topic, None) class FakeProducer(AioKafkaFastProducer): @@ -287,88 +165,52 @@ async def publish_batch( return None -async def _fake_connect(self: KafkaBroker, *args: Any, **kwargs: Any) -> None: - """ - Fake connection to Kafka. - - Args: - self (KafkaBroker): The KafkaBroker instance. - *args (Any): Additional arguments. - **kwargs (Any): Additional keyword arguments. - - Returns: - None: This method does not return a value. - """ - self._producer = FakeProducer(self) - - -def _fake_close( - broker: KafkaBroker, - exc_type: Optional[Type[BaseException]] = None, - exc_val: Optional[BaseException] = None, - exec_tb: Optional[TracebackType] = None, -) -> None: - """ - Fake closing of the KafkaBroker. - - Args: - self (KafkaBroker): The KafkaBroker instance. - exc_type (Optional[Type[BaseException]], optional): Exception type. Defaults to None. - exc_val (Optional[BaseException], optional): Exception value. Defaults to None. - exec_tb (Optional[TracebackType], optional): Traceback information. Defaults to None. - - Returns: - None: This method does not return a value. - """ - broker.middlewares = [ - CriticalLogMiddleware(broker.logger, broker.log_level), - *broker.middlewares, - ] - - for p in broker._publishers.values(): - p.mock.reset_mock() - if getattr(p, "_fake_handler", False): - broker.handlers.pop(p.topic, None) - p._fake_handler = False - p.mock.reset_mock() - - for h in broker.handlers.values(): - for f, _, _, _, _, _ in h.calls: - f.refresh(with_mock=True) - - -def _fake_start(broker: KafkaBroker, *args: Any, **kwargs: Any) -> None: +def build_message( + message: SendableMessage, + topic: str, + partition: Optional[int] = None, + timestamp_ms: Optional[int] = None, + key: Optional[bytes] = None, + headers: Optional[Dict[str, str]] = None, + correlation_id: Optional[str] = None, + *, + reply_to: str = "", +) -> ConsumerRecord: """ - Fake starting of the KafkaBroker. + Build a Kafka ConsumerRecord for a sendable message. Args: - self (KafkaBroker): The KafkaBroker instance. - *args (Any): Additional arguments. - **kwargs (Any): Additional keyword arguments. + message (SendableMessage): The sendable message to be encoded. + topic (str): The Kafka topic for the message. + partition (Optional[int], optional): The Kafka partition for the message. Defaults to None. + timestamp_ms (Optional[int], optional): The message timestamp in milliseconds. Defaults to None. + key (Optional[bytes], optional): The message key. Defaults to None. + headers (Optional[Dict[str, str]], optional): Additional headers for the message. Defaults to None. + correlation_id (Optional[str], optional): The correlation ID for the message. Defaults to None. + reply_to (str, optional): The topic to which responses should be sent. Defaults to "". Returns: - None: This method does not return a value. + ConsumerRecord: A Kafka ConsumerRecord object. """ + msg, content_type = encode_message(message) + k = key or b"" + headers = { + "content-type": content_type or "", + "correlation_id": correlation_id or str(uuid4()), + "reply_to": reply_to, + **(headers or {}), + } - for key, p in broker._publishers.items(): - if getattr(p, "_fake_handler", False): - continue - - handler = broker.handlers.get(key) - if handler is not None: - for f, _, _, _, _, _ in handler.calls: - f.mock.side_effect = p.mock - else: - p._fake_handler = True - - @broker.subscriber( # type: ignore[call-overload,misc] - p.topic, batch=p.batch, _raw=True - ) - def f(msg: Any) -> None: - pass - - p.mock = f.mock - - p._producer = broker._producer - - patch_broker_calls(broker) + return ConsumerRecord( + value=msg, + topic=topic, + partition=partition or 0, + timestamp=timestamp_ms or int(datetime.now().timestamp()), + timestamp_type=0, + key=k, + serialized_key_size=len(k), + serialized_value_size=len(msg), + checksum=sum(msg), + offset=0, + headers=[(i, j.encode()) for i, j in headers.items()], + ) diff --git a/faststream/nats/test.py b/faststream/nats/test.py index c23cb7a248..925deba7bd 100644 --- a/faststream/nats/test.py +++ b/faststream/nats/test.py @@ -1,101 +1,47 @@ -from contextlib import asynccontextmanager -from functools import partial from itertools import zip_longest -from types import MethodType, TracebackType -from typing import Any, AsyncGenerator, Dict, Optional, Type, Union -from unittest.mock import AsyncMock +from typing import Any, Dict, Optional, Union from uuid import uuid4 from nats.aio.msg import Msg from faststream._compat import override -from faststream.broker.middlewares import CriticalLogMiddleware from faststream.broker.parsers import encode_message -from faststream.broker.test import call_handler, patch_broker_calls -from faststream.nats.asyncapi import Handler +from faststream.broker.test import TestBroker, call_handler +from faststream.broker.wrapper import HandlerCallWrapper +from faststream.nats.asyncapi import Handler, Publisher from faststream.nats.broker import NatsBroker +from faststream.nats.message import NatsMessage from faststream.nats.producer import NatsFastProducer from faststream.types import SendableMessage __all__ = ("TestNatsBroker",) -class TestNatsBroker: - # This is set so pytest ignores this class - __test__ = False +class TestNatsBroker(TestBroker[NatsBroker]): + @staticmethod + def patch_publisher(broker: NatsBroker, publisher: Any) -> None: + publisher._producer = broker._producer - def __init__( - self, + @staticmethod + def create_publisher_fake_subscriber( broker: NatsBroker, - with_real: bool = False, - connect_only: bool = False, - ): - self.with_real = with_real - self.broker = broker - self.connect_only = connect_only - - @asynccontextmanager - async def _create_ctx(self) -> AsyncGenerator[NatsBroker, None]: - if not self.with_real: - self.broker.start = AsyncMock(wraps=partial(_fake_start, self.broker)) # type: ignore[method-assign] - self.broker._connect = MethodType(_fake_connect, self.broker) # type: ignore[method-assign] - self.broker.close = AsyncMock() # type: ignore[method-assign] - else: - _fake_start(self.broker) - - async with self.broker: - try: - if not self.connect_only: - await self.broker.start() - yield self.broker - finally: - _fake_close(self.broker) - - async def __aenter__(self) -> NatsBroker: - self._ctx = self._create_ctx() - return await self._ctx.__aenter__() - - async def __aexit__(self, *args: Any) -> None: - await self._ctx.__aexit__(*args) - - -class PatchedMessage(Msg): - async def ack(self) -> None: - pass - - async def ack_sync(self, timeout: float = 1) -> "PatchedMessage": - return self - - async def nak(self, delay: Union[int, float, None] = None) -> None: - pass - - async def term(self) -> None: - pass + publisher: Publisher, + ) -> HandlerCallWrapper[Any, Any, Any]: + @broker.subscriber(publisher.subject, _raw=True) + def f(msg: NatsMessage) -> None: + pass - async def in_progress(self) -> None: - pass + return f + @staticmethod + async def _fake_connect(broker: NatsBroker, *args: Any, **kwargs: Any) -> None: + broker._js_producer = broker._producer = FakeProducer(broker) # type: ignore[assignment] -def build_message( - message: SendableMessage, - subject: str, - *, - reply_to: str = "", - correlation_id: Optional[str] = None, - headers: Optional[Dict[str, Any]] = None, -) -> PatchedMessage: - msg, content_type = encode_message(message) - return PatchedMessage( - _client=None, # type: ignore - subject=subject, - reply=reply_to, - data=msg, - headers={ - "content-type": content_type or "", - "correlation_id": correlation_id or str(uuid4()), - **(headers or {}), - }, - ) + @staticmethod + def remove_publisher_fake_subscriber( + broker: NatsBroker, publisher: Publisher + ) -> None: + broker.handlers.pop(Handler.get_routing_hash(publisher.subject), None) class FakeProducer(NatsFastProducer): @@ -160,51 +106,40 @@ async def publish( # type: ignore[override] return None -async def _fake_connect(self: NatsBroker, *args: Any, **kwargs: Any) -> None: - self._js_producer = self._producer = FakeProducer(self) # type: ignore[assignment] - - -def _fake_close( - broker: NatsBroker, - exc_type: Optional[Type[BaseException]] = None, - exc_val: Optional[BaseException] = None, - exec_tb: Optional[TracebackType] = None, -) -> None: - broker.middlewares = [ - CriticalLogMiddleware(broker.logger, broker.log_level), - *broker.middlewares, - ] - - for p in broker._publishers.values(): - p.mock.reset_mock() - if getattr(p, "_fake_handler", False): - broker.handlers.pop(Handler.get_routing_hash(p.subject), None) - p._fake_handler = False - p.mock.reset_mock() - - for h in broker.handlers.values(): - for f, _, _, _, _, _ in h.calls: - f.refresh(with_mock=True) - +def build_message( + message: SendableMessage, + subject: str, + *, + reply_to: str = "", + correlation_id: Optional[str] = None, + headers: Optional[Dict[str, Any]] = None, +) -> "PatchedMessage": + msg, content_type = encode_message(message) + return PatchedMessage( + _client=None, # type: ignore + subject=subject, + reply=reply_to, + data=msg, + headers={ + "content-type": content_type or "", + "correlation_id": correlation_id or str(uuid4()), + **(headers or {}), + }, + ) -def _fake_start(broker: NatsBroker, *args: Any, **kwargs: Any) -> None: - for key, p in broker._publishers.items(): - if getattr(p, "_fake_handler", False): - continue - handler = broker.handlers.get(key) - if handler is not None: - for f, _, _, _, _, _ in handler.calls: - f.mock.side_effect = p.mock - else: - p._fake_handler = True +class PatchedMessage(Msg): + async def ack(self) -> None: + pass - @broker.subscriber(p.subject, _raw=True) - def f(msg: Any) -> None: - pass + async def ack_sync(self, timeout: float = 1) -> "PatchedMessage": + return self - p.mock = f.mock + async def nak(self, delay: Union[int, float, None] = None) -> None: + pass - p._producer = broker._producer + async def term(self) -> None: + pass - patch_broker_calls(broker) + async def in_progress(self) -> None: + pass diff --git a/faststream/rabbit/test.py b/faststream/rabbit/test.py index f8301c9280..024df86c08 100644 --- a/faststream/rabbit/test.py +++ b/faststream/rabbit/test.py @@ -1,8 +1,5 @@ import re -from contextlib import asynccontextmanager -from functools import partial -from types import MethodType, TracebackType -from typing import Any, AsyncGenerator, Optional, Type, Union +from typing import Any, Optional, Union from unittest.mock import AsyncMock from uuid import uuid4 @@ -11,8 +8,9 @@ from pamqp import commands as spec from pamqp.header import ContentHeader -from faststream.broker.middlewares import CriticalLogMiddleware -from faststream.broker.test import call_handler, patch_broker_calls +from faststream.broker.test import TestBroker, call_handler +from faststream.broker.wrapper import HandlerCallWrapper +from faststream.rabbit.asyncapi import Publisher from faststream.rabbit.broker import RabbitBroker from faststream.rabbit.message import RabbitMessage from faststream.rabbit.parser import AioPikaParser @@ -30,110 +28,47 @@ __all__ = ("TestRabbitBroker",) -class TestRabbitBroker: +class TestRabbitBroker(TestBroker[RabbitBroker]): + @classmethod + def _patch_test_broker(cls, broker: RabbitBroker) -> None: + broker._channel = AsyncMock() + broker.declarer = AsyncMock() + super()._patch_test_broker(broker) - """ - A context manager for creating a test RabbitBroker instance with optional mocking. - - This class is designed to be used as a context manager for creating a RabbitBroker instance, optionally replacing some - of its components with mocks for testing purposes. If the `with_real` attribute is set to True, it operates as a - pass-through context manager, returning the original RabbitBroker instance without any modifications. If `with_real` - is set to False, it replaces certain components like the channel, declarer, and start/connect/close methods with mock - objects to isolate the broker for testing. - - Args: - broker (RabbitBroker): The RabbitBroker instance to be used in testing. - with_real (bool, optional): If True, the original broker is returned; if False, components are replaced with - mock objects. Defaults to False. - - Attributes: - broker (RabbitBroker): The RabbitBroker instance provided for testing. - with_real (bool): A boolean flag indicating whether to use the original broker (True) or replace components with - mocks (False). - - Methods: - __aenter__(self) -> RabbitBroker: - Enter the context and return the RabbitBroker instance. - - __aexit__(self, *args: Any) -> None: - Exit the context. - - Example usage: + @staticmethod + async def _fake_connect(broker: RabbitBroker, *args: Any, **kwargs: Any) -> None: + broker._producer = FakeProducer(broker) - ```python - real_broker = RabbitBroker() - with TestRabbitBroker(real_broker, with_real=True) as broker: - # Use the real RabbitBroker instance for testing. + @staticmethod + def patch_publisher(broker: RabbitBroker, publisher: Any) -> None: + publisher._producer = broker._producer - with TestRabbitBroker(real_broker, with_real=False) as broker: - # Use a mocked RabbitBroker instance for testing. - ``` - """ - - # This is set so pytest ignores this class - __test__ = False - - def __init__( - self, + @staticmethod + def create_publisher_fake_subscriber( broker: RabbitBroker, - with_real: bool = False, - connect_only: bool = False, - ): - """ - Initialize a TestRabbitBroker instance. - - Args: - broker (RabbitBroker): The RabbitBroker instance to be used in testing. - with_real (bool, optional): If True, the original broker is returned; if False, components are replaced with - mock objects. Defaults to False. - """ - self.with_real = with_real - self.broker = broker - self.connect_only = connect_only - - @asynccontextmanager - async def _create_ctx(self) -> AsyncGenerator[RabbitBroker, None]: - """ - Create the context for the context manager. - - Yields: - RabbitBroker: The RabbitBroker instance for testing, either with or without mocks. - """ - if not self.with_real: - self.broker._channel = AsyncMock() - self.broker.declarer = AsyncMock() - self.broker.start = AsyncMock(wraps=partial(_fake_start, self.broker)) # type: ignore[method-assign] - self.broker._connect = MethodType(_fake_connect, self.broker) # type: ignore[method-assign] - self.broker.close = AsyncMock() # type: ignore[method-assign] - else: - _fake_start(self.broker) - - async with self.broker: - try: - if not self.connect_only: - await self.broker.start() - yield self.broker - finally: - _fake_close(self.broker) - - async def __aenter__(self) -> RabbitBroker: - """ - Enter the context and return the RabbitBroker instance. - - Returns: - RabbitBroker: The RabbitBroker instance for testing, either with or without mocks. - """ - self._ctx = self._create_ctx() - return await self._ctx.__aenter__() - - async def __aexit__(self, *args: Any) -> None: - """ - Exit the context. - - Args: - *args: Variable-length argument list. - """ - await self._ctx.__aexit__(*args) + publisher: Publisher, + ) -> HandlerCallWrapper[Any, Any, Any]: + @broker.subscriber( + queue=publisher.queue, + exchange=publisher.exchange, + _raw=True, + ) + def f(msg: RabbitMessage) -> str: + return "" + + return f + + @staticmethod + def remove_publisher_fake_subscriber( + broker: RabbitBroker, publisher: Publisher + ) -> None: + broker.handlers.pop( + get_routing_hash( + queue=publisher.queue, + exchange=publisher.exchange, + ), + None, + ) class PatchedMessage(IncomingMessage): @@ -365,85 +300,3 @@ async def publish( return r return None - - -async def _fake_connect(self: RabbitBroker, *args: Any, **kwargs: Any) -> None: - """ - Fake connection method for the RabbitBroker class. - - Args: - self (RabbitBroker): The RabbitBroker instance. - *args (Any): Additional arguments. - **kwargs (Any): Additional keyword arguments. - """ - self._producer = FakeProducer(self) - - -def _fake_close( - broker: RabbitBroker, - exc_type: Optional[Type[BaseException]] = None, - exc_val: Optional[BaseException] = None, - exec_tb: Optional[TracebackType] = None, -) -> None: - """ - Fake close method for the RabbitBroker class. - - Args: - self (RabbitBroker): The RabbitBroker instance. - exc_type (Optional[Type[BaseException]]): The exception type. - exc_val (Optional[BaseException]]): The exception value. - exec_tb (Optional[TracebackType]]): The exception traceback. - """ - broker.middlewares = [ - CriticalLogMiddleware(broker.logger, broker.log_level), - *broker.middlewares, - ] - - for key, p in broker._publishers.items(): - p.mock.reset_mock() - if getattr(p, "_fake_handler", False): - key = get_routing_hash(p.queue, p.exchange) - broker.handlers.pop(key, None) - p._fake_handler = False - p.mock.reset_mock() - - for h in broker.handlers.values(): - for f, _, _, _, _, _ in h.calls: - f.refresh(with_mock=True) - - -def _fake_start(broker: RabbitBroker, *args: Any, **kwargs: Any) -> None: - """ - Fake start method for the RabbitBroker class. - - Args: - self (RabbitBroker): The RabbitBroker instance. - *args (Any): Additional arguments. - **kwargs (Any): Additional keyword arguments. - """ - for key, p in broker._publishers.items(): - if getattr(p, "_fake_handler", False): - continue - - handler = broker.handlers.get(key) - - if handler is not None: - for f, _, _, _, _, _ in handler.calls: - f.mock.side_effect = p.mock - - else: - p._fake_handler = True - - @broker.subscriber( - queue=p.queue, - exchange=p.exchange, - _raw=True, - ) - def f(msg: RabbitMessage) -> str: - return "" - - p.mock = f.mock - - p._producer = broker._producer - - patch_broker_calls(broker) diff --git a/pyproject.toml b/pyproject.toml index 7b7154dd73..979e991965 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -230,6 +230,7 @@ source = [ context = '${CONTEXT}' omit = [ "**/__init__.py", + "tests/mypy/*", ] [tool.coverage.report]