Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subgraph state is not inserted to persistance db. #2142

Open
5 tasks done
jhachirag7 opened this issue Oct 19, 2024 · 3 comments
Open
5 tasks done

Subgraph state is not inserted to persistance db. #2142

jhachirag7 opened this issue Oct 19, 2024 · 3 comments

Comments

@jhachirag7
Copy link

jhachirag7 commented Oct 19, 2024

Checked other resources

  • I added a very descriptive title to this issue.
  • I searched the LangGraph/LangChain documentation with the integrated search.
  • I used the GitHub search to find a similar question and didn't find it.
  • I am sure that this is a bug in LangGraph/LangChain rather than my code.
  • I am sure this is better as an issue rather than a GitHub discussion, since this is a LangGraph bug and not a design question.

Example Code

Subgraph class:


import functools
from typing import Annotated, Sequence, TypedDict, Literal
from common.schema import  ResearchSchema
from langchain_core.messages import (
    BaseMessage,
    HumanMessage,
    AIMessage
)
from common.constants import RESEARCH_COLLECTION
from common.firestore_db import update_research_chat, get_doc
from langgraph.checkpoint.postgres import PostgresSaver
from psycopg_pool import ConnectionPool
import os
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
from langgraph.graph import END, StateGraph, START, add_messages, MessagesState
from langgraph.prebuilt import ToolNode
from common.utils import convert_message_to_dict
from tools import ToolSaveHandler
from uuid import uuid4
import ast
from common.firestore_db import update_multiagent_research_chat


class NeoState(MessagesState):
    sender: str
    limit:int

class Agent:
    def __init__(self, llm, tools, inputs, history, feedback_limit, user_query ,system_message: str,simple = False, agent_type = True):
        self.llm = llm
        self.tools = tools
        self.system_message = system_message
        self.simple = simple
        self.inputs=inputs
        self.user_query = user_query
        self.history=history
        self.feedback_limit = feedback_limit
        self.agent=self.create_reviewer_agent()
        if agent_type:
            self.agent = self.create_generator_agent()



    def create_reviewer_agent(self):
        """Create an agent."""
        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """
1. Follow Instructions:
- Adhere to any specific instructions provided by the user, below are user instructions that are essential for producing the correct feedback or critique response.:
```{system_message}``` 
- To ensure the output aligns with the user's requirements. These guidelines are essential for producing the correct feedback or critique response.

Your primary goal is to ensure that the final answer is accurate, well-supported, and complete. Collaborate efficiently with the other assistants, provide valuable critiques, and guide the team toward the best possible outcome.
                    """,
                ),
                
                MessagesPlaceholder(variable_name="messages"),
            ]
        )
        prompt = prompt.partial(tool_names=", ".join([tool.name for tool in self.tools]))
        return prompt | self.llm.bind_tools(self.tools)
        

    def create_generator_agent(self):
        """Create an agent."""
        if self.simple:
            prompt = ChatPromptTemplate.from_messages(
                [
                    (
                        "system",
                        self.system_message
                    ),
                    MessagesPlaceholder(variable_name="history"),
                    MessagesPlaceholder(variable_name="messages"),
                ]
            )
            prompt = prompt.partial(history=self.history)
            prompt = prompt.partial(system_message=self.system_message)
            prompt = prompt.partial(tool_names=", ".join([tool.name for tool in self.tools]))
            return prompt | self.llm.bind_tools(self.tools)
        else:

            prompt = ChatPromptTemplate.from_messages(
                [
                    (
                        "system",
                        """


**1. Follow Instructions:**
- Adhere to any specific instructions provided by the user, below are user instructions that are essential for producing the correct response.:
```{system_message}``` 
- To ensure the output aligns with the user's requirements. These guidelines are essential for producing the correct response.

Your sole purpose is to generate effective, accurate, and well-constructed responses based on the user's query and always do the work according to the instructions, whether using tools or relying on knowledge. Stay within your role, continue generating content, and contribute toward the final solution.
                        """,
                    ),
                    MessagesPlaceholder(variable_name="history"),
                    MessagesPlaceholder(variable_name="messages"),
                ]
            )
            prompt = prompt.partial(history=self.history)
            prompt = prompt.partial(system_message=self.system_message)
            prompt = prompt.partial(tool_names=", ".join([tool.name for tool in self.tools]))
            return prompt | self.llm.bind_tools(self.tools)

    def agent_node(self, state:NeoState, name):
        result = self.agent.invoke(state)
        response = [result]
        cnt=self.feedback_limit
        if 'limit' in state:
            cnt = state['limit']
        if result.tool_calls:
            pass
        else:
            result = AIMessage(**result.dict(exclude={"type", "name"}), name=name)
            post_fill = HumanMessage(**result.dict(exclude={"type", "name"}),name=name, limit=cnt)
            post_fill.id = str(uuid4())
            post_fill.content = f"Reviewer has to review the responses generated by Generator agent and provide feedback."
            response = [result, post_fill]
            if name=='reviewer':
                post_fill.content = f"""Generator has to generate and compile the user query by incorporating feedback.
Given any feedback you need to always generate the user query response.
original user query:\n {self.user_query}"""
                cnt=cnt-1
                return {
                "messages": response,
                "sender": name,
                "limit": cnt
                }

        return {
        "messages": response,
        "sender": name,
        "limit": cnt,
        }

