-
Notifications
You must be signed in to change notification settings - Fork 30
/
app.py
188 lines (142 loc) · 6.33 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from fastapi import FastAPI, Form, Request, Response, File, Depends, HTTPException, status
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.encoders import jsonable_encoder
from langchain.llms import CTransformers
from langchain.chains import QAGenerationChain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain.document_loaders import PyPDFLoader
from langchain.prompts import PromptTemplate
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains.summarize import load_summarize_chain
from langchain.chains import RetrievalQA
import os
import json
import time
import uvicorn
import aiofiles
from PyPDF2 import PdfReader
import csv
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
def load_llm():
# Load the locally downloaded model here
llm = CTransformers(
model = "mistral-7b-instruct-v0.1.Q4_K_S.gguf",
model_type="mistral",
max_new_tokens = 1048,
temperature = 0.3
)
return llm
def file_processing(file_path):
# Load data from PDF
loader = PyPDFLoader(file_path)
data = loader.load()
question_gen = ''
for page in data:
question_gen += page.page_content
splitter_ques_gen = RecursiveCharacterTextSplitter(
chunk_size = 1000,
chunk_overlap = 100
)
chunks_ques_gen = splitter_ques_gen.split_text(question_gen)
document_ques_gen = [Document(page_content=t) for t in chunks_ques_gen]
splitter_ans_gen = RecursiveCharacterTextSplitter(
chunk_size = 300,
chunk_overlap = 30
)
document_answer_gen = splitter_ans_gen.split_documents(
document_ques_gen
)
return document_ques_gen, document_answer_gen
def llm_pipeline(file_path):
document_ques_gen, document_answer_gen = file_processing(file_path)
llm_ques_gen_pipeline = load_llm()
prompt_template = """
You are an expert at creating questions based on coding materials and documentation.
Your goal is to prepare a coder or programmer for their exam and coding tests.
You do this by asking questions about the text below:
------------
{text}
------------
Create questions that will prepare the coders or programmers for their tests.
Make sure not to lose any important information.
QUESTIONS:
"""
PROMPT_QUESTIONS = PromptTemplate(template=prompt_template, input_variables=["text"])
refine_template = ("""
You are an expert at creating practice questions based on coding material and documentation.
Your goal is to help a coder or programmer prepare for a coding test.
We have received some practice questions to a certain extent: {existing_answer}.
We have the option to refine the existing questions or add new ones.
(only if necessary) with some more context below.
------------
{text}
------------
Given the new context, refine the original questions in English.
If the context is not helpful, please provide the original questions.
QUESTIONS:
"""
)
REFINE_PROMPT_QUESTIONS = PromptTemplate(
input_variables=["existing_answer", "text"],
template=refine_template,
)
ques_gen_chain = load_summarize_chain(llm = llm_ques_gen_pipeline,
chain_type = "refine",
verbose = True,
question_prompt=PROMPT_QUESTIONS,
refine_prompt=REFINE_PROMPT_QUESTIONS)
ques = ques_gen_chain.run(document_ques_gen)
embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
vector_store = FAISS.from_documents(document_answer_gen, embeddings)
llm_answer_gen = load_llm()
ques_list = ques.split("\n")
filtered_ques_list = [element for element in ques_list if element.endswith('?') or element.endswith('.')]
answer_generation_chain = RetrievalQA.from_chain_type(llm=llm_answer_gen,
chain_type="stuff",
retriever=vector_store.as_retriever())
return answer_generation_chain, filtered_ques_list
def get_csv (file_path):
answer_generation_chain, ques_list = llm_pipeline(file_path)
base_folder = 'static/output/'
if not os.path.isdir(base_folder):
os.mkdir(base_folder)
output_file = base_folder+"QA.csv"
with open(output_file, "w", newline="", encoding="utf-8") as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow(["Question", "Answer"]) # Writing the header row
for question in ques_list:
print("Question: ", question)
answer = answer_generation_chain.run(question)
print("Answer: ", answer)
print("--------------------------------------------------\n\n")
# Save answer to CSV file
csv_writer.writerow([question, answer])
return output_file
@app.get("/")
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/upload")
async def chat(request: Request, pdf_file: bytes = File(), filename: str = Form(...)):
base_folder = 'static/docs/'
if not os.path.isdir(base_folder):
os.mkdir(base_folder)
pdf_filename = os.path.join(base_folder, filename)
async with aiofiles.open(pdf_filename, 'wb') as f:
await f.write(pdf_file)
response_data = jsonable_encoder(json.dumps({"msg": 'success',"pdf_filename": pdf_filename}))
res = Response(response_data)
return res
@app.post("/analyze")
async def chat(request: Request, pdf_filename: str = Form(...)):
output_file = get_csv(pdf_filename)
response_data = jsonable_encoder(json.dumps({"output_file": output_file}))
res = Response(response_data)
return res
if __name__ == "__main__":
uvicorn.run("app:app", host='0.0.0.0', port=8000, reload=True)