-
Notifications
You must be signed in to change notification settings - Fork 2
/
app.py
153 lines (142 loc) · 5.55 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import contextlib
import json
import streamlit as st
from streamlit_mic_recorder import mic_recorder
from speak import database
from speak.chat_engine import chat
from speak.speech_to_text import transcribe
from speak.spell_check import grammar_coherence_correction
from speak.text_to_speech import generate_audio
def answers(audio_bytes, chat_identifier):
"""
Answers the user's audio input by transcribing it, correcting grammar coherence, generating a response,
and saving the messages in the database.
Args:
audio_bytes (bytes): The audio input from the user.
chat_identifier: The identifier of the chat.
Returns:
None
"""
# save the audio on tmp_file.wave
with open("tmp_file.wav", "wb") as file:
file.write(audio_bytes)
audio_transcribe = transcribe()
message_corrected = grammar_coherence_correction(audio_transcribe)
database.insert_message(
chat_id=chat_id, role="user", content=message_corrected, audio=audio_bytes,
)
all_messages = database.get_messages_by_chat_id(chat_identifier)
# inverse messages list
all_messages = all_messages[::-1]
clean_messages = []
for msg in all_messages:
msg_content = json.loads(msg[3])
if msg[2] in ["system", "assistant"]:
clean_messages.append(msg_content)
elif msg[2] == "user":
clean_messages.append({"role": "user", "content": msg_content["rewritten"]})
response = chat(clean_messages)
response_audio = generate_audio(response["content"])
database.insert_message(
chat_id=chat_identifier,
role="assistant",
content=response,
audio=response_audio,
)
st.set_page_config(
page_title="Speak", page_icon="🧊", layout="wide",
)
st.title("Speak: A chatbot for language learning")
col1, col2 = st.columns([1, 4])
chat_id = None
with col1:
with st.status("**Select a chat**"):
chats = database.get_all_chats()
names = [chat[1] for chat in chats]
if names:
selected_chat = st.selectbox("Select chat", names)
chat_id = chats[names.index(selected_chat)][0]
with st.status("**Create a new chat**"):
name = st.text_input("Chat name", "")
prompt = st.text_area("Chat Prompt", "", help="Write a prompt for the chatbot")
if st.button("Create"):
try:
if name in names:
st.error("Chat already exists")
else:
chat_id = database.insert_chat(name)
content = {
"role": "system",
"content": prompt,
}
database.insert_message(
chat_id=chat_id, role="system", content=content, audio="NULL"
)
st.success("Chat created")
except NameError:
chat_id = database.insert_chat(name)
content = {
"role": "system",
"content": prompt,
}
database.insert_message(
chat_id=chat_id, role="system", content=content, audio="NULL"
)
st.success("Chat created")
with st.status(" **Assistant Description**"):
with contextlib.suppress(Exception):
messages = database.get_messages_by_chat_id(chat_id)
for message in messages:
if message[2] == "system":
st.write(json.loads(message[3])["content"])
with st.status(":red[**Delete**]"):
# Delete all messages
if st.button("Delete all messages", type="primary"):
database.delete_messages_by_chat_id(chat_id)
st.success("All messages deleted")
# Delete chats
if st.button("Delete chat", type="primary"):
database.delete_chat(chat_id)
st.success("Chat deleted")
with col2:
if chat_id:
if selected_chat:
st.write(f" **{selected_chat}**")
st.write("Record your voice:")
c1, c2 = st.columns([9, 10])
with c1:
if audio := mic_recorder(
start_prompt="⏺️ Record", stop_prompt="⏹️Stop", key="recorder"
):
st.audio(audio["bytes"])
with c2:
if st.button("Send"):
answers(audio["bytes"], chat_id)
messages = database.get_messages_by_chat_id(chat_id)
for message in messages:
role = message[2]
content = json.loads(message[3])
audio = message[4]
if role == "system":
continue
st.write(f"{role.capitalize()}: ")
st.audio(audio)
if role == "user":
score_color = (
":green[Correction]"
if content["score"] >= 80
else ":red[Correction]"
)
with st.status(score_color):
st.write(
f"Score: {score_color.replace('Correction', str(content['score']))}"
)
st.write(f"Original: {content['original']}")
st.write(f"Suggestion: {content['rewritten']}")
st.write(f"Grammar corrected: {content['grammar_corrected']}")
st.write(f"Coherence corrected: {content['coherence_corrected']}")
else:
with st.status("Transcription"):
st.write(content["content"])
else:
st.write("Select a chat to continue")