From d2d8aeade17643bf559bd17c75af43a2da12a6af Mon Sep 17 00:00:00 2001 From: leoguillaume Date: Thu, 19 Sep 2024 18:20:24 +0200 Subject: [PATCH] feat: remove tools --- app/endpoints/chat.py | 47 ++---------- app/endpoints/collections.py | 1 - app/endpoints/search.py | 7 +- app/endpoints/tools.py | 26 ------- app/helpers/_vectorstore.py | 10 ++- app/main.py | 3 +- app/schemas/chat.py | 16 ++-- app/schemas/search.py | 14 +++- app/schemas/tools.py | 19 ----- app/tests/test_tools.py | 21 ------ app/tests/tools/test_baserag.py | 128 -------------------------------- app/tools/__init__.py | 4 - app/tools/_baserag.py | 48 ------------ app/tools/_fewshots.py | 64 ---------------- 14 files changed, 41 insertions(+), 367 deletions(-) delete mode 100644 app/endpoints/tools.py delete mode 100644 app/schemas/tools.py delete mode 100644 app/tests/test_tools.py delete mode 100644 app/tests/tools/test_baserag.py delete mode 100644 app/tools/__init__.py delete mode 100644 app/tools/_baserag.py delete mode 100644 app/tools/_fewshots.py diff --git a/app/endpoints/chat.py b/app/endpoints/chat.py index 6d227ec..e05d80e 100644 --- a/app/endpoints/chat.py +++ b/app/endpoints/chat.py @@ -1,22 +1,17 @@ -import httpx -import json from typing import Union -from fastapi import APIRouter, Security, HTTPException +from fastapi import APIRouter, HTTPException, Security from fastapi.responses import StreamingResponse +import httpx -from app.schemas.chat import ChatCompletionRequest, ChatCompletion, ChatCompletionChunk -from app.utils.security import check_api_key -from app.utils.lifespan import clients -from app.utils.config import LOGGER -from app.tools import * -from app.tools import __all__ as tools_list +from app.schemas.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionRequest from app.schemas.config import LANGUAGE_MODEL_TYPE +from app.utils.lifespan import clients +from app.utils.security import check_api_key router = APIRouter() -# @TODO: remove tooling from here @router.post("/chat/completions") async def chat_completions(request: ChatCompletionRequest, user: str = Security(check_api_key)) -> Union[ChatCompletion, ChatCompletionChunk]: """Completion API similar to OpenAI's API. @@ -34,50 +29,20 @@ async def chat_completions(request: ChatCompletionRequest, user: str = Security( if not client.check_context_length(model=request["model"], messages=request["messages"]): raise HTTPException(status_code=400, detail="Context length too large") - # tool call - metadata = list() - tools = request.get("tools") - if tools: - for tool in tools: - if tool["function"]["name"] not in tools_list: - raise HTTPException(status_code=404, detail="Tool not found") - func = globals()[tool["function"]["name"]](clients=clients) - params = request | tool["function"]["parameters"] - params["user"] = user - LOGGER.debug(f"params: {params}") - try: - tool_output = await func.get_prompt(**params) - except Exception as e: - raise HTTPException(status_code=400, detail=f"tool error {e}") - metadata.append({tool["function"]["name"]: tool_output.model_dump()}) - request["messages"] = [{"role": "user", "content": tool_output.prompt}] - request.pop("tools") - - if not client.check_context_length(model=request["model"], messages=request["messages"]): - raise HTTPException(status_code=400, detail="Context length too large after tool call") - # non stream case if not request["stream"]: async_client = httpx.AsyncClient(timeout=20) response = await async_client.request(method="POST", url=url, headers=headers, json=request) + print(response.text) response.raise_for_status() data = response.json() - data["metadata"] = metadata return ChatCompletion(**data) # stream case async def forward_stream(url: str, headers: dict, request: dict): async with httpx.AsyncClient(timeout=20) as async_client: async with async_client.stream(method="POST", url=url, headers=headers, json=request) as response: - i = 0 async for chunk in response.aiter_raw(): - if i == 0: - chunks = chunk.decode("utf-8").split("\n\n") - chunk = json.loads(chunks[0].lstrip("data: ")) - chunk["metadata"] = metadata - chunks[0] = f"data: {json.dumps(chunk)}" - chunk = "\n\n".join(chunks).encode("utf-8") - i = 1 yield chunk return StreamingResponse(forward_stream(url, headers, request), media_type="text/event-stream") diff --git a/app/endpoints/collections.py b/app/endpoints/collections.py index 0efb531..3533cfc 100644 --- a/app/endpoints/collections.py +++ b/app/endpoints/collections.py @@ -12,7 +12,6 @@ router = APIRouter() -# @TODO: remove get one collection and a /collections/search to similarity search (remove /tools) @router.get("/collections/{collection}") @router.get("/collections") async def get_collections(collection: Optional[str] = None, user: str = Security(check_api_key)) -> Union[Collection, Collections]: diff --git a/app/endpoints/search.py b/app/endpoints/search.py index ac4fa71..a541760 100644 --- a/app/endpoints/search.py +++ b/app/endpoints/search.py @@ -1,8 +1,7 @@ from fastapi import APIRouter, Security from app.helpers import VectorStore -from app.schemas.chunks import Chunks -from app.schemas.search import SearchRequest +from app.schemas.search import SearchRequest, Searches from app.utils.lifespan import clients from app.utils.security import check_api_key @@ -10,7 +9,7 @@ @router.post("/search") -async def search(request: SearchRequest, user: str = Security(check_api_key)) -> Chunks: +async def search(request: SearchRequest, user: str = Security(check_api_key)) -> Searches: """ Similarity search for chunks in the vector store. @@ -27,4 +26,4 @@ async def search(request: SearchRequest, user: str = Security(check_api_key)) -> prompt=request.prompt, model=request.model, collection_names=request.collections, k=request.k, score_threshold=request.score_threshold ) - return Chunks(data=data) + return Searches(data=data) diff --git a/app/endpoints/tools.py b/app/endpoints/tools.py deleted file mode 100644 index 26748a8..0000000 --- a/app/endpoints/tools.py +++ /dev/null @@ -1,26 +0,0 @@ -from fastapi import APIRouter, Security - -from app.schemas.tools import Tools -from app.utils.security import check_api_key -from app.tools import * -from app.tools import __all__ as tools_list - -router = APIRouter() - - -@router.get("/tools") -def tools(user: str = Security(check_api_key)) -> Tools: - """ - Get list a availables tools. Only RAG functions are currenty supported. - """ - data = [ - { - "id": globals()[tool].__name__, - "description": globals()[tool].__doc__.strip(), - "object": "tool", - } - for tool in tools_list - ] - response = {"object": "list", "data": data} - - return Tools(**response) diff --git a/app/helpers/_vectorstore.py b/app/helpers/_vectorstore.py index 83f3ff0..5cd72b5 100644 --- a/app/helpers/_vectorstore.py +++ b/app/helpers/_vectorstore.py @@ -9,6 +9,7 @@ from app.schemas.chunks import Chunk from app.schemas.collections import CollectionMetadata from app.schemas.config import EMBEDDINGS_MODEL_TYPE, METADATA_COLLECTION, PRIVATE_COLLECTION_TYPE, PUBLIC_COLLECTION_TYPE +from app.schemas.search import Search class VectorStore: @@ -64,7 +65,7 @@ def search( k: Optional[int] = 4, score_threshold: Optional[float] = None, filter: Optional[Filter] = None, - ) -> List[Chunk]: + ) -> List[Search]: response = self.models[model].embeddings.create(input=[prompt], model=model) vector = response.data[0].embedding @@ -88,9 +89,12 @@ def search( # sort by similarity score and get top k chunks = sorted(chunks, key=lambda x: x.score, reverse=True)[:k] - chunks = [Chunk(id=chunk.id, content=chunk.payload["page_content"], metadata=chunk.payload["metadata"]) for chunk in chunks] + data = [ + Search(score=chunk.score, chunk=Chunk(id=chunk.id, content=chunk.payload["page_content"], metadata=chunk.payload["metadata"])) + for chunk in chunks + ] - return chunks + return data def get_collection_metadata(self, collection_names: List[str] = [], type: str = "all", errors: str = "raise") -> List[CollectionMetadata]: """ diff --git a/app/main.py b/app/main.py index 46de908..894852f 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,6 @@ from fastapi import FastAPI, Response, Security -from app.endpoints import chat, chunks, collections, completions, embeddings, files, models, search, tools +from app.endpoints import chat, chunks, collections, completions, embeddings, files, models, search from app.utils.config import APP_CONTACT_EMAIL, APP_CONTACT_URL, APP_DESCRIPTION, APP_VERSION from app.utils.lifespan import lifespan from app.utils.security import check_api_key @@ -32,4 +32,3 @@ def health(user: str = Security(check_api_key)): app.include_router(chunks.router, tags=["Chunks"], prefix="/v1") app.include_router(files.router, tags=["Files"], prefix="/v1") app.include_router(search.router, tags=["Search"], prefix="/v1") -app.include_router(tools.router, tags=["Tools"], prefix="/v1") diff --git a/app/schemas/chat.py b/app/schemas/chat.py index df7c27a..976a5e3 100644 --- a/app/schemas/chat.py +++ b/app/schemas/chat.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union from openai.types.chat import ( ChatCompletion, @@ -9,10 +9,9 @@ ) from pydantic import BaseModel, Field -from app.schemas.tools import ToolOutput - class ChatCompletionRequest(BaseModel): + # See https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py messages: List[ChatCompletionMessageParam] model: str stream: Optional[Literal[True, False]] = False @@ -27,11 +26,18 @@ class ChatCompletionRequest(BaseModel): stop: Union[Optional[str], List[str]] = Field(default_factory=list) tool_choice: Optional[Union[Literal["none"], ChatCompletionToolChoiceOptionParam]] = "none" tools: List[ChatCompletionToolParam] = None + user: Optional[str] = None + best_of: Optional[int] = None + top_k: int = -1 + min_p: float = 0.0 + + class Config: + extra = "allow" class ChatCompletion(ChatCompletion): - metadata: Optional[List[Dict[str, ToolOutput]]] = [] + pass class ChatCompletionChunk(ChatCompletionChunk): - metadata: Optional[List[Dict[str, ToolOutput]]] = [] + pass diff --git a/app/schemas/search.py b/app/schemas/search.py index d44e1fb..4f20bd5 100644 --- a/app/schemas/search.py +++ b/app/schemas/search.py @@ -1,7 +1,9 @@ -from typing import List, Optional +from typing import List, Literal, Optional from pydantic import BaseModel, Field, field_validator +from app.schemas.chunks import Chunk + class SearchRequest(BaseModel): prompt: str @@ -15,3 +17,13 @@ def blank_string(value): if value.strip() == "": raise ValueError("Prompt cannot be empty") return value + + +class Search(BaseModel): + score: float + chunk: Chunk + + +class Searches(BaseModel): + object: Literal["list"] = "list" + data: List[Search] diff --git a/app/schemas/tools.py b/app/schemas/tools.py deleted file mode 100644 index a0ca514..0000000 --- a/app/schemas/tools.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import Literal, List - -from pydantic import BaseModel - - -class Tool(BaseModel): - object: Literal["tool"] = "tool" - id: str - description: str - - -class Tools(BaseModel): - object: Literal["list"] = "list" - data: List[Tool] - - -class ToolOutput(BaseModel): - prompt: str - metadata: dict diff --git a/app/tests/test_tools.py b/app/tests/test_tools.py deleted file mode 100644 index 989c43a..0000000 --- a/app/tests/test_tools.py +++ /dev/null @@ -1,21 +0,0 @@ -import pytest - -from app.schemas.tools import Tools, Tool - - -@pytest.mark.usefixtures("args", "session") -class TestTools: - def test_get_tools_response_status_code(self, args, session): - """Test the GET /tools response status code.""" - response = session.get(f"{args['base_url']}/tools") - assert response.status_code == 200, f"error: retrieve tools ({response.status_code})" - - def test_get_tools_response_schemas(self, args, session): - """Test the GET /tools response schemas.""" - response = session.get(f"{args['base_url']}/tools") - response_json = response.json() - - tools = Tools(data=[Tool(**tool) for tool in response_json["data"]]) - - assert isinstance(tools, Tools) - assert all(isinstance(tool, Tool) for tool in tools.data) diff --git a/app/tests/tools/test_baserag.py b/app/tests/tools/test_baserag.py deleted file mode 100644 index 468ad4b..0000000 --- a/app/tests/tools/test_baserag.py +++ /dev/null @@ -1,128 +0,0 @@ -import os -import logging - -import pytest -import wget - -from app.schemas.config import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE -from app.schemas.models import Models, Model -from app.schemas.chat import ChatCompletion - - -@pytest.mark.usefixtures("args", "session") -class TestTools: - FILE_URL = "https://www.legifrance.gouv.fr/download/file/rxcTl0H4YnnzLkMLiP4x15qORfLSKk_h8QsSb2xnJ8Y=/JOE_TEXTE" - FILE_PATH = "test.pdf" - COLLECTION = "pytest" - - def test_baserag_tool_response_status_code(self, args, session): - """Setup for other tests.""" - - # delete collection if exists - response = session.get(f"{args["base_url"]}/collections", params={"collection": self.COLLECTION}, timeout=10) - if response.status_code == 200: - response = session.delete( - f"{args["base_url"]}/collections", - params={"collection": self.COLLECTION}, - timeout=10, - ) - assert response.status_code == 204, f"error: delete collection ({response.status_code})" - logging.info(f"collection {self.COLLECTION} deleted") - - # download file - if not os.path.exists(self.FILE_PATH): - wget.download(self.FILE_URL, out=self.FILE_PATH) - logging.info(f"file {self.FILE_PATH} downloaded") - - # get a embeddings_model - response = session.get(f"{args["base_url"]}/models", timeout=10) - assert response.status_code == 200, f"error: retrieve models ({response.status_code})" - models = response.json() - models = Models(data=[Model(**model) for model in models["data"]]) - self.EMBEDDINGS_MODEL = [model for model in models.data if model.type == EMBEDDINGS_MODEL_TYPE][0].id - logging.debug(f"embeddings_model: {self.EMBEDDINGS_MODEL}") - - # upload file - params = { - "embeddings_model": self.EMBEDDINGS_MODEL, - "collection": self.COLLECTION, - } - files = { - "files": ( - os.path.basename(self.FILE_PATH), - open(self.FILE_PATH, "rb"), - "application/pdf", - ) - } - response = session.post(f"{args["base_url"]}/files", params=params, files=files, timeout=30) - assert response.status_code == 200, f"error: upload file ({response.status_code})" - - # get a language model - response = session.get(f"{args["base_url"]}/models", timeout=10) - assert response.status_code == 200, f"error: retrieve models ({response.status_code})" - models = response.json() - models = Models(data=[Model(**model) for model in models["data"]]) - self.LANGUAGE_MODEL = [model for model in models.data if model.type == LANGUAGE_MODEL_TYPE][0].id - logging.debug(f"language_model: {self.LANGUAGE_MODEL}") - - # test baserag tool - data = { - "model": self.LANGUAGE_MODEL, - "messages": [{"role": "user", "content": "Qui est Ulrich Tan ?"}], - "stream": False, - "n": 1, - "tools": [ - { - "function": { - "name": "BaseRAG", - "parameters": { - "embeddings_model": self.EMBEDDINGS_MODEL, - "collections": [self.COLLECTION], - "k": 2, - }, - }, - "type": "function", - } - ], - } - response = session.post(f"{args["base_url"]}/chat/completions", json=data, timeout=30) - assert response.status_code == 200, f"error: chat completions ({response.status_code})" - response = response.json() - response = ChatCompletion(**response) - - assert response.choices[0].message.content is not None, "error: response content is None" - logging.debug(response.choices[0].message.content) - logging.debug(response.metadata) - - # check if metadata - assert "BaseRAG" in response.metadata[0], "error: metadata BaseRAG not found" - - # test with wrong embeddings_model - response = session.get(f"{args["base_url"]}/models", timeout=10) - assert response.status_code == 200, f"error: retrieve models ({response.status_code})" - models = response.json() - models = Models(data=[Model(**model) for model in models["data"]]) - wrong_embeddings_model = [model for model in models.data if model.type == EMBEDDINGS_MODEL_TYPE and model.id != self.EMBEDDINGS_MODEL][0].id - logging.debug(f"wrong_embeddings_model: {wrong_embeddings_model}") - - data = { - "model": self.LANGUAGE_MODEL, - "messages": [{"role": "user", "content": "Qui est Ulrich Tan ?"}], - "stream": False, - "n": 1, - "tools": [ - { - "function": { - "name": "BaseRAG", - "parameters": { - "embeddings_model": wrong_embeddings_model, - "collections": [self.COLLECTION], - "k": 2, - }, - }, - "type": "function", - } - ], - } - response = session.post(f"{args["base_url"]}/chat/completions", json=data, timeout=30) - assert response.status_code == 400, f"error: chat completions ({response.status_code})" diff --git a/app/tools/__init__.py b/app/tools/__init__.py deleted file mode 100644 index e239995..0000000 --- a/app/tools/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from ._baserag import BaseRAG -from ._fewshots import FewShots - -__all__ = ["BaseRAG", "FewShots"] diff --git a/app/tools/_baserag.py b/app/tools/_baserag.py deleted file mode 100644 index a2052a3..0000000 --- a/app/tools/_baserag.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import List, Optional - -from fastapi import HTTPException - -from app.helpers import VectorStore -from app.schemas.tools import ToolOutput - - -class BaseRAG: - """ - Base RAG, basic retrival augmented generation. - - Args: - embeddings_model (str): OpenAI embeddings model - collection (List[str], optional): List of collections to search in. Defaults to None (all collections). - k (int, optional): Top K per collection. Defaults to 4. - prompt_template (Optional[str], optional): Prompt template. Defaults to DEFAULT_PROMPT_TEMPLATE. - - DEFAULT_PROMPT_TEMPLATE: - "Réponds à la question suivante en te basant sur les documents ci-dessous : {prompt}\n\nDocuments :\n\n{documents}" - """ - - DEFAULT_PROMPT_TEMPLATE = "Réponds à la question suivante en te basant sur les documents ci-dessous : {prompt}\n\nDocuments :\n\n{documents}" - - def __init__(self, clients: dict): - self.clients = clients - - async def get_prompt( - self, - embeddings_model: str, - collections: List[str] = [], - k: Optional[int] = 4, - prompt_template: Optional[str] = DEFAULT_PROMPT_TEMPLATE, - **request, - ) -> ToolOutput: - if "{prompt}" not in prompt_template or "{documents}" not in prompt_template: - raise HTTPException(status_code=400, detail="Prompt template must contain '{prompt}' and '{documents}' placeholders.") - - vectorstore = VectorStore(clients=self.clients, user=request["user"]) - prompt = request["messages"][-1]["content"] - - chunks = vectorstore.search(model=embeddings_model, prompt=prompt, collection_names=collections, k=k) - - metadata = {"chunks": [chunk.metadata for chunk in chunks]} - documents = "\n\n".join([chunk.content for chunk in chunks]) - prompt = prompt_template.format(documents=documents, prompt=prompt) - - return ToolOutput(prompt=prompt, metadata=metadata) diff --git a/app/tools/_fewshots.py b/app/tools/_fewshots.py deleted file mode 100644 index 8d1510e..0000000 --- a/app/tools/_fewshots.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Optional - - -from app.helpers import VectorStore -from app.schemas.tools import ToolOutput - - -class FewShots: - """ - Fewshots RAG. - - Args: - embeddings_model (str): OpenAI embeddings model - collection (str): Collection name. The collection must have question/answer pairs in metadata. - k (int, optional): Top K per collection. Defaults to 4. - """ - - COLLECTION = "service-public-plus" - DEFAULT_SYSTEM_PROMPT = "Tu es un générateur de réponse automatique à une expérience utilisateur. Tu parles un français courtois." - - DEFAULT_PROMPT_TEMPLATE = """Vous incarnez un agent chevronné de l'administration française, expert en matière de procédures et réglementations administratives. Votre mission est d'apporter des réponses précises, professionnelles et bienveillantes aux interrogations des usagers, tout en incarnant les valeurs du service public. - -Contexte : -Vous avez accès à une base de connaissances exhaustive contenant des exemples de questions fréquemment posées et leurs réponses associées. Utilisez ces informations comme référence pour formuler vos réponses : - -{context} - -Directives : -1. Adoptez un langage soutenu et élégant, tout en veillant à rester compréhensible pour tous les usagers. -2. Basez-vous sur les exemples fournis pour élaborer des réponses pertinentes et précises. -3. Faites preuve de courtoisie, d'empathie et de pédagogie dans vos interactions, reflétant ainsi l'excellence du service public français. -4. Structurez votre réponse de manière claire et logique, en utilisant si nécessaire des puces ou des numéros pour faciliter la compréhension. -5. En cas d'incertitude sur un point spécifique, indiquez-le clairement et orientez l'usager vers les ressources ou services compétents. -6. Concluez systématiquement votre réponse par une formule de politesse adaptée et proposez votre assistance pour toute question supplémentaire. - -Question de l'usager : - -{prompt} - -Veuillez apporter une réponse circonstanciée à cette question en respectant scrupuleusement les directives énoncées ci-dessus. -""" - - def __init__(self, clients: dict): - self.clients = clients - - async def get_prompt( - self, - embeddings_model: str, - k: Optional[int] = 4, - **request, - ) -> ToolOutput: - vectorstore = VectorStore(clients=self.clients, user=request["user"]) - collection = vectorstore.get_collection_metadata(collection_names=[self.COLLECTION])[0] - prompt = request["messages"][-1]["content"] - results = vectorstore.search(collection_names=[collection.name], prompt=prompt, k=k, model=embeddings_model) - - context = "\n\n\n".join( - [f"Question: {result.payload.get('question', 'N/A')}\n" f"Réponse: {result.payload.get('answer', 'N/A')}" for result in results] - ) - - prompt = self.DEFAULT_PROMPT_TEMPLATE.format(context=context, prompt=prompt) - metadata = {"chunks": [result.id for result in results]} - - return ToolOutput(prompt=prompt, metadata=metadata)