diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index d421d7ccb2ce..ef9ecb2a00c5 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -2,15 +2,27 @@ import json import logging import warnings -from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Mapping, Sequence +from typing import ( + Any, + AsyncGenerator, + Awaitable, + Callable, + Dict, + List, + Mapping, + Sequence, +) from autogen_core import CancellationToken, FunctionCall +from autogen_core.model_context import ( + ChatCompletionContext, + UnboundedChatCompletionContext, +) from autogen_core.models import ( AssistantMessage, ChatCompletionClient, FunctionExecutionResult, FunctionExecutionResultMessage, - LLMMessage, SystemMessage, UserMessage, ) @@ -87,7 +99,6 @@ class AssistantAgent(BaseChatAgent): If multiple handoffs are detected, only the first handoff is executed. - Args: name (str): The name of the agent. model_client (ChatCompletionClient): The model client to use for inference. @@ -96,8 +107,9 @@ class AssistantAgent(BaseChatAgent): allowing it to transfer to other agents by responding with a :class:`HandoffMessage`. The transfer is only executed when the team is in :class:`~autogen_agentchat.teams.Swarm`. If a handoff is a string, it should represent the target agent's name. + model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset. description (str, optional): The description of the agent. - system_message (str, optional): The system message for the model. + system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable. reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`. tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result. @@ -219,9 +231,11 @@ def __init__( *, tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None, handoffs: List[HandoffBase | str] | None = None, + model_context: ChatCompletionContext | None = None, description: str = "An agent that provides assistance with ability to use tools.", - system_message: str - | None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", + system_message: ( + str | None + ) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", reflect_on_tool_use: bool = False, tool_call_summary_format: str = "{result}", ): @@ -273,7 +287,8 @@ def __init__( raise ValueError( f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}" ) - self._model_context: List[LLMMessage] = [] + if not model_context: + self._model_context = UnboundedChatCompletionContext() self._reflect_on_tool_use = reflect_on_tool_use self._tool_call_summary_format = tool_call_summary_format self._is_running = False @@ -301,19 +316,19 @@ async def on_messages_stream( for msg in messages: if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False: raise ValueError("The model does not support vision.") - self._model_context.append(UserMessage(content=msg.content, source=msg.source)) + await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source)) # Inner messages. inner_messages: List[AgentEvent | ChatMessage] = [] # Generate an inference result based on the current model context. - llm_messages = self._system_messages + self._model_context + llm_messages = self._system_messages + await self._model_context.get_messages() result = await self._model_client.create( llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token ) # Add the response to the model context. - self._model_context.append(AssistantMessage(content=result.content, source=self.name)) + await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name)) # Check if the response is a string and return it. if isinstance(result.content, str): @@ -335,7 +350,7 @@ async def on_messages_stream( results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content]) tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name) event_logger.debug(tool_call_result_msg) - self._model_context.append(FunctionExecutionResultMessage(content=results)) + await self._model_context.add_message(FunctionExecutionResultMessage(content=results)) inner_messages.append(tool_call_result_msg) yield tool_call_result_msg @@ -360,11 +375,11 @@ async def on_messages_stream( if self._reflect_on_tool_use: # Generate another inference result based on the tool call and result. - llm_messages = self._system_messages + self._model_context + llm_messages = self._system_messages + await self._model_context.get_messages() result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token) assert isinstance(result.content, str) # Add the response to the model context. - self._model_context.append(AssistantMessage(content=result.content, source=self.name)) + await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name)) # Yield the response. yield Response( chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage), @@ -406,14 +421,15 @@ async def _execute_tool_call( async def on_reset(self, cancellation_token: CancellationToken) -> None: """Reset the assistant agent to its initialization state.""" - self._model_context.clear() + await self._model_context.clear() async def save_state(self) -> Mapping[str, Any]: """Save the current state of the assistant agent.""" - return AssistantAgentState(llm_messages=self._model_context.copy()).model_dump() + model_context_state = await self._model_context.save_state() + return AssistantAgentState(llm_context=model_context_state).model_dump() async def load_state(self, state: Mapping[str, Any]) -> None: """Load the state of the assistant agent""" assistant_agent_state = AssistantAgentState.model_validate(state) - self._model_context.clear() - self._model_context.extend(assistant_agent_state.llm_messages) + # Load the model context state. + await self._model_context.load_state(assistant_agent_state.llm_context) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py b/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py index ddc57e23d94c..002d5fb472a2 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py @@ -1,8 +1,5 @@ from typing import Any, List, Mapping, Optional -from autogen_core.models import ( - LLMMessage, -) from pydantic import BaseModel, Field from ..messages import ( @@ -21,7 +18,7 @@ class BaseState(BaseModel): class AssistantAgentState(BaseState): """State for an assistant agent.""" - llm_messages: List[LLMMessage] = Field(default_factory=list) + llm_context: Mapping[str, Any] = Field(default_factory=lambda: dict([("messages", [])])) type: str = Field(default="AssistantAgentState") diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 6d2fe29beb8f..0409b3c468c4 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -239,8 +239,13 @@ async def test_round_robin_group_chat_state() -> None: await team2.load_state(state) state2 = await team2.save_state() assert state == state2 - assert agent3._model_context == agent1._model_context # pyright: ignore - assert agent4._model_context == agent2._model_context # pyright: ignore + + agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore + agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore + agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore + agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore + assert agent3_model_ctx_messages == agent1_model_ctx_messages + assert agent4_model_ctx_messages == agent2_model_ctx_messages manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore AgentId("group_chat_manager", team1._team_id), # pyright: ignore RoundRobinGroupChatManager, # pyright: ignore @@ -337,7 +342,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned" # Test streaming. - tool_use_agent._model_context.clear() # pyright: ignore + await tool_use_agent._model_context.clear() # pyright: ignore mock.reset() index = 0 await team.reset() @@ -351,7 +356,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch index += 1 # Test Console. - tool_use_agent._model_context.clear() # pyright: ignore + await tool_use_agent._model_context.clear() # pyright: ignore mock.reset() index = 0 await team.reset() @@ -579,8 +584,13 @@ async def test_selector_group_chat_state() -> None: await team2.load_state(state) state2 = await team2.save_state() assert state == state2 - assert agent3._model_context == agent1._model_context # pyright: ignore - assert agent4._model_context == agent2._model_context # pyright: ignore + + agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore + agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore + agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore + agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore + assert agent3_model_ctx_messages == agent1_model_ctx_messages + assert agent4_model_ctx_messages == agent2_model_ctx_messages manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore AgentId("group_chat_manager", team1._team_id), # pyright: ignore SelectorGroupChatManager, # pyright: ignore @@ -931,7 +941,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) - assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned" # Test streaming. - agent1._model_context.clear() # pyright: ignore + await agent1._model_context.clear() # pyright: ignore mock.reset() index = 0 await team.reset() @@ -944,7 +954,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) - index += 1 # Test Console - agent1._model_context.clear() # pyright: ignore + await agent1._model_context.clear() # pyright: ignore mock.reset() index = 0 await team.reset() diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/model-clients.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/model-clients.ipynb index 73998b02e547..cadd0466ab0c 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/model-clients.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/model-clients.ipynb @@ -454,17 +454,17 @@ "\n", "The above `SimpleAgent` always responds with a fresh context that contains only\n", "the system message and the latest user's message.\n", - "We can use model context classes from {py:mod}`autogen_core.components.model_context`\n", + "We can use model context classes from {py:mod}`autogen_core.model_context`\n", "to make the agent \"remember\" previous conversations.\n", "A model context supports storage and retrieval of Chat Completion messages.\n", "It is always used together with a model client to generate LLM-based responses.\n", "\n", - "For example, {py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`\n", + "For example, {py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`\n", "is a most-recent-used (MRU) context that stores the most recent `buffer_size`\n", "number of messages. This is useful to avoid context overflow in many LLMs.\n", "\n", "Let's update the previous example to use\n", - "{py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`." + "{py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`." ] }, { @@ -473,7 +473,7 @@ "metadata": {}, "outputs": [], "source": [ - "from autogen_core.components.model_context import BufferedChatCompletionContext\n", + "from autogen_core.model_context import BufferedChatCompletionContext\n", "from autogen_core.models import AssistantMessage\n", "\n", "\n", @@ -615,7 +615,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/python/packages/autogen-core/samples/common/agents/_chat_completion_agent.py b/python/packages/autogen-core/samples/common/agents/_chat_completion_agent.py index 246861cb6da8..538175ef4ce3 100644 --- a/python/packages/autogen-core/samples/common/agents/_chat_completion_agent.py +++ b/python/packages/autogen-core/samples/common/agents/_chat_completion_agent.py @@ -254,10 +254,10 @@ async def _execute_function( async def save_state(self) -> Mapping[str, Any]: return { - "memory": self._model_context.save_state(), + "chat_history": await self._model_context.save_state(), "system_messages": self._system_messages, } async def load_state(self, state: Mapping[str, Any]) -> None: - self._model_context.load_state(state["memory"]) + await self._model_context.load_state(state["chat_history"]) self._system_messages = state["system_messages"] diff --git a/python/packages/autogen-core/samples/common/patterns/_group_chat_manager.py b/python/packages/autogen-core/samples/common/patterns/_group_chat_manager.py index e8a940beef09..f39e354c9d48 100644 --- a/python/packages/autogen-core/samples/common/patterns/_group_chat_manager.py +++ b/python/packages/autogen-core/samples/common/patterns/_group_chat_manager.py @@ -143,10 +143,12 @@ async def on_new_message(self, message: TextMessage | MultiModalMessage, ctx: Me async def save_state(self) -> Mapping[str, Any]: return { - "chat_history": self._model_context.save_state(), + "chat_history": await self._model_context.save_state(), "termination_word": self._termination_word, } async def load_state(self, state: Mapping[str, Any]) -> None: - self._model_context.load_state(state["chat_history"]) + # Load the chat history. + await self._model_context.load_state(state["chat_history"]) + # Load the termination word. self._termination_word = state["termination_word"] diff --git a/python/packages/autogen-core/samples/slow_human_in_loop.py b/python/packages/autogen-core/samples/slow_human_in_loop.py index 9c4476d06b5c..61ea36fda890 100644 --- a/python/packages/autogen-core/samples/slow_human_in_loop.py +++ b/python/packages/autogen-core/samples/slow_human_in_loop.py @@ -114,7 +114,7 @@ async def save_state(self) -> Mapping[str, Any]: return state_to_save async def load_state(self, state: Mapping[str, Any]) -> None: - self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]}) + await self._model_context.load_state(state["memory"]) class ScheduleMeetingInput(BaseModel): @@ -200,11 +200,11 @@ async def handle_message(self, message: UserTextMessage, ctx: MessageContext) -> async def save_state(self) -> Mapping[str, Any]: return { - "memory": self._model_context.save_state(), + "memory": await self._model_context.save_state(), } async def load_state(self, state: Mapping[str, Any]) -> None: - self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]}) + await self._model_context.load_state(state["memory"]) class NeedsUserInputHandler(DefaultInterventionHandler): diff --git a/python/packages/autogen-core/src/autogen_core/model_context/__init__.py b/python/packages/autogen-core/src/autogen_core/model_context/__init__.py index 8431a2e80dfc..0c8c7af5cf08 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/__init__.py @@ -1,9 +1,14 @@ from ._buffered_chat_completion_context import BufferedChatCompletionContext -from ._chat_completion_context import ChatCompletionContext +from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext +from ._unbounded_chat_completion_context import ( + UnboundedChatCompletionContext, +) __all__ = [ "ChatCompletionContext", + "ChatCompletionContextState", + "UnboundedChatCompletionContext", "BufferedChatCompletionContext", "HeadAndTailChatCompletionContext", ] diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_buffered_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_buffered_chat_completion_context.py index 15b634fad6ac..f66197246e91 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/_buffered_chat_completion_context.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/_buffered_chat_completion_context.py @@ -1,4 +1,4 @@ -from typing import Any, List, Mapping +from typing import List from ..models import FunctionExecutionResultMessage, LLMMessage from ._chat_completion_context import ChatCompletionContext @@ -10,17 +10,15 @@ class BufferedChatCompletionContext(ChatCompletionContext): Args: buffer_size (int): The size of the buffer. - + initial_messages (List[LLMMessage] | None): The initial messages. """ def __init__(self, buffer_size: int, initial_messages: List[LLMMessage] | None = None) -> None: - self._messages: List[LLMMessage] = initial_messages or [] + super().__init__(initial_messages) + if buffer_size <= 0: + raise ValueError("buffer_size must be greater than 0.") self._buffer_size = buffer_size - async def add_message(self, message: LLMMessage) -> None: - """Add a message to the memory.""" - self._messages.append(message) - async def get_messages(self) -> List[LLMMessage]: """Get at most `buffer_size` recent messages.""" messages = self._messages[-self._buffer_size :] @@ -29,17 +27,3 @@ async def get_messages(self) -> List[LLMMessage]: # Remove the first message from the list. messages = messages[1:] return messages - - async def clear(self) -> None: - """Clear the message memory.""" - self._messages = [] - - def save_state(self) -> Mapping[str, Any]: - return { - "messages": [message for message in self._messages], - "buffer_size": self._buffer_size, - } - - def load_state(self, state: Mapping[str, Any]) -> None: - self._messages = state["messages"] - self._buffer_size = state["buffer_size"] diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_chat_completion_context.py index f6cf08c4ba6b..33b1dac7fa18 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/_chat_completion_context.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/_chat_completion_context.py @@ -1,19 +1,40 @@ -from typing import List, Mapping, Protocol +from abc import ABC, abstractmethod +from typing import Any, List, Mapping + +from pydantic import BaseModel, Field from ..models import LLMMessage -class ChatCompletionContext(Protocol): - """A protocol for defining the interface of a chat completion context. +class ChatCompletionContext(ABC): + """An abstract base class for defining the interface of a chat completion context. A chat completion context lets agents store and retrieve LLM messages. - It can be implemented with different recall strategies.""" + It can be implemented with different recall strategies. + + Args: + initial_messages (List[LLMMessage] | None): The initial messages. + """ + + def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None: + self._messages: List[LLMMessage] = initial_messages or [] - async def add_message(self, message: LLMMessage) -> None: ... + async def add_message(self, message: LLMMessage) -> None: + """Add a message to the context.""" + self._messages.append(message) + @abstractmethod async def get_messages(self) -> List[LLMMessage]: ... - async def clear(self) -> None: ... + async def clear(self) -> None: + """Clear the context.""" + self._messages = [] + + async def save_state(self) -> Mapping[str, Any]: + return ChatCompletionContextState(messages=self._messages).model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + self._messages = ChatCompletionContextState.model_validate(state).messages - def save_state(self) -> Mapping[str, LLMMessage]: ... - def load_state(self, state: Mapping[str, LLMMessage]) -> None: ... +class ChatCompletionContextState(BaseModel): + messages: List[LLMMessage] = Field(default_factory=list) diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_head_and_tail_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_head_and_tail_chat_completion_context.py index ab50df41626e..2518f456b632 100644 --- a/python/packages/autogen-core/src/autogen_core/model_context/_head_and_tail_chat_completion_context.py +++ b/python/packages/autogen-core/src/autogen_core/model_context/_head_and_tail_chat_completion_context.py @@ -1,4 +1,4 @@ -from typing import Any, List, Mapping +from typing import List from .._types import FunctionCall from ..models import AssistantMessage, FunctionExecutionResultMessage, LLMMessage, UserMessage @@ -13,17 +13,18 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext): Args: head_size (int): The size of the head. tail_size (int): The size of the tail. + initial_messages (List[LLMMessage] | None): The initial messages. """ - def __init__(self, head_size: int, tail_size: int) -> None: - self._messages: List[LLMMessage] = [] + def __init__(self, head_size: int, tail_size: int, initial_messages: List[LLMMessage] | None = None) -> None: + super().__init__(initial_messages) + if head_size <= 0: + raise ValueError("head_size must be greater than 0.") + if tail_size <= 0: + raise ValueError("tail_size must be greater than 0.") self._head_size = head_size self._tail_size = tail_size - async def add_message(self, message: LLMMessage) -> None: - """Add a message to the memory.""" - self._messages.append(message) - async def get_messages(self) -> List[LLMMessage]: """Get at most `head_size` recent messages and `tail_size` oldest messages.""" head_messages = self._messages[: self._head_size] @@ -51,21 +52,3 @@ async def get_messages(self) -> List[LLMMessage]: placeholder_messages = [UserMessage(content=f"Skipped {num_skipped} messages.", source="System")] return head_messages + placeholder_messages + tail_messages - - async def clear(self) -> None: - """Clear the message memory.""" - self._messages = [] - - def save_state(self) -> Mapping[str, Any]: - return { - "messages": [message for message in self._messages], - "head_size": self._head_size, - "tail_size": self._tail_size, - "placeholder_message": self._placeholder_message, - } - - def load_state(self, state: Mapping[str, Any]) -> None: - self._messages = state["messages"] - self._head_size = state["head_size"] - self._tail_size = state["tail_size"] - self._placeholder_message = state["placeholder_message"] diff --git a/python/packages/autogen-core/src/autogen_core/model_context/_unbounded_chat_completion_context.py b/python/packages/autogen-core/src/autogen_core/model_context/_unbounded_chat_completion_context.py new file mode 100644 index 000000000000..dff45bfc92d8 --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/model_context/_unbounded_chat_completion_context.py @@ -0,0 +1,12 @@ +from typing import List + +from ..models import LLMMessage +from ._chat_completion_context import ChatCompletionContext + + +class UnboundedChatCompletionContext(ChatCompletionContext): + """An unbounded chat completion context that keeps a view of the all the messages.""" + + async def get_messages(self) -> List[LLMMessage]: + """Get at most `buffer_size` recent messages.""" + return self._messages diff --git a/python/packages/autogen-core/tests/test_model_context.py b/python/packages/autogen-core/tests/test_model_context.py index 2fd71574ef87..46f4b6319370 100644 --- a/python/packages/autogen-core/tests/test_model_context.py +++ b/python/packages/autogen-core/tests/test_model_context.py @@ -1,7 +1,11 @@ from typing import List import pytest -from autogen_core.model_context import BufferedChatCompletionContext, HeadAndTailChatCompletionContext +from autogen_core.model_context import ( + BufferedChatCompletionContext, + HeadAndTailChatCompletionContext, + UnboundedChatCompletionContext, +) from autogen_core.models import AssistantMessage, LLMMessage, UserMessage @@ -26,6 +30,17 @@ async def test_buffered_model_context() -> None: retrieved = await model_context.get_messages() assert len(retrieved) == 0 + # Test saving and loading state. + await model_context.add_message(messages[0]) + await model_context.add_message(messages[1]) + state = await model_context.save_state() + await model_context.clear() + await model_context.load_state(state) + retrieved = await model_context.get_messages() + assert len(retrieved) == 2 + assert retrieved[0] == messages[0] + assert retrieved[1] == messages[1] + @pytest.mark.asyncio async def test_head_and_tail_model_context() -> None: @@ -48,3 +63,44 @@ async def test_head_and_tail_model_context() -> None: await model_context.clear() retrieved = await model_context.get_messages() assert len(retrieved) == 0 + + # Test saving and loading state. + for msg in messages: + await model_context.add_message(msg) + state = await model_context.save_state() + await model_context.clear() + await model_context.load_state(state) + retrived = await model_context.get_messages() + assert len(retrived) == 3 + assert retrived[0] == messages[0] + assert retrived[2] == messages[-1] + + +@pytest.mark.asyncio +async def test_unbounded_model_context() -> None: + model_context = UnboundedChatCompletionContext() + messages: List[LLMMessage] = [ + UserMessage(content="Hello!", source="user"), + AssistantMessage(content="What can I do for you?", source="assistant"), + UserMessage(content="Tell what are some fun things to do in seattle.", source="user"), + ] + for msg in messages: + await model_context.add_message(msg) + + retrieved = await model_context.get_messages() + assert len(retrieved) == 3 + assert retrieved == messages + + await model_context.clear() + retrieved = await model_context.get_messages() + assert len(retrieved) == 0 + + # Test saving and loading state. + for msg in messages: + await model_context.add_message(msg) + state = await model_context.save_state() + await model_context.clear() + await model_context.load_state(state) + retrieved = await model_context.get_messages() + assert len(retrieved) == 3 + assert retrieved == messages