Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use centralised vector db #7

Merged
merged 12 commits into from
May 10, 2024
2 changes: 2 additions & 0 deletions .env.default
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ LANGCHAIN_TRACING_V2=true
LANGCHAIN_ENDPOINT="https://api.smith.langchain.com"
LANGCHAIN_API_KEY=<your-langchain-api-key>
LANGCHAIN_PROJECT="virtual-contributor"
VECTOR_DB_HOST=localhost
VECTOR_DB_PORT=8000
264 changes: 55 additions & 209 deletions ai_adapter.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,28 @@
import json
from langchain.vectorstores import FAISS
from langchain_core.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain_openai import AzureOpenAIEmbeddings
from langchain_openai import AzureOpenAI
import chromadb
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction

from langchain_core.prompts import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
from langchain_openai import AzureChatOpenAI
from langchain.schema import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain.schema import format_document
from langchain_core.messages import get_buffer_string
from langchain_core.messages.ai import AIMessage
from langchain_core.runnables import RunnableBranch
from langchain.callbacks import get_openai_callback

from operator import itemgetter
import logging
import sys
import io
from config import config, vectordb_path, local_path, LOG_LEVEL, max_token_limit
from ingest import ingest
from numpy import source
from config import config, local_path, LOG_LEVEL, max_token_limit


import os
# configure logging
logger = logging.getLogger(__name__)
assert LOG_LEVEL in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
logger.setLevel(getattr(logging, LOG_LEVEL)) # Set logger level


# Create handlers
c_handler = logging.StreamHandler(io.TextIOWrapper(sys.stdout.buffer, line_buffering=True))
f_handler = logging.FileHandler(os.path.join(os.path.expanduser(local_path), 'app.log'))

c_handler.setLevel(level=getattr(logging, LOG_LEVEL))
f_handler.setLevel(logging.WARNING)
from logger import setup_logger

# Create formatters and add them to handlers
c_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', '%m-%d %H:%M:%S')
f_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', '%m-%d %H:%M:%S')
c_handler.setFormatter(c_format)
f_handler.setFormatter(f_format)

# Add handlers to the logger
logger.addHandler(c_handler)
logger.addHandler(f_handler)

logger.info(f"log level {os.path.basename(__file__)}: {LOG_LEVEL}")
logger = setup_logger(__name__)

# verbose output for LLMs
if LOG_LEVEL == "DEBUG":
verbose_models = True
else:
verbose_models = False

# define internal configuration parameters

# does chain return the source documents?
return_source_documents = True


# Define a dictionary containing country codes as keys and related languages as values

language_mapping = {
'EN': 'English',
'US': 'English',
'UK': 'English',
'FR': 'French',
'DE': 'German',
'ES': 'Spanish',
'NL': 'Dutch',
'BG': 'Bulgarian',
'UA': "Ukranian"
}

# function to retrieve language from country
def get_language_by_code(language_code):
"""Returns the language associated with the given code. If no match is found, it returns 'English'."""
return language_mapping.get(language_code, 'English')


chat_system_template = """
You are a friendly and talkative conversational agent, tasked with answering questions based on the context provided below delimited by triple pluses.
Use the following step-by-step instructions to respond to user inputs:
Expand Down Expand Up @@ -111,76 +54,57 @@ def get_language_by_code(language_code):
)


# generic_llm = AzureOpenAI(azure_deployment=os.environ["LLM_DEPLOYMENT_NAME"],
# temperature=0, verbose=verbose_models)

chat_llm = AzureChatOpenAI(azure_deployment=os.environ["LLM_DEPLOYMENT_NAME"],
temperature=float(os.environ["AI_MODEL_TEMPERATURE"]),
max_tokens=max_token_limit, verbose=verbose_models)

# condense_llm = AzureChatOpenAI(azure_deployment=os.environ["LLM_DEPLOYMENT_NAME"],
# temperature=0,
# verbose=verbose_models)

embeddings = AzureOpenAIEmbeddings(
azure_deployment=config['embeddings_deployment_name'],
chunk_size=1
llm = AzureChatOpenAI(
azure_deployment=config["llm_deployment_name"],
temperature=float(config["model_temperature"]),
max_tokens=max_token_limit,
verbose=verbose_models,
)

def load_vector_db():
"""
Purpose:
Load the data into the vector database.
Args:

