Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
xingyaoww committed Mar 24, 2024
1 parent 43440c6 commit ae90a7b
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 19 deletions.
2 changes: 1 addition & 1 deletion agenthub/codeact_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
]
print(colored("===USER:===\n" + instruction, "green"))

def step(self, state: State) -> None:
def step(self, state: State) -> Action:
updated_info = state.updated_info

if updated_info:
Expand Down
7 changes: 2 additions & 5 deletions agenthub/langchains_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,16 @@ def _initialize(self):
if thought.startswith("RUN"):
command = thought.split("RUN ")[1]
d = {"action": "run", "args": {"command": command}}
d = CmdRunAction(command=command)
next_is_output = True

elif thought.startswith("RECALL"):
query = thought.split("RECALL ")[1]
d = {"action": "recall", "args": {"query": query}}
d = AgentRecallAction(query=query)
next_is_output = True

elif thought.startswith("BROWSE"):
url = thought.split("BROWSE ")[1]
d = {"action": "browse", "args": {"url": url}}
d = BrowseURLAction(url=url)
next_is_output = True
else:
d = {"action": "think", "args": {"thought": thought}}
Expand All @@ -142,13 +139,13 @@ def step(self, state: State) -> Action:
if info.error:
d = {"action": "error", "args": {"output": info.content}}
else:
d = {"action": "output", "args": {"output": info.output}}
d = {"action": "output", "args": {"output": info.content}}
# elif isinstance(info, UserMessageObservation):
# d = {"action": "output", "args": {"output": info.message}}
# elif isinstance(info, AgentMessageObservation):
# d = {"action": "output", "args": {"output": info.message}}
elif isinstance(info, BrowserOutputObservation):
d = {"action": "output", "args": {"output": info.output}}
d = {"action": "output", "args": {"output": info.content}}
else:
raise NotImplementedError(f"Unknown observation type: {info}")
self._add_event(d)
Expand Down
8 changes: 4 additions & 4 deletions agenthub/langchains_agent/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_chain(template, model_name):
assert (
"OPENAI_API_KEY" in os.environ
), "Please set the OPENAI_API_KEY environment variable to use langchains_agent."
llm = ChatOpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"), model_name=model_name)
llm = ChatOpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"), model_name=model_name) # type: ignore
prompt = PromptTemplate.from_template(template)
llm_chain = LLMChain(prompt=prompt, llm=llm)
return llm_chain
Expand All @@ -128,7 +128,7 @@ def request_action(
task,
thoughts: List[dict],
model_name: str,
background_commands_obs: Mapping[int, CmdOutputObservation] = [],
background_commands_obs: List[CmdOutputObservation] = [],
):
llm_chain = get_chain(ACTION_PROMPT, model_name)
parser = JsonOutputParser(pydantic_object=_ActionDict)
Expand All @@ -146,8 +146,8 @@ def request_action(
bg_commands_message = ""
if len(background_commands_obs) > 0:
bg_commands_message = "The following commands are running in the background:"
for id, command_obs in background_commands_obs.items():
bg_commands_message += f"\n`{id}`: {command_obs.command}"
for command_obs in background_commands_obs:
bg_commands_message += f"\n`{command_obs.command_id}`: {command_obs.command}"
bg_commands_message += "\nYou can end any process by sending a `kill` action with the numerical `id` above."

latest_thought = thoughts[-1]
Expand Down
2 changes: 1 addition & 1 deletion agenthub/langchains_agent/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self):
def add_event(self, event):
doc = Document(
text=json.dumps(event),
doc_id=self.thought_idx,
doc_id=str(self.thought_idx),
extra_info={
"type": event["action"],
"idx": self.thought_idx,
Expand Down
9 changes: 6 additions & 3 deletions opendevin/action/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from opendevin.observation import AgentMessageObservation, Observation
from opendevin.observation import AgentRecallObservation, AgentMessageObservation, Observation
from .base import ExecutableAction, NotExecutableAction
if TYPE_CHECKING:
from opendevin.controller import AgentController
Expand All @@ -11,8 +11,11 @@
class AgentRecallAction(ExecutableAction):
query: str

def run(self, controller: "AgentController") -> AgentMessageObservation:
return AgentMessageObservation(controller.agent.search_memory(self.query))
def run(self, controller: "AgentController") -> AgentRecallObservation:
return AgentRecallObservation(
content="Recalling memories...",
memories=controller.agent.search_memory(self.query)
)


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion opendevin/action/browse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
import requests

from dataclasses import dataclass
from opendevin.observation import BrowserOutputObservation

from .base import ExecutableAction
Expand Down
1 change: 0 additions & 1 deletion opendevin/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def reset(self) -> None:
"""
self.instruction = ""
self._complete = False
self._history = []

@classmethod
def register(cls, name: str, agent_cls: Type["Agent"]):
Expand Down
3 changes: 2 additions & 1 deletion opendevin/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse

from typing import Type
from opendevin.agent import Agent
from opendevin.controller import AgentController

Expand Down Expand Up @@ -44,7 +45,7 @@

print(f"Running agent {args.agent_cls} (model: {args.model_name}, directory: {args.directory}) with task: \"{args.task}\"")

AgentCls: Agent = Agent.get_cls(args.agent_cls)
AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)
agent = AgentCls(instruction=args.task, model_name=args.model_name)

controller = AgentController(
Expand Down
10 changes: 10 additions & 0 deletions opendevin/observation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from dataclasses import dataclass

@dataclass
Expand Down Expand Up @@ -50,3 +51,12 @@ class AgentMessageObservation(Observation):
This data class represents a message sent by the agent.
"""
role: str = "assistant"

@dataclass
class AgentRecallObservation(Observation):
"""
This data class represents a list of memories recalled by the agent.
"""
memories: List[str]
role: str = "assistant"

4 changes: 2 additions & 2 deletions opendevin/state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Mapping, List
from typing import List

from opendevin.action import (
Action,
Expand All @@ -12,5 +12,5 @@

@dataclass
class State:
background_commands_obs: Mapping[int, CmdOutputObservation]
background_commands_obs: List[CmdOutputObservation]
updated_info: List[Action | Observation]

0 comments on commit ae90a7b

Please sign in to comment.