Skip to content

Commit

Permalink
feat: improve baserag and update user paramaters to collection
Browse files Browse the repository at this point in the history
  • Loading branch information
leoguillaumegouv committed Jul 10, 2024
1 parent 80dd8cf commit 57dc7e1
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 106 deletions.
72 changes: 31 additions & 41 deletions app/endpoints/albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@ async def chat_history(

return chat_history


@router.post("/files", tags=["Albert"])
async def upload_files(
user: str,
files: List[UploadFile],
collection: str,
model: str,
chunk_size: int = 3000,
chunk_overlap: int = 400,
chunk_min_size: int = 90,
files: List[UploadFile],
chunk_size: Optional[int] = 512,
chunk_overlap: Optional[int] = 0,
chunk_min_size: Optional[int] = 10,
api_key: str = Security(check_api_key),
) -> FileUploadResponse:
"""
Expand All @@ -68,9 +67,9 @@ async def upload_files(
response = {"object": "list", "data": []}

try:
clients["files"].head_bucket(Bucket=user)
clients["files"].head_bucket(Bucket=collection)
except ClientError:
clients["files"].create_bucket(Bucket=user)
clients["files"].create_bucket(Bucket=collection)

loader = S3FileLoader(
s3=clients["files"],
Expand Down Expand Up @@ -101,7 +100,7 @@ async def upload_files(
# upload files into S3 bucket
clients["files"].upload_fileobj(
file.file,
user,
collection,
file_id,
ExtraArgs={
"ContentType": file.content_type,
Expand All @@ -119,11 +118,11 @@ async def upload_files(

try:
# convert files into langchain documents
documents = loader._get_elements(file_id=file_id, bucket=user)
documents = loader._get_elements(file_id=file_id, bucket=collection)
except Exception as e:
logging.error(f"convert {file_name} into documents:\n{e}")
status = "failed"
clients["files"].delete_object(Bucket=user, Key=file_id)
clients["files"].delete_object(Bucket=collection, Key=file_id)
response["data"].append({"object": "upload", "id": file_id, "filename": file_name, "status": status}) # fmt: off
continue

Expand All @@ -132,14 +131,14 @@ async def upload_files(
db = await VectorStore.afrom_documents(
documents=documents,
embedding=embedding,
collection_name=user,
collection_name=collection,
url=clients["vectors"].url,
api_key=clients["vectors"].api_key,
)
except Exception as e:
logging.error(f"create vectors of {file_name}:\n{e}")
status = "failed"
clients["files"].delete_object(Bucket=user, Key=file_id)
clients["files"].delete_object(Bucket=collection, Key=file_id)
response["data"].append({"object": "upload", "id": file_id, "filename": file_name, "status": status}) # fmt: off
continue

Expand All @@ -150,10 +149,10 @@ async def upload_files(
return response


@router.get("/files/{user}/{file_id}")
@router.get("/files/{user}")
@router.get("/files/{collection}/{file_id}")
@router.get("/files/{collection}")
def files(
user: str, file_id: Optional[str] = None, api_key: str = Security(check_api_key)
collection: str, file_id: Optional[str] = None, api_key: str = Security(check_api_key)
) -> Union[File, FileResponse]:
response = {"object": "list", "metadata": {"files": 0, "vectors": 0}, "data": []}
"""
Expand All @@ -162,13 +161,13 @@ def files(
"""

try:
clients["files"].head_bucket(Bucket=user)
clients["files"].head_bucket(Bucket=collection)
except ClientError:
raise HTTPException(status_code=404, detail="Files not found")

response = {"object": "list", "data": []}
objects = clients["files"].list_objects_v2(Bucket=user).get("Contents", [])
objects = [object | clients["files"].head_object(Bucket=user, Key=object["Key"])["Metadata"] for object in objects] # fmt: off
objects = clients["files"].list_objects_v2(Bucket=collection).get("Contents", [])
objects = [object | clients["files"].head_object(Bucket=collection, Key=object["Key"])["Metadata"] for object in objects] # fmt: off
for object in objects:
data = {
"id": object["Key"],
Expand All @@ -188,32 +187,32 @@ def files(
return response


@router.delete("/files/{user}/{file_id}")
@router.delete("/files/{user}")
@router.delete("/files/{collection}/{file_id}")
@router.delete("/files/{collection}")
def delete_file(
user: str, file_id: Optional[str] = None, api_key: str = Security(check_api_key)
collection: str, file_id: Optional[str] = None, api_key: str = Security(check_api_key)
) -> Response:
"""
Delete files from configured files and vectors databases.
"""

try:
clients["files"].head_bucket(Bucket=user)
clients["files"].head_bucket(Bucket=collection)
except ClientError:
raise HTTPException(status_code=404, detail="Bucket not found")

if file_id is None:
objects = clients["files"].list_objects_v2(Bucket=user)
objects = clients["files"].list_objects_v2(Bucket=collection)
if "Contents" in objects:
objects = [{"Key": obj["Key"]} for obj in objects["Contents"]]
clients["files"].delete_objects(Bucket=user, Delete={"Objects": objects})
clients["files"].delete_objects(Bucket=collection, Delete={"Objects": objects})

clients["files"].delete_bucket(Bucket=user)
clients["vectors"].delete_collection(user)
clients["files"].delete_bucket(Bucket=collection)
clients["vectors"].delete_collection(collection)
else:
clients["files"].delete_object(Bucket=user, Key=file_id)
clients["files"].delete_object(Bucket=collection, Key=file_id)
filter = rest.Filter(must=[rest.FieldCondition(key="metadata.file_id", match=rest.MatchAny(any=[file_id]))]) # fmt: off
clients["vectors"].delete(collection_name=user, points_selector=rest.FilterSelector(filter=filter)) # fmt: off
clients["vectors"].delete(collection_name=collection, points_selector=rest.FilterSelector(filter=filter)) # fmt: off

return Response(status_code=204)

Expand All @@ -236,22 +235,13 @@ def tools(api_key: str = Security(check_api_key)) -> ToolResponse:
return response


@router.get("/collections/{user}")
def collections(user: Optional[str], api_key: str = Security(check_api_key)) -> CollectionResponse:
@router.get("/collections")
def collections(api_key: str = Security(check_api_key)) -> CollectionResponse:
"""
Get list of collections.
"""

response = clients["vectors"].get_collections()
collections = [
{
"object": "collection",
"name": collection,
"type": "user" if collection == user else "public",
}
for collection in response.collections
if collection.name.startswith("public_") or collection.name == user
]

response = {"object": "list", "data": collections}

return response
50 changes: 30 additions & 20 deletions app/tools/_baserag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@

class BaseRAG:
"""
Base RAG, basic retrival augmented generation.
Base RAG, basic retrival augmented generation.
Args:
embeddings_model (str): OpenAI embeddings model
collection_name (Optional[str], optional): Collection name. Defaults to "user" parameters.
file_ids (Optional[List[str]], optional): List of file ids. Defaults to None.
k (int, optional): Top K. Defaults to 4.
prompt_template (Optional[str], optional): Prompt template. Defaults to DEFAULT_PROMPT_TEMPLATE.
Args:
embeddings_model (str): OpenAI embeddings model
collection (List[Optional[str]]): Collection names. Defaults to "user" parameter.
file_ids (Optional[List[str]], optional): List of file ids for user collections (after upload files). Defaults to None.
k (int, optional): Top K per collection (max: 6). Defaults to 4.
prompt_template (Optional[str], optional): Prompt template. Defaults to DEFAULT_PROMPT_TEMPLATE.
"""

DEFAULT_PROMPT_TEMPLATE = "Réponds à la question suivante en te basant sur les documents ci-dessous : %(prompt)s\n\nDocuments :\n\n%(docs)s"
MAX_K = 6

def __init__(self, clients: dict, user: str):
self.user = user
Expand All @@ -27,15 +28,14 @@ def __init__(self, clients: dict, user: str):
def get_rag_prompt(
self,
embeddings_model: str,
#@TODO: add multiple collections support
collection_name: Optional[str] = None,
collections: List[Optional[str]],
file_ids: Optional[List[str]] = None,
#@TODO: add max value of k to ensure that the value is not too high
k: int = 4,
k: Optional[int] = 4,
prompt_template: Optional[str] = DEFAULT_PROMPT_TEMPLATE,
**request,
) -> str:
collection_name = collection_name or self.user
if k > self.MAX_K:
raise HTTPException(status_code=400, detail=f"K must be less than or equal to {self.MAX_K}")

try:
model_url = str(self.clients["openai"][embeddings_model].base_url)
Expand All @@ -48,15 +48,25 @@ def get_rag_prompt(
huggingfacehub_api_token=self.clients["openai"][embeddings_model].api_key,
)

vectorstore = Qdrant(
client=self.clients["vectors"],
embeddings=embeddings,
collection_name=collection_name,
)
filter = rest.Filter(must=[rest.FieldCondition(key="metadata.file_id", match=rest.MatchAny(any=file_ids))]) if file_ids else None # fmt: off

all_collections = [
collection.name for collection in self.clients["vectors"].get_collections().collections
]
filter = rest.Filter(must=[rest.FieldCondition(key="metadata.file_id", match=rest.MatchAny(any=file_ids))]) if file_ids else None # fmt: off
prompt = request["messages"][-1]["content"]
docs = vectorstore.similarity_search(prompt, k=k, filter=filter)

docs = []
for collection in collections:
# check if collections exists
if collection not in all_collections:
raise HTTPException(status_code=404, detail=f"Collection {collection} not found")

vectorstore = Qdrant(
client=self.clients["vectors"],
embeddings=embeddings,
collection_name=collection,
)
docs.extend(vectorstore.similarity_search(prompt, k=k, filter=filter))

docs = "\n\n".join([doc.page_content for doc in docs])

prompt = prompt_template % {"docs": docs, "prompt": prompt}
Expand Down
70 changes: 52 additions & 18 deletions tutorials/chat_completions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
{
"data": {
"text/plain": [
"ChatCompletion(id='cmpl-380600d1542042ba81e3de8aee46f18d', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Salut ! Comment vas-tu ?', role='assistant', function_call=None, tool_calls=[]), stop_reason=128009)], created=1719511595, model='AgentPublic/llama3-instruct-8b', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=9, prompt_tokens=14, total_tokens=23))"
"ChatCompletion(id='cmpl-2fa6f08162cb43b8bd8b38920090a033', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=\"Bonjour ! Je suis Albert, un modèle linguistique d'intelligence artificielle. Je suis ravi de discuter avec toi ! Qu'est-ce que tu veux parler ou demander ?\", role='assistant', function_call=None, tool_calls=[]), stop_reason=128009)], created=1720595202, model='AgentPublic/llama3-instruct-8b', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=41, prompt_tokens=14, total_tokens=55))"
]
},
"execution_count": 4,
Expand Down Expand Up @@ -119,16 +119,33 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 19,
"id": "d60dff87-eb79-46d0-8c7d-e54cf33d3181",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"b'{\"id\":\"340bf5f9-3be0-4325-a1eb-44269d3dcb94\",\"choices\":[{\"finish_reason\":\"stop\",\"index\":0,\"logprobs\":null,\"message\":{\"content\":\"Salut ! Comment vas-tu aujourd\\'hui ?\",\"role\":\"assistant\",\"function_call\":null,\"tool_calls\":[]},\"stop_reason\":128009}],\"created\":1719511608,\"model\":\"AgentPublic/llama3-instruct-8b\",\"object\":\"chat.completion\",\"service_tier\":null,\"system_fingerprint\":null,\"usage\":{\"completion_tokens\":11,\"prompt_tokens\":14,\"total_tokens\":25}}'\n"
]
"data": {
"text/plain": [
"{'id': '91d4ed10-9ba4-409c-8c15-01548f7f8a68',\n",
" 'choices': [{'finish_reason': 'stop',\n",
" 'index': 0,\n",
" 'logprobs': None,\n",
" 'message': {'content': \"Salut ! Comment vas-tu aujourd'hui ? Je suis Albert, un modèle de langage artificiel, et je suis là pour discuter avec toi ! Qu'est-ce que tu veux parler de ?\",\n",
" 'role': 'assistant',\n",
" 'function_call': None,\n",
" 'tool_calls': []},\n",
" 'stop_reason': 128009}],\n",
" 'created': 1720595309,\n",
" 'model': 'AgentPublic/llama3-instruct-8b',\n",
" 'object': 'chat.completion',\n",
" 'service_tier': None,\n",
" 'system_fingerprint': None,\n",
" 'usage': {'completion_tokens': 44, 'prompt_tokens': 14, 'total_tokens': 58}}"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
Expand All @@ -145,33 +162,50 @@
"\n",
"response = session.post(url=f\"{base_url}/chat/completions\", json=data)\n",
"\n",
"print(response.content)"
"response.json()"
]
},
{
"cell_type": "markdown",
"id": "28980487-dd6a-4258-9cc9-c4d0ef97bd57",
"id": "fbfec4f8-3c73-4eb7-a8d1-ad7adfb175f9",
"metadata": {},
"source": [
"Now, to continue the conversation with the current chat history, you need to pass the `id` parameter that was returned in the previous template response (here *340bf5f9-3be0-4325-a1eb-44269d3dcb94*)."
"Now, to continue the conversation with the current chat history, you need to pass the `id` parameter that was returned in the previous template response."
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 18,
"id": "6a062fab-d72d-4f82-a69d-1a74ad98a214",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"b'{\"id\":\"340bf5f9-3be0-4325-a1eb-44269d3dcb94\",\"choices\":[{\"finish_reason\":\"stop\",\"index\":0,\"logprobs\":null,\"message\":{\"content\":\"Ton pr\\xc3\\xa9c\\xc3\\xa9dent message \\xc3\\xa9tait : \\\\\"Salut Albert !\\\\\"\",\"role\":\"assistant\",\"function_call\":null,\"tool_calls\":[]},\"stop_reason\":128009}],\"created\":1719511628,\"model\":\"AgentPublic/llama3-instruct-8b\",\"object\":\"chat.completion\",\"service_tier\":null,\"system_fingerprint\":null,\"usage\":{\"completion_tokens\":13,\"prompt_tokens\":43,\"total_tokens\":56}}'\n"
]
"data": {
"text/plain": [
"{'id': 'd433a128-8b53-42bd-b89f-156c75e102b3',\n",
" 'choices': [{'finish_reason': 'stop',\n",
" 'index': 0,\n",
" 'logprobs': None,\n",
" 'message': {'content': 'Votre précédent message est également : \"Salut Albert !\"',\n",
" 'role': 'assistant',\n",
" 'function_call': None,\n",
" 'tool_calls': []},\n",
" 'stop_reason': 128009}],\n",
" 'created': 1720595300,\n",
" 'model': 'AgentPublic/llama3-instruct-8b',\n",
" 'object': 'chat.completion',\n",
" 'service_tier': None,\n",
" 'system_fingerprint': None,\n",
" 'usage': {'completion_tokens': 14, 'prompt_tokens': 162, 'total_tokens': 176}}"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat_id = \"340bf5f9-3be0-4325-a1eb-44269d3dcb94\"\n",
"chat_id = response.json()[\"id\"]\n",
"\n",
"data = {\n",
" \"model\": \"AgentPublic/llama3-instruct-8b\",\n",
Expand All @@ -183,7 +217,7 @@
"}\n",
"response = session.post(url=f\"{base_url}/chat/completions\", json=data)\n",
"\n",
"print(response.content)"
"response.json()"
]
},
{
Expand Down
Loading

0 comments on commit 57dc7e1

Please sign in to comment.