Skip to content

Commit

Permalink
fix: internet integration and metrics exception handler
Browse files Browse the repository at this point in the history
  • Loading branch information
leoguillaumegouv committed Dec 6, 2024
1 parent e85f03e commit 22da6ed
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 19 deletions.
21 changes: 20 additions & 1 deletion app/helpers/_metricsmiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,26 @@ async def dispatch(self, request: Request, call_next) -> Response:
if not content_type.startswith("multipart/form-data"):
body = await request.body()
body = body.decode(encoding="utf-8")
model = json.loads(body).get("model") if body else None
try:
model = json.loads(body).get("model")
except json.JSONDecodeError as e:
return Response(
status_code=422,
content=json.dumps(
{
"detail": [
{
"type": "json_invalid",
"loc": ["body", e.pos],
"msg": "JSON decode error",
"input": {},
"ctx": {"error": str(e.msg)},
}
]
}
),
media_type="application/json",
)

if authorization and authorization.startswith("Bearer "):
user_id = AuthenticationClient._api_key_to_user_id(input=authorization.split(sep=" ")[1])
Expand Down
9 changes: 8 additions & 1 deletion app/helpers/searchclients/_qdrantsearchclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
HYBRID_SEARCH_TYPE,
PUBLIC_COLLECTION_TYPE,
SEMANTIC_SEARCH_TYPE,
PRIVATE_COLLECTION_TYPE,
)


Expand Down Expand Up @@ -208,7 +209,13 @@ def get_collections(
return collections

def create_collection(
self, collection_id: str, collection_name: str, collection_model: str, collection_type: str, collection_description: str, user: User
self,
collection_id: str,
collection_name: str,
collection_model: str,
user: User,
collection_type: str = PRIVATE_COLLECTION_TYPE,
collection_description: Optional[str] = None,
) -> Collection:
"""
See SearchClient.create_collection
Expand Down
10 changes: 8 additions & 2 deletions app/helpers/searchclients/_searchclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from app.schemas.documents import Document
from app.schemas.search import Search
from app.schemas.security import User
from app.utils.variables import HYBRID_SEARCH_TYPE, LEXICAL_SEARCH_TYPE, SEMANTIC_SEARCH_TYPE
from app.utils.variables import HYBRID_SEARCH_TYPE, LEXICAL_SEARCH_TYPE, SEMANTIC_SEARCH_TYPE, PRIVATE_COLLECTION_TYPE


def to_camel_case(chaine):
Expand Down Expand Up @@ -78,7 +78,13 @@ def get_collections(self, collection_ids: List[str], user: User) -> List[Collect

@abstractmethod
def create_collection(
self, collection_id: str, collection_name: str, collection_model: str, collection_type: str, collection_description: str, user: User
self,
collection_id: str,
collection_name: str,
collection_model: str,
user: User,
collection_type: str = PRIVATE_COLLECTION_TYPE,
collection_description: Optional[str] = None,
) -> Collection:
"""
Create a collection, if collection already exists, return the collection id.
Expand Down
2 changes: 1 addition & 1 deletion app/schemas/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class SearchRequest(BaseModel):
prompt: str
collections: List[Union[UUID, Literal[INTERNET_COLLECTION_DISPLAY_ID]]]
rff_k: int = Field(default=20, description="k constant in RFF algorithm")
k: int = Field(gt=0, description="Number of results to return")
k: int = Field(gt=0, default=4, description="Number of results to return")
method: Literal[HYBRID_SEARCH_TYPE, LEXICAL_SEARCH_TYPE, SEMANTIC_SEARCH_TYPE] = Field(default=SEMANTIC_SEARCH_TYPE)
score_threshold: Optional[float] = Field(
0.0, ge=0.0, le=1.0, description="Score of cosine similarity threshold for filtering results, only available for semantic search method."
Expand Down
14 changes: 0 additions & 14 deletions app/tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,6 @@ def test_chat_completions_unknown_params(self, args, session_user, setup):
response = session_user.post(f"{args['base_url']}/chat/completions", json=params)
assert response.status_code == 200, f"error: retrieve chat completions ({response.status_code})"

def test_chat_completions_max_tokens_too_large(self, args, session_user, setup):
MODEL_ID, MAX_CONTEXT_LENGTH = setup

prompt = "test"
params = {
"model": MODEL_ID,
"messages": [{"role": "user", "content": prompt}],
"stream": False,
"n": 1,
"max_tokens": MAX_CONTEXT_LENGTH + 100,
}
response = session_user.post(f"{args['base_url']}/chat/completions", json=params)
assert response.status_code == 400, f"error: retrieve chat completions ({response.status_code})"

def test_chat_completions_context_too_large(self, args, session_user, setup):
MODEL_ID, MAX_CONTEXT_LENGTH = setup

Expand Down

0 comments on commit 22da6ed

Please sign in to comment.