diff --git a/opendevin/controller/__init__.py b/opendevin/controller/__init__.py index 2a5dba15e854..1801be081445 100644 --- a/opendevin/controller/__init__.py +++ b/opendevin/controller/__init__.py @@ -1,5 +1,5 @@ -from typing import List +from typing import List, Callable from opendevin.state import State from opendevin.agent import Agent @@ -27,12 +27,14 @@ def __init__( agent: Agent, workdir: str, max_iterations: int = 100, + callbacks: List[Callable] = [], ): self.agent = agent self.max_iterations = max_iterations self.workdir = workdir self.command_manager = CommandManager(workdir) self.state_updated_info: List[Action | Observation] = [] + self.callbacks = callbacks def get_current_state(self) -> State: # update observations & actions @@ -58,6 +60,8 @@ async def start_loop(self, task_instruction: str): action: Action = self.agent.step(state) self.state_updated_info.append(action) print("ACTION", action, flush=True) + for _callback_fn in self.callbacks: + _callback_fn(action) if isinstance(action, AgentFinishAction): print("FINISHED", flush=True) @@ -75,6 +79,8 @@ async def start_loop(self, task_instruction: str): observation: Observation = action.run(self) self.state_updated_info.append(observation) print(observation, flush=True) + for _callback_fn in self.callbacks: + _callback_fn(action) else: print("ACTION NOT EXECUTABLE", flush=True) diff --git a/opendevin/server/session.py b/opendevin/server/session.py index 95443c5c2a7c..ba60f19ce478 100644 --- a/opendevin/server/session.py +++ b/opendevin/server/session.py @@ -1,12 +1,43 @@ import os import asyncio -from typing import Optional +from typing import Optional, Dict, Type from fastapi import WebSocketDisconnect from opendevin.agent import Agent from opendevin.controller import AgentController -from opendevin.lib.event import Event + +from opendevin.action import ( + Action, + CmdRunAction, + CmdKillAction, + BrowseURLAction, + FileReadAction, + FileWriteAction, + AgentRecallAction, + AgentThinkAction, + AgentFinishAction, +) +from opendevin.observation import ( + Observation, + CmdOutputObservation, + UserMessageObservation, + AgentMessageObservation, + BrowserOutputObservation, +) + +# NOTE: this is a temporary solution - but hopefully we can use Action/Observation throughout the codebase +ACTION_TYPE_TO_CLASS: Dict[str, Type[Action]] = { + "run": CmdRunAction, + "kill": CmdKillAction, + "browse": BrowseURLAction, + "read": FileReadAction, + "write": FileWriteAction, + "recall": AgentRecallAction, + "think": AgentThinkAction, + "finish": AgentFinishAction, +} + def parse_event(data): if "action" not in data: @@ -18,7 +49,11 @@ def parse_event(data): message = None if "message" in data: message = data["message"] - return Event(action, args, message) + return { + "action": action, + "args": args, + "message": message, + } class Session: def __init__(self, websocket): @@ -55,15 +90,17 @@ async def start_listening(self): if event is None: await self.send_error("Invalid event") continue - if event.action == "initialize": + if event["action"] == "initialize": await self.create_controller(event) - elif event.action == "start": + elif event["action"] == "start": await self.start_task(event) else: if self.controller is None: await self.send_error("No agent started. Please wait a second...") else: - await self.controller.add_user_event(event) + action_cls = ACTION_TYPE_TO_CLASS[event["action"]] + action = action_cls(**event["args"]) + await self.controller.add_user_action(action) except WebSocketDisconnect as e: self.websocket = None @@ -83,10 +120,7 @@ async def create_controller(self, start_event=None): model = start_event.args["model"] AgentCls = Agent.get_cls(agent_cls) - self.agent = AgentCls( - workspace_dir=directory, - model_name=model, - ) + self.agent = AgentCls(model_name=model) self.controller = AgentController(self.agent, directory, callbacks=[self.on_agent_event]) await self.send_message("Control loop started") @@ -101,10 +135,20 @@ async def start_task(self, start_event): return self.agent_task = asyncio.create_task(self.controller.start_loop(task), name="agent loop") - def on_agent_event(self, event): - evt = { - "action": event.action, - "message": event.get_message(), - "args": event.args, - } + def on_agent_event(self, event: Observation | Action): + + # TODO: trying to find a potentially better solution here + # but not sure what "message" is for + if isinstance(event, Observation): + evt = { + "action": event.__class__.__name__, + "message": event.content, + "args": event.__dict__, + } + else: + evt = { + "action": event.__class__.__name__, + "args": event.__dict__, + "message": event.message, + } asyncio.create_task(self.send(evt), name="send event in callback")