Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AI-170: Add new broker type: dev_broker #51

Merged
merged 9 commits into from
Oct 15, 2024
104 changes: 104 additions & 0 deletions src/solace_ai_connector/common/messaging/dev_broker_messaging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""This is a simple broker for testing purposes. It allows sending and receiving
messages to/from queues. It supports subscriptions based on topics."""

from typing import Dict, List, Any
import queue
import re
from copy import deepcopy
from .messaging import Messaging


class DevBroker(Messaging):
def __init__(self, broker_properties: dict, flow_lock_manager, flow_kv_store):
super().__init__(broker_properties)
self.flow_lock_manager = flow_lock_manager
self.flow_kv_store = flow_kv_store
self.connected = False
self.subscriptions_lock = self.flow_lock_manager.get_lock("subscriptions")
with self.subscriptions_lock:
self.subscriptions = self.flow_kv_store.get("dev_broker:subscriptions")
if self.subscriptions is None:
self.subscriptions: Dict[str, List[str]] = {}
self.flow_kv_store.set("dev_broker:subscriptions", self.subscriptions)
self.queues = self.flow_kv_store.get("dev_broker:queues")
if self.queues is None:
self.queues: Dict[str, queue.Queue] = {}
self.flow_kv_store.set("dev_broker:queues", self.queues)

def connect(self):
self.connected = True
queue_name = self.broker_properties.get("queue_name")
subscriptions = self.broker_properties.get("subscriptions", [])
if queue_name:
self.queues[queue_name] = queue.Queue()
for subscription in subscriptions:
self.subscribe(subscription["topic"], queue_name)

def disconnect(self):
self.connected = False

def receive_message(self, timeout_ms, queue_name: str):
if not self.connected:
raise RuntimeError("DevBroker is not connected")

try:
return self.queues[queue_name].get(timeout=timeout_ms / 1000)
except queue.Empty:
return None

def send_message(
self,
destination_name: str,
payload: Any,
user_properties: Dict = None,
user_context: Dict = None,
):
if not self.connected:
raise RuntimeError("DevBroker is not connected")

message = {
"payload": payload,
"topic": destination_name,
"user_properties": user_properties or {},
}

matching_queue_names = self._get_matching_queue_names(destination_name)

for queue_name in matching_queue_names:
# Clone the message for each queue to ensure isolation
self.queues[queue_name].put(deepcopy(message))

if user_context and "callback" in user_context:
user_context["callback"](user_context)

def subscribe(self, subscription: str, queue_name: str):
if not self.connected:
raise RuntimeError("DevBroker is not connected")

subscription = self._subscription_to_regex(subscription)

with self.subscriptions_lock:
if queue_name not in self.queues:
self.queues[queue_name] = queue.Queue()
if subscription not in self.subscriptions:
self.subscriptions[subscription] = []
self.subscriptions[subscription].append(queue_name)

def ack_message(self, message):
pass

def _get_matching_queue_names(self, topic: str) -> List[str]:
matching_queue_names = []
with self.subscriptions_lock:
for subscription, queue_names in self.subscriptions.items():
if self._topic_matches(subscription, topic):
matching_queue_names.extend(queue_names)
return list(set(matching_queue_names)) # Remove duplicates

@staticmethod
def _topic_matches(subscription: str, topic: str) -> bool:
return re.match(f"^{subscription}$", topic) is not None

@staticmethod
def _subscription_to_regex(subscription: str) -> str:
return subscription.replace("*", "[^/]+").replace(">", ".*")
20 changes: 10 additions & 10 deletions src/solace_ai_connector/common/messaging/messaging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# messaging.py - Base class for EDA messaging services
from typing import Any, Dict


class Messaging:
Expand All @@ -11,14 +11,14 @@ def connect(self):
def disconnect(self):
raise NotImplementedError

def receive_message(self, timeout_ms):
def receive_message(self, timeout_ms, queue_id: str):
raise NotImplementedError

# def is_connected(self):
# raise NotImplementedError

# def send_message(self, destination_name: str, message: str):
# raise NotImplementedError

# def subscribe(self, subscription: str, message_handler): #: MessageHandler):
# raise NotImplementedError
def send_message(
self,
destination_name: str,
payload: Any,
user_properties: Dict = None,
user_context: Dict = None,
):
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Class to build a Messaging Service object"""

from .solace_messaging import SolaceMessaging
from .dev_broker_messaging import DevBroker


# Make a Messaging Service builder - this is a factory for Messaging Service objects
class MessagingServiceBuilder:
def __init__(self):
def __init__(self, flow_lock_manager, flow_kv_store):
self.broker_properties = {}
self.flow_lock_manager = flow_lock_manager
self.flow_kv_store = flow_kv_store

