-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add chat history to user_intent
- Loading branch information
1 parent
76cac5a
commit 99fd24d
Showing
8 changed files
with
263 additions
and
198 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import uuid | ||
from datetime import datetime | ||
from typing import List, Optional, Tuple | ||
|
||
import pydantic | ||
from langchain.schema import AgentAction | ||
from langchain_core.tools import tool | ||
|
||
from src.data.focus_repository import search_focus_items | ||
from src.data.models.focus import FocusItem, FocusState | ||
|
||
|
||
@tool("search_tasks", parse_docstring=True) | ||
def search_tasks( | ||
keyword: str, | ||
search_title: str, | ||
profile_id: uuid.UUID, | ||
due_on: Optional[datetime], | ||
due_after: Optional[datetime], | ||
due_before: Optional[datetime], | ||
status: Optional[FocusState], | ||
) -> List[FocusItem]: | ||
""" | ||
Search for tasks based on a keyword or specific attributes. | ||
Args: | ||
keyword: The keyword to search for tasks | ||
search_title: A user-friendly title for the search | ||
profile_id: The user's Profile ID | ||
due_on: The due date in ISO Date Time Format for the task. Example: "2023-01-01T12:00" | None | ||
due_after: A ISO Date Time Format date used when the users wants to search for tasks after a specific date. Example: "2023-01-01T12:00" | None | ||
due_before: A ISO Date Time Format date used when the users wants to search for tasks before a specific date. Example: "2023-01-01T12:00" | None | ||
status: The status of the task | ||
""" | ||
focus_items = search_focus_items( | ||
keyword=keyword, | ||
due_on=due_on, | ||
due_after=due_after, | ||
due_before=due_before, | ||
status=status, | ||
profile_id=profile_id, | ||
) | ||
|
||
return [focus_item.to_model() for focus_item in focus_items] | ||
|
||
|
||
class SearchIntentParameters(pydantic.BaseModel): | ||
keyword: str | ||
profile_id: uuid.UUID | ||
due_on: Optional[datetime] | None = pydantic.Field( | ||
None, description="The due date in ISO Date Time Format for the task" | ||
) | ||
due_after: Optional[datetime] | None = pydantic.Field( | ||
None, | ||
description="A ISO Date Time Format date used when the users wants to search for tasks after a specific date", | ||
) | ||
due_before: Optional[datetime] | None = pydantic.Field( | ||
None, | ||
description="A ISO Date Time Format date used when the users wants to search for tasks before a specific date", | ||
) | ||
status: Optional[str] | None = pydantic.Field(None, description="The status of the task") | ||
|
||
|
||
class SearchIntentsResponse(pydantic.BaseModel): | ||
input: SearchIntentParameters | ||
output: List[FocusItem] | ||
|
||
|
||
def format_search_tool_calls(intermediate_steps) -> SearchIntentsResponse | None: | ||
search_tasks: List[Tuple[AgentAction, List[FocusItem]]] = list( | ||
filter(lambda x: x[0].tool == "search_tasks", intermediate_steps) | ||
) | ||
|
||
if len(search_tasks) == 0: | ||
return None | ||
|
||
search_task = search_tasks[0] | ||
if search_task[0].tool_input is None: | ||
return None | ||
|
||
return SearchIntentsResponse( | ||
input=search_task[0].tool_input, # type: ignore | ||
output=search_tasks[0][1], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import uuid | ||
from datetime import datetime | ||
from typing import Annotated, Any, Sequence | ||
|
||
import pydantic | ||
from langchain.agents import create_tool_calling_agent | ||
from langchain_core.messages import AIMessage, HumanMessage | ||
from langchain_core.prompts import ( | ||
ChatPromptTemplate, | ||
SystemMessagePromptTemplate, | ||
) | ||
from langgraph.graph import END, Graph | ||
from typing_extensions import Dict, TypedDict | ||
|
||
from src.routers.chat_router import start_chat | ||
from src.services.file_service import get_file_contents | ||
from src.services.openai_service import openai_chat | ||
from src.services.user_intent.tools import search_tasks | ||
from src.services.user_intent.tools.search_tasks import format_search_tool_calls | ||
from src.services.user_intent.user_intent_service import ( | ||
GeneratedIntentsResponse, | ||
create_tasks, | ||
edit_task, | ||
format_chat_tool_call, | ||
format_create_tool_calls, | ||
) | ||
|
||
|
||
class AgentState(TypedDict): | ||
messages: Annotated[Sequence[HumanMessage | AIMessage], pydantic.Field(default_factory=list)] | ||
tools_used: Annotated[list[str], pydantic.Field(default_factory=list)] | ||
profile_id: uuid.UUID | ||
current_date: str | ||
|
||
|
||
def determine_tool(state: AgentState): | ||
system_prompt = get_file_contents("src/services/user_intent/user_intent_prompt.md") | ||
tools = [create_tasks, search_tasks, edit_task, start_chat] | ||
chat_prompt = ChatPromptTemplate( | ||
[ | ||
SystemMessagePromptTemplate.from_template(template=system_prompt), | ||
("human", ("Profile ID: {profile_id} \n\n" + "Today's Date: {current_date} \n\n" + "{input}")), | ||
("placeholder", "{agent_scratchpad}"), | ||
] | ||
) | ||
agent = create_tool_calling_agent(openai_chat, tools, chat_prompt) | ||
result = agent.invoke(state) | ||
tool_to_use = result.tool | ||
return tool_to_use | ||
|
||
|
||
def execute_tool(state: AgentState, tool_name: str): | ||
tools = { | ||
"create_tasks": create_tasks, | ||
"search_tasks": search_tasks, | ||
"edit_task": edit_task, | ||
"chat": start_chat, | ||
} | ||
tool = tools[tool_name] | ||
result = tool(**state) # You might need to adjust this based on your tool implementations | ||
state["tools_used"].append(tool_name) | ||
state["messages"].append(AIMessage(content=str(result))) | ||
return state | ||
|
||
|
||
def should_continue(state: AgentState): | ||
# For simplicity, let's say we continue if we've used less than 3 tools | ||
return len(state["tools_used"]) < 3 | ||
|
||
|
||
workflow = Graph() | ||
|
||
workflow.add_node("determine_tool", determine_tool) | ||
workflow.add_node("execute_tool", execute_tool) | ||
|
||
workflow.add_edge("determine_tool", "execute_tool") | ||
workflow.add_conditional_edges("execute_tool", should_continue, {True: "determine_tool", False: END}) | ||
|
||
workflow.set_entry_point("determine_tool") | ||
|
||
chain = workflow.compile() | ||
|
||
|
||
def get_user_intent_graph(user_input: str, profile_id: uuid.UUID) -> Dict[str, Any]: | ||
initial_state = AgentState( | ||
messages=[HumanMessage(content=user_input)], | ||
tools_used=[], | ||
profile_id=profile_id, | ||
current_date=datetime.now().strftime("%Y-%m-%d"), | ||
) | ||
result = chain.invoke(initial_state) | ||
return result | ||
|
||
|
||
def generate_intent_result_graph(intent: AgentState) -> GeneratedIntentsResponse: | ||
""" | ||
Generates a response based on the user intent. | ||
Args: | ||
intent (AgentState): The final state after running the workflow. | ||
Returns: | ||
result (GeneratedIntentsResponse): The generated response. | ||
""" | ||
steps = intent["tools_used"] | ||
output = intent["messages"][-1].content if intent["messages"] else "" | ||
|
||
return GeneratedIntentsResponse( | ||
input=intent["messages"][0].content if intent["messages"] else None, | ||
output=output, | ||
chat=format_chat_tool_call(steps), | ||
create=format_create_tool_calls(steps), | ||
search=format_search_tool_calls(steps), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.