Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Human Input Mode in AutoGen Studio #3484

Merged
merged 9 commits into from
Sep 8, 2024
180 changes: 91 additions & 89 deletions samples/apps/autogen-studio/autogenstudio/chatmanager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import asyncio
import os
from datetime import datetime
from queue import Queue
from typing import Any, Dict, List, Optional, Tuple, Union

import websockets
from fastapi import WebSocket, WebSocketDisconnect
from loguru import logger

from .datamodel import Message
from .websocket_connection_manager import WebSocketConnectionManager
from .workflowmanager import WorkflowManager


Expand All @@ -17,15 +16,19 @@ class AutoGenChatManager:
using an automated workflow configuration and message queue.
"""

def __init__(self, message_queue: Queue) -> None:
def __init__(
self, message_queue: Queue, websocket_manager: WebSocketConnectionManager = None, human_input_timeout: int = 180
) -> None:
"""
Initializes the AutoGenChatManager with a message queue.

:param message_queue: A queue to use for sending messages asynchronously.
"""
self.message_queue = message_queue
self.websocket_manager = websocket_manager
self.a_human_input_timeout = human_input_timeout

def send(self, message: str) -> None:
def send(self, message: dict) -> None:
"""
Sends a message by putting it into the message queue.

Expand All @@ -34,6 +37,45 @@ def send(self, message: str) -> None:
if self.message_queue is not None:
self.message_queue.put_nowait(message)

async def a_send(self, message: dict) -> None:
"""
Asynchronously sends a message via the WebSocketManager class

:param message: The message string to be sent.
"""
for connection, socket_client_id in self.websocket_manager.active_connections:
if message["connection_id"] == socket_client_id:
logger.info(
f"Sending message to connection_id: {message['connection_id']}. Connection ID: {socket_client_id}"
)
await self.websocket_manager.send_message(message, connection)
else:
logger.info(
f"Skipping message for connection_id: {message['connection_id']}. Connection ID: {socket_client_id}"
)

async def a_prompt_for_input(self, prompt: dict, timeout: int = 60) -> str:
"""
Sends the user a prompt and waits for a response asynchronously via the WebSocketManager class

:param message: The message string to be sent.
"""

for connection, socket_client_id in self.websocket_manager.active_connections:
if prompt["connection_id"] == socket_client_id:
logger.info(
f"Sending message to connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}"
)
try:
result = await self.websocket_manager.get_input(prompt, connection, timeout)
return result
except Exception as e:
return f"Error: {e}\nTERMINATE"
else:
logger.info(
f"Skipping message for connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}"
)

