From 22da6ed30f870d882562ca04a75ea7c1158c9baa Mon Sep 17 00:00:00 2001 From: leoguillaume Date: Fri, 6 Dec 2024 15:36:48 +0100 Subject: [PATCH] fix: internet integration and metrics exception handler --- app/helpers/_metricsmiddleware.py | 21 ++++++++++++++++++- .../searchclients/_qdrantsearchclient.py | 9 +++++++- app/helpers/searchclients/_searchclient.py | 10 +++++++-- app/schemas/search.py | 2 +- app/tests/test_chat.py | 14 ------------- 5 files changed, 37 insertions(+), 19 deletions(-) diff --git a/app/helpers/_metricsmiddleware.py b/app/helpers/_metricsmiddleware.py index 5d3c3251..015d2488 100644 --- a/app/helpers/_metricsmiddleware.py +++ b/app/helpers/_metricsmiddleware.py @@ -26,7 +26,26 @@ async def dispatch(self, request: Request, call_next) -> Response: if not content_type.startswith("multipart/form-data"): body = await request.body() body = body.decode(encoding="utf-8") - model = json.loads(body).get("model") if body else None + try: + model = json.loads(body).get("model") + except json.JSONDecodeError as e: + return Response( + status_code=422, + content=json.dumps( + { + "detail": [ + { + "type": "json_invalid", + "loc": ["body", e.pos], + "msg": "JSON decode error", + "input": {}, + "ctx": {"error": str(e.msg)}, + } + ] + } + ), + media_type="application/json", + ) if authorization and authorization.startswith("Bearer "): user_id = AuthenticationClient._api_key_to_user_id(input=authorization.split(sep=" ")[1]) diff --git a/app/helpers/searchclients/_qdrantsearchclient.py b/app/helpers/searchclients/_qdrantsearchclient.py index 8916f757..8962d4b5 100644 --- a/app/helpers/searchclients/_qdrantsearchclient.py +++ b/app/helpers/searchclients/_qdrantsearchclient.py @@ -35,6 +35,7 @@ HYBRID_SEARCH_TYPE, PUBLIC_COLLECTION_TYPE, SEMANTIC_SEARCH_TYPE, + PRIVATE_COLLECTION_TYPE, ) @@ -208,7 +209,13 @@ def get_collections( return collections def create_collection( - self, collection_id: str, collection_name: str, collection_model: str, collection_type: str, collection_description: str, user: User + self, + collection_id: str, + collection_name: str, + collection_model: str, + user: User, + collection_type: str = PRIVATE_COLLECTION_TYPE, + collection_description: Optional[str] = None, ) -> Collection: """ See SearchClient.create_collection diff --git a/app/helpers/searchclients/_searchclient.py b/app/helpers/searchclients/_searchclient.py index 0f89ffc7..f755df13 100644 --- a/app/helpers/searchclients/_searchclient.py +++ b/app/helpers/searchclients/_searchclient.py @@ -8,7 +8,7 @@ from app.schemas.documents import Document from app.schemas.search import Search from app.schemas.security import User -from app.utils.variables import HYBRID_SEARCH_TYPE, LEXICAL_SEARCH_TYPE, SEMANTIC_SEARCH_TYPE +from app.utils.variables import HYBRID_SEARCH_TYPE, LEXICAL_SEARCH_TYPE, SEMANTIC_SEARCH_TYPE, PRIVATE_COLLECTION_TYPE def to_camel_case(chaine): @@ -78,7 +78,13 @@ def get_collections(self, collection_ids: List[str], user: User) -> List[Collect @abstractmethod def create_collection( - self, collection_id: str, collection_name: str, collection_model: str, collection_type: str, collection_description: str, user: User + self, + collection_id: str, + collection_name: str, + collection_model: str, + user: User, + collection_type: str = PRIVATE_COLLECTION_TYPE, + collection_description: Optional[str] = None, ) -> Collection: """ Create a collection, if collection already exists, return the collection id. diff --git a/app/schemas/search.py b/app/schemas/search.py index 4ed3f345..f65f58f1 100644 --- a/app/schemas/search.py +++ b/app/schemas/search.py @@ -11,7 +11,7 @@ class SearchRequest(BaseModel): prompt: str collections: List[Union[UUID, Literal[INTERNET_COLLECTION_DISPLAY_ID]]] rff_k: int = Field(default=20, description="k constant in RFF algorithm") - k: int = Field(gt=0, description="Number of results to return") + k: int = Field(gt=0, default=4, description="Number of results to return") method: Literal[HYBRID_SEARCH_TYPE, LEXICAL_SEARCH_TYPE, SEMANTIC_SEARCH_TYPE] = Field(default=SEMANTIC_SEARCH_TYPE) score_threshold: Optional[float] = Field( 0.0, ge=0.0, le=1.0, description="Score of cosine similarity threshold for filtering results, only available for semantic search method." diff --git a/app/tests/test_chat.py b/app/tests/test_chat.py index 5b247db9..f113658d 100644 --- a/app/tests/test_chat.py +++ b/app/tests/test_chat.py @@ -77,20 +77,6 @@ def test_chat_completions_unknown_params(self, args, session_user, setup): response = session_user.post(f"{args['base_url']}/chat/completions", json=params) assert response.status_code == 200, f"error: retrieve chat completions ({response.status_code})" - def test_chat_completions_max_tokens_too_large(self, args, session_user, setup): - MODEL_ID, MAX_CONTEXT_LENGTH = setup - - prompt = "test" - params = { - "model": MODEL_ID, - "messages": [{"role": "user", "content": prompt}], - "stream": False, - "n": 1, - "max_tokens": MAX_CONTEXT_LENGTH + 100, - } - response = session_user.post(f"{args['base_url']}/chat/completions", json=params) - assert response.status_code == 400, f"error: retrieve chat completions ({response.status_code})" - def test_chat_completions_context_too_large(self, args, session_user, setup): MODEL_ID, MAX_CONTEXT_LENGTH = setup