Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Prometheus instrumentation and modify UI to get model_id using /info endpoint and source_docs using langchain #5

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
74 changes: 44 additions & 30 deletions examples/ui/gradio/gradio-hftgi-rag-redis/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from langchain.llms import HuggingFaceTextGenInference
from langchain.prompts import PromptTemplate
from langchain.vectorstores.redis import Redis
from prometheus_client import start_http_server, Counter
from prometheus_client import start_http_server, Counter, Histogram, Gauge
import requests

load_dotenv()

Expand All @@ -24,7 +25,7 @@
APP_TITLE = os.getenv('APP_TITLE', 'Talk with your documentation')

INFERENCE_SERVER_URL = os.getenv('INFERENCE_SERVER_URL')
MAX_NEW_TOKENS = int(os.getenv('MAX_NEW_TOKENS', 512))
MAX_NEW_TOKENS = int(os.getenv('MAX_NEW_TOKENS', 100))
TOP_K = int(os.getenv('TOP_K', 10))
TOP_P = float(os.getenv('TOP_P', 0.95))
TYPICAL_P = float(os.getenv('TYPICAL_P', 0.95))
Expand All @@ -33,16 +34,18 @@

REDIS_URL = os.getenv('REDIS_URL')
REDIS_INDEX = os.getenv('REDIS_INDEX')

TIMEOUT = int(os.getenv('TIMEOUT', 30))
# Start Prometheus metrics server
start_http_server(8000)

# Create a counter metric
FEEDBACK_COUNTER = Counter("feedback_stars", "Number of feedbacks by stars", ["stars"])
# Create metric
FEEDBACK_COUNTER = Counter("feedback_stars", "Number of feedbacks by stars", ["stars", "model_id"])
MODEL_USAGE_COUNTER = Counter('model_usage', 'Number of times a model was used', ['model_id'])
REQUEST_TIME = Gauge('request_duration_seconds', 'Time spent processing a request', ['model_id'])

model_id = ""

client = Client(base_url=INFERENCE_SERVER_URL)
client = Client(base_url=INFERENCE_SERVER_URL,timeout=TIMEOUT)

# Streaming implementation
class QueueCallback(BaseCallbackHandler):
Expand All @@ -64,30 +67,39 @@ def remove_source_duplicates(input_list):
unique_list.append(item.metadata['source'])
return unique_list


def stream(input_text) -> Generator:

global model_id

# Create a Queue
# Create queue
job_done = object()

# Create a function to call - this will run in a thread
def task():
resp = qa_chain({"query": input_text})
sources = remove_source_duplicates(resp['source_documents'])

sources = remove_source_duplicates(resp['source_documents'])
input = str(input_text)
response = client.generate(input, max_new_tokens=MAX_NEW_TOKENS)
text = response.generated_text
model_id = response.model_id
q.put({"model_id": response.model_id})
start_time = time.perf_counter() # start and end time to get the precise timing of the request

try:
response = requests.get(INFERENCE_SERVER_URL + "/info")
json_response = response.json()
print(json_response)
model_id = json_response['model_id']
end_time = time.perf_counter()
# Record successful request time
REQUEST_TIME.labels(model_id=model_id).set(end_time - start_time)
except TimeoutError: # or whatever exception your client throws on timeout
end_time = time.perf_counter()

q.put({"model_id": model_id})
print("MODEL ID IS:",model_id)
print("Question:",input)
if len(sources) != 0:
q.put("\n*Sources:* \n")
for source in sources:
q.put("* " + str(source) + "\n")
q.put(job_done)
print("Saving it...")

# Create a thread and start the function
t = Thread(target=task)
Expand All @@ -103,7 +115,7 @@ def task():
break
if isinstance(next_token, dict) and 'model_id' in next_token:
model_id = next_token['model_id']
MODEL_USAGE_COUNTER.labels(model_id=model_id).inc()
MODEL_USAGE_COUNTER.labels(model_id=model_id).inc()
elif isinstance(next_token, str):
content += next_token
yield next_token, content, model_id
Expand Down Expand Up @@ -163,39 +175,41 @@ def task():
return_source_documents=True
)

# Gradio implementation
def ask_llm(message, history):
for next_token, content, model_id in stream(message):
print(model_id)
model_id_box.update(value=model_id)
yield f"{content}\n\nModel ID: {model_id}"


