Skip to content

Commit

Permalink
fix (#1759): cast Enums to str in ConfluentConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik committed Sep 5, 2024
1 parent 1adc2c3 commit 9aea482
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 45 deletions.
3 changes: 2 additions & 1 deletion faststream/confluent/broker/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AsyncConfluentConsumer,
AsyncConfluentProducer,
)
from faststream.confluent.config import ConfluentFastConfig
from faststream.confluent.publisher.producer import AsyncConfluentFastProducer
from faststream.confluent.schemas.params import ConsumerConnectionParams
from faststream.confluent.security import parse_security
Expand Down Expand Up @@ -394,7 +395,7 @@ def __init__(
)
self.client_id = client_id
self._producer = None
self.config = config
self.config = ConfluentFastConfig(config)

async def _close(
self,
Expand Down
39 changes: 18 additions & 21 deletions faststream/confluent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from confluent_kafka import Consumer, KafkaError, KafkaException, Message, Producer
from confluent_kafka.admin import AdminClient, NewTopic

from faststream.confluent.config import ConfluentConfig
from faststream.confluent import config as config_module
from faststream.confluent.schemas import TopicPartition
from faststream.exceptions import SetupError
from faststream.log import logger as faststream_logger
Expand All @@ -34,6 +34,7 @@ def __init__(
self,
*,
logger: Optional["LoggerProto"],
config: config_module.ConfluentFastConfig,
bootstrap_servers: Union[str, List[str]] = "localhost",
client_id: Optional[str] = None,
metadata_max_age_ms: int = 300000,
Expand All @@ -53,12 +54,9 @@ def __init__(
sasl_mechanism: Optional[str] = None,
sasl_plain_password: Optional[str] = None,
sasl_plain_username: Optional[str] = None,
config: Optional[ConfluentConfig] = None,
) -> None:
self.logger = logger

self.config: Dict[str, Any] = {} if config is None else dict(config)

if isinstance(bootstrap_servers, Iterable) and not isinstance(
bootstrap_servers, str
):
Expand Down Expand Up @@ -89,18 +87,19 @@ def __init__(
"connections.max.idle.ms": connections_max_idle_ms,
"allow.auto.create.topics": allow_auto_create_topics,
}
self.config = {**self.config, **config_from_params}

final_config = {**config.as_config_dict(), **config_from_params}

if sasl_mechanism in ["PLAIN", "SCRAM-SHA-256", "SCRAM-SHA-512"]:
self.config.update(
final_config.update(
{
"sasl.mechanism": sasl_mechanism,
"sasl.username": sasl_plain_username,
"sasl.password": sasl_plain_password,
}
)

self.producer = Producer(self.config, logger=self.logger)
self.producer = Producer(final_config, logger=self.logger)

async def stop(self) -> None:
"""Stop the Kafka producer and flush remaining messages."""
Expand Down Expand Up @@ -180,6 +179,7 @@ def __init__(
*topics: str,
partitions: Sequence["TopicPartition"],
logger: Optional["LoggerProto"],
config: config_module.ConfluentFastConfig,
bootstrap_servers: Union[str, List[str]] = "localhost",
client_id: Optional[str] = "confluent-kafka-consumer",
group_id: Optional[str] = None,
Expand All @@ -205,18 +205,9 @@ def __init__(
sasl_mechanism: Optional[str] = None,
sasl_plain_password: Optional[str] = None,
sasl_plain_username: Optional[str] = None,
config: Optional[ConfluentConfig] = None,
) -> None:
self.logger = logger

self.config: Dict[str, Any] = {} if config is None else dict(config)

if group_id is None:
group_id = self.config.get("group.id", "faststream-consumer-group")

if group_instance_id is None:
group_instance_id = self.config.get("group.instance.id", None)

if isinstance(bootstrap_servers, Iterable) and not isinstance(
bootstrap_servers, str
):
Expand All @@ -232,13 +223,18 @@ def __init__(
for x in partition_assignment_strategy
]
)

final_config = config.as_config_dict()

config_from_params = {
"allow.auto.create.topics": allow_auto_create_topics,
"topic.metadata.refresh.interval.ms": 1000,
"bootstrap.servers": bootstrap_servers,
"client.id": client_id,
"group.id": group_id,
"group.instance.id": group_instance_id,
"group.id": group_id
or final_config.get("group.id", "faststream-consumer-group"),
"group.instance.id": group_instance_id
or final_config.get("group.instance.id", None),
"fetch.wait.max.ms": fetch_max_wait_ms,
"fetch.max.bytes": fetch_max_bytes,
"fetch.min.bytes": fetch_min_bytes,
Expand All @@ -259,18 +255,19 @@ def __init__(
"isolation.level": isolation_level,
}
self.allow_auto_create_topics = allow_auto_create_topics
self.config = {**self.config, **config_from_params}
final_config.update(config_from_params)

if sasl_mechanism in ["PLAIN", "SCRAM-SHA-256", "SCRAM-SHA-512"]:
self.config.update(
final_config.update(
{
"sasl.mechanism": sasl_mechanism,
"sasl.username": sasl_plain_username,
"sasl.password": sasl_plain_password,
}
)

self.consumer = Consumer(self.config, logger=self.logger)
self.config = final_config
self.consumer = Consumer(self.final_config, logger=self.logger)

@property
def topics_to_create(self) -> List[str]:
Expand Down
78 changes: 55 additions & 23 deletions faststream/confluent/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from enum import Enum
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from typing_extensions import TypedDict

if TYPE_CHECKING:
from faststream.types import AnyDict

class BuiltinFeatures(Enum):

class BuiltinFeatures(str, Enum):
gzip = "gzip"
snappy = "snappy"
ssl = "ssl"
Expand All @@ -21,7 +24,7 @@ class BuiltinFeatures(Enum):
oidc = "oidc"


class Debug(Enum):
class Debug(str, Enum):
generic = "generic"
broker = "broker"
topic = "topic"
Expand All @@ -44,57 +47,57 @@ class Debug(Enum):
all = "all"


class BrokerAddressFamily(Enum):
class BrokerAddressFamily(str, Enum):
any = "any"
v4 = "v4"
v6 = "v6"


class SecurityProtocol(Enum):
class SecurityProtocol(str, Enum):
plaintext = "plaintext"
ssl = "ssl"
sasl_plaintext = "sasl_plaintext"
sasl_ssl = "sasl_ssl"


class SASLOAUTHBearerMethod(Enum):
class SASLOAUTHBearerMethod(str, Enum):
default = "default"
oidc = "oidc"


class GroupProtocol(Enum):
class GroupProtocol(str, Enum):
classic = "classic"
consumer = "consumer"


class OffsetStoreMethod(Enum):
class OffsetStoreMethod(str, Enum):
none = "none"
file = "file"
broker = "broker"


class IsolationLevel(Enum):
class IsolationLevel(str, Enum):
read_uncommitted = "read_uncommitted"
read_committed = "read_committed"


class CompressionCodec(Enum):
class CompressionCodec(str, Enum):
none = "none"
gzip = "gzip"
snappy = "snappy"
lz4 = "lz4"
zstd = "zstd"


class CompressionType(Enum):
class CompressionType(str, Enum):
none = "none"
gzip = "gzip"
snappy = "snappy"
lz4 = "lz4"
zstd = "zstd"


class ClientDNSLookup(Enum):
class ClientDNSLookup(str, Enum):
use_all_dns_ips = "use_all_dns_ips"
resolve_canonical_bootstrap_servers_only = (
"resolve_canonical_bootstrap_servers_only"
Expand All @@ -104,7 +107,17 @@ class ClientDNSLookup(Enum):
ConfluentConfig = TypedDict(
"ConfluentConfig",
{
"builtin.features": BuiltinFeatures,
"compression.codec": Union[CompressionCodec, str],
"compression.type": Union[CompressionType, str],
"client.dns.lookup": Union[ClientDNSLookup, str],
"offset.store.method": Union[OffsetStoreMethod, str],
"isolation.level": Union[IsolationLevel, str],
"sasl.oauthbearer.method": Union[SASLOAUTHBearerMethod, str],
"security.protocol": Union[SecurityProtocol, str],
"broker.address.family": Union[BrokerAddressFamily, str],
"builtin.features": Union[BuiltinFeatures, str],
"debug": Union[Debug, str],
"group.protocol": Union[GroupProtocol, str],
"client.id": str,
"metadata.broker.list": str,
"bootstrap.servers": str,
Expand All @@ -120,7 +133,6 @@ class ClientDNSLookup(Enum):
"topic.metadata.refresh.sparse": bool,
"topic.metadata.propagation.max.ms": int,
"topic.blacklist": str,
"debug": Debug,
"socket.timeout.ms": int,
"socket.blocking.max.ms": int,
"socket.send.buffer.bytes": int,
Expand All @@ -129,7 +141,6 @@ class ClientDNSLookup(Enum):
"socket.nagle.disable": bool,
"socket.max.fails": int,
"broker.address.ttl": int,
"broker.address.family": BrokerAddressFamily,
"socket.connection.setup.timeout.ms": int,
"connections.max.idle.ms": int,
"reconnect.backoff.jitter.ms": int,
Expand Down Expand Up @@ -160,7 +171,6 @@ class ClientDNSLookup(Enum):
"api.version.fallback.ms": int,
"broker.version.fallback": str,
"allow.auto.create.topics": bool,
"security.protocol": SecurityProtocol,
"ssl.cipher.suites": str,
"ssl.curves.list": str,
"ssl.sigalgs.list": str,
Expand Down Expand Up @@ -197,7 +207,6 @@ class ClientDNSLookup(Enum):
"sasl.oauthbearer.config": str,
"enable.sasl.oauthbearer.unsecure.jwt": bool,
"oauthbearer_token_refresh_cb": Callable[..., Any],
"sasl.oauthbearer.method": SASLOAUTHBearerMethod,
"sasl.oauthbearer.client.id": str,
"sasl.oauthbearer.client.secret": str,
"sasl.oauthbearer.scope": str,
Expand All @@ -211,7 +220,6 @@ class ClientDNSLookup(Enum):
"session.timeout.ms": str,
"heartbeat.interval.ms": str,
"group.protocol.type": str,
"group.protocol": GroupProtocol,
"group.remote.assignor": str,
"coordinator.query.interval.ms": int,
"max.poll.interval.ms": int,
Expand All @@ -227,8 +235,6 @@ class ClientDNSLookup(Enum):
"fetch.max.bytes": int,
"fetch.min.bytes": int,
"fetch.error.backoff.ms": int,
"offset.store.method": OffsetStoreMethod,
"isolation.level": IsolationLevel,
"consume_cb": Callable[..., Any],
"rebalance_cb": Callable[..., Any],
"offset_commit_cb": Callable[..., Any],
Expand All @@ -248,15 +254,41 @@ class ClientDNSLookup(Enum):
"retry.backoff.ms": int,
"retry.backoff.max.ms": int,
"queue.buffering.backpressure.threshold": int,
"compression.codec": CompressionCodec,
"compression.type": CompressionType,
"batch.num.messages": int,
"batch.size": int,
"delivery.report.only.error": bool,
"dr_cb": Callable[..., Any],
"dr_msg_cb": Callable[..., Any],
"sticky.partitioning.linger.ms": int,
"client.dns.lookup": ClientDNSLookup,
},
total=False,
)


class ConfluentFastConfig:
def __init__(self, config: Optional[ConfluentConfig]) -> None:
self.config = config

def as_config_dict(self) -> "AnyDict":
if not self.config:
return {}

data = dict(self.config)

for key, enum in (
("compression.codec", CompressionCodec),
("compression.type", CompressionType),
("client.dns.lookup", ClientDNSLookup),
("offset.store.method", OffsetStoreMethod),
("isolation.level", IsolationLevel),
("sasl.oauthbearer.method", SASLOAUTHBearerMethod),
("security.protocol", SecurityProtocol),
("broker.address.family", BrokerAddressFamily),
("builtin.features", BuiltinFeatures),
("debug", Debug),
("group.protocol", GroupProtocol),
):
if key in data:
data[key] = enum(data[key]).value

return data

0 comments on commit 9aea482

Please sign in to comment.