Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple model_context from AssistantAgent #4681

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,28 @@
import json
import logging
import warnings
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Mapping, Sequence
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
)

from autogen_core import CancellationToken, FunctionCall
from autogen_core.model_context import (
ChatCompletionContext,
UnboundedBufferedChatCompletionContext,
)
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
Expand Down Expand Up @@ -215,12 +228,14 @@ def __init__(
self,
name: str,
model_client: ChatCompletionClient,
model_context: Optional[ChatCompletionContext] = None,
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[HandoffBase | str] | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str
| None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
system_message: (
str | None
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
):
Expand Down Expand Up @@ -272,7 +287,8 @@ def __init__(
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
self._model_context: List[LLMMessage] = []
if not model_context:
self._model_context = UnboundedBufferedChatCompletionContext()
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
Expand All @@ -297,19 +313,19 @@ async def on_messages_stream(
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False:
raise ValueError("The model does not support vision.")
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))

# Inner messages.
inner_messages: List[AgentMessage] = []

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))

# Check if the response is a string and return it.
if isinstance(result.content, str):
Expand All @@ -331,7 +347,7 @@ async def on_messages_stream(
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
tool_call_result_msg = ToolCallResultMessage(content=results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
await self._model_context.add_message(FunctionExecutionResultMessage(content=results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

Expand All @@ -356,11 +372,11 @@ async def on_messages_stream(

if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + self._model_context
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
# Yield the response.
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
Expand Down Expand Up @@ -402,14 +418,18 @@ async def _execute_tool_call(

async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Reset the assistant agent to its initialization state."""
self._model_context.clear()
await self._model_context.clear()

async def save_state(self) -> Mapping[str, Any]:
"""Save the current state of the assistant agent."""
return AssistantAgentState(llm_messages=self._model_context.copy()).model_dump()
current_model_ctx_state = self._model_context.save_state()
return AssistantAgentState(llm_messages=current_model_ctx_state["messages"]).model_dump()

async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the assistant agent"""
assistant_agent_state = AssistantAgentState.model_validate(state)
self._model_context.clear()
self._model_context.extend(assistant_agent_state.llm_messages)
await self._model_context.clear()

current_model_ctx_state = dict(self._model_context.save_state())
current_model_ctx_state["messages"] = assistant_agent_state.llm_messages
self._model_context.load_state(current_model_ctx_state)
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from ._buffered_chat_completion_context import BufferedChatCompletionContext
from ._chat_completion_context import ChatCompletionContext
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
from ._unbounded_buffered_chat_completion_context import (
UnboundedBufferedChatCompletionContext,
)

__all__ = [
"ChatCompletionContext",
"UnboundedBufferedChatCompletionContext",
"BufferedChatCompletionContext",
"HeadAndTailChatCompletionContext",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Any, List, Mapping

from ..models import LLMMessage
from ._chat_completion_context import ChatCompletionContext


class UnboundedBufferedChatCompletionContext(ChatCompletionContext):
"""An unbounded buffered chat completion context that keeps a view of the all the messages."""

def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []

async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the memory."""
self._messages.append(message)

async def get_messages(self) -> List[LLMMessage]:
"""Get at most `buffer_size` recent messages."""
return self._messages

async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []

def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
}

def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]