diff --git a/app/database/chroma_db.py b/app/database/chroma_db.py index e0fa448..0d7e3d1 100644 --- a/app/database/chroma_db.py +++ b/app/database/chroma_db.py @@ -10,7 +10,7 @@ import os import datetime from dotenv import load_dotenv -from app.dto.db_dto import AddScheduleDTO +from app.dto.db_dto import AddScheduleDTO, RecommendationMainRequestDTO load_dotenv() CHROMA_DB_IP_ADDRESS = os.getenv("CHROMA_DB_IP_ADDRESS") @@ -34,25 +34,47 @@ def check_db_heartbeat(): chroma_client.heartbeat() -# description: DB에서 검색하는 함수 +# description: DB에서 검색하는 함수 - chat case 3에 사용 async def search_db_query(query): # 컬렉션 생성 # 컬렉션에 쿼리 전송 result = schedules.query( query_texts=query, - n_results=2 # 결과에서 한 가지 문서만 반환하면 한강공원이, 두 가지 문서 반환하면 AI가 뜸->유사도가 이상하게 검사되는 것 같음 + n_results=5 # 결과에서 한 가지 문서만 반환하면 한강공원이, 두 가지 문서 반환하면 AI가 뜸->유사도가 이상하게 검사되는 것 같음 ) return result # description: DB에 저장하는 함수 # 스프링 백엔드로부터 chroma DB에 저장할 데이터를 받아 DB에 추가한다. async def add_db_data(schedule_data: AddScheduleDTO): + schedule_date = schedule_data.schedule_datetime_start.split("T")[0] schedules.add( documents=[schedule_data.data], ids=[str(schedule_data.schedule_id)], - metadatas=[{"datetime_start": schedule_data.schedule_datetime_start, "datetime_end": schedule_data.schedule_datetime_end, "member": schedule_data.member_id, "category": schedule_data.category, "location": schedule_data.location, "person": schedule_data.person}] + metadatas=[{"date": schedule_date, "datetime_start": schedule_data.schedule_datetime_start, "datetime_end": schedule_data.schedule_datetime_end, "member": schedule_data.member_id, "category": schedule_data.category, "location": schedule_data.location, "person": schedule_data.person}] ) return True +# 메인페이지 한 줄 추천 기능에 사용하는 함수 +# 유저의 id, 해당 날짜로 필터링 +async def db_recommendation_main(user_data: RecommendationMainRequestDTO): + member = user_data.member_id + schedule_datetime_start = user_data.schedule_datetime_start + schedule_datetime_end = user_data.schedule_datetime_end + schedule_date = schedule_datetime_start.split("T")[0] + persona = user_data.user_persona or "hard working" + results = schedules.query( + query_texts=[persona], + n_results=5, + where={"$and": + [ + {"member": {"$eq": int(member)}}, + {"date": {"$eq": schedule_date}} + ] + } + # where_document={"$contains":"search_string"} # optional filter + ) + return results['documents'] + def get_chroma_client(): return chroma_client \ No newline at end of file diff --git a/app/dto/db_dto.py b/app/dto/db_dto.py index 8d16d8b..8228fd5 100644 --- a/app/dto/db_dto.py +++ b/app/dto/db_dto.py @@ -11,3 +11,9 @@ class AddScheduleDTO(BaseModel): location: str person: str +class RecommendationMainRequestDTO(BaseModel): + member_id: int + user_persona: str + schedule_datetime_start: str + schedule_datetime_end: str + diff --git a/app/dto/openai_dto.py b/app/dto/openai_dto.py index 3b5f905..cfa374f 100644 --- a/app/dto/openai_dto.py +++ b/app/dto/openai_dto.py @@ -5,4 +5,8 @@ class PromptRequest(BaseModel): prompt: str class ChatResponse(BaseModel): - ness: str \ No newline at end of file + ness: str + +class ChatCaseResponse(BaseModel): + ness: str + case: int \ No newline at end of file diff --git a/app/prompt/openai_config.ini b/app/prompt/openai_config.ini index 258efc3..8a62d43 100644 --- a/app/prompt/openai_config.ini +++ b/app/prompt/openai_config.ini @@ -1,4 +1,14 @@ [NESS_NORMAL] TEMPERATURE = 0 MAX_TOKENS = 2048 +MODEL_NAME = gpt-3.5-turbo-1106 + +[NESS_CASE] +TEMPERATURE = 0 +MAX_TOKENS = 2048 +MODEL_NAME = gpt-4 + +[NESS_RECOMMENDATION] +TEMPERATURE = 0 +MAX_TOKENS = 2048 MODEL_NAME = gpt-3.5-turbo-1106 \ No newline at end of file diff --git a/app/prompt/openai_prompt.py b/app/prompt/openai_prompt.py index 2a38add..e17f97d 100644 --- a/app/prompt/openai_prompt.py +++ b/app/prompt/openai_prompt.py @@ -1,5 +1,17 @@ class Template: - recommendation_template = """한 줄 추천 기능 템플릿""" + recommendation_template = """ + You are an AI assistant designed to recommend daily activities based on a user's schedule. You will receive a day's worth of the user's schedule information. Your task is to understand that schedule and, based on it, recommend an activity for the user to perform that day. There are a few rules you must follow in your recommendations: + 1. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. + 2. Ensure your recommendation is encouraging, so the user doesn't feel compelled. + 3. The recommendation must be concise, limited to one sentence without any additional commentary. + + Example: + User schedule: [Practice guitar, Calculate accuracy, Study backend development, Run AI models in the lab, Study NEST.JS] + AI Recommendation: "Your day is filled with learning and research. how about taking a short walk in between studies?" + + User schedule: {schedule} + AI Recommendation: + """ # case 분류 잘 안됨 - 수정 필요 case_classify_template = """ Task: User Chat Classification diff --git a/app/routers/chat.py b/app/routers/chat.py index 1c0b882..400ded8 100644 --- a/app/routers/chat.py +++ b/app/routers/chat.py @@ -7,7 +7,7 @@ from langchain_community.chat_models import ChatOpenAI from langchain_core.prompts import PromptTemplate -from app.dto.openai_dto import PromptRequest, ChatResponse +from app.dto.openai_dto import PromptRequest, ChatResponse, ChatCaseResponse from app.prompt import openai_prompt import app.database.chroma_db as vectordb @@ -26,15 +26,15 @@ config = configparser.ConfigParser() config.read(CONFIG_FILE_PATH) -@router.post("/case", status_code=status.HTTP_200_OK, response_model=ChatResponse) -async def get_langchain_case(data: PromptRequest) -> ChatResponse: +@router.post("/case", status_code=status.HTTP_200_OK, response_model=ChatCaseResponse) +async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse: # description: use langchain - config_normal = config['NESS_NORMAL'] + config_chat = config['NESS_CHAT'] - chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0) - max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수 - model_name=config_normal['MODEL_NAME'], # 모델명 + chat_model = ChatOpenAI(temperature=config_chat['TEMPERATURE'], # 창의성 (0.0 ~ 2.0) + max_tokens=config_chat['MAX_TOKENS'], # 최대 토큰수 + model_name=config_chat['MODEL_NAME'], # 모델명 openai_api_key=OPENAI_API_KEY # API 키 ) question = data.prompt @@ -48,28 +48,32 @@ async def get_langchain_case(data: PromptRequest) -> ChatResponse: print(case) case = int(case) if case == 1: - return await get_langchain_normal(data) + response = await get_langchain_normal(data) elif case == 2: - return await get_langchain_schedule(data) + response = await get_langchain_schedule(data) elif case == 3: - return await get_langchain_rag(data) + response = await get_langchain_rag(data) else: print("wrong case classification") # 적절한 HTTP 상태 코드와 함께 오류 메시지를 반환하거나, 다른 처리를 할 수 있습니다. raise HTTPException(status_code=400, detail="Wrong case classification") + return ChatCaseResponse(ness=response, case=case) + # case 1 : normal #@router.post("/case/normal") # 테스트용 엔드포인트 -async def get_langchain_normal(data: PromptRequest) -> ChatResponse: # case 1 : normal +async def get_langchain_normal(data: PromptRequest): # case 1 : normal print("running case 1") # description: use langchain - chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0) - max_tokens=2048, # 최대 토큰수 - model_name='gpt-3.5-turbo-1106', # 모델명 + config_normal = config['NESS_NORMAL'] + + chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0) + max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수 + model_name=config_normal['MODEL_NAME'], # 모델명 openai_api_key=OPENAI_API_KEY # API 키 ) question = data.prompt @@ -82,16 +86,18 @@ async def get_langchain_normal(data: PromptRequest) -> ChatResponse: # case 1 : prompt = PromptTemplate.from_template(my_template) response = chat_model.predict(prompt.format(output_language="Korean", question=question)) print(response) - return ChatResponse(ness=response) + return response # case 2 : 일정 생성 #@router.post("/case/make_schedule") # 테스트용 엔드포인트 -async def get_langchain_schedule(data: PromptRequest) -> ChatResponse: +async def get_langchain_schedule(data: PromptRequest): print("running case 2") # description: use langchain - chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0) - max_tokens=2048, # 최대 토큰수 - model_name='gpt-3.5-turbo-1106', # 모델명 + config_normal = config['NESS_NORMAL'] + + chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0) + max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수 + model_name=config_normal['MODEL_NAME'], # 모델명 openai_api_key=OPENAI_API_KEY # API 키 ) question = data.prompt @@ -100,16 +106,18 @@ async def get_langchain_schedule(data: PromptRequest) -> ChatResponse: prompt = PromptTemplate.from_template(case2_template) response = chat_model.predict(prompt.format(output_language="Korean", question=question)) print(response) - return ChatResponse(ness=response) + return response # case 3 : rag #@router.post("/case/rag") # 테스트용 엔드포인트 -async def get_langchain_rag(data: PromptRequest) -> ChatResponse: +async def get_langchain_rag(data: PromptRequest): print("running case 3") # description: use langchain - chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0) - max_tokens=2048, # 최대 토큰수 - model_name='gpt-3.5-turbo-1106', # 모델명 + config_normal = config['NESS_NORMAL'] + + chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0) + max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수 + model_name=config_normal['MODEL_NAME'], # 모델명 openai_api_key=OPENAI_API_KEY # API 키 ) question = data.prompt @@ -123,4 +131,4 @@ async def get_langchain_rag(data: PromptRequest) -> ChatResponse: prompt = PromptTemplate.from_template(case3_template) response = chat_model.predict(prompt.format(output_language="Korean", question=question, schedule=schedule)) print(response) - return ChatResponse(ness=response) + return response diff --git a/app/routers/recommendation.py b/app/routers/recommendation.py index d3116a6..2ca270f 100644 --- a/app/routers/recommendation.py +++ b/app/routers/recommendation.py @@ -2,11 +2,14 @@ import os from dotenv import load_dotenv -from fastapi import APIRouter +from fastapi import APIRouter, Depends, status, HTTPException from langchain_community.chat_models import ChatOpenAI from langchain_core.prompts import PromptTemplate +from app.dto.db_dto import RecommendationMainRequestDTO +from app.dto.openai_dto import ChatResponse from app.prompt import openai_prompt +import app.database.chroma_db as vectordb router = APIRouter( prefix="/recommendation", @@ -22,19 +25,30 @@ config = configparser.ConfigParser() config.read(CONFIG_FILE_PATH) -@router.get("/main") -async def get_recommendation(): +@router.post("/main", status_code=status.HTTP_200_OK) +async def get_recommendation(user_data: RecommendationMainRequestDTO) -> ChatResponse: + try: + # 모델 + config_recommendation = config['NESS_RECOMMENDATION'] - # 모델 - chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0) - max_tokens=2048, # 최대 토큰수 - model_name='gpt-3.5-turbo-1106', # 모델명 - openai_api_key=OPENAI_API_KEY # API 키 - ) + chat_model = ChatOpenAI(temperature=config_recommendation['TEMPERATURE'], # 창의성 (0.0 ~ 2.0) + max_tokens=config_recommendation['MAX_TOKENS'], # 최대 토큰수 + model_name=config_recommendation['MODEL_NAME'], # 모델명 + openai_api_key=OPENAI_API_KEY # API 키 + ) + # vectordb에서 유저의 정보를 가져온다. + schedule = await vectordb.db_recommendation_main(user_data) - # 템플릿 - recommendation_template = openai_prompt.Template.recommendation_template + print(schedule) - prompt = PromptTemplate.from_template(recommendation_template) - return chat_model.predict(prompt.format()) \ No newline at end of file + # 템플릿 + recommendation_template = openai_prompt.Template.recommendation_template + + prompt = PromptTemplate.from_template(recommendation_template) + result = chat_model.predict(prompt.format(output_language="Korean", schedule=schedule)) + print(result) + return ChatResponse(ness=result) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file