diff --git a/README.md b/README.md index 994a5c1..fab8aec 100644 --- a/README.md +++ b/README.md @@ -29,4 +29,4 @@ Contributions are encouraged! Please read [CONTRIBUTING](CONTRIBUTING.md) for de ## License -See the LICENSE file for details. +See the [LICENSE](LICENSE) file for details. diff --git a/docs/configuration.md b/docs/configuration.md index 9394aa8..81136bb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -127,7 +127,7 @@ A flow is an instance of a pipeline that processes events in a sequential manner Flows can be communicating together if programmed to do so. For example, a flow can send a message to a broker and another flow can subscribe to the same topic to receive the message. -flows can be spread across multiple configuration files. The connector will merge the flows from all the files and run them together. +Flows can be spread across multiple configuration files. The connector will merge the flows from all the files and run them together. The `flows` section is a list of flow configurations. Each flow configuration is a dictionary with the following keys: @@ -135,6 +135,16 @@ following keys: - `name`: - The unique name of the flow - `components`: A list of component configurations. Check [Component Configuration](#component-configuration) for more details +```yaml + flows: + - name: + components: + - component_name: + - name: + components: + - component_name: +``` + ## Message Data Between each component in a flow, a message is passed. This message is a dictionary that is used to pass data between components within the same flow. The message object has different properties, some are available throughout the whole flow, some only between two immediate components, and some have other characteristics. @@ -153,7 +163,7 @@ This data type is available only after a topic subscription and then it will be - `previous`: The complete output of the previous component in the flow. This can be used to completely forward the output of the previous component as an input to the next component or be modified in the `input_transforms` section of the next component. -- transform specific variables: Some transforms function will add specific variables to the message object that are ONLY accessible in that transform. For example, the [`map` transform](./transforms/map.md) will add `item`, `index`, and `source_list` to the message object or the [`reduce` transform](./transforms/reduce.md) will add `accumulated_value`, `current_value`, and `source_list` to the message object. You can find these details in each transform documentation. +- Transform specific variables: Some transforms function will add specific variables to the message object that are ONLY accessible in that transform. For example, the [`map` transform](./transforms/map.md) will add `item`, `index`, and `source_list` to the message object or the [`reduce` transform](./transforms/reduce.md) will add `accumulated_value`, `current_value`, and `source_list` to the message object. You can find these details in each [transform](transforms/index.md) documentation. ## Expression Syntax @@ -601,4 +611,4 @@ You can find various usecase examples in the [examples directory](../examples/) --- -Checkout [components.md](./components/index.md), [transforms.md](./transforms/index.md), or [tips_and_tricks](tips_and_tricks.md) next. +Checkout [components](./components/index.md), [transforms](./transforms/index.md), or [tips_and_tricks](tips_and_tricks.md) next. diff --git a/docs/getting_started.md b/docs/getting_started.md index 53d847b..3887eb5 100755 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -16,13 +16,14 @@ To get started with creating a solace PubSub+ event broker follow the instructio ### Install the connector -Optionally create a virtual environment: +(Optional) Create a virtual environment: ```sh python3 -m venv env source env/bin/activate ``` +Set up the connector package ```sh pip install solace-ai-connector ``` @@ -53,6 +54,12 @@ export SOLACE_BROKER_PASSWORD=default export SOLACE_BROKER_VPN=default ``` +(Optional) Store the environment variables permanently in ~/.profile file and activate them by: + +```sh +source ~/.profile +``` + Run the connector: ```sh @@ -83,10 +90,9 @@ export OPENAI_API_ENDPOINT= export MODEL_NAME= ``` -Note that if you want to use the default OpenAI endpoint, just delete that line from the openai_chat.yaml file. +Note that if you want to use the default OpenAI endpoint, just delete that line from the langchain_openai_with_history_chat.yaml file. Install the langchain openai dependencies: - ```sh pip install langchain_openai openai ``` @@ -113,7 +119,7 @@ Payload: In the "Try Me!" also subscribe to `demo/joke/subject/response` to see the response -## Installation +## Running From Source Code 1. Clone the repository and enter its directory: @@ -123,7 +129,7 @@ In the "Try Me!" also subscribe to `demo/joke/subject/response` to see the respo cd solace-ai-connector ``` -2. Optionally create a virtual environment: +2. (Optional) Create a virtual environment: ```sh python -m venv .venv @@ -136,11 +142,13 @@ In the "Try Me!" also subscribe to `demo/joke/subject/response` to see the respo pip install -r requirements.txt ``` -## Configuration +### Configuration -1. Edit the example configuration file at the root of the repository: +1. (Optional) Edit the example configuration file at the root of the repository: + ```sh config.yaml + ``` 2. Set up the environment variables that you need for the config.yaml file. The default one requires the following variables: @@ -152,7 +160,7 @@ In the "Try Me!" also subscribe to `demo/joke/subject/response` to see the respo ``` -## Running the AI Event Connector +### Running the AI Event Connector 1. Start the AI Event Connector: @@ -176,4 +184,4 @@ make build --- -Checkout [configuration.md](configuration.md) or [overview.md](overview.md) next \ No newline at end of file +Checkout [configuration](configuration.md) or [overview](overview.md) next \ No newline at end of file diff --git a/examples/websocket/websocket.yaml b/examples/websocket/websocket.yaml new file mode 100644 index 0000000..4efabfa --- /dev/null +++ b/examples/websocket/websocket.yaml @@ -0,0 +1,42 @@ +--- + # Example configuration for a WebSocket flow + # This flow creates a WebSocket server that echoes messages back to clients + # It also serves an example HTML file for easy testing + + log: + stdout_log_level: INFO + log_file_level: DEBUG + log_file: solace_ai_connector.log + + flows: + - name: websocket_echo + components: + # WebSocket Input + - component_name: websocket_input + component_module: websocket_input + component_config: + listen_port: 5000 + serve_html: true + html_path: "examples/websocket/websocket_example_app.html" + + # Pass Through + - component_name: pass_through + component_module: pass_through + component_config: {} + input_transforms: + - type: copy + source_expression: input.payload + dest_expression: user_data.input:payload + - type: copy + source_expression: input.user_properties:socket_id + dest_expression: user_data.input:socket_id + input_selection: + source_expression: user_data.input + + # WebSocket Output + - component_name: websocket_output + component_module: websocket_output + component_config: + payload_encoding: none + input_selection: + source_expression: previous diff --git a/examples/websocket/websocket_example_app.html b/examples/websocket/websocket_example_app.html new file mode 100644 index 0000000..02d54db --- /dev/null +++ b/examples/websocket/websocket_example_app.html @@ -0,0 +1,266 @@ + + + + + + WebSocket Example App + + + + +
+

