diff --git a/examples/websocket/websocket.yaml b/examples/websocket/websocket.yaml new file mode 100644 index 00000000..4efabfaf --- /dev/null +++ b/examples/websocket/websocket.yaml @@ -0,0 +1,42 @@ +--- + # Example configuration for a WebSocket flow + # This flow creates a WebSocket server that echoes messages back to clients + # It also serves an example HTML file for easy testing + + log: + stdout_log_level: INFO + log_file_level: DEBUG + log_file: solace_ai_connector.log + + flows: + - name: websocket_echo + components: + # WebSocket Input + - component_name: websocket_input + component_module: websocket_input + component_config: + listen_port: 5000 + serve_html: true + html_path: "examples/websocket/websocket_example_app.html" + + # Pass Through + - component_name: pass_through + component_module: pass_through + component_config: {} + input_transforms: + - type: copy + source_expression: input.payload + dest_expression: user_data.input:payload + - type: copy + source_expression: input.user_properties:socket_id + dest_expression: user_data.input:socket_id + input_selection: + source_expression: user_data.input + + # WebSocket Output + - component_name: websocket_output + component_module: websocket_output + component_config: + payload_encoding: none + input_selection: + source_expression: previous diff --git a/examples/websocket/websocket_example_app.html b/examples/websocket/websocket_example_app.html new file mode 100644 index 00000000..02d54dbc --- /dev/null +++ b/examples/websocket/websocket_example_app.html @@ -0,0 +1,266 @@ + + + + + + WebSocket Example App + + + + +
+

WebSocket Example App

+

This is a simple app to show how JSON can be sent into a solace-ai-connector flow and how to receive output from it. Just hit connect to connect to the flow, type in some JSON and hit send. Your JSON should be echoed back to you.

+ +
+ + + Disconnected +
+ +
+ + +
+
+ +
+

Received Messages