def from_properties(self, broker_properties: dict):
self.broker_properties = broker_properties
Expand All @@ -15,6 +18,10 @@ def from_properties(self, broker_properties: dict):
def build(self):
if self.broker_properties["broker_type"] == "solace":
return SolaceMessaging(self.broker_properties)
elif self.broker_properties["broker_type"] == "dev_broker":
return DevBroker(
self.broker_properties, self.flow_lock_manager, self.flow_kv_store
)

raise ValueError(
f"Unsupported broker type: {self.broker_properties['broker_type']}"
Expand Down
20 changes: 17 additions & 3 deletions src/solace_ai_connector/common/messaging/solace_messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,19 @@ def send_message(
user_context=user_context,
)

def receive_message(self, timeout_ms):
return self.persistent_receivers[0].receive_message(timeout_ms)
def receive_message(self, timeout_ms, queue_id):
broker_message = self.persistent_receivers[0].receive_message(timeout_ms)
if broker_message is None:
return None

# Convert Solace message to dictionary format
return {
"payload": broker_message.get_payload_as_string()
or broker_message.get_payload_as_bytes(),
"topic": broker_message.get_destination_name(),
"user_properties": broker_message.get_properties(),
"_original_message": broker_message, # Keep original message for acknowledgement
}
cyrus2281 marked this conversation as resolved.
Show resolved Hide resolved

def subscribe(
self, subscription: str, persistent_receiver: PersistentMessageReceiver
Expand All @@ -256,4 +267,7 @@ def subscribe(
persistent_receiver.add_subscription(sub)

def ack_message(self, broker_message):
self.persistent_receiver.ack(broker_message)
if "_original_message" in broker_message:
self.persistent_receiver.ack(broker_message["_original_message"])
else:
log.warning("Cannot acknowledge message: original Solace message not found")
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, module_info, **kwargs):
self.broker_properties = self.get_broker_properties()
if self.broker_properties["broker_type"] not in ["test", "test_streaming"]:
self.messaging_service = (
MessagingServiceBuilder()
MessagingServiceBuilder(self.flow_lock_manager, self.flow_kv_store)
.from_properties(self.broker_properties)
.build()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,18 @@ def invoke(self, message, data):
def get_next_message(self, timeout_ms=None):
if timeout_ms is None:
timeout_ms = DEFAULT_TIMEOUT_MS
broker_message = self.messaging_service.receive_message(timeout_ms)
broker_message = self.messaging_service.receive_message(
timeout_ms, self.broker_properties["queue_name"]
)
if not broker_message:
return None
self.current_broker_message = broker_message

payload = broker_message.get_payload_as_string()
if payload is None:
payload = broker_message.get_payload_as_bytes()
payload = broker_message.get("payload")
payload = self.decode_payload(payload)

topic = broker_message.get_destination_name()
user_properties = broker_message.get_properties()
topic = broker_message.get("topic")
user_properties = broker_message.get("user_properties", {})
log.debug(
"Received message from broker: topic=%s, user_properties=%s, payload length=%d",
topic,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,12 @@ def __init__(self, **kwargs):
]
self.test_mode = False

if self.broker_type == "solace":
self.connect()
elif self.broker_type == "test" or self.broker_type == "test_streaming":
if self.broker_type == "test" or self.broker_type == "test_streaming":
self.test_mode = True
self.setup_test_pass_through()
else:
self.connect()

self.start()

def start(self):
Expand All @@ -224,7 +225,9 @@ def start_response_thread(self):
def handle_responses(self):
while not self.stop_signal.is_set():
try:
broker_message = self.messaging_service.receive_message(1000)
broker_message = self.messaging_service.receive_message(
1000, self.reply_queue_name
)
if broker_message:
self.process_response(broker_message)
except Exception as e:
Expand All @@ -248,12 +251,10 @@ def process_response(self, broker_message):
topic = broker_message.get_topic()
user_properties = broker_message.get_user_properties()
else:
payload = broker_message.get_payload_as_string()
if payload is None:
payload = broker_message.get_payload_as_bytes()
payload = broker_message.get("payload")
payload = self.decode_payload(payload)
topic = broker_message.get_destination_name()
user_properties = broker_message.get_properties()
topic = broker_message.get("topic")
user_properties = broker_message.get("user_properties", {})

metadata_json = user_properties.get(
"__solace_ai_connector_broker_request_reply_metadata__"
Expand Down
Loading