Skip to content

Commit

Permalink
attempt to fix session
Browse files Browse the repository at this point in the history
  • Loading branch information
xingyaoww committed Mar 24, 2024
1 parent fb10f2d commit c405f21
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 17 deletions.
8 changes: 7 additions & 1 deletion opendevin/controller/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from typing import List
from typing import List, Callable

from opendevin.state import State
from opendevin.agent import Agent
Expand Down Expand Up @@ -27,12 +27,14 @@ def __init__(
agent: Agent,
workdir: str,
max_iterations: int = 100,
callbacks: List[Callable] = [],
):
self.agent = agent
self.max_iterations = max_iterations
self.workdir = workdir
self.command_manager = CommandManager(workdir)
self.state_updated_info: List[Action | Observation] = []
self.callbacks = callbacks

def get_current_state(self) -> State:
# update observations & actions
Expand All @@ -58,6 +60,8 @@ async def start_loop(self, task_instruction: str):
action: Action = self.agent.step(state)
self.state_updated_info.append(action)
print("ACTION", action, flush=True)
for _callback_fn in self.callbacks:
_callback_fn(action)

if isinstance(action, AgentFinishAction):
print("FINISHED", flush=True)
Expand All @@ -75,6 +79,8 @@ async def start_loop(self, task_instruction: str):
observation: Observation = action.run(self)
self.state_updated_info.append(observation)
print(observation, flush=True)
for _callback_fn in self.callbacks:
_callback_fn(action)
else:
print("ACTION NOT EXECUTABLE", flush=True)

Expand Down
76 changes: 60 additions & 16 deletions opendevin/server/session.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,43 @@
import os
import asyncio
from typing import Optional
from typing import Optional, Dict, Type

from fastapi import WebSocketDisconnect

from opendevin.agent import Agent
from opendevin.controller import AgentController
from opendevin.lib.event import Event

from opendevin.action import (
Action,
CmdRunAction,
CmdKillAction,
BrowseURLAction,
FileReadAction,
FileWriteAction,
AgentRecallAction,
AgentThinkAction,
AgentFinishAction,
)
from opendevin.observation import (
Observation,
CmdOutputObservation,
UserMessageObservation,
AgentMessageObservation,
BrowserOutputObservation,
)

# NOTE: this is a temporary solution - but hopefully we can use Action/Observation throughout the codebase
ACTION_TYPE_TO_CLASS: Dict[str, Type[Action]] = {
"run": CmdRunAction,
"kill": CmdKillAction,
"browse": BrowseURLAction,
"read": FileReadAction,
"write": FileWriteAction,
"recall": AgentRecallAction,
"think": AgentThinkAction,
"finish": AgentFinishAction,
}


def parse_event(data):
if "action" not in data:
Expand All @@ -18,7 +49,11 @@ def parse_event(data):
message = None
if "message" in data:
message = data["message"]
return Event(action, args, message)
return {
"action": action,
"args": args,
"message": message,
}

class Session:
def __init__(self, websocket):
Expand Down Expand Up @@ -55,15 +90,17 @@ async def start_listening(self):
if event is None:
await self.send_error("Invalid event")
continue
if event.action == "initialize":
if event["action"] == "initialize":
await self.create_controller(event)
elif event.action == "start":
elif event["action"] == "start":
await self.start_task(event)
else:
if self.controller is None:
await self.send_error("No agent started. Please wait a second...")
else:
await self.controller.add_user_event(event)
action_cls = ACTION_TYPE_TO_CLASS[event["action"]]
action = action_cls(**event["args"])
await self.controller.add_user_action(action)

except WebSocketDisconnect as e:
self.websocket = None
Expand All @@ -83,10 +120,7 @@ async def create_controller(self, start_event=None):
model = start_event.args["model"]

AgentCls = Agent.get_cls(agent_cls)
self.agent = AgentCls(
workspace_dir=directory,
model_name=model,
)
self.agent = AgentCls(model_name=model)
self.controller = AgentController(self.agent, directory, callbacks=[self.on_agent_event])
await self.send_message("Control loop started")

Expand All @@ -101,10 +135,20 @@ async def start_task(self, start_event):
return
self.agent_task = asyncio.create_task(self.controller.start_loop(task), name="agent loop")

def on_agent_event(self, event):
evt = {
"action": event.action,
"message": event.get_message(),
"args": event.args,
}
def on_agent_event(self, event: Observation | Action):

# TODO: trying to find a potentially better solution here
# but not sure what "message" is for
if isinstance(event, Observation):
evt = {
"action": event.__class__.__name__,
"message": event.content,
"args": event.__dict__,
}
else:
evt = {
"action": event.__class__.__name__,
"args": event.__dict__,
"message": event.message,
}
asyncio.create_task(self.send(evt), name="send event in callback")

0 comments on commit c405f21

Please sign in to comment.