From b9a1a485e6cfd902dad132c7a051e09e121a5c4f Mon Sep 17 00:00:00 2001 From: uommou Date: Tue, 7 May 2024 12:59:41 +0900 Subject: [PATCH 01/11] =?UTF-8?q?hotfix:=20=EC=B6=A9=EB=8F=8C=ED=95=B4?= =?UTF-8?q?=EA=B2=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/dto/openai_dto.py | 10 +--------- app/prompt/openai_prompt.py | 2 +- app/routers/recommendation.py | 11 ----------- 3 files changed, 2 insertions(+), 21 deletions(-) diff --git a/app/dto/openai_dto.py b/app/dto/openai_dto.py index efcce7c..d5e0805 100644 --- a/app/dto/openai_dto.py +++ b/app/dto/openai_dto.py @@ -23,7 +23,6 @@ class EmailResponse(BaseModel): text: str image: str -<<<<<<< Updated upstream class ActivityDescription(BaseModel): activity: str imageTag: str @@ -31,11 +30,4 @@ class ActivityDescription(BaseModel): class RecommendationResponse(BaseModel): ness: str activityList: List[ActivityDescription] -======= -class ActivityInfo(BaseModel): - activity: str - imageTag: str -class RecommendationResponse(BaseModel): - ness: str - activityList: List[ActivityInfo] ->>>>>>> Stashed changes + diff --git a/app/prompt/openai_prompt.py b/app/prompt/openai_prompt.py index e67aabe..317221d 100644 --- a/app/prompt/openai_prompt.py +++ b/app/prompt/openai_prompt.py @@ -84,7 +84,7 @@ class Template: 2. Organize the event the user wants to add into a json format for saving in a database. The returned json will have keys for info, location, person, and date. - info: Summarizes what the user wants to do. This value must always be present. - location: If the user's event information includes a place, save that place as the value. - - person: If the user's event mentions a person they want to include, save that person as the value. + - person: If th e user's event mentions a person they want to include, save that person as the value. - date: If the user's event information includes a specific date and time, save that date and time in datetime format. Dates should be organized based on the current time at the user's location. Current time is {current_time}. Separate the outputs for tasks 1 and 2 with a special token . diff --git a/app/routers/recommendation.py b/app/routers/recommendation.py index ca8bc14..4cbe35f 100644 --- a/app/routers/recommendation.py +++ b/app/routers/recommendation.py @@ -8,11 +8,7 @@ from langchain_core.prompts import PromptTemplate from app.dto.db_dto import RecommendationMainRequestDTO -<<<<<<< Updated upstream from app.dto.openai_dto import RecommendationResponse, ActivityDescription -======= -from app.dto.openai_dto import RecommendationResponse ->>>>>>> Stashed changes from app.prompt import openai_prompt, persona_prompt import app.database.chroma_db as vectordb @@ -54,7 +50,6 @@ async def get_recommendation(user_data: RecommendationMainRequestDTO) -> Recomme user_persona_prompt = persona_prompt.Template.from_persona(persona) prompt = PromptTemplate.from_template(recommendation_template) -<<<<<<< Updated upstream ness = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", schedule=day_schedule @@ -87,12 +82,6 @@ async def get_recommendation(user_data: RecommendationMainRequestDTO) -> Recomme response = RecommendationResponse(ness=ness, activityList=activity_list) return response -======= - ness = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", schedule=schedule)) - print(ness) - - return RecommendationResponse(ness=ness) ->>>>>>> Stashed changes except Exception as e: raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file From 7ef38e322effff8f2297a579d008c7090fbd66be Mon Sep 17 00:00:00 2001 From: uommou Date: Wed, 15 May 2024 17:20:14 +0900 Subject: [PATCH 02/11] =?UTF-8?q?feat:=20whisper=20=EC=97=B0=EA=B2=B0?= =?UTF-8?q?=EC=9D=84=20=EC=9C=84=ED=95=9C=20api=20=EC=8A=A4=ED=8E=99=20?= =?UTF-8?q?=EB=B3=80=EA=B2=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/dto/openai_dto.py | 1 + app/routers/chat.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/app/dto/openai_dto.py b/app/dto/openai_dto.py index d5e0805..bb04a4c 100644 --- a/app/dto/openai_dto.py +++ b/app/dto/openai_dto.py @@ -4,6 +4,7 @@ class PromptRequest(BaseModel): prompt: str persona: str + chatType: str class ChatResponse(BaseModel): ness: str diff --git a/app/routers/chat.py b/app/routers/chat.py index e8b35ef..2b110d2 100644 --- a/app/routers/chat.py +++ b/app/routers/chat.py @@ -38,14 +38,24 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse: model_name=config_chat['MODEL_NAME'], # 모델명 openai_api_key=OPENAI_API_KEY # API 키 ) - question = data.prompt + question = data.prompt # 유저로부터 받은 채팅의 내용 + chat_type = data.chatType # 위스퍼 사용 여부 [STT, USER] # description: give NESS's instruction as for case analysis my_template = openai_prompt.Template.case_classify_template - prompt = PromptTemplate.from_template(my_template) - case = chat_model.predict(prompt.format(question=question)) + # chat type에 따라 적합한 프롬프트를 삽입 + if chat_type == "STT": + prompt = PromptTemplate.from_template(my_template) + case = chat_model.predict(prompt.format(question=question)) + elif chat_type == "USER": + prompt = PromptTemplate.from_template(my_template) + case = chat_model.predict(prompt.format(question=question)) + else: + prompt = PromptTemplate.from_template(my_template) + case = chat_model.predict(prompt.format(question=question)) + # 각 케이스에도 chat type에 따라 적합한 프롬프트 삽입 필요 print(case) case = int(case) if case == 1: From d65978212370de4cadb86a43427682bc04e641a8 Mon Sep 17 00:00:00 2001 From: uommou Date: Wed, 15 May 2024 17:37:21 +0900 Subject: [PATCH 03/11] =?UTF-8?q?feat:=20stt=20=EC=82=AC=EC=9A=A9=EC=8B=9C?= =?UTF-8?q?=20=EC=B6=94=EA=B0=80=ED=95=A0=20=ED=94=84=EB=A1=AC=ED=94=84?= =?UTF-8?q?=ED=8A=B8=20=EC=9E=91=EC=84=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/prompt/openai_prompt.py | 9 +++++++++ app/routers/chat.py | 12 +++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/app/prompt/openai_prompt.py b/app/prompt/openai_prompt.py index 317221d..980bfb7 100644 --- a/app/prompt/openai_prompt.py +++ b/app/prompt/openai_prompt.py @@ -41,6 +41,8 @@ class Template: Task: User Chat Classification You are a case classifier integrated in scheduler application. Please analyze User Chat according to the following criteria and return the appropriate case number (1, 2, 3). + {chat_type} + - Case 1: \ The question is a general information request, advice, or simple conversation, and does not require accessing the user's schedule database. - Case 2: \ @@ -71,6 +73,13 @@ class Template: User Chat: {question} Answer: """ + chat_type_stt_template = """ + You should keep in mind that this user's input was written using speech to text technology. + Therefore, there may be inaccuracies in the text due to errors in the STT process. + You need to consider this aspect when performing the given task. + """ + chat_type_user_template = """ + """ case1_template = """ {persona} YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. diff --git a/app/routers/chat.py b/app/routers/chat.py index 2b110d2..00e4a24 100644 --- a/app/routers/chat.py +++ b/app/routers/chat.py @@ -47,13 +47,14 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse: # chat type에 따라 적합한 프롬프트를 삽입 if chat_type == "STT": prompt = PromptTemplate.from_template(my_template) - case = chat_model.predict(prompt.format(question=question)) + chat_type_prompt = openai_prompt.Template.chat_type_stt_template + case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt)) elif chat_type == "USER": prompt = PromptTemplate.from_template(my_template) - case = chat_model.predict(prompt.format(question=question)) + chat_type_prompt = openai_prompt.Template.chat_type_user_template + case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt)) else: - prompt = PromptTemplate.from_template(my_template) - case = chat_model.predict(prompt.format(question=question)) + raise HTTPException(status_code=500, detail="WRONG CHAT TYPE") # 각 케이스에도 chat type에 따라 적합한 프롬프트 삽입 필요 print(case) @@ -68,9 +69,6 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse: response = await get_langchain_rag(data) else: - # print("wrong case classification") - # # 적절한 HTTP 상태 코드와 함께 오류 메시지를 반환하거나, 다른 처리를 할 수 있습니다. - # raise HTTPException(status_code=400, detail="Wrong case classification") response = "좀 더 명확한 요구가 필요해요. 다시 한 번 얘기해주실 수 있나요?" case = "Exception" From 372901467c57d4d2446e9a4e6c58cc2a6e19250c Mon Sep 17 00:00:00 2001 From: uommou Date: Wed, 15 May 2024 17:54:08 +0900 Subject: [PATCH 04/11] =?UTF-8?q?feat:=20whisper=20=EC=82=AC=EC=9A=A9=20?= =?UTF-8?q?=EA=B0=80=EB=8A=A5=ED=95=98=EB=8F=84=EB=A1=9D=20=ED=94=84?= =?UTF-8?q?=EB=A1=AC=ED=94=84=ED=8A=B8=20=EC=9E=91=EC=84=B1=20=EC=99=84?= =?UTF-8?q?=EB=A3=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/prompt/openai_prompt.py | 6 +++++- app/routers/chat.py | 27 ++++++++++++++------------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/app/prompt/openai_prompt.py b/app/prompt/openai_prompt.py index 980bfb7..9ee23ca 100644 --- a/app/prompt/openai_prompt.py +++ b/app/prompt/openai_prompt.py @@ -82,11 +82,13 @@ class Template: """ case1_template = """ {persona} - YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. + {chat_type} + YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. Current time is {current_time}. Respond to the user considering the current time. User input: {question} """ case2_template = """ {persona} + {chat_type} The user's input contains information about a new event they want to add to their schedule. You have two tasks to perform: 1. Respond kindly to the user's input. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. @@ -118,6 +120,8 @@ class Template: case3_template = """ {persona} + {chat_type} + Current time is {current_time}. Respond to the user considering the current time. When responding to user inputs, it's crucial to adapt your responses to the specified output language, maintaining a consistent and accessible communication style. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. Your responses should not only be accurate but also display empathy and understanding of the user's needs. You are equipped with a state-of-the-art RAG (Retrieval-Augmented Generation) technique, enabling you to dynamically pull relevant schedule information from a comprehensive database tailored to the user's specific inquiries. This technique enhances your ability to provide precise, context-aware responses by leveraging real-time data retrieval combined with advanced natural language understanding. diff --git a/app/routers/chat.py b/app/routers/chat.py index 00e4a24..b7c60ae 100644 --- a/app/routers/chat.py +++ b/app/routers/chat.py @@ -46,27 +46,26 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse: # chat type에 따라 적합한 프롬프트를 삽입 if chat_type == "STT": - prompt = PromptTemplate.from_template(my_template) chat_type_prompt = openai_prompt.Template.chat_type_stt_template - case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt)) elif chat_type == "USER": - prompt = PromptTemplate.from_template(my_template) chat_type_prompt = openai_prompt.Template.chat_type_user_template - case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt)) else: raise HTTPException(status_code=500, detail="WRONG CHAT TYPE") + prompt = PromptTemplate.from_template(my_template) + case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt)) + # 각 케이스에도 chat type에 따라 적합한 프롬프트 삽입 필요 print(case) case = int(case) if case == 1: - response = await get_langchain_normal(data) + response = await get_langchain_normal(data, chat_type_prompt) elif case == 2: - response = await get_langchain_schedule(data) + response = await get_langchain_schedule(data, chat_type_prompt) elif case == 3: - response = await get_langchain_rag(data) + response = await get_langchain_rag(data, chat_type_prompt) else: response = "좀 더 명확한 요구가 필요해요. 다시 한 번 얘기해주실 수 있나요?" @@ -77,7 +76,7 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse: # case 1 : normal #@router.post("/case/normal") # 테스트용 엔드포인트 -async def get_langchain_normal(data: PromptRequest): # case 1 : normal +async def get_langchain_normal(data: PromptRequest, chat_type_prompt): # case 1 : normal print("running case 1") # description: use langchain config_normal = config['NESS_NORMAL'] @@ -95,13 +94,14 @@ async def get_langchain_normal(data: PromptRequest): # case 1 : normal my_template = openai_prompt.Template.case1_template prompt = PromptTemplate.from_template(my_template) - response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question)) + current_time = datetime.now() + response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt)) print(response) return response # case 2 : 일정 생성 #@router.post("/case/make_schedule") # 테스트용 엔드포인트 -async def get_langchain_schedule(data: PromptRequest): +async def get_langchain_schedule(data: PromptRequest, chat_type_prompt): print("running case 2") # description: use langchain config_normal = config['NESS_NORMAL'] @@ -118,13 +118,13 @@ async def get_langchain_schedule(data: PromptRequest): prompt = PromptTemplate.from_template(case2_template) current_time = datetime.now() - response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time)) + response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt)) print(response) return response # case 3 : rag #@router.post("/case/rag") # 테스트용 엔드포인트 -async def get_langchain_rag(data: PromptRequest): +async def get_langchain_rag(data: PromptRequest, chat_type_prompt): print("running case 3") # description: use langchain config_normal = config['NESS_NORMAL'] @@ -145,6 +145,7 @@ async def get_langchain_rag(data: PromptRequest): case3_template = openai_prompt.Template.case3_template prompt = PromptTemplate.from_template(case3_template) - response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, schedule=schedule)) + current_time = datetime.now() + response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, schedule=schedule, current_time=current_time, chat_type=chat_type_prompt)) print(response) return response From d058b5c9bcb21a16ff2e0ce826bdce58c92b001a Mon Sep 17 00:00:00 2001 From: uommou Date: Fri, 24 May 2024 19:44:35 +0900 Subject: [PATCH 05/11] =?UTF-8?q?fix:=20case3=EC=97=90=20=EB=A9=A4?= =?UTF-8?q?=EB=B2=84=20=ED=95=84=ED=84=B0=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. DTO 수정 2. RAG 쿼리 필터 수정 --- app/database/chroma_db.py | 8 ++++---- app/dto/openai_dto.py | 1 + app/routers/chat.py | 3 ++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/app/database/chroma_db.py b/app/database/chroma_db.py index cb8539d..d474f1e 100644 --- a/app/database/chroma_db.py +++ b/app/database/chroma_db.py @@ -37,12 +37,12 @@ def check_db_heartbeat(): chroma_client.heartbeat() # description: DB에서 검색하는 함수 - chat case 3에 사용 -async def search_db_query(query): - # 컬렉션 생성 - # 컬렉션에 쿼리 전송 +async def search_db_query(member_id, query): + member = member_id result = schedules.query( query_texts=query, - n_results=5 # 결과에서 한 가지 문서만 반환하면 한강공원이, 두 가지 문서 반환하면 AI가 뜸->유사도가 이상하게 검사되는 것 같음 + n_results=5, # 결과에서 한 가지 문서만 반환하면 한강공원이, 두 가지 문서 반환하면 AI가 뜸->유사도가 이상하게 검사되는 것 같음 + where={"member": {"$eq": int(member)}} ) return result diff --git a/app/dto/openai_dto.py b/app/dto/openai_dto.py index bb04a4c..0c235f9 100644 --- a/app/dto/openai_dto.py +++ b/app/dto/openai_dto.py @@ -2,6 +2,7 @@ from typing import List class PromptRequest(BaseModel): + member_id: int prompt: str persona: str chatType: str diff --git a/app/routers/chat.py b/app/routers/chat.py index b7c60ae..8828d4a 100644 --- a/app/routers/chat.py +++ b/app/routers/chat.py @@ -134,12 +134,13 @@ async def get_langchain_rag(data: PromptRequest, chat_type_prompt): model_name=config_normal['MODEL_NAME'], # 모델명 openai_api_key=OPENAI_API_KEY # API 키 ) + member_id = data.member_id question = data.prompt persona = data.persona user_persona_prompt = persona_prompt.Template.from_persona(persona) # vectordb.search_db_query를 비동기적으로 호출합니다. - schedule = await vectordb.search_db_query(question) # vector db에서 검색 + schedule = await vectordb.search_db_query(member_id, question) # vector db에서 검색 # description: give NESS's ideal instruction as template case3_template = openai_prompt.Template.case3_template From a1986d2c796b88563c7c8bcffcae49707a14bb64 Mon Sep 17 00:00:00 2001 From: uommou Date: Fri, 24 May 2024 21:24:59 +0900 Subject: [PATCH 06/11] =?UTF-8?q?feat:=20delete=20schedule=20=EC=B6=94?= =?UTF-8?q?=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/database/chroma_db.py | 11 ++++++++++- app/dto/db_dto.py | 4 ++++ app/routers/chromadb.py | 13 +++++++++++-- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/app/database/chroma_db.py b/app/database/chroma_db.py index d474f1e..0abe634 100644 --- a/app/database/chroma_db.py +++ b/app/database/chroma_db.py @@ -10,7 +10,7 @@ import os import datetime from dotenv import load_dotenv -from app.dto.db_dto import AddScheduleDTO, RecommendationMainRequestDTO, ReportTagsRequestDTO +from app.dto.db_dto import AddScheduleDTO, DeleteScheduleDTO, RecommendationMainRequestDTO, ReportTagsRequestDTO load_dotenv() CHROMA_DB_IP_ADDRESS = os.getenv("CHROMA_DB_IP_ADDRESS") @@ -61,6 +61,15 @@ async def add_db_data(schedule_data: AddScheduleDTO): return True # 메인페이지 한 줄 추천 기능에 사용하는 함수 +async def delete_db_data(schedule_data: DeleteScheduleDTO): + member_id = schedule_data.member_id + schedule_id = schedule_data.schedule_id + schedules.delete( + ids=[str(schedule_id)], + where={"member": {"$eq": int(member_id)}} + ) + return True + # 유저의 id, 해당 날짜로 필터링 async def db_daily_schedule(user_data: RecommendationMainRequestDTO): member = user_data.member_id diff --git a/app/dto/db_dto.py b/app/dto/db_dto.py index 8be0800..5d84981 100644 --- a/app/dto/db_dto.py +++ b/app/dto/db_dto.py @@ -11,6 +11,10 @@ class AddScheduleDTO(BaseModel): location: str person: str +class DeleteScheduleDTO(BaseModel): + schedule_id: int + member_id: int + class RecommendationMainRequestDTO(BaseModel): member_id: int user_persona: str diff --git a/app/routers/chromadb.py b/app/routers/chromadb.py index 99ce84a..f204b95 100644 --- a/app/routers/chromadb.py +++ b/app/routers/chromadb.py @@ -4,8 +4,8 @@ from dotenv import load_dotenv from fastapi import APIRouter, HTTPException, Depends, status -from app.dto.db_dto import AddScheduleDTO -from app.database.chroma_db import add_db_data, get_chroma_client +from app.dto.db_dto import AddScheduleDTO, DeleteScheduleDTO +from app.database.chroma_db import add_db_data, delete_db_data, get_chroma_client router = APIRouter( prefix="/chromadb", @@ -30,3 +30,12 @@ async def add_schedule_endpoint(schedule_data: AddScheduleDTO, chroma_client=Dep return {"message": "Schedule added successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + +@router.post("/delete_schedule", status_code=status.HTTP_204_NO_CONTENT) +async def delete_schedule_endpoint(schedule_data: DeleteScheduleDTO, chroma_client=Depends(get_chroma_client)): + try: + # 직접 `add_db_data` 함수를 비동기적으로 호출합니다. + await delete_db_data(schedule_data) + return {"message": "Schedule delete successfully"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) From 3cad2ebef2bdfb5f8d630400c979ff065a762189 Mon Sep 17 00:00:00 2001 From: uommou Date: Sat, 25 May 2024 20:53:43 +0900 Subject: [PATCH 07/11] fix: delete schedule --- app/routers/chromadb.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/app/routers/chromadb.py b/app/routers/chromadb.py index f204b95..4fd8775 100644 --- a/app/routers/chromadb.py +++ b/app/routers/chromadb.py @@ -31,7 +31,7 @@ async def add_schedule_endpoint(schedule_data: AddScheduleDTO, chroma_client=Dep except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@router.post("/delete_schedule", status_code=status.HTTP_204_NO_CONTENT) +@router.delete("/delete_schedule", status_code=status.HTTP_204_NO_CONTENT) async def delete_schedule_endpoint(schedule_data: DeleteScheduleDTO, chroma_client=Depends(get_chroma_client)): try: # 직접 `add_db_data` 함수를 비동기적으로 호출합니다. @@ -39,3 +39,13 @@ async def delete_schedule_endpoint(schedule_data: DeleteScheduleDTO, chroma_clie return {"message": "Schedule delete successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/delete_schedule", status_code=status.HTTP_204_NO_CONTENT) +async def delete_schedule_endpoint(schedule_data: DeleteScheduleDTO, chroma_client=Depends(get_chroma_client)): + try: + # 직접 `add_db_data` 함수를 비동기적으로 호출합니다. + await delete_db_data(schedule_data) + return {"message": "Schedule delete successfully"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file From 285fbb17a70a1f19c1e3dfbfac0448ed15c8b2d8 Mon Sep 17 00:00:00 2001 From: uommou Date: Sat, 25 May 2024 21:28:52 +0900 Subject: [PATCH 08/11] feat: update schedule --- app/database/chroma_db.py | 27 ++++++++++++++++++++++++++- app/dto/db_dto.py | 12 +++++++++++- app/routers/chromadb.py | 14 +++++++------- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/app/database/chroma_db.py b/app/database/chroma_db.py index 0abe634..677854a 100644 --- a/app/database/chroma_db.py +++ b/app/database/chroma_db.py @@ -10,7 +10,7 @@ import os import datetime from dotenv import load_dotenv -from app.dto.db_dto import AddScheduleDTO, DeleteScheduleDTO, RecommendationMainRequestDTO, ReportTagsRequestDTO +from app.dto.db_dto import AddScheduleDTO, DeleteScheduleDTO, UpdateScheduleDTO, RecommendationMainRequestDTO, ReportTagsRequestDTO load_dotenv() CHROMA_DB_IP_ADDRESS = os.getenv("CHROMA_DB_IP_ADDRESS") @@ -70,6 +70,31 @@ async def delete_db_data(schedule_data: DeleteScheduleDTO): ) return True +# 데이터베이스 업데이트 함수 정의 +async def update_db_data(schedule_data: UpdateScheduleDTO): + schedule_date = schedule_data.schedule_datetime_start.split("T")[0] + year = int(schedule_date.split("-")[0]) + month = int(schedule_date.split("-")[1]) + date = int(schedule_date.split("-")[2]) + + # 기존 스케줄 업데이트 로직 + schedules.update( + documents=[schedule_data.data], + ids=[str(schedule_data.schedule_id)], + metadatas=[{ + "year": year, + "month": month, + "date": date, + "datetime_start": schedule_data.schedule_datetime_start, + "datetime_end": schedule_data.schedule_datetime_end, + "member": schedule_data.member_id, + "category": schedule_data.category, + "location": schedule_data.location, + "person": schedule_data.person + }] + ) + return True + # 유저의 id, 해당 날짜로 필터링 async def db_daily_schedule(user_data: RecommendationMainRequestDTO): member = user_data.member_id diff --git a/app/dto/db_dto.py b/app/dto/db_dto.py index 5d84981..cac7dc9 100644 --- a/app/dto/db_dto.py +++ b/app/dto/db_dto.py @@ -7,7 +7,7 @@ class AddScheduleDTO(BaseModel): schedule_datetime_end: str schedule_id: int member_id: int - category: int + category: str location: str person: str @@ -15,6 +15,16 @@ class DeleteScheduleDTO(BaseModel): schedule_id: int member_id: int +class UpdateScheduleDTO(BaseModel): + data: str + schedule_datetime_start: str + schedule_datetime_end: str + schedule_id: int + member_id: int + category: str + location: str + person: str + class RecommendationMainRequestDTO(BaseModel): member_id: int user_persona: str diff --git a/app/routers/chromadb.py b/app/routers/chromadb.py index 4fd8775..6bca24f 100644 --- a/app/routers/chromadb.py +++ b/app/routers/chromadb.py @@ -4,8 +4,8 @@ from dotenv import load_dotenv from fastapi import APIRouter, HTTPException, Depends, status -from app.dto.db_dto import AddScheduleDTO, DeleteScheduleDTO -from app.database.chroma_db import add_db_data, delete_db_data, get_chroma_client +from app.dto.db_dto import AddScheduleDTO, DeleteScheduleDTO, UpdateScheduleDTO +from app.database.chroma_db import add_db_data, delete_db_data, update_db_data, get_chroma_client router = APIRouter( prefix="/chromadb", @@ -41,11 +41,11 @@ async def delete_schedule_endpoint(schedule_data: DeleteScheduleDTO, chroma_clie raise HTTPException(status_code=500, detail=str(e)) -@router.delete("/delete_schedule", status_code=status.HTTP_204_NO_CONTENT) -async def delete_schedule_endpoint(schedule_data: DeleteScheduleDTO, chroma_client=Depends(get_chroma_client)): +@router.put("/update_schedule", status_code=status.HTTP_200_OK) +async def update_schedule_endpoint(schedule_data: UpdateScheduleDTO, chroma_client=Depends(get_chroma_client)): try: - # 직접 `add_db_data` 함수를 비동기적으로 호출합니다. - await delete_db_data(schedule_data) - return {"message": "Schedule delete successfully"} + # 데이터베이스 업데이트 함수 호출 + await update_db_data(schedule_data) + return {"message": "Schedule updated successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file From 8c49fcc36fd4d8be59fcbe1219a35cf2a3af4beb Mon Sep 17 00:00:00 2001 From: uommou Date: Sat, 25 May 2024 21:33:36 +0900 Subject: [PATCH 09/11] =?UTF-8?q?fix:=20category=20id=20=EB=B0=9B=EC=95=84?= =?UTF-8?q?=EC=98=A4=EA=B8=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/database/chroma_db.py | 22 +++++++++++++++++----- app/dto/db_dto.py | 2 ++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/app/database/chroma_db.py b/app/database/chroma_db.py index 677854a..25ccc5d 100644 --- a/app/database/chroma_db.py +++ b/app/database/chroma_db.py @@ -41,7 +41,7 @@ async def search_db_query(member_id, query): member = member_id result = schedules.query( query_texts=query, - n_results=5, # 결과에서 한 가지 문서만 반환하면 한강공원이, 두 가지 문서 반환하면 AI가 뜸->유사도가 이상하게 검사되는 것 같음 + n_results=15, # 결과에서 한 가지 문서만 반환하면 한강공원이, 두 가지 문서 반환하면 AI가 뜸->유사도가 이상하게 검사되는 것 같음 where={"member": {"$eq": int(member)}} ) return result @@ -50,16 +50,27 @@ async def search_db_query(member_id, query): # 스프링 백엔드로부터 chroma DB에 저장할 데이터를 받아 DB에 추가한다. async def add_db_data(schedule_data: AddScheduleDTO): schedule_date = schedule_data.schedule_datetime_start.split("T")[0] - year = int(schedule_date.split("-")[0]) - month = int(schedule_date.split("-")[1]) - date = int(schedule_date.split("-")[2]) + year, month, date = map(int, schedule_date.split("-")) + schedules.add( documents=[schedule_data.data], ids=[str(schedule_data.schedule_id)], - metadatas=[{"year": year, "month": month, "date": date, "datetime_start": schedule_data.schedule_datetime_start, "datetime_end": schedule_data.schedule_datetime_end, "member": schedule_data.member_id, "category": schedule_data.category, "location": schedule_data.location, "person": schedule_data.person}] + metadatas=[{ + "year": year, + "month": month, + "date": date, + "datetime_start": schedule_data.schedule_datetime_start, + "datetime_end": schedule_data.schedule_datetime_end, + "member": schedule_data.member_id, + "category": schedule_data.category, + "category_id": schedule_data.schedule_id, + "location": schedule_data.location, + "person": schedule_data.person + }] ) return True + # 메인페이지 한 줄 추천 기능에 사용하는 함수 async def delete_db_data(schedule_data: DeleteScheduleDTO): member_id = schedule_data.member_id @@ -89,6 +100,7 @@ async def update_db_data(schedule_data: UpdateScheduleDTO): "datetime_end": schedule_data.schedule_datetime_end, "member": schedule_data.member_id, "category": schedule_data.category, + "category_id": schedule_data.schedule_id, "location": schedule_data.location, "person": schedule_data.person }] diff --git a/app/dto/db_dto.py b/app/dto/db_dto.py index cac7dc9..2377e92 100644 --- a/app/dto/db_dto.py +++ b/app/dto/db_dto.py @@ -8,6 +8,7 @@ class AddScheduleDTO(BaseModel): schedule_id: int member_id: int category: str + category_id: int location: str person: str @@ -22,6 +23,7 @@ class UpdateScheduleDTO(BaseModel): schedule_id: int member_id: int category: str + category_id: int location: str person: str From b90581bc56100a7693bb03f6547c6ca3d63a5e58 Mon Sep 17 00:00:00 2001 From: uommou Date: Sat, 25 May 2024 23:12:43 +0900 Subject: [PATCH 10/11] =?UTF-8?q?feat:=20rds=20=EC=BB=A4=EB=84=A5=EC=85=98?= =?UTF-8?q?=20=EC=83=9D=EC=84=B1=20=EB=B0=8F=20=EC=A0=95=EB=B3=B4=20?= =?UTF-8?q?=EC=9D=BD=EC=96=B4=EC=98=A4=EA=B8=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/database/connect_rds.py | 31 +++++++++++++++++++++++++++++++ requirements.txt | 1 + 2 files changed, 32 insertions(+) create mode 100644 app/database/connect_rds.py diff --git a/app/database/connect_rds.py b/app/database/connect_rds.py new file mode 100644 index 0000000..c1a9ee4 --- /dev/null +++ b/app/database/connect_rds.py @@ -0,0 +1,31 @@ +import os +import pymysql +from pymysql.cursors import DictCursor +from dotenv import load_dotenv + +# .env 파일에서 환경 변수 로드 +load_dotenv() + +def get_rds_connection(): + return pymysql.connect( + host=os.getenv('RDS_HOST'), + user=os.getenv('RDS_USER'), + password=os.getenv('RDS_PASSWORD'), + database=os.getenv('RDS_DATABASE'), + cursorclass=DictCursor + ) + +def fetch_category_classification_data(member_id): + connection = get_rds_connection() + try: + with connection.cursor() as cursor: + sql = """ + SELECT c.* + FROM category c + WHERE c.member_id = %s + """ + cursor.execute(sql, (member_id,)) + result = cursor.fetchall() + return result + finally: + connection.close() diff --git a/requirements.txt b/requirements.txt index e6ace11..4f4da86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ python-dotenv==1.0.0 starlette==0.35.1 pydantic==2.5.3 sentence-transformers==2.5.1 +pymysql \ No newline at end of file From 62163fb0c35372fa56cfd9267578adfc7cbe39ec Mon Sep 17 00:00:00 2001 From: uommou Date: Sat, 25 May 2024 23:55:54 +0900 Subject: [PATCH 11/11] =?UTF-8?q?feat:=20category=20classification=20?= =?UTF-8?q?=EC=99=84=EC=84=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/prompt/openai_prompt.py | 57 ++++++++++++++++++---------------- app/routers/chat.py | 62 +++++++++++++++++++++++++------------ 2 files changed, 74 insertions(+), 45 deletions(-) diff --git a/app/prompt/openai_prompt.py b/app/prompt/openai_prompt.py index 9ee23ca..2fff551 100644 --- a/app/prompt/openai_prompt.py +++ b/app/prompt/openai_prompt.py @@ -86,37 +86,42 @@ class Template: YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. Current time is {current_time}. Respond to the user considering the current time. User input: {question} """ + case2_template = """ - {persona} - {chat_type} - The user's input contains information about a new event they want to add to their schedule. You have two tasks to perform: + {persona} + {chat_type} + The user's input contains information about a new event they want to add to their schedule. You have two tasks to perform: - 1. Respond kindly to the user's input. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. - 2. Organize the event the user wants to add into a json format for saving in a database. The returned json will have keys for info, location, person, and date. - - info: Summarizes what the user wants to do. This value must always be present. - - location: If the user's event information includes a place, save that place as the value. - - person: If th e user's event mentions a person they want to include, save that person as the value. - - date: If the user's event information includes a specific date and time, save that date and time in datetime format. Dates should be organized based on the current time at the user's location. Current time is {current_time}. - Separate the outputs for tasks 1 and 2 with a special token . + 1. Respond kindly to the user's input. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. + 2. Organize the event the user wants to add into a json format for saving in a database. The returned json will have keys for info, location, person, start_time, end_time, and category. + - info: Summarizes what the user wants to do. This value must always be present. + - location: If the user's event information includes a place, save that place as the value. + - person: If the user's event mentions a person they want to include, save that person as the value. + - start_time: If the user's event information includes a specific date and time, save that date and time in datetime format. Dates should be organized based on the current time at the user's location. Current time is {current_time}. + - end_time: If the user's event information includes an end time, save that date and time in datetime format. + - category: Choose the most appropriate category for the event from the following list: {categories}. + Separate the outputs for tasks 1 and 2 with a special token . - Example for one-shot learning: + Example for one-shot learning: - User input: I have a meeting with Dr. Smith at her office on March 3rd at 10am. + User input: I have a meeting with Dr. Smith at her office on March 3rd from 10am to 11am. - Response to user: - Shall I add your meeting with Dr. Smith at her office on March 3rd at 10am to your schedule? - - {{ - "info": "meeting with Dr. Smith", - "location": "Dr. Smith's office", - "person": "Dr. Smith", - "date": "2023-03-03T10:00:00" - }} - - User input: {question} - - Response to user: - """ + Response to user: + Shall I add your meeting with Dr. Smith at her office on March 3rd from 10am to 11am to your schedule? + + {{ + "info": "meeting with Dr. Smith", + "location": "Dr. Smith's office", + "person": "Dr. Smith", + "start_time": "2023-03-03T10:00:00", + "end_time": "2023-03-03T11:00:00", + "category": "Work" + }} + + User input: {question} + + Response to user: + """ case3_template = """ {persona} diff --git a/app/routers/chat.py b/app/routers/chat.py index 8828d4a..70fb275 100644 --- a/app/routers/chat.py +++ b/app/routers/chat.py @@ -8,6 +8,7 @@ from langchain_core.prompts import PromptTemplate from datetime import datetime +from app.database.connect_rds import fetch_category_classification_data from app.dto.openai_dto import PromptRequest, ChatResponse, ChatCaseResponse from app.prompt import openai_prompt, persona_prompt @@ -102,25 +103,48 @@ async def get_langchain_normal(data: PromptRequest, chat_type_prompt): # case 1 # case 2 : 일정 생성 #@router.post("/case/make_schedule") # 테스트용 엔드포인트 async def get_langchain_schedule(data: PromptRequest, chat_type_prompt): - print("running case 2") - # description: use langchain - config_normal = config['NESS_NORMAL'] - - chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0) - max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수 - model_name=config_normal['MODEL_NAME'], # 모델명 - openai_api_key=OPENAI_API_KEY # API 키 - ) - question = data.prompt - persona = data.persona - user_persona_prompt = persona_prompt.Template.from_persona(persona) - case2_template = openai_prompt.Template.case2_template - - prompt = PromptTemplate.from_template(case2_template) - current_time = datetime.now() - response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt)) - print(response) - return response + try: + print("running case 2") + member_id = data.member_id + categories = fetch_category_classification_data(member_id) + # 카테고리 데이터를 텍스트로 변환 (JSON 형식으로 변환) + categories_text = ", ".join([category['name'] for category in categories]) + print(categories_text) + + # description: use langchain + config_normal = config['NESS_NORMAL'] + + chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0) + max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수 + model_name=config_normal['MODEL_NAME'], # 모델명 + openai_api_key=OPENAI_API_KEY # API 키 + ) + + question = data.prompt + persona = data.persona + user_persona_prompt = persona_prompt.Template.from_persona(persona) + case2_template = openai_prompt.Template.case2_template + + prompt = PromptTemplate.from_template(case2_template) + current_time = datetime.now() + + # OpenAI 프롬프트에 데이터 통합 + response = chat_model.predict( + prompt.format( + persona=user_persona_prompt, + output_language="Korean", + question=question, + current_time=current_time, + chat_type=chat_type_prompt, + categories=categories_text # 카테고리 데이터를 프롬프트에 포함 + ) + ) + + print(response) + return response + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) # case 3 : rag #@router.post("/case/rag") # 테스트용 엔드포인트