diff --git a/src/agents/research_assistant.py b/src/agents/research_assistant.py index 246d8f4..6a4a149 100644 --- a/src/agents/research_assistant.py +++ b/src/agents/research_assistant.py @@ -7,7 +7,7 @@ from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, MessagesState, StateGraph -from langgraph.managed import IsLastStep +from langgraph.managed import RemainingSteps from langgraph.prebuilt import ToolNode from agents.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment @@ -22,7 +22,7 @@ class AgentState(MessagesState, total=False): """ safety: LlamaGuardOutput - is_last_step: IsLastStep + remaining_steps: RemainingSteps web_search = DuckDuckGoSearchResults(name="WebSearch") @@ -75,7 +75,7 @@ async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: if safety_output.safety_assessment == SafetyAssessment.UNSAFE: return {"messages": [format_safety_message(safety_output)], "safety": safety_output} - if state["is_last_step"] and response.tool_calls: + if state["remaining_steps"] < 2 and response.tool_calls: return { "messages": [ AIMessage(