Skip to content

Commit

Permalink
feat: add chat history to user_intent
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesponti committed Sep 25, 2024
1 parent 76cac5a commit 99fd24d
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 198 deletions.
2 changes: 1 addition & 1 deletion src/data/focus_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def search_focus_items(
elif status:
query = query.filter(Focus.state == status.value)

focus_items = query.all()
focus_items = query.limit(10).all()

return focus_items
except Exception as e:
Expand Down
7 changes: 4 additions & 3 deletions src/routers/admin_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from src.services import chroma_service
from src.services.file_service import get_file_contents
from src.services.keywords.keywords_service import get_query_keywords
from src.services.user_intent.user_intent_graph import generate_intent_result_graph, get_user_intent_graph
from src.services.user_intent.user_intent_service import (
generate_intent_result,
generate_intent_result_graph,
get_user_intent,
get_user_intent_graph,
)
from src.utils.config import settings
from src.utils.context import SessionDep

admin_router = APIRouter()

Expand Down Expand Up @@ -53,6 +53,7 @@ def sherpa_keyword_generator(
@admin_router.post("/intent")
def sherpa_user_intent(
request: Request,
session: SessionDep,
input: str = Form(...),
formatted: bool = Form(False),
profile_id: uuid.UUID = Form(...),
Expand All @@ -63,7 +64,7 @@ def sherpa_user_intent(
content = get_file_contents("src/prompts/test_user_input.md")

try:
intents = get_user_intent(content, profile_id=profile_id)
intents = get_user_intent(user_input=content, profile_id=profile_id, session=session)

if formatted:
result = generate_intent_result(intents)
Expand Down
25 changes: 12 additions & 13 deletions src/routers/chat_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from src.data.chat_repository import insert_message
from src.data.models.chat import Chat, ChatState, Message, MessageOutput, MessageRole
from src.data.models.focus import FocusItem
from src.services.user_intent.user_intent_service import generate_intent_result, get_user_intent
from src.utils.context import CurrentProfile, CurrentUser, SessionDep

Expand Down Expand Up @@ -145,37 +144,38 @@ class ChatMessageInput(BaseModel):
class SendChatMessageOutput(BaseModel):
messages: List[MessageOutput]
function_calls: List[str]
focus_items: List[FocusItem]


@chat_router.post("/{chat_id}/messages")
async def send_chat_message(
db: SessionDep, profile: CurrentProfile, input: ChatMessageInput, request: Request
db: SessionDep, profile: CurrentProfile, input: ChatMessageInput, request: Request, chat_id: UUID
) -> SendChatMessageOutput:
chat_id = UUID(request.path_params["chat_id"])

# Insert new message into the database
user_message = insert_message(
db, chat_id=chat_id, profile_id=profile.id, message=input.message, role="user"
)

# Retrieve message from ChatGPT
user_intent = get_user_intent(user_input=input.message, profile_id=profile.id)
user_intent = get_user_intent(
session=db, user_input=input.message, profile_id=profile.id, chat_id=chat_id
)
sherpa_response = generate_intent_result(intent=user_intent)
if sherpa_response is None:
raise Exception("No response from the model")

focus_items = []
function_calls = []

if sherpa_response.create is not None:
focus_items = sherpa_response.create.output
focus_items.extend(sherpa_response.create.output)
function_calls.append("create_tasks")
elif sherpa_response.search is not None:
focus_items = sherpa_response.search.output

if sherpa_response.search is not None:
focus_items.extend(sherpa_response.search.output)
function_calls.append("search_tasks")

# Save system response to the database
system_message = insert_message(
# Save assistant response to the database
assistant_message = insert_message(
db,
chat_id=chat_id,
profile_id=profile.id,
Expand All @@ -187,8 +187,7 @@ async def send_chat_message(
return SendChatMessageOutput(
messages=[
user_message.to_model(session=db),
system_message.to_model(session=db),
assistant_message.to_model(session=db),
],
function_calls=function_calls,
focus_items=[],
)
12 changes: 7 additions & 5 deletions src/routers/sherpa_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
generate_intent_result,
get_user_intent,
)
from src.utils.context import CurrentProfile
from src.utils.context import CurrentProfile, SessionDep
from src.utils.logger import logger

sherpa_router = APIRouter()
Expand All @@ -25,10 +25,10 @@ class GenerateTextIntentInput(BaseModel):

@sherpa_router.post("/text")
async def handle_text_input_route(
profile: CurrentProfile, input: GenerateTextIntentInput
session: SessionDep, profile: CurrentProfile, input: GenerateTextIntentInput
) -> GeneratedIntentsResponse:
try:
intent = get_user_intent(input.content, profile_id=profile.id)
intent = get_user_intent(user_input=input.content, profile_id=profile.id, session=session)
result = generate_intent_result(intent)

return result
Expand Down Expand Up @@ -83,7 +83,9 @@ async def transcribe_audio(audio: AudioUpload, profile: CurrentProfile) -> str:


@sherpa_router.post("/voice")
async def handle_audio_upload_route(audio: AudioUpload, profile: CurrentProfile) -> GeneratedIntentsResponse:
async def handle_audio_upload_route(
session: SessionDep, audio: AudioUpload, profile: CurrentProfile
) -> GeneratedIntentsResponse:
try:
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, "temp_audio.m4a")
Expand All @@ -98,7 +100,7 @@ async def handle_audio_upload_route(audio: AudioUpload, profile: CurrentProfile)
model="whisper-1", file=audio_file, response_format="text"
)

intent = get_user_intent(str(transcription), profile_id=profile.id)
intent = get_user_intent(user_input=str(transcription), profile_id=profile.id, session=session)
result = generate_intent_result(intent)

return result
Expand Down
84 changes: 84 additions & 0 deletions src/services/user_intent/tools/search_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import uuid
from datetime import datetime
from typing import List, Optional, Tuple

import pydantic
from langchain.schema import AgentAction
from langchain_core.tools import tool

from src.data.focus_repository import search_focus_items
from src.data.models.focus import FocusItem, FocusState


@tool("search_tasks", parse_docstring=True)
def search_tasks(
keyword: str,
search_title: str,
profile_id: uuid.UUID,
due_on: Optional[datetime],
due_after: Optional[datetime],
due_before: Optional[datetime],
status: Optional[FocusState],
) -> List[FocusItem]:
"""
Search for tasks based on a keyword or specific attributes.
Args:
keyword: The keyword to search for tasks
search_title: A user-friendly title for the search
profile_id: The user's Profile ID
due_on: The due date in ISO Date Time Format for the task. Example: "2023-01-01T12:00" | None
due_after: A ISO Date Time Format date used when the users wants to search for tasks after a specific date. Example: "2023-01-01T12:00" | None
due_before: A ISO Date Time Format date used when the users wants to search for tasks before a specific date. Example: "2023-01-01T12:00" | None
status: The status of the task
"""
focus_items = search_focus_items(
keyword=keyword,
due_on=due_on,
due_after=due_after,
due_before=due_before,
status=status,
profile_id=profile_id,
)

return [focus_item.to_model() for focus_item in focus_items]


class SearchIntentParameters(pydantic.BaseModel):
keyword: str
profile_id: uuid.UUID
due_on: Optional[datetime] | None = pydantic.Field(
None, description="The due date in ISO Date Time Format for the task"
)
due_after: Optional[datetime] | None = pydantic.Field(
None,
description="A ISO Date Time Format date used when the users wants to search for tasks after a specific date",
)
due_before: Optional[datetime] | None = pydantic.Field(
None,
description="A ISO Date Time Format date used when the users wants to search for tasks before a specific date",
)
status: Optional[str] | None = pydantic.Field(None, description="The status of the task")


class SearchIntentsResponse(pydantic.BaseModel):
input: SearchIntentParameters
output: List[FocusItem]


def format_search_tool_calls(intermediate_steps) -> SearchIntentsResponse | None:
search_tasks: List[Tuple[AgentAction, List[FocusItem]]] = list(
filter(lambda x: x[0].tool == "search_tasks", intermediate_steps)
)

if len(search_tasks) == 0:
return None

search_task = search_tasks[0]
if search_task[0].tool_input is None:
return None

return SearchIntentsResponse(
input=search_task[0].tool_input, # type: ignore
output=search_tasks[0][1],
)
114 changes: 114 additions & 0 deletions src/services/user_intent/user_intent_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import uuid
from datetime import datetime
from typing import Annotated, Any, Sequence

import pydantic
from langchain.agents import create_tool_calling_agent
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
)
from langgraph.graph import END, Graph
from typing_extensions import Dict, TypedDict

from src.routers.chat_router import start_chat
from src.services.file_service import get_file_contents
from src.services.openai_service import openai_chat
from src.services.user_intent.tools import search_tasks
from src.services.user_intent.tools.search_tasks import format_search_tool_calls
from src.services.user_intent.user_intent_service import (
GeneratedIntentsResponse,
create_tasks,
edit_task,
format_chat_tool_call,
format_create_tool_calls,
)


class AgentState(TypedDict):
messages: Annotated[Sequence[HumanMessage | AIMessage], pydantic.Field(default_factory=list)]
tools_used: Annotated[list[str], pydantic.Field(default_factory=list)]
profile_id: uuid.UUID
current_date: str


def determine_tool(state: AgentState):
system_prompt = get_file_contents("src/services/user_intent/user_intent_prompt.md")
tools = [create_tasks, search_tasks, edit_task, start_chat]
chat_prompt = ChatPromptTemplate(
[
SystemMessagePromptTemplate.from_template(template=system_prompt),
("human", ("Profile ID: {profile_id} \n\n" + "Today's Date: {current_date} \n\n" + "{input}")),
("placeholder", "{agent_scratchpad}"),
]
)
agent = create_tool_calling_agent(openai_chat, tools, chat_prompt)
result = agent.invoke(state)
tool_to_use = result.tool
return tool_to_use


def execute_tool(state: AgentState, tool_name: str):
tools = {
"create_tasks": create_tasks,
"search_tasks": search_tasks,
"edit_task": edit_task,
"chat": start_chat,
}
tool = tools[tool_name]
result = tool(**state) # You might need to adjust this based on your tool implementations
state["tools_used"].append(tool_name)
state["messages"].append(AIMessage(content=str(result)))
return state


def should_continue(state: AgentState):
# For simplicity, let's say we continue if we've used less than 3 tools
return len(state["tools_used"]) < 3


workflow = Graph()

workflow.add_node("determine_tool", determine_tool)
workflow.add_node("execute_tool", execute_tool)

workflow.add_edge("determine_tool", "execute_tool")
workflow.add_conditional_edges("execute_tool", should_continue, {True: "determine_tool", False: END})

workflow.set_entry_point("determine_tool")

chain = workflow.compile()


def get_user_intent_graph(user_input: str, profile_id: uuid.UUID) -> Dict[str, Any]:
initial_state = AgentState(
messages=[HumanMessage(content=user_input)],
tools_used=[],
profile_id=profile_id,
current_date=datetime.now().strftime("%Y-%m-%d"),
)
result = chain.invoke(initial_state)
return result


def generate_intent_result_graph(intent: AgentState) -> GeneratedIntentsResponse:
"""
Generates a response based on the user intent.
Args:
intent (AgentState): The final state after running the workflow.
Returns:
result (GeneratedIntentsResponse): The generated response.
"""
steps = intent["tools_used"]
output = intent["messages"][-1].content if intent["messages"] else ""

return GeneratedIntentsResponse(
input=intent["messages"][0].content if intent["messages"] else None,
output=output,
chat=format_chat_tool_call(steps),
create=format_create_tool_calls(steps),
search=format_search_tool_calls(steps),
)
6 changes: 5 additions & 1 deletion src/services/user_intent/user_intent_prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ You MUST include at least 10 keywords in the task keywords list.

## Search


### SEARCH INPUT
- if the user wants to perform a search, the `keyword` parameter should be used to search for tasks based on a keyword or specific attributes. The `keyword` parameter should be a string that represents the search query. It should be the primary focus of the user's query. For example:
- "What do I need at the pet store today?" -> `keyword` = "pet store"
Expand Down Expand Up @@ -78,7 +79,10 @@ You MUST include at least 10 keywords in the task keywords list.
- If a date is used, `You had one event yesterday.`
- If no keyword is used, `Found 10 results`


## Search Rules
- DO NOT EXECUTE THE SEARCH TOOL IF THE USER DOES NOT SPECIFY A KEYWORD OR DATE.
- IF THE USER ASKS A GENERAL QUESTION or A QUESTION RELATED TO THE CONVERSATION, DO NOT EXECUTE THE SEARCH TOOL.
- THE SEARCH TOOL SHOULD BE CALLED WITH ALL REQUIRED PARAMETERS.

---

Expand Down
Loading

0 comments on commit 99fd24d

Please sign in to comment.