WebSocket Example App

+

This is a simple app to show how JSON can be sent into a solace-ai-connector flow and how to receive output from it. Just hit connect to connect to the flow, type in some JSON and hit send. Your JSON should be echoed back to you.

+ +
+ + + Disconnected +
+ +
+ + +
+
+ +
+

Received Messages

+
+ +
+
+ + + + + diff --git a/src/solace_ai_connector/common/messaging/dev_broker_messaging.py b/src/solace_ai_connector/common/messaging/dev_broker_messaging.py new file mode 100644 index 0000000..0d10cd6 --- /dev/null +++ b/src/solace_ai_connector/common/messaging/dev_broker_messaging.py @@ -0,0 +1,104 @@ +"""This is a simple broker for testing purposes. It allows sending and receiving +messages to/from queues. It supports subscriptions based on topics.""" + +from typing import Dict, List, Any +import queue +import re +from copy import deepcopy +from .messaging import Messaging + + +class DevBroker(Messaging): + def __init__(self, broker_properties: dict, flow_lock_manager, flow_kv_store): + super().__init__(broker_properties) + self.flow_lock_manager = flow_lock_manager + self.flow_kv_store = flow_kv_store + self.connected = False + self.subscriptions_lock = self.flow_lock_manager.get_lock("subscriptions") + with self.subscriptions_lock: + self.subscriptions = self.flow_kv_store.get("dev_broker:subscriptions") + if self.subscriptions is None: + self.subscriptions: Dict[str, List[str]] = {} + self.flow_kv_store.set("dev_broker:subscriptions", self.subscriptions) + self.queues = self.flow_kv_store.get("dev_broker:queues") + if self.queues is None: + self.queues: Dict[str, queue.Queue] = {} + self.flow_kv_store.set("dev_broker:queues", self.queues) + + def connect(self): + self.connected = True + queue_name = self.broker_properties.get("queue_name") + subscriptions = self.broker_properties.get("subscriptions", []) + if queue_name: + self.queues[queue_name] = queue.Queue() + for subscription in subscriptions: + self.subscribe(subscription["topic"], queue_name) + + def disconnect(self): + self.connected = False + + def receive_message(self, timeout_ms, queue_name: str): + if not self.connected: + raise RuntimeError("DevBroker is not connected") + + try: + return self.queues[queue_name].get(timeout=timeout_ms / 1000) + except queue.Empty: + return None + + def send_message( + self, + destination_name: str, + payload: Any, + user_properties: Dict = None, + user_context: Dict = None, + ): + if not self.connected: + raise RuntimeError("DevBroker is not connected") + + message = { + "payload": payload, + "topic": destination_name, + "user_properties": user_properties or {}, + } + + matching_queue_names = self._get_matching_queue_names(destination_name) + + for queue_name in matching_queue_names: + # Clone the message for each queue to ensure isolation + self.queues[queue_name].put(deepcopy(message)) + + if user_context and "callback" in user_context: + user_context["callback"](user_context) + + def subscribe(self, subscription: str, queue_name: str): + if not self.connected: + raise RuntimeError("DevBroker is not connected") + + subscription = self._subscription_to_regex(subscription) + + with self.subscriptions_lock: + if queue_name not in self.queues: + self.queues[queue_name] = queue.Queue() + if subscription not in self.subscriptions: + self.subscriptions[subscription] = [] + self.subscriptions[subscription].append(queue_name) + + def ack_message(self, message): + pass + + def _get_matching_queue_names(self, topic: str) -> List[str]: + matching_queue_names = [] + with self.subscriptions_lock: + for subscription, queue_names in self.subscriptions.items(): + if self._topic_matches(subscription, topic): + matching_queue_names.extend(queue_names) + return list(set(matching_queue_names)) # Remove duplicates + + @staticmethod + def _topic_matches(subscription: str, topic: str) -> bool: + return re.match(f"^{subscription}$", topic) is not None + + @staticmethod + def _subscription_to_regex(subscription: str) -> str: + return subscription.replace("*", "[^/]+").replace(">", ".*") diff --git a/src/solace_ai_connector/common/messaging/messaging.py b/src/solace_ai_connector/common/messaging/messaging.py index 0844863..5847eff 100644 --- a/src/solace_ai_connector/common/messaging/messaging.py +++ b/src/solace_ai_connector/common/messaging/messaging.py @@ -1,4 +1,4 @@ -# messaging.py - Base class for EDA messaging services +from typing import Any, Dict class Messaging: @@ -11,14 +11,14 @@ def connect(self): def disconnect(self): raise NotImplementedError - def receive_message(self, timeout_ms): + def receive_message(self, timeout_ms, queue_id: str): raise NotImplementedError - # def is_connected(self): - # raise NotImplementedError - - # def send_message(self, destination_name: str, message: str): - # raise NotImplementedError - - # def subscribe(self, subscription: str, message_handler): #: MessageHandler): - # raise NotImplementedError + def send_message( + self, + destination_name: str, + payload: Any, + user_properties: Dict = None, + user_context: Dict = None, + ): + raise NotImplementedError diff --git a/src/solace_ai_connector/common/messaging/messaging_builder.py b/src/solace_ai_connector/common/messaging/messaging_builder.py index 423d246..826cdd4 100644 --- a/src/solace_ai_connector/common/messaging/messaging_builder.py +++ b/src/solace_ai_connector/common/messaging/messaging_builder.py @@ -1,12 +1,15 @@ """Class to build a Messaging Service object""" from .solace_messaging import SolaceMessaging +from .dev_broker_messaging import DevBroker # Make a Messaging Service builder - this is a factory for Messaging Service objects class MessagingServiceBuilder: - def __init__(self): + def __init__(self, flow_lock_manager, flow_kv_store): self.broker_properties = {} + self.flow_lock_manager = flow_lock_manager + self.flow_kv_store = flow_kv_store def from_properties(self, broker_properties: dict): self.broker_properties = broker_properties @@ -15,6 +18,10 @@ def from_properties(self, broker_properties: dict): def build(self): if self.broker_properties["broker_type"] == "solace": return SolaceMessaging(self.broker_properties) + elif self.broker_properties["broker_type"] == "dev_broker": + return DevBroker( + self.broker_properties, self.flow_lock_manager, self.flow_kv_store + ) raise ValueError( f"Unsupported broker type: {self.broker_properties['broker_type']}" diff --git a/src/solace_ai_connector/common/messaging/solace_messaging.py b/src/solace_ai_connector/common/messaging/solace_messaging.py index ed33091..4b03e7a 100644 --- a/src/solace_ai_connector/common/messaging/solace_messaging.py +++ b/src/solace_ai_connector/common/messaging/solace_messaging.py @@ -246,8 +246,19 @@ def send_message( user_context=user_context, ) - def receive_message(self, timeout_ms): - return self.persistent_receivers[0].receive_message(timeout_ms) + def receive_message(self, timeout_ms, queue_id): + broker_message = self.persistent_receivers[0].receive_message(timeout_ms) + if broker_message is None: + return None + + # Convert Solace message to dictionary format + return { + "payload": broker_message.get_payload_as_string() + or broker_message.get_payload_as_bytes(), + "topic": broker_message.get_destination_name(), + "user_properties": broker_message.get_properties(), + "_original_message": broker_message, # Keep original message for acknowledgement + } def subscribe( self, subscription: str, persistent_receiver: PersistentMessageReceiver @@ -256,4 +267,7 @@ def subscribe( persistent_receiver.add_subscription(sub) def ack_message(self, broker_message): - self.persistent_receiver.ack(broker_message) + if "_original_message" in broker_message: + self.persistent_receiver.ack(broker_message["_original_message"]) + else: + log.warning("Cannot acknowledge message: original Solace message not found") diff --git a/src/solace_ai_connector/common/utils.py b/src/solace_ai_connector/common/utils.py index 54d6852..9050bdc 100755 --- a/src/solace_ai_connector/common/utils.py +++ b/src/solace_ai_connector/common/utils.py @@ -7,6 +7,11 @@ import builtins import subprocess import types +import base64 +import gzip +import json +import yaml + from .log import log @@ -136,8 +141,13 @@ def import_module(module, base_path=None, component_package=None): ) else: return importlib.import_module(full_name) - except ModuleNotFoundError: - pass + except ModuleNotFoundError as e: + name = str(e.name) + if ( + name != "solace_ai_connector" + and name.split(".")[-1] != full_name.split(".")[-1] + ): + raise e except Exception as e: raise ImportError( f"Module load error for {full_name}: {e}" @@ -337,3 +347,45 @@ def ensure_slash_on_start(string): if not string.startswith("/"): return "/" + string return string + + +def encode_payload(payload, encoding, payload_format): + # First, format the payload + if payload_format == "json": + formatted_payload = json.dumps(payload) + elif payload_format == "yaml": + formatted_payload = yaml.dump(payload) + elif isinstance(payload, bytes) or isinstance(payload, bytearray): + formatted_payload = payload + else: + formatted_payload = str(payload) + + # Then, encode the formatted payload + if encoding == "utf-8": + return formatted_payload.encode("utf-8") + elif encoding == "base64": + return base64.b64encode(formatted_payload.encode("utf-8")) + elif encoding == "gzip": + return gzip.compress(formatted_payload.encode("utf-8")) + else: + return formatted_payload + + +def decode_payload(payload, encoding, payload_format): + if encoding == "base64": + payload = base64.b64decode(payload) + elif encoding == "gzip": + payload = gzip.decompress(payload) + elif encoding == "utf-8" and ( + isinstance(payload, bytes) or isinstance(payload, bytearray) + ): + payload = payload.decode("utf-8") + elif encoding == "unicode_escape": + payload = payload.decode('unicode_escape') + + if payload_format == "json": + payload = json.loads(payload) + elif payload_format == "yaml": + payload = yaml.safe_load(payload) + + return payload diff --git a/src/solace_ai_connector/components/component_base.py b/src/solace_ai_connector/components/component_base.py index f059c06..f7c8c41 100644 --- a/src/solace_ai_connector/components/component_base.py +++ b/src/solace_ai_connector/components/component_base.py @@ -12,7 +12,7 @@ from ..common.event import Event, EventType from ..flow.request_response_flow_controller import RequestResponseFlowController -DEFAULT_QUEUE_TIMEOUT_MS = 200 +DEFAULT_QUEUE_TIMEOUT_MS = 1000 DEFAULT_QUEUE_MAX_DEPTH = 5 @@ -68,23 +68,29 @@ def run(self): try: event = self.get_next_event() if event is not None: - if self.trace_queue: - self.trace_event(event) - self.process_event(event) + self.process_event_with_tracing(event) except AssertionError as e: raise e except Exception as e: - log.error( - "%sComponent has crashed: %s\n%s", - self.log_identifier, - e, - traceback.format_exc(), - ) - if self.error_queue: - self.handle_error(e, event) + self.handle_component_error(e, event) self.stop_component() + def process_event_with_tracing(self, event): + if self.trace_queue: + self.trace_event(event) + self.process_event(event) + + def handle_component_error(self, e, event): + log.error( + "%sComponent has crashed: %s\n%s", + self.log_identifier, + e, + traceback.format_exc(), + ) + if self.error_queue: + self.handle_error(e, event) + def get_next_event(self): # Check if there is a get_next_message defined by a # component that inherits from this class - this is diff --git a/src/solace_ai_connector/components/general/langchain/langchain_chat_model_with_history.py b/src/solace_ai_connector/components/general/langchain/langchain_chat_model_with_history.py index b559ee4..9885c68 100644 --- a/src/solace_ai_connector/components/general/langchain/langchain_chat_model_with_history.py +++ b/src/solace_ai_connector/components/general/langchain/langchain_chat_model_with_history.py @@ -213,7 +213,7 @@ def invoke_model( True, ) - result = namedtuple("Result", ["content", "uuid"])( + result = namedtuple("Result", ["content", "response_uuid"])( aggregate_result, response_uuid ) @@ -233,7 +233,7 @@ def send_streaming_message( message = Message( payload={ "chunk": chunk, - "aggregate_result": aggregate_result, + "content": aggregate_result, "response_uuid": response_uuid, "first_chunk": first_chunk, "last_chunk": last_chunk, diff --git a/src/solace_ai_connector/components/general/openai/openai_chat_model_base.py b/src/solace_ai_connector/components/general/openai/openai_chat_model_base.py index dbd763d..d0a53d5 100755 --- a/src/solace_ai_connector/components/general/openai/openai_chat_model_base.py +++ b/src/solace_ai_connector/components/general/openai/openai_chat_model_base.py @@ -106,7 +106,27 @@ "content": { "type": "string", "description": "The generated response from the model", - } + }, + "chunk": { + "type": "string", + "description": "The current chunk of the response", + }, + "response_uuid": { + "type": "string", + "description": "The UUID of the response", + }, + "first_chunk": { + "type": "boolean", + "description": "Whether this is the first chunk of the response", + }, + "last_chunk": { + "type": "boolean", + "description": "Whether this is the last chunk of the response", + }, + "streaming": { + "type": "boolean", + "description": "Whether this is a streaming response", + }, }, "required": ["content"], }, @@ -221,7 +241,7 @@ def invoke_stream(self, client, message, messages): return { "content": aggregate_result, "chunk": current_batch, - "uuid": response_uuid, + "response_uuid": response_uuid, "first_chunk": first_chunk, "last_chunk": True, "streaming": True, @@ -237,7 +257,7 @@ def invoke_stream(self, client, message, messages): True, ) - return {"content": aggregate_result, "uuid": response_uuid} + return {"content": aggregate_result, "response_uuid": response_uuid} def send_streaming_message( self, @@ -251,10 +271,11 @@ def send_streaming_message( message = Message( payload={ "chunk": chunk, - "aggregate_result": aggregate_result, + "content": aggregate_result, "response_uuid": response_uuid, "first_chunk": first_chunk, "last_chunk": last_chunk, + "streaming": True, }, user_properties=input_message.get_user_properties(), ) @@ -272,18 +293,19 @@ def send_to_next_component( message = Message( payload={ "chunk": chunk, - "aggregate_result": aggregate_result, + "content": aggregate_result, "response_uuid": response_uuid, "first_chunk": first_chunk, "last_chunk": last_chunk, + "streaming": True, }, user_properties=input_message.get_user_properties(), ) result = { - "content": aggregate_result, "chunk": chunk, - "uuid": response_uuid, + "content": aggregate_result, + "response_uuid": response_uuid, "first_chunk": first_chunk, "last_chunk": last_chunk, "streaming": True, diff --git a/src/solace_ai_connector/components/inputs_outputs/broker_base.py b/src/solace_ai_connector/components/inputs_outputs/broker_base.py index a641427..c312740 100644 --- a/src/solace_ai_connector/components/inputs_outputs/broker_base.py +++ b/src/solace_ai_connector/components/inputs_outputs/broker_base.py @@ -1,17 +1,13 @@ """Base class for broker input/output components for the Solace AI Event Connector""" -import base64 -import gzip -import json -import yaml import uuid from abc import abstractmethod -# from solace_ai_connector.common.log import log from ..component_base import ComponentBase from ...common.message import Message from ...common.messaging.messaging_builder import MessagingServiceBuilder +from ...common.utils import encode_payload, decode_payload # TBD - at the moment, there is no connection sharing supported. It should be possible # to share a connection between multiple components and even flows. The changes @@ -39,7 +35,7 @@ def __init__(self, module_info, **kwargs): self.broker_properties = self.get_broker_properties() if self.broker_properties["broker_type"] not in ["test", "test_streaming"]: self.messaging_service = ( - MessagingServiceBuilder() + MessagingServiceBuilder(self.flow_lock_manager, self.flow_kv_store) .from_properties(self.broker_properties) .build() ) @@ -68,54 +64,12 @@ def stop_component(self): def decode_payload(self, payload): encoding = self.get_config("payload_encoding") payload_format = self.get_config("payload_format") - if encoding == "base64": - payload = base64.b64decode(payload) - elif encoding == "gzip": - payload = gzip.decompress(payload) - elif encoding == "utf-8" and ( - isinstance(payload, bytes) or isinstance(payload, bytearray) - ): - payload = payload.decode("utf-8") - elif encoding == "unicode_escape": - payload = payload.decode('unicode_escape') - - if payload_format == "json": - payload = json.loads(payload) - elif payload_format == "yaml": - payload = yaml.safe_load(payload) - return payload + return decode_payload(payload, encoding, payload_format) def encode_payload(self, payload): encoding = self.get_config("payload_encoding") payload_format = self.get_config("payload_format") - if encoding == "utf-8": - if payload_format == "json": - return json.dumps(payload).encode("utf-8") - elif payload_format == "yaml": - return yaml.dump(payload).encode("utf-8") - else: - return str(payload).encode("utf-8") - elif encoding == "base64": - if payload_format == "json": - return base64.b64encode(json.dumps(payload).encode("utf-8")) - elif payload_format == "yaml": - return base64.b64encode(yaml.dump(payload).encode("utf-8")) - else: - return base64.b64encode(str(payload).encode("utf-8")) - elif encoding == "gzip": - if payload_format == "json": - return gzip.compress(json.dumps(payload).encode("utf-8")) - elif payload_format == "yaml": - return gzip.compress(yaml.dump(payload).encode("utf-8")) - else: - return gzip.compress(str(payload).encode("utf-8")) - else: - if payload_format == "json": - return json.dumps(payload) - elif payload_format == "yaml": - return yaml.dump(payload) - else: - return str(payload) + return encode_payload(payload, encoding, payload_format) def get_egress_topic(self, message: Message): pass diff --git a/src/solace_ai_connector/components/inputs_outputs/broker_input.py b/src/solace_ai_connector/components/inputs_outputs/broker_input.py index 3aabd8d..2d277cb 100644 --- a/src/solace_ai_connector/components/inputs_outputs/broker_input.py +++ b/src/solace_ai_connector/components/inputs_outputs/broker_input.py @@ -110,16 +110,18 @@ def invoke(self, message, data): def get_next_message(self, timeout_ms=None): if timeout_ms is None: timeout_ms = DEFAULT_TIMEOUT_MS - broker_message = self.messaging_service.receive_message(timeout_ms) + broker_message = self.messaging_service.receive_message( + timeout_ms, self.broker_properties["queue_name"] + ) if not broker_message: return None self.current_broker_message = broker_message - payload = broker_message.get_payload_as_string() - topic = broker_message.get_destination_name() - if payload is None: - payload = broker_message.get_payload_as_bytes() + + payload = broker_message.get("payload") payload = self.decode_payload(payload) - user_properties = broker_message.get_properties() + + topic = broker_message.get("topic") + user_properties = broker_message.get("user_properties", {}) log.debug( "Received message from broker: topic=%s, user_properties=%s, payload length=%d", topic, diff --git a/src/solace_ai_connector/components/inputs_outputs/broker_request_response.py b/src/solace_ai_connector/components/inputs_outputs/broker_request_response.py index cb217b9..4c33ddb 100644 --- a/src/solace_ai_connector/components/inputs_outputs/broker_request_response.py +++ b/src/solace_ai_connector/components/inputs_outputs/broker_request_response.py @@ -193,11 +193,12 @@ def __init__(self, **kwargs): ] self.test_mode = False - if self.broker_type == "solace": - self.connect() - elif self.broker_type == "test" or self.broker_type == "test_streaming": + if self.broker_type == "test" or self.broker_type == "test_streaming": self.test_mode = True self.setup_test_pass_through() + else: + self.connect() + self.start() def start(self): @@ -224,7 +225,9 @@ def start_response_thread(self): def handle_responses(self): while not self.stop_signal.is_set(): try: - broker_message = self.messaging_service.receive_message(1000) + broker_message = self.messaging_service.receive_message( + 1000, self.reply_queue_name + ) if broker_message: self.process_response(broker_message) except Exception as e: @@ -248,12 +251,10 @@ def process_response(self, broker_message): topic = broker_message.get_topic() user_properties = broker_message.get_user_properties() else: - payload = broker_message.get_payload_as_string() - if payload is None: - payload = broker_message.get_payload_as_bytes() + payload = broker_message.get("payload") payload = self.decode_payload(payload) - topic = broker_message.get_destination_name() - user_properties = broker_message.get_properties() + topic = broker_message.get("topic") + user_properties = broker_message.get("user_properties", {}) metadata_json = user_properties.get( "__solace_ai_connector_broker_request_reply_metadata__" diff --git a/src/solace_ai_connector/components/inputs_outputs/websocket_base.py b/src/solace_ai_connector/components/inputs_outputs/websocket_base.py new file mode 100644 index 0000000..ffe1ad9 --- /dev/null +++ b/src/solace_ai_connector/components/inputs_outputs/websocket_base.py @@ -0,0 +1,143 @@ +"""Base class for WebSocket components.""" + +from abc import ABC, abstractmethod +from flask import Flask, send_file, request +from flask_socketio import SocketIO +import logging +from ...common.log import log +from ..component_base import ComponentBase +import copy +from flask.logging import default_handler + +base_info = { + "config_parameters": [ + { + "name": "listen_port", + "type": "int", + "required": False, + "description": "Port to listen on (optional)", + }, + { + "name": "serve_html", + "type": "bool", + "required": False, + "description": "Serve the example HTML file", + "default": False, + }, + { + "name": "html_path", + "type": "string", + "required": False, + "description": "Path to the HTML file to serve", + "default": "examples/websocket/websocket_example_app.html", + }, + { + "name": "payload_encoding", + "required": False, + "description": "Encoding for the payload (utf-8, base64, gzip, none)", + "default": "none", + }, + { + "name": "payload_format", + "required": False, + "description": "Format for the payload (json, yaml, text)", + "default": "json", + }, + ], +} + + +class WebsocketBase(ComponentBase, ABC): + def __init__(self, info, **kwargs): + super().__init__(info, **kwargs) + self.listen_port = self.get_config("listen_port") + self.serve_html = self.get_config("serve_html", False) + self.html_path = self.get_config("html_path", "") + self.sockets = {} + self.app = None + self.socketio = None + + if self.listen_port: + self.setup_websocket_server() + + def setup_websocket_server(self): + self.app = Flask(__name__) + + # Enable Flask debugging + self.app.debug = False + + # Set up Flask logging + # self.app.logger.setLevel(logging.DEBUG) + # self.app.logger.addHandler(default_handler) + + # Enable SocketIO logging + # logging.getLogger("socketio").setLevel(logging.DEBUG) + # logging.getLogger("engineio").setLevel(logging.DEBUG) + + self.socketio = SocketIO( + self.app, cors_allowed_origins="*", logger=False, engineio_logger=False + ) + self.setup_websocket() + + if self.serve_html: + self.setup_html_route() + + def setup_html_route(self): + @self.app.route("/") + def serve_html(): + return send_file(self.html_path) + + def setup_websocket(self): + @self.socketio.on("connect") + def handle_connect(): + socket_id = request.sid + self.sockets[socket_id] = self.socketio + self.kv_store_set("websocket_connections", self.sockets) + log.info("New WebSocket connection established. Socket ID: %s", socket_id) + return socket_id + + @self.socketio.on("disconnect") + def handle_disconnect(): + socket_id = request.sid + if socket_id in self.sockets: + del self.sockets[socket_id] + self.kv_store_set("websocket_connections", self.sockets) + log.info("WebSocket connection closed. Socket ID: %s", socket_id) + + def run_server(self): + if self.socketio: + self.socketio.run( + self.app, port=self.listen_port, debug=False, use_reloader=False + ) + + def stop_server(self): + if self.socketio: + self.socketio.stop() + if self.app: + func = request.environ.get("werkzeug.server.shutdown") + if func is None: + raise RuntimeError("Not running with the Werkzeug Server") + func() + + def get_sockets(self): + if not self.sockets: + self.sockets = self.kv_store_get("websocket_connections") or {} + return self.sockets + + def send_to_socket(self, socket_id, payload): + sockets = self.get_sockets() + if socket_id == "*": + for socket in sockets.values(): + socket.emit("message", payload) + log.debug("Message sent to all WebSocket connections") + elif socket_id in sockets: + sockets[socket_id].emit("message", payload) + log.debug("Message sent to WebSocket connection %s", socket_id) + else: + log.error("No active connection found for socket_id: %s", socket_id) + return False + return True + + @abstractmethod + def invoke(self, message, data): + pass diff --git a/src/solace_ai_connector/components/inputs_outputs/websocket_input.py b/src/solace_ai_connector/components/inputs_outputs/websocket_input.py new file mode 100644 index 0000000..6b80617 --- /dev/null +++ b/src/solace_ai_connector/components/inputs_outputs/websocket_input.py @@ -0,0 +1,84 @@ +"""This component receives messages from a websocket connection and sends them to the next component in the flow.""" + +import json +import os +import copy + +from flask import request +from ...common.log import log +from ...common.message import Message +from ...common.event import Event, EventType +from ...common.utils import decode_payload +from .websocket_base import WebsocketBase, base_info + + +# Merge base_info into info +info = copy.deepcopy(base_info) +info.update( + { + "class_name": "WebsocketInput", + "description": "Listen for incoming messages on a websocket connection.", + "output_schema": { + "type": "object", + "properties": { + "payload": { + "type": "object", + "description": "The decoded JSON payload received from the WebSocket", + }, + }, + "required": ["payload"], + }, + } +) + + +class WebsocketInput(WebsocketBase): + def __init__(self, **kwargs): + super().__init__(info, **kwargs) + self.payload_encoding = self.get_config("payload_encoding") + self.payload_format = self.get_config("payload_format") + + if not self.listen_port: + raise ValueError("listen_port is required for WebsocketInput") + + if not os.path.isabs(self.html_path): + self.html_path = os.path.join(os.getcwd(), self.html_path) + + self.setup_message_handler() + + def setup_message_handler(self): + @self.socketio.on("message") + def handle_message(data): + try: + decoded_payload = decode_payload( + data, self.payload_encoding, self.payload_format + ) + socket_id = request.sid + message = Message( + payload=decoded_payload, user_properties={"socket_id": socket_id} + ) + event = Event(EventType.MESSAGE, message) + self.process_event_with_tracing(event) + except json.JSONDecodeError: + log.error("Received invalid payload: %s", data) + except AssertionError as e: + raise e + except Exception as e: + self.handle_component_error(e, event) + + def run(self): + self.run_server() + + def stop_component(self): + self.stop_server() + + def invoke(self, message, data): + try: + return { + "payload": message.get_payload(), + "topic": message.get_topic(), + "user_properties": message.get_user_properties(), + } + except Exception as e: + log.error("Error processing WebSocket message: %s", str(e)) + return None diff --git a/src/solace_ai_connector/components/inputs_outputs/websocket_output.py b/src/solace_ai_connector/components/inputs_outputs/websocket_output.py new file mode 100644 index 0000000..d732064 --- /dev/null +++ b/src/solace_ai_connector/components/inputs_outputs/websocket_output.py @@ -0,0 +1,73 @@ +"""This component sends messages to a websocket connection.""" + +import copy +import threading +from ...common.log import log +from ...common.utils import encode_payload +from .websocket_base import WebsocketBase, base_info + +info = copy.deepcopy(base_info) +info.update( + { + "class_name": "WebsocketOutput", + "description": "Send messages to a websocket connection.", + "input_schema": { + "type": "object", + "properties": { + "payload": { + "type": "object", + "description": "The payload to be sent via WebSocket", + }, + "socket_id": { + "type": "string", + "description": "Identifier for the WebSocket connection", + }, + }, + "required": ["payload", "user_properties"], + }, + } +) + + +class WebsocketOutput(WebsocketBase): + def __init__(self, **kwargs): + super().__init__(info, **kwargs) + self.payload_encoding = self.get_config("payload_encoding") + self.payload_format = self.get_config("payload_format") + self.server_thread = None + + def run(self): + if self.listen_port: + self.server_thread = threading.Thread(target=self.run_server) + self.server_thread.start() + super().run() + + def stop_component(self): + self.stop_server() + if self.server_thread: + self.server_thread.join() + + def invoke(self, message, data): + try: + payload = data.get("payload") + socket_id = data.get("socket_id") + + if not socket_id: + log.error("No socket_id provided") + self.discard_current_message() + return None + + encoded_payload = encode_payload( + payload, self.payload_encoding, self.payload_format + ) + + if not self.send_to_socket(socket_id, encoded_payload): + self.discard_current_message() + return None + + except Exception as e: + log.error("Error sending message via WebSocket: %s", str(e)) + self.discard_current_message() + return None + + return data diff --git a/src/solace_ai_connector/main.py b/src/solace_ai_connector/main.py index a62cf00..7bf43cf 100644 --- a/src/solace_ai_connector/main.py +++ b/src/solace_ai_connector/main.py @@ -2,15 +2,18 @@ import sys import re import yaml +from pathlib import Path from .solace_ai_connector import SolaceAiConnector def load_config(file): """Load configuration from a YAML file.""" try: - # Load the YAML file as a string - with open(file, "r", encoding="utf8") as f: - yaml_str = f.read() + # Get the directory of the current file + file_dir = os.path.dirname(os.path.abspath(file)) + + # Load the YAML file as a string, processing includes + yaml_str = process_includes(file, file_dir) # Substitute the environment variables using os.environ yaml_str = expandvars_with_defaults(yaml_str) @@ -23,6 +26,31 @@ def load_config(file): sys.exit(1) +def process_includes(file_path, base_dir): + """Process #include directives in the given file.""" + with open(file_path, "r", encoding="utf8") as f: + content = f.read() + + def include_repl(match): + indent = match.group(1) # Capture the leading spaces + indent = indent.replace("\n", "") # Remove newlines + include_path = match.group(2).strip("'\"") + full_path = os.path.join(base_dir, include_path) + if not os.path.exists(full_path): + raise FileNotFoundError(f"Included file not found: {full_path}") + included_content = process_includes(full_path, os.path.dirname(full_path)) + # Indent each line of the included content + indented_content = "\n".join( + indent + line for line in included_content.splitlines() + ) + return indented_content + + include_pattern = re.compile( + r'^(\s*)!include\s+(["\']?[^"\s\']+)["\']?', re.MULTILINE + ) + return include_pattern.sub(include_repl, content) + + def expandvars_with_defaults(text): """Expand environment variables with support for default values. Supported syntax: ${VAR_NAME} or ${VAR_NAME, default_value}"""