class NeoleadsWorkflow:
    def __init__(self, inputs, generator_system_message,reviewer_system_message ,user_prompt, tools, llm, feedback_limit, chatHistory, config, chat, research,callbacks=[], approval_tool_names=[],simple =False):
        self.inputs=inputs
        self.generator_system_message = self.convert_prompt(generator_system_message,inputs,False)
        self.reviewer_system_message = self.convert_prompt(reviewer_system_message,inputs,False)
        self.simple = simple
        self.user_query = self.convert_prompt(user_prompt,inputs,False)
        self.user_prompt = self.convert_prompt(user_prompt,inputs=inputs)
        self.DB_URI = os.getenv('DB_URI')
        self.connection_kwargs = {
            "autocommit": True,
            "prepare_threshold": 0,
        }
        self.chat = chat
        self.research = research
        self.approval_tool_names=approval_tool_names
        self.tools=tools
        self.config=config
        self.callbacks=callbacks
        self.history=chatHistory
        # self.research=research
        self.snapshot = None
        self.feedback_limit=feedback_limit
        self.llm=llm
        self.intermediate_steps=[]
        self.final_response=""
        self.workflow = self.create_workflow()
        # print(self.workflow.get_graph().draw_mermaid())
        # print(self.workflow.get_graph().print_ascii())
        


    def convert_prompt(self,prompt,inputs, tick=True):
        if tick:
            prompt="Generator will only generate or answer user query:\n" + prompt +"\n\n Reviewer will only review and give feedback:\n"
        input_variables=[key for key,value in inputs.items()]
        template=PromptTemplate(template=prompt,input_variables=input_variables)
        return template.invoke(inputs).text
    
    def human_review_node(self,state:NeoState):
        pass
    
    def route_after_human(self,state:NeoState):
        if isinstance(state["messages"][-1], AIMessage):
            return "call_tool"
        else:
            return "continue"

    def create_workflow(self):

        generator_agent = Agent(self.llm, self.tools,self.inputs,self.history, self.feedback_limit,self.user_query,self.generator_system_message,self.simple)
        reviewer_agent = Agent(self.llm, self.tools,self.inputs,self.history,self.feedback_limit,self.user_query,self.reviewer_system_message,agent_type=False)
        tool_node = ToolNode(self.tools)

        workflow = StateGraph(NeoState)

        workflow.add_node("generator", functools.partial(generator_agent.agent_node, name="generator"))
        workflow.add_node("reviewer", functools.partial(reviewer_agent.agent_node, name="reviewer"))
        workflow.add_node("human_review_node", self.human_review_node)
        workflow.add_node("call_tool", tool_node)
        workflow.add_conditional_edges("human_review_node", self.route_after_human)
        workflow.add_conditional_edges(
            "generator",
            self.router,
            {"continue": "reviewer", "call_tool": "call_tool", "human_review_node":"human_review_node","__end__": END},
        )
        workflow.add_conditional_edges(
            "reviewer",
            self.router,
            {"continue": "generator", "call_tool": "call_tool","human_review_node":"human_review_node", "__end__": END},
        )

        workflow.add_conditional_edges(
            "call_tool",
            lambda x: x["sender"],
            {
                "generator": "generator",
                "reviewer": "reviewer",
            },
        )
        workflow.add_edge(START, "generator")

        return workflow.compile(interrupt_before=["human_review_node"])
    
    def before_tool(self,tool_calls):
        tool_name = tool_calls[0]['name']
        input_str = tool_calls[0]["args"]
        artifect = {
            "tool_input": str(input_str),
            "tool_name": tool_name,
            "file_name": "",
            "response_url": "",
            "approval":True
        }
        if not self.research.chatHistory[-1].metadata:
            self.research.chatHistory[-1].metadata = {}
        if 'tools' in self.research.chatHistory[-1].metadata:
            if len(self.research.chatHistory[-1].metadata['tools'])>0 and self.research.chatHistory[-1].metadata['tools'][-1]['approval']:
                self.research.chatHistory[-1].metadata['tools'][-1]['approval'] = False
            else:
                self.research.chatHistory[-1].metadata['tools'].append(artifect)
        else:
            self.research.chatHistory[-1].metadata['tools'] = [artifect]
        if self.chat:
            update_multiagent_research_chat(self.chat,self.research)

    def router(self, state:NeoState) -> Literal["call_tool", "__end__", "continue"]:
        messages = state["messages"]
        last_message = messages[-1]
        if last_message.tool_calls:
            if last_message.tool_calls[0]['name'] in self.approval_tool_names:
                self.before_tool(last_message.tool_calls)
                return "human_review_node"
            else:
                return "call_tool"
        # if "FINAL ANSWER" in last_message.content:
        #     return "__end__"
        if last_message.name=='generator' and last_message.limit==0:
            return "__end__"
        return "continue"
    ```

Multi agent Class:

```python
import functools
import operator
from typing import Sequence, TypedDict, Literal, List, Dict, Any
from langchain_core.messages import BaseMessage, HumanMessage,AIMessage
from langgraph.graph import END, StateGraph, START,add_messages,MessagesState
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.checkpoint.memory import MemorySaver
from typing import Annotated
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from common.schema import MultiAgentChatHistorySchema,MultiAgentResearchSchema
from pydantic import BaseModel
import os
from multiagent_api.src.langgraph_wrapper import NeoleadsWorkflow
from psycopg_pool import ConnectionPool
from langchain_core.messages import RemoveMessage
from uuid import uuid4

