Skip to content

Commit

Permalink
Merge pull request #42 from studio-recoding/dev
Browse files Browse the repository at this point in the history
[feat] 메인 페이지 기능 1차 개발 완료
  • Loading branch information
uommou authored Apr 7, 2024
2 parents 0117e37 + c723048 commit 91a1abd
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 44 deletions.
30 changes: 26 additions & 4 deletions app/database/chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import datetime
from dotenv import load_dotenv
from app.dto.db_dto import AddScheduleDTO
from app.dto.db_dto import AddScheduleDTO, RecommendationMainRequestDTO

load_dotenv()
CHROMA_DB_IP_ADDRESS = os.getenv("CHROMA_DB_IP_ADDRESS")
Expand All @@ -34,25 +34,47 @@
def check_db_heartbeat():
chroma_client.heartbeat()

# description: DB에서 검색하는 함수
# description: DB에서 검색하는 함수 - chat case 3에 사용
async def search_db_query(query):
# 컬렉션 생성
# 컬렉션에 쿼리 전송
result = schedules.query(
query_texts=query,
n_results=2 # 결과에서 한 가지 문서만 반환하면 한강공원이, 두 가지 문서 반환하면 AI가 뜸->유사도가 이상하게 검사되는 것 같음
n_results=5 # 결과에서 한 가지 문서만 반환하면 한강공원이, 두 가지 문서 반환하면 AI가 뜸->유사도가 이상하게 검사되는 것 같음
)
return result

# description: DB에 저장하는 함수
# 스프링 백엔드로부터 chroma DB에 저장할 데이터를 받아 DB에 추가한다.
async def add_db_data(schedule_data: AddScheduleDTO):
schedule_date = schedule_data.schedule_datetime_start.split("T")[0]
schedules.add(
documents=[schedule_data.data],
ids=[str(schedule_data.schedule_id)],
metadatas=[{"datetime_start": schedule_data.schedule_datetime_start, "datetime_end": schedule_data.schedule_datetime_end, "member": schedule_data.member_id, "category": schedule_data.category, "location": schedule_data.location, "person": schedule_data.person}]
metadatas=[{"date": schedule_date, "datetime_start": schedule_data.schedule_datetime_start, "datetime_end": schedule_data.schedule_datetime_end, "member": schedule_data.member_id, "category": schedule_data.category, "location": schedule_data.location, "person": schedule_data.person}]
)
return True

# 메인페이지 한 줄 추천 기능에 사용하는 함수
# 유저의 id, 해당 날짜로 필터링
async def db_recommendation_main(user_data: RecommendationMainRequestDTO):
member = user_data.member_id
schedule_datetime_start = user_data.schedule_datetime_start
schedule_datetime_end = user_data.schedule_datetime_end
schedule_date = schedule_datetime_start.split("T")[0]
persona = user_data.user_persona or "hard working"
results = schedules.query(
query_texts=[persona],
n_results=5,
where={"$and":
[
{"member": {"$eq": int(member)}},
{"date": {"$eq": schedule_date}}
]
}
# where_document={"$contains":"search_string"} # optional filter
)
return results['documents']

def get_chroma_client():
return chroma_client
6 changes: 6 additions & 0 deletions app/dto/db_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@ class AddScheduleDTO(BaseModel):
location: str
person: str

class RecommendationMainRequestDTO(BaseModel):
member_id: int
user_persona: str
schedule_datetime_start: str
schedule_datetime_end: str

6 changes: 5 additions & 1 deletion app/dto/openai_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ class PromptRequest(BaseModel):
prompt: str

class ChatResponse(BaseModel):
ness: str
ness: str

class ChatCaseResponse(BaseModel):
ness: str
case: int
10 changes: 10 additions & 0 deletions app/prompt/openai_config.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
[NESS_NORMAL]
TEMPERATURE = 0
MAX_TOKENS = 2048
MODEL_NAME = gpt-3.5-turbo-1106

[NESS_CASE]
TEMPERATURE = 0
MAX_TOKENS = 2048
MODEL_NAME = gpt-4

[NESS_RECOMMENDATION]
TEMPERATURE = 0
MAX_TOKENS = 2048
MODEL_NAME = gpt-3.5-turbo-1106
14 changes: 13 additions & 1 deletion app/prompt/openai_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
class Template:
recommendation_template = """한 줄 추천 기능 템플릿"""
recommendation_template = """
You are an AI assistant designed to recommend daily activities based on a user's schedule. You will receive a day's worth of the user's schedule information. Your task is to understand that schedule and, based on it, recommend an activity for the user to perform that day. There are a few rules you must follow in your recommendations:
1. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT.
2. Ensure your recommendation is encouraging, so the user doesn't feel compelled.
3. The recommendation must be concise, limited to one sentence without any additional commentary.
Example:
User schedule: [Practice guitar, Calculate accuracy, Study backend development, Run AI models in the lab, Study NEST.JS]
AI Recommendation: "Your day is filled with learning and research. how about taking a short walk in between studies?"
User schedule: {schedule}
AI Recommendation:
"""
# case 분류 잘 안됨 - 수정 필요
case_classify_template = """
Task: User Chat Classification
Expand Down
58 changes: 33 additions & 25 deletions app/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from langchain_community.chat_models import ChatOpenAI
from langchain_core.prompts import PromptTemplate

from app.dto.openai_dto import PromptRequest, ChatResponse
from app.dto.openai_dto import PromptRequest, ChatResponse, ChatCaseResponse
from app.prompt import openai_prompt

import app.database.chroma_db as vectordb
Expand All @@ -26,15 +26,15 @@
config = configparser.ConfigParser()
config.read(CONFIG_FILE_PATH)

@router.post("/case", status_code=status.HTTP_200_OK, response_model=ChatResponse)
async def get_langchain_case(data: PromptRequest) -> ChatResponse:
@router.post("/case", status_code=status.HTTP_200_OK, response_model=ChatCaseResponse)
async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse:
# description: use langchain

config_normal = config['NESS_NORMAL']
config_chat = config['NESS_CHAT']

chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수
model_name=config_normal['MODEL_NAME'], # 모델명
chat_model = ChatOpenAI(temperature=config_chat['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
max_tokens=config_chat['MAX_TOKENS'], # 최대 토큰수
model_name=config_chat['MODEL_NAME'], # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)
question = data.prompt
Expand All @@ -48,28 +48,32 @@ async def get_langchain_case(data: PromptRequest) -> ChatResponse:
print(case)
case = int(case)
if case == 1:
return await get_langchain_normal(data)
response = await get_langchain_normal(data)

elif case == 2:
return await get_langchain_schedule(data)
response = await get_langchain_schedule(data)

elif case == 3:
return await get_langchain_rag(data)
response = await get_langchain_rag(data)

else:
print("wrong case classification")
# 적절한 HTTP 상태 코드와 함께 오류 메시지를 반환하거나, 다른 처리를 할 수 있습니다.
raise HTTPException(status_code=400, detail="Wrong case classification")

return ChatCaseResponse(ness=response, case=case)


# case 1 : normal
#@router.post("/case/normal") # 테스트용 엔드포인트
async def get_langchain_normal(data: PromptRequest) -> ChatResponse: # case 1 : normal
async def get_langchain_normal(data: PromptRequest): # case 1 : normal
print("running case 1")
# description: use langchain
chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0)
max_tokens=2048, # 최대 토큰수
model_name='gpt-3.5-turbo-1106', # 모델명
config_normal = config['NESS_NORMAL']

chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수
model_name=config_normal['MODEL_NAME'], # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)
question = data.prompt
Expand All @@ -82,16 +86,18 @@ async def get_langchain_normal(data: PromptRequest) -> ChatResponse: # case 1 :
prompt = PromptTemplate.from_template(my_template)
response = chat_model.predict(prompt.format(output_language="Korean", question=question))
print(response)
return ChatResponse(ness=response)
return response

# case 2 : 일정 생성
#@router.post("/case/make_schedule") # 테스트용 엔드포인트
async def get_langchain_schedule(data: PromptRequest) -> ChatResponse:
async def get_langchain_schedule(data: PromptRequest):
print("running case 2")
# description: use langchain
chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0)
max_tokens=2048, # 최대 토큰수
model_name='gpt-3.5-turbo-1106', # 모델명
config_normal = config['NESS_NORMAL']

chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수
model_name=config_normal['MODEL_NAME'], # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)
question = data.prompt
Expand All @@ -100,16 +106,18 @@ async def get_langchain_schedule(data: PromptRequest) -> ChatResponse:
prompt = PromptTemplate.from_template(case2_template)
response = chat_model.predict(prompt.format(output_language="Korean", question=question))
print(response)
return ChatResponse(ness=response)
return response

# case 3 : rag
#@router.post("/case/rag") # 테스트용 엔드포인트
async def get_langchain_rag(data: PromptRequest) -> ChatResponse:
async def get_langchain_rag(data: PromptRequest):
print("running case 3")
# description: use langchain
chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0)
max_tokens=2048, # 최대 토큰수
model_name='gpt-3.5-turbo-1106', # 모델명
config_normal = config['NESS_NORMAL']

chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수
model_name=config_normal['MODEL_NAME'], # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)
question = data.prompt
Expand All @@ -123,4 +131,4 @@ async def get_langchain_rag(data: PromptRequest) -> ChatResponse:
prompt = PromptTemplate.from_template(case3_template)
response = chat_model.predict(prompt.format(output_language="Korean", question=question, schedule=schedule))
print(response)
return ChatResponse(ness=response)
return response
40 changes: 27 additions & 13 deletions app/routers/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import os

from dotenv import load_dotenv
from fastapi import APIRouter
from fastapi import APIRouter, Depends, status, HTTPException
from langchain_community.chat_models import ChatOpenAI
from langchain_core.prompts import PromptTemplate

from app.dto.db_dto import RecommendationMainRequestDTO
from app.dto.openai_dto import ChatResponse
from app.prompt import openai_prompt
import app.database.chroma_db as vectordb

router = APIRouter(
prefix="/recommendation",
Expand All @@ -22,19 +25,30 @@
config = configparser.ConfigParser()
config.read(CONFIG_FILE_PATH)

@router.get("/main")
async def get_recommendation():
@router.post("/main", status_code=status.HTTP_200_OK)
async def get_recommendation(user_data: RecommendationMainRequestDTO) -> ChatResponse:
try:
# 모델
config_recommendation = config['NESS_RECOMMENDATION']

# 모델
chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0)
max_tokens=2048, # 최대 토큰수
model_name='gpt-3.5-turbo-1106', # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)
chat_model = ChatOpenAI(temperature=config_recommendation['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
max_tokens=config_recommendation['MAX_TOKENS'], # 최대 토큰수
model_name=config_recommendation['MODEL_NAME'], # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)

# vectordb에서 유저의 정보를 가져온다.
schedule = await vectordb.db_recommendation_main(user_data)

# 템플릿
recommendation_template = openai_prompt.Template.recommendation_template
print(schedule)

prompt = PromptTemplate.from_template(recommendation_template)
return chat_model.predict(prompt.format())
# 템플릿
recommendation_template = openai_prompt.Template.recommendation_template

prompt = PromptTemplate.from_template(recommendation_template)
result = chat_model.predict(prompt.format(output_language="Korean", schedule=schedule))
print(result)
return ChatResponse(ness=result)

except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

0 comments on commit 91a1abd

Please sign in to comment.