with gr.Blocks(title="HatBot", css="footer {visibility: hidden}") as demo:
# Gradio implementation
with gr.Blocks(title="HatBot") as demo:

input_box = gr.Textbox(label="Your Question")
output_answer = gr.Textbox(label="Answer", readonly=True)
with gr.Column(class_name="column-class"):
project_box = gr.Textbox(label="Your Project")
customer_box = gr.Textbox(label="Customer")
input_box = gr.Textbox(label="Your Question")
submit_button = gr.Button("Submit")
output_answer = gr.Textbox(label="Answer", readonly=True)

model_id_box = gr.Textbox(visible=False) # will hold the model_id

gr.Interface(
submit_button.click(
fn=ask_llm,
inputs=[input_box],
outputs=[output_answer],
clear_btn=None,
retry_btn=None,
undo_btn=None,
stop_btn=None,
description=APP_TITLE
)
outputs=[output_answer]
)

radio = gr.Radio(["1", "2", "3", "4", "5"], label="Star Rating")
output = gr.Textbox(label="Output Box")

download_button = gr.Button("Download")

@radio.input(inputs=radio, outputs=output)
def get_feedback(star):
print("Rating: " + star)
# Increment the counter based on the star rating received
FEEDBACK_COUNTER.labels(stars=str(star)).inc()

FEEDBACK_COUNTER.labels(stars=str(star), model_id=model_id).inc()
return f"Received {star} star feedback. Thank you!"


Expand Down
49 changes: 28 additions & 21 deletions examples/ui/gradio/gradio-hftgi-rag-redis/chatbot_test.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,45 @@
# Generated by Selenium IDE
# import pytest
# Import necessary libraries
import time
import json
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.common.action_chains import ActionChains
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.wait import WebDriverWait
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
from selenium.common.exceptions import TimeoutException
import random

# read yaml file
# loop
driver = webdriver.Firefox()
vars = {}

driver.get("https://gradio-hftgi-rag-redis-vectordb.apps.ai-dev01.kni.syseng.devcluster.openshift.com")
driver.get("https://canary-gradio-vectordb.apps.ai-dev01.kni.syseng.devcluster.openshift.com")
driver.set_window_size(1084, 811)
timeout = 10
try:
# element_present = EC.presence_of_element_located((By.CSS_SELECTOR, "#chatinput .scroll-hide"))
element_present = EC.presence_of_element_located((By.CSS_SELECTOR, "#component-1 .scroll-hide"))

for user in range(20):
element_present = EC.presence_of_element_located((By.CSS_SELECTOR, "#component-0 .scroll-hide"))
WebDriverWait(driver, timeout).until(element_present)
except TimeoutException:
print("Timed out waiting for page to load")

driver.find_element(By.CSS_SELECTOR, "#component-1 .scroll-hide")
driver.find_element(By.CSS_SELECTOR, "#component-1 .scroll-hide").send_keys("hi how are you")
driver.find_element(By.ID, "component-12").click()
# User enters a question
project_input = driver.find_element(By.CSS_SELECTOR, "#component-3 .scroll-hide")
project_input.clear() # Clearing any previous input
project_input.send_keys(f"User {user + 1}: OpenShift AI")
customer_input = driver.find_element(By.CSS_SELECTOR, "#component-4 .scroll-hide")
customer_input.clear() # Clearing any previous input
customer_input.send_keys(f"User {user + 1}: Accenture")
question_input = driver.find_element(By.CSS_SELECTOR, "#component-5 .scroll-hide")
question_input.clear() # Clearing any previous input
question_input.send_keys(f"User {user + 1}: What is OpenShift AI?")
driver.find_element(By.ID, "component-6").click()

label_list=[1,2,3,4,5]
random_num = random.choice(label_list)
labelname=str(random_num)+'-radio-label'
label_id="label[data-testid='"+labelname+"']"

label_id="label[data-testid='2-radio-label']"
x = WebDriverWait(driver, 20).until(EC.visibility_of_element_located((By.CSS_SELECTOR, label_id))).click()
# label_id = "label[data-testid='2-radio-label']"
WebDriverWait(driver, 20).until(EC.visibility_of_element_located((By.CSS_SELECTOR, label_id))).click()
time.sleep(2) # Adding a delay for better simulation of user interaction

# if needed add some delay
# driver.quit()
# end loop
# Close the browser after the loop completes
#driver.quit()