Skip to content

Commit

Permalink
feat: remove tools
Browse files Browse the repository at this point in the history
  • Loading branch information
leoguillaumegouv committed Sep 19, 2024
1 parent 9de4aac commit d2d8aea
Show file tree
Hide file tree
Showing 14 changed files with 41 additions and 367 deletions.
47 changes: 6 additions & 41 deletions app/endpoints/chat.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
import httpx
import json
from typing import Union

from fastapi import APIRouter, Security, HTTPException
from fastapi import APIRouter, HTTPException, Security
from fastapi.responses import StreamingResponse
import httpx

from app.schemas.chat import ChatCompletionRequest, ChatCompletion, ChatCompletionChunk
from app.utils.security import check_api_key
from app.utils.lifespan import clients
from app.utils.config import LOGGER
from app.tools import *
from app.tools import __all__ as tools_list
from app.schemas.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionRequest
from app.schemas.config import LANGUAGE_MODEL_TYPE
from app.utils.lifespan import clients
from app.utils.security import check_api_key

router = APIRouter()


# @TODO: remove tooling from here
@router.post("/chat/completions")
async def chat_completions(request: ChatCompletionRequest, user: str = Security(check_api_key)) -> Union[ChatCompletion, ChatCompletionChunk]:
"""Completion API similar to OpenAI's API.
Expand All @@ -34,50 +29,20 @@ async def chat_completions(request: ChatCompletionRequest, user: str = Security(
if not client.check_context_length(model=request["model"], messages=request["messages"]):
raise HTTPException(status_code=400, detail="Context length too large")

# tool call
metadata = list()
tools = request.get("tools")
if tools:
for tool in tools:
if tool["function"]["name"] not in tools_list:
raise HTTPException(status_code=404, detail="Tool not found")
func = globals()[tool["function"]["name"]](clients=clients)
params = request | tool["function"]["parameters"]
params["user"] = user
LOGGER.debug(f"params: {params}")
try:
tool_output = await func.get_prompt(**params)
except Exception as e:
raise HTTPException(status_code=400, detail=f"tool error {e}")
metadata.append({tool["function"]["name"]: tool_output.model_dump()})
request["messages"] = [{"role": "user", "content": tool_output.prompt}]
request.pop("tools")

if not client.check_context_length(model=request["model"], messages=request["messages"]):
raise HTTPException(status_code=400, detail="Context length too large after tool call")

# non stream case
if not request["stream"]:
async_client = httpx.AsyncClient(timeout=20)
response = await async_client.request(method="POST", url=url, headers=headers, json=request)
print(response.text)
response.raise_for_status()
data = response.json()
data["metadata"] = metadata
return ChatCompletion(**data)

# stream case
async def forward_stream(url: str, headers: dict, request: dict):
async with httpx.AsyncClient(timeout=20) as async_client:
async with async_client.stream(method="POST", url=url, headers=headers, json=request) as response:
i = 0
async for chunk in response.aiter_raw():
if i == 0:
chunks = chunk.decode("utf-8").split("\n\n")
chunk = json.loads(chunks[0].lstrip("data: "))
chunk["metadata"] = metadata
chunks[0] = f"data: {json.dumps(chunk)}"
chunk = "\n\n".join(chunks).encode("utf-8")
i = 1
yield chunk

return StreamingResponse(forward_stream(url, headers, request), media_type="text/event-stream")
1 change: 0 additions & 1 deletion app/endpoints/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
router = APIRouter()


# @TODO: remove get one collection and a /collections/search to similarity search (remove /tools)
@router.get("/collections/{collection}")
@router.get("/collections")
async def get_collections(collection: Optional[str] = None, user: str = Security(check_api_key)) -> Union[Collection, Collections]:
Expand Down
7 changes: 3 additions & 4 deletions app/endpoints/search.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from fastapi import APIRouter, Security

from app.helpers import VectorStore
from app.schemas.chunks import Chunks
from app.schemas.search import SearchRequest
from app.schemas.search import SearchRequest, Searches
from app.utils.lifespan import clients
from app.utils.security import check_api_key

router = APIRouter()


@router.post("/search")
async def search(request: SearchRequest, user: str = Security(check_api_key)) -> Chunks:
async def search(request: SearchRequest, user: str = Security(check_api_key)) -> Searches:
"""
Similarity search for chunks in the vector store.
Expand All @@ -27,4 +26,4 @@ async def search(request: SearchRequest, user: str = Security(check_api_key)) ->
prompt=request.prompt, model=request.model, collection_names=request.collections, k=request.k, score_threshold=request.score_threshold
)