def chat(
self,
message: Message,
Expand Down Expand Up @@ -72,6 +114,7 @@ def chat(
history=history,
work_dir=work_dir,
send_message_function=self.send,
a_send_message_function=self.a_send,
connection_id=connection_id,
)

Expand All @@ -82,96 +125,55 @@ def chat(
result_message.session_id = message.session_id
return result_message


class WebSocketConnectionManager:
"""
Manages WebSocket connections including sending, broadcasting, and managing the lifecycle of connections.
"""

def __init__(
async def a_chat(
self,
active_connections: List[Tuple[WebSocket, str]] = None,
active_connections_lock: asyncio.Lock = None,
) -> None:
"""
Initializes WebSocketConnectionManager with an optional list of active WebSocket connections.

:param active_connections: A list of tuples, each containing a WebSocket object and its corresponding client_id.
"""
if active_connections is None:
active_connections = []
self.active_connections_lock = active_connections_lock
self.active_connections: List[Tuple[WebSocket, str]] = active_connections

async def connect(self, websocket: WebSocket, client_id: str) -> None:
message: Message,
history: List[Dict[str, Any]],
workflow: Any = None,
connection_id: Optional[str] = None,
user_dir: Optional[str] = None,
**kwargs,
) -> Message:
"""
Accepts a new WebSocket connection and appends it to the active connections list.
Processes an incoming message according to the agent's workflow configuration
and generates a response.

:param websocket: The WebSocket instance representing a client connection.
:param client_id: A string representing the unique identifier of the client.
:param message: An instance of `Message` representing an incoming message.
:param history: A list of dictionaries, each representing a past interaction.
:param flow_config: An instance of `AgentWorkFlowConfig`. If None, defaults to a standard configuration.
:param connection_id: An optional connection identifier.
:param kwargs: Additional keyword arguments.
:return: An instance of `Message` representing a response.
"""
await websocket.accept()
async with self.active_connections_lock:
self.active_connections.append((websocket, client_id))
print(f"New Connection: {client_id}, Total: {len(self.active_connections)}")

async def disconnect(self, websocket: WebSocket) -> None:
"""
Disconnects and removes a WebSocket connection from the active connections list.
# create a working director for workflow based on user_dir/session_id/time_hash
work_dir = os.path.join(
user_dir,
str(message.session_id),
datetime.now().strftime("%Y%m%d_%H-%M-%S"),
)
os.makedirs(work_dir, exist_ok=True)

:param websocket: The WebSocket instance to remove.
"""
async with self.active_connections_lock:
try:
self.active_connections = [conn for conn in self.active_connections if conn[0] != websocket]
print(f"Connection Closed. Total: {len(self.active_connections)}")
except ValueError:
print("Error: WebSocket connection not found")

async def disconnect_all(self) -> None:
"""
Disconnects all active WebSocket connections.
"""
for connection, _ in self.active_connections[:]:
await self.disconnect(connection)
# if no flow config is provided, use the default
if workflow is None:
raise ValueError("Workflow must be specified")

async def send_message(self, message: Union[Dict, str], websocket: WebSocket) -> None:
"""
Sends a JSON message to a single WebSocket connection.
workflow_manager = WorkflowManager(
workflow=workflow,
history=history,
work_dir=work_dir,
send_message_function=self.send,
a_send_message_function=self.a_send,
a_human_input_function=self.a_prompt_for_input,
a_human_input_timeout=self.a_human_input_timeout,
connection_id=connection_id,
)

:param message: A JSON serializable dictionary containing the message to send.
:param websocket: The WebSocket instance through which to send the message.
"""
try:
async with self.active_connections_lock:
await websocket.send_json(message)
except WebSocketDisconnect:
print("Error: Tried to send a message to a closed WebSocket")
await self.disconnect(websocket)
except websockets.exceptions.ConnectionClosedOK:
print("Error: WebSocket connection closed normally")
await self.disconnect(websocket)
except Exception as e:
print(f"Error in sending message: {str(e)}", message)
await self.disconnect(websocket)

async def broadcast(self, message: Dict) -> None:
"""
Broadcasts a JSON message to all active WebSocket connections.
message_text = message.content.strip()
result_message: Message = await workflow_manager.a_run(
message=f"{message_text}", clear_history=False, history=history
)

:param message: A JSON serializable dictionary containing the message to broadcast.
"""
# Create a message dictionary with the desired format
message_dict = {"message": message}

for connection, _ in self.active_connections[:]:
try:
if connection.client_state == websockets.protocol.State.OPEN:
# Call send_message method with the message dictionary and current WebSocket connection
await self.send_message(message_dict, connection)
else:
print("Error: WebSocket connection is closed")
await self.disconnect(connection)
except (WebSocketDisconnect, websockets.exceptions.ConnectionClosedOK) as e:
print(f"Error: WebSocket disconnected or closed({str(e)})")
await self.disconnect(connection)
result_message.user_id = message.user_id
result_message.session_id = message.session_id
return result_message
8 changes: 8 additions & 0 deletions samples/apps/autogen-studio/autogenstudio/database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ def init_db_samples(dbmanager: Any):
model="gpt-4-1106-preview", description="OpenAI GPT-4 model", user_id="[email protected]", api_type="open_ai"
)

anthropic_sonnet_model = Model(
model="claude-3-5-sonnet-20240620",
description="Anthropic's Claude 3.5 Sonnet model",
api_type="anthropic",
user_id="[email protected]",
)

# skills
generate_pdf_skill = Skill(
name="generate_and_save_pdf",
Expand Down Expand Up @@ -303,6 +310,7 @@ def init_db_samples(dbmanager: Any):
session.add(google_gemini_model)
session.add(azure_model)
session.add(gpt_4_model)
session.add(anthropic_sonnet_model)
session.add(generate_image_skill)
session.add(generate_pdf_skill)
session.add(user_proxy)
Expand Down
2 changes: 1 addition & 1 deletion samples/apps/autogen-studio/autogenstudio/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
VERSION = "0.1.4"
VERSION = "0.1.6"
__version__ = VERSION
APP_NAME = "autogenstudio"
13 changes: 10 additions & 3 deletions samples/apps/autogen-studio/autogenstudio/web/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from loguru import logger
from openai import OpenAIError

from ..chatmanager import AutoGenChatManager, WebSocketConnectionManager
from ..chatmanager import AutoGenChatManager
from ..database import workflow_from_id
from ..database.dbmanager import DBManager
from ..datamodel import Agent, Message, Model, Response, Session, Skill, Workflow
from ..profiler import Profiler
from ..utils import check_and_cast_datetime_fields, init_app_folders, md5_hash, test_model
from ..version import VERSION
from ..websocket_connection_manager import WebSocketConnectionManager

profiler = Profiler()
managers = {"chat": None} # manage calls to autogen
Expand Down Expand Up @@ -64,11 +65,17 @@ def message_handler():
database_engine_uri = folders["database_engine_uri"]
dbmanager = DBManager(engine_uri=database_engine_uri)

HUMAN_INPUT_TIMEOUT_SECONDS = 180


@asynccontextmanager
async def lifespan(app: FastAPI):
print("***** App started *****")
managers["chat"] = AutoGenChatManager(message_queue=message_queue)
managers["chat"] = AutoGenChatManager(
message_queue=message_queue,
websocket_manager=websocket_manager,
human_input_timeout=HUMAN_INPUT_TIMEOUT_SECONDS,
)
dbmanager.create_db_and_tables()

yield
Expand Down Expand Up @@ -449,7 +456,7 @@ async def run_session_workflow(message: Message, session_id: int, workflow_id: i
user_dir = os.path.join(folders["files_static_root"], "user", md5_hash(message.user_id))
os.makedirs(user_dir, exist_ok=True)
workflow = workflow_from_id(workflow_id, dbmanager=dbmanager)
agent_response: Message = managers["chat"].chat(
agent_response: Message = await managers["chat"].a_chat(
message=message,
history=user_message_history,
user_dir=user_dir,
Expand Down
Loading
Loading