-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5274bc0
commit 9de4aac
Showing
11 changed files
with
185 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,17 @@ | ||
from typing import List, Optional | ||
|
||
from pydantic import BaseModel | ||
from pydantic import BaseModel, Field, field_validator | ||
|
||
|
||
class SearchRequest(BaseModel): | ||
prompt: str | ||
model: str | ||
collections: List[str] | ||
k: int | ||
k: int = Field(gt=0, description="Number of results to return") | ||
score_threshold: Optional[float] = None | ||
|
||
@field_validator("prompt") | ||
def blank_string(value): | ||
if value.strip() == "": | ||
raise ValueError("Prompt cannot be empty") | ||
return value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import os | ||
import logging | ||
import pytest | ||
import wget | ||
|
||
from app.schemas.chunks import Chunk, Chunks | ||
from app.schemas.config import EMBEDDINGS_MODEL_TYPE | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def setup(args, session): | ||
COLLECTION = "pytest" | ||
FILE_NAME = "pytest.pdf" | ||
FILE_URL = "http://www.legifrance.gouv.fr/download/file/rxcTl0H4YnnzLkMLiP4x15qORfLSKk_h8QsSb2xnJ8Y=/JOE_TEXTE" | ||
|
||
# Delete the collection if it exists | ||
response = session.delete(f"{args['base_url']}/collections/{COLLECTION}") | ||
assert response.status_code == 204 or response.status_code == 404, f"error: delete collection ({response.status_code} - {response.text})" | ||
|
||
# Get a embedding model | ||
response = session.get(f"{args['base_url']}/models") | ||
response = response.json()["data"] | ||
EMBEDDINGS_MODEL = [model["id"] for model in response if model["type"] == EMBEDDINGS_MODEL_TYPE][0] | ||
logging.debug(f"model: {EMBEDDINGS_MODEL}") | ||
|
||
# Download a file | ||
if not os.path.exists(FILE_NAME): | ||
wget.download(FILE_URL, out=FILE_NAME) | ||
|
||
# Upload the file to the collection | ||
params = {"embeddings_model": EMBEDDINGS_MODEL, "collection": COLLECTION} | ||
files = {"files": (os.path.basename(FILE_NAME), open(FILE_NAME, "rb"), "application/pdf")} | ||
response = session.post(f"{args['base_url']}/files", params=params, files=files, timeout=30) | ||
assert response.status_code == 200, f"error: upload file ({response.status_code} - {response.text})" | ||
|
||
# Check if the file is uploaded | ||
response = session.get(f"{args['base_url']}/files/{COLLECTION}", timeout=10) | ||
assert response.status_code == 200, f"error: retrieve files ({response.status_code} - {response.text})" | ||
files = response.json() | ||
assert len(files["data"]) == 1 | ||
assert files["data"][0]["file_name"] == FILE_NAME | ||
FILE_ID = files["data"][0]["id"] | ||
|
||
CHUNK_IDS = files["data"][0]["chunks"] | ||
|
||
# Get chunks of the file | ||
data = {"chunks": CHUNK_IDS} | ||
response = session.post(f"{args['base_url']}/chunks/{COLLECTION}", json=data, timeout=10) | ||
assert response.status_code == 200, f"error: retrieve chunks ({response.status_code} - {response.text})" | ||
chunks = response.json() | ||
MAX_K = len(chunks["data"]) | ||
|
||
if os.path.exists(FILE_NAME): | ||
os.remove(FILE_NAME) | ||
|
||
yield EMBEDDINGS_MODEL, FILE_ID, MAX_K, COLLECTION | ||
|
||
|
||
@pytest.mark.usefixtures("args", "session") | ||
class TestSearch: | ||
def test_search_response_status_code(self, args, session, setup): | ||
"""Test the POST /search response status code.""" | ||
|
||
EMBEDDINGS_MODEL, _, MAX_K, COLLECTION = setup | ||
data = {"prompt": "test query", "model": EMBEDDINGS_MODEL, "collections": [COLLECTION], "k": MAX_K} | ||
response = session.post(f"{args['base_url']}/search", json=data) | ||
assert response.status_code == 200, f"error: search request ({response.status_code} - {response.text})" | ||
|
||
chunks = Chunks(**response.json()) | ||
assert isinstance(chunks, Chunks) | ||
assert all(isinstance(chunk, Chunk) for chunk in chunks.data) | ||
|
||
def test_search_with_score_threshold(self, args, session, setup): | ||
"""Test search with a score threshold.""" | ||
|
||
EMBEDDINGS_MODEL, _, MAX_K, COLLECTION = setup | ||
data = {"prompt": "test query", "model": EMBEDDINGS_MODEL, "collections": [COLLECTION], "k": MAX_K, "score_threshold": 0.5} | ||
response = session.post(f"{args['base_url']}/search", json=data) | ||
assert response.status_code == 200 | ||
|
||
def test_search_invalid_collection(self, args, session, setup): | ||
"""Test search with an invalid collection.""" | ||
|
||
EMBEDDINGS_MODEL, _, MAX_K, _ = setup | ||
data = {"prompt": "test query", "model": EMBEDDINGS_MODEL, "collections": ["non_existent_collection"], "k": MAX_K} | ||
response = session.post(f"{args['base_url']}/search", json=data) | ||
assert response.status_code == 404 | ||
|
||
def test_search_invalid_k(self, args, session, setup): | ||
"""Test search with an invalid k value.""" | ||
|
||
EMBEDDINGS_MODEL, _, _, COLLECTION = setup | ||
data = {"prompt": "test query", "model": EMBEDDINGS_MODEL, "collections": [COLLECTION], "k": 0} | ||
response = session.post(f"{args['base_url']}/search", json=data) | ||
assert response.status_code == 422 | ||
|
||
def test_search_empty_prompt(self, args, session, setup): | ||
"""Test search with an empty prompt.""" | ||
|
||
EMBEDDINGS_MODEL, _, MAX_K, COLLECTION = setup | ||
data = {"prompt": "", "model": EMBEDDINGS_MODEL, "collections": [COLLECTION], "k": MAX_K} | ||
response = session.post(f"{args['base_url']}/search", json=data) | ||
assert response.status_code == 422 | ||
|
||
def test_search_invalid_model(self, args, session, setup): | ||
"""Test search with an invalid model.""" | ||
|
||
_, _, MAX_K, COLLECTION = setup | ||
data = {"prompt": "test query", "model": "non_existent_model", "collections": [COLLECTION], "k": MAX_K} | ||
response = session.post(f"{args['base_url']}/search", json=data) | ||
assert response.status_code == 404 |
Oops, something went wrong.