From 1d35cbda52b7e61f4a3f6dbf2c89847510655634 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sat, 24 Aug 2024 00:03:24 +0300 Subject: [PATCH 01/62] subscriber.get_one() --- faststream/rabbit/subscriber/usecase.py | 43 +++++++++++++++++++------ scripts/start_test_env.sh | 2 +- scripts/stop_test_env.sh | 2 +- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 67421df2da..f85873cf46 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -1,3 +1,4 @@ +import asyncio from typing import ( TYPE_CHECKING, Any, @@ -14,14 +15,15 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.exceptions import SetupError +from faststream.rabbit.message import RabbitMessage from faststream.rabbit.parser import AioPikaParser from faststream.rabbit.schemas import BaseRMQInformation if TYPE_CHECKING: - from aio_pika import IncomingMessage, RobustQueue + from aio_pika import IncomingMessage, RobustQueue, Message from fast_depends.dependencies import Depends - from faststream.broker.message import StreamMessage + from faststream.broker.message import StreamMessage, gen_cor_id from faststream.broker.types import BrokerMiddleware, CustomCallable from faststream.rabbit.helpers.declarer import RabbitDeclarer from faststream.rabbit.publisher.producer import AioPikaFastProducer @@ -45,6 +47,7 @@ class LogicSubscriber( _consumer_tag: Optional[str] _queue_obj: Optional["RobustQueue"] _producer: Optional["AioPikaFastProducer"] + _prepared: bool def __init__( self, @@ -64,11 +67,11 @@ def __init__( description_: Optional[str], include_in_schema: bool, ) -> None: - parser = AioPikaParser(pattern=queue.path_regex) + self.parser = AioPikaParser(pattern=queue.path_regex) super().__init__( - default_parser=parser.parse_message, - default_decoder=parser.decode_message, + default_parser=self.parser.parse_message, + default_decoder=self.parser.decode_message, # Propagated options no_ack=no_ack, no_reply=no_reply, @@ -94,6 +97,7 @@ def __init__( self.app_id = None self.virtual_host = "" self.declarer = None + self._prepared = False @override def setup( # type: ignore[override] @@ -133,9 +137,7 @@ def setup( # type: ignore[override] _call_decorators=_call_decorators, ) - @override - async def start(self) -> None: - """Starts the consumer for the RabbitMQ queue.""" + async def _prepare(self) -> None: if self.declarer is None: raise SetupError("You should setup subscriber at first.") @@ -156,7 +158,17 @@ async def start(self) -> None: robust=self.queue.robust, ) - self._consumer_tag = await queue.consume( + async def _ensure_prepared(self) -> None: + if not self._prepared: + await self._prepare() + self.prepared = True + + @override + async def start(self) -> None: + """Starts the consumer for the RabbitMQ queue.""" + await self._ensure_prepared() + + self._consumer_tag = await self._queue_obj.consume( # NOTE: aio-pika expects AbstractIncomingMessage, not IncomingMessage self.consume, # type: ignore[arg-type] arguments=self.consume_args, @@ -164,6 +176,19 @@ async def start(self) -> None: await super().start() + async def get_one(self, auto_ack: bool = False) -> "RabbitMessage": + await self._ensure_prepared() + + if self._queue_obj is None: + raise SetupError("You should prepare() subscriber at first.") + + while (message := await self._queue_obj.get(fail=False, no_ack=auto_ack)) is None: + await asyncio.sleep(0) + + parsed_message = await self.parser.parse_message(message) + assert isinstance(parsed_message, RabbitMessage) + return parsed_message + async def close(self) -> None: await super().close() diff --git a/scripts/start_test_env.sh b/scripts/start_test_env.sh index a0ae1627b8..906556db41 100755 --- a/scripts/start_test_env.sh +++ b/scripts/start_test_env.sh @@ -2,4 +2,4 @@ source ./scripts/set_variables.sh -docker-compose -p $DOCKER_COMPOSE_PROJECT -f docs/includes/docker-compose.yaml up -d --no-recreate +docker compose -p $DOCKER_COMPOSE_PROJECT -f docs/includes/docker-compose.yaml up -d --no-recreate diff --git a/scripts/stop_test_env.sh b/scripts/stop_test_env.sh index 5d77186357..76ab4a3ee0 100755 --- a/scripts/stop_test_env.sh +++ b/scripts/stop_test_env.sh @@ -2,4 +2,4 @@ source ./scripts/set_variables.sh -docker-compose -p $DOCKER_COMPOSE_PROJECT -f docs/includes/docker-compose.yaml down +docker compose -p $DOCKER_COMPOSE_PROJECT -f docs/includes/docker-compose.yaml down From 4ba22b675bc12d256f3ec53e0d3372b34fec9d54 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sat, 24 Aug 2024 00:35:36 +0300 Subject: [PATCH 02/62] remove _prepare --- faststream/rabbit/subscriber/usecase.py | 41 ++++++++++--------------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index f85873cf46..b1ba4be770 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -67,11 +67,11 @@ def __init__( description_: Optional[str], include_in_schema: bool, ) -> None: - self.parser = AioPikaParser(pattern=queue.path_regex) + parser = AioPikaParser(pattern=queue.path_regex) super().__init__( - default_parser=self.parser.parse_message, - default_decoder=self.parser.decode_message, + default_parser=parser.parse_message, + default_decoder=parser.decode_message, # Propagated options no_ack=no_ack, no_reply=no_reply, @@ -137,16 +137,18 @@ def setup( # type: ignore[override] _call_decorators=_call_decorators, ) - async def _prepare(self) -> None: + @override + async def start(self) -> None: + """Starts the consumer for the RabbitMQ queue.""" if self.declarer is None: raise SetupError("You should setup subscriber at first.") self._queue_obj = queue = await self.declarer.declare_queue(self.queue) if ( - self.exchange is not None - and not queue.passive # queue just getted from RMQ - and self.exchange.name # check Exchange is not default + self.exchange is not None + and not queue.passive # queue just getted from RMQ + and self.exchange.name # check Exchange is not default ): exchange = await self.declarer.declare_exchange(self.exchange) @@ -158,15 +160,8 @@ async def _prepare(self) -> None: robust=self.queue.robust, ) - async def _ensure_prepared(self) -> None: - if not self._prepared: - await self._prepare() - self.prepared = True - - @override - async def start(self) -> None: - """Starts the consumer for the RabbitMQ queue.""" - await self._ensure_prepared() + if not self.calls: + return self._consumer_tag = await self._queue_obj.consume( # NOTE: aio-pika expects AbstractIncomingMessage, not IncomingMessage @@ -176,17 +171,15 @@ async def start(self) -> None: await super().start() - async def get_one(self, auto_ack: bool = False) -> "RabbitMessage": - await self._ensure_prepared() - + async def get_one(self, auto_ack: bool = False) -> "Optional[RabbitMessage]": if self._queue_obj is None: - raise SetupError("You should prepare() subscriber at first.") + raise SetupError("You should start subscriber at first.") + + assert not self.calls - while (message := await self._queue_obj.get(fail=False, no_ack=auto_ack)) is None: - await asyncio.sleep(0) + message = await self._queue_obj.get(no_ack=auto_ack) + parsed_message = await self._default_parser(message) - parsed_message = await self.parser.parse_message(message) - assert isinstance(parsed_message, RabbitMessage) return parsed_message async def close(self) -> None: From 8e71f73f9cc5463ce074f4a6bbe403645b0658ca Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sat, 24 Aug 2024 00:39:58 +0300 Subject: [PATCH 03/62] ruff satisfied --- faststream/app.py | 2 +- faststream/rabbit/subscriber/usecase.py | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/faststream/app.py b/faststream/app.py index ebf71bfb1f..83f852baaf 100644 --- a/faststream/app.py +++ b/faststream/app.py @@ -170,7 +170,7 @@ async def run( tg.start_soon(self._startup, log_level, run_extra_options) # TODO: mv it to event trigger after nats-py fixing - while not self.should_exit: # noqa: ASYNC110 + while not self.should_exit: await anyio.sleep(sleep_time) await self._shutdown(log_level) diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index b1ba4be770..857c6ddf71 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -1,4 +1,3 @@ -import asyncio from typing import ( TYPE_CHECKING, Any, @@ -15,17 +14,17 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.exceptions import SetupError -from faststream.rabbit.message import RabbitMessage from faststream.rabbit.parser import AioPikaParser from faststream.rabbit.schemas import BaseRMQInformation if TYPE_CHECKING: - from aio_pika import IncomingMessage, RobustQueue, Message + from aio_pika import IncomingMessage, RobustQueue from fast_depends.dependencies import Depends - from faststream.broker.message import StreamMessage, gen_cor_id + from faststream.broker.message import StreamMessage from faststream.broker.types import BrokerMiddleware, CustomCallable from faststream.rabbit.helpers.declarer import RabbitDeclarer + from faststream.rabbit.message import RabbitMessage from faststream.rabbit.publisher.producer import AioPikaFastProducer from faststream.rabbit.schemas import ( RabbitExchange, @@ -146,9 +145,9 @@ async def start(self) -> None: self._queue_obj = queue = await self.declarer.declare_queue(self.queue) if ( - self.exchange is not None - and not queue.passive # queue just getted from RMQ - and self.exchange.name # check Exchange is not default + self.exchange is not None + and not queue.passive # queue just getted from RMQ + and self.exchange.name # check Exchange is not default ): exchange = await self.declarer.declare_exchange(self.exchange) From d119b66f278eea5924321161e99d2ab414b63784 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sat, 24 Aug 2024 00:41:13 +0300 Subject: [PATCH 04/62] fixes --- scripts/start_test_env.sh | 2 +- scripts/stop_test_env.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/start_test_env.sh b/scripts/start_test_env.sh index 906556db41..a0ae1627b8 100755 --- a/scripts/start_test_env.sh +++ b/scripts/start_test_env.sh @@ -2,4 +2,4 @@ source ./scripts/set_variables.sh -docker compose -p $DOCKER_COMPOSE_PROJECT -f docs/includes/docker-compose.yaml up -d --no-recreate +docker-compose -p $DOCKER_COMPOSE_PROJECT -f docs/includes/docker-compose.yaml up -d --no-recreate diff --git a/scripts/stop_test_env.sh b/scripts/stop_test_env.sh index 76ab4a3ee0..5d77186357 100755 --- a/scripts/stop_test_env.sh +++ b/scripts/stop_test_env.sh @@ -2,4 +2,4 @@ source ./scripts/set_variables.sh -docker compose -p $DOCKER_COMPOSE_PROJECT -f docs/includes/docker-compose.yaml down +docker-compose -p $DOCKER_COMPOSE_PROJECT -f docs/includes/docker-compose.yaml down From 907482c7ee2b2f78c12bc8ac9c1f12480bf73a2d Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sat, 24 Aug 2024 00:41:53 +0300 Subject: [PATCH 05/62] fixes --- faststream/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faststream/app.py b/faststream/app.py index 83f852baaf..ebf71bfb1f 100644 --- a/faststream/app.py +++ b/faststream/app.py @@ -170,7 +170,7 @@ async def run( tg.start_soon(self._startup, log_level, run_extra_options) # TODO: mv it to event trigger after nats-py fixing - while not self.should_exit: + while not self.should_exit: # noqa: ASYNC110 await anyio.sleep(sleep_time) await self._shutdown(log_level) From 1be821eb2f85e85958c82769fd158cba08b16ef2 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sat, 24 Aug 2024 00:48:08 +0300 Subject: [PATCH 06/62] fixes --- faststream/rabbit/subscriber/usecase.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 857c6ddf71..8e92579287 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -9,6 +9,7 @@ Union, ) +from aio_pika.abc import TimeoutType from typing_extensions import override from faststream.broker.publisher.fake import FakePublisher @@ -46,7 +47,6 @@ class LogicSubscriber( _consumer_tag: Optional[str] _queue_obj: Optional["RobustQueue"] _producer: Optional["AioPikaFastProducer"] - _prepared: bool def __init__( self, @@ -96,7 +96,6 @@ def __init__( self.app_id = None self.virtual_host = "" self.declarer = None - self._prepared = False @override def setup( # type: ignore[override] @@ -170,15 +169,21 @@ async def start(self) -> None: await super().start() - async def get_one(self, auto_ack: bool = False) -> "Optional[RabbitMessage]": + async def get_one( + self, + no_ack: bool = False, + fail: bool = True, + timeout: TimeoutType = 5, + ) -> "Optional[RabbitMessage]": if self._queue_obj is None: raise SetupError("You should start subscriber at first.") assert not self.calls - message = await self._queue_obj.get(no_ack=auto_ack) + message = await self._queue_obj.get(no_ack=no_ack, fail=fail, timeout=timeout) # type: ignore[call-overload] parsed_message = await self._default_parser(message) + assert isinstance(parsed_message, RabbitMessage) return parsed_message async def close(self) -> None: From 2ce5a5dc2c8eabeae66f6fbbd3a5f4a28ef018a4 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sat, 24 Aug 2024 00:49:32 +0300 Subject: [PATCH 07/62] fixes --- faststream/rabbit/subscriber/usecase.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 8e92579287..3bea30d4e9 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -15,6 +15,8 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.exceptions import SetupError +from faststream.rabbit.helpers.declarer import RabbitDeclarer +from faststream.rabbit.message import RabbitMessage from faststream.rabbit.parser import AioPikaParser from faststream.rabbit.schemas import BaseRMQInformation @@ -25,7 +27,6 @@ from faststream.broker.message import StreamMessage from faststream.broker.types import BrokerMiddleware, CustomCallable from faststream.rabbit.helpers.declarer import RabbitDeclarer - from faststream.rabbit.message import RabbitMessage from faststream.rabbit.publisher.producer import AioPikaFastProducer from faststream.rabbit.schemas import ( RabbitExchange, From c439b4a2a622ce1a6fb1afd8da79a03b9fa4268d Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sat, 24 Aug 2024 13:56:03 +0300 Subject: [PATCH 08/62] Kafka subscriber.get_one() --- faststream/kafka/subscriber/usecase.py | 14 ++++++++++++++ faststream/rabbit/subscriber/usecase.py | 22 +++++++++++----------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index c1bcfa6511..609923a120 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -164,6 +164,9 @@ async def start(self) -> None: await consumer.start() await super().start() + # if not self.calls: + # return + self.task = asyncio.create_task(self._consume()) async def close(self) -> None: @@ -178,6 +181,17 @@ async def close(self) -> None: self.task = None + async def get_one(self) -> "Optional[KafkaMessage]": + assert self.consumer, "You should start subscriber at first." + + assert not self.calls + + message = await self.consumer.getone() + parsed_message = await self._default_parser(message) + + assert isinstance(parsed_message, KafkaMessage) + return parsed_message + def _make_response_publisher( self, message: "StreamMessage[Any]", diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 3bea30d4e9..99841de56f 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -170,6 +170,17 @@ async def start(self) -> None: await super().start() + async def close(self) -> None: + await super().close() + + if self._queue_obj is not None: + if self._consumer_tag is not None: # pragma: no branch + if not self._queue_obj.channel.is_closed: + await self._queue_obj.cancel(self._consumer_tag) + self._consumer_tag = None + + self._queue_obj = None + async def get_one( self, no_ack: bool = False, @@ -187,17 +198,6 @@ async def get_one( assert isinstance(parsed_message, RabbitMessage) return parsed_message - async def close(self) -> None: - await super().close() - - if self._queue_obj is not None: - if self._consumer_tag is not None: # pragma: no branch - if not self._queue_obj.channel.is_closed: - await self._queue_obj.cancel(self._consumer_tag) - self._consumer_tag = None - - self._queue_obj = None - def _make_response_publisher( self, message: "StreamMessage[Any]", From 2ed83c58c2c2296193bec5a7166a788abe750dc6 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sat, 24 Aug 2024 14:04:07 +0300 Subject: [PATCH 09/62] Confluent subscriber.get_one() --- faststream/confluent/subscriber/usecase.py | 15 +++++++++++++++ faststream/kafka/subscriber/usecase.py | 4 ++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index 9fd244ba8e..a1836147c5 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -19,6 +19,7 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.broker.types import MsgType +from faststream.confluent.message import KafkaMessage from faststream.confluent.parser import AsyncConfluentParser from faststream.confluent.schemas import TopicPartition @@ -152,6 +153,9 @@ async def start(self) -> None: await super().start() + if not self.calls: + return + self.task = asyncio.create_task(self._consume()) async def close(self) -> None: @@ -166,6 +170,17 @@ async def close(self) -> None: self.task = None + async def get_one(self, timeout: float = 0.1) -> "Optional[KafkaMessage]": + assert self.consumer, "You should start subscriber at first." + + assert not self.calls + + message = await self.consumer.getone(timeout=timeout) + parsed_message = await self._default_parser(message) + + assert isinstance(parsed_message, KafkaMessage) + return parsed_message + def _make_response_publisher( self, message: "StreamMessage[Any]", diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index 609923a120..913956527b 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -164,8 +164,8 @@ async def start(self) -> None: await consumer.start() await super().start() - # if not self.calls: - # return + if not self.calls: + return self.task = asyncio.create_task(self._consume()) From 9a2efde5501d20efea9a377125ed558781df4f3b Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sat, 24 Aug 2024 19:28:30 +0300 Subject: [PATCH 10/62] refactor: polist RMQ get_one method --- faststream/broker/subscriber/proto.py | 5 ++ faststream/broker/subscriber/usecase.py | 19 +++--- faststream/confluent/subscriber/usecase.py | 2 +- faststream/kafka/subscriber/usecase.py | 2 +- faststream/rabbit/subscriber/usecase.py | 71 +++++++++++++++------- faststream/utils/functions.py | 4 ++ 6 files changed, 69 insertions(+), 34 deletions(-) diff --git a/faststream/broker/subscriber/proto.py b/faststream/broker/subscriber/proto.py index 116d003d48..23db32d9b2 100644 --- a/faststream/broker/subscriber/proto.py +++ b/faststream/broker/subscriber/proto.py @@ -85,6 +85,11 @@ async def consume(self, msg: MsgType) -> Any: ... @abstractmethod async def process_message(self, msg: MsgType) -> Any: ... + @abstractmethod + async def get_one( + self, *, timeout: float = 5.0 + ) -> "Optional[StreamMessage[MsgType]]": ... + @abstractmethod def add_call( self, diff --git a/faststream/broker/subscriber/usecase.py b/faststream/broker/subscriber/usecase.py index 1897826d9a..07bea50860 100644 --- a/faststream/broker/subscriber/usecase.py +++ b/faststream/broker/subscriber/usecase.py @@ -107,8 +107,8 @@ def __init__( """Initialize a new instance of the class.""" self.calls = [] - self._default_parser = default_parser - self._default_decoder = default_decoder + self._parser = default_parser + self._decoder = default_decoder self._no_reply = no_reply # Watcher args self._no_ack = no_ack @@ -163,18 +163,17 @@ def setup( # type: ignore[override] for call in self.calls: if parser := call.item_parser or broker_parser: - async_parser = resolve_custom_func( - to_async(parser), self._default_parser - ) + async_parser = resolve_custom_func(to_async(parser), self._parser) else: - async_parser = self._default_parser + async_parser = self._parser if decoder := call.item_decoder or broker_decoder: - async_decoder = resolve_custom_func( - to_async(decoder), self._default_decoder - ) + async_decoder = resolve_custom_func(to_async(decoder), self._decoder) else: - async_decoder = self._default_decoder + async_decoder = self._decoder + + self._parser = async_parser + self._decoder = async_decoder call.setup( parser=async_parser, diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index a1836147c5..9075e298d6 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -176,7 +176,7 @@ async def get_one(self, timeout: float = 0.1) -> "Optional[KafkaMessage]": assert not self.calls message = await self.consumer.getone(timeout=timeout) - parsed_message = await self._default_parser(message) + parsed_message = await self._parser(message) assert isinstance(parsed_message, KafkaMessage) return parsed_message diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index 913956527b..3951e8751b 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -187,7 +187,7 @@ async def get_one(self) -> "Optional[KafkaMessage]": assert not self.calls message = await self.consumer.getone() - parsed_message = await self._default_parser(message) + parsed_message = await self._parser(message) assert isinstance(parsed_message, KafkaMessage) return parsed_message diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 99841de56f..07cdadc0d1 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -1,6 +1,9 @@ +from contextlib import AsyncExitStack +from functools import partial from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Dict, Iterable, @@ -9,16 +12,16 @@ Union, ) -from aio_pika.abc import TimeoutType +import anyio from typing_extensions import override from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.exceptions import SetupError from faststream.rabbit.helpers.declarer import RabbitDeclarer -from faststream.rabbit.message import RabbitMessage from faststream.rabbit.parser import AioPikaParser from faststream.rabbit.schemas import BaseRMQInformation +from faststream.utils.functions import return_input if TYPE_CHECKING: from aio_pika import IncomingMessage, RobustQueue @@ -27,6 +30,7 @@ from faststream.broker.message import StreamMessage from faststream.broker.types import BrokerMiddleware, CustomCallable from faststream.rabbit.helpers.declarer import RabbitDeclarer + from faststream.rabbit.message import RabbitMessage from faststream.rabbit.publisher.producer import AioPikaFastProducer from faststream.rabbit.schemas import ( RabbitExchange, @@ -159,14 +163,12 @@ async def start(self) -> None: robust=self.queue.robust, ) - if not self.calls: - return - - self._consumer_tag = await self._queue_obj.consume( - # NOTE: aio-pika expects AbstractIncomingMessage, not IncomingMessage - self.consume, # type: ignore[arg-type] - arguments=self.consume_args, - ) + if self.calls: + self._consumer_tag = await self._queue_obj.consume( + # NOTE: aio-pika expects AbstractIncomingMessage, not IncomingMessage + self.consume, # type: ignore[arg-type] + arguments=self.consume_args, + ) await super().start() @@ -182,21 +184,46 @@ async def close(self) -> None: self._queue_obj = None async def get_one( - self, - no_ack: bool = False, - fail: bool = True, - timeout: TimeoutType = 5, + self, + *, + timeout: float = 5.0, + no_ack: bool = True, ) -> "Optional[RabbitMessage]": - if self._queue_obj is None: - raise SetupError("You should start subscriber at first.") - - assert not self.calls + assert self._queue_obj, "You should start subscriber at first." # nosec B101 + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." + + sleep_interval = timeout / 10 + + raw_message: Optional[IncomingMessage] = None + with anyio.move_on_after(timeout): + while ( # noqa: ASYNC110 + raw_message := await self._queue_obj.get( + fail=False, + no_ack=no_ack, + timeout=timeout, + ) + ) is None: + await anyio.sleep(sleep_interval) + + if raw_message is None: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[[RabbitMessage], Awaitable[RabbitMessage]] = ( + return_input + ) + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) - message = await self._queue_obj.get(no_ack=no_ack, fail=fail, timeout=timeout) # type: ignore[call-overload] - parsed_message = await self._default_parser(message) + parsed_msg = await self._parser(raw_message) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) - assert isinstance(parsed_message, RabbitMessage) - return parsed_message + raise AssertionError("unreachable") def _make_response_publisher( self, diff --git a/faststream/utils/functions.py b/faststream/utils/functions.py index 81b1b06db9..453c70ffc7 100644 --- a/faststream/utils/functions.py +++ b/faststream/utils/functions.py @@ -80,3 +80,7 @@ def drop_response_type( ) -> CallModel[F_Spec, F_Return]: model.response_model = None return model + + +async def return_input(x: Any) -> Any: + return x From bc85f53e6aa6853e8157d3106895f6f9c702eff9 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sat, 24 Aug 2024 21:46:22 +0300 Subject: [PATCH 11/62] Small refactoring of get_one --- faststream/rabbit/subscriber/usecase.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 07cdadc0d1..642a2111fb 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -211,9 +211,8 @@ async def get_one( return None async with AsyncExitStack() as stack: - return_msg: Callable[[RabbitMessage], Awaitable[RabbitMessage]] = ( - return_input - ) + return_msg: Callable[[RabbitMessage], Awaitable[RabbitMessage]] + for m in self._broker_middlewares: mid = m(raw_message) await stack.enter_async_context(mid) @@ -223,8 +222,6 @@ async def get_one( parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) - raise AssertionError("unreachable") - def _make_response_publisher( self, message: "StreamMessage[Any]", From e2347be940b33772c8edf7bab682e76b1f916d36 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sat, 24 Aug 2024 22:44:09 +0300 Subject: [PATCH 12/62] Rabbit get_one error fix --- faststream/rabbit/subscriber/usecase.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 642a2111fb..223a84c0c6 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -211,7 +211,9 @@ async def get_one( return None async with AsyncExitStack() as stack: - return_msg: Callable[[RabbitMessage], Awaitable[RabbitMessage]] + return_msg: Callable[[RabbitMessage], Awaitable[RabbitMessage]] = ( + return_input + ) for m in self._broker_middlewares: mid = m(raw_message) From 768f63729f5ad2220c25bcb8f4a5d534bc29130d Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sat, 24 Aug 2024 22:44:30 +0300 Subject: [PATCH 13/62] Kafka get_one update --- faststream/kafka/subscriber/usecase.py | 33 ++++++++++++++++++++------ 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index 3951e8751b..ecee737c6e 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -1,5 +1,7 @@ import asyncio from abc import ABC, abstractmethod +from contextlib import AsyncExitStack +from functools import partial from itertools import chain from typing import ( TYPE_CHECKING, @@ -10,7 +12,7 @@ List, Optional, Sequence, - Tuple, + Tuple, Awaitable, ) import anyio @@ -28,6 +30,7 @@ ) from faststream.kafka.message import KafkaAckableMessage, KafkaMessage from faststream.kafka.parser import AioKafkaBatchParser, AioKafkaParser +from faststream.utils.functions import return_input from faststream.utils.path import compile_path if TYPE_CHECKING: @@ -181,16 +184,32 @@ async def close(self) -> None: self.task = None - async def get_one(self) -> "Optional[KafkaMessage]": + async def get_one(self, timeout: float = 5) -> "Optional[KafkaMessage]": assert self.consumer, "You should start subscriber at first." + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." - assert not self.calls + raw_messages = await self.consumer.getmany(timeout_ms=timeout * 1000, max_records=1) - message = await self.consumer.getone() - parsed_message = await self._parser(message) + if not raw_messages: + return None - assert isinstance(parsed_message, KafkaMessage) - return parsed_message + (raw_message,) ,= raw_messages.values() + + async with AsyncExitStack() as stack: + return_msg: Callable[[KafkaMessage], Awaitable[KafkaMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._parser(raw_message) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) def _make_response_publisher( self, From 81b3edf846d7476bce1f1007bd03d3700cfe0f14 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sat, 24 Aug 2024 22:47:21 +0300 Subject: [PATCH 14/62] Confluent get_one update --- faststream/confluent/subscriber/usecase.py | 28 ++++++++++++++++------ 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index 9075e298d6..763a0bce52 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -1,5 +1,7 @@ import asyncio from abc import ABC, abstractmethod +from contextlib import AsyncExitStack +from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -9,7 +11,7 @@ List, Optional, Sequence, - Tuple, + Tuple, Awaitable, ) import anyio @@ -22,6 +24,7 @@ from faststream.confluent.message import KafkaMessage from faststream.confluent.parser import AsyncConfluentParser from faststream.confluent.schemas import TopicPartition +from faststream.utils.functions import return_input if TYPE_CHECKING: from fast_depends.dependencies import Depends @@ -170,16 +173,27 @@ async def close(self) -> None: self.task = None - async def get_one(self, timeout: float = 0.1) -> "Optional[KafkaMessage]": + async def get_one(self, timeout: float = 5.0) -> "Optional[KafkaMessage]": assert self.consumer, "You should start subscriber at first." + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." - assert not self.calls + raw_message = await self.consumer.getone(timeout=timeout) - message = await self.consumer.getone(timeout=timeout) - parsed_message = await self._parser(message) + async with AsyncExitStack() as stack: + return_msg: Callable[[KafkaMessage], Awaitable[KafkaMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) - assert isinstance(parsed_message, KafkaMessage) - return parsed_message + parsed_msg = await self._parser(raw_message) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) def _make_response_publisher( self, From 3d89d0f33f9cefcee68d38c05c1a620265f19faf Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 25 Aug 2024 13:02:39 +0300 Subject: [PATCH 15/62] Redis channel get_one --- faststream/redis/subscriber/usecase.py | 49 +++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index dcfcd226f6..b6806acb01 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -1,7 +1,8 @@ import asyncio from abc import abstractmethod -from contextlib import suppress +from contextlib import suppress, AsyncExitStack from copy import deepcopy +from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -30,6 +31,7 @@ DefaultStreamMessage, PubSubMessage, UnifyRedisDict, + RedisMessage ) from faststream.redis.parser import ( RedisBatchListParser, @@ -39,6 +41,7 @@ RedisStreamParser, ) from faststream.redis.schemas import ListSub, PubSub, StreamSub +from faststream.utils.functions import return_input if TYPE_CHECKING: from fast_depends.dependencies import Depends @@ -268,6 +271,9 @@ async def start(self) -> None: else: await psub.subscribe(self.channel.name) + if not self.calls: + return + await super().start(psub) async def close(self) -> None: @@ -278,6 +284,47 @@ async def close(self) -> None: await super().close() + async def get_one( + self, + *, + timeout: float = 5.0, + ignore_subscribe_messages: bool = False, + ) -> "Optional[RedisMessage]": + assert self.subscription, "You should start subscriber at first." # nosec B101 + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." + + sleep_interval = timeout / 10 + + raw_message = None + + with anyio.move_on_after(timeout): + while ( # noqa: ASYNC110 + raw_message := await self.subscription.get_message( + timeout=timeout, + ignore_subscribe_messages=ignore_subscribe_messages, + ) + ) is None: + await anyio.sleep(sleep_interval) + + if not raw_message: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._parser(raw_message) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) + async def _get_msgs(self, psub: RPubSub) -> None: raw_msg = await psub.get_message( ignore_subscribe_messages=True, From 9ec0e56ca17a72a8d1a73d040697b0f8c934cf64 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 25 Aug 2024 13:47:53 +0300 Subject: [PATCH 16/62] Redis list get_one draft --- faststream/redis/subscriber/usecase.py | 44 ++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index b6806acb01..21dd408c5a 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -409,6 +409,10 @@ async def start(self) -> None: return assert self._client, "You should setup subscriber at first." # nosec B101 + + if not self.calls: + return None + await super().start(self._client) def add_prefix(self, prefix: str) -> None: @@ -450,6 +454,46 @@ def __init__( include_in_schema=include_in_schema, ) + async def get_one( + self, + *, + timeout: float = 5.0, + ignore_subscribe_messages: bool = False, + ) -> "Optional[RedisMessage]": + # assert self.list_sub, "You should start subscriber at first." # nosec B101 + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." + + sleep_interval = timeout / 10 + raw_message = None + + with anyio.move_on_after(timeout): + while ( # noqa: ASYNC110 + raw_message := await self._client.lpop(name=self.list_sub.name) + ) is None: + await anyio.sleep(sleep_interval) + + if not raw_message: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + message = DefaultListMessage( + type="list", + data=raw_message, + channel=self.list_sub.name, + ) + return await return_msg(message) + async def _get_msgs(self, client: "Redis[bytes]") -> None: raw_msg = await client.lpop(name=self.list_sub.name) From 053d2f2f7f270e923b5ae735140e86429ca46379 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 25 Aug 2024 13:58:08 +0300 Subject: [PATCH 17/62] Redis batch list get_one draft --- faststream/redis/subscriber/usecase.py | 40 +++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index 21dd408c5a..10c7ca194d 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -458,7 +458,6 @@ async def get_one( self, *, timeout: float = 5.0, - ignore_subscribe_messages: bool = False, ) -> "Optional[RedisMessage]": # assert self.list_sub, "You should start subscriber at first." # nosec B101 assert ( # nosec B101 @@ -543,6 +542,45 @@ def __init__( include_in_schema=include_in_schema, ) + async def get_one( + self, + *, + timeout: float = 5.0, + ) -> "Optional[RedisMessage]": + # assert self.list_sub, "You should start subscriber at first." # nosec B101 + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." + + sleep_interval = timeout / 10 + raw_message = None + + with anyio.move_on_after(timeout): + while ( # noqa: ASYNC110 + raw_message := await self._client.lpop(name=self.list_sub.name) + ) is None: + await anyio.sleep(sleep_interval) + + if not raw_message: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + message = DefaultListMessage( + type="list", + data=raw_message, + channel=self.list_sub.name, + ) + return await return_msg(message) + async def _get_msgs(self, client: "Redis[bytes]") -> None: raw_msgs = await client.lpop( name=self.list_sub.name, From fe4d00041191a5b46b5801feec54b8fcd8d12c00 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 25 Aug 2024 15:43:12 +0300 Subject: [PATCH 18/62] Redis channel get_one update and list get_one message decoding --- faststream/redis/subscriber/usecase.py | 34 ++++++++++++++------------ 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index 10c7ca194d..09436846ac 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -288,7 +288,6 @@ async def get_one( self, *, timeout: float = 5.0, - ignore_subscribe_messages: bool = False, ) -> "Optional[RedisMessage]": assert self.subscription, "You should start subscriber at first." # nosec B101 assert ( # nosec B101 @@ -297,18 +296,13 @@ async def get_one( sleep_interval = timeout / 10 - raw_message = None + message: Optional[PubSubMessage] = None with anyio.move_on_after(timeout): - while ( # noqa: ASYNC110 - raw_message := await self.subscription.get_message( - timeout=timeout, - ignore_subscribe_messages=ignore_subscribe_messages, - ) - ) is None: + while (message := await self._get_message(self.subscription)) is None: await anyio.sleep(sleep_interval) - if not raw_message: + if not message: return None async with AsyncExitStack() as stack: @@ -317,28 +311,34 @@ async def get_one( ) for m in self._broker_middlewares: - mid = m(raw_message) + mid = m(message) await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) - parsed_msg = await self._parser(raw_message) + parsed_msg = await self._parser(message) parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) - async def _get_msgs(self, psub: RPubSub) -> None: + async def _get_message(self, psub: RPubSub) -> Optional[PubSubMessage]: raw_msg = await psub.get_message( ignore_subscribe_messages=True, timeout=self.channel.polling_interval, ) if raw_msg: - msg = PubSubMessage( + return PubSubMessage( type=raw_msg["type"], data=raw_msg["data"], channel=raw_msg["channel"].decode(), pattern=raw_msg["pattern"], ) - await self.consume(msg) # type: ignore[arg-type] + + return + + + async def _get_msgs(self, psub: RPubSub) -> None: + msg = await self._get_message(psub) + await self.consume(msg) # type: ignore[arg-type] def add_prefix(self, prefix: str) -> None: new_ch = deepcopy(self.channel) @@ -491,7 +491,11 @@ async def get_one( data=raw_message, channel=self.list_sub.name, ) - return await return_msg(message) + + parsed_message = await self._parser(message) + parsed_message._decoded_body = await self._decoder(parsed_message) + + return await return_msg(parsed_message) async def _get_msgs(self, client: "Redis[bytes]") -> None: raw_msg = await client.lpop(name=self.list_sub.name) From 2cd00ddf7255c26123f7762f49f840eed835e9b9 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 25 Aug 2024 16:02:32 +0300 Subject: [PATCH 19/62] Redis list batch get_one message decoding --- faststream/redis/subscriber/usecase.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index 09436846ac..2ab9c4c1f0 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -583,7 +583,11 @@ async def get_one( data=raw_message, channel=self.list_sub.name, ) - return await return_msg(message) + + parsed_message = await self._parser(message) + parsed_message._decoded_body = await self._decoder(parsed_message) + + return await return_msg(parsed_message) async def _get_msgs(self, client: "Redis[bytes]") -> None: raw_msgs = await client.lpop( From cb9e0d7fb510ba84c89328ed944956b9859f826e Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 25 Aug 2024 19:02:44 +0300 Subject: [PATCH 20/62] Redis stream get_one --- faststream/redis/subscriber/usecase.py | 54 +++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index 2ab9c4c1f0..864915d111 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -271,9 +271,6 @@ async def start(self) -> None: else: await psub.subscribe(self.channel.name) - if not self.calls: - return - await super().start(psub) async def close(self) -> None: @@ -459,7 +456,7 @@ async def get_one( *, timeout: float = 5.0, ) -> "Optional[RedisMessage]": - # assert self.list_sub, "You should start subscriber at first." # nosec B101 + assert self._client, "You should start subscriber at first." assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." @@ -551,7 +548,7 @@ async def get_one( *, timeout: float = 5.0, ) -> "Optional[RedisMessage]": - # assert self.list_sub, "You should start subscriber at first." # nosec B101 + assert self._client, "You should start subscriber at first." assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." @@ -663,6 +660,9 @@ async def start(self) -> None: assert self._client, "You should setup subscriber at first." # nosec B101 + if not self.calls: + return + client = self._client self.extra_watcher_options.update( @@ -795,6 +795,50 @@ def __init__( include_in_schema=include_in_schema, ) + async def get_one( + self, + *, + timeout: float = 5.0, + ) -> "Optional[RedisMessage]": + assert self._client, "You should start subscriber at first." # nosec B101 + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." + + stream_message = await self._client.xread( + {self.stream_sub.name: self.last_id}, + block=timeout * 100000, + count=1, + ) + + if not stream_message: + return None + + (stream_name, ((message_id, raw_message),)) ,= stream_message + + self.last_id = message_id.decode() + + msg = DefaultStreamMessage( + type="stream", + channel=stream_name.decode(), + message_ids=[message_id], + data=raw_message, + ) + + async with AsyncExitStack() as stack: + return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(msg) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._parser(msg) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) + async def _get_msgs( self, read: Callable[ From 80c55b8e8975cd7b2ac9c191dcd3b7a72172cf9f Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 25 Aug 2024 19:05:56 +0300 Subject: [PATCH 21/62] Redis batch stream get_one --- faststream/redis/subscriber/usecase.py | 44 ++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index 864915d111..0ad082dc4b 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -908,6 +908,50 @@ def __init__( include_in_schema=include_in_schema, ) + async def get_one( + self, + *, + timeout: float = 5.0, + ) -> "Optional[RedisMessage]": + assert self._client, "You should start subscriber at first." # nosec B101 + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." + + stream_message = await self._client.xread( + {self.stream_sub.name: self.last_id}, + block=timeout * 100000, + count=1, + ) + + if not stream_message: + return None + + (stream_name, ((message_id, raw_message),)) ,= stream_message + + self.last_id = message_id.decode() + + msg = DefaultStreamMessage( + type="stream", + channel=stream_name.decode(), + message_ids=[message_id], + data=raw_message, + ) + + async with AsyncExitStack() as stack: + return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(msg) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._parser(msg) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) + async def _get_msgs( self, read: Callable[ From 0dc83189fb2040d06f7e3bdcc747dc189b759e2a Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 25 Aug 2024 20:53:43 +0300 Subject: [PATCH 22/62] Redis channel get_one fix --- faststream/redis/subscriber/usecase.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index 0ad082dc4b..5e3b639354 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -271,6 +271,9 @@ async def start(self) -> None: else: await psub.subscribe(self.channel.name) + if not self.calls: + return None + await super().start(psub) async def close(self) -> None: From b269ba6b3dd8a02e42f3f48dd703315aa214bcff Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 25 Aug 2024 23:57:15 +0300 Subject: [PATCH 23/62] Update brokers start methods --- faststream/confluent/subscriber/usecase.py | 6 ++---- faststream/kafka/subscriber/usecase.py | 6 ++---- faststream/nats/subscriber/usecase.py | 7 +++++++ faststream/redis/subscriber/usecase.py | 17 +++++------------ 4 files changed, 16 insertions(+), 20 deletions(-) diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index 763a0bce52..db4edc79b8 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -156,10 +156,8 @@ async def start(self) -> None: await super().start() - if not self.calls: - return - - self.task = asyncio.create_task(self._consume()) + if self.calls: + self.task = asyncio.create_task(self._consume()) async def close(self) -> None: await super().close() diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index ecee737c6e..30cb52f460 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -167,10 +167,8 @@ async def start(self) -> None: await consumer.start() await super().start() - if not self.calls: - return - - self.task = asyncio.create_task(self._consume()) + if self.calls: + self.task = asyncio.create_task(self._consume()) async def close(self) -> None: await super().close() diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 85c133770c..974c7bd3da 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -155,6 +155,10 @@ async def start(self) -> None: """Create NATS subscription and start consume tasks.""" assert self._connection, NOT_CONNECTED_YET # nosec B101 await super().start() + + if not self.calls: + return None + await self._create_subscription(connection=self._connection) async def close(self) -> None: @@ -236,6 +240,7 @@ def get_routing_hash( class _DefaultSubscriber(LogicSubscriber[MsgType]): + def get_one(self, *args, **kwargs): ... def __init__( self, *, @@ -435,6 +440,8 @@ async def _create_subscription( # type: ignore[override] **self.extra_options, ) + + def get_log_context( self, message: Annotated[ diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index 5e3b639354..bdcfcb580d 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -160,10 +160,12 @@ async def start( await super().start() start_signal = anyio.Event() - self.task = asyncio.create_task(self._consume(*args, start_signal=start_signal)) - with anyio.fail_after(3.0): - await start_signal.wait() + if self.calls: + self.task = asyncio.create_task(self._consume(*args, start_signal=start_signal)) + + with anyio.fail_after(3.0): + await start_signal.wait() async def _consume(self, *args: Any, start_signal: anyio.Event) -> None: connected = True @@ -271,9 +273,6 @@ async def start(self) -> None: else: await psub.subscribe(self.channel.name) - if not self.calls: - return None - await super().start(psub) async def close(self) -> None: @@ -410,9 +409,6 @@ async def start(self) -> None: assert self._client, "You should setup subscriber at first." # nosec B101 - if not self.calls: - return None - await super().start(self._client) def add_prefix(self, prefix: str) -> None: @@ -663,9 +659,6 @@ async def start(self) -> None: assert self._client, "You should setup subscriber at first." # nosec B101 - if not self.calls: - return - client = self._client self.extra_watcher_options.update( From e4aea956deacbfd9e53d4ddcffed49be60ad8bfb Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Mon, 26 Aug 2024 00:11:43 +0300 Subject: [PATCH 24/62] remove unnecessary code --- faststream/nats/subscriber/usecase.py | 1 - 1 file changed, 1 deletion(-) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 974c7bd3da..98663022e8 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -240,7 +240,6 @@ def get_routing_hash( class _DefaultSubscriber(LogicSubscriber[MsgType]): - def get_one(self, *args, **kwargs): ... def __init__( self, *, From 8ef531ec25a42b94fa9bdc3ffc56538eb5366370 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Tue, 27 Aug 2024 11:41:22 +0300 Subject: [PATCH 25/62] Nats CoreSubscriber.get_one --- faststream/nats/subscriber/usecase.py | 45 ++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 98663022e8..220617a107 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -1,6 +1,7 @@ import asyncio from abc import abstractmethod -from contextlib import suppress +from contextlib import suppress, AsyncExitStack +from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -27,6 +28,7 @@ from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.broker.types import CustomCallable, MsgType from faststream.exceptions import NOT_CONNECTED_YET +from faststream.nats.message import NatsMessage from faststream.nats.parser import ( BatchParser, JsParser, @@ -41,6 +43,7 @@ ) from faststream.types import AnyDict, LoggerProto, SendableMessage from faststream.utils.context.repository import context +from faststream.utils.functions import return_input if TYPE_CHECKING: from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -422,6 +425,45 @@ def __init__( include_in_schema=include_in_schema, ) + async def get_one( + self, + *, + timeout: float = 5.0, + ) -> "Optional[NatsMessage]": + assert self._connection + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." + + assert self.subscription is None + + self.subscription = await self._connection.subscribe( + subject=self.clear_subject, + queue=self.queue, + **self.extra_options, + ) + + raw_message = None + async for raw_message in self.subscription.messages: + break + + if not raw_message: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._parser(raw_message) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) + @override async def _create_subscription( # type: ignore[override] self, @@ -429,6 +471,7 @@ async def _create_subscription( # type: ignore[override] connection: "Client", ) -> None: """Create NATS subscription and start consume task.""" + if self.subscription: return From e4d7079d9fe2ae8ba616292b35fd399ee38761bf Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Tue, 27 Aug 2024 12:42:11 +0300 Subject: [PATCH 26/62] Nats CoreSubscriber.get_one timeout support --- faststream/nats/subscriber/usecase.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 220617a107..f860450f88 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -435,17 +435,18 @@ async def get_one( not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." - assert self.subscription is None - - self.subscription = await self._connection.subscribe( - subject=self.clear_subject, - queue=self.queue, - **self.extra_options, - ) + if self.subscription is None: + self.subscription = await self._connection.subscribe( + subject=self.clear_subject, + queue=self.queue, + **self.extra_options, + ) raw_message = None - async for raw_message in self.subscription.messages: - break + + with anyio.move_on_after(timeout): + async for raw_message in self.subscription.messages: + break if not raw_message: return None From f6d136b7ee78cde53d8bb225562238be99f6442e Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Tue, 27 Aug 2024 15:32:46 +0300 Subject: [PATCH 27/62] Nats PullStreamSubscriber get_one --- faststream/nats/subscriber/usecase.py | 91 ++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 2 deletions(-) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index f860450f88..734a7d600a 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -157,10 +157,11 @@ def clear_subject(self) -> str: async def start(self) -> None: """Create NATS subscription and start consume tasks.""" assert self._connection, NOT_CONNECTED_YET # nosec B101 - await super().start() if not self.calls: - return None + return + + await super().start() await self._create_subscription(connection=self._connection) @@ -613,6 +614,7 @@ def get_log_context( ], ) -> Dict[str, str]: """Log context factory using in `self.consume` scope.""" + return self.build_log_context( message=message, subject=self._resolved_subject_string, @@ -752,6 +754,35 @@ def __init__( include_in_schema=include_in_schema, ) + async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: + if not self.subscription: + await self._create_subscription(connection=self._connection) + + try: + raw_message ,= await self.subscription.fetch( + batch=1, + timeout=timeout, + ) + except TimeoutError: + raw_message = None + + if not raw_message: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._parser(raw_message) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) + @override async def _create_subscription( # type: ignore[override] self, @@ -832,6 +863,36 @@ def __init__( include_in_schema=include_in_schema, ) + async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: + if not self.subscription: + self.subscription = await self._connection.pull_subscribe( + subject=self.clear_subject, + config=self.config, + **self.extra_options, + ) + + raw_message ,= await self.subscription.fetch( + batch=1, + timeout=timeout, + ) + + if not raw_message: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._parser(raw_message) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) + @override async def _create_subscription( # type: ignore[override] self, @@ -901,6 +962,32 @@ def __init__( include_in_schema=include_in_schema, ) + async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: + if not self.subscription: + await self._create_subscription(connection=self._connection) + + raw_message ,= await self.subscription.fetch( + batch=1, + timeout=timeout, + ) + + if not raw_message: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._parser(raw_message) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) + @override async def _create_subscription( # type: ignore[override] self, From 0fa93a69305683473a1469ea6034fc8936958e98 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Tue, 27 Aug 2024 20:46:24 +0300 Subject: [PATCH 28/62] Nats KeyValueWatchSubscriber get_one prototype --- faststream/nats/subscriber/usecase.py | 65 +++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 4 deletions(-) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 734a7d600a..50fafe08f6 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -28,7 +28,7 @@ from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.broker.types import CustomCallable, MsgType from faststream.exceptions import NOT_CONNECTED_YET -from faststream.nats.message import NatsMessage +from faststream.nats.message import NatsMessage, NatsKvMessage from faststream.nats.parser import ( BatchParser, JsParser, @@ -431,7 +431,7 @@ async def get_one( *, timeout: float = 5.0, ) -> "Optional[NatsMessage]": - assert self._connection + assert self._connection, "Please, start() subscriber first" assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." @@ -755,8 +755,14 @@ def __init__( ) async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: + assert self._connection, "Please, start() subscriber first" + if not self.subscription: - await self._create_subscription(connection=self._connection) + self.subscription = await self._connection.pull_subscribe( + subject=self.clear_subject, + config=self.config, + **self.extra_options, + ) try: raw_message ,= await self.subscription.fetch( @@ -864,6 +870,8 @@ def __init__( ) async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: + assert self._connection, "Please, start() subscriber first" + if not self.subscription: self.subscription = await self._connection.pull_subscribe( subject=self.clear_subject, @@ -963,8 +971,14 @@ def __init__( ) async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: + assert self._connection, "Please, start() subscriber first" + if not self.subscription: - await self._create_subscription(connection=self._connection) + self.subscription = await self._connection.pull_subscribe( + subject=self.clear_subject, + config=self.config, + **self.extra_options, + ) raw_message ,= await self.subscription.fetch( batch=1, @@ -1056,6 +1070,49 @@ def __init__( include_in_schema=include_in_schema, ) + async def get_one(self, *, timeout: float = 5) -> Optional[NatsKvMessage]: + assert self._connection, "Please, start() subscriber first" + + if not self.subscription: + bucket = await self._connection.create_key_value( + bucket=self.kv_watch.name, + declare=self.kv_watch.declare, + ) + + self.subscription = UnsubscribeAdapter["KeyValue.KeyWatcher"]( + await bucket.watch( + keys=self.clear_subject, + headers_only=self.kv_watch.headers_only, + include_history=self.kv_watch.include_history, + ignore_deletes=self.kv_watch.ignore_deletes, + meta_only=self.kv_watch.meta_only, + # inactive_threshold=self.kv_watch.inactive_threshold + ) + ) + + raw_message = None + sleep_interval = timeout / 10 + with anyio.move_on_after(timeout): + while (raw_message := await self.subscription.obj.updates(timeout)) is None: + await anyio.sleep(sleep_interval) + + if not raw_message: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._parser(raw_message) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) + @override async def _create_subscription( # type: ignore[override] self, From 6748a79f6eb62dd461020b1bdad8d5195f02ae63 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Tue, 27 Aug 2024 21:02:41 +0300 Subject: [PATCH 29/62] Nats ObjStoreWatchSubscriber get_one prototype --- faststream/nats/subscriber/usecase.py | 39 +++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 50fafe08f6..ceebeeb5cc 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -1205,6 +1205,7 @@ def __init__( parser = ObjParser(pattern="") self.obj_watch = obj_watch + self.obj_watch_conn = None super().__init__( subject=subject, @@ -1223,6 +1224,44 @@ def __init__( include_in_schema=include_in_schema, ) + async def get_one(self, *, timeout: float = 5) -> Optional[NatsKvMessage]: + assert self._connection, "Please, start() subscriber first" + + if not self.obj_watch_conn: + self.bucket = await self._connection.create_object_store( + bucket=self.subject, + declare=self.obj_watch.declare, + ) + + self.obj_watch_conn = await self.bucket.watch( + ignore_deletes=self.obj_watch.ignore_deletes, + include_history=self.obj_watch.include_history, + meta_only=self.obj_watch.meta_only, + ) + + raw_message = None + sleep_interval = timeout / 10 + with anyio.move_on_after(timeout): + while (raw_message := await self.obj_watch_conn.updates(timeout)) is None: + await anyio.sleep(sleep_interval) + + if not raw_message: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._parser(raw_message) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) + @override async def _create_subscription( # type: ignore[override] self, From d9e187b6bd5d263f3de6f2034d2cc84ee4d16816 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Tue, 27 Aug 2024 23:00:30 +0300 Subject: [PATCH 30/62] Add Nats additional get_one methods --- faststream/nats/subscriber/usecase.py | 70 +++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index ceebeeb5cc..ee8d5d906d 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -626,6 +626,41 @@ def get_log_context( class PushStreamSubscription(_StreamSubscriber): subscription: Optional["JetStreamContext.PushSubscription"] + async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: + assert self._connection, "Please, start() subscriber first" + + if not self.subscription: + self.subscription = await self._connection.pull_subscribe( + subject=self.clear_subject, + config=self.config, + **self.extra_options, + ) + + try: + raw_message ,= await self.subscription.fetch( + batch=1, + timeout=timeout, + ) + except TimeoutError: + raw_message = None + + if not raw_message: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._parser(raw_message) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) + @override async def _create_subscription( # type: ignore[override] self, @@ -689,6 +724,41 @@ def __init__( include_in_schema=include_in_schema, ) + async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: + assert self._connection, "Please, start() subscriber first" + + if not self.subscription: + self.subscription = await self._connection.pull_subscribe( + subject=self.clear_subject, + config=self.config, + **self.extra_options, + ) + + try: + raw_message ,= await self.subscription.fetch( + batch=1, + timeout=timeout, + ) + except TimeoutError: + raw_message = None + + if not raw_message: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( + return_input + ) + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._parser(raw_message) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) + @override async def _create_subscription( # type: ignore[override] self, From ebbcef7dee0d04516f78bb1a6b0f4a8635781632 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sat, 24 Aug 2024 19:28:30 +0300 Subject: [PATCH 31/62] refactor: polist RMQ get_one method --- faststream/rabbit/subscriber/usecase.py | 1 - 1 file changed, 1 deletion(-) diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 223a84c0c6..3703d53577 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -214,7 +214,6 @@ async def get_one( return_msg: Callable[[RabbitMessage], Awaitable[RabbitMessage]] = ( return_input ) - for m in self._broker_middlewares: mid = m(raw_message) await stack.enter_async_context(mid) From fed12ccedd7c016ec3b217d6ccf5456a1a127851 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Tue, 27 Aug 2024 23:44:45 +0300 Subject: [PATCH 32/62] Rabbit subscriber get_one tests --- tests/brokers/rabbit/test_consume.py | 54 ++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/brokers/rabbit/test_consume.py b/tests/brokers/rabbit/test_consume.py index cd2550429c..3ff4b41e3d 100644 --- a/tests/brokers/rabbit/test_consume.py +++ b/tests/brokers/rabbit/test_consume.py @@ -383,3 +383,57 @@ async def handler(msg: RabbitMessage): m.mock.assert_not_called() assert event.is_set() + + @pytest.mark.asyncio + async def test_get_one( + self, + queue: str, + exchange: RabbitExchange, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue, exchange=exchange) + + async with self.patch_broker(broker) as br: + await broker.start() + + message = None + async def set_msg(): + nonlocal message + message = await subscriber.get_one() + + await asyncio.wait( + ( + asyncio.create_task(br.publish(message="test_message", queue=queue, exchange=exchange)), + asyncio.create_task(set_msg()), + ), + timeout=3 + ) + + assert message is not None + assert await message.decode() == "test_message" + + @pytest.mark.asyncio + async def test_get_one_timeout( + self, + queue: str, + exchange: RabbitExchange, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue, exchange=exchange) + + async with self.patch_broker(broker) as br: + await broker.start() + + message = object() + async def coro(): + nonlocal message + message = await subscriber.get_one(timeout=1) + + await asyncio.wait( + ( + asyncio.create_task(coro()), + ), + timeout=3 + ) + + assert message is None, message From 7e55e57a565d78817eecb22c0e7191899a71fb39 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Wed, 28 Aug 2024 09:57:35 +0300 Subject: [PATCH 33/62] Kafka subscriber get_one tests --- tests/brokers/kafka/test_consume.py | 52 ++++++++++++++++++++++++++++ tests/brokers/rabbit/test_consume.py | 2 +- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/tests/brokers/kafka/test_consume.py b/tests/brokers/kafka/test_consume.py index 7da9f90a5f..d354c966cd 100644 --- a/tests/brokers/kafka/test_consume.py +++ b/tests/brokers/kafka/test_consume.py @@ -312,3 +312,55 @@ async def handler(msg: KafkaMessage): m.mock.assert_not_called() assert event.is_set() + + @pytest.mark.asyncio + async def test_get_one( + self, + queue: str, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue) + + async with self.patch_broker(broker) as br: + await broker.start() + + message = None + async def set_msg(): + nonlocal message + message = await subscriber.get_one() + + await asyncio.wait( + ( + asyncio.create_task(br.publish("test_message", queue)), + asyncio.create_task(set_msg()), + ), + timeout=3 + ) + + assert message is not None + assert await message.decode() == "test_message" + + @pytest.mark.asyncio + async def test_get_one_timeout( + self, + queue: str, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue) + + async with self.patch_broker(broker) as br: + await broker.start() + + message = object() + async def coro(): + nonlocal message + message = await subscriber.get_one(timeout=1) + + await asyncio.wait( + ( + asyncio.create_task(coro()), + ), + timeout=3 + ) + + assert message is None diff --git a/tests/brokers/rabbit/test_consume.py b/tests/brokers/rabbit/test_consume.py index 3ff4b41e3d..c75e2aaeac 100644 --- a/tests/brokers/rabbit/test_consume.py +++ b/tests/brokers/rabbit/test_consume.py @@ -436,4 +436,4 @@ async def coro(): timeout=3 ) - assert message is None, message + assert message is None From 6d6b13945c03ae8066df2b74802e811e4d0ef8d8 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Wed, 28 Aug 2024 20:13:02 +0300 Subject: [PATCH 34/62] Confluent subscriber get_one tests --- faststream/confluent/subscriber/usecase.py | 3 ++ tests/brokers/confluent/test_consume.py | 57 ++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index db4edc79b8..c35c4df014 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -179,6 +179,9 @@ async def get_one(self, timeout: float = 5.0) -> "Optional[KafkaMessage]": raw_message = await self.consumer.getone(timeout=timeout) + if not raw_message: + return None + async with AsyncExitStack() as stack: return_msg: Callable[[KafkaMessage], Awaitable[KafkaMessage]] = ( return_input diff --git a/tests/brokers/confluent/test_consume.py b/tests/brokers/confluent/test_consume.py index f3eb5774cd..5e6ced37e0 100644 --- a/tests/brokers/confluent/test_consume.py +++ b/tests/brokers/confluent/test_consume.py @@ -1,4 +1,5 @@ import asyncio +import time from unittest.mock import patch import pytest @@ -322,3 +323,59 @@ async def subscriber_with_auto_commit(m): assert event.is_set() assert event2.is_set() + + @pytest.mark.asyncio + async def test_get_one( + self, + queue: str, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue) + + async with self.patch_broker(broker) as br: + await broker.start() + + message = None + async def consume(): + nonlocal message + message = await subscriber.get_one(5) + + async def publish(): + await asyncio.sleep(3) + await br.publish("test_message", queue) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + asyncio.create_task(publish()), + ), + timeout=10 + ) + + assert message is not None + assert await message.decode() == "test_message" + + @pytest.mark.asyncio + async def test_get_one_timeout( + self, + queue: str, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue) + + async with self.patch_broker(broker) as br: + await broker.start() + + message = object() + async def coro(): + nonlocal message + message = await subscriber.get_one(timeout=1) + + await asyncio.wait( + ( + asyncio.create_task(coro()), + ), + timeout=3 + ) + + assert message is None From f17ce3db2dd75989946c617ac86e01c94d614015 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Wed, 28 Aug 2024 23:56:44 +0300 Subject: [PATCH 35/62] Redis subscriber get_one tests --- faststream/confluent/subscriber/usecase.py | 2 +- faststream/kafka/subscriber/usecase.py | 2 +- faststream/redis/subscriber/usecase.py | 8 +- tests/brokers/confluent/test_consume.py | 7 +- tests/brokers/kafka/test_consume.py | 4 +- tests/brokers/rabbit/test_consume.py | 4 +- tests/brokers/redis/test_consume.py | 170 +++++++++++++++++++++ 7 files changed, 184 insertions(+), 13 deletions(-) diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index c35c4df014..24ab4165d5 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -171,7 +171,7 @@ async def close(self) -> None: self.task = None - async def get_one(self, timeout: float = 5.0) -> "Optional[KafkaMessage]": + async def get_one(self, *, timeout: float = 5.0) -> "Optional[KafkaMessage]": assert self.consumer, "You should start subscriber at first." assert ( # nosec B101 not self.calls diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index 30cb52f460..569f535d4c 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -182,7 +182,7 @@ async def close(self) -> None: self.task = None - async def get_one(self, timeout: float = 5) -> "Optional[KafkaMessage]": + async def get_one(self, *, timeout: float = 5.0,) -> "Optional[KafkaMessage]": assert self.consumer, "You should start subscriber at first." assert ( # nosec B101 not self.calls diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index bdcfcb580d..071d533797 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -801,9 +801,9 @@ async def get_one( not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." - stream_message = await self._client.xread( + stream_message = await self._client.xread( {self.stream_sub.name: self.last_id}, - block=timeout * 100000, + block=timeout * 1000, count=1, ) @@ -914,9 +914,9 @@ async def get_one( not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." - stream_message = await self._client.xread( + stream_message = await self._client.xread( {self.stream_sub.name: self.last_id}, - block=timeout * 100000, + block=timeout * 1000, count=1, ) diff --git a/tests/brokers/confluent/test_consume.py b/tests/brokers/confluent/test_consume.py index 5e6ced37e0..fa07ff323c 100644 --- a/tests/brokers/confluent/test_consume.py +++ b/tests/brokers/confluent/test_consume.py @@ -328,17 +328,18 @@ async def subscriber_with_auto_commit(m): async def test_get_one( self, queue: str, + event: asyncio.Event, ): broker = self.get_broker(apply_types=True) subscriber = broker.subscriber(queue) async with self.patch_broker(broker) as br: - await broker.start() + await br.start() message = None async def consume(): nonlocal message - message = await subscriber.get_one(5) + message = await subscriber.get_one(timeout=5) async def publish(): await asyncio.sleep(3) @@ -364,7 +365,7 @@ async def test_get_one_timeout( subscriber = broker.subscriber(queue) async with self.patch_broker(broker) as br: - await broker.start() + await br.start() message = object() async def coro(): diff --git a/tests/brokers/kafka/test_consume.py b/tests/brokers/kafka/test_consume.py index d354c966cd..12cd5101d1 100644 --- a/tests/brokers/kafka/test_consume.py +++ b/tests/brokers/kafka/test_consume.py @@ -322,7 +322,7 @@ async def test_get_one( subscriber = broker.subscriber(queue) async with self.patch_broker(broker) as br: - await broker.start() + await br.start() message = None async def set_msg(): @@ -349,7 +349,7 @@ async def test_get_one_timeout( subscriber = broker.subscriber(queue) async with self.patch_broker(broker) as br: - await broker.start() + await br.start() message = object() async def coro(): diff --git a/tests/brokers/rabbit/test_consume.py b/tests/brokers/rabbit/test_consume.py index c75e2aaeac..2c3cf23987 100644 --- a/tests/brokers/rabbit/test_consume.py +++ b/tests/brokers/rabbit/test_consume.py @@ -394,7 +394,7 @@ async def test_get_one( subscriber = broker.subscriber(queue, exchange=exchange) async with self.patch_broker(broker) as br: - await broker.start() + await br.start() message = None async def set_msg(): @@ -422,7 +422,7 @@ async def test_get_one_timeout( subscriber = broker.subscriber(queue, exchange=exchange) async with self.patch_broker(broker) as br: - await broker.start() + await br.start() message = object() async def coro(): diff --git a/tests/brokers/redis/test_consume.py b/tests/brokers/redis/test_consume.py index bda2af699e..2e1be6ff43 100644 --- a/tests/brokers/redis/test_consume.py +++ b/tests/brokers/redis/test_consume.py @@ -92,6 +92,63 @@ async def handler(msg): mock.assert_called_once_with("hello") + @pytest.mark.asyncio + async def test_get_one( + self, + queue: str, + event: asyncio.Event, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue) + + async with self.patch_broker(broker) as br: + await br.start() + + message = None + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=5) + + async def publish(): + await asyncio.sleep(0.5) + await br.publish("test_message", queue) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + asyncio.create_task(publish()), + ), + timeout=10 + ) + + assert message is not None + assert await message.decode() == "test_message" + + @pytest.mark.asyncio + async def test_get_one_timeout( + self, + queue: str, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue) + + async with self.patch_broker(broker) as br: + await br.start() + + message = object() + async def coro(): + nonlocal message + message = await subscriber.get_one(timeout=1) + + await asyncio.wait( + ( + asyncio.create_task(coro()), + ), + timeout=3 + ) + + assert message is None + @pytest.mark.redis @pytest.mark.asyncio @@ -311,6 +368,63 @@ async def handler(msg): assert [{1, "hi"}] == [set(r.result()) for r in result] + @pytest.mark.asyncio + async def test_get_one( + self, + queue: str, + event: asyncio.Event, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(list=queue) + + async with self.patch_broker(broker) as br: + await br.start() + + message = None + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=5) + + async def publish(): + await asyncio.sleep(0.5) + await br.publish("test_message", list=queue) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + asyncio.create_task(publish()), + ), + timeout=10 + ) + + assert message is not None + assert await message.decode() == "test_message" + + @pytest.mark.asyncio + async def test_get_one_timeout( + self, + queue: str, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(list=queue) + + async with self.patch_broker(broker) as br: + await br.start() + + message = object() + async def coro(): + nonlocal message + message = await subscriber.get_one(timeout=1) + + await asyncio.wait( + ( + asyncio.create_task(coro()), + ), + timeout=3 + ) + + assert message is None + @pytest.mark.redis @pytest.mark.asyncio @@ -592,3 +706,59 @@ async def handler(msg: RedisMessage): m.mock.assert_called_once() assert event.is_set() + + @pytest.mark.asyncio + async def test_get_one( + self, + queue: str, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(stream=queue) + + async with self.patch_broker(broker) as br: + await br.start() + + message = None + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=3) + + async def publish(): + await asyncio.sleep(0.5) + await br.publish("test_message", stream=queue) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + asyncio.create_task(publish()), + ), + timeout=10 + ) + + assert message is not None + assert await message.decode() == "test_message" + + @pytest.mark.asyncio + async def test_get_one_timeout( + self, + queue: str, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(stream=queue) + + async with self.patch_broker(broker) as br: + await br.start() + + message = object() + async def coro(): + nonlocal message + message = await subscriber.get_one(timeout=1) + + await asyncio.wait( + ( + asyncio.create_task(coro()), + ), + timeout=3 + ) + + assert message is None From cdc5b498d4b385ed1a5a7eb133b6d2e95b148660 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Fri, 30 Aug 2024 13:02:22 +0300 Subject: [PATCH 36/62] Nats core and JS get_one tests --- faststream/nats/subscriber/usecase.py | 1 - tests/brokers/nats/test_consume.py | 113 ++++++++++++++++++++++++++ tests/brokers/redis/test_consume.py | 6 -- 3 files changed, 113 insertions(+), 7 deletions(-) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index ee8d5d906d..15e4676ea6 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -633,7 +633,6 @@ async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: self.subscription = await self._connection.pull_subscribe( subject=self.clear_subject, config=self.config, - **self.extra_options, ) try: diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index c742fb9e48..2ea6c80226 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -375,3 +375,116 @@ async def handler(filename: str): assert event.is_set() mock.assert_called_once_with("hello") + + async def test_get_one( + self, + queue: str, + event: asyncio.Event, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue) + + async with self.patch_broker(broker) as br: + await br.start() + + message = None + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=5) + + async def publish(): + await asyncio.sleep(0.5) + await br.publish("test_message", queue) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + asyncio.create_task(publish()), + ), + timeout=10 + ) + + assert message is not None + assert await message.decode() == "test_message" + + async def test_get_one_timeout( + self, + queue: str, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue) + + async with self.patch_broker(broker) as br: + await br.start() + + message = object() + async def coro(): + nonlocal message + message = await subscriber.get_one(timeout=1) + + await asyncio.wait( + ( + asyncio.create_task(coro()), + ), + timeout=3 + ) + + assert message is None + + async def test_get_one_js( + self, + queue: str, + event: asyncio.Event, + stream: JStream, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue, stream=stream) + + async with self.patch_broker(broker) as br: + await br.start() + + message = None + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=5) + + async def publish(): + await asyncio.sleep(0.5) + await br.publish("test_message", queue, stream=stream.name) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + asyncio.create_task(publish()), + ), + timeout=10 + ) + + assert message is not None + assert await message.decode() == "test_message" + + async def test_get_one_timeout_js( + self, + queue: str, + stream: JStream, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue, stream=stream) + + async with self.patch_broker(broker) as br: + await br.start() + + message = object() + async def coro(): + nonlocal message + message = await subscriber.get_one(timeout=1) + + await asyncio.wait( + ( + asyncio.create_task(coro()), + ), + timeout=3 + ) + + assert message is None + diff --git a/tests/brokers/redis/test_consume.py b/tests/brokers/redis/test_consume.py index 2e1be6ff43..0a8a5d1012 100644 --- a/tests/brokers/redis/test_consume.py +++ b/tests/brokers/redis/test_consume.py @@ -92,7 +92,6 @@ async def handler(msg): mock.assert_called_once_with("hello") - @pytest.mark.asyncio async def test_get_one( self, queue: str, @@ -124,7 +123,6 @@ async def publish(): assert message is not None assert await message.decode() == "test_message" - @pytest.mark.asyncio async def test_get_one_timeout( self, queue: str, @@ -368,7 +366,6 @@ async def handler(msg): assert [{1, "hi"}] == [set(r.result()) for r in result] - @pytest.mark.asyncio async def test_get_one( self, queue: str, @@ -400,7 +397,6 @@ async def publish(): assert message is not None assert await message.decode() == "test_message" - @pytest.mark.asyncio async def test_get_one_timeout( self, queue: str, @@ -707,7 +703,6 @@ async def handler(msg: RedisMessage): assert event.is_set() - @pytest.mark.asyncio async def test_get_one( self, queue: str, @@ -738,7 +733,6 @@ async def publish(): assert message is not None assert await message.decode() == "test_message" - @pytest.mark.asyncio async def test_get_one_timeout( self, queue: str, From 7ae5e511924ee817154e3ad7786781da71d7389a Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Fri, 30 Aug 2024 18:57:11 +0300 Subject: [PATCH 37/62] Nats PoolSubscriber get_one tests --- tests/brokers/nats/test_consume.py | 65 ++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index 2ea6c80226..525a30ef88 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -488,3 +488,68 @@ async def coro(): assert message is None + async def test_get_one_pool( + self, + queue: str, + event: asyncio.Event, + stream: JStream, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber( + queue, + stream=stream, + pull_sub=PullSub(1), + ) + + async with self.patch_broker(broker) as br: + await br.start() + + message = None + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=5) + + async def publish(): + await asyncio.sleep(0.5) + await br.publish("test_message", queue) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + asyncio.create_task(publish()), + ), + timeout=10 + ) + + assert message is not None + assert await message.decode() == "test_message" + + async def test_get_one_pool_timeout( + self, + queue: str, + event: asyncio.Event, + stream: JStream, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber( + queue, + stream=stream, + pull_sub=PullSub(1), + ) + + async with self.patch_broker(broker) as br: + await br.start() + + message = object + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=1) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + ), + timeout=3 + ) + + assert message is None From 837c3eff33ab2c97e0a7b879ba0140efd90190ee Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 1 Sep 2024 16:38:21 +0300 Subject: [PATCH 38/62] Nats batch pull get_one tests + fixes --- faststream/nats/subscriber/usecase.py | 16 +++--- tests/brokers/nats/test_consume.py | 74 +++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 15e4676ea6..c8984aecdb 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -1049,12 +1049,12 @@ async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: **self.extra_options, ) - raw_message ,= await self.subscription.fetch( - batch=1, - timeout=timeout, - ) - - if not raw_message: + try: + raw_messages = await self.subscription.fetch( + batch=1, + timeout=timeout, + ) + except TimeoutError: return None async with AsyncExitStack() as stack: @@ -1063,11 +1063,11 @@ async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: ) for m in self._broker_middlewares: - mid = m(raw_message) + mid = m(raw_messages) await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) - parsed_msg = await self._parser(raw_message) + parsed_msg = await self._parser(raw_messages) parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index 525a30ef88..5acf1aff78 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -477,7 +477,7 @@ async def test_get_one_timeout_js( message = object() async def coro(): nonlocal message - message = await subscriber.get_one(timeout=1) + message = await subscriber.get_one(timeout=0.5) await asyncio.wait( ( @@ -488,7 +488,7 @@ async def coro(): assert message is None - async def test_get_one_pool( + async def test_get_one_pull( self, queue: str, event: asyncio.Event, @@ -524,7 +524,7 @@ async def publish(): assert message is not None assert await message.decode() == "test_message" - async def test_get_one_pool_timeout( + async def test_get_one_pull_timeout( self, queue: str, event: asyncio.Event, @@ -543,7 +543,73 @@ async def test_get_one_pool_timeout( message = object async def consume(): nonlocal message - message = await subscriber.get_one(timeout=1) + message = await subscriber.get_one(timeout=0.5) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + ), + timeout=3 + ) + + assert message is None + + async def test_get_one_batch( + self, + queue: str, + event: asyncio.Event, + stream: JStream, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber( + queue, + stream=stream, + pull_sub=PullSub(1, batch=True), + ) + + async with self.patch_broker(broker) as br: + await br.start() + + message = None + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=5) + + async def publish(): + await asyncio.sleep(0.5) + await br.publish("test_message", queue) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + asyncio.create_task(publish()), + ), + timeout=10 + ) + + assert message is not None + assert await message.decode() == ["test_message"] + + async def test_get_one_batch_timeout( + self, + queue: str, + event: asyncio.Event, + stream: JStream, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber( + queue, + stream=stream, + pull_sub=PullSub(1, batch=True), + ) + + async with self.patch_broker(broker) as br: + await br.start() + + message = object + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=0.5) await asyncio.wait( ( From e678e696f874779435ac46a39ffa88534d6a9a24 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 1 Sep 2024 19:00:31 +0300 Subject: [PATCH 39/62] Nats get_one with filter test --- tests/brokers/nats/test_consume.py | 35 ++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index 5acf1aff78..ae1b12728a 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -619,3 +619,38 @@ async def consume(): ) assert message is None + + async def test_get_one_with_filter( + self, + queue: str, + event: asyncio.Event, + stream: JStream, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber( + config=ConsumerConfig(filter_subjects=[f"{queue}.a"]), + stream=JStream(queue, subjects=[f"{queue}.*"]), + ) + + async with self.patch_broker(broker) as br: + await br.start() + + message = None + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=5) + + async def publish(): + await asyncio.sleep(0.5) + await br.publish("test_message", f"{queue}.a") + + await asyncio.wait( + ( + asyncio.create_task(consume()), + asyncio.create_task(publish()), + ), + timeout=10 + ) + + assert message is not None + assert await message.decode() == "test_message" From 302e5d31032be779844e7c699876310c8aeb7e5b Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Fri, 6 Sep 2024 22:18:35 +0300 Subject: [PATCH 40/62] Nats CoreSubscriber.get_one small refactoring --- faststream/nats/subscriber/usecase.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 816957cf05..7b44598aec 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -443,13 +443,9 @@ async def get_one( **self.extra_options, ) - raw_message = None - - with anyio.move_on_after(timeout): - async for raw_message in self.subscription.messages: - break - - if not raw_message: + try: + raw_message = await self.subscription.next_msg(timeout) + except TimeoutError: return None async with AsyncExitStack() as stack: From 8a6481886cf040d82d37133bcca69b377cfe81b8 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 15:44:34 +0300 Subject: [PATCH 41/62] refactor: polish get_one --- faststream/confluent/subscriber/usecase.py | 13 +- faststream/kafka/subscriber/usecase.py | 21 +- faststream/nats/subscriber/usecase.py | 202 +++++---------- faststream/rabbit/subscriber/usecase.py | 2 +- faststream/redis/subscriber/usecase.py | 275 +++++++-------------- tests/brokers/confluent/test_consume.py | 18 +- tests/brokers/kafka/test_consume.py | 11 +- tests/brokers/nats/test_consume.py | 54 ++-- tests/brokers/rabbit/test_consume.py | 17 +- tests/brokers/redis/test_consume.py | 37 +-- 10 files changed, 222 insertions(+), 428 deletions(-) diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index 24ab4165d5..52b4b2eddb 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -5,13 +5,14 @@ from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Dict, Iterable, List, Optional, Sequence, - Tuple, Awaitable, + Tuple, ) import anyio @@ -21,7 +22,6 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.broker.types import MsgType -from faststream.confluent.message import KafkaMessage from faststream.confluent.parser import AsyncConfluentParser from faststream.confluent.schemas import TopicPartition from faststream.utils.functions import return_input @@ -37,6 +37,7 @@ CustomCallable, ) from faststream.confluent.client import AsyncConfluentConsumer + from faststream.confluent.message import KafkaMessage from faststream.types import AnyDict, Decorator, LoggerProto @@ -172,7 +173,7 @@ async def close(self) -> None: self.task = None async def get_one(self, *, timeout: float = 5.0) -> "Optional[KafkaMessage]": - assert self.consumer, "You should start subscriber at first." + assert self.consumer, "You should start subscriber at first." # nosec B101 assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." @@ -183,16 +184,14 @@ async def get_one(self, *, timeout: float = 5.0) -> "Optional[KafkaMessage]": return None async with AsyncExitStack() as stack: - return_msg: Callable[[KafkaMessage], Awaitable[KafkaMessage]] = ( - return_input - ) + return_msg: Callable[[KafkaMessage], Awaitable[KafkaMessage]] = return_input for m in self._broker_middlewares: mid = m(raw_message) await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) - parsed_msg = await self._parser(raw_message) + parsed_msg: KafkaMessage = await self._parser(raw_message) parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index 569f535d4c..f79e74b65c 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -6,13 +6,14 @@ from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Dict, Iterable, List, Optional, Sequence, - Tuple, Awaitable, + Tuple, ) import anyio @@ -182,30 +183,34 @@ async def close(self) -> None: self.task = None - async def get_one(self, *, timeout: float = 5.0,) -> "Optional[KafkaMessage]": + async def get_one( + self, + *, + timeout: float = 5.0, + ) -> "Optional[KafkaMessage]": assert self.consumer, "You should start subscriber at first." assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." - raw_messages = await self.consumer.getmany(timeout_ms=timeout * 1000, max_records=1) + raw_messages = await self.consumer.getmany( + timeout_ms=timeout * 1000, max_records=1 + ) if not raw_messages: return None - (raw_message,) ,= raw_messages.values() + ((raw_message,),) = raw_messages.values() async with AsyncExitStack() as stack: - return_msg: Callable[[KafkaMessage], Awaitable[KafkaMessage]] = ( - return_input - ) + return_msg: Callable[[KafkaMessage], Awaitable[KafkaMessage]] = return_input for m in self._broker_middlewares: mid = m(raw_message) await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) - parsed_msg = await self._parser(raw_message) + parsed_msg: KafkaMessage = await self._parser(raw_message) parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 816957cf05..9520081f34 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -1,6 +1,6 @@ import asyncio from abc import abstractmethod -from contextlib import suppress, AsyncExitStack +from contextlib import AsyncExitStack, suppress from functools import partial from typing import ( TYPE_CHECKING, @@ -28,7 +28,7 @@ from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.broker.types import CustomCallable, MsgType from faststream.exceptions import NOT_CONNECTED_YET -from faststream.nats.message import NatsMessage, NatsKvMessage +from faststream.nats.message import NatsKvMessage, NatsMessage from faststream.nats.parser import ( BatchParser, JsParser, @@ -158,12 +158,10 @@ async def start(self) -> None: """Create NATS subscription and start consume tasks.""" assert self._connection, NOT_CONNECTED_YET # nosec B101 - if not self.calls: - return - await super().start() - await self._create_subscription(connection=self._connection) + if self.calls: + await self._create_subscription(connection=self._connection) async def close(self) -> None: """Clean up handler subscription, cancel consume task in graceful mode.""" @@ -446,16 +444,15 @@ async def get_one( raw_message = None with anyio.move_on_after(timeout): - async for raw_message in self.subscription.messages: + async for msg in self.subscription.messages: + raw_message = msg break if not raw_message: return None async with AsyncExitStack() as stack: - return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( - return_input - ) + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = return_input for m in self._broker_middlewares: mid = m(raw_message) @@ -473,7 +470,6 @@ async def _create_subscription( # type: ignore[override] connection: "Client", ) -> None: """Create NATS subscription and start consume task.""" - if self.subscription: return @@ -484,8 +480,6 @@ async def _create_subscription( # type: ignore[override] **self.extra_options, ) - - def get_log_context( self, message: Annotated[ @@ -614,7 +608,6 @@ def get_log_context( ], ) -> Dict[str, str]: """Log context factory using in `self.consume` scope.""" - return self.build_log_context( message=message, subject=self._resolved_subject_string, @@ -622,44 +615,58 @@ def get_log_context( stream=self.stream.name, ) - -class PushStreamSubscription(_StreamSubscriber): - subscription: Optional["JetStreamContext.PushSubscription"] - - async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: + async def get_one( + self, + *, + timeout: float = 5, + ) -> Optional[NatsMessage]: assert self._connection, "Please, start() subscriber first" + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." if not self.subscription: + extra_options = { + "pending_bytes_limit": self.extra_options["pending_bytes_limit"], + "pending_msgs_limit": self.extra_options["pending_msgs_limit"], + "durable": self.extra_options["durable"], + "stream": self.extra_options["stream"], + } + if inbox_prefix := self.extra_options.get("inbox_prefix"): + extra_options["inbox_prefix"] = inbox_prefix + self.subscription = await self._connection.pull_subscribe( subject=self.clear_subject, config=self.config, + **extra_options, ) try: - raw_message ,= await self.subscription.fetch( - batch=1, - timeout=timeout, - ) - except TimeoutError: - raw_message = None - - if not raw_message: + raw_message = ( + await self.subscription.fetch( + batch=1, + timeout=timeout, + ) + )[0] + except (TimeoutError, ConnectionClosedError): return None async with AsyncExitStack() as stack: - return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( - return_input - ) + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = return_input for m in self._broker_middlewares: mid = m(raw_message) await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) - parsed_msg = await self._parser(raw_message) + parsed_msg: NatsMessage = await self._parser(raw_message) parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) + +class PushStreamSubscription(_StreamSubscriber): + subscription: Optional["JetStreamContext.PushSubscription"] + @override async def _create_subscription( # type: ignore[override] self, @@ -723,41 +730,6 @@ def __init__( include_in_schema=include_in_schema, ) - async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: - assert self._connection, "Please, start() subscriber first" - - if not self.subscription: - self.subscription = await self._connection.pull_subscribe( - subject=self.clear_subject, - config=self.config, - **self.extra_options, - ) - - try: - raw_message ,= await self.subscription.fetch( - batch=1, - timeout=timeout, - ) - except TimeoutError: - raw_message = None - - if not raw_message: - return None - - async with AsyncExitStack() as stack: - return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( - return_input - ) - - for m in self._broker_middlewares: - mid = m(raw_message) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg = await self._parser(raw_message) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - @override async def _create_subscription( # type: ignore[override] self, @@ -823,41 +795,6 @@ def __init__( include_in_schema=include_in_schema, ) - async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: - assert self._connection, "Please, start() subscriber first" - - if not self.subscription: - self.subscription = await self._connection.pull_subscribe( - subject=self.clear_subject, - config=self.config, - **self.extra_options, - ) - - try: - raw_message ,= await self.subscription.fetch( - batch=1, - timeout=timeout, - ) - except TimeoutError: - raw_message = None - - if not raw_message: - return None - - async with AsyncExitStack() as stack: - return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( - return_input - ) - - for m in self._broker_middlewares: - mid = m(raw_message) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg = await self._parser(raw_message) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - @override async def _create_subscription( # type: ignore[override] self, @@ -938,38 +875,6 @@ def __init__( include_in_schema=include_in_schema, ) - async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: - assert self._connection, "Please, start() subscriber first" - - if not self.subscription: - self.subscription = await self._connection.pull_subscribe( - subject=self.clear_subject, - config=self.config, - **self.extra_options, - ) - - raw_message ,= await self.subscription.fetch( - batch=1, - timeout=timeout, - ) - - if not raw_message: - return None - - async with AsyncExitStack() as stack: - return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( - return_input - ) - - for m in self._broker_middlewares: - mid = m(raw_message) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg = await self._parser(raw_message) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - @override async def _create_subscription( # type: ignore[override] self, @@ -1039,8 +944,15 @@ def __init__( include_in_schema=include_in_schema, ) - async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: + async def get_one( + self, + *, + timeout: float = 5, + ) -> Optional[NatsMessage]: assert self._connection, "Please, start() subscriber first" + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." if not self.subscription: self.subscription = await self._connection.pull_subscribe( @@ -1058,9 +970,7 @@ async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: return None async with AsyncExitStack() as stack: - return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( - return_input - ) + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = return_input for m in self._broker_middlewares: mid = m(raw_messages) @@ -1141,6 +1051,9 @@ def __init__( async def get_one(self, *, timeout: float = 5) -> Optional[NatsKvMessage]: assert self._connection, "Please, start() subscriber first" + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." if not self.subscription: bucket = await self._connection.create_key_value( @@ -1162,16 +1075,14 @@ async def get_one(self, *, timeout: float = 5) -> Optional[NatsKvMessage]: raw_message = None sleep_interval = timeout / 10 with anyio.move_on_after(timeout): - while (raw_message := await self.subscription.obj.updates(timeout)) is None: + while (raw_message := await self.subscription.obj.updates(timeout)) is None: # noqa: ASYNC110 await anyio.sleep(sleep_interval) if not raw_message: return None async with AsyncExitStack() as stack: - return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( - return_input - ) + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = return_input for m in self._broker_middlewares: mid = m(raw_message) @@ -1295,6 +1206,9 @@ def __init__( async def get_one(self, *, timeout: float = 5) -> Optional[NatsKvMessage]: assert self._connection, "Please, start() subscriber first" + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." if not self.obj_watch_conn: self.bucket = await self._connection.create_object_store( @@ -1311,16 +1225,14 @@ async def get_one(self, *, timeout: float = 5) -> Optional[NatsKvMessage]: raw_message = None sleep_interval = timeout / 10 with anyio.move_on_after(timeout): - while (raw_message := await self.obj_watch_conn.updates(timeout)) is None: + while (raw_message := await self.obj_watch_conn.updates(timeout)) is None: # noqa: ASYNC110 await anyio.sleep(sleep_interval) if not raw_message: return None async with AsyncExitStack() as stack: - return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = ( - return_input - ) + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = return_input for m in self._broker_middlewares: mid = m(raw_message) diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 3703d53577..4ed2434bef 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -219,7 +219,7 @@ async def get_one( await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) - parsed_msg = await self._parser(raw_message) + parsed_msg: RabbitMessage = await self._parser(raw_message) parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index 071d533797..f9ac81d9e7 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -1,6 +1,6 @@ import asyncio from abc import abstractmethod -from contextlib import suppress, AsyncExitStack +from contextlib import AsyncExitStack, suppress from copy import deepcopy from functools import partial from typing import ( @@ -30,8 +30,10 @@ DefaultListMessage, DefaultStreamMessage, PubSubMessage, + RedisListMessage, + RedisMessage, + RedisStreamMessage, UnifyRedisDict, - RedisMessage ) from faststream.redis.parser import ( RedisBatchListParser, @@ -162,7 +164,9 @@ async def start( start_signal = anyio.Event() if self.calls: - self.task = asyncio.create_task(self._consume(*args, start_signal=start_signal)) + self.task = asyncio.create_task( + self._consume(*args, start_signal=start_signal) + ) with anyio.fail_after(3.0): await start_signal.wait() @@ -298,23 +302,21 @@ async def get_one( message: Optional[PubSubMessage] = None with anyio.move_on_after(timeout): - while (message := await self._get_message(self.subscription)) is None: + while (message := await self._get_message(self.subscription)) is None: # noqa: ASYNC110 await anyio.sleep(sleep_interval) if not message: return None async with AsyncExitStack() as stack: - return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = ( - return_input - ) + return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = return_input for m in self._broker_middlewares: mid = m(message) await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) - parsed_msg = await self._parser(message) + parsed_msg: RedisMessage = await self._parser(message) parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) @@ -334,7 +336,6 @@ async def _get_message(self, psub: RPubSub) -> Optional[PubSubMessage]: return - async def _get_msgs(self, psub: RPubSub) -> None: msg = await self._get_message(psub) await self.consume(msg) # type: ignore[arg-type] @@ -411,6 +412,46 @@ async def start(self) -> None: await super().start(self._client) + async def get_one( + self, + *, + timeout: float = 5.0, + ) -> "Optional[RedisListMessage]": + assert self._client, "You should start subscriber at first." + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." + + sleep_interval = timeout / 10 + raw_message = None + + with anyio.move_on_after(timeout): + while ( # noqa: ASYNC110 + raw_message := await self._client.lpop(name=self.list_sub.name) + ) is None: + await anyio.sleep(sleep_interval) + + if not raw_message: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = return_input + + for m in self._broker_middlewares: + mid = m(raw_message) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + message = DefaultListMessage( + type="list", + data=raw_message, + channel=self.list_sub.name, + ) + + parsed_message: RedisListMessage = await self._parser(message) + parsed_message._decoded_body = await self._decoder(parsed_message) + return await return_msg(parsed_message) + def add_prefix(self, prefix: str) -> None: new_list = deepcopy(self.list_sub) new_list.name = "".join((prefix, new_list.name)) @@ -450,49 +491,6 @@ def __init__( include_in_schema=include_in_schema, ) - async def get_one( - self, - *, - timeout: float = 5.0, - ) -> "Optional[RedisMessage]": - assert self._client, "You should start subscriber at first." - assert ( # nosec B101 - not self.calls - ), "You can't use `get_one` method if subscriber has registered handlers." - - sleep_interval = timeout / 10 - raw_message = None - - with anyio.move_on_after(timeout): - while ( # noqa: ASYNC110 - raw_message := await self._client.lpop(name=self.list_sub.name) - ) is None: - await anyio.sleep(sleep_interval) - - if not raw_message: - return None - - async with AsyncExitStack() as stack: - return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = ( - return_input - ) - - for m in self._broker_middlewares: - mid = m(raw_message) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - message = DefaultListMessage( - type="list", - data=raw_message, - channel=self.list_sub.name, - ) - - parsed_message = await self._parser(message) - parsed_message._decoded_body = await self._decoder(parsed_message) - - return await return_msg(parsed_message) - async def _get_msgs(self, client: "Redis[bytes]") -> None: raw_msg = await client.lpop(name=self.list_sub.name) @@ -542,49 +540,6 @@ def __init__( include_in_schema=include_in_schema, ) - async def get_one( - self, - *, - timeout: float = 5.0, - ) -> "Optional[RedisMessage]": - assert self._client, "You should start subscriber at first." - assert ( # nosec B101 - not self.calls - ), "You can't use `get_one` method if subscriber has registered handlers." - - sleep_interval = timeout / 10 - raw_message = None - - with anyio.move_on_after(timeout): - while ( # noqa: ASYNC110 - raw_message := await self._client.lpop(name=self.list_sub.name) - ) is None: - await anyio.sleep(sleep_interval) - - if not raw_message: - return None - - async with AsyncExitStack() as stack: - return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = ( - return_input - ) - - for m in self._broker_middlewares: - mid = m(raw_message) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - message = DefaultListMessage( - type="list", - data=raw_message, - channel=self.list_sub.name, - ) - - parsed_message = await self._parser(message) - parsed_message._decoded_body = await self._decoder(parsed_message) - - return await return_msg(parsed_message) - async def _get_msgs(self, client: "Redis[bytes]") -> None: raw_msgs = await client.lpop( name=self.list_sub.name, @@ -752,6 +707,48 @@ def read( await super().start(read) + async def get_one( + self, + *, + timeout: float = 5.0, + ) -> "Optional[RedisStreamMessage]": + assert self._client, "You should start subscriber at first." # nosec B101 + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." + + stream_message = await self._client.xread( + {self.stream_sub.name: self.last_id}, + block=timeout * 1000, + count=1, + ) + + if not stream_message: + return None + + ((stream_name, ((message_id, raw_message),)),) = stream_message + + self.last_id = message_id.decode() + + msg = DefaultStreamMessage( + type="stream", + channel=stream_name.decode(), + message_ids=[message_id], + data=raw_message, + ) + + async with AsyncExitStack() as stack: + return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = return_input + + for m in self._broker_middlewares: + mid = m(msg) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg: RedisStreamMessage = await self._parser(msg) + parsed_msg._decoded_body = await self._decoder(parsed_msg) + return await return_msg(parsed_msg) + def add_prefix(self, prefix: str) -> None: new_stream = deepcopy(self.stream_sub) new_stream.name = "".join((prefix, new_stream.name)) @@ -791,50 +788,6 @@ def __init__( include_in_schema=include_in_schema, ) - async def get_one( - self, - *, - timeout: float = 5.0, - ) -> "Optional[RedisMessage]": - assert self._client, "You should start subscriber at first." # nosec B101 - assert ( # nosec B101 - not self.calls - ), "You can't use `get_one` method if subscriber has registered handlers." - - stream_message = await self._client.xread( - {self.stream_sub.name: self.last_id}, - block=timeout * 1000, - count=1, - ) - - if not stream_message: - return None - - (stream_name, ((message_id, raw_message),)) ,= stream_message - - self.last_id = message_id.decode() - - msg = DefaultStreamMessage( - type="stream", - channel=stream_name.decode(), - message_ids=[message_id], - data=raw_message, - ) - - async with AsyncExitStack() as stack: - return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = ( - return_input - ) - - for m in self._broker_middlewares: - mid = m(msg) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg = await self._parser(msg) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - async def _get_msgs( self, read: Callable[ @@ -904,50 +857,6 @@ def __init__( include_in_schema=include_in_schema, ) - async def get_one( - self, - *, - timeout: float = 5.0, - ) -> "Optional[RedisMessage]": - assert self._client, "You should start subscriber at first." # nosec B101 - assert ( # nosec B101 - not self.calls - ), "You can't use `get_one` method if subscriber has registered handlers." - - stream_message = await self._client.xread( - {self.stream_sub.name: self.last_id}, - block=timeout * 1000, - count=1, - ) - - if not stream_message: - return None - - (stream_name, ((message_id, raw_message),)) ,= stream_message - - self.last_id = message_id.decode() - - msg = DefaultStreamMessage( - type="stream", - channel=stream_name.decode(), - message_ids=[message_id], - data=raw_message, - ) - - async with AsyncExitStack() as stack: - return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = ( - return_input - ) - - for m in self._broker_middlewares: - mid = m(msg) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg = await self._parser(msg) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - async def _get_msgs( self, read: Callable[ diff --git a/tests/brokers/confluent/test_consume.py b/tests/brokers/confluent/test_consume.py index fa07ff323c..108450882f 100644 --- a/tests/brokers/confluent/test_consume.py +++ b/tests/brokers/confluent/test_consume.py @@ -1,5 +1,4 @@ import asyncio -import time from unittest.mock import patch import pytest @@ -331,18 +330,21 @@ async def test_get_one( event: asyncio.Event, ): broker = self.get_broker(apply_types=True) - subscriber = broker.subscriber(queue) + + args, kwargs = self.get_subscriber_params(queue) + + subscriber = broker.subscriber(*args, **kwargs) async with self.patch_broker(broker) as br: await br.start() message = None + async def consume(): nonlocal message message = await subscriber.get_one(timeout=5) async def publish(): - await asyncio.sleep(3) await br.publish("test_message", queue) await asyncio.wait( @@ -350,7 +352,7 @@ async def publish(): asyncio.create_task(consume()), asyncio.create_task(publish()), ), - timeout=10 + timeout=10, ) assert message is not None @@ -368,15 +370,11 @@ async def test_get_one_timeout( await br.start() message = object() + async def coro(): nonlocal message message = await subscriber.get_one(timeout=1) - await asyncio.wait( - ( - asyncio.create_task(coro()), - ), - timeout=3 - ) + await asyncio.wait((asyncio.create_task(coro()),), timeout=3) assert message is None diff --git a/tests/brokers/kafka/test_consume.py b/tests/brokers/kafka/test_consume.py index 12cd5101d1..57ff6c4e5d 100644 --- a/tests/brokers/kafka/test_consume.py +++ b/tests/brokers/kafka/test_consume.py @@ -325,6 +325,7 @@ async def test_get_one( await br.start() message = None + async def set_msg(): nonlocal message message = await subscriber.get_one() @@ -334,7 +335,7 @@ async def set_msg(): asyncio.create_task(br.publish("test_message", queue)), asyncio.create_task(set_msg()), ), - timeout=3 + timeout=3, ) assert message is not None @@ -352,15 +353,11 @@ async def test_get_one_timeout( await br.start() message = object() + async def coro(): nonlocal message message = await subscriber.get_one(timeout=1) - await asyncio.wait( - ( - asyncio.create_task(coro()), - ), - timeout=3 - ) + await asyncio.wait((asyncio.create_task(coro()),), timeout=3) assert message is None diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index bec1dc6f2b..be53e398de 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -414,12 +414,12 @@ async def test_get_one( await br.start() message = None + async def consume(): nonlocal message message = await subscriber.get_one(timeout=5) async def publish(): - await asyncio.sleep(0.5) await br.publish("test_message", queue) await asyncio.wait( @@ -427,7 +427,7 @@ async def publish(): asyncio.create_task(consume()), asyncio.create_task(publish()), ), - timeout=10 + timeout=10, ) assert message is not None @@ -444,16 +444,12 @@ async def test_get_one_timeout( await br.start() message = object() + async def coro(): nonlocal message message = await subscriber.get_one(timeout=1) - await asyncio.wait( - ( - asyncio.create_task(coro()), - ), - timeout=3 - ) + await asyncio.wait((asyncio.create_task(coro()),), timeout=3) assert message is None @@ -470,12 +466,12 @@ async def test_get_one_js( await br.start() message = None + async def consume(): nonlocal message message = await subscriber.get_one(timeout=5) async def publish(): - await asyncio.sleep(0.5) await br.publish("test_message", queue, stream=stream.name) await asyncio.wait( @@ -483,7 +479,7 @@ async def publish(): asyncio.create_task(consume()), asyncio.create_task(publish()), ), - timeout=10 + timeout=10, ) assert message is not None @@ -501,16 +497,12 @@ async def test_get_one_timeout_js( await br.start() message = object() + async def coro(): nonlocal message message = await subscriber.get_one(timeout=0.5) - await asyncio.wait( - ( - asyncio.create_task(coro()), - ), - timeout=3 - ) + await asyncio.wait((asyncio.create_task(coro()),), timeout=3) assert message is None @@ -531,12 +523,12 @@ async def test_get_one_pull( await br.start() message = None + async def consume(): nonlocal message message = await subscriber.get_one(timeout=5) async def publish(): - await asyncio.sleep(0.5) await br.publish("test_message", queue) await asyncio.wait( @@ -544,7 +536,7 @@ async def publish(): asyncio.create_task(consume()), asyncio.create_task(publish()), ), - timeout=10 + timeout=10, ) assert message is not None @@ -567,16 +559,12 @@ async def test_get_one_pull_timeout( await br.start() message = object + async def consume(): nonlocal message message = await subscriber.get_one(timeout=0.5) - await asyncio.wait( - ( - asyncio.create_task(consume()), - ), - timeout=3 - ) + await asyncio.wait((asyncio.create_task(consume()),), timeout=3) assert message is None @@ -597,12 +585,12 @@ async def test_get_one_batch( await br.start() message = None + async def consume(): nonlocal message message = await subscriber.get_one(timeout=5) async def publish(): - await asyncio.sleep(0.5) await br.publish("test_message", queue) await asyncio.wait( @@ -610,7 +598,7 @@ async def publish(): asyncio.create_task(consume()), asyncio.create_task(publish()), ), - timeout=10 + timeout=10, ) assert message is not None @@ -633,16 +621,12 @@ async def test_get_one_batch_timeout( await br.start() message = object + async def consume(): nonlocal message message = await subscriber.get_one(timeout=0.5) - await asyncio.wait( - ( - asyncio.create_task(consume()), - ), - timeout=3 - ) + await asyncio.wait((asyncio.create_task(consume()),), timeout=3) assert message is None @@ -662,20 +646,20 @@ async def test_get_one_with_filter( await br.start() message = None + async def consume(): nonlocal message message = await subscriber.get_one(timeout=5) async def publish(): - await asyncio.sleep(0.5) await br.publish("test_message", f"{queue}.a") await asyncio.wait( ( - asyncio.create_task(consume()), asyncio.create_task(publish()), + asyncio.create_task(consume()), ), - timeout=10 + timeout=10, ) assert message is not None diff --git a/tests/brokers/rabbit/test_consume.py b/tests/brokers/rabbit/test_consume.py index 2c3cf23987..daa56fc44d 100644 --- a/tests/brokers/rabbit/test_consume.py +++ b/tests/brokers/rabbit/test_consume.py @@ -397,16 +397,21 @@ async def test_get_one( await br.start() message = None + async def set_msg(): nonlocal message message = await subscriber.get_one() await asyncio.wait( ( - asyncio.create_task(br.publish(message="test_message", queue=queue, exchange=exchange)), + asyncio.create_task( + br.publish( + message="test_message", queue=queue, exchange=exchange + ) + ), asyncio.create_task(set_msg()), ), - timeout=3 + timeout=3, ) assert message is not None @@ -425,15 +430,11 @@ async def test_get_one_timeout( await br.start() message = object() + async def coro(): nonlocal message message = await subscriber.get_one(timeout=1) - await asyncio.wait( - ( - asyncio.create_task(coro()), - ), - timeout=3 - ) + await asyncio.wait((asyncio.create_task(coro()),), timeout=3) assert message is None diff --git a/tests/brokers/redis/test_consume.py b/tests/brokers/redis/test_consume.py index 0a8a5d1012..881f58a1f4 100644 --- a/tests/brokers/redis/test_consume.py +++ b/tests/brokers/redis/test_consume.py @@ -104,12 +104,12 @@ async def test_get_one( await br.start() message = None + async def consume(): nonlocal message message = await subscriber.get_one(timeout=5) async def publish(): - await asyncio.sleep(0.5) await br.publish("test_message", queue) await asyncio.wait( @@ -117,7 +117,7 @@ async def publish(): asyncio.create_task(consume()), asyncio.create_task(publish()), ), - timeout=10 + timeout=10, ) assert message is not None @@ -134,16 +134,12 @@ async def test_get_one_timeout( await br.start() message = object() + async def coro(): nonlocal message message = await subscriber.get_one(timeout=1) - await asyncio.wait( - ( - asyncio.create_task(coro()), - ), - timeout=3 - ) + await asyncio.wait((asyncio.create_task(coro()),), timeout=3) assert message is None @@ -378,12 +374,12 @@ async def test_get_one( await br.start() message = None + async def consume(): nonlocal message message = await subscriber.get_one(timeout=5) async def publish(): - await asyncio.sleep(0.5) await br.publish("test_message", list=queue) await asyncio.wait( @@ -391,7 +387,7 @@ async def publish(): asyncio.create_task(consume()), asyncio.create_task(publish()), ), - timeout=10 + timeout=10, ) assert message is not None @@ -408,16 +404,12 @@ async def test_get_one_timeout( await br.start() message = object() + async def coro(): nonlocal message message = await subscriber.get_one(timeout=1) - await asyncio.wait( - ( - asyncio.create_task(coro()), - ), - timeout=3 - ) + await asyncio.wait((asyncio.create_task(coro()),), timeout=3) assert message is None @@ -714,12 +706,13 @@ async def test_get_one( await br.start() message = None + async def consume(): nonlocal message message = await subscriber.get_one(timeout=3) async def publish(): - await asyncio.sleep(0.5) + await asyncio.sleep(0.1) await br.publish("test_message", stream=queue) await asyncio.wait( @@ -727,7 +720,7 @@ async def publish(): asyncio.create_task(consume()), asyncio.create_task(publish()), ), - timeout=10 + timeout=10, ) assert message is not None @@ -744,15 +737,11 @@ async def test_get_one_timeout( await br.start() message = object() + async def coro(): nonlocal message message = await subscriber.get_one(timeout=1) - await asyncio.wait( - ( - asyncio.create_task(coro()), - ), - timeout=3 - ) + await asyncio.wait((asyncio.create_task(coro()),), timeout=3) assert message is None From d52e1fbca8e2c8899bc5439d40d239d61ccf343f Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 16:02:18 +0300 Subject: [PATCH 42/62] lint: fix redis mypy --- faststream/redis/subscriber/usecase.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index f9ac81d9e7..2c07b7715e 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -287,7 +287,8 @@ async def close(self) -> None: await super().close() - async def get_one( + @override + async def get_one( # type: ignore[override] self, *, timeout: float = 5.0, @@ -312,7 +313,7 @@ async def get_one( return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = return_input for m in self._broker_middlewares: - mid = m(message) + mid = m(message) # type: ignore[arg-type] await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) @@ -320,6 +321,8 @@ async def get_one( parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) + raise AssertionError("unreachable") + async def _get_message(self, psub: RPubSub) -> Optional[PubSubMessage]: raw_msg = await psub.get_message( ignore_subscribe_messages=True, @@ -334,7 +337,7 @@ async def _get_message(self, psub: RPubSub) -> Optional[PubSubMessage]: pattern=raw_msg["pattern"], ) - return + return None async def _get_msgs(self, psub: RPubSub) -> None: msg = await self._get_message(psub) @@ -412,7 +415,8 @@ async def start(self) -> None: await super().start(self._client) - async def get_one( + @override + async def get_one( # type: ignore[override] self, *, timeout: float = 5.0, @@ -435,7 +439,7 @@ async def get_one( return None async with AsyncExitStack() as stack: - return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = return_input + return_msg: Callable[[RedisListMessage], Awaitable[RedisListMessage]] = return_input for m in self._broker_middlewares: mid = m(raw_message) @@ -452,6 +456,8 @@ async def get_one( parsed_message._decoded_body = await self._decoder(parsed_message) return await return_msg(parsed_message) + raise AssertionError("unreachable") + def add_prefix(self, prefix: str) -> None: new_list = deepcopy(self.list_sub) new_list.name = "".join((prefix, new_list.name)) @@ -707,7 +713,8 @@ def read( await super().start(read) - async def get_one( + @override + async def get_one( # type: ignore[override] self, *, timeout: float = 5.0, @@ -738,10 +745,10 @@ async def get_one( ) async with AsyncExitStack() as stack: - return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = return_input + return_msg: Callable[[RedisStreamMessage], Awaitable[RedisStreamMessage]] = return_input for m in self._broker_middlewares: - mid = m(msg) + mid = m(msg) # type: ignore[arg-type] await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) @@ -749,6 +756,8 @@ async def get_one( parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) + raise AssertionError("unreachable") + def add_prefix(self, prefix: str) -> None: new_stream = deepcopy(self.stream_sub) new_stream.name = "".join((prefix, new_stream.name)) From fbafc62c616d822527a92b7bd1ea7550aa6f844b Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 16:03:26 +0300 Subject: [PATCH 43/62] lint: fix rabbit mypy --- faststream/rabbit/subscriber/usecase.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 4ed2434bef..6bd7d4d1e4 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -183,6 +183,7 @@ async def close(self) -> None: self._queue_obj = None + @override async def get_one( self, *, @@ -223,6 +224,8 @@ async def get_one( parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) + raise AssertionError("unreachable") + def _make_response_publisher( self, message: "StreamMessage[Any]", From 42f898bd92209010eb78bb6cb4dd95c9dc245fbe Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 16:23:41 +0300 Subject: [PATCH 44/62] lint: fix kafka mypy --- faststream/kafka/subscriber/usecase.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index f79e74b65c..6c93a94888 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -183,7 +183,8 @@ async def close(self) -> None: self.task = None - async def get_one( + @override + async def get_one( # type: ignore[override] self, *, timeout: float = 5.0, @@ -214,6 +215,8 @@ async def get_one( parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) + raise AssertionError("unreachable") + def _make_response_publisher( self, message: "StreamMessage[Any]", From 2715e7940121d1205ea80ac52da2a93725588b59 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 17:27:57 +0300 Subject: [PATCH 45/62] lint: fix confluent mypy --- faststream/confluent/subscriber/usecase.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index 52b4b2eddb..91b7bc6e54 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -172,7 +172,12 @@ async def close(self) -> None: self.task = None - async def get_one(self, *, timeout: float = 5.0) -> "Optional[KafkaMessage]": + @override + async def get_one( + self, + *, + timeout: float = 5.0, + ) -> "Optional[StreamMessage[MsgType]]": assert self.consumer, "You should start subscriber at first." # nosec B101 assert ( # nosec B101 not self.calls @@ -184,17 +189,19 @@ async def get_one(self, *, timeout: float = 5.0) -> "Optional[KafkaMessage]": return None async with AsyncExitStack() as stack: - return_msg: Callable[[KafkaMessage], Awaitable[KafkaMessage]] = return_input + return_msg: Callable[[StreamMessage[MsgType]], Awaitable[StreamMessage[MsgType]]] = return_input for m in self._broker_middlewares: - mid = m(raw_message) + mid = m(raw_message) # type: ignore[arg-type] await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) - parsed_msg: KafkaMessage = await self._parser(raw_message) + parsed_msg: StreamMessage[MsgType] = await self._parser(raw_message) parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) + raise AssertionError("unreachable") + def _make_response_publisher( self, message: "StreamMessage[Any]", From c80528057c5bbf6332633c65d228946fd0654036 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 17:28:52 +0300 Subject: [PATCH 46/62] lint: fix kafka mypy --- faststream/kafka/subscriber/usecase.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index 6c93a94888..16fe929c40 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -184,11 +184,11 @@ async def close(self) -> None: self.task = None @override - async def get_one( # type: ignore[override] + async def get_one( self, *, timeout: float = 5.0, - ) -> "Optional[KafkaMessage]": + ) -> "Optional[StreamMessage[MsgType]]": assert self.consumer, "You should start subscriber at first." assert ( # nosec B101 not self.calls @@ -204,14 +204,16 @@ async def get_one( # type: ignore[override] ((raw_message,),) = raw_messages.values() async with AsyncExitStack() as stack: - return_msg: Callable[[KafkaMessage], Awaitable[KafkaMessage]] = return_input + return_msg: Callable[ + [StreamMessage[MsgType]], Awaitable[StreamMessage[MsgType]] + ] = return_input for m in self._broker_middlewares: mid = m(raw_message) await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) - parsed_msg: KafkaMessage = await self._parser(raw_message) + parsed_msg: StreamMessage[MsgType] = await self._parser(raw_message) parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) From 35b71956626b80bd826283ae0402fd8c47c91c99 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 17:51:46 +0300 Subject: [PATCH 47/62] lint: fix nats mypy --- faststream/confluent/subscriber/usecase.py | 5 +- faststream/nats/subscriber/usecase.py | 194 +++++++++++++++------ faststream/redis/subscriber/usecase.py | 8 +- 3 files changed, 145 insertions(+), 62 deletions(-) diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index 91b7bc6e54..30c24f70d4 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -37,7 +37,6 @@ CustomCallable, ) from faststream.confluent.client import AsyncConfluentConsumer - from faststream.confluent.message import KafkaMessage from faststream.types import AnyDict, Decorator, LoggerProto @@ -189,7 +188,9 @@ async def get_one( return None async with AsyncExitStack() as stack: - return_msg: Callable[[StreamMessage[MsgType]], Awaitable[StreamMessage[MsgType]]] = return_input + return_msg: Callable[ + [StreamMessage[MsgType]], Awaitable[StreamMessage[MsgType]] + ] = return_input for m in self._broker_middlewares: mid = m(raw_message) # type: ignore[arg-type] diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index c2f03893f6..2701022590 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -28,7 +28,6 @@ from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.broker.types import CustomCallable, MsgType from faststream.exceptions import NOT_CONNECTED_YET -from faststream.nats.message import NatsKvMessage, NatsMessage from faststream.nats.parser import ( BatchParser, JsParser, @@ -60,6 +59,7 @@ BrokerMiddleware, ) from faststream.nats.helpers import KVBucketDeclarer, OSBucketDeclarer + from faststream.nats.message import NatsKvMessage, NatsMessage, NatsObjMessage from faststream.nats.schemas import JStream, KvWatch, ObjWatch, PullSub from faststream.types import Decorator @@ -69,7 +69,6 @@ class LogicSubscriber(SubscriberUsecase[MsgType]): subscription: Optional[Unsubscriptable] producer: Optional["ProducerProto"] - _connection: Union["Client", "JetStreamContext", None] def __init__( self, @@ -110,44 +109,10 @@ def __init__( include_in_schema=include_in_schema, ) - self._connection = None + self._connection: Any = None self.subscription = None self.producer = None - @override - def setup( # type: ignore[override] - self, - *, - connection: Union["Client", "JetStreamContext"], - # basic args - logger: Optional["LoggerProto"], - producer: Optional["ProducerProto"], - graceful_timeout: Optional[float], - extra_context: "AnyDict", - # broker options - broker_parser: Optional["CustomCallable"], - broker_decoder: Optional["CustomCallable"], - # dependant args - apply_types: bool, - is_validate: bool, - _get_dependant: Optional[Callable[..., Any]], - _call_decorators: Iterable["Decorator"], - ) -> None: - self._connection = connection - - super().setup( - logger=logger, - producer=producer, - graceful_timeout=graceful_timeout, - extra_context=extra_context, - broker_parser=broker_parser, - broker_decoder=broker_decoder, - apply_types=apply_types, - is_validate=is_validate, - _get_dependant=_get_dependant, - _call_decorators=_call_decorators, - ) - @property def clear_subject(self) -> str: """Compile `test.{name}` to `test.*` subject.""" @@ -381,6 +346,7 @@ async def _put_msg(self, msg: "Msg") -> None: class CoreSubscriber(_DefaultSubscriber["Msg"]): subscription: Optional["Subscription"] + _connection: Optional["Client"] def __init__( self, @@ -424,6 +390,41 @@ def __init__( include_in_schema=include_in_schema, ) + @override + def setup( # type: ignore[override] + self, + *, + connection: "Client", + # basic args + logger: Optional["LoggerProto"], + producer: Optional["ProducerProto"], + graceful_timeout: Optional[float], + extra_context: "AnyDict", + # broker options + broker_parser: Optional["CustomCallable"], + broker_decoder: Optional["CustomCallable"], + # dependant args + apply_types: bool, + is_validate: bool, + _get_dependant: Optional[Callable[..., Any]], + _call_decorators: Iterable["Decorator"], + ) -> None: + self._connection = connection + + super().setup( + logger=logger, + producer=producer, + graceful_timeout=graceful_timeout, + extra_context=extra_context, + broker_parser=broker_parser, + broker_decoder=broker_decoder, + apply_types=apply_types, + is_validate=is_validate, + _get_dependant=_get_dependant, + _call_decorators=_call_decorators, + ) + + @override async def get_one( self, *, @@ -458,6 +459,8 @@ async def get_one( parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) + raise AssertionError("unreachable") + @override async def _create_subscription( # type: ignore[override] self, @@ -551,6 +554,9 @@ async def _create_subscription( # type: ignore[override] class _StreamSubscriber(_DefaultSubscriber["Msg"]): + _connection: Optional["JetStreamContext"] + __fetch_sub: Optional["JetStreamContext.PullSubscription"] + def __init__( self, *, @@ -576,6 +582,8 @@ def __init__( self.queue = queue self.stream = stream + self.__fetch_sub = None + super().__init__( subject=subject, config=config, @@ -595,6 +603,40 @@ def __init__( include_in_schema=include_in_schema, ) + @override + def setup( # type: ignore[override] + self, + *, + connection: "JetStreamContext", + # basic args + logger: Optional["LoggerProto"], + producer: Optional["ProducerProto"], + graceful_timeout: Optional[float], + extra_context: "AnyDict", + # broker options + broker_parser: Optional["CustomCallable"], + broker_decoder: Optional["CustomCallable"], + # dependant args + apply_types: bool, + is_validate: bool, + _get_dependant: Optional[Callable[..., Any]], + _call_decorators: Iterable["Decorator"], + ) -> None: + self._connection = connection + + super().setup( + logger=logger, + producer=producer, + graceful_timeout=graceful_timeout, + extra_context=extra_context, + broker_parser=broker_parser, + broker_decoder=broker_decoder, + apply_types=apply_types, + is_validate=is_validate, + _get_dependant=_get_dependant, + _call_decorators=_call_decorators, + ) + def get_log_context( self, message: Annotated[ @@ -610,17 +652,18 @@ def get_log_context( stream=self.stream.name, ) + @override async def get_one( self, *, timeout: float = 5, - ) -> Optional[NatsMessage]: + ) -> Optional["NatsMessage"]: assert self._connection, "Please, start() subscriber first" assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." - if not self.subscription: + if not self.__fetch_sub: extra_options = { "pending_bytes_limit": self.extra_options["pending_bytes_limit"], "pending_msgs_limit": self.extra_options["pending_msgs_limit"], @@ -630,7 +673,7 @@ async def get_one( if inbox_prefix := self.extra_options.get("inbox_prefix"): extra_options["inbox_prefix"] = inbox_prefix - self.subscription = await self._connection.pull_subscribe( + self.__fetch_sub = await self._connection.pull_subscribe( subject=self.clear_subject, config=self.config, **extra_options, @@ -638,7 +681,7 @@ async def get_one( try: raw_message = ( - await self.subscription.fetch( + await self.__fetch_sub.fetch( batch=1, timeout=timeout, ) @@ -658,6 +701,8 @@ async def get_one( parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) + raise AssertionError("unreachable") + class PushStreamSubscription(_StreamSubscriber): subscription: Optional["JetStreamContext.PushSubscription"] @@ -894,6 +939,7 @@ class BatchPullStreamSubscriber(_TasksMixin, _DefaultSubscriber[List["Msg"]]): """Batch-message consumer class.""" subscription: Optional["JetStreamContext.PullSubscription"] + __fetch_sub: Optional["JetStreamContext.PullSubscription"] def __init__( self, @@ -920,6 +966,8 @@ def __init__( self.stream = stream self.pull_sub = pull_sub + self.__fetch_sub = None + super().__init__( subject=subject, config=config, @@ -939,25 +987,28 @@ def __init__( include_in_schema=include_in_schema, ) + @override async def get_one( self, *, timeout: float = 5, - ) -> Optional[NatsMessage]: + ) -> Optional["NatsMessage"]: assert self._connection, "Please, start() subscriber first" assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." - if not self.subscription: - self.subscription = await self._connection.pull_subscribe( + if not self.__fetch_sub: + fetch_sub = self.__fetch_sub = await self._connection.pull_subscribe( subject=self.clear_subject, config=self.config, **self.extra_options, ) + else: + fetch_sub = self.__fetch_sub try: - raw_messages = await self.subscription.fetch( + raw_messages = await fetch_sub.fetch( batch=1, timeout=timeout, ) @@ -976,6 +1027,8 @@ async def get_one( parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) + raise AssertionError("unreachable") + @override async def _create_subscription( # type: ignore[override] self, @@ -1044,7 +1097,12 @@ def __init__( include_in_schema=include_in_schema, ) - async def get_one(self, *, timeout: float = 5) -> Optional[NatsKvMessage]: + @override + async def get_one( + self, + *, + timeout: float = 5, + ) -> Optional["NatsKvMessage"]: assert self._connection, "Please, start() subscriber first" assert ( # nosec B101 not self.calls @@ -1063,31 +1121,36 @@ async def get_one(self, *, timeout: float = 5) -> Optional[NatsKvMessage]: include_history=self.kv_watch.include_history, ignore_deletes=self.kv_watch.ignore_deletes, meta_only=self.kv_watch.meta_only, - # inactive_threshold=self.kv_watch.inactive_threshold ) ) raw_message = None sleep_interval = timeout / 10 with anyio.move_on_after(timeout): - while (raw_message := await self.subscription.obj.updates(timeout)) is None: # noqa: ASYNC110 + while ( # noqa: ASYNC110 + raw_message := await self.subscription.obj.updates(timeout) # type: ignore[no-untyped-call] + ) is None: await anyio.sleep(sleep_interval) if not raw_message: return None async with AsyncExitStack() as stack: - return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = return_input + return_msg: Callable[[NatsKvMessage], Awaitable[NatsKvMessage]] = ( + return_input + ) for m in self._broker_middlewares: mid = m(raw_message) await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) - parsed_msg = await self._parser(raw_message) + parsed_msg: NatsKvMessage = await self._parser(raw_message) parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) + raise AssertionError("unreachable") + @override async def _create_subscription( # type: ignore[override] self, @@ -1109,7 +1172,6 @@ async def _create_subscription( # type: ignore[override] include_history=self.kv_watch.include_history, ignore_deletes=self.kv_watch.ignore_deletes, meta_only=self.kv_watch.meta_only, - # inactive_threshold=self.kv_watch.inactive_threshold ) ) @@ -1163,6 +1225,7 @@ def get_log_context( class ObjStoreWatchSubscriber(_TasksMixin, LogicSubscriber[ObjectInfo]): subscription: Optional["UnsubscribeAdapter[ObjectStore.ObjectWatcher]"] + __fetch_sub: Optional["ObjectStore.ObjectWatcher"] def __init__( self, @@ -1182,6 +1245,8 @@ def __init__( self.obj_watch = obj_watch self.obj_watch_conn = None + self.__fetch_sub = None + super().__init__( subject=subject, config=config, @@ -1199,45 +1264,58 @@ def __init__( include_in_schema=include_in_schema, ) - async def get_one(self, *, timeout: float = 5) -> Optional[NatsKvMessage]: + @override + async def get_one( + self, + *, + timeout: float = 5, + ) -> Optional["NatsObjMessage"]: assert self._connection, "Please, start() subscriber first" assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." - if not self.obj_watch_conn: + if not self.__fetch_sub: self.bucket = await self._connection.create_object_store( bucket=self.subject, declare=self.obj_watch.declare, ) - self.obj_watch_conn = await self.bucket.watch( + fetch_sub = self.__fetch_sub = await self.bucket.watch( ignore_deletes=self.obj_watch.ignore_deletes, include_history=self.obj_watch.include_history, meta_only=self.obj_watch.meta_only, ) + else: + fetch_sub = self.__fetch_sub raw_message = None sleep_interval = timeout / 10 with anyio.move_on_after(timeout): - while (raw_message := await self.obj_watch_conn.updates(timeout)) is None: # noqa: ASYNC110 + while ( # noqa: ASYNC110 + raw_message := await fetch_sub.updates(timeout) + ) is None: await anyio.sleep(sleep_interval) if not raw_message: return None async with AsyncExitStack() as stack: - return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = return_input + return_msg: Callable[[NatsObjMessage], Awaitable[NatsObjMessage]] = ( + return_input + ) for m in self._broker_middlewares: mid = m(raw_message) await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) - parsed_msg = await self._parser(raw_message) + parsed_msg: NatsObjMessage = await self._parser(raw_message) parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) + raise AssertionError("unreachable") + @override async def _create_subscription( # type: ignore[override] self, @@ -1270,7 +1348,7 @@ async def _consume_watch(self) -> None: with suppress(TimeoutError): message = cast( Optional["ObjectInfo"], - await obj_watch.updates(self.obj_watch.timeout), # type: ignore[no-untyped-call] + await obj_watch.updates(self.obj_watch.timeout), ) if message: diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index 2c07b7715e..e895089809 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -439,7 +439,9 @@ async def get_one( # type: ignore[override] return None async with AsyncExitStack() as stack: - return_msg: Callable[[RedisListMessage], Awaitable[RedisListMessage]] = return_input + return_msg: Callable[[RedisListMessage], Awaitable[RedisListMessage]] = ( + return_input + ) for m in self._broker_middlewares: mid = m(raw_message) @@ -745,7 +747,9 @@ async def get_one( # type: ignore[override] ) async with AsyncExitStack() as stack: - return_msg: Callable[[RedisStreamMessage], Awaitable[RedisStreamMessage]] = return_input + return_msg: Callable[ + [RedisStreamMessage], Awaitable[RedisStreamMessage] + ] = return_input for m in self._broker_middlewares: mid = m(msg) # type: ignore[arg-type] From 9664558c267eabf614ef51027eb5d6b588a78828 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 17:56:56 +0300 Subject: [PATCH 48/62] lint: fix precommit --- faststream/confluent/subscriber/usecase.py | 2 +- faststream/kafka/subscriber/usecase.py | 2 +- faststream/nats/subscriber/usecase.py | 10 +++++----- faststream/redis/subscriber/usecase.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index 30c24f70d4..ede6e62a5c 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -193,7 +193,7 @@ async def get_one( ] = return_input for m in self._broker_middlewares: - mid = m(raw_message) # type: ignore[arg-type] + mid = m(raw_message) await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index 16fe929c40..f69705e4d0 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -189,7 +189,7 @@ async def get_one( *, timeout: float = 5.0, ) -> "Optional[StreamMessage[MsgType]]": - assert self.consumer, "You should start subscriber at first." + assert self.consumer, "You should start subscriber at first." # nosec B101 assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 2701022590..0edf14bed7 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -430,7 +430,7 @@ async def get_one( *, timeout: float = 5.0, ) -> "Optional[NatsMessage]": - assert self._connection, "Please, start() subscriber first" + assert self._connection, "Please, start() subscriber first" # nosec B101 assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." @@ -658,7 +658,7 @@ async def get_one( *, timeout: float = 5, ) -> Optional["NatsMessage"]: - assert self._connection, "Please, start() subscriber first" + assert self._connection, "Please, start() subscriber first" # nosec B101 assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." @@ -993,7 +993,7 @@ async def get_one( *, timeout: float = 5, ) -> Optional["NatsMessage"]: - assert self._connection, "Please, start() subscriber first" + assert self._connection, "Please, start() subscriber first" # nosec B101 assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." @@ -1103,7 +1103,7 @@ async def get_one( *, timeout: float = 5, ) -> Optional["NatsKvMessage"]: - assert self._connection, "Please, start() subscriber first" + assert self._connection, "Please, start() subscriber first" # nosec B101 assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." @@ -1270,7 +1270,7 @@ async def get_one( *, timeout: float = 5, ) -> Optional["NatsObjMessage"]: - assert self._connection, "Please, start() subscriber first" + assert self._connection, "Please, start() subscriber first" # nosec B101 assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index e895089809..f51b8a57f6 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -421,7 +421,7 @@ async def get_one( # type: ignore[override] *, timeout: float = 5.0, ) -> "Optional[RedisListMessage]": - assert self._client, "You should start subscriber at first." + assert self._client, "You should start subscriber at first." # nosec B101 assert ( # nosec B101 not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." From 164f6cbca6b5e02cfa7bab28100758c70cc64529 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 18:44:48 +0300 Subject: [PATCH 49/62] refactor: fix nats unsub --- faststream/nats/subscriber/asyncapi.py | 2 +- faststream/nats/subscriber/usecase.py | 240 ++++++++++++------------ faststream/nats/testing.py | 11 +- tests/brokers/confluent/test_consume.py | 2 +- tests/brokers/kafka/test_consume.py | 2 +- tests/brokers/nats/test_consume.py | 18 +- tests/brokers/rabbit/test_consume.py | 2 +- tests/brokers/redis/test_consume.py | 2 +- 8 files changed, 137 insertions(+), 142 deletions(-) diff --git a/faststream/nats/subscriber/asyncapi.py b/faststream/nats/subscriber/asyncapi.py index ad0edb0bca..402aa0b114 100644 --- a/faststream/nats/subscriber/asyncapi.py +++ b/faststream/nats/subscriber/asyncapi.py @@ -25,7 +25,7 @@ ) -class AsyncAPISubscriber(LogicSubscriber[Any]): +class AsyncAPISubscriber(LogicSubscriber[Any, Any]): """A class to represent a NATS handler.""" def get_name(self) -> str: diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 0edf14bed7..b922cfdac1 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -9,10 +9,12 @@ Callable, Coroutine, Dict, + Generic, Iterable, List, Optional, Sequence, + TypeVar, Union, cast, ) @@ -21,7 +23,6 @@ from fast_depends.dependencies import Depends from nats.errors import ConnectionClosedError, TimeoutError from nats.js.api import ConsumerConfig, ObjectInfo -from nats.js.kv import KeyValue from typing_extensions import Annotated, Doc, override from faststream.broker.publisher.fake import FakePublisher @@ -50,6 +51,7 @@ from nats.aio.msg import Msg from nats.aio.subscription import Subscription from nats.js import JetStreamContext + from nats.js.kv import KeyValue from nats.js.object_store import ObjectStore from faststream.broker.message import StreamMessage @@ -64,11 +66,16 @@ from faststream.types import Decorator -class LogicSubscriber(SubscriberUsecase[MsgType]): +ConnectionType = TypeVar("ConnectionType") + + +class LogicSubscriber(Generic[ConnectionType, MsgType], SubscriberUsecase[MsgType]): """A class to represent a NATS handler.""" subscription: Optional[Unsubscriptable] + _fetch_sub: Optional[Unsubscriptable] producer: Optional["ProducerProto"] + _connection: Optional[ConnectionType] def __init__( self, @@ -109,10 +116,45 @@ def __init__( include_in_schema=include_in_schema, ) - self._connection: Any = None + self._connection = None + self._fetch_sub = None self.subscription = None self.producer = None + @override + def setup( # type: ignore[override] + self, + *, + connection: ConnectionType, + # basic args + logger: Optional["LoggerProto"], + producer: Optional["ProducerProto"], + graceful_timeout: Optional[float], + extra_context: "AnyDict", + # broker options + broker_parser: Optional["CustomCallable"], + broker_decoder: Optional["CustomCallable"], + # dependant args + apply_types: bool, + is_validate: bool, + _get_dependant: Optional[Callable[..., Any]], + _call_decorators: Iterable["Decorator"], + ) -> None: + self._connection = connection + + super().setup( + logger=logger, + producer=producer, + graceful_timeout=graceful_timeout, + extra_context=extra_context, + broker_parser=broker_parser, + broker_decoder=broker_decoder, + apply_types=apply_types, + is_validate=is_validate, + _get_dependant=_get_dependant, + _call_decorators=_call_decorators, + ) + @property def clear_subject(self) -> str: """Compile `test.{name}` to `test.*` subject.""" @@ -136,13 +178,15 @@ async def close(self) -> None: await self.subscription.unsubscribe() self.subscription = None + if self._fetch_sub is not None: + await self._fetch_sub.unsubscribe() + self.subscription = None + @abstractmethod async def _create_subscription( self, *, - connection: Union[ - "Client", "JetStreamContext", "KVBucketDeclarer", "OSBucketDeclarer" - ], + connection: ConnectionType, ) -> None: """Create NATS subscription object to consume messages.""" raise NotImplementedError() @@ -206,7 +250,7 @@ def get_routing_hash( return hash(subject) -class _DefaultSubscriber(LogicSubscriber[MsgType]): +class _DefaultSubscriber(LogicSubscriber[ConnectionType, MsgType]): def __init__( self, *, @@ -277,7 +321,7 @@ def get_log_context( ) -class _TasksMixin(LogicSubscriber[Any]): +class _TasksMixin(LogicSubscriber[Any, Any]): def __init__(self, **kwargs: Any) -> None: self.tasks: List[asyncio.Task[Any]] = [] @@ -344,9 +388,9 @@ async def _put_msg(self, msg: "Msg") -> None: await self.send_stream.send(msg) -class CoreSubscriber(_DefaultSubscriber["Msg"]): +class CoreSubscriber(_DefaultSubscriber["Client", "Msg"]): subscription: Optional["Subscription"] - _connection: Optional["Client"] + _fetch_sub: Optional["Subscription"] def __init__( self, @@ -390,40 +434,6 @@ def __init__( include_in_schema=include_in_schema, ) - @override - def setup( # type: ignore[override] - self, - *, - connection: "Client", - # basic args - logger: Optional["LoggerProto"], - producer: Optional["ProducerProto"], - graceful_timeout: Optional[float], - extra_context: "AnyDict", - # broker options - broker_parser: Optional["CustomCallable"], - broker_decoder: Optional["CustomCallable"], - # dependant args - apply_types: bool, - is_validate: bool, - _get_dependant: Optional[Callable[..., Any]], - _call_decorators: Iterable["Decorator"], - ) -> None: - self._connection = connection - - super().setup( - logger=logger, - producer=producer, - graceful_timeout=graceful_timeout, - extra_context=extra_context, - broker_parser=broker_parser, - broker_decoder=broker_decoder, - apply_types=apply_types, - is_validate=is_validate, - _get_dependant=_get_dependant, - _call_decorators=_call_decorators, - ) - @override async def get_one( self, @@ -435,15 +445,17 @@ async def get_one( not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." - if self.subscription is None: - self.subscription = await self._connection.subscribe( + if self._fetch_sub is None: + fetch_sub = self._fetch_sub = await self._connection.subscribe( subject=self.clear_subject, queue=self.queue, **self.extra_options, ) + else: + fetch_sub = self._fetch_sub try: - raw_message = await self.subscription.next_msg(timeout) + raw_message = await fetch_sub.next_msg(timeout=timeout) except TimeoutError: return None @@ -462,7 +474,7 @@ async def get_one( raise AssertionError("unreachable") @override - async def _create_subscription( # type: ignore[override] + async def _create_subscription( self, *, connection: "Client", @@ -493,7 +505,10 @@ def get_log_context( ) -class ConcurrentCoreSubscriber(_ConcurrentMixin, CoreSubscriber): +class ConcurrentCoreSubscriber( + _ConcurrentMixin, + CoreSubscriber, +): def __init__( self, *, @@ -534,7 +549,7 @@ def __init__( ) @override - async def _create_subscription( # type: ignore[override] + async def _create_subscription( self, *, connection: "Client", @@ -553,9 +568,8 @@ async def _create_subscription( # type: ignore[override] ) -class _StreamSubscriber(_DefaultSubscriber["Msg"]): - _connection: Optional["JetStreamContext"] - __fetch_sub: Optional["JetStreamContext.PullSubscription"] +class _StreamSubscriber(_DefaultSubscriber["JetStreamContext", "Msg"]): + _fetch_sub: Optional["JetStreamContext.PullSubscription"] def __init__( self, @@ -582,8 +596,6 @@ def __init__( self.queue = queue self.stream = stream - self.__fetch_sub = None - super().__init__( subject=subject, config=config, @@ -603,40 +615,6 @@ def __init__( include_in_schema=include_in_schema, ) - @override - def setup( # type: ignore[override] - self, - *, - connection: "JetStreamContext", - # basic args - logger: Optional["LoggerProto"], - producer: Optional["ProducerProto"], - graceful_timeout: Optional[float], - extra_context: "AnyDict", - # broker options - broker_parser: Optional["CustomCallable"], - broker_decoder: Optional["CustomCallable"], - # dependant args - apply_types: bool, - is_validate: bool, - _get_dependant: Optional[Callable[..., Any]], - _call_decorators: Iterable["Decorator"], - ) -> None: - self._connection = connection - - super().setup( - logger=logger, - producer=producer, - graceful_timeout=graceful_timeout, - extra_context=extra_context, - broker_parser=broker_parser, - broker_decoder=broker_decoder, - apply_types=apply_types, - is_validate=is_validate, - _get_dependant=_get_dependant, - _call_decorators=_call_decorators, - ) - def get_log_context( self, message: Annotated[ @@ -663,7 +641,7 @@ async def get_one( not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." - if not self.__fetch_sub: + if not self._fetch_sub: extra_options = { "pending_bytes_limit": self.extra_options["pending_bytes_limit"], "pending_msgs_limit": self.extra_options["pending_msgs_limit"], @@ -673,7 +651,7 @@ async def get_one( if inbox_prefix := self.extra_options.get("inbox_prefix"): extra_options["inbox_prefix"] = inbox_prefix - self.__fetch_sub = await self._connection.pull_subscribe( + self._fetch_sub = await self._connection.pull_subscribe( subject=self.clear_subject, config=self.config, **extra_options, @@ -681,7 +659,7 @@ async def get_one( try: raw_message = ( - await self.__fetch_sub.fetch( + await self._fetch_sub.fetch( batch=1, timeout=timeout, ) @@ -708,7 +686,7 @@ class PushStreamSubscription(_StreamSubscriber): subscription: Optional["JetStreamContext.PushSubscription"] @override - async def _create_subscription( # type: ignore[override] + async def _create_subscription( self, *, connection: "JetStreamContext", @@ -726,7 +704,10 @@ async def _create_subscription( # type: ignore[override] ) -class ConcurrentPushStreamSubscriber(_ConcurrentMixin, _StreamSubscriber): +class ConcurrentPushStreamSubscriber( + _ConcurrentMixin, + _StreamSubscriber, +): subscription: Optional["JetStreamContext.PushSubscription"] def __init__( @@ -771,7 +752,7 @@ def __init__( ) @override - async def _create_subscription( # type: ignore[override] + async def _create_subscription( self, *, connection: "JetStreamContext", @@ -791,7 +772,10 @@ async def _create_subscription( # type: ignore[override] ) -class PullStreamSubscriber(_TasksMixin, _StreamSubscriber): +class PullStreamSubscriber( + _TasksMixin, + _StreamSubscriber, +): subscription: Optional["JetStreamContext.PullSubscription"] def __init__( @@ -836,7 +820,7 @@ def __init__( ) @override - async def _create_subscription( # type: ignore[override] + async def _create_subscription( self, *, connection: "JetStreamContext", @@ -873,7 +857,10 @@ async def _consume_pull( tg.start_soon(cb, msg) -class ConcurrentPullStreamSubscriber(_ConcurrentMixin, PullStreamSubscriber): +class ConcurrentPullStreamSubscriber( + _ConcurrentMixin, + PullStreamSubscriber, +): def __init__( self, *, @@ -916,7 +903,7 @@ def __init__( ) @override - async def _create_subscription( # type: ignore[override] + async def _create_subscription( self, *, connection: "JetStreamContext", @@ -935,11 +922,14 @@ async def _create_subscription( # type: ignore[override] self.add_task(self._consume_pull(cb=self._put_msg)) -class BatchPullStreamSubscriber(_TasksMixin, _DefaultSubscriber[List["Msg"]]): +class BatchPullStreamSubscriber( + _TasksMixin, + _DefaultSubscriber["JetStreamContext", List["Msg"]], +): """Batch-message consumer class.""" subscription: Optional["JetStreamContext.PullSubscription"] - __fetch_sub: Optional["JetStreamContext.PullSubscription"] + _fetch_sub: Optional["JetStreamContext.PullSubscription"] def __init__( self, @@ -966,8 +956,6 @@ def __init__( self.stream = stream self.pull_sub = pull_sub - self.__fetch_sub = None - super().__init__( subject=subject, config=config, @@ -998,14 +986,14 @@ async def get_one( not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." - if not self.__fetch_sub: - fetch_sub = self.__fetch_sub = await self._connection.pull_subscribe( + if not self._fetch_sub: + fetch_sub = self._fetch_sub = await self._connection.pull_subscribe( subject=self.clear_subject, config=self.config, **self.extra_options, ) else: - fetch_sub = self.__fetch_sub + fetch_sub = self._fetch_sub try: raw_messages = await fetch_sub.fetch( @@ -1030,7 +1018,7 @@ async def get_one( raise AssertionError("unreachable") @override - async def _create_subscription( # type: ignore[override] + async def _create_subscription( self, *, connection: "JetStreamContext", @@ -1061,8 +1049,12 @@ async def _consume_pull(self) -> None: await self.consume(messages) -class KeyValueWatchSubscriber(_TasksMixin, LogicSubscriber[KeyValue.Entry]): +class KeyValueWatchSubscriber( + _TasksMixin, + LogicSubscriber["KVBucketDeclarer", "KeyValue.Entry"], +): subscription: Optional["UnsubscribeAdapter[KeyValue.KeyWatcher]"] + _fetch_sub: Optional[UnsubscribeAdapter["KeyValue.KeyWatcher"]] def __init__( self, @@ -1108,13 +1100,13 @@ async def get_one( not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." - if not self.subscription: + if not self._fetch_sub: bucket = await self._connection.create_key_value( bucket=self.kv_watch.name, declare=self.kv_watch.declare, ) - self.subscription = UnsubscribeAdapter["KeyValue.KeyWatcher"]( + fetch_sub = self._fetch_sub = UnsubscribeAdapter["KeyValue.KeyWatcher"]( await bucket.watch( keys=self.clear_subject, headers_only=self.kv_watch.headers_only, @@ -1123,12 +1115,14 @@ async def get_one( meta_only=self.kv_watch.meta_only, ) ) + else: + fetch_sub = self._fetch_sub raw_message = None sleep_interval = timeout / 10 with anyio.move_on_after(timeout): while ( # noqa: ASYNC110 - raw_message := await self.subscription.obj.updates(timeout) # type: ignore[no-untyped-call] + raw_message := await fetch_sub.obj.updates(timeout) # type: ignore[no-untyped-call] ) is None: await anyio.sleep(sleep_interval) @@ -1152,7 +1146,7 @@ async def get_one( raise AssertionError("unreachable") @override - async def _create_subscription( # type: ignore[override] + async def _create_subscription( self, *, connection: "KVBucketDeclarer", @@ -1223,9 +1217,12 @@ def get_log_context( OBJECT_STORAGE_CONTEXT_KEY = "__object_storage" -class ObjStoreWatchSubscriber(_TasksMixin, LogicSubscriber[ObjectInfo]): +class ObjStoreWatchSubscriber( + _TasksMixin, + LogicSubscriber["OSBucketDeclarer", ObjectInfo], +): subscription: Optional["UnsubscribeAdapter[ObjectStore.ObjectWatcher]"] - __fetch_sub: Optional["ObjectStore.ObjectWatcher"] + _fetch_sub: Optional[UnsubscribeAdapter["ObjectStore.ObjectWatcher"]] def __init__( self, @@ -1245,8 +1242,6 @@ def __init__( self.obj_watch = obj_watch self.obj_watch_conn = None - self.__fetch_sub = None - super().__init__( subject=subject, config=config, @@ -1275,25 +1270,28 @@ async def get_one( not self.calls ), "You can't use `get_one` method if subscriber has registered handlers." - if not self.__fetch_sub: + if not self._fetch_sub: self.bucket = await self._connection.create_object_store( bucket=self.subject, declare=self.obj_watch.declare, ) - fetch_sub = self.__fetch_sub = await self.bucket.watch( + obj_watch = await self.bucket.watch( ignore_deletes=self.obj_watch.ignore_deletes, include_history=self.obj_watch.include_history, meta_only=self.obj_watch.meta_only, ) + fetch_sub = self._fetch_sub = UnsubscribeAdapter[ + "ObjectStore.ObjectWatcher" + ](obj_watch) else: - fetch_sub = self.__fetch_sub + fetch_sub = self._fetch_sub raw_message = None sleep_interval = timeout / 10 with anyio.move_on_after(timeout): while ( # noqa: ASYNC110 - raw_message := await fetch_sub.updates(timeout) + raw_message := await fetch_sub.obj.updates(timeout) # type: ignore[no-untyped-call] ) is None: await anyio.sleep(sleep_interval) @@ -1317,7 +1315,7 @@ async def get_one( raise AssertionError("unreachable") @override - async def _create_subscription( # type: ignore[override] + async def _create_subscription( self, *, connection: "OSBucketDeclarer", diff --git a/faststream/nats/testing.py b/faststream/nats/testing.py index 011629093b..08d43ed30f 100644 --- a/faststream/nats/testing.py +++ b/faststream/nats/testing.py @@ -30,8 +30,8 @@ class TestNatsBroker(TestBroker[NatsBroker]): def create_publisher_fake_subscriber( broker: NatsBroker, publisher: "AsyncAPIPublisher", - ) -> Tuple["LogicSubscriber[Any]", bool]: - sub: Optional[LogicSubscriber[Any]] = None + ) -> Tuple["LogicSubscriber[Any, Any]", bool]: + sub: Optional[LogicSubscriber[Any, Any]] = None publisher_stream = publisher.stream.name if publisher.stream else None for handler in broker._subscribers.values(): if _is_handler_suitable(handler, publisher.subject, publisher_stream): @@ -144,7 +144,10 @@ async def request( # type: ignore[override] raise SubscriberNotFound async def _execute_handler( - self, msg: Any, subject: str, handler: "LogicSubscriber[Any]" + self, + msg: Any, + subject: str, + handler: "LogicSubscriber[Any, Any]", ) -> "PatchedMessage": result = await handler.process_message(msg) @@ -157,7 +160,7 @@ async def _execute_handler( def _is_handler_suitable( - handler: "LogicSubscriber[Any]", + handler: "LogicSubscriber[Any, Any]", subject: str, stream: Optional[str] = None, ) -> bool: diff --git a/tests/brokers/confluent/test_consume.py b/tests/brokers/confluent/test_consume.py index 108450882f..ac34a71230 100644 --- a/tests/brokers/confluent/test_consume.py +++ b/tests/brokers/confluent/test_consume.py @@ -373,7 +373,7 @@ async def test_get_one_timeout( async def coro(): nonlocal message - message = await subscriber.get_one(timeout=1) + message = await subscriber.get_one(timeout=1e-24) await asyncio.wait((asyncio.create_task(coro()),), timeout=3) diff --git a/tests/brokers/kafka/test_consume.py b/tests/brokers/kafka/test_consume.py index 57ff6c4e5d..c93804861f 100644 --- a/tests/brokers/kafka/test_consume.py +++ b/tests/brokers/kafka/test_consume.py @@ -356,7 +356,7 @@ async def test_get_one_timeout( async def coro(): nonlocal message - message = await subscriber.get_one(timeout=1) + message = await subscriber.get_one(timeout=1e-24) await asyncio.wait((asyncio.create_task(coro()),), timeout=3) diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index be53e398de..63bc0a8aa7 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -447,7 +447,7 @@ async def test_get_one_timeout( async def coro(): nonlocal message - message = await subscriber.get_one(timeout=1) + message = await subscriber.get_one(timeout=1e-24) await asyncio.wait((asyncio.create_task(coro()),), timeout=3) @@ -489,6 +489,7 @@ async def test_get_one_timeout_js( self, queue: str, stream: JStream, + mock, ): broker = self.get_broker(apply_types=True) subscriber = broker.subscriber(queue, stream=stream) @@ -496,15 +497,8 @@ async def test_get_one_timeout_js( async with self.patch_broker(broker) as br: await br.start() - message = object() - - async def coro(): - nonlocal message - message = await subscriber.get_one(timeout=0.5) - - await asyncio.wait((asyncio.create_task(coro()),), timeout=3) - - assert message is None + mock(await subscriber.get_one(timeout=1e-24)) + mock.assert_called_once_with(None) async def test_get_one_pull( self, @@ -562,7 +556,7 @@ async def test_get_one_pull_timeout( async def consume(): nonlocal message - message = await subscriber.get_one(timeout=0.5) + message = await subscriber.get_one(timeout=1e-24) await asyncio.wait((asyncio.create_task(consume()),), timeout=3) @@ -624,7 +618,7 @@ async def test_get_one_batch_timeout( async def consume(): nonlocal message - message = await subscriber.get_one(timeout=0.5) + message = await subscriber.get_one(timeout=1e-24) await asyncio.wait((asyncio.create_task(consume()),), timeout=3) diff --git a/tests/brokers/rabbit/test_consume.py b/tests/brokers/rabbit/test_consume.py index daa56fc44d..7788d9b8f5 100644 --- a/tests/brokers/rabbit/test_consume.py +++ b/tests/brokers/rabbit/test_consume.py @@ -433,7 +433,7 @@ async def test_get_one_timeout( async def coro(): nonlocal message - message = await subscriber.get_one(timeout=1) + message = await subscriber.get_one(timeout=1e-24) await asyncio.wait((asyncio.create_task(coro()),), timeout=3) diff --git a/tests/brokers/redis/test_consume.py b/tests/brokers/redis/test_consume.py index 881f58a1f4..a43cc8bc5b 100644 --- a/tests/brokers/redis/test_consume.py +++ b/tests/brokers/redis/test_consume.py @@ -740,7 +740,7 @@ async def test_get_one_timeout( async def coro(): nonlocal message - message = await subscriber.get_one(timeout=1) + message = await subscriber.get_one(timeout=1e-24) await asyncio.wait((asyncio.create_task(coro()),), timeout=3) From 669d7c17d73aa2aca95af271a3f4db16fd09e7bf Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 18:59:35 +0300 Subject: [PATCH 50/62] fix: correct redis timeout --- faststream/redis/subscriber/usecase.py | 2 +- tests/brokers/redis/test_consume.py | 36 +++++++------------------- 2 files changed, 10 insertions(+), 28 deletions(-) diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index f51b8a57f6..ebeaf10183 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -728,7 +728,7 @@ async def get_one( # type: ignore[override] stream_message = await self._client.xread( {self.stream_sub.name: self.last_id}, - block=timeout * 1000, + block=int(timeout * 1000) or None, count=1, ) diff --git a/tests/brokers/redis/test_consume.py b/tests/brokers/redis/test_consume.py index a43cc8bc5b..7cc5e22080 100644 --- a/tests/brokers/redis/test_consume.py +++ b/tests/brokers/redis/test_consume.py @@ -126,6 +126,7 @@ async def publish(): async def test_get_one_timeout( self, queue: str, + mock: MagicMock, ): broker = self.get_broker(apply_types=True) subscriber = broker.subscriber(queue) @@ -133,15 +134,8 @@ async def test_get_one_timeout( async with self.patch_broker(broker) as br: await br.start() - message = object() - - async def coro(): - nonlocal message - message = await subscriber.get_one(timeout=1) - - await asyncio.wait((asyncio.create_task(coro()),), timeout=3) - - assert message is None + mock(await subscriber.get_one(timeout=1e-24)) + mock.assert_called_once_with(None) @pytest.mark.redis @@ -396,6 +390,7 @@ async def publish(): async def test_get_one_timeout( self, queue: str, + mock: MagicMock, ): broker = self.get_broker(apply_types=True) subscriber = broker.subscriber(list=queue) @@ -403,15 +398,8 @@ async def test_get_one_timeout( async with self.patch_broker(broker) as br: await br.start() - message = object() - - async def coro(): - nonlocal message - message = await subscriber.get_one(timeout=1) - - await asyncio.wait((asyncio.create_task(coro()),), timeout=3) - - assert message is None + mock(await subscriber.get_one(timeout=1e-24)) + mock.assert_called_once_with(None) @pytest.mark.redis @@ -729,6 +717,7 @@ async def publish(): async def test_get_one_timeout( self, queue: str, + mock: MagicMock, ): broker = self.get_broker(apply_types=True) subscriber = broker.subscriber(stream=queue) @@ -736,12 +725,5 @@ async def test_get_one_timeout( async with self.patch_broker(broker) as br: await br.start() - message = object() - - async def coro(): - nonlocal message - message = await subscriber.get_one(timeout=1e-24) - - await asyncio.wait((asyncio.create_task(coro()),), timeout=3) - - assert message is None + mock(await subscriber.get_one(timeout=1e-24)) + mock.assert_called_once_with(None) From 0a75e12c64ff0fe6f9fad884075d658f2235a46d Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 19:20:48 +0300 Subject: [PATCH 51/62] fix: correct redis channel sub --- faststream/broker/middlewares/exception.py | 4 +++- faststream/redis/subscriber/usecase.py | 10 +++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/faststream/broker/middlewares/exception.py b/faststream/broker/middlewares/exception.py index 2693bfe8bd..cd1b79fe28 100644 --- a/faststream/broker/middlewares/exception.py +++ b/faststream/broker/middlewares/exception.py @@ -28,7 +28,9 @@ from faststream.types import AsyncFuncAny -GeneralExceptionHandler: TypeAlias = Union[Callable[..., None], Callable[..., Awaitable[None]]] +GeneralExceptionHandler: TypeAlias = Union[ + Callable[..., None], Callable[..., Awaitable[None]] +] PublishingExceptionHandler: TypeAlias = Callable[..., "Any"] CastedGeneralExceptionHandler: TypeAlias = Callable[..., Awaitable[None]] diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index ebeaf10183..ec027269ef 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -1,4 +1,5 @@ import asyncio +import math from abc import abstractmethod from contextlib import AsyncExitStack, suppress from copy import deepcopy @@ -171,6 +172,9 @@ async def start( with anyio.fail_after(3.0): await start_signal.wait() + else: + start_signal.set() + async def _consume(self, *args: Any, start_signal: anyio.Event) -> None: connected = True @@ -340,8 +344,8 @@ async def _get_message(self, psub: RPubSub) -> Optional[PubSubMessage]: return None async def _get_msgs(self, psub: RPubSub) -> None: - msg = await self._get_message(psub) - await self.consume(msg) # type: ignore[arg-type] + if msg := await self._get_message(psub): + await self.consume(msg) # type: ignore[arg-type] def add_prefix(self, prefix: str) -> None: new_ch = deepcopy(self.channel) @@ -728,7 +732,7 @@ async def get_one( # type: ignore[override] stream_message = await self._client.xread( {self.stream_sub.name: self.last_id}, - block=int(timeout * 1000) or None, + block=math.ceil(timeout * 1000), count=1, ) From 1056fd9cde5f138296998bd9bf3b907c2939bc11 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 19:27:18 +0300 Subject: [PATCH 52/62] lint: fix precommit --- faststream/broker/middlewares/exception.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/faststream/broker/middlewares/exception.py b/faststream/broker/middlewares/exception.py index cd1b79fe28..f0325a1788 100644 --- a/faststream/broker/middlewares/exception.py +++ b/faststream/broker/middlewares/exception.py @@ -11,6 +11,7 @@ Tuple, Type, Union, + cast, overload, ) @@ -128,7 +129,12 @@ def __init__( self._handlers: CastedHandlers = [ (IgnoredException, ignore_handler), *( - (exc_type, apply_types(to_async(handler))) + ( + exc_type, + apply_types( + cast(Callable[..., Awaitable[None]], to_async(handler)) + ), + ) for exc_type, handler in (handlers or {}).items() ), ] From f99dc086867b8466ee54701acc19d85286f2b2e5 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 19:41:15 +0300 Subject: [PATCH 53/62] tests: mv get_one tests to basic testcase --- tests/brokers/base/consume.py | 63 +++++++++++++++++++++++++ tests/brokers/confluent/test_consume.py | 56 ---------------------- tests/brokers/kafka/test_consume.py | 49 ------------------- tests/brokers/nats/test_consume.py | 51 -------------------- tests/brokers/rabbit/test_consume.py | 55 --------------------- tests/brokers/redis/test_consume.py | 45 ------------------ 6 files changed, 63 insertions(+), 256 deletions(-) diff --git a/tests/brokers/base/consume.py b/tests/brokers/base/consume.py index 4e7e2621de..936d2bc575 100644 --- a/tests/brokers/base/consume.py +++ b/tests/brokers/base/consume.py @@ -273,6 +273,69 @@ async def subscriber(m): assert event.is_set() + async def test_get_one( + self, + queue: str, + event: asyncio.Event, + mock: MagicMock, + ): + broker = self.get_broker(apply_types=True) + + args, kwargs = self.get_subscriber_params(queue) + subscriber = broker.subscriber(*args, **kwargs) + + async with self.patch_broker(broker) as br: + await br.start() + + async def consume(): + mock(await subscriber.get_one(timeout=self.timeout)) + + async def publish(): + await anyio.sleep(1e-24) + await br.publish("test_message", queue) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + asyncio.create_task(publish()), + ), + timeout=self.timeout, + ) + + mock.assert_called_once() + message = mock.call_args[0][0] + assert message + assert await message.decode() == "test_message" + + async def test_get_one_timeout( + self, + queue: str, + mock: MagicMock, + ): + broker = self.get_broker(apply_types=True) + args, kwargs = self.get_subscriber_params(queue) + subscriber = broker.subscriber(*args, **kwargs) + + async with self.patch_broker(broker) as br: + await br.start() + + mock(await subscriber.get_one(timeout=1e-24)) + mock.assert_called_once_with(None) + + async def test_get_one_conflicts_with_handler(self, queue): + broker = self.get_broker(apply_types=True) + args, kwargs = self.get_subscriber_params(queue) + subscriber = broker.subscriber(*args, **kwargs) + + @subscriber + async def t(): ... + + async with self.patch_broker(broker) as br: + await br.start() + + with pytest.raises(AssertionError): + await subscriber.get_one(timeout=1e-24) + @pytest.mark.asyncio class BrokerRealConsumeTestcase(BrokerConsumeTestcase): diff --git a/tests/brokers/confluent/test_consume.py b/tests/brokers/confluent/test_consume.py index ac34a71230..f3eb5774cd 100644 --- a/tests/brokers/confluent/test_consume.py +++ b/tests/brokers/confluent/test_consume.py @@ -322,59 +322,3 @@ async def subscriber_with_auto_commit(m): assert event.is_set() assert event2.is_set() - - @pytest.mark.asyncio - async def test_get_one( - self, - queue: str, - event: asyncio.Event, - ): - broker = self.get_broker(apply_types=True) - - args, kwargs = self.get_subscriber_params(queue) - - subscriber = broker.subscriber(*args, **kwargs) - - async with self.patch_broker(broker) as br: - await br.start() - - message = None - - async def consume(): - nonlocal message - message = await subscriber.get_one(timeout=5) - - async def publish(): - await br.publish("test_message", queue) - - await asyncio.wait( - ( - asyncio.create_task(consume()), - asyncio.create_task(publish()), - ), - timeout=10, - ) - - assert message is not None - assert await message.decode() == "test_message" - - @pytest.mark.asyncio - async def test_get_one_timeout( - self, - queue: str, - ): - broker = self.get_broker(apply_types=True) - subscriber = broker.subscriber(queue) - - async with self.patch_broker(broker) as br: - await br.start() - - message = object() - - async def coro(): - nonlocal message - message = await subscriber.get_one(timeout=1e-24) - - await asyncio.wait((asyncio.create_task(coro()),), timeout=3) - - assert message is None diff --git a/tests/brokers/kafka/test_consume.py b/tests/brokers/kafka/test_consume.py index c93804861f..7da9f90a5f 100644 --- a/tests/brokers/kafka/test_consume.py +++ b/tests/brokers/kafka/test_consume.py @@ -312,52 +312,3 @@ async def handler(msg: KafkaMessage): m.mock.assert_not_called() assert event.is_set() - - @pytest.mark.asyncio - async def test_get_one( - self, - queue: str, - ): - broker = self.get_broker(apply_types=True) - subscriber = broker.subscriber(queue) - - async with self.patch_broker(broker) as br: - await br.start() - - message = None - - async def set_msg(): - nonlocal message - message = await subscriber.get_one() - - await asyncio.wait( - ( - asyncio.create_task(br.publish("test_message", queue)), - asyncio.create_task(set_msg()), - ), - timeout=3, - ) - - assert message is not None - assert await message.decode() == "test_message" - - @pytest.mark.asyncio - async def test_get_one_timeout( - self, - queue: str, - ): - broker = self.get_broker(apply_types=True) - subscriber = broker.subscriber(queue) - - async with self.patch_broker(broker) as br: - await br.start() - - message = object() - - async def coro(): - nonlocal message - message = await subscriber.get_one(timeout=1e-24) - - await asyncio.wait((asyncio.create_task(coro()),), timeout=3) - - assert message is None diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index 63bc0a8aa7..ab7452a415 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -402,57 +402,6 @@ async def handler(filename: str): assert event.is_set() mock.assert_called_once_with("hello") - async def test_get_one( - self, - queue: str, - event: asyncio.Event, - ): - broker = self.get_broker(apply_types=True) - subscriber = broker.subscriber(queue) - - async with self.patch_broker(broker) as br: - await br.start() - - message = None - - async def consume(): - nonlocal message - message = await subscriber.get_one(timeout=5) - - async def publish(): - await br.publish("test_message", queue) - - await asyncio.wait( - ( - asyncio.create_task(consume()), - asyncio.create_task(publish()), - ), - timeout=10, - ) - - assert message is not None - assert await message.decode() == "test_message" - - async def test_get_one_timeout( - self, - queue: str, - ): - broker = self.get_broker(apply_types=True) - subscriber = broker.subscriber(queue) - - async with self.patch_broker(broker) as br: - await br.start() - - message = object() - - async def coro(): - nonlocal message - message = await subscriber.get_one(timeout=1e-24) - - await asyncio.wait((asyncio.create_task(coro()),), timeout=3) - - assert message is None - async def test_get_one_js( self, queue: str, diff --git a/tests/brokers/rabbit/test_consume.py b/tests/brokers/rabbit/test_consume.py index 7788d9b8f5..cd2550429c 100644 --- a/tests/brokers/rabbit/test_consume.py +++ b/tests/brokers/rabbit/test_consume.py @@ -383,58 +383,3 @@ async def handler(msg: RabbitMessage): m.mock.assert_not_called() assert event.is_set() - - @pytest.mark.asyncio - async def test_get_one( - self, - queue: str, - exchange: RabbitExchange, - ): - broker = self.get_broker(apply_types=True) - subscriber = broker.subscriber(queue, exchange=exchange) - - async with self.patch_broker(broker) as br: - await br.start() - - message = None - - async def set_msg(): - nonlocal message - message = await subscriber.get_one() - - await asyncio.wait( - ( - asyncio.create_task( - br.publish( - message="test_message", queue=queue, exchange=exchange - ) - ), - asyncio.create_task(set_msg()), - ), - timeout=3, - ) - - assert message is not None - assert await message.decode() == "test_message" - - @pytest.mark.asyncio - async def test_get_one_timeout( - self, - queue: str, - exchange: RabbitExchange, - ): - broker = self.get_broker(apply_types=True) - subscriber = broker.subscriber(queue, exchange=exchange) - - async with self.patch_broker(broker) as br: - await br.start() - - message = object() - - async def coro(): - nonlocal message - message = await subscriber.get_one(timeout=1e-24) - - await asyncio.wait((asyncio.create_task(coro()),), timeout=3) - - assert message is None diff --git a/tests/brokers/redis/test_consume.py b/tests/brokers/redis/test_consume.py index 7cc5e22080..467254a62f 100644 --- a/tests/brokers/redis/test_consume.py +++ b/tests/brokers/redis/test_consume.py @@ -92,51 +92,6 @@ async def handler(msg): mock.assert_called_once_with("hello") - async def test_get_one( - self, - queue: str, - event: asyncio.Event, - ): - broker = self.get_broker(apply_types=True) - subscriber = broker.subscriber(queue) - - async with self.patch_broker(broker) as br: - await br.start() - - message = None - - async def consume(): - nonlocal message - message = await subscriber.get_one(timeout=5) - - async def publish(): - await br.publish("test_message", queue) - - await asyncio.wait( - ( - asyncio.create_task(consume()), - asyncio.create_task(publish()), - ), - timeout=10, - ) - - assert message is not None - assert await message.decode() == "test_message" - - async def test_get_one_timeout( - self, - queue: str, - mock: MagicMock, - ): - broker = self.get_broker(apply_types=True) - subscriber = broker.subscriber(queue) - - async with self.patch_broker(broker) as br: - await br.start() - - mock(await subscriber.get_one(timeout=1e-24)) - mock.assert_called_once_with(None) - @pytest.mark.redis @pytest.mark.asyncio From 6602569fa0a1e15c71ce8979d3d7c773be27a8cb Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 20:04:49 +0300 Subject: [PATCH 54/62] tests: mv get_one tests to real testcase --- faststream/broker/subscriber/utils.py | 66 ++++++++++++ faststream/nats/subscriber/usecase.py | 143 +++++++++----------------- tests/brokers/base/consume.py | 34 +++--- 3 files changed, 134 insertions(+), 109 deletions(-) create mode 100644 faststream/broker/subscriber/utils.py diff --git a/faststream/broker/subscriber/utils.py b/faststream/broker/subscriber/utils.py new file mode 100644 index 0000000000..7f2414cbb9 --- /dev/null +++ b/faststream/broker/subscriber/utils.py @@ -0,0 +1,66 @@ +from contextlib import AsyncExitStack +from functools import partial +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Iterable, + Optional, +) + +from typing_extensions import Literal, overload + +from faststream.broker.types import MsgType +from faststream.utils.functions import return_input + +if TYPE_CHECKING: + from faststream.broker.message import StreamMessage + from faststream.broker.types import ( + BrokerMiddleware, + ) + + +@overload +async def process_msg( + msg: Literal[None], + middlewares: Iterable["BrokerMiddleware[MsgType]"], + parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]], + decoder: Callable[["StreamMessage[MsgType]"], "Any"], +) -> None: ... + + +@overload +async def process_msg( + msg: MsgType, + middlewares: Iterable["BrokerMiddleware[MsgType]"], + parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]], + decoder: Callable[["StreamMessage[MsgType]"], "Any"], +) -> "StreamMessage[MsgType]": ... + + +async def process_msg( + msg: Optional[MsgType], + middlewares: Iterable["BrokerMiddleware[MsgType]"], + parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]], + decoder: Callable[["StreamMessage[MsgType]"], "Any"], +) -> Optional["StreamMessage[MsgType]"]: + if msg is None: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[ + [StreamMessage[MsgType]], + Awaitable[StreamMessage[MsgType]], + ] = return_input + + for m in middlewares: + mid = m(msg) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await parser(msg) + parsed_msg._decoded_body = await decoder(parsed_msg) + return await return_msg(parsed_msg) + + raise AssertionError("unreachable") diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index b922cfdac1..589d6bbe2e 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -1,7 +1,6 @@ import asyncio from abc import abstractmethod -from contextlib import AsyncExitStack, suppress -from functools import partial +from contextlib import suppress from typing import ( TYPE_CHECKING, Any, @@ -27,7 +26,8 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase -from faststream.broker.types import CustomCallable, MsgType +from faststream.broker.subscriber.utils import process_msg +from faststream.broker.types import MsgType from faststream.exceptions import NOT_CONNECTED_YET from faststream.nats.parser import ( BatchParser, @@ -41,9 +41,7 @@ UnsubscribeAdapter, Unsubscriptable, ) -from faststream.types import AnyDict, LoggerProto, SendableMessage from faststream.utils.context.repository import context -from faststream.utils.functions import return_input if TYPE_CHECKING: from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -59,11 +57,12 @@ from faststream.broker.types import ( AsyncCallable, BrokerMiddleware, + CustomCallable, ) from faststream.nats.helpers import KVBucketDeclarer, OSBucketDeclarer from faststream.nats.message import NatsKvMessage, NatsMessage, NatsObjMessage from faststream.nats.schemas import JStream, KvWatch, ObjWatch, PullSub - from faststream.types import Decorator + from faststream.types import AnyDict, Decorator, LoggerProto, SendableMessage ConnectionType = TypeVar("ConnectionType") @@ -82,7 +81,7 @@ def __init__( *, subject: str, config: "ConsumerConfig", - extra_options: Optional[AnyDict], + extra_options: Optional["AnyDict"], # Subscriber args default_parser: "AsyncCallable", default_decoder: "AsyncCallable", @@ -257,7 +256,7 @@ def __init__( subject: str, config: "ConsumerConfig", # default args - extra_options: Optional[AnyDict], + extra_options: Optional["AnyDict"], # Subscriber args default_parser: "AsyncCallable", default_decoder: "AsyncCallable", @@ -399,7 +398,7 @@ def __init__( subject: str, config: "ConsumerConfig", queue: str, - extra_options: Optional[AnyDict], + extra_options: Optional["AnyDict"], # Subscriber args no_ack: bool, no_reply: bool, @@ -459,19 +458,13 @@ async def get_one( except TimeoutError: return None - async with AsyncExitStack() as stack: - return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = return_input - - for m in self._broker_middlewares: - mid = m(raw_message) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg = await self._parser(raw_message) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - - raise AssertionError("unreachable") + msg: NatsMessage = await process_msg( # type: ignore[assignment] + msg=raw_message, + middlewares=self._broker_middlewares, + parser=self._parser, + decoder=self._decoder, + ) + return msg @override async def _create_subscription( @@ -517,7 +510,7 @@ def __init__( subject: str, config: "ConsumerConfig", queue: str, - extra_options: Optional[AnyDict], + extra_options: Optional["AnyDict"], # Subscriber args no_ack: bool, no_reply: bool, @@ -579,7 +572,7 @@ def __init__( subject: str, config: "ConsumerConfig", queue: str, - extra_options: Optional[AnyDict], + extra_options: Optional["AnyDict"], # Subscriber args no_ack: bool, no_reply: bool, @@ -667,19 +660,13 @@ async def get_one( except (TimeoutError, ConnectionClosedError): return None - async with AsyncExitStack() as stack: - return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = return_input - - for m in self._broker_middlewares: - mid = m(raw_message) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg: NatsMessage = await self._parser(raw_message) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - - raise AssertionError("unreachable") + msg: NatsMessage = await process_msg( # type: ignore[assignment] + msg=raw_message, + middlewares=self._broker_middlewares, + parser=self._parser, + decoder=self._decoder, + ) + return msg class PushStreamSubscription(_StreamSubscriber): @@ -719,7 +706,7 @@ def __init__( subject: str, config: "ConsumerConfig", queue: str, - extra_options: Optional[AnyDict], + extra_options: Optional["AnyDict"], # Subscriber args no_ack: bool, no_reply: bool, @@ -786,7 +773,7 @@ def __init__( # default args subject: str, config: "ConsumerConfig", - extra_options: Optional[AnyDict], + extra_options: Optional["AnyDict"], # Subscriber args no_ack: bool, no_reply: bool, @@ -838,7 +825,7 @@ async def _create_subscription( async def _consume_pull( self, - cb: Callable[["Msg"], Awaitable[SendableMessage]], + cb: Callable[["Msg"], Awaitable["SendableMessage"]], ) -> None: """Endless task consuming messages using NATS Pull subscriber.""" assert self.subscription # nosec B101 @@ -870,7 +857,7 @@ def __init__( stream: "JStream", subject: str, config: "ConsumerConfig", - extra_options: Optional[AnyDict], + extra_options: Optional["AnyDict"], # Subscriber args no_ack: bool, no_reply: bool, @@ -939,7 +926,7 @@ def __init__( config: "ConsumerConfig", stream: "JStream", pull_sub: "PullSub", - extra_options: Optional[AnyDict], + extra_options: Optional["AnyDict"], # Subscriber args no_ack: bool, no_reply: bool, @@ -996,26 +983,20 @@ async def get_one( fetch_sub = self._fetch_sub try: - raw_messages = await fetch_sub.fetch( + raw_message = await fetch_sub.fetch( batch=1, timeout=timeout, ) except TimeoutError: return None - async with AsyncExitStack() as stack: - return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = return_input - - for m in self._broker_middlewares: - mid = m(raw_messages) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg = await self._parser(raw_messages) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - - raise AssertionError("unreachable") + msg: NatsMessage = await process_msg( + msg=raw_message, + middlewares=self._broker_middlewares, + parser=self._parser, + decoder=self._decoder, + ) + return msg @override async def _create_subscription( @@ -1126,24 +1107,13 @@ async def get_one( ) is None: await anyio.sleep(sleep_interval) - if not raw_message: - return None - - async with AsyncExitStack() as stack: - return_msg: Callable[[NatsKvMessage], Awaitable[NatsKvMessage]] = ( - return_input - ) - - for m in self._broker_middlewares: - mid = m(raw_message) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg: NatsKvMessage = await self._parser(raw_message) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - - raise AssertionError("unreachable") + msg: NatsKvMessage = await process_msg( + msg=raw_message, + middlewares=self._broker_middlewares, + parser=self._parser, + decoder=self._decoder, + ) + return msg @override async def _create_subscription( @@ -1295,24 +1265,13 @@ async def get_one( ) is None: await anyio.sleep(sleep_interval) - if not raw_message: - return None - - async with AsyncExitStack() as stack: - return_msg: Callable[[NatsObjMessage], Awaitable[NatsObjMessage]] = ( - return_input - ) - - for m in self._broker_middlewares: - mid = m(raw_message) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg: NatsObjMessage = await self._parser(raw_message) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - - raise AssertionError("unreachable") + msg: NatsObjMessage = await process_msg( + msg=raw_message, + middlewares=self._broker_middlewares, + parser=self._parser, + decoder=self._decoder, + ) + return msg @override async def _create_subscription( diff --git a/tests/brokers/base/consume.py b/tests/brokers/base/consume.py index 936d2bc575..0e7b07b698 100644 --- a/tests/brokers/base/consume.py +++ b/tests/brokers/base/consume.py @@ -273,6 +273,23 @@ async def subscriber(m): assert event.is_set() + async def test_get_one_conflicts_with_handler(self, queue): + broker = self.get_broker(apply_types=True) + args, kwargs = self.get_subscriber_params(queue) + subscriber = broker.subscriber(*args, **kwargs) + + @subscriber + async def t(): ... + + async with self.patch_broker(broker) as br: + await br.start() + + with pytest.raises(AssertionError): + await subscriber.get_one(timeout=1e-24) + + +@pytest.mark.asyncio +class BrokerRealConsumeTestcase(BrokerConsumeTestcase): async def test_get_one( self, queue: str, @@ -322,23 +339,6 @@ async def test_get_one_timeout( mock(await subscriber.get_one(timeout=1e-24)) mock.assert_called_once_with(None) - async def test_get_one_conflicts_with_handler(self, queue): - broker = self.get_broker(apply_types=True) - args, kwargs = self.get_subscriber_params(queue) - subscriber = broker.subscriber(*args, **kwargs) - - @subscriber - async def t(): ... - - async with self.patch_broker(broker) as br: - await br.start() - - with pytest.raises(AssertionError): - await subscriber.get_one(timeout=1e-24) - - -@pytest.mark.asyncio -class BrokerRealConsumeTestcase(BrokerConsumeTestcase): @pytest.mark.slow async def test_stop_consume_exc( self, From 283944be6731dfcfa6b52bc3d51ecb57d67b359a Mon Sep 17 00:00:00 2001 From: Lancetnik Date: Sun, 8 Sep 2024 17:08:30 +0000 Subject: [PATCH 55/62] docs: generate API References --- docs/docs/SUMMARY.md | 2 ++ .../faststream/broker/subscriber/utils/process_msg.md | 11 +++++++++++ 2 files changed, 13 insertions(+) create mode 100644 docs/docs/en/api/faststream/broker/subscriber/utils/process_msg.md diff --git a/docs/docs/SUMMARY.md b/docs/docs/SUMMARY.md index a5e41f0749..009c60fa88 100644 --- a/docs/docs/SUMMARY.md +++ b/docs/docs/SUMMARY.md @@ -414,6 +414,8 @@ search: - [SubscriberProto](api/faststream/broker/subscriber/proto/SubscriberProto.md) - usecase - [SubscriberUsecase](api/faststream/broker/subscriber/usecase/SubscriberUsecase.md) + - utils + - [process_msg](api/faststream/broker/subscriber/utils/process_msg.md) - types - [PublisherMiddleware](api/faststream/broker/types/PublisherMiddleware.md) - utils diff --git a/docs/docs/en/api/faststream/broker/subscriber/utils/process_msg.md b/docs/docs/en/api/faststream/broker/subscriber/utils/process_msg.md new file mode 100644 index 0000000000..cea94447f6 --- /dev/null +++ b/docs/docs/en/api/faststream/broker/subscriber/utils/process_msg.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.broker.subscriber.utils.process_msg From 84530929a89e4282a5febc23de2a3bb6c2168a08 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 20:21:36 +0300 Subject: [PATCH 56/62] refactor: use process_msg everywhere --- faststream/confluent/subscriber/usecase.py | 34 +++------ faststream/kafka/subscriber/usecase.py | 27 ++----- faststream/rabbit/subscriber/usecase.py | 29 ++------ faststream/redis/subscriber/usecase.py | 86 +++++++--------------- 4 files changed, 54 insertions(+), 122 deletions(-) diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index ede6e62a5c..129c11c9ce 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -1,11 +1,8 @@ import asyncio from abc import ABC, abstractmethod -from contextlib import AsyncExitStack -from functools import partial from typing import ( TYPE_CHECKING, Any, - Awaitable, Callable, Dict, Iterable, @@ -21,10 +18,10 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase +from faststream.broker.subscriber.utils import process_msg from faststream.broker.types import MsgType from faststream.confluent.parser import AsyncConfluentParser from faststream.confluent.schemas import TopicPartition -from faststream.utils.functions import return_input if TYPE_CHECKING: from fast_depends.dependencies import Depends @@ -172,11 +169,11 @@ async def close(self) -> None: self.task = None @override - async def get_one( + async def get_one( # type: ignore[override] self, *, timeout: float = 5.0, - ) -> "Optional[StreamMessage[MsgType]]": + ) -> "Optional[StreamMessage[Message]]": assert self.consumer, "You should start subscriber at first." # nosec B101 assert ( # nosec B101 not self.calls @@ -184,24 +181,13 @@ async def get_one( raw_message = await self.consumer.getone(timeout=timeout) - if not raw_message: - return None - - async with AsyncExitStack() as stack: - return_msg: Callable[ - [StreamMessage[MsgType]], Awaitable[StreamMessage[MsgType]] - ] = return_input - - for m in self._broker_middlewares: - mid = m(raw_message) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg: StreamMessage[MsgType] = await self._parser(raw_message) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - - raise AssertionError("unreachable") + msg = await process_msg( + msg=raw_message, + middlewares=self._broker_middlewares, + parser=self._parser, + decoder=self._decoder, + ) + return msg def _make_response_publisher( self, diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index f69705e4d0..c012c5ab29 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -1,12 +1,9 @@ import asyncio from abc import ABC, abstractmethod -from contextlib import AsyncExitStack -from functools import partial from itertools import chain from typing import ( TYPE_CHECKING, Any, - Awaitable, Callable, Dict, Iterable, @@ -23,6 +20,7 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase +from faststream.broker.subscriber.utils import process_msg from faststream.broker.types import ( AsyncCallable, BrokerMiddleware, @@ -31,7 +29,6 @@ ) from faststream.kafka.message import KafkaAckableMessage, KafkaMessage from faststream.kafka.parser import AioKafkaBatchParser, AioKafkaParser -from faststream.utils.functions import return_input from faststream.utils.path import compile_path if TYPE_CHECKING: @@ -203,21 +200,13 @@ async def get_one( ((raw_message,),) = raw_messages.values() - async with AsyncExitStack() as stack: - return_msg: Callable[ - [StreamMessage[MsgType]], Awaitable[StreamMessage[MsgType]] - ] = return_input - - for m in self._broker_middlewares: - mid = m(raw_message) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg: StreamMessage[MsgType] = await self._parser(raw_message) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - - raise AssertionError("unreachable") + msg: StreamMessage[MsgType] = await process_msg( + msg=raw_message, + middlewares=self._broker_middlewares, + parser=self._parser, + decoder=self._decoder, + ) + return msg def _make_response_publisher( self, diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 6bd7d4d1e4..66ddffbd3b 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -1,9 +1,6 @@ -from contextlib import AsyncExitStack -from functools import partial from typing import ( TYPE_CHECKING, Any, - Awaitable, Callable, Dict, Iterable, @@ -17,11 +14,11 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase +from faststream.broker.subscriber.utils import process_msg from faststream.exceptions import SetupError from faststream.rabbit.helpers.declarer import RabbitDeclarer from faststream.rabbit.parser import AioPikaParser from faststream.rabbit.schemas import BaseRMQInformation -from faststream.utils.functions import return_input if TYPE_CHECKING: from aio_pika import IncomingMessage, RobustQueue @@ -208,23 +205,13 @@ async def get_one( ) is None: await anyio.sleep(sleep_interval) - if raw_message is None: - return None - - async with AsyncExitStack() as stack: - return_msg: Callable[[RabbitMessage], Awaitable[RabbitMessage]] = ( - return_input - ) - for m in self._broker_middlewares: - mid = m(raw_message) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg: RabbitMessage = await self._parser(raw_message) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - - raise AssertionError("unreachable") + msg: Optional[RabbitMessage] = await process_msg( # type: ignore[assignment] + msg=raw_message, + middlewares=self._broker_middlewares, + parser=self._parser, + decoder=self._decoder, + ) + return msg def _make_response_publisher( self, diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index ec027269ef..23aaf574fd 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -1,9 +1,8 @@ import asyncio import math from abc import abstractmethod -from contextlib import AsyncExitStack, suppress +from contextlib import suppress from copy import deepcopy -from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -25,6 +24,7 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase +from faststream.broker.subscriber.utils import process_msg from faststream.redis.message import ( BatchListMessage, BatchStreamMessage, @@ -44,7 +44,6 @@ RedisStreamParser, ) from faststream.redis.schemas import ListSub, PubSub, StreamSub -from faststream.utils.functions import return_input if TYPE_CHECKING: from fast_depends.dependencies import Depends @@ -310,22 +309,13 @@ async def get_one( # type: ignore[override] while (message := await self._get_message(self.subscription)) is None: # noqa: ASYNC110 await anyio.sleep(sleep_interval) - if not message: - return None - - async with AsyncExitStack() as stack: - return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = return_input - - for m in self._broker_middlewares: - mid = m(message) # type: ignore[arg-type] - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg: RedisMessage = await self._parser(message) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - - raise AssertionError("unreachable") + msg: Optional[RedisMessage] = await process_msg( # type: ignore[assignment] + msg=message, + middlewares=self._broker_middlewares, # type: ignore[arg-type] + parser=self._parser, + decoder=self._decoder, + ) + return msg async def _get_message(self, psub: RPubSub) -> Optional[PubSubMessage]: raw_msg = await psub.get_message( @@ -442,27 +432,17 @@ async def get_one( # type: ignore[override] if not raw_message: return None - async with AsyncExitStack() as stack: - return_msg: Callable[[RedisListMessage], Awaitable[RedisListMessage]] = ( - return_input - ) - - for m in self._broker_middlewares: - mid = m(raw_message) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - message = DefaultListMessage( + msg: RedisListMessage = await process_msg( # type: ignore[assignment] + msg=DefaultListMessage( type="list", data=raw_message, channel=self.list_sub.name, - ) - - parsed_message: RedisListMessage = await self._parser(message) - parsed_message._decoded_body = await self._decoder(parsed_message) - return await return_msg(parsed_message) - - raise AssertionError("unreachable") + ), + middlewares=self._broker_middlewares, # type: ignore[arg-type] + parser=self._parser, + decoder=self._decoder, + ) + return msg def add_prefix(self, prefix: str) -> None: new_list = deepcopy(self.list_sub) @@ -743,28 +723,18 @@ async def get_one( # type: ignore[override] self.last_id = message_id.decode() - msg = DefaultStreamMessage( - type="stream", - channel=stream_name.decode(), - message_ids=[message_id], - data=raw_message, + msg: RedisStreamMessage = await process_msg( # type: ignore[assignment] + msg=DefaultStreamMessage( + type="stream", + channel=stream_name.decode(), + message_ids=[message_id], + data=raw_message, + ), + middlewares=self._broker_middlewares, # type: ignore[arg-type] + parser=self._parser, + decoder=self._decoder, ) - - async with AsyncExitStack() as stack: - return_msg: Callable[ - [RedisStreamMessage], Awaitable[RedisStreamMessage] - ] = return_input - - for m in self._broker_middlewares: - mid = m(msg) # type: ignore[arg-type] - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg: RedisStreamMessage = await self._parser(msg) - parsed_msg._decoded_body = await self._decoder(parsed_msg) - return await return_msg(parsed_msg) - - raise AssertionError("unreachable") + return msg def add_prefix(self, prefix: str) -> None: new_stream = deepcopy(self.stream_sub) From fd952c886177063f3acd36bc248a222908973441 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 20:27:22 +0300 Subject: [PATCH 57/62] refactor: mv process_msg broker.utils --- docs/docs/SUMMARY.md | 3 +- .../{subscriber => }/utils/process_msg.md | 2 +- faststream/broker/subscriber/utils.py | 66 ------------------- faststream/broker/utils.py | 55 +++++++++++++++- faststream/confluent/subscriber/usecase.py | 2 +- faststream/kafka/subscriber/usecase.py | 2 +- faststream/nats/subscriber/usecase.py | 2 +- faststream/rabbit/subscriber/usecase.py | 2 +- faststream/redis/subscriber/usecase.py | 2 +- 9 files changed, 59 insertions(+), 77 deletions(-) rename docs/docs/en/api/faststream/broker/{subscriber => }/utils/process_msg.md (68%) delete mode 100644 faststream/broker/subscriber/utils.py diff --git a/docs/docs/SUMMARY.md b/docs/docs/SUMMARY.md index 009c60fa88..f5846592cc 100644 --- a/docs/docs/SUMMARY.md +++ b/docs/docs/SUMMARY.md @@ -414,14 +414,13 @@ search: - [SubscriberProto](api/faststream/broker/subscriber/proto/SubscriberProto.md) - usecase - [SubscriberUsecase](api/faststream/broker/subscriber/usecase/SubscriberUsecase.md) - - utils - - [process_msg](api/faststream/broker/subscriber/utils/process_msg.md) - types - [PublisherMiddleware](api/faststream/broker/types/PublisherMiddleware.md) - utils - [MultiLock](api/faststream/broker/utils/MultiLock.md) - [default_filter](api/faststream/broker/utils/default_filter.md) - [get_watcher_context](api/faststream/broker/utils/get_watcher_context.md) + - [process_msg](api/faststream/broker/utils/process_msg.md) - [resolve_custom_func](api/faststream/broker/utils/resolve_custom_func.md) - wrapper - call diff --git a/docs/docs/en/api/faststream/broker/subscriber/utils/process_msg.md b/docs/docs/en/api/faststream/broker/utils/process_msg.md similarity index 68% rename from docs/docs/en/api/faststream/broker/subscriber/utils/process_msg.md rename to docs/docs/en/api/faststream/broker/utils/process_msg.md index cea94447f6..e7ce8aaf99 100644 --- a/docs/docs/en/api/faststream/broker/subscriber/utils/process_msg.md +++ b/docs/docs/en/api/faststream/broker/utils/process_msg.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: faststream.broker.subscriber.utils.process_msg +::: faststream.broker.utils.process_msg diff --git a/faststream/broker/subscriber/utils.py b/faststream/broker/subscriber/utils.py deleted file mode 100644 index 7f2414cbb9..0000000000 --- a/faststream/broker/subscriber/utils.py +++ /dev/null @@ -1,66 +0,0 @@ -from contextlib import AsyncExitStack -from functools import partial -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Callable, - Iterable, - Optional, -) - -from typing_extensions import Literal, overload - -from faststream.broker.types import MsgType -from faststream.utils.functions import return_input - -if TYPE_CHECKING: - from faststream.broker.message import StreamMessage - from faststream.broker.types import ( - BrokerMiddleware, - ) - - -@overload -async def process_msg( - msg: Literal[None], - middlewares: Iterable["BrokerMiddleware[MsgType]"], - parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]], - decoder: Callable[["StreamMessage[MsgType]"], "Any"], -) -> None: ... - - -@overload -async def process_msg( - msg: MsgType, - middlewares: Iterable["BrokerMiddleware[MsgType]"], - parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]], - decoder: Callable[["StreamMessage[MsgType]"], "Any"], -) -> "StreamMessage[MsgType]": ... - - -async def process_msg( - msg: Optional[MsgType], - middlewares: Iterable["BrokerMiddleware[MsgType]"], - parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]], - decoder: Callable[["StreamMessage[MsgType]"], "Any"], -) -> Optional["StreamMessage[MsgType]"]: - if msg is None: - return None - - async with AsyncExitStack() as stack: - return_msg: Callable[ - [StreamMessage[MsgType]], - Awaitable[StreamMessage[MsgType]], - ] = return_input - - for m in middlewares: - mid = m(msg) - await stack.enter_async_context(mid) - return_msg = partial(mid.consume_scope, return_msg) - - parsed_msg = await parser(msg) - parsed_msg._decoded_body = await decoder(parsed_msg) - return await return_msg(parsed_msg) - - raise AssertionError("unreachable") diff --git a/faststream/broker/utils.py b/faststream/broker/utils.py index ceda2b6a5b..c12c3fc967 100644 --- a/faststream/broker/utils.py +++ b/faststream/broker/utils.py @@ -1,12 +1,14 @@ import asyncio import inspect -from contextlib import suppress +from contextlib import AsyncExitStack, suppress from functools import partial from typing import ( TYPE_CHECKING, Any, AsyncContextManager, + Awaitable, Callable, + Iterable, Optional, Type, Union, @@ -14,10 +16,11 @@ ) import anyio -from typing_extensions import Self +from typing_extensions import Literal, Self, overload from faststream.broker.acknowledgement_watcher import WatcherContext, get_watcher -from faststream.utils.functions import fake_context, to_async +from faststream.broker.types import MsgType +from faststream.utils.functions import fake_context, return_input, to_async if TYPE_CHECKING: from types import TracebackType @@ -25,12 +28,58 @@ from faststream.broker.message import StreamMessage from faststream.broker.types import ( AsyncCallable, + BrokerMiddleware, CustomCallable, SyncCallable, ) from faststream.types import LoggerProto +@overload +async def process_msg( + msg: Literal[None], + middlewares: Iterable["BrokerMiddleware[MsgType]"], + parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]], + decoder: Callable[["StreamMessage[MsgType]"], "Any"], +) -> None: ... + + +@overload +async def process_msg( + msg: MsgType, + middlewares: Iterable["BrokerMiddleware[MsgType]"], + parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]], + decoder: Callable[["StreamMessage[MsgType]"], "Any"], +) -> "StreamMessage[MsgType]": ... + + +async def process_msg( + msg: Optional[MsgType], + middlewares: Iterable["BrokerMiddleware[MsgType]"], + parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]], + decoder: Callable[["StreamMessage[MsgType]"], "Any"], +) -> Optional["StreamMessage[MsgType]"]: + if msg is None: + return None + + async with AsyncExitStack() as stack: + return_msg: Callable[ + [StreamMessage[MsgType]], + Awaitable[StreamMessage[MsgType]], + ] = return_input + + for m in middlewares: + mid = m(msg) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await parser(msg) + parsed_msg._decoded_body = await decoder(parsed_msg) + return await return_msg(parsed_msg) + + raise AssertionError("unreachable") + + async def default_filter(msg: "StreamMessage[Any]") -> bool: """A function to filter stream messages.""" return not msg.processed diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index 129c11c9ce..3a3d57d494 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -18,8 +18,8 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase -from faststream.broker.subscriber.utils import process_msg from faststream.broker.types import MsgType +from faststream.broker.utils import process_msg from faststream.confluent.parser import AsyncConfluentParser from faststream.confluent.schemas import TopicPartition diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index c012c5ab29..b14e107faf 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -20,13 +20,13 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase -from faststream.broker.subscriber.utils import process_msg from faststream.broker.types import ( AsyncCallable, BrokerMiddleware, CustomCallable, MsgType, ) +from faststream.broker.utils import process_msg from faststream.kafka.message import KafkaAckableMessage, KafkaMessage from faststream.kafka.parser import AioKafkaBatchParser, AioKafkaParser from faststream.utils.path import compile_path diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 589d6bbe2e..e7b3e0ce01 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -26,8 +26,8 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase -from faststream.broker.subscriber.utils import process_msg from faststream.broker.types import MsgType +from faststream.broker.utils import process_msg from faststream.exceptions import NOT_CONNECTED_YET from faststream.nats.parser import ( BatchParser, diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 66ddffbd3b..5ed53edf24 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -14,7 +14,7 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase -from faststream.broker.subscriber.utils import process_msg +from faststream.broker.utils import process_msg from faststream.exceptions import SetupError from faststream.rabbit.helpers.declarer import RabbitDeclarer from faststream.rabbit.parser import AioPikaParser diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index 23aaf574fd..a67095e986 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -24,7 +24,7 @@ from faststream.broker.publisher.fake import FakePublisher from faststream.broker.subscriber.usecase import SubscriberUsecase -from faststream.broker.subscriber.utils import process_msg +from faststream.broker.utils import process_msg from faststream.redis.message import ( BatchListMessage, BatchStreamMessage, From 5d7d5028431f9f66e978f96acf8865336641de26 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 20:30:43 +0300 Subject: [PATCH 58/62] lint: fix mypy --- faststream/confluent/subscriber/usecase.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index 3a3d57d494..3540be9bdf 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -169,7 +169,7 @@ async def close(self) -> None: self.task = None @override - async def get_one( # type: ignore[override] + async def get_one( self, *, timeout: float = 5.0, From 584a2f9825dd5c89874d37d2ea5e2306ede50c48 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 8 Sep 2024 21:13:29 +0300 Subject: [PATCH 59/62] Nats KV and Obj subscribers get_one tests --- tests/brokers/nats/test_consume.py | 67 ++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index ab7452a415..013a872c01 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -607,3 +607,70 @@ async def publish(): assert message is not None assert await message.decode() == "test_message" + + async def test_get_one_kv( + self, + queue: str, + event: asyncio.Event, + stream: JStream, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue, kv_watch=queue + "1") + + async with self.patch_broker(broker) as br: + await br.start() + bucket = await br.key_value(queue + "1") + + message = None + + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=5) + + async def publish(): + await bucket.put(queue, b"test_message") + + await asyncio.wait( + ( + asyncio.create_task(consume()), + asyncio.create_task(publish()), + ), + timeout=10, + ) + + assert message is not None + assert await message.decode() == b"test_message" + + async def test_get_one_os( + self, + queue: str, + event: asyncio.Event, + stream: JStream, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue, obj_watch=True) + + async with self.patch_broker(broker) as br: + await br.start() + bucket = await br.object_storage(queue) + + new_object_id = None + + async def consume(): + nonlocal new_object_id + new_object_event = await subscriber.get_one(timeout=5) + new_object_id = await new_object_event.decode() + + async def publish(): + await bucket.put(queue, b"test_message") + + await asyncio.wait( + ( + asyncio.create_task(consume()), + asyncio.create_task(publish()), + ), + timeout=10, + ) + + new_object = await bucket.get(new_object_id) + assert new_object.data == b"test_message" From 2197c610dc7dbb3276c551007b99dcd354e1e7f5 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 8 Sep 2024 21:21:13 +0300 Subject: [PATCH 60/62] Nats KV and Obj subscribers get_one timeout tests --- tests/brokers/nats/test_consume.py | 55 ++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index 013a872c01..aa0bbd6370 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -641,6 +641,33 @@ async def publish(): assert message is not None assert await message.decode() == b"test_message" + async def test_get_one_kv_timeout( + self, + queue: str, + event: asyncio.Event, + stream: JStream, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue, kv_watch=queue + "1") + + async with self.patch_broker(broker) as br: + await br.start() + + message = object() + + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=1e-24) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + ), + timeout=10, + ) + + assert message is None + async def test_get_one_os( self, queue: str, @@ -674,3 +701,31 @@ async def publish(): new_object = await bucket.get(new_object_id) assert new_object.data == b"test_message" + + + async def test_get_one_os_timeout( + self, + queue: str, + event: asyncio.Event, + stream: JStream, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber(queue, obj_watch=True) + + async with self.patch_broker(broker) as br: + await br.start() + + new_object_event = object() + + async def consume(): + nonlocal new_object_event + new_object_event = await subscriber.get_one(timeout=1e-24) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + ), + timeout=10, + ) + + assert new_object_event is None From 04b7bd7295720a9ad8b0c8de070ae44d3f9b4172 Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 8 Sep 2024 21:35:17 +0300 Subject: [PATCH 61/62] format fix --- tests/brokers/nats/test_consume.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index aa0bbd6370..c2c89c415c 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -660,9 +660,7 @@ async def consume(): message = await subscriber.get_one(timeout=1e-24) await asyncio.wait( - ( - asyncio.create_task(consume()), - ), + (asyncio.create_task(consume()),), timeout=10, ) @@ -722,9 +720,7 @@ async def consume(): new_object_event = await subscriber.get_one(timeout=1e-24) await asyncio.wait( - ( - asyncio.create_task(consume()), - ), + (asyncio.create_task(consume()),), timeout=10, ) From 149b22016e9d619fa0874c691982285be97e0986 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 8 Sep 2024 23:05:40 +0300 Subject: [PATCH 62/62] tests: tefactor timeout tests --- tests/brokers/nats/test_consume.py | 55 +++++++----------------------- 1 file changed, 12 insertions(+), 43 deletions(-) diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index c2c89c415c..a8b9778e4d 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -490,6 +490,7 @@ async def test_get_one_pull_timeout( queue: str, event: asyncio.Event, stream: JStream, + mock: Mock, ): broker = self.get_broker(apply_types=True) subscriber = broker.subscriber( @@ -501,15 +502,8 @@ async def test_get_one_pull_timeout( async with self.patch_broker(broker) as br: await br.start() - message = object - - async def consume(): - nonlocal message - message = await subscriber.get_one(timeout=1e-24) - - await asyncio.wait((asyncio.create_task(consume()),), timeout=3) - - assert message is None + mock(await subscriber.get_one(timeout=1e-24)) + mock.assert_called_once_with(None) async def test_get_one_batch( self, @@ -552,6 +546,7 @@ async def test_get_one_batch_timeout( queue: str, event: asyncio.Event, stream: JStream, + mock: Mock, ): broker = self.get_broker(apply_types=True) subscriber = broker.subscriber( @@ -563,15 +558,8 @@ async def test_get_one_batch_timeout( async with self.patch_broker(broker) as br: await br.start() - message = object - - async def consume(): - nonlocal message - message = await subscriber.get_one(timeout=1e-24) - - await asyncio.wait((asyncio.create_task(consume()),), timeout=3) - - assert message is None + mock(await subscriber.get_one(timeout=1e-24)) + mock.assert_called_once_with(None) async def test_get_one_with_filter( self, @@ -646,6 +634,7 @@ async def test_get_one_kv_timeout( queue: str, event: asyncio.Event, stream: JStream, + mock: Mock, ): broker = self.get_broker(apply_types=True) subscriber = broker.subscriber(queue, kv_watch=queue + "1") @@ -653,18 +642,8 @@ async def test_get_one_kv_timeout( async with self.patch_broker(broker) as br: await br.start() - message = object() - - async def consume(): - nonlocal message - message = await subscriber.get_one(timeout=1e-24) - - await asyncio.wait( - (asyncio.create_task(consume()),), - timeout=10, - ) - - assert message is None + mock(await subscriber.get_one(timeout=1e-24)) + mock.assert_called_once_with(None) async def test_get_one_os( self, @@ -700,12 +679,12 @@ async def publish(): new_object = await bucket.get(new_object_id) assert new_object.data == b"test_message" - async def test_get_one_os_timeout( self, queue: str, event: asyncio.Event, stream: JStream, + mock: Mock, ): broker = self.get_broker(apply_types=True) subscriber = broker.subscriber(queue, obj_watch=True) @@ -713,15 +692,5 @@ async def test_get_one_os_timeout( async with self.patch_broker(broker) as br: await br.start() - new_object_event = object() - - async def consume(): - nonlocal new_object_event - new_object_event = await subscriber.get_one(timeout=1e-24) - - await asyncio.wait( - (asyncio.create_task(consume()),), - timeout=10, - ) - - assert new_object_event is None + mock(await subscriber.get_one(timeout=1e-24)) + mock.assert_called_once_with(None)