-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from boostcampaitech6/sanggi
update: spa + sanggi
- Loading branch information
Showing
8 changed files
with
1,302 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from fastapi import FastAPI | ||
from fastapi.responses import HTMLResponse, FileResponse | ||
from fastapi.staticfiles import StaticFiles | ||
from pydantic import BaseModel | ||
from typing import Optional | ||
import uvicorn | ||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | ||
from contextlib import asynccontextmanager | ||
from schemas import LineRequest, PoemRequest | ||
from dependency import load_model_tokenizer, get_model_tokenizer, load_poem_model_tokenizer, get_poem_model_tokenizer | ||
from config import config | ||
from loguru import logger | ||
|
||
# from openai import OpenAI | ||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): | ||
|
||
# 모델&토큰나이저 로드 | ||
# load_model_tokenizer(config.model_path) | ||
load_poem_model_tokenizer(config.poem_model_path) | ||
logger.info("Loading model") | ||
yield | ||
|
||
app = FastAPI(lifespan=lifespan) | ||
template = FileResponse('template/index.html') | ||
app.mount("/static", StaticFiles(directory="template/static"), name="static") | ||
|
||
@app.get("/", response_class=HTMLResponse) | ||
async def home(): | ||
return template | ||
|
||
@app.post("/api/line") | ||
async def generate_line(request: LineRequest): | ||
emotion = request.emotion | ||
|
||
lines = [emotion, emotion, emotion] | ||
# model, tokenizer = get_model_tokenizer() | ||
# # input 데이터 전처리 | ||
# inputs = tokenizer(emotion, return_tensors="pt") | ||
# # 문장 생성 | ||
# for i in range(3): | ||
# output = model.generate(**inputs, do_sample=True) | ||
# decoded_output = tokenizer.decode(output[0], skip_special_tokens=True) | ||
# lines.append(decoded_output) | ||
|
||
return { "lines": lines} | ||
|
||
@app.post("/api/poem") | ||
async def generate_poem(request: PoemRequest): | ||
model, tokenizer = get_poem_model_tokenizer() | ||
|
||
line = request.line + '\n' | ||
|
||
# 이미지 생성 | ||
# OpenAI API_KEY 설정 | ||
# API_KEY = None | ||
# client = OpenAI(api_key=API_KEY) | ||
# response = client.images.generate(model='dall-e-3', | ||
# prompt=line, | ||
# size='1024x1024', | ||
# quality='standard', | ||
# n=1) | ||
# generated_image_url = response.data[0].url | ||
|
||
# 시 생성 | ||
input_ids = tokenizer.encode(line, add_special_tokens=True, return_tensors='pt') | ||
output = model.generate( | ||
input_ids=input_ids, | ||
temperature=0.2, # 생성 다양성 조절 | ||
max_new_tokens=64, # 생성되는 문장의 최대 길이 | ||
top_k=25, # 높은 확률을 가진 top-k 토큰만 고려 | ||
top_p=0.95, # 누적 확률이 p를 초과하는 토큰은 제외 | ||
repetition_penalty=1.2, # 반복을 줄이는 패널티 | ||
do_sample=True, # 샘플링 기반 생성 활성화 | ||
num_return_sequences=1, # 생성할 시퀀스의 수 | ||
) | ||
poem = tokenizer.decode(output[0].tolist(), skip_special_tokens=True) | ||
|
||
return { "poem": poem} | ||
|
||
if __name__ == "__main__": | ||
uvicorn.run(app, host="127.0.0.1", port=8000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,39 @@ | ||
sentence_generator = None | ||
poem_generator = None | ||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedTokenizerFast, GPT2LMHeadModel | ||
|
||
def load_sentence_generator(model_path: str): | ||
import joblib | ||
model = None | ||
tokenizer = None | ||
|
||
global sentence_generator | ||
sentence_generator = joblib.load(model_path) | ||
poem_model = None | ||
poem_tokenizer = None | ||
|
||
def get_sentence_generator(): | ||
global sentence_generator | ||
return sentence_generator | ||
def load_model_tokenizer(model_path: str): | ||
|
||
global model | ||
global tokenizer | ||
|
||
model = AutoModelForSeq2SeqLM.from_pretrained(model_path) | ||
tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
|
||
def get_model_tokenizer(): | ||
|
||
global model | ||
global tokenizer | ||
|
||
return model,tokenizer | ||
|
||
|
||
def load_poem_generator(model_path: str): | ||
import joblib | ||
|
||
global poem_generator | ||
poem_generator = joblib.load(model_path) | ||
|
||
def get_poem_generator(): | ||
global poem_generator | ||
return poem_generator | ||
def load_poem_model_tokenizer(poem_model_path: str): | ||
|
||
global poem_model | ||
global poem_tokenizer | ||
|
||
poem_model = GPT2LMHeadModel.from_pretrained(poem_model_path) | ||
poem_tokenizer = PreTrainedTokenizerFast.from_pretrained("skt/kogpt2-base-v2", | ||
bos_token='</s>', eos_token='</s>', unk_token='<unk>', | ||
pad_token='<pad>', mask_token='<mask>' | ||
) | ||
|
||
def get_poem_model_tokenizer(): | ||
global poem_model | ||
global poem_tokenizer | ||
|
||
return poem_model, poem_tokenizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,8 @@ | ||
from pydantic import BaseModel, Field | ||
from typing import Optional | ||
|
||
# 형용사 단어를 선택했을 때 문장생성을 위해 모델에 전달할 Request Schema | ||
class SentenceGenerationRequest(BaseModel): | ||
input_word: str | ||
class LineRequest(BaseModel): | ||
emotion: Optional[str] = None | ||
|
||
# 생성된 문장을 받는 Response Schema | ||
class SentenceGenerationResponse(BaseModel): | ||
result_sentence1: str | ||
result_sentence2: str | ||
result_sentence3: str | ||
|
||
# 시를 생성하기 시작할 문장을 모델에 전달할 Request Schema | ||
class PoemImageGenerationRequest(BaseModel): | ||
input_sentence: str | ||
|
||
# 생성된 시와 이미지를 받는 Response Schema | ||
class PoemImageGenerationResponse(BaseModel): | ||
result_poem: str | ||
result_image_url: Optional[str] = Field(default='OPEN_AI_API_KEY is not selected', description="The path to the generated image") | ||
class PoemRequest(BaseModel): | ||
line: Optional[str] = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
<!DOCTYPE html> | ||
<html lang="ko"> | ||
<head> | ||
<meta charset="UTF-8"> | ||
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | ||
<script src="https://cdn.tailwindcss.com"></script> | ||
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script> | ||
<link href="https://fonts.googleapis.com/css2?family=Noto+Sans+KR:wght@400;700&display=swap" rel="stylesheet"> | ||
<link rel="stylesheet" type="text/css" href="./static/css/style.css"> | ||
<title>오늘의 시</title> | ||
|
||
<style> | ||
body { | ||
font-family: 'Noto Sans KR', sans-serif; | ||
} | ||
|
||
#firstPage { | ||
display: block; | ||
} | ||
|
||
#secondPage { | ||
display: none; | ||
} | ||
|
||
#thirdPage { | ||
display: none; | ||
} | ||
|
||
</style> | ||
</head> | ||
<body class="bg-gray-100"> | ||
<script src="http://ajax.googleapis.com/ajax/libs/jquery/1.10.2/jquery.min.js"></script> | ||
<script> | ||
$(window).load(function () { | ||
$('body').sakura(); | ||
}); | ||
</script> | ||
|
||
<div class="flex justify-center items-center min-h-screen"> | ||
<div class="container mx-auto p-4 max-w-2xl"> | ||
<!-- #1 감정 선택 페이지 --> | ||
<div id="firstPage" class="bg-white shadow-md rounded px-8 pt-6 pb-8 mb-4"> | ||
<h1 class="block text-gray-700 text-xl font-bold mb-2">오늘의 기분은?</h1> | ||
<p class="text-gray-700 mb-4">오늘 하루 어떠셨나요? 오늘의 감정을 선택해 보세요!</p> | ||
<form id="emotionOptions"> | ||
<div class="mb-4"> | ||
<label class="inline-flex items-center"> | ||
<input type="radio" class="form-radio" name="emotion" value="행복하다"> | ||
<span class="ml-2">행복하다</span> | ||
</label> | ||
<label class="inline-flex items-center ml-6"> | ||
<input type="radio" class="form-radio" name="emotion" value="슬프다"> | ||
<span class="ml-2">슬프다</span> | ||
</label> | ||
<label class="inline-flex items-center ml-6"> | ||
<input type="radio" class="form-radio" name="emotion" value="화나다"> | ||
<span class="ml-2">화나다</span> | ||
</label> | ||
<label class="inline-flex items-center ml-6"> | ||
<input type="radio" class="form-radio" name="emotion" value="지루하다"> | ||
<span class="ml-2">지루하다</span> | ||
</label> | ||
<label class="inline-flex items-center ml-6"> | ||
<input type="radio" class="form-radio" name="emotion" value="놀라다"> | ||
<span class="ml-2">놀라다</span> | ||
</label> | ||
</div> | ||
<button id="submitButton" type="button" class="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded" onclick="submitEmotion()">선택완료</button> | ||
</form> | ||
</div> | ||
|
||
<!-- #2 구절 선택 페이지 --> | ||
<div id="secondPage" class="hidden bg-white shadow-md rounded px-8 pt-6 pb-8 mb-4"> | ||
<h1 class="block text-gray-700 text-xl font-bold mb-2">마음에 드는 구절은?</h1> | ||
<p class="text-gray-700 mb-4">마음에 드는 문장을 선택하면 해당 구절로 시를 생성해 드려요!</p> | ||
<div class="mb-4"> | ||
<div class="flex items-center mb-2"> | ||
<input id="line1" type="radio" name="line" value="line1" class="form-radio mr-2"> | ||
<label for="line1" class="text-gray-700">line</label> | ||
</div> | ||
<div class="flex items-center mb-2"> | ||
<input id="line2" type="radio" name="line" value="line2" class="form-radio mr-2"> | ||
<label for="line2" class="text-gray-700">line</label> | ||
</div> | ||
<div class="flex items-center mb-2"> | ||
<input id="line3" type="radio" name="line" value="line3" class="form-radio mr-2"> | ||
<label for="line3" class="text-gray-700">line</label> | ||
</div> | ||
</div> | ||
<div class="flex justify-between"> | ||
<button id="reButton" type="button" class="px-4 py-2 bg-gray-200 text-gray-700 rounded hover:bg-gray-300 transition duration-300" onclick="submitEmotion()">다시 생성</button> | ||
<div> | ||
<button id="submitButton" type="submit" class="px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 transition duration-300 mr-2" onclick="submitLine()">선택완료</button> | ||
<button type="button" class="px-4 py-2 bg-gray-200 text-gray-700 rounded hover:bg-gray-300 transition duration-300" onclick="back()">뒤로가기</button> | ||
</div> | ||
</div> | ||
</div> | ||
|
||
<!-- #3 시 생성 페이지 --> | ||
<div id="thirdPage" class="hidden bg-white shadow-md rounded px-8 pt-6 pb-8 mb-4"> | ||
<h1 class="block text-gray-700 text-xl font-bold mb-2">당신의 오늘과 어울리는 시</h1> | ||
<div class="flex mb-4"> | ||
<div class="w-1/2 mr-2"> | ||
<img src="https://via.placeholder.com/150" alt="이미지" class="rounded"> | ||
</div> | ||
<div class="w-1/2 ml-2"> | ||
<pre id=poemContent class="text-gray-700">생성된 시</pre> | ||
</div> | ||
</div> | ||
<div class="flex justify-center gap-4 mb-4"> | ||
<button id="saveButton" class="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded"> | ||
이미지와 시 저장 | ||
</button> | ||
<button class="bg-gray-300 hover:bg-gray-400 text-gray-800 font-bold py-2 px-4 rounded" onclick="home()"> | ||
메인으로 돌아가기 | ||
</button> | ||
</div> | ||
</div> | ||
</div> | ||
</div> | ||
<script> | ||
function submitEmotion() { | ||
var emotion = $('input[name="emotion"]:checked').val(); // 선택된 라디오 버튼 값 가져오기 | ||
if(emotion) { // 선택된 감정이 있을 경우 | ||
generateLine(emotion); // generateLine 함수 호출 | ||
} else { | ||
// 감정이 선택되지 않았을 경우의 처리, 예: 알림 표시 | ||
alert("감정을 선택해주세요."); | ||
} | ||
} | ||
function generateLine(emotion) { | ||
$.ajax({ | ||
type: "POST", | ||
url: "/api/line", | ||
contentType: "application/json", // 요청의 Content-Type을 application/json으로 명시 | ||
data: JSON.stringify({ 'emotion': emotion }), // 객체를 JSON 문자열로 변환 | ||
success: function (response) { | ||
lines = response.lines | ||
showLine(lines) | ||
} | ||
}) | ||
} | ||
function showLine(lines) { | ||
// firstPage를 숨깁니다. | ||
document.getElementById('firstPage').style.display = 'none'; | ||
document.getElementById('secondPage').style.display = 'block'; | ||
|
||
for (let i = 0; i < lines.length; i++) { | ||
document.getElementById(`line${i+1}`).value = lines[i]; | ||
document.querySelector(`label[for=line${i+1}]`).textContent = lines[i]; | ||
} | ||
$('input[name="line"]').prop('checked', false); | ||
} | ||
function submitLine() { | ||
var line = $('input[name="line"]:checked').val(); // 선택된 라디오 버튼 값 가져오기 | ||
if(line) { // 선택된 구절이 있을 경우 | ||
generatePoem(line); // generatePoem 함수 호출 | ||
} else { | ||
// 구절이 선택되지 않았을 경우의 처리 | ||
alert("구절을 선택해주세요."); | ||
} | ||
} | ||
function generatePoem(line) { | ||
$.ajax({ | ||
type: "POST", | ||
url: "/api/poem", | ||
contentType: "application/json", // 요청의 Content-Type을 application/json으로 명시 | ||
data: JSON.stringify({ 'line': line }), // 객체를 JSON 문자열로 변환 | ||
success: function (response) { | ||
poem = response.poem | ||
showPoem(poem) | ||
} | ||
}) | ||
} | ||
function showPoem(poem) { | ||
document.getElementById('secondPage').style.display = 'none'; | ||
document.getElementById('thirdPage').style.display = 'block'; | ||
|
||
var poemContentElement = document.getElementById('poemContent'); | ||
poemContentElement.textContent = poem; | ||
} | ||
function back() { | ||
document.getElementById('firstPage').style.display = 'block'; | ||
document.getElementById('secondPage').style.display = 'none'; | ||
$('input[name="emotion"]').prop('checked', false); | ||
$('input[name="line"]').prop('checked', false); | ||
} | ||
function home() { | ||
document.getElementById('firstPage').style.display = 'block'; | ||
document.getElementById('secondPage').style.display = 'none'; | ||
document.getElementById('thirdPage').style.display = 'none'; | ||
$('input[name="emotion"]').prop('checked', false); | ||
$('input[name="line"]').prop('checked', false); | ||
} | ||
</script> | ||
<script src="./static/js/jquery_sakura.js"></script> | ||
</body> | ||
</html> |
Oops, something went wrong.