return Chunks(data=data)
return Searches(data=data)
26 changes: 0 additions & 26 deletions app/endpoints/tools.py

This file was deleted.

10 changes: 7 additions & 3 deletions app/helpers/_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from app.schemas.chunks import Chunk
from app.schemas.collections import CollectionMetadata
from app.schemas.config import EMBEDDINGS_MODEL_TYPE, METADATA_COLLECTION, PRIVATE_COLLECTION_TYPE, PUBLIC_COLLECTION_TYPE
from app.schemas.search import Search


class VectorStore:
Expand Down Expand Up @@ -64,7 +65,7 @@ def search(
k: Optional[int] = 4,
score_threshold: Optional[float] = None,
filter: Optional[Filter] = None,
) -> List[Chunk]:
) -> List[Search]:
response = self.models[model].embeddings.create(input=[prompt], model=model)
vector = response.data[0].embedding

Expand All @@ -88,9 +89,12 @@ def search(

# sort by similarity score and get top k
chunks = sorted(chunks, key=lambda x: x.score, reverse=True)[:k]
chunks = [Chunk(id=chunk.id, content=chunk.payload["page_content"], metadata=chunk.payload["metadata"]) for chunk in chunks]
data = [
Search(score=chunk.score, chunk=Chunk(id=chunk.id, content=chunk.payload["page_content"], metadata=chunk.payload["metadata"]))
for chunk in chunks
]

return chunks
return data

def get_collection_metadata(self, collection_names: List[str] = [], type: str = "all", errors: str = "raise") -> List[CollectionMetadata]:
"""
Expand Down
3 changes: 1 addition & 2 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import FastAPI, Response, Security

from app.endpoints import chat, chunks, collections, completions, embeddings, files, models, search, tools
from app.endpoints import chat, chunks, collections, completions, embeddings, files, models, search
from app.utils.config import APP_CONTACT_EMAIL, APP_CONTACT_URL, APP_DESCRIPTION, APP_VERSION
from app.utils.lifespan import lifespan
from app.utils.security import check_api_key
Expand Down Expand Up @@ -32,4 +32,3 @@ def health(user: str = Security(check_api_key)):
app.include_router(chunks.router, tags=["Chunks"], prefix="/v1")
app.include_router(files.router, tags=["Files"], prefix="/v1")
app.include_router(search.router, tags=["Search"], prefix="/v1")
app.include_router(tools.router, tags=["Tools"], prefix="/v1")
16 changes: 11 additions & 5 deletions app/schemas/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Literal, Optional, Union
from typing import List, Literal, Optional, Union

from openai.types.chat import (
ChatCompletion,
Expand All @@ -9,10 +9,9 @@
)
from pydantic import BaseModel, Field

from app.schemas.tools import ToolOutput


class ChatCompletionRequest(BaseModel):
# See https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py
messages: List[ChatCompletionMessageParam]
model: str
stream: Optional[Literal[True, False]] = False
Expand All @@ -27,11 +26,18 @@ class ChatCompletionRequest(BaseModel):
stop: Union[Optional[str], List[str]] = Field(default_factory=list)
tool_choice: Optional[Union[Literal["none"], ChatCompletionToolChoiceOptionParam]] = "none"
tools: List[ChatCompletionToolParam] = None
user: Optional[str] = None
best_of: Optional[int] = None
top_k: int = -1
min_p: float = 0.0

class Config:
extra = "allow"


class ChatCompletion(ChatCompletion):
metadata: Optional[List[Dict[str, ToolOutput]]] = []
pass


class ChatCompletionChunk(ChatCompletionChunk):
metadata: Optional[List[Dict[str, ToolOutput]]] = []
pass
14 changes: 13 additions & 1 deletion app/schemas/search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import List, Optional
from typing import List, Literal, Optional

from pydantic import BaseModel, Field, field_validator

from app.schemas.chunks import Chunk


class SearchRequest(BaseModel):
prompt: str
Expand All @@ -15,3 +17,13 @@ def blank_string(value):
if value.strip() == "":
raise ValueError("Prompt cannot be empty")
return value


class Search(BaseModel):
score: float
chunk: Chunk


class Searches(BaseModel):
object: Literal["list"] = "list"
data: List[Search]
19 changes: 0 additions & 19 deletions app/schemas/tools.py

This file was deleted.

21 changes: 0 additions & 21 deletions app/tests/test_tools.py

This file was deleted.

Loading

0 comments on commit d2d8aea

Please sign in to comment.