From be74f0685c21cb8e6fc318a171676645c2f6ab6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=AE=AE=E0=AE=A9=E0=AF=8B=E0=AE=9C=E0=AF=8D=E0=AE=95?= =?UTF-8?q?=E0=AF=81=E0=AE=AE=E0=AE=BE=E0=AE=B0=E0=AF=8D=20=E0=AE=AA?= =?UTF-8?q?=E0=AE=B4=E0=AE=A9=E0=AE=BF=E0=AE=9A=E0=AF=8D=E0=AE=9A=E0=AE=BE?= =?UTF-8?q?=E0=AE=AE=E0=AE=BF?= Date: Tue, 13 Aug 2024 13:59:28 +0530 Subject: [PATCH] Feat: Regenerate message --- frontend/src/components/chat/ChatInterface.tsx | 16 ++++++++++++++-- frontend/src/services/chatService.ts | 9 +++++++++ frontend/src/state/chatSlice.ts | 6 +++++- frontend/src/types/ActionType.tsx | 3 +++ opendevin/controller/agent_controller.py | 8 ++++++-- opendevin/core/schema/action.py | 4 ++++ opendevin/events/action/__init__.py | 3 ++- opendevin/events/action/message.py | 7 +++++++ opendevin/events/serialization/action.py | 3 ++- opendevin/events/stream.py | 9 +++++++++ 10 files changed, 61 insertions(+), 7 deletions(-) diff --git a/frontend/src/components/chat/ChatInterface.tsx b/frontend/src/components/chat/ChatInterface.tsx index ac95f66524c..eb135ff8548 100644 --- a/frontend/src/components/chat/ChatInterface.tsx +++ b/frontend/src/components/chat/ChatInterface.tsx @@ -11,12 +11,13 @@ import Chat from "./Chat"; import TypingIndicator from "./TypingIndicator"; import { RootState } from "#/store"; import AgentState from "#/types/AgentState"; -import { sendChatMessage } from "#/services/chatService"; -import { addUserMessage, addAssistantMessage } from "#/state/chatSlice"; +import { sendChatMessage, regenerateLastMessage } from "#/services/chatService"; +import { addUserMessage, addAssistantMessage, removeLastAssistantMessage } from "#/state/chatSlice"; import { I18nKey } from "#/i18n/declaration"; import { useScrollToBottom } from "#/hooks/useScrollToBottom"; import FeedbackModal from "../modals/feedback/FeedbackModal"; import beep from "#/utils/beep"; +import { FaSyncAlt } from "react-icons/fa"; interface ScrollButtonProps { onClick: () => void; @@ -91,6 +92,12 @@ function ChatInterface() { ); }; + + const handleRegenerateClick = () => { + dispatch(removeLastAssistantMessage()); + regenerateLastMessage(); + }; + const scrollRef = useRef(null); const { scrollDomToBottom, onChatBodyScroll, hitBottom } = @@ -184,6 +191,11 @@ function ChatInterface() { icon={} label="" /> + } + label="" + /> )} diff --git a/frontend/src/services/chatService.ts b/frontend/src/services/chatService.ts index d857fb603f4..dc44701cbfd 100644 --- a/frontend/src/services/chatService.ts +++ b/frontend/src/services/chatService.ts @@ -9,3 +9,12 @@ export function sendChatMessage(message: string, images_urls: string[]): void { const eventString = JSON.stringify(event); Session.send(eventString); } + +export function regenerateLastMessage(): void { + const event = { + action: ActionType.REGENERATE, + args: {}, + }; + const eventString = JSON.stringify(event); + Session.send(eventString); +} diff --git a/frontend/src/state/chatSlice.ts b/frontend/src/state/chatSlice.ts index a1b01fa7768..d83048e871a 100644 --- a/frontend/src/state/chatSlice.ts +++ b/frontend/src/state/chatSlice.ts @@ -34,9 +34,13 @@ export const chatSlice = createSlice({ clearMessages(state) { state.messages = []; }, + + removeLastAssistantMessage(state) { + state.messages.pop(); + }, }, }); -export const { addUserMessage, addAssistantMessage, clearMessages } = +export const { addUserMessage, addAssistantMessage, clearMessages, removeLastAssistantMessage } = chatSlice.actions; export default chatSlice.reducer; diff --git a/frontend/src/types/ActionType.tsx b/frontend/src/types/ActionType.tsx index a8d469b1cb4..b658b3824c2 100644 --- a/frontend/src/types/ActionType.tsx +++ b/frontend/src/types/ActionType.tsx @@ -5,6 +5,9 @@ enum ActionType { // Represents a message from the user or agent. MESSAGE = "message", + // Regenerates the last message from the agent. + REGENERATE = "regenerate", + // Reads the contents of a file. READ = "read", diff --git a/opendevin/controller/agent_controller.py b/opendevin/controller/agent_controller.py index 533322d5ff8..431fa6ced99 100644 --- a/opendevin/controller/agent_controller.py +++ b/opendevin/controller/agent_controller.py @@ -28,6 +28,7 @@ MessageAction, ModifyTaskAction, NullAction, + RegenerateAction, ) from opendevin.events.event import Event from opendevin.events.observation import ( @@ -160,14 +161,17 @@ async def _start_step_loop(self): async def on_event(self, event: Event): if isinstance(event, ChangeAgentStateAction): await self.set_agent_state_to(event.agent_state) # type: ignore + elif isinstance(event, RegenerateAction): + logger.info(event, extra={'msg_type': 'ACTION'}) + self.event_stream.remove_latest_event() + await self.set_agent_state_to(AgentState.RUNNING) elif isinstance(event, MessageAction): if event.source == EventSource.USER: logger.info( event, extra={'msg_type': 'ACTION', 'event_source': EventSource.USER}, ) - if self.get_agent_state() != AgentState.RUNNING: - await self.set_agent_state_to(AgentState.RUNNING) + await self.set_agent_state_to(AgentState.RUNNING) elif event.source == EventSource.AGENT and event.wait_for_response: await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT) elif isinstance(event, AgentDelegateAction): diff --git a/opendevin/core/schema/action.py b/opendevin/core/schema/action.py index b2cd267e21f..cee4ad85017 100644 --- a/opendevin/core/schema/action.py +++ b/opendevin/core/schema/action.py @@ -12,6 +12,10 @@ class ActionTypeSchema(BaseModel): """Represents a message. """ + REGENERATE: str = Field(default='regenerate') + """Regenerates the message. + """ + START: str = Field(default='start') """Starts a new development task OR send chat from the user. Only sent by the client. """ diff --git a/opendevin/events/action/__init__.py b/opendevin/events/action/__init__.py index 3a4baacb22a..0e9bac76b67 100644 --- a/opendevin/events/action/__init__.py +++ b/opendevin/events/action/__init__.py @@ -10,7 +10,7 @@ from .commands import CmdRunAction, IPythonRunCellAction from .empty import NullAction from .files import FileReadAction, FileWriteAction -from .message import MessageAction +from .message import MessageAction, RegenerateAction from .tasks import AddTaskAction, ModifyTaskAction __all__ = [ @@ -30,5 +30,6 @@ 'ChangeAgentStateAction', 'IPythonRunCellAction', 'MessageAction', + 'RegenerateAction', 'ActionConfirmationStatus', ] diff --git a/opendevin/events/action/message.py b/opendevin/events/action/message.py index b235dd8687b..d640fbb9d19 100644 --- a/opendevin/events/action/message.py +++ b/opendevin/events/action/message.py @@ -23,3 +23,10 @@ def __str__(self) -> str: for url in self.images_urls: ret += f'\nIMAGE_URL: {url}' return ret + + +class RegenerateAction(Action): + action = ActionType.REGENERATE + + def __str__(self) -> str: + return f'**RegenerateAction** (source={self.source})\n' diff --git a/opendevin/events/serialization/action.py b/opendevin/events/serialization/action.py index 3f7a8265af3..3364d33e78f 100644 --- a/opendevin/events/serialization/action.py +++ b/opendevin/events/serialization/action.py @@ -14,7 +14,7 @@ ) from opendevin.events.action.empty import NullAction from opendevin.events.action.files import FileReadAction, FileWriteAction -from opendevin.events.action.message import MessageAction +from opendevin.events.action.message import MessageAction, RegenerateAction from opendevin.events.action.tasks import AddTaskAction, ModifyTaskAction actions = ( @@ -32,6 +32,7 @@ ModifyTaskAction, ChangeAgentStateAction, MessageAction, + RegenerateAction, AgentSummarizeAction, ) diff --git a/opendevin/events/stream.py b/opendevin/events/stream.py index 054ca40af28..543077e2b30 100644 --- a/opendevin/events/stream.py +++ b/opendevin/events/stream.py @@ -141,6 +141,15 @@ def add_event(self, event: Event, source: EventSource): callback = stack[-1] asyncio.create_task(callback(event)) + def remove_latest_event(self): + # Remove NullObservation, RegenerateAction, AgentStateChangedObservation, NullObservation and the previous Action + for _ in range(5): + logger.debug(f'Removing latest event id={self._cur_id - 1}') + logger.debug(f'Removing event: {self.get_latest_event()}') + + self.file_store.delete(self._get_filename_for_id(self._cur_id - 1)) + self._cur_id -= 1 + def filtered_events_by_source(self, source: EventSource): for event in self.get_events(): if event.source == source: