-
Notifications
You must be signed in to change notification settings - Fork 1
/
joonhai_rag
119 lines (96 loc) · 4.62 KB
/
joonhai_rag
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# utils/rag.py
import sys
import os
from langchain_community.document_loaders import PyPDFLoader
from sentence_transformers import SentenceTransformer, util
from rank_bm25 import BM25Okapi
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import gradio as gr
def load_documents(pdf_path):
"""PDF 파일에서 문서 로드 및 내용 추출"""
if not os.path.exists(pdf_path):
raise ValueError(f"File path {pdf_path} is not valid. Please check if the file exists.")
pdf_loader = PyPDFLoader(pdf_path)
pdf_docs = pdf_loader.load()
return [doc.page_content for doc in pdf_docs]
def load_qa_data(qa_path, an_path):
"""질의응답 데이터 로드"""
with open(qa_path, 'r', encoding='utf-8') as f:
questions = [line.strip() for line in f.readlines()]
with open(an_path, 'r', encoding='utf-8') as f:
answers = [line.strip() for line in f.readlines()]
return questions, answers
def setup_bm25(documents):
"""BM25 설정"""
tokenized_docs = [doc.split(" ") for doc in documents]
return BM25Okapi(tokenized_docs)
def setup_embedding_model(model_name, documents):
"""임베딩 모델 설정 및 문서 임베딩 생성"""
embedding_model = SentenceTransformer(model_name)
document_embeddings = embedding_model.encode(documents, convert_to_tensor=True)
return embedding_model, document_embeddings
def hybrid_search(query, documents, bm25, embedding_model, document_embeddings, bm25_weight=0.3, embedding_weight=0.7):
"""하이브리드 검색 함수"""
# 키워드 검색
tokenized_query = query.split(" ")
bm25_scores = bm25.get_scores(tokenized_query)
# 의미 검색
query_embedding = embedding_model.encode(query, convert_to_tensor=True)
cosine_scores = util.pytorch_cos_sim(query_embedding, document_embeddings)[0].cpu().numpy()
# 앙상블 점수 계산
combined_scores = bm25_weight * bm25_scores + embedding_weight * cosine_scores
# 가장 높은 점수를 가진 문서 인덱스 추출
top_index = np.argmax(combined_scores)
return documents[top_index]
def evaluate_performance(questions, answers, documents, bm25, embedding_model, document_embeddings):
"""검색 시스템의 성능 평가"""
predictions = []
for question in questions:
prediction = hybrid_search(question, documents, bm25, embedding_model, document_embeddings)
predictions.append(prediction)
# 정확도 및 F1 점수 계산
exact_matches = [1 if pred == ans else 0 for pred, ans in zip(predictions, answers)]
accuracy = accuracy_score(answers, predictions)
f1 = f1_score(answers, predictions, average='micro')
return accuracy, f1
# Gradio 인터페이스를 위한 함수
def gradio_search_interface(query):
result = hybrid_search(query, documents, bm25, embedding_model, document_embeddings)
return result[:500] + "..." if len(result) > 500 else result
def gradio_evaluation_interface():
accuracy, f1 = evaluate_performance(questions, answers, documents, bm25, embedding_model, document_embeddings)
return f"정확도: {accuracy:.2f}, F1 점수: {f1:.2f}"
if __name__ == "__main__":
# PDF 파일의 절대 경로 설정
pdf_path = '/home/joonhai/rag_lecture/rag_model/insurance.pdf'
qa_path = '/home/joonhai/rag_lecture/rag_model/qa.txt'
an_path = '/home/joonhai/rag_lecture/rag_model/an.txt'
try:
# 문서 로드 및 모델 초기화
documents = load_documents(pdf_path)
print("문서가 성공적으로 로드되었습니다.")
bm25 = setup_bm25(documents)
embedding_model, document_embeddings = setup_embedding_model('BAAI/bge-m3', documents)
# 질의응답 데이터 로드
questions, answers = load_qa_data(qa_path, an_path)
# Gradio 인터페이스 설정
interface = gr.Interface(
fn=gradio_search_interface,
inputs="text",
outputs="text",
title="질의응답 시스템",
description="질문을 입력하면 문서에서 답변을 제공합니다."
)
evaluation_interface = gr.Interface(
fn=gradio_evaluation_interface,
inputs=[],
outputs="text",
title="성능 평가 결과",
description="성능 평가 점수를 확인합니다."
)
# 두 개의 인터페이스를 하나의 Gradio 앱으로 결합
combined_interface = gr.TabbedInterface([interface, evaluation_interface], ["질의응답", "성능 평가"])
combined_interface.launch(share=True)
except ValueError as e:
print(f"Error: {e}")