Skip to content

Commit

Permalink
feat(ui): refacto playground UI (#115)
Browse files Browse the repository at this point in the history
* feat(playground): change ui + cache optimization

* feat(playground): update pyproject.toml

* feat(playground): update style for dsfr

* feat(playground): update style for dsfr 2

* feat: cleaning and add reranker

---------

Co-authored-by: camilleAND <[email protected]>
Co-authored-by: leoguillaume <[email protected]>
  • Loading branch information
3 people authored Dec 23, 2024
1 parent 741b0cd commit ee37cf1
Show file tree
Hide file tree
Showing 9 changed files with 419 additions and 296 deletions.
6 changes: 1 addition & 5 deletions compose.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
services:
fastapi:
build:
context: .
dockerfile: ./app/Dockerfile

image: ghcr.io/etalab-ia/albert-api/app:latest
command: uvicorn app.main:app --host 0.0.0.0 --port 8000
environment:
Expand All @@ -21,7 +17,7 @@ services:

streamlit:
image: ghcr.io/etalab-ia/albert-api/ui:latest
command: streamlit run /ui/chat.py --server.port=8501 --browser.gatherUsageStats false --theme.base light --server.maxUploadSize=20
command: streamlit run ui/main.py --server.port=8501 --browser.gatherUsageStats false --theme.base=light --theme.primaryColor=#6a6af4 --server.maxUploadSize=20
restart: always
environment:
- BASE_URL=http://fastapi:8000/v1
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [

[project.optional-dependencies]
ui = [
"streamlit==1.39.0",
"streamlit==1.40.2",
"streamlit-extras==0.5.0",
]
app = [
Expand Down
119 changes: 0 additions & 119 deletions ui/chat.py

This file was deleted.

1 change: 1 addition & 0 deletions ui/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
EMBEDDINGS_MODEL_TYPE = "text-embeddings-inference"
LANGUAGE_MODEL_TYPE = "text-generation"
AUDIO_MODEL_TYPE = "automatic-speech-recognition"
RERANK_MODEL_TYPE = "text-classification"
INTERNET_COLLECTION_DISPLAY_ID = "internet"
PRIVATE_COLLECTION_TYPE = "private"
SUPPORTED_LANGUAGES = [
Expand Down
30 changes: 30 additions & 0 deletions ui/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import streamlit as st

from config import BASE_URL

st.set_page_config(
page_title="Albert playground",
page_icon="https://www.systeme-de-design.gouv.fr/uploads/apple_touch_icon_8ffa1fa80c.png",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
"Get Help": "mailto:[email protected]",
"Report a bug": "https://github.com/etalab-ia/albert-api/issues",
"About": "https://github.com/etalab-ia/albert-api",
},
)

st.logo(
image="https://upload.wikimedia.org/wikipedia/fr/thumb/5/50/Bloc_Marianne.svg/1200px-Bloc_Marianne.svg.png",
link=BASE_URL.replace("/v1", "/playground"),
size="large",
)

pg = st.navigation(
pages=[
st.Page(page="pages/chat.py", title="Chat", icon=":material/chat:"),
st.Page(page="pages/documents.py", title="Documents", icon=":material/file_copy:"),
st.Page(page="pages/transcription.py", title="Transcription", icon=":material/graphic_eq:"),
]
)
pg.run()
179 changes: 179 additions & 0 deletions ui/pages/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import logging
import traceback


import streamlit as st

from config import INTERNET_COLLECTION_DISPLAY_ID
from utils import generate_stream, get_collections, get_models, header

API_KEY = header()

# Data
try:
language_models, embeddings_models, _, rerank_models = get_models(api_key=API_KEY)
collections = get_collections(api_key=API_KEY)
except Exception:
st.error("Error to fetch user data.")
logging.error(traceback.format_exc())
st.stop()

# State

if "selected_model" not in st.session_state:
st.session_state["selected_model"] = language_models[0]

if "selected_collections" not in st.session_state:
st.session_state.selected_collections = []

if "messages" not in st.session_state:
st.session_state["messages"] = []
st.session_state["sources"] = []

# Sidebar
with st.sidebar:
new_chat = st.button(label="**:material/refresh: New chat**", key="new", use_container_width=True)
if new_chat:
st.session_state.pop("messages", None)
st.session_state.pop("sources", None)
st.rerun()
params = {"sampling_params": dict(), "rag": dict()}

st.subheader(body="Chat parameters")
st.session_state["selected_model"] = st.selectbox(
label="Language model", options=language_models, index=language_models.index(st.session_state.selected_model)
)

params["sampling_params"]["model"] = st.session_state["selected_model"]
params["sampling_params"]["temperature"] = st.slider(label="Temperature", value=0.2, min_value=0.0, max_value=1.0, step=0.1)

if st.toggle(label="Max tokens", value=False):
max_tokens = st.number_input(label="Max tokens", value=100, min_value=0, step=100)
params["sampling_params"]["max_tokens"] = max_tokens

st.subheader(body="RAG parameters")
params["rag"]["embeddings_model"] = st.selectbox(label="Embeddings model", options=embeddings_models)
model_collections = [
f"{collection["name"]} - {collection["id"]}" for collection in collections if collection["model"] == params["rag"]["embeddings_model"]
] + [f"Internet - {INTERNET_COLLECTION_DISPLAY_ID}"]

if model_collections:

@st.dialog("Select collections")
def add_collection(collections: list) -> None:
selected_collections = st.session_state.selected_collections
col1, col2 = st.columns(spec=2)

for collection in collections:
collection_id = collection.split(" - ")[1]
if st.checkbox(
label=f"{collection.split(" - ")[0]} (*{collection_id[:8]}*)",
value=False if collection_id not in st.session_state.selected_collections else True,
):
selected_collections.append(collection_id)
elif collection_id in selected_collections:
selected_collections.remove(collection_id)

with col1:
if st.button(label="**Submit :material/check_circle:**", use_container_width=True):
st.session_state.selected_collections = list(set(selected_collections))
st.rerun()
with col2:
if st.button(label="**Clear :material/close:**", use_container_width=True):
st.session_state.selected_collections = []
st.rerun()

option_map = {0: f"{len(set(st.session_state.selected_collections))} selected"}
pill = st.pills(
label="Collections",
options=option_map.keys(),
format_func=lambda option: option_map[option],
selection_mode="single",
default=None,
key="add_collections",
)
if pill == 0:
add_collection(collections=model_collections)

params["rag"]["collections"] = st.session_state.selected_collections
params["rag"]["k"] = st.number_input(label="Number of chunks to retrieve (k)", value=3)

if st.session_state.selected_collections:
rag = st.toggle(label="Activated RAG", value=True, disabled=not bool(params["rag"]["collections"]))
else:
rag = st.toggle(label="Activated RAG", value=False, disabled=True, help="You need to select at least one collection to activate RAG.")

if st.session_state.selected_collections and rag:
rerank = st.toggle(
label="Add rerank",
value=False,
disabled=not bool(params["rag"]["collections"]),
help="When activated, that retrieve the double number of chunks (k*2) and keep the best k chunks after reranking.",
)
if rerank:
params["rag"]["rerank_model"] = st.selectbox(label="Rerank model", options=rerank_models)
else:
rerank = st.toggle(
label="Add rerank", value=False, disabled=True, help="You need to select at least one collection to activate rerank and activate RAG."
)

# Main
with st.chat_message(name="assistant"):
st.markdown(
body="""Bonjour je suis Albert, et je peux vous aider si vous avez des questions administratives !
Je peux me connecter à vos bases de connaissances, pour ça sélectionnez les collections voulues dans le menu de gauche. Je peux également chercher sur les sites officiels de l'État, pour ça sélectionnez la collection "Internet" à gauche. Si vous ne souhaitez pas utiliser de collection, désactivez le RAG en décochant la fonction "Activated RAG".
Comment puis-je vous aider ?
"""
)

for i, message in enumerate(st.session_state.messages):
with st.chat_message(message["role"], avatar=":material/face:" if message["role"] == "user" else None):
st.markdown(message["content"])
if st.session_state.sources[i]:
st.pills(label="Sources", options=st.session_state.sources[i], label_visibility="hidden")

sources = []
if prompt := st.chat_input(placeholder="Message to Albert"):
# send message to the model
user_message = {"role": "user", "content": prompt}
st.session_state.messages.append(user_message)
st.session_state.sources.append([])
with st.chat_message(name="user", avatar=":material/face:"):
st.markdown(body=prompt)

with st.chat_message(name="assistant"):
try:
stream, sources = generate_stream(
messages=st.session_state.messages,
params=params,
api_key=API_KEY,
rag=rag,
rerank=rerank,
)
response = st.write_stream(stream=stream)
except Exception:
st.error(body="Error to generate response.")
logging.error(traceback.format_exc())
st.stop()

formatted_sources = []
if sources:
for source in sources:
formatted_source = source[:15] + "..." if len(source) > 15 else source
if source.lower().startswith("http"):
formatted_sources.append(f":material/globe: [{formatted_source}]({source})")
else:
formatted_sources.append(f":material/import_contacts: {formatted_source}")
st.pills(label="Sources", options=formatted_sources, label_visibility="hidden")

assistant_message = {"role": "assistant", "content": response}
st.session_state.messages.append(assistant_message)
st.session_state.sources.append(formatted_sources)

with st._bottom:
st.caption(
body='<p style="text-align: center;"><i>I can make mistakes, please always verify my sources and answers.</i></p>',
unsafe_allow_html=True,
)
Loading

0 comments on commit ee37cf1

Please sign in to comment.