Skip to content

Commit

Permalink
fix: search endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
leoguillaumegouv committed Sep 18, 2024
1 parent 5274bc0 commit 9de4aac
Show file tree
Hide file tree
Showing 11 changed files with 185 additions and 69 deletions.
22 changes: 13 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# Albert API
![](https://img.shields.io/badge/python-3.12-green) ![](https://img.shields.io/badge/vLLM-v0.5.5-blue) ![](https://img.shields.io/badge/HuggingFace%20Text%20Embeddings%20Inference-1.5-red)

## Fonctionnalités
Albert API est une API open source d'IA générative développée par Etalab. Elle permet d'être un proxy entre des modèles de langage et vos données. Elle aggrège les services suivants :
- [vLLM](https://github.com/vllm-project/vllm) pour la gestion des modèles de langage
- [HuggingFace Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference) pour la génération d'embeddings
- [Qdrant](https://qdrant.tech/) pour la recherche de similarité

### OpenAI conventions


En ce basant sur les conventions définies par OpenAI, l'API Albert expose des endpoints qui peuvent être appelés avec le [client officiel python d'OpenAI](https://github.com/openai/openai-python/tree/main). Ce formalisme permet d'intégrer facilement l'API Albert avec des bibliothèques tierces comme [Langchain](https://www.langchain.com/) ou [LlamaIndex](https://www.llamaindex.ai/).

## ⚙️ Fonctionnalités

### Converser avec un modèle de langage (chat memory)

<a target="_blank" href="https://colab.research.google.com/github/etalab-ia/albert-api/blob/main/docs/tutorials/chat_completions.ipynb">
Expand All @@ -23,14 +28,13 @@ Albert API intègre nativement la mémorisation des messages pour les conversati

Grâce à un fichier de configuration (*[config.example.yml](./config.example.yml)*) vous pouvez connecter autant d'API de modèles que vous le souhaitez. L'API Albert se charge de mutualiser l'accès à tous ces modèles dans une unique API. Vous pouvez obtenir la liste des différents modèles accessibles en appelant le endpoint `/v1/models`.

### Fonctionnalités avancées (tools)

Les tools sont une fonctionnalité définie par OpenAI que l'on surcharge dans le cas de l'API Albert pour permettre de configurer des tâches spéficiques comme du RAG ou la génération de résumé. Vous pouvez appelez le endpoint `/tools` pour voir la liste des tools disponibles.

![](./docs/assets/chatcompletion.png)

#### Interroger des documents (RAG)
### Interroger des documents (RAG)

<a target="_blank" href="https://colab.research.google.com/github/etalab-ia/albert-api/blob/main/docs/tutorials/retrival_augmented_generation.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>


## 🧑‍💻 Contribuez au projet

Albert API est un projet open source, vous pouvez contribuez au projet, veuillez lire notre [guide de contribution](./CONTRIBUTING.md).
33 changes: 20 additions & 13 deletions app/endpoints/chunks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union, Optional

from fastapi import APIRouter, Security
from qdrant_client.http.models import Filter, HasIdCondition

Expand All @@ -11,24 +9,33 @@
router = APIRouter()


# @TODO: add pagination
@router.get("/chunks/{collection}/{chunk}")
@router.post("/chunks/{collection}")
async def chunks(
async def get_chunk(
collection: str,
chunk: Optional[str] = None,
request: Optional[ChunkRequest] = None,
chunk: str,
user: str = Security(check_api_key),
) -> Union[Chunk, Chunks]:
) -> Chunk:
"""
Get a chunk.
Get a single chunk.
"""

vectorstore = VectorStore(clients=clients, user=user)
ids = [chunk] if chunk else dict(request)["chunks"]
ids = [chunk]
filter = Filter(must=[HasIdCondition(has_id=ids)])
chunks = vectorstore.get_chunks(collection_name=collection, filter=filter)
if not request:
return chunks[0]
return chunks[0]


@router.post("/chunks/{collection}")
async def get_chunks(
collection: str,
request: ChunkRequest,
user: str = Security(check_api_key),
) -> Chunks:
"""
Get multiple chunks.
"""
vectorstore = VectorStore(clients=clients, user=user)
ids = request.chunks
filter = Filter(must=[HasIdCondition(has_id=ids)])
chunks = vectorstore.get_chunks(collection_name=collection, filter=filter)
return Chunks(data=chunks)
15 changes: 7 additions & 8 deletions app/endpoints/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ async def upload_files(
file.file,
collection_id,
file_id,
ExtraArgs={"ContentType": file.content_type, "Metadata": {"filename": encoded_file_name, "id": file_id}},
ExtraArgs={"ContentType": file.content_type, "Metadata": {"file_name": encoded_file_name, "id": file_id}},
)
except Exception as e:
LOGGER.error(f"store {file_name}:\n{e}")
data.append(Upload(id=file_id, filename=file_name, status="failed"))
data.append(Upload(id=file_id, file_name=file_name, status="failed"))
continue

try:
Expand All @@ -87,7 +87,7 @@ async def upload_files(
except Exception as e:
LOGGER.error(f"convert {file_name} into documents:\n{e}")
clients["files"].delete_object(Bucket=collection_id, Key=file_id)
data.append(Upload(id=file_id, filename=file_name, status="failed"))
data.append(Upload(id=file_id, file_name=file_name, status="failed"))
continue

try:
Expand All @@ -99,15 +99,14 @@ async def upload_files(
except Exception as e:
LOGGER.error(f"create vectors of {file_name}:\n{e}")
clients["files"].delete_object(Bucket=collection_id, Key=file_id)
data.append(Upload(id=file_id, filename=file_name, status="failed"))
data.append(Upload(id=file_id, file_name=file_name, status="failed"))
continue

data.append(Upload(id=file_id, filename=file_name, status="success"))
data.append(Upload(id=file_id, file_name=file_name, status="success"))

return Uploads(data=data)


# @TODO: add pagination
@router.get("/files/{collection}/{file}")
@router.get("/files/{collection}")
async def files(
Expand Down Expand Up @@ -142,8 +141,8 @@ async def files(
id=object["Key"],
object="file",
bytes=object["Size"],
filename=base64.b64decode(object["filename"].encode("ascii")).decode("utf-8"),
chunk_ids=chunk_ids,
file_name=base64.b64decode(object["file_name"].encode("ascii")).decode("utf-8"),
chunks=chunk_ids,
created_at=round(object["LastModified"].timestamp()),
)
data.append(object)
Expand Down
4 changes: 3 additions & 1 deletion app/endpoints/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ async def search(request: SearchRequest, user: str = Security(check_api_key)) ->
"""

vectorstore = VectorStore(clients=clients, user=user)
data = vectorstore.search(prompt=request.prompt, collection_names=request.collections, k=request.k, score_threshold=request.score_threshold)
data = vectorstore.search(
prompt=request.prompt, model=request.model, collection_names=request.collections, k=request.k, score_threshold=request.score_threshold
)

return Chunks(data=data)
32 changes: 10 additions & 22 deletions app/helpers/_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def from_documents(self, documents: List[Document], model: str, collection_name:

collection = self.get_collection_metadata(collection_names=[collection_name])[0]
if collection.model != model:
raise HTTPException(status_code=400, detail=f"Model {collection.model} does not match {model}")
raise HTTPException(status_code=400, detail="Wrong model collection")

for i in range(0, len(documents), self.BATCH_SIZE):
batch = documents[i : i + self.BATCH_SIZE]
Expand Down Expand Up @@ -72,7 +72,7 @@ def search(
collections = self.get_collection_metadata(collection_names=collection_names)
for collection in collections:
if collection.model != model:
raise HTTPException(status_code=400, detail=f"Model {collection.model} does not match {model}")
raise HTTPException(status_code=400, detail="Wrong model collection")

results = self.vectors.search(
collection_name=collection.id,
Expand All @@ -82,18 +82,13 @@ def search(
with_payload=True,
query_filter=filter,
)
for i, result in enumerate(results):
results[i] = result.model_dump()
results[i]["collection"] = collection.name

for result in results:
result.payload["metadata"]["collection"] = collection.name
chunks.extend(results)

# sort by similarity score and get top k
chunks = sorted(chunks, key=lambda x: x["score"], reverse=True)[:k]
chunks = [
Chunk(id=chunk["id"], collection=chunk["collection"], content=chunk["payload"]["page_content"], metadata=chunk["payload"]["metadata"])
for chunk in chunks
]
chunks = sorted(chunks, key=lambda x: x.score, reverse=True)[:k]
chunks = [Chunk(id=chunk.id, content=chunk.payload["page_content"], metadata=chunk.payload["metadata"]) for chunk in chunks]

return chunks

Expand Down Expand Up @@ -251,16 +246,9 @@ def get_chunks(self, collection_name: str, filter: Optional[Filter] = None) -> L
List[Chunk]: A list of Chunk objects containing the retrieved chunks.
"""
collection = self.get_collection_metadata(collection_names=[collection_name], type="all")[0]
chunks = self.vectors.scroll(
collection_name=collection.id,
with_payload=True,
with_vectors=False,
scroll_filter=filter,
limit=100, # @TODO: add pagination
)[0]
chunks = [
Chunk(collection=collection_name, id=chunk.id, metadata=chunk.payload["metadata"], content=chunk.payload["page_content"])
for chunk in chunks
]
chunks = self.vectors.scroll(collection_name=collection.id, with_payload=True, with_vectors=False, scroll_filter=filter)[0]
for chunk in chunks:
chunk.payload["metadata"]["collection"] = collection.name
chunks = [Chunk(id=chunk.id, metadata=chunk.payload["metadata"], content=chunk.payload["page_content"]) for chunk in chunks]

return chunks
1 change: 0 additions & 1 deletion app/schemas/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

class Chunk(BaseModel):
object: Literal["chunk"] = "chunk"
collection: str
id: str
metadata: dict
content: str
Expand Down
6 changes: 3 additions & 3 deletions app/schemas/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ class File(BaseModel):
object: Literal["file"] = "file"
id: UUID
bytes: int
filename: str
chunk_ids: Optional[list] = []
file_name: str
chunks: Optional[list] = []
created_at: int


Expand All @@ -21,7 +21,7 @@ class Files(BaseModel):
class Upload(BaseModel):
object: Literal["upload"] = "upload"
id: UUID
filename: str
file_name: str
status: Literal["success", "failed"] = "success"


Expand Down
11 changes: 9 additions & 2 deletions app/schemas/search.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from typing import List, Optional

from pydantic import BaseModel
from pydantic import BaseModel, Field, field_validator


class SearchRequest(BaseModel):
prompt: str
model: str
collections: List[str]
k: int
k: int = Field(gt=0, description="Number of results to return")
score_threshold: Optional[float] = None

@field_validator("prompt")
def blank_string(value):
if value.strip() == "":
raise ValueError("Prompt cannot be empty")
return value
2 changes: 1 addition & 1 deletion app/tests/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_upload_file(self, args, session, setup):
files["data"] = [File(**file) for file in files["data"]]
assert len(files["data"]) == 1, f"error: number of files ({len(files)})"
files = Files(**files)
assert files.data[0].filename == FILE_NAME, f"error: filename ({files.data[0].filename})"
assert files.data[0].file_name == FILE_NAME, f"error: file name ({files.data[0].file_name})"
assert files.data[0].id == file_id, f"error: file id ({files.data[0].id})"

def test_collection_creation(self, args, session, setup):
Expand Down
111 changes: 111 additions & 0 deletions app/tests/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
import logging
import pytest
import wget

from app.schemas.chunks import Chunk, Chunks
from app.schemas.config import EMBEDDINGS_MODEL_TYPE


@pytest.fixture(scope="function")
def setup(args, session):
COLLECTION = "pytest"
FILE_NAME = "pytest.pdf"
FILE_URL = "http://www.legifrance.gouv.fr/download/file/rxcTl0H4YnnzLkMLiP4x15qORfLSKk_h8QsSb2xnJ8Y=/JOE_TEXTE"

# Delete the collection if it exists
response = session.delete(f"{args['base_url']}/collections/{COLLECTION}")
assert response.status_code == 204 or response.status_code == 404, f"error: delete collection ({response.status_code} - {response.text})"

# Get a embedding model
response = session.get(f"{args['base_url']}/models")
response = response.json()["data"]
EMBEDDINGS_MODEL = [model["id"] for model in response if model["type"] == EMBEDDINGS_MODEL_TYPE][0]
logging.debug(f"model: {EMBEDDINGS_MODEL}")

# Download a file
if not os.path.exists(FILE_NAME):
wget.download(FILE_URL, out=FILE_NAME)

# Upload the file to the collection
params = {"embeddings_model": EMBEDDINGS_MODEL, "collection": COLLECTION}
files = {"files": (os.path.basename(FILE_NAME), open(FILE_NAME, "rb"), "application/pdf")}
response = session.post(f"{args['base_url']}/files", params=params, files=files, timeout=30)
assert response.status_code == 200, f"error: upload file ({response.status_code} - {response.text})"

# Check if the file is uploaded
response = session.get(f"{args['base_url']}/files/{COLLECTION}", timeout=10)
assert response.status_code == 200, f"error: retrieve files ({response.status_code} - {response.text})"
files = response.json()
assert len(files["data"]) == 1
assert files["data"][0]["file_name"] == FILE_NAME
FILE_ID = files["data"][0]["id"]

CHUNK_IDS = files["data"][0]["chunks"]

# Get chunks of the file
data = {"chunks": CHUNK_IDS}
response = session.post(f"{args['base_url']}/chunks/{COLLECTION}", json=data, timeout=10)
assert response.status_code == 200, f"error: retrieve chunks ({response.status_code} - {response.text})"
chunks = response.json()
MAX_K = len(chunks["data"])

if os.path.exists(FILE_NAME):
os.remove(FILE_NAME)

yield EMBEDDINGS_MODEL, FILE_ID, MAX_K, COLLECTION


@pytest.mark.usefixtures("args", "session")
class TestSearch:
def test_search_response_status_code(self, args, session, setup):
"""Test the POST /search response status code."""

EMBEDDINGS_MODEL, _, MAX_K, COLLECTION = setup
data = {"prompt": "test query", "model": EMBEDDINGS_MODEL, "collections": [COLLECTION], "k": MAX_K}
response = session.post(f"{args['base_url']}/search", json=data)
assert response.status_code == 200, f"error: search request ({response.status_code} - {response.text})"

chunks = Chunks(**response.json())
assert isinstance(chunks, Chunks)
assert all(isinstance(chunk, Chunk) for chunk in chunks.data)

def test_search_with_score_threshold(self, args, session, setup):
"""Test search with a score threshold."""

EMBEDDINGS_MODEL, _, MAX_K, COLLECTION = setup
data = {"prompt": "test query", "model": EMBEDDINGS_MODEL, "collections": [COLLECTION], "k": MAX_K, "score_threshold": 0.5}
response = session.post(f"{args['base_url']}/search", json=data)
assert response.status_code == 200

def test_search_invalid_collection(self, args, session, setup):
"""Test search with an invalid collection."""

EMBEDDINGS_MODEL, _, MAX_K, _ = setup
data = {"prompt": "test query", "model": EMBEDDINGS_MODEL, "collections": ["non_existent_collection"], "k": MAX_K}
response = session.post(f"{args['base_url']}/search", json=data)
assert response.status_code == 404

def test_search_invalid_k(self, args, session, setup):
"""Test search with an invalid k value."""

EMBEDDINGS_MODEL, _, _, COLLECTION = setup
data = {"prompt": "test query", "model": EMBEDDINGS_MODEL, "collections": [COLLECTION], "k": 0}
response = session.post(f"{args['base_url']}/search", json=data)
assert response.status_code == 422

def test_search_empty_prompt(self, args, session, setup):
"""Test search with an empty prompt."""

EMBEDDINGS_MODEL, _, MAX_K, COLLECTION = setup
data = {"prompt": "", "model": EMBEDDINGS_MODEL, "collections": [COLLECTION], "k": MAX_K}
response = session.post(f"{args['base_url']}/search", json=data)
assert response.status_code == 422

def test_search_invalid_model(self, args, session, setup):
"""Test search with an invalid model."""

_, _, MAX_K, COLLECTION = setup
data = {"prompt": "test query", "model": "non_existent_model", "collections": [COLLECTION], "k": MAX_K}
response = session.post(f"{args['base_url']}/search", json=data)
assert response.status_code == 404
Loading

0 comments on commit 9de4aac

Please sign in to comment.