Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load and Save state in AgentChat #4436

Merged
merged 19 commits into from
Dec 5, 2024
Merged

Load and Save state in AgentChat #4436

merged 19 commits into from
Dec 5, 2024

Conversation

victordibia
Copy link
Collaborator

@victordibia victordibia commented Nov 30, 2024

Loading and Saving State in AgentChat

This PR adds mechanisms to load/save state in AgentChat.

  • Load/Save Agents. Not all agents have state to be saved. Currently only AssistantAgent really keeps state - _model_context. Load and save serves to populate that variable. Current changes adds load_state and save_state ChatAgent (abs methods) and stub for BaseChatAgent. Introduces a BaseState and AssistantAgentState class
assistant_agent = AssistantAgent(
    name="assistant_agent",
    system_message="You are a helpful assistant",
    model_client=OpenAIChatCompletionClient(
        model="gpt-4o-2024-08-06",
    ),
)

result = await assistant_agent.run(task="Write a 3 line poem on lake tangayika")
agent_state = await assistant_agent.save_state() 

load state in a new agent

await new_assistant_agent.load_state(agent_state)
result = await new_assistant_agent.run(task="What was the last line of the previous poem you wrote")
The last line of the poem is:  
"Infinite horizon, nature's grand encore."
  • Load/Save Teams.
    Currently implemented on the BaseGroupChat class with load_state and save_state methods. Introduces a BaseTeamState
    • agent_names: List[str] = field(default_factory=list)
    • termination_state: Optional[BaseState] = field(default=None)
    • agent_states: Dict[str, BaseState] = field(default_factory=dict)
    • manager_state: BaseGroupChatManagerState = field( default_factory=BaseGroupChatManagerState)
    • state_type: str = field(default="BaseTeamState")

load_state will load each participant agent state in agent_states and termination_state. All teams that inherit from BaseGroupChat should benefit from this OOTB

Note

What this does not do:

  • Does not save/load GroupChatManager state yet: GroupChats have GroupChatManager classes that have state but are managed by the runtime. Getting and setting their state e.g., current turn etc is more involved. This implementation does not do this yet, but allows for it in the manager_state member of BaseTeamState.
    This can have some effects - e.g, after you load a team state, it might not know the last speaker to continue from and will begin from the first e.g. in RoundRobinGroupChat
  • Custom Agents have to implement their load/save methods
assistant_agent = AssistantAgent(
    name="assistant_agent",
    system_message="You are a helpful assistant",
    model_client=OpenAIChatCompletionClient(
        model="gpt-4o-2024-08-06",
    ),
)
# Define a team
agent_team = RoundRobinGroupChat([assistant_agent], termination_condition=MaxMessageTermination(max_messages=2))

# Run the team and stream messages to the console
stream = agent_team.run_stream(task="Write a beautiful poem 3-line about lake tangayika")
await Console(stream)
team_state = await agent_team.save_state()
from autogen_agentchat.state import BaseTeamState

# convert team state to dict (save to file if needed)
team_state_dict = vars(team_state)
print(team_state_dict)      

# convert back to team state object
team_state = BaseTeamState(**vars(team_state)) 

# load team state
await agent_team.load_state(team_state)
stream = agent_team.run_stream(task="What was the last line of the poem you wrote?")
await Console(stream)
  • Save / Load Termination
max_termination = MaxMessageTermination(max_messages=2)
text_termination = TextMentionTermination(text="stop")
termination = text_termination | max_termination
termination_state = await termination.save_state()
print(termination_state)
await termination.load_state(termination_state)

Why are these changes needed?

Related issue number

Closes #4100

Checks

@victordibia victordibia marked this pull request as ready for review November 30, 2024 04:55
@ekzhu
Copy link
Collaborator

ekzhu commented Nov 30, 2024

I believe right now for team you can serialize the SingleThreadedAgentRuntime by calling the runtime's save_state and load_state. This requires the BaseGroupChatManager and ChatAgentContainer to implement their save_state and load_state methods. For the latter just call the contained ChatAgent's corresponding methods.

This will obviously change once we allow user to customize the runtime the team uses, and a runtime can be shared across multiple teams. However, this will serve as the driving force for that design.

@victordibia
Copy link
Collaborator Author

I believe right now for team you can serialize the SingleThreadedAgentRuntime by calling the runtime's save_state and load_state
. For the latter just call the contained ChatAgent's corresponding methods.

Not sure I understand what you mean above? Can you add some sample code to show what is meant by this?
Also, how does that add up to a developer experience of being able to do team.load_state() and team.save_state()?

@ekzhu
Copy link
Collaborator

ekzhu commented Nov 30, 2024

Each group chat team carries its own runtime currently:

The SingleThreadedAgentRuntime supports save_state and load_state for all registered agents:

async def save_state(self) -> Mapping[str, Any]:
state: Dict[str, Dict[str, Any]] = {}

So, from BaseGroupChat, you can implement save_state by calling the self._runtime.save_state() to get all the agent's states, including the group chat manager's.

This requires implementing the save_state and load_state of Core agents:

async def save_state(self) -> Mapping[str, Any]:
"""Save the state of the agent. The result must be JSON serializable."""

So this means the ChatAgentContainer and BaseGroupChatManager should implement these methods.

Besides all the agent's state, the team itself also carries some state including the team's id, which is generated in the constructor. This team id should be saved and loaded because it forms the reference to the Core agents in the runtime -- all Core agents' AgentID's key field is the team id.

class BaseGroupChat(Team, ABC):
  ...
  async def save_state(self) -> Mapping[str, Any]:
    if not self._initialized:
        return {"team_id": self._team_id} # If not initialized the only state we have is team id
    if self._is_running:
      raise RuntimeError("The group chat is currently running. It must be stopped before calling save_state()")
    self._is_running = True # Prevent running the team while saving states.
    runtime_state = await self._runtime.save_state()
    self._is_running = False
    return {"runtime_state": runtime_state, "team_id": self._team_id}

There might be other stuff like self._stop_reason etc., but should be trivial to handle.

@victordibia
Copy link
Collaborator Author

victordibia commented Nov 30, 2024

Ah, got it. This PR certainly goes in a diff/wrong direction. Do you want to take a stab at the approach you describe above which seems a lot more straightforward (it is still way too under specified for me to explore atm, similar to the original issue).
I’ll close this for now in favor of that PR which I’ll be happy to review.

@victordibia
Copy link
Collaborator Author

Thanks Eric, I tested, looks good.

@ekzhu ekzhu merged commit 777f2ab into main Dec 5, 2024
44 checks passed
@ekzhu ekzhu deleted the save_load_state_vd branch December 5, 2024 00:14
gagb pushed a commit that referenced this pull request Dec 5, 2024
1. convert dataclass types to pydantic basemodel
2. add save_state and load_state for ChatAgent
3. state types for AgentChat
---------

Co-authored-by: Eric Zhu <[email protected]>
rysweet pushed a commit that referenced this pull request Dec 10, 2024
1. convert dataclass types to pydantic basemodel 
2. add save_state and load_state for ChatAgent
3. state types for AgentChat
---------

Co-authored-by: Eric Zhu <[email protected]>
@lsy641
Copy link

lsy641 commented Dec 10, 2024

Hello. How to use the team save_state feature? Currently, I am not able to use it. I got "AttributeError: 'RoundRobinGroupChat' object has no attribute 'save_state'" in 0.4.0 dev

@ekzhu
Copy link
Collaborator

ekzhu commented Dec 10, 2024

@lsy641 please upgrade to the latest dev release (dev11 as we speak)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AgentChat agents and teams can be rolled back to snapshots of states
3 participants