-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
121 lines (103 loc) · 3.76 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
from urllib.parse import urlparse
import secrets
import voyageai
from flask import Flask, request
from dotenv import load_dotenv
from langchain_text_splitters import RecursiveCharacterTextSplitter
from pinecone import Pinecone
from mistralai.client import MistralClient
from dbbuilder import build_vector_database
from rag import WordPressRAG
load_dotenv()
app = Flask(__name__)
app.secret_key = secrets.token_hex()
DATABASE = Pinecone(os.environ["PINECONE_API_KEY"]).Index("wordpress-chatbot")
EMBEDDER = voyageai.Client(os.environ["VOYAGE_API_KEY"])
LLM_CLIENT = MistralClient(os.environ["MISTRAL_API_KEY"])
CHATBOT = WordPressRAG(LLM_CLIENT, "open-mistral-7b", EMBEDDER, DATABASE)
def authorize():
auth = request.authorization
if auth is None or auth.token != os.environ["AUTH_KEY"]:
return {"message": "Wrong or no authentication key present."}, 401
return {"message": "Authorization successful."}, 200
@app.route("/db", methods=["POST", "DELETE"])
def db_ops():
"""
Request Body:
{
"site_url": "https://www.example.com",
"create_if_not_present": true
}
Optional:
"create_if_not_present": false (Default)
"""
auth_res = authorize()
if auth_res[1] == 401:
return auth_res
if "site_url" not in request.json:
return {"message": "'site_url' is not present in the request body."}, 400
site_domain = urlparse(request.json["site_url"]).hostname
if not site_domain:
return {"message": "'site_url' is not a url."}, 400
match request.method:
case "POST":
EMBEDDING_SIZE = 1024
match = DATABASE.query(
namespace=site_domain,
vector=[0]*EMBEDDING_SIZE,
top_k=1,
)["matches"]
if len(match) == 1:
return {
"message": f"Database already present for '{site_domain}'",
"database_present": True,
"database_created": False,
}
elif len(match) > 1:
raise AssertionError("Only one match should've been returned.")
if not request.json.get("create_if_not_present", False):
return {
"message": f"Database not present for '{site_domain}'",
"database_present": False,
"database_created": False,
}
build_vector_database(
request.json["site_url"],
RecursiveCharacterTextSplitter.from_tiktoken_encoder(
encoding_name="cl100k_base", chunk_size=200, chunk_overlap=40
),
EMBEDDER,
DATABASE,
)
return {
"message": f"Database created for '{site_domain}'",
"database_present": True,
"database_created": True,
}
case "DELETE":
DATABASE.delete(delete_all=True, namespace=site_domain)
return {"message": f"Database deleted for '{site_domain}'"}
@app.route("/chat", methods=["POST"])
def chat():
"""
Request Body:
{
"site_url": "https://www.example.com",
"messages": [
{"role": "assistant", "content": "message1"},
{"role": "system", "content": "message2"},
{"role": "user", "content": "message3"}
]
}
Optional:
"site_url": If not present, RAG is not performed.
"""
auth_res = authorize()
if auth_res[1] == 401:
return auth_res
if "site_url" in request.json:
site_domain = urlparse(request.json["site_url"]).hostname
else:
site_domain = None
return CHATBOT.generate(site_domain, request.json["messages"])