+
+ +
+
+ + + + + diff --git a/src/solace_ai_connector/common/messaging/dev_broker_messaging.py b/src/solace_ai_connector/common/messaging/dev_broker_messaging.py new file mode 100644 index 00000000..0d10cd6b --- /dev/null +++ b/src/solace_ai_connector/common/messaging/dev_broker_messaging.py @@ -0,0 +1,104 @@ +"""This is a simple broker for testing purposes. It allows sending and receiving +messages to/from queues. It supports subscriptions based on topics.""" + +from typing import Dict, List, Any +import queue +import re +from copy import deepcopy +from .messaging import Messaging + + +class DevBroker(Messaging): + def __init__(self, broker_properties: dict, flow_lock_manager, flow_kv_store): + super().__init__(broker_properties) + self.flow_lock_manager = flow_lock_manager + self.flow_kv_store = flow_kv_store + self.connected = False + self.subscriptions_lock = self.flow_lock_manager.get_lock("subscriptions") + with self.subscriptions_lock: + self.subscriptions = self.flow_kv_store.get("dev_broker:subscriptions") + if self.subscriptions is None: + self.subscriptions: Dict[str, List[str]] = {} + self.flow_kv_store.set("dev_broker:subscriptions", self.subscriptions) + self.queues = self.flow_kv_store.get("dev_broker:queues") + if self.queues is None: + self.queues: Dict[str, queue.Queue] = {} + self.flow_kv_store.set("dev_broker:queues", self.queues) + + def connect(self): + self.connected = True + queue_name = self.broker_properties.get("queue_name") + subscriptions = self.broker_properties.get("subscriptions", []) + if queue_name: + self.queues[queue_name] = queue.Queue() + for subscription in subscriptions: + self.subscribe(subscription["topic"], queue_name) + + def disconnect(self): + self.connected = False + + def receive_message(self, timeout_ms, queue_name: str): + if not self.connected: + raise RuntimeError("DevBroker is not connected") + + try: + return self.queues[queue_name].get(timeout=timeout_ms / 1000) + except queue.Empty: + return None + + def send_message( + self, + destination_name: str, + payload: Any, + user_properties: Dict = None, + user_context: Dict = None, + ): + if not self.connected: + raise RuntimeError("DevBroker is not connected") + + message = { + "payload": payload, + "topic": destination_name, + "user_properties": user_properties or {}, + } + + matching_queue_names = self._get_matching_queue_names(destination_name) + + for queue_name in matching_queue_names: + # Clone the message for each queue to ensure isolation + self.queues[queue_name].put(deepcopy(message)) + + if user_context and "callback" in user_context: + user_context["callback"](user_context) + + def subscribe(self, subscription: str, queue_name: str): + if not self.connected: + raise RuntimeError("DevBroker is not connected") + + subscription = self._subscription_to_regex(subscription) + + with self.subscriptions_lock: + if queue_name not in self.queues: + self.queues[queue_name] = queue.Queue() + if subscription not in self.subscriptions: + self.subscriptions[subscription] = [] + self.subscriptions[subscription].append(queue_name) + + def ack_message(self, message): + pass + + def _get_matching_queue_names(self, topic: str) -> List[str]: + matching_queue_names = [] + with self.subscriptions_lock: + for subscription, queue_names in self.subscriptions.items(): + if self._topic_matches(subscription, topic): + matching_queue_names.extend(queue_names) + return list(set(matching_queue_names)) # Remove duplicates + + @staticmethod + def _topic_matches(subscription: str, topic: str) -> bool: + return re.match(f"^{subscription}$", topic) is not None + + @staticmethod + def _subscription_to_regex(subscription: str) -> str: + return subscription.replace("*", "[^/]+").replace(">", ".*") diff --git a/src/solace_ai_connector/common/messaging/messaging.py b/src/solace_ai_connector/common/messaging/messaging.py index 08448635..5847effd 100644 --- a/src/solace_ai_connector/common/messaging/messaging.py +++ b/src/solace_ai_connector/common/messaging/messaging.py @@ -1,4 +1,4 @@ -# messaging.py - Base class for EDA messaging services +from typing import Any, Dict class Messaging: @@ -11,14 +11,14 @@ def connect(self): def disconnect(self): raise NotImplementedError - def receive_message(self, timeout_ms): + def receive_message(self, timeout_ms, queue_id: str): raise NotImplementedError - # def is_connected(self): - # raise NotImplementedError - - # def send_message(self, destination_name: str, message: str): - # raise NotImplementedError - - # def subscribe(self, subscription: str, message_handler): #: MessageHandler): - # raise NotImplementedError + def send_message( + self, + destination_name: str, + payload: Any, + user_properties: Dict = None, + user_context: Dict = None, + ): + raise NotImplementedError diff --git a/src/solace_ai_connector/common/messaging/messaging_builder.py b/src/solace_ai_connector/common/messaging/messaging_builder.py index 423d2465..826cdd45 100644 --- a/src/solace_ai_connector/common/messaging/messaging_builder.py +++ b/src/solace_ai_connector/common/messaging/messaging_builder.py @@ -1,12 +1,15 @@ """Class to build a Messaging Service object""" from .solace_messaging import SolaceMessaging +from .dev_broker_messaging import DevBroker # Make a Messaging Service builder - this is a factory for Messaging Service objects class MessagingServiceBuilder: - def __init__(self): + def __init__(self, flow_lock_manager, flow_kv_store): self.broker_properties = {} + self.flow_lock_manager = flow_lock_manager + self.flow_kv_store = flow_kv_store def from_properties(self, broker_properties: dict): self.broker_properties = broker_properties @@ -15,6 +18,10 @@ def from_properties(self, broker_properties: dict): def build(self): if self.broker_properties["broker_type"] == "solace": return SolaceMessaging(self.broker_properties) + elif self.broker_properties["broker_type"] == "dev_broker": + return DevBroker( + self.broker_properties, self.flow_lock_manager, self.flow_kv_store + ) raise ValueError( f"Unsupported broker type: {self.broker_properties['broker_type']}" diff --git a/src/solace_ai_connector/common/messaging/solace_messaging.py b/src/solace_ai_connector/common/messaging/solace_messaging.py index ed330914..4b03e7a5 100644 --- a/src/solace_ai_connector/common/messaging/solace_messaging.py +++ b/src/solace_ai_connector/common/messaging/solace_messaging.py @@ -246,8 +246,19 @@ def send_message( user_context=user_context, ) - def receive_message(self, timeout_ms): - return self.persistent_receivers[0].receive_message(timeout_ms) + def receive_message(self, timeout_ms, queue_id): + broker_message = self.persistent_receivers[0].receive_message(timeout_ms) + if broker_message is None: + return None + + # Convert Solace message to dictionary format + return { + "payload": broker_message.get_payload_as_string() + or broker_message.get_payload_as_bytes(), + "topic": broker_message.get_destination_name(), + "user_properties": broker_message.get_properties(), + "_original_message": broker_message, # Keep original message for acknowledgement + } def subscribe( self, subscription: str, persistent_receiver: PersistentMessageReceiver @@ -256,4 +267,7 @@ def subscribe( persistent_receiver.add_subscription(sub) def ack_message(self, broker_message): - self.persistent_receiver.ack(broker_message) + if "_original_message" in broker_message: + self.persistent_receiver.ack(broker_message["_original_message"]) + else: + log.warning("Cannot acknowledge message: original Solace message not found") diff --git a/src/solace_ai_connector/common/utils.py b/src/solace_ai_connector/common/utils.py index 4004e873..e53fb610 100644 --- a/src/solace_ai_connector/common/utils.py +++ b/src/solace_ai_connector/common/utils.py @@ -7,6 +7,11 @@ import builtins import subprocess import types +import base64 +import gzip +import json +import yaml + from .log import log @@ -134,8 +139,13 @@ def import_module(module, base_path=None, component_package=None): ) else: return importlib.import_module(full_name) - except ModuleNotFoundError: - pass + except ModuleNotFoundError as e: + name = str(e.name) + if ( + name != "solace_ai_connector" + and name.split(".")[-1] != full_name.split(".")[-1] + ): + raise e except Exception as e: raise ImportError( f"Module load error for {full_name}: {e}" @@ -335,3 +345,43 @@ def ensure_slash_on_start(string): if not string.startswith("/"): return "/" + string return string + + +def encode_payload(payload, encoding, payload_format): + # First, format the payload + if payload_format == "json": + formatted_payload = json.dumps(payload) + elif payload_format == "yaml": + formatted_payload = yaml.dump(payload) + elif isinstance(payload, bytes) or isinstance(payload, bytearray): + formatted_payload = payload + else: + formatted_payload = str(payload) + + # Then, encode the formatted payload + if encoding == "utf-8": + return formatted_payload.encode("utf-8") + elif encoding == "base64": + return base64.b64encode(formatted_payload.encode("utf-8")) + elif encoding == "gzip": + return gzip.compress(formatted_payload.encode("utf-8")) + else: + return formatted_payload + + +def decode_payload(payload, encoding, payload_format): + if encoding == "base64": + payload = base64.b64decode(payload) + elif encoding == "gzip": + payload = gzip.decompress(payload) + elif encoding == "utf-8" and ( + isinstance(payload, bytes) or isinstance(payload, bytearray) + ): + payload = payload.decode("utf-8") + + if payload_format == "json": + payload = json.loads(payload) + elif payload_format == "yaml": + payload = yaml.safe_load(payload) + + return payload diff --git a/src/solace_ai_connector/components/component_base.py b/src/solace_ai_connector/components/component_base.py index f059c064..f7c8c419 100644 --- a/src/solace_ai_connector/components/component_base.py +++ b/src/solace_ai_connector/components/component_base.py @@ -12,7 +12,7 @@ from ..common.event import Event, EventType from ..flow.request_response_flow_controller import RequestResponseFlowController -DEFAULT_QUEUE_TIMEOUT_MS = 200 +DEFAULT_QUEUE_TIMEOUT_MS = 1000 DEFAULT_QUEUE_MAX_DEPTH = 5 @@ -68,23 +68,29 @@ def run(self): try: event = self.get_next_event() if event is not None: - if self.trace_queue: - self.trace_event(event) - self.process_event(event) + self.process_event_with_tracing(event) except AssertionError as e: raise e except Exception as e: - log.error( - "%sComponent has crashed: %s\n%s", - self.log_identifier, - e, - traceback.format_exc(), - ) - if self.error_queue: - self.handle_error(e, event) + self.handle_component_error(e, event) self.stop_component() + def process_event_with_tracing(self, event): + if self.trace_queue: + self.trace_event(event) + self.process_event(event) + + def handle_component_error(self, e, event): + log.error( + "%sComponent has crashed: %s\n%s", + self.log_identifier, + e, + traceback.format_exc(), + ) + if self.error_queue: + self.handle_error(e, event) + def get_next_event(self): # Check if there is a get_next_message defined by a # component that inherits from this class - this is diff --git a/src/solace_ai_connector/components/inputs_outputs/broker_base.py b/src/solace_ai_connector/components/inputs_outputs/broker_base.py index fac4207d..c312740b 100644 --- a/src/solace_ai_connector/components/inputs_outputs/broker_base.py +++ b/src/solace_ai_connector/components/inputs_outputs/broker_base.py @@ -1,17 +1,13 @@ """Base class for broker input/output components for the Solace AI Event Connector""" -import base64 -import gzip -import json -import yaml import uuid from abc import abstractmethod -# from solace_ai_connector.common.log import log from ..component_base import ComponentBase from ...common.message import Message from ...common.messaging.messaging_builder import MessagingServiceBuilder +from ...common.utils import encode_payload, decode_payload # TBD - at the moment, there is no connection sharing supported. It should be possible # to share a connection between multiple components and even flows. The changes @@ -39,7 +35,7 @@ def __init__(self, module_info, **kwargs): self.broker_properties = self.get_broker_properties() if self.broker_properties["broker_type"] not in ["test", "test_streaming"]: self.messaging_service = ( - MessagingServiceBuilder() + MessagingServiceBuilder(self.flow_lock_manager, self.flow_kv_store) .from_properties(self.broker_properties) .build() ) @@ -68,51 +64,12 @@ def stop_component(self): def decode_payload(self, payload): encoding = self.get_config("payload_encoding") payload_format = self.get_config("payload_format") - if encoding == "base64": - payload = base64.b64decode(payload) - elif encoding == "gzip": - payload = gzip.decompress(payload) - elif encoding == "utf-8" and ( - isinstance(payload, bytes) or isinstance(payload, bytearray) - ): - payload = payload.decode("utf-8") - if payload_format == "json": - payload = json.loads(payload) - elif payload_format == "yaml": - payload = yaml.safe_load(payload) - return payload + return decode_payload(payload, encoding, payload_format) def encode_payload(self, payload): encoding = self.get_config("payload_encoding") payload_format = self.get_config("payload_format") - if encoding == "utf-8": - if payload_format == "json": - return json.dumps(payload).encode("utf-8") - elif payload_format == "yaml": - return yaml.dump(payload).encode("utf-8") - else: - return str(payload).encode("utf-8") - elif encoding == "base64": - if payload_format == "json": - return base64.b64encode(json.dumps(payload).encode("utf-8")) - elif payload_format == "yaml": - return base64.b64encode(yaml.dump(payload).encode("utf-8")) - else: - return base64.b64encode(str(payload).encode("utf-8")) - elif encoding == "gzip": - if payload_format == "json": - return gzip.compress(json.dumps(payload).encode("utf-8")) - elif payload_format == "yaml": - return gzip.compress(yaml.dump(payload).encode("utf-8")) - else: - return gzip.compress(str(payload).encode("utf-8")) - else: - if payload_format == "json": - return json.dumps(payload) - elif payload_format == "yaml": - return yaml.dump(payload) - else: - return str(payload) + return encode_payload(payload, encoding, payload_format) def get_egress_topic(self, message: Message): pass diff --git a/src/solace_ai_connector/components/inputs_outputs/broker_input.py b/src/solace_ai_connector/components/inputs_outputs/broker_input.py index 3aabd8d7..2d277cb0 100644 --- a/src/solace_ai_connector/components/inputs_outputs/broker_input.py +++ b/src/solace_ai_connector/components/inputs_outputs/broker_input.py @@ -110,16 +110,18 @@ def invoke(self, message, data): def get_next_message(self, timeout_ms=None): if timeout_ms is None: timeout_ms = DEFAULT_TIMEOUT_MS - broker_message = self.messaging_service.receive_message(timeout_ms) + broker_message = self.messaging_service.receive_message( + timeout_ms, self.broker_properties["queue_name"] + ) if not broker_message: return None self.current_broker_message = broker_message - payload = broker_message.get_payload_as_string() - topic = broker_message.get_destination_name() - if payload is None: - payload = broker_message.get_payload_as_bytes() + + payload = broker_message.get("payload") payload = self.decode_payload(payload) - user_properties = broker_message.get_properties() + + topic = broker_message.get("topic") + user_properties = broker_message.get("user_properties", {}) log.debug( "Received message from broker: topic=%s, user_properties=%s, payload length=%d", topic, diff --git a/src/solace_ai_connector/components/inputs_outputs/broker_request_response.py b/src/solace_ai_connector/components/inputs_outputs/broker_request_response.py index cb217b96..4c33ddb4 100644 --- a/src/solace_ai_connector/components/inputs_outputs/broker_request_response.py +++ b/src/solace_ai_connector/components/inputs_outputs/broker_request_response.py @@ -193,11 +193,12 @@ def __init__(self, **kwargs): ] self.test_mode = False - if self.broker_type == "solace": - self.connect() - elif self.broker_type == "test" or self.broker_type == "test_streaming": + if self.broker_type == "test" or self.broker_type == "test_streaming": self.test_mode = True self.setup_test_pass_through() + else: + self.connect() + self.start() def start(self): @@ -224,7 +225,9 @@ def start_response_thread(self): def handle_responses(self): while not self.stop_signal.is_set(): try: - broker_message = self.messaging_service.receive_message(1000) + broker_message = self.messaging_service.receive_message( + 1000, self.reply_queue_name + ) if broker_message: self.process_response(broker_message) except Exception as e: @@ -248,12 +251,10 @@ def process_response(self, broker_message): topic = broker_message.get_topic() user_properties = broker_message.get_user_properties() else: - payload = broker_message.get_payload_as_string() - if payload is None: - payload = broker_message.get_payload_as_bytes() + payload = broker_message.get("payload") payload = self.decode_payload(payload) - topic = broker_message.get_destination_name() - user_properties = broker_message.get_properties() + topic = broker_message.get("topic") + user_properties = broker_message.get("user_properties", {}) metadata_json = user_properties.get( "__solace_ai_connector_broker_request_reply_metadata__" diff --git a/src/solace_ai_connector/components/inputs_outputs/websocket_base.py b/src/solace_ai_connector/components/inputs_outputs/websocket_base.py new file mode 100644 index 00000000..ffe1ad92 --- /dev/null +++ b/src/solace_ai_connector/components/inputs_outputs/websocket_base.py @@ -0,0 +1,143 @@ +"""Base class for WebSocket components.""" + +from abc import ABC, abstractmethod +from flask import Flask, send_file, request +from flask_socketio import SocketIO +import logging +from ...common.log import log +from ..component_base import ComponentBase +import copy +from flask.logging import default_handler + +base_info = { + "config_parameters": [ + { + "name": "listen_port", + "type": "int", + "required": False, + "description": "Port to listen on (optional)", + }, + { + "name": "serve_html", + "type": "bool", + "required": False, + "description": "Serve the example HTML file", + "default": False, + }, + { + "name": "html_path", + "type": "string", + "required": False, + "description": "Path to the HTML file to serve", + "default": "examples/websocket/websocket_example_app.html", + }, + { + "name": "payload_encoding", + "required": False, + "description": "Encoding for the payload (utf-8, base64, gzip, none)", + "default": "none", + }, + { + "name": "payload_format", + "required": False, + "description": "Format for the payload (json, yaml, text)", + "default": "json", + }, + ], +} + + +class WebsocketBase(ComponentBase, ABC): + def __init__(self, info, **kwargs): + super().__init__(info, **kwargs) + self.listen_port = self.get_config("listen_port") + self.serve_html = self.get_config("serve_html", False) + self.html_path = self.get_config("html_path", "") + self.sockets = {} + self.app = None + self.socketio = None + + if self.listen_port: + self.setup_websocket_server() + + def setup_websocket_server(self): + self.app = Flask(__name__) + + # Enable Flask debugging + self.app.debug = False + + # Set up Flask logging + # self.app.logger.setLevel(logging.DEBUG) + # self.app.logger.addHandler(default_handler) + + # Enable SocketIO logging + # logging.getLogger("socketio").setLevel(logging.DEBUG) + # logging.getLogger("engineio").setLevel(logging.DEBUG) + + self.socketio = SocketIO( + self.app, cors_allowed_origins="*", logger=False, engineio_logger=False + ) + self.setup_websocket() + + if self.serve_html: + self.setup_html_route() + + def setup_html_route(self): + @self.app.route("/") + def serve_html(): + return send_file(self.html_path) + + def setup_websocket(self): + @self.socketio.on("connect") + def handle_connect(): + socket_id = request.sid + self.sockets[socket_id] = self.socketio + self.kv_store_set("websocket_connections", self.sockets) + log.info("New WebSocket connection established. Socket ID: %s", socket_id) + return socket_id + + @self.socketio.on("disconnect") + def handle_disconnect(): + socket_id = request.sid + if socket_id in self.sockets: + del self.sockets[socket_id] + self.kv_store_set("websocket_connections", self.sockets) + log.info("WebSocket connection closed. Socket ID: %s", socket_id) + + def run_server(self): + if self.socketio: + self.socketio.run( + self.app, port=self.listen_port, debug=False, use_reloader=False + ) + + def stop_server(self): + if self.socketio: + self.socketio.stop() + if self.app: + func = request.environ.get("werkzeug.server.shutdown") + if func is None: + raise RuntimeError("Not running with the Werkzeug Server") + func() + + def get_sockets(self): + if not self.sockets: + self.sockets = self.kv_store_get("websocket_connections") or {} + return self.sockets + + def send_to_socket(self, socket_id, payload): + sockets = self.get_sockets() + if socket_id == "*": + for socket in sockets.values(): + socket.emit("message", payload) + log.debug("Message sent to all WebSocket connections") + elif socket_id in sockets: + sockets[socket_id].emit("message", payload) + log.debug("Message sent to WebSocket connection %s", socket_id) + else: + log.error("No active connection found for socket_id: %s", socket_id) + return False + return True + + @abstractmethod + def invoke(self, message, data): + pass diff --git a/src/solace_ai_connector/components/inputs_outputs/websocket_input.py b/src/solace_ai_connector/components/inputs_outputs/websocket_input.py new file mode 100644 index 00000000..6b806172 --- /dev/null +++ b/src/solace_ai_connector/components/inputs_outputs/websocket_input.py @@ -0,0 +1,84 @@ +"""This component receives messages from a websocket connection and sends them to the next component in the flow.""" + +import json +import os +import copy + +from flask import request +from ...common.log import log +from ...common.message import Message +from ...common.event import Event, EventType +from ...common.utils import decode_payload +from .websocket_base import WebsocketBase, base_info + + +# Merge base_info into info +info = copy.deepcopy(base_info) +info.update( + { + "class_name": "WebsocketInput", + "description": "Listen for incoming messages on a websocket connection.", + "output_schema": { + "type": "object", + "properties": { + "payload": { + "type": "object", + "description": "The decoded JSON payload received from the WebSocket", + }, + }, + "required": ["payload"], + }, + } +) + + +class WebsocketInput(WebsocketBase): + def __init__(self, **kwargs): + super().__init__(info, **kwargs) + self.payload_encoding = self.get_config("payload_encoding") + self.payload_format = self.get_config("payload_format") + + if not self.listen_port: + raise ValueError("listen_port is required for WebsocketInput") + + if not os.path.isabs(self.html_path): + self.html_path = os.path.join(os.getcwd(), self.html_path) + + self.setup_message_handler() + + def setup_message_handler(self): + @self.socketio.on("message") + def handle_message(data): + try: + decoded_payload = decode_payload( + data, self.payload_encoding, self.payload_format + ) + socket_id = request.sid + message = Message( + payload=decoded_payload, user_properties={"socket_id": socket_id} + ) + event = Event(EventType.MESSAGE, message) + self.process_event_with_tracing(event) + except json.JSONDecodeError: + log.error("Received invalid payload: %s", data) + except AssertionError as e: + raise e + except Exception as e: + self.handle_component_error(e, event) + + def run(self): + self.run_server() + + def stop_component(self): + self.stop_server() + + def invoke(self, message, data): + try: + return { + "payload": message.get_payload(), + "topic": message.get_topic(), + "user_properties": message.get_user_properties(), + } + except Exception as e: + log.error("Error processing WebSocket message: %s", str(e)) + return None diff --git a/src/solace_ai_connector/components/inputs_outputs/websocket_output.py b/src/solace_ai_connector/components/inputs_outputs/websocket_output.py new file mode 100644 index 00000000..d7320645 --- /dev/null +++ b/src/solace_ai_connector/components/inputs_outputs/websocket_output.py @@ -0,0 +1,73 @@ +"""This component sends messages to a websocket connection.""" + +import copy +import threading +from ...common.log import log +from ...common.utils import encode_payload +from .websocket_base import WebsocketBase, base_info + +info = copy.deepcopy(base_info) +info.update( + { + "class_name": "WebsocketOutput", + "description": "Send messages to a websocket connection.", + "input_schema": { + "type": "object", + "properties": { + "payload": { + "type": "object", + "description": "The payload to be sent via WebSocket", + }, + "socket_id": { + "type": "string", + "description": "Identifier for the WebSocket connection", + }, + }, + "required": ["payload", "user_properties"], + }, + } +) + + +class WebsocketOutput(WebsocketBase): + def __init__(self, **kwargs): + super().__init__(info, **kwargs) + self.payload_encoding = self.get_config("payload_encoding") + self.payload_format = self.get_config("payload_format") + self.server_thread = None + + def run(self): + if self.listen_port: + self.server_thread = threading.Thread(target=self.run_server) + self.server_thread.start() + super().run() + + def stop_component(self): + self.stop_server() + if self.server_thread: + self.server_thread.join() + + def invoke(self, message, data): + try: + payload = data.get("payload") + socket_id = data.get("socket_id") + + if not socket_id: + log.error("No socket_id provided") + self.discard_current_message() + return None + + encoded_payload = encode_payload( + payload, self.payload_encoding, self.payload_format + ) + + if not self.send_to_socket(socket_id, encoded_payload): + self.discard_current_message() + return None + + except Exception as e: + log.error("Error sending message via WebSocket: %s", str(e)) + self.discard_current_message() + return None + + return data