Returns:
vectorstore: the vectorstore object
"""
# Check if the vector database exists
if os.path.exists(vectordb_path + os.sep + "index.pkl"):
logger.info(f"The file vector database is present")
else:
logger.info(f"The file vector database is not present, ingesting")
ingest()

return FAISS.load_local(vectordb_path, embeddings)


vectorstore = load_vector_db()

retriever = vectorstore.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": .5})
embed_func = OpenAIEmbeddingFunction(
api_key=config["openai_api_key"],
api_base=config["openai_endpoint"],
api_type="azure",
api_version=config["openai_api_version"],
model_name=config["embeddings_deployment_name"],
)

def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

def _combine_documents(docs, document_separator="\n\n"):
return document_separator.join(docs)

DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")

def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
# how do we handle languages? not all spaces are in Dutch obviously
# translating the question to the data _base language_ should be a separate call
# so the translation could be used for embeddings retrieval
async def query_chain(message, language, history):

async def query_chain(question, language, chat_history):
space_name = message["spaceNameID"]
question = message["question"]
logger.info(
"Query chaing invoked for '%s' with question: %s" % (space_name, question)
)

logger.info(question)
chroma_client = chromadb.HttpClient(host=config["db_host"], port=config["db_port"])
collection = chroma_client.get_collection(space_name, embedding_function=embed_func)

docs = retriever.invoke(question['question'])
docs = collection.query(
query_texts=[question], include=["documents", "metadatas"], n_results=4
)

logger.info(list(map(lambda d: d.metadata['source'], docs)))
logger.info(docs["metadatas"])
logger.info("Documents with ids [%s] selected" % ",".join(list(docs["ids"][0])))

review_system_prompt = SystemMessagePromptTemplate(
prompt=PromptTemplate(
input_variables=["context"],
template=chat_system_template
)
prompt=PromptTemplate(
input_variables=["context"], template=chat_system_template
)
)

review_human_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(
input_variables=["question"],
template=condense_question_template
input_variables=["question"], template=condense_question_template
)
)

Expand All @@ -191,93 +115,15 @@ async def query_chain(question, language, chat_history):
messages=messages,
)

# source_documents = map(lambda doc: doc.metadata['source'], docs)

# context = json.loads(question['context'])
# space_description = context["space"]["description"] + "\n" + context["space"]["tagline"]
# exchanged_messages = "\n".join(map(lambda message: message["message"], context["messages"]))

review_chain = review_prompt_template | chat_llm

result = review_chain.invoke({"question": question["question"], "context": _combine_documents(docs) })
return {'answer': result, 'source_documents': docs}

# # check whether the chat history is empty
# if chat_history.buffer == []:
# first_call = True
# else:
# first_call = False

# # add first_call to the question
# question.update({"first_call": first_call})
review_chain = review_prompt_template | llm

# logger.info(f"first call: {first_call}\n")
# logger.debug(f"chat history: {chat_history.buffer}\n")

# # First we add a step to load memory
# # This adds a "memory" key to the input object
# loaded_memory = RunnablePassthrough.assign(
# chat_history=RunnableLambda(chat_history.load_memory_variables) | itemgetter("history"),
# )

# logger.debug(f"loaded memory {loaded_memory}\n")
# logger.debug(f"chat history {chat_history}\n")


# # Now we calculate the standalone question if the chat_history is not empty
# standalone_question = {
# "standalone_question": {
# "question": lambda x: x["question"],
# "chat_history": lambda x: get_buffer_string(x["chat_history"]),
# }
# | condense_question_prompt
# | condense_llm
# | StrOutputParser(),
# }

# # pass the question directly on the first call in a chat sequence of the chatbot
# direct_question = {
# "question": lambda x: x["question"],
# }
# # Now we retrieve the documents
# # in case it is the first call (chat history empty)
# retrieved_documents = {
# "docs": itemgetter("question") | retriever,
# "question": lambda x: x["question"],
# }
# # or when the chat history is not empty, rephrase the question taking into account the chat history
# retrieved_documents_sa = {
# "docs": itemgetter("standalone_question") | retriever,
# "question": lambda x: x["standalone_question"],
# }

# # Now we construct the inputs for the final prompt
# final_inputs = {
# "context": lambda x: _combine_documents(x["docs"]),
# "chat_history" : lambda x: chat_history.buffer,
# "question": itemgetter("question"),
# "language": lambda x: language['language'],
# }

# # And finally, we do the part that returns the answers
# answer = {
# "answer": final_inputs | chat_prompt | chat_llm,
# "docs": itemgetter("docs"),
# }

# # And now we put it all together in a 'RunnableBranch', so we only invoke the rephrasing part when the chat history is not empty
# final_chain = RunnableBranch(
# (lambda x: x["first_call"], loaded_memory | direct_question | retrieved_documents | answer),
# loaded_memory | standalone_question | retrieved_documents_sa | answer,
# )

# try:
# logger.debug(f"final chain {final_chain}\n")
# result = await final_chain.ainvoke(question)
# except Exception as e:
# logger.error(f"An error occurred while generating a response: {str(e)}")
# # Handle the error appropriately here
# return {'answer': AIMessage(content='An error occurred while generating a response.'), 'source_documents': []}
# else:
if docs["documents"] and docs["metadatas"]:
result = review_chain.invoke(
{
"question": question,
"context": _combine_documents(docs["documents"][0]),
}
)
return {"answer": result, "source_documents": docs["metadatas"][0]}

# return {'answer': result['answer'], 'source_documents': result['docs'] if result['docs'] else []}
return {"answer": "", "source_documents": []}
8 changes: 7 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
load_dotenv()

config = {
"db_host": os.getenv('VECTOR_DB_HOST'),
"db_port": os.getenv('VECTOR_DB_PORT'),
"llm_deployment_name": os.getenv('LLM_DEPLOYMENT_NAME'),
"model_temperature": os.getenv('AI_MODEL_TEMPERATURE'),
"embeddings_deployment_name": os.getenv('EMBEDDINGS_DEPLOYMENT_NAME'),
"openai_endpoint": os.getenv('AZURE_OPENAI_ENDPOINT'),
"openai_api_key": os.getenv('AZURE_OPENAI_API_KEY'),
"openai_api_version": os.getenv('OPENAI_API_VERSION'),
"rabbitmq_host": os.getenv('RABBITMQ_HOST'),
"rabbitmq_user": os.getenv('RABBITMQ_USER'),
"rabbitmq_password": os.getenv('RABBITMQ_PASSWORD'),
"rabbitmqrequestqueue": os.getenv('RABBITMQ_QUEUE'),
"rabbitmq_queue": os.getenv('RABBITMQ_QUEUE'),
"source_website": os.getenv('AI_SOURCE_WEBSITE'),
"local_path": os.getenv('AI_LOCAL_PATH') or ''
}
Expand All @@ -22,3 +27,4 @@
max_token_limit = 2000

LOG_LEVEL = os.getenv('LOG_LEVEL') # Possible values: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'
assert LOG_LEVEL in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
7 changes: 0 additions & 7 deletions data/callout-0ac80071-cd92-473e-942b-4b7524033146.txt

This file was deleted.

21 changes: 0 additions & 21 deletions data/callout-10ce8d3a-b78a-43be-a418-64772d08a5e6.txt

This file was deleted.

7 changes: 0 additions & 7 deletions data/callout-149e7315-4529-4869-8d9b-ba5e3a05b415.txt

This file was deleted.

7 changes: 0 additions & 7 deletions data/callout-15d72ecc-9f44-4b13-b09a-d8fc818bea7d.txt

This file was deleted.

Loading