class AgentState(MessagesState):
    next: str
    steps: str
class RouteResponse(BaseModel):
    next: str
    steps: str

class AgentInfo(BaseModel):
    name: str
    system_prompt: str
class DynamicMultiAgentWorkflow:
    def __init__(self, llm, inputs, config, chathistory, callbacks=[],system_message=""):
        self.inputs=inputs
        self.llm = llm
        self.callbacks = callbacks
        self.DB_URI = os.getenv('DB_URI')
        self.connection_kwargs = {
            "autocommit": True,
            "prepare_threshold": 0,
        }
        self.system_message = system_message
        self.config = config
        self.chathistory = chathistory
        self.final_response = "No response yet"
        self.members = []
        self.final_output_prompt = ChatPromptTemplate.from_messages([
            ("system", """ 
Your task is to generate a response or answer based on the user query, context of the conversation, agent responses, and any user feedback provided. Follow these instructions to ensure the response meets the user's expectations:

Instructions:
 1. Analyze the User Query:

   - Start by understanding the user's query clearly. Identify the key points and goals the user wants to achieve.

 2. Review the Context:

   - Consider the conversation history to understand the progress made so far.
   - Incorporate all relevant details from previous agent responses that contribute to solving the user's query.

 3. Integrate Agent Responses:

   - Use information and data from agent responses that have been provided so far in the conversation.
   - If multiple agent responses exist, synthesize the information to ensure a coherent and comprehensive answer.

 4. Consider User Feedback:

   - If the user has provided any feedback during the process, adjust the answer accordingly.
   - The feedback should guide any changes in the response structure or content to better meet the user's expectations.

 5. Generate the Response:

   - Formulate a complete answer based on the information gathered from the user query, context, agent responses, and feedback.
   - Ensure that the answer directly addresses the user's needs, providing a clear and actionable resolution
            
Output Format:
 - Clear and eloborate answer that meets the user's expectations.

Note:
 - Always prioritize user feedback to tailor the final response to their expectations.
 - Ensure clarity and conciseness in the final response.
             """),
            MessagesPlaceholder(variable_name="messages"),
        ])
        self.agents: List[AgentInfo] = []
        self.workflow = StateGraph(AgentState)
        self.supervisor_prompt = ChatPromptTemplate.from_messages([
            ("system", """
As a supervisor, your role is to manage a conversation between multiple AI workers, each with unique capabilities, to resolve the user's query efficiently. Based on the user's query, conversation history, and the available workers, your task is to determine the next step by selecting the most appropriate worker. If the query has been fully answered, select FINISH. Follow the guidelines below to manage the workflow effectively, while also ensuring that user feedback is taken seriously and incorporated into the decision-making process.



***Supervisor Instructions:***

 1. Select the Next Worker:

    - Choose a worker from the provided {options} list based on their capabilities and how they can contribute to answering the user's query.
    - Incorporate any user feedback that may influence which worker should be selected or how the task should proceed.
    - if user feedback is to continue than you have to select the next worker.
    - If no worker is suitable or the query has already been fully answered, select FINISH.


 2. Complete the Task:

    - If the query has been fully answered or no further worker action is needed, select FINISH to end the task.
    - Return FINISH immediately when no more actions are required.
    - Ensure the final output reflects any user feedback that has been incorporated during the process.

***Output Format:***
 1. Worker Name or FINISH:

    - Select and provide the name of the chosen worker from the options list, or select FINISH if the task is complete.

 2. User Query Answer (from Context):

    - Present the response to the user's query based on the current context. This answer should be actual response made by the workers for the user's query.
    - Ensure the answer addresses the user's query and incorporates with any feedback provided by the user.

 3. Steps:

    - Provide a detailed list of steps, outlining what each worker should do next based on their capabilities and the user query:
      - Step 1: Assign the first task to a suitable worker, specifying the exact action required based on their capabilities.
      - Step 2: Once the first worker completes their task, assign the next step to another worker, ensuring that the step matches the worker's specific capabilities.
      - Continue assigning tasks until the query is fully resolved, assigning only tasks that each worker is capable of performing.
      - If any user feedback is provided during the process, adjust the steps accordingly to reflect that feedback.
    - Once the query is answered, return FINISH immediately.
             
**Follow Additional Instructions From User:**
 - Adhere to any specific instructions provided by the user, below are user instructions that are essential for producing the correct response.:
   ```{system_message}``` 
 - To ensure the output aligns with the user's requirements. These guidelines are essential for producing the correct response.
             


             

             """),
            MessagesPlaceholder(variable_name="history"),
            MessagesPlaceholder(variable_name="messages"),
        ])
        
        self.workflow.add_node("supervisor", self.supervisor_agent)
        self.workflow.add_node("human_feedback", self.human_feedback_node)
        self.workflow.add_node("FINISH", self.finish_node)

    def add_agent(self, name: str, system_prompt:str, reviewer_prompt:str ,feedback_limit, config, user_prompt:str, tools, llm, chat, research ,approval_tool_names=[]):
        agent_info = AgentInfo(name=name, system_prompt=system_prompt)
        agent = NeoleadsWorkflow(self.inputs, system_prompt, reviewer_prompt,user_prompt, tools, llm, feedback_limit, self.chathistory, config, chat, research,self.callbacks, approval_tool_names)
        # agent = create_react_agent(self.llm, tools=self.tools)
        node = functools.partial(self.agent_node, agent=agent.create_workflow(), name=name)
        self.workflow.add_node(name, node)
        self.members.append(name)
        self.agents.append(agent_info)

    def compile(self):
        options = ["FINISH"] + self.members
        agent_info = "\n\n".join([f"*** {agent.name}:***\n```{agent.system_prompt}```" for agent in self.agents])
        self.supervisor_prompt = self.supervisor_prompt.partial(
            options=str(options),
            agent_info=agent_info,
            history = self.chathistory,
            system_message = self.system_message
        )

        conditional_map = {k: k for k in self.members}
        conditional_map["FINISH"] = "FINISH"
        self.workflow.add_conditional_edges(
            "supervisor",
            lambda x: x["next"],
            conditional_map
        )
        self.workflow.add_edge("FINISH",END)
        for i in range(len(self.members)):
            self.workflow.add_edge(self.members[i], "human_feedback")
        self.workflow.add_edge("human_feedback", "supervisor")
        self.workflow.add_edge(START, "supervisor")
        
        

    
    def human_feedback_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
        return state
    
    def feedback_run(self, message: str):
        self.compile()
        with ConnectionPool(
            conninfo=self.DB_URI,
            kwargs=self.connection_kwargs,
        ) as pool:
            checkpointer = PostgresSaver(pool)
            # self.config["configurable"] = {"thread_id": "1"}
            graph = self.workflow.compile(checkpointer = checkpointer, interrupt_before=["human_feedback"])
            messages = graph.get_state(self.config).values["messages"]

            graph.update_state(self.config, {"messages": [RemoveMessage(id=m.id) for m in messages]})
            messages = graph.get_state(self.config).values["messages"]

            message = [HumanMessage(content="My Feedback:"+message)]
            graph.update_state(self.config, {"messages": message}, as_node="human_feedback")
            graph.get_state(self.config)
            self.config['recursion_limit'] = 150
            self.config['callbacks'] = self.callbacks
            for s in graph.stream(None, config=self.config,subgraphs=True):
                if "__end__" not in s:
                    print(s)
                    if "FINISH" in s:
                        self.final_response = s["FINISH"]['messages'][-1].content
                    print("----")
    
    def run(self, input_message: str):
        self.compile()
        with ConnectionPool(
            conninfo=self.DB_URI,
            kwargs=self.connection_kwargs,
        ) as pool:
            checkpointer = PostgresSaver(pool)
            # memory = MemorySaver()
            # checkpointer.setup()
            print(self.config)
            graph = self.workflow.compile(checkpointer = checkpointer, interrupt_before=["human_feedback"])
            print(graph.get_graph(xray=True).draw_mermaid())
            self.config['recursion_limit'] = 150
            self.config['callbacks'] = self.callbacks
            # self.config["configurable"] = {"thread_id": "1"}
            for s in graph.stream({"messages": [HumanMessage(content=input_message)],"user_feedback": "No Feedback till now, so continue your work",}, config=self.config,stream_mode="values",subgraphs=True):
                print(s)
                # if "__end__" not in s:
                #     print(s)
                #     if "FINISH" in s:
                #         self.final_response = s["FINISH"]['messages'][-1].content
                #     print("----")
            state = graph.get_state(self.config, subgraphs=True)
            print("------")
            print(state.tasks[0])
            print("------")
            
            # print("********")
            # print(graph)

    @staticmethod
    def agent_node(state: Dict[str, Any], agent: Any, name: str) -> Dict[str, List[HumanMessage]]:
        state["messages"].append(HumanMessage(content=state["steps"]))
        result = agent.invoke(state)
        return {"messages": [AIMessage(content=result["messages"][-1].content, name=name, id=uuid4()), 
                             HumanMessage(content=f"""From the Above assistant response if the task that was requested is complete respond with FINISH
                                           if you are a supervisor agent, if you are any other agent always answer what you are designed to do.""",id=uuid4())]}

    def supervisor_agent(self, state: Dict[str, Any]) -> Dict[str, str]:
        supervisor_chain = (
            self.supervisor_prompt
            | self.llm.with_structured_output(RouteResponse)
        )
        return supervisor_chain.invoke(state)
    
    def finish_node(self, state: Dict[str, Any]):
        chain = (
            self.final_output_prompt
            | self.llm
        )
        result = chain.invoke(state)
        return {"messages":[AIMessage(**result.dict(exclude={"type", "name"}))]}
    


### Error Message and Stack Trace (if applicable)

```shell
when streaming subgraph latest response has the state values:

(('GenerateKeywordsIdeas:63abfcb8-9b4b-ec50-2659-cbfa9def5157',), {'messages': [HumanMessage(content='find seed keywords for neoleads.com', additional_kwargs={}, response_metadata={}, id='c3c68ab2-87d7-41a0-b055-fe7355e83712'), HumanMessage(content='Step 1: Use the scraping tool to extract data from neoleads.com.\nStep 2: Identify seed keywords from the scraped data, focusing on meta tags, content, and key phrases used on the website.\nStep 3: Use Google Search, Bing Search, and YouTube Search to find related keywords and searches based on the extracted seed keywords.\nStep 4: Compile a comprehensive list of seed keywords and related searches, ensuring they are relevant to neoleads.com and its services.', additional_kwargs={}, response_metadata={}, id='f7db98e8-d26d-4a72-861c-c96c88148a69'), AIMessage(content=[{'type': 'text', 'text': "Certainly! I'll follow the steps you've outlined to find seed keywords for neoleads.com. Let's begin by scraping the website content and then use that information to find related keywords.\n\nStep 1: Let's use the scraping tool to extract data from neoleads.com.", 'index': 0}, {'type': 'tool_use', 'id': 'toolu_bdrk_01Vn43sFEdBcL3K6xUzE1hU1', 'name': 'apify_website_content_crawler', 'input': {}, 'index': 1, 'partial_json': '{"query": "https://neoleads.com"}'}], additional_kwargs={}, response_metadata={'stop_reason': 'tool_use', 'stop_sequence': None}, id='run-ffa3b9bc-4810-4560-8afc-dd9b892c6c23-0', tool_calls=[{'name': 'apify_website_content_crawler', 'args': {'query': 'https://neoleads.com'}, 'id': 'toolu_bdrk_01Vn43sFEdBcL3K6xUzE1hU1', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1721, 'output_tokens': 113, 'total_tokens': 1834})], 'sender': 'generator', 'limit': 0})

but once the i print tasks for subgraph, state value is none:

PregelTask(id='63abfcb8-9b4b-ec50-2659-cbfa9def5157', name='GenerateKeywordsIdeas', path=('__pregel_pull', 'GenerateKeywordsIdeas'), error=None, interrupts=(), state=None, result=None)

Description

The state value should be inserted in persistance db.

System Info

requests
langchain
langchain-openai
langchain-aws
langgraph
langchain-community
firebase_admin
supabase
rollbar
tavily-python
google-search-results
pandas
apify-client
mailchimp-marketing
hubspot-api-client
wikipedia
langchain-google-community
langchain-anthropic
psycopg
psycopg-binary
psycopg-pool
langgraph-checkpoint-postgres

@vbarda
Copy link
Collaborator

vbarda commented Oct 21, 2024

@jhachirag7 can you please provide a smaller reproducible code example that demonstrates the issue?

@jhachirag7
Copy link
Author

actually i got what the issue is
node = functools.partial(self.agent_node, agent=agent.create_workflow(), name=name)
agent.create_workflow() is compiledGraph object

if i create node = agent.create_workflow() than internal graph state messages are stored to persistance db
but if we convert node to functions than internal graph state values is none
workflow.add_node(name, node)

@jhachirag7
Copy link
Author

for smaller use case code you can use this below code

from langgraph.graph import StateGraph, END, START, MessagesState
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langchain_aws import ChatBedrock
import os

@tool
def get_weather(city: str):
    """Get the weather for a specific city"""
    return f"It's sunny in {city}!"


raw_model = ChatBedrock(
        model_id="anthropic.claude-3-sonnet-20240229-v1:0",
        streaming= True,
        beta_use_converse_api=False,
    )
model = raw_model.with_structured_output(get_weather)


class SubGraphState(MessagesState):
    city: str


def model_node(state: SubGraphState):
    result = model.invoke(state["messages"])
    return {"city": result["city"]}


def weather_node(state: SubGraphState):
    result = get_weather.invoke({"city": state["city"]})
    return {"messages": [{"role": "assistant", "content": result}]}


subgraph = StateGraph(SubGraphState)
subgraph.add_node(model_node)
subgraph.add_node(weather_node)
subgraph.add_edge(START, "model_node")
subgraph.add_edge("model_node", "weather_node")
subgraph.add_edge("weather_node", END)
subgraph = subgraph.compile(interrupt_before=["weather_node"])
from typing import Literal
from typing_extensions import TypedDict


class RouterState(MessagesState):
    route: Literal["weather", "other"]


class Router(TypedDict):
    route: Literal["weather", "other"]


router_model = raw_model.with_structured_output(Router)


def router_node(state: RouterState):
    system_message = "Classify the incoming query as either about weather or not."
    messages = [{"role": "system", "content": system_message}] + state["messages"]
    route = router_model.invoke(messages)
    return {"route": route["route"]}


def normal_llm_node(state: RouterState):
    response = raw_model.invoke(state["messages"])
    return {"messages": [response]}


def route_after_prediction(
    state: RouterState,
) -> Literal["weather_graph", "normal_llm_node"]:
    if state["route"] == "weather":
        return "weather_graph"
    else:
        return "normal_llm_node"


graph = StateGraph(RouterState)
graph.add_node(router_node)
graph.add_node(normal_llm_node)
graph.add_node("weather_graph", subgraph)
graph.add_edge(START, "router_node")
graph.add_conditional_edges("router_node", route_after_prediction)
graph.add_edge("normal_llm_node", END)
graph.add_edge("weather_graph", END)
from psycopg_pool import ConnectionPool
from langgraph.checkpoint.postgres import PostgresSaver
import os
DB_URI = os.getenv('DB_URI')
connection_kwargs = {
            "autocommit": True,
            "prepare_threshold": 0,
        }
config = {"configurable": {"thread_id": "3"}}

state = None
with ConnectionPool(
            conninfo=DB_URI,
            kwargs=connection_kwargs,
        ) as pool:
            checkpointer = PostgresSaver(pool)
            graph = graph.compile(checkpointer=checkpointer)

            inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]}
            for update in graph.stream(inputs, config=config, stream_mode="values", subgraphs=True):
                # print(update)
                pass

            state = graph.get_state(config, subgraphs=True)
            
print(state.tasks[0])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants