forked from PromtEngineer/localGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
localGPT_UI.py
122 lines (98 loc) · 3.86 KB
/
localGPT_UI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import subprocess
import streamlit as st
from run_localGPT import load_model
from langchain.vectorstores import Chroma
from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.chains import RetrievalQA
from streamlit_extras.add_vertical_space import add_vertical_space
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
def model_memory():
# Adding history to the model.
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer,\
just say that you don't know, don't try to make up an answer.
{context}
{history}
Question: {question}
Helpful Answer:"""
prompt = PromptTemplate(input_variables=["history", "context", "question"], template=template)
memory = ConversationBufferMemory(input_key="question", memory_key="history")
return prompt, memory
# Sidebar contents
with st.sidebar:
st.title("🤗💬 Converse with your Data")
st.markdown(
"""
## About
This app is an LLM-powered chatbot built using:
- [Streamlit](https://streamlit.io/)
- [LangChain](https://python.langchain.com/)
- [LocalGPT](https://github.com/PromtEngineer/localGPT)
"""
)
add_vertical_space(5)
st.write("Made with ❤️ by [Prompt Engineer](https://youtube.com/@engineerprompt)")
if torch.backends.mps.is_available():
DEVICE_TYPE = "mps"
elif torch.cuda.is_available():
DEVICE_TYPE = "cuda"
else:
DEVICE_TYPE = "cpu"
# if "result" not in st.session_state:
# # Run the document ingestion process.
# run_langest_commands = ["python", "ingest.py"]
# run_langest_commands.append("--device_type")
# run_langest_commands.append(DEVICE_TYPE)
# result = subprocess.run(run_langest_commands, capture_output=True)
# st.session_state.result = result
# Define the retreiver
# load the vectorstore
if "EMBEDDINGS" not in st.session_state:
EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
st.session_state.EMBEDDINGS = EMBEDDINGS
if "DB" not in st.session_state:
DB = Chroma(
persist_directory=PERSIST_DIRECTORY,
embedding_function=st.session_state.EMBEDDINGS,
client_settings=CHROMA_SETTINGS,
)
st.session_state.DB = DB
if "RETRIEVER" not in st.session_state:
RETRIEVER = DB.as_retriever()
st.session_state.RETRIEVER = RETRIEVER
if "LLM" not in st.session_state:
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME)
st.session_state["LLM"] = LLM
if "QA" not in st.session_state:
prompt, memory = model_memory()
QA = RetrievalQA.from_chain_type(
llm=LLM,
chain_type="stuff",
retriever=RETRIEVER,
return_source_documents=True,
chain_type_kwargs={"prompt": prompt, "memory": memory},
)
st.session_state["QA"] = QA
st.title("LocalGPT App 💬")
# Create a text input box for the user
prompt = st.text_input("Input your prompt here")
# while True:
# If the user hits enter
if prompt:
# Then pass the prompt to the LLM
response = st.session_state["QA"](prompt)
answer, docs = response["result"], response["source_documents"]
# ...and write it out to the screen
st.write(answer)
# With a streamlit expander
with st.expander("Document Similarity Search"):
# Find the relevant pages
search = st.session_state.DB.similarity_search_with_score(prompt)
# Write out the first
for i, doc in enumerate(search):
# print(doc)
st.write(f"Source Document # {i+1} : {doc[0].metadata['source'].split('/')[-1]}")
st.write(doc[0].page_content)
st.write("--------------------------------")