Skip to content

Commit

Permalink
feat: better RMQ AsyncAPI (#879)
Browse files Browse the repository at this point in the history
* fix (#868): remove pydantic  in AsyncAPI schema

* fix: better RMQ AsyncAPI

* chore: bump version
  • Loading branch information
Lancetnik authored Oct 20, 2023
1 parent caa452f commit 6873f1b
Show file tree
Hide file tree
Showing 17 changed files with 290 additions and 77 deletions.
2 changes: 1 addition & 1 deletion faststream/__about__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Simple and fast framework to create message brokers based microservices"""
__version__ = "0.2.7"
__version__ = "0.2.8"


INSTALL_YAML = """
Expand Down
46 changes: 22 additions & 24 deletions faststream/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import os
import sys
from contextlib import suppress
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union

from fast_depends._compat import PYDANTIC_V2 as PYDANTIC_V2
Expand Down Expand Up @@ -55,35 +54,34 @@ def is_installed(package: str) -> bool:
IS_OPTIMIZED = os.getenv("PYTHONOPTIMIZE", False)


if is_installed("fastapi"):
# NOTE: is_installed somewhen True for some reason
with suppress(ImportError):
from fastapi import __version__ as FASTAPI_VERSION
try:
from fastapi import __version__ as FASTAPI_VERSION

major, minor, _ = map(int, FASTAPI_VERSION.split("."))
FASTAPI_V2 = not (major <= 0 and minor < 100)
HAS_FASTAPI = True

if FASTAPI_V2:
from fastapi._compat import _normalize_errors
from fastapi.exceptions import RequestValidationError
major, minor, _ = map(int, FASTAPI_VERSION.split("."))
FASTAPI_V2 = major > 0 or minor > 100

def raise_fastapi_validation_error(
errors: List[Any], body: AnyDict
) -> Never:
raise RequestValidationError(_normalize_errors(errors), body=body)
if FASTAPI_V2:
from fastapi._compat import _normalize_errors
from fastapi.exceptions import RequestValidationError

else:
from pydantic import ( # type: ignore[assignment] # isort: skip
ValidationError as RequestValidationError,
)
from pydantic import create_model
def raise_fastapi_validation_error(errors: List[Any], body: AnyDict) -> Never:
raise RequestValidationError(_normalize_errors(errors), body=body)

ROUTER_VALIDATION_ERROR_MODEL = create_model("StreamRoute")
else:
from pydantic import ( # type: ignore[assignment] # isort: skip
ValidationError as RequestValidationError,
)
from pydantic import create_model

ROUTER_VALIDATION_ERROR_MODEL = create_model("StreamRoute")

def raise_fastapi_validation_error(errors: List[Any], body: AnyDict) -> Never:
raise RequestValidationError(errors, ROUTER_VALIDATION_ERROR_MODEL) # type: ignore[misc]

def raise_fastapi_validation_error(
errors: List[Any], body: AnyDict
) -> Never:
raise RequestValidationError(errors, ROUTER_VALIDATION_ERROR_MODEL) # type: ignore[misc]
except ImportError:
HAS_FASTAPI = False


JsonSchemaValue = Mapping[str, Any]
Expand Down
6 changes: 3 additions & 3 deletions faststream/asyncapi/generate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, List, Union

from faststream._compat import is_installed, PYDANTIC_V2
from faststream.types import AnyDict
from faststream._compat import HAS_FASTAPI, PYDANTIC_V2
from faststream.app import FastStream
from faststream.asyncapi.schema import (
Channel,
Expand All @@ -13,8 +12,9 @@
Server,
)
from faststream.constants import ContentTypes
from faststream.types import AnyDict

if is_installed("fastapi"):
if HAS_FASTAPI:
from faststream.broker.fastapi.router import StreamRouter


Expand Down
7 changes: 5 additions & 2 deletions faststream/asyncapi/schema/bindings/amqp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Literal, Optional
from typing import Literal, Optional

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -92,5 +92,8 @@ class OperationBinding(BaseModel):

cc: Optional[str] = None
ack: bool = True
replyTo: Optional[Dict[str, Any]] = None
replyTo: Optional[str] = None
priority: Optional[int] = None
mandatory: Optional[bool] = None
deliveryMode: Optional[int] = None
bindingVersion: str = "0.2.0"
1 change: 1 addition & 0 deletions faststream/asyncapi/schema/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

if is_installed("email_validator"):
from pydantic import EmailStr

else: # pragma: no cover
# NOTE: EmailStr mock was copied from the FastAPI
# https://github.com/tiangolo/fastapi/blob/master/fastapi/openapi/models.py#24
Expand Down
24 changes: 18 additions & 6 deletions faststream/rabbit/asyncapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ class Publisher(LogicPublisher):

@property
def name(self) -> str:
routing = (
self.routing_key
or (self.queue.routing if _is_exchange(self.exchange) else None)
or "_"
)
return (
self.title
or f"{self.queue.name}:{self.exchange.name if self.exchange else '_'}:Publisher"
self.title or f"{routing}:{getattr(self.exchange, 'name', '_')}:Publisher"
)

def schema(self) -> Dict[str, Channel]:
Expand All @@ -45,7 +49,11 @@ def schema(self) -> Dict[str, Channel]:
publish=Operation(
bindings=OperationBinding(
amqp=amqp.OperationBinding(
cc=self.queue.name,
cc=self.routing,
deliveryMode=2 if self.persist else 1,
mandatory=self.mandatory,
replyTo=self.reply_to,
priority=self.priority,
),
)
if _is_exchange(self.exchange)
Expand All @@ -71,8 +79,9 @@ def schema(self) -> Dict[str, Channel]:
durable=self.queue.durable,
exclusive=self.queue.exclusive,
autoDelete=self.queue.auto_delete,
vhost=self.virtual_host,
)
if _is_exchange(self.exchange)
if _is_exchange(self.exchange) and self.queue.name
else None,
"exchange": (
amqp.Exchange(type="default")
Expand All @@ -82,6 +91,7 @@ def schema(self) -> Dict[str, Channel]:
name=self.exchange.name,
durable=self.exchange.durable,
autoDelete=self.exchange.auto_delete,
vhost=self.virtual_host,
)
),
}
Expand All @@ -107,7 +117,7 @@ def schema(self) -> Dict[str, Channel]:

handler_name = (
self._title
or f"{self.queue.name}:{self.exchange.name if self.exchange else '_'}:{self.call_name}"
or f"{self.queue.name}:{getattr(self.exchange, 'name', '_')}:{self.call_name}"
)

return {
Expand All @@ -116,7 +126,7 @@ def schema(self) -> Dict[str, Channel]:
subscribe=Operation(
bindings=OperationBinding(
amqp=amqp.OperationBinding(
cc=self.queue.name,
cc=self.queue.routing,
),
)
if _is_exchange(self.exchange)
Expand All @@ -138,6 +148,7 @@ def schema(self) -> Dict[str, Channel]:
durable=self.queue.durable,
exclusive=self.queue.exclusive,
autoDelete=self.queue.auto_delete,
vhost=self.virtual_host,
)
if _is_exchange(self.exchange)
else None,
Expand All @@ -149,6 +160,7 @@ def schema(self) -> Dict[str, Channel]:
name=self.exchange.name,
durable=self.exchange.durable,
autoDelete=self.exchange.auto_delete,
vhost=self.virtual_host,
)
),
}
Expand Down
62 changes: 40 additions & 22 deletions faststream/rabbit/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
self,
url: Union[str, URL, None] = "amqp://guest:guest@localhost:5672/",
*,
virtualhost: str = "/",
max_consumers: Optional[int] = None,
protocol: str = "amqp",
protocol_version: Optional[str] = "0.9.1",
Expand All @@ -93,11 +94,21 @@ def __init__(
protocol_version (Optional[str], optional): The protocol version to use (e.g., "0.9.1"). Defaults to "0.9.1".
**kwargs: Additional keyword arguments.
"""
if url is not None:
if not isinstance(url, URL):
url = URL(url)

self.virtual_host = url.path
url = str(url)
else:
self.virtual_host = virtualhost

super().__init__(
url=url,
protocol=protocol,
protocol_version=protocol_version,
security=security,
virtualhost=virtualhost,
**kwargs,
)

Expand Down Expand Up @@ -125,11 +136,8 @@ async def _close(
await self._channel.close()
self._channel = None

if self.declarer is not None:
self.declarer = None

if self._producer is not None:
self._producer = None
self.declarer = None
self._producer = None

if self._connection is not None: # pragma: no branch
await self._connection.close()
Expand Down Expand Up @@ -214,6 +222,10 @@ async def start(self) -> None:
self.declarer
), "Declarer should be initialized in `connect` method"

for publisher in self._publishers.values():
if publisher.exchange is not None:
await self.declare_exchange(publisher.exchange)

for handler in self.handlers.values():
c = self._get_log_context(None, handler.queue, handler.exchange)
self._log(f"`{handler.call_name}` waiting for messages", extra=c)
Expand Down Expand Up @@ -274,6 +286,7 @@ def subscriber( # type: ignore[override]
consume_args=consume_args,
description=description,
title=title,
virtual_host=self.virtual_host,
),
)

Expand Down Expand Up @@ -333,6 +346,7 @@ def publisher( # type: ignore[override]
title: Optional[str] = None,
description: Optional[str] = None,
schema: Optional[Any] = None,
priority: Optional[int] = None,
**message_kwargs: Any,
) -> Publisher:
"""
Expand All @@ -355,25 +369,29 @@ def publisher( # type: ignore[override]
Publisher: A message publisher instance.
"""
q, ex = RabbitQueue.validate(queue), RabbitExchange.validate(exchange)
key = get_routing_hash(q, ex)
publisher = self._publishers.get(
key,
Publisher(
title=title,
queue=q,
exchange=ex,
routing_key=routing_key,
mandatory=mandatory,
immediate=immediate,
timeout=timeout,
persist=persist,
reply_to=reply_to,
message_kwargs=message_kwargs,
_description=description,
_schema=schema,
),

publisher = Publisher(
title=title,
queue=q,
exchange=ex,
routing_key=routing_key,
mandatory=mandatory,
immediate=immediate,
timeout=timeout,
persist=persist,
reply_to=reply_to,
priority=priority,
message_kwargs=message_kwargs,
_description=description,
_schema=schema,
virtual_host=self.virtual_host,
)

key = publisher._get_routing_hash()
publisher = self._publishers.get(key, publisher)
super().publisher(key, publisher)
if self._producer is not None:
publisher._producer = self._producer
return publisher

@override
Expand Down
2 changes: 2 additions & 0 deletions faststream/rabbit/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
# AsyncAPI information
description: Optional[str] = None,
title: Optional[str] = None,
virtual_host: str = "/",
):
"""Initialize a RabbitMQ consumer.
Expand All @@ -85,6 +86,7 @@ def __init__(

self.queue = queue
self.exchange = exchange
self.virtual_host = virtual_host
self.consume_args = consume_args or {}

self._consumer_tag = None
Expand Down
2 changes: 1 addition & 1 deletion faststream/rabbit/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def parse_message(
The above docstring is autogenerated by docstring-gen library (https://docstring-gen.airt.ai)
"""
handler = context.get("handler_")
handler = context.get_local("handler_")
path: AnyDict = {}
path_re: Optional[Pattern[str]]
if handler and (path_re := handler.queue.path_regex):
Expand Down
13 changes: 11 additions & 2 deletions faststream/rabbit/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from faststream._compat import override
from faststream.rabbit.producer import AioPikaFastProducer
from faststream.rabbit.shared.publisher import ABCPublisher
from faststream.rabbit.shared.schemas import get_routing_hash
from faststream.rabbit.types import AioPikaSendableMessage
from faststream.types import SendableMessage

Expand All @@ -27,6 +28,13 @@ class LogicPublisher(ABCPublisher[IncomingMessage]):

_producer: Optional[AioPikaFastProducer] = field(default=None, init=False)

@property
def routing(self) -> Optional[str]:
return self.routing_key or self.queue.routing

def _get_routing_hash(self) -> int:
return get_routing_hash(self.queue, self.exchange) + hash(self.routing_key)

@override
async def publish( # type: ignore[override]
self,
Expand All @@ -36,6 +44,7 @@ async def publish( # type: ignore[override]
rpc_timeout: Optional[float] = 30.0,
raise_timeout: bool = False,
correlation_id: Optional[str] = None,
priority: Optional[int] = None,
**message_kwargs: Any,
) -> Union[aiormq.abc.ConfirmationFrameType, SendableMessage]:
"""Publish a message.
Expand All @@ -60,9 +69,8 @@ async def publish( # type: ignore[override]
assert self._producer, "Please, setup `_producer` first" # nosec B101
return await self._producer.publish(
message=message,
queue=self.queue,
exchange=self.exchange,
routing_key=self.routing_key,
routing_key=self.routing,
mandatory=self.mandatory,
immediate=self.immediate,
timeout=self.timeout,
Expand All @@ -72,6 +80,7 @@ async def publish( # type: ignore[override]
persist=self.persist,
reply_to=self.reply_to,
correlation_id=correlation_id,
priority=priority or self.priority,
**self.message_kwargs,
**message_kwargs,
)
Loading

0 comments on commit 6873f1b

Please sign in to comment.