Skip to content

Commit

Permalink
Default internet models (#93)
Browse files Browse the repository at this point in the history
* chore: change internet client

* feat: add score threshold

* feat: async client

* fix: internet and exceptions

---------

Co-authored-by: leoguillaume <[email protected]>
  • Loading branch information
leoguillaume and leoguillaumegouv authored Dec 6, 2024
1 parent a1cc34f commit e85f03e
Show file tree
Hide file tree
Showing 24 changed files with 508 additions and 1,625 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Albert API est une initiative d'[Etalab](https://www.etalab.gouv.fr/). Il s'agit
- servir des modèles de langage avec [vLLM](https://github.com/vllm-project/vllm)
- servir des modèles d'embeddings avec [HuggingFace Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
- servir des modèles de reconnaissance vocale avec [Whisper OpenAI API](https://github.com/etalab-ia/whisper-openai-api)
- accès un *vector store* avec [Qdrant](https://qdrant.tech/) pour la recherche de similarité
- accès un *vector store* avec [Elasticsearch](https://www.elastic.co/fr/products/elasticsearch) pour la recherche de similarité (lexicale, sémantique ou hybride) ou [Qdrant](https://qdrant.tech/) pour la recherche sémantique uniquement.

En se basant sur les conventions définies par OpenAI, l'API Albert expose des endpoints qui peuvent être appelés avec le [client officiel python d'OpenAI](https://github.com/openai/openai-python/tree/main). Ce formalisme permet d'intégrer facilement l'API Albert avec des bibliothèques tierces comme [Langchain](https://www.langchain.com/) ou [LlamaIndex](https://www.llamaindex.ai/).

Expand Down
60 changes: 37 additions & 23 deletions app/endpoints/audio.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
from typing import List, Literal

from fastapi import APIRouter, Form, Security, Request, UploadFile, File
from fastapi import APIRouter, File, Form, HTTPException, Request, Security, UploadFile
from fastapi.responses import PlainTextResponse
import httpx

from app.schemas.audio import AudioTranscription, AudioTranscriptionVerbose
from app.schemas.audio import AudioTranscription
from app.schemas.settings import AUDIO_MODEL_TYPE
from app.utils.settings import settings
from app.utils.security import check_api_key, check_rate_limit, User
from app.utils.lifespan import clients, limiter
from app.utils.exceptions import ModelNotFoundException
from app.utils.variables import SUPPORTED_LANGUAGES

from app.utils.lifespan import clients, limiter
from app.utils.security import User, check_api_key, check_rate_limit
from app.utils.settings import settings
from app.utils.variables import DEFAULT_TIMEOUT, SUPPORTED_LANGUAGES

router = APIRouter()
SUPPORTED_LANGUAGES_VALUES = sorted(set(SUPPORTED_LANGUAGES.values())) + sorted(set(SUPPORTED_LANGUAGES.keys()))
Expand All @@ -21,32 +23,44 @@ async def audio_transcriptions(
request: Request,
file: UploadFile = File(...),
model: str = Form(...),
language: Literal[*SUPPORTED_LANGUAGES_VALUES] = Form("fr"),
language: Literal[*SUPPORTED_LANGUAGES_VALUES] = Form(default="fr"),
prompt: str = Form(None),
response_format: str = Form("json"),
response_format: Literal["json", "text"] = Form(default="json"),
temperature: float = Form(0),
timestamp_granularities: List[str] = Form(alias="timestamp_granularities[]", default=["segment"]),
user: User = Security(check_api_key),
) -> AudioTranscription | AudioTranscriptionVerbose:
user: User = Security(dependency=check_api_key),
) -> AudioTranscription:
"""
API de transcription similaire à l'API d'OpenAI.
"""

client = clients.models[model]

# @TODO: check if the file is an audio file
if client.type != AUDIO_MODEL_TYPE:
raise ModelNotFoundException()

# @TODO: Implement prompt
# @TODO: Implement timestamp_granularities
# @TODO: Implement verbose response format

file_content = await file.read()

response = await client.audio.transcriptions.create(
file=(file.filename, file_content, file.content_type),
model=model,
language=language,
prompt=prompt,
response_format=response_format,
temperature=temperature,
timestamp_granularities=timestamp_granularities,
)
return response
url = f"{client.base_url}audio/transcriptions"
headers = {"Authorization": f"Bearer {client.api_key}"}

try:
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client:
response = await async_client.post(
url=url,
headers=headers,
files={"file": (file.filename, file_content, file.content_type)},
data={"language": language, "response_format": response_format, "temperature": temperature},
)
response.raise_for_status()
if response_format == "text":
return PlainTextResponse(content=response.text)

data = response.json()
return AudioTranscription(**data)

except Exception as e:
raise HTTPException(status_code=e.response.status_code, detail=json.loads(s=e.response.text)["message"])
44 changes: 25 additions & 19 deletions app/endpoints/chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union

from fastapi import APIRouter, Request, Security
import json
from fastapi import APIRouter, Request, Security, HTTPException
from fastapi.responses import StreamingResponse
import httpx

Expand All @@ -9,14 +9,15 @@
from app.utils.settings import settings
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.variables import DEFAULT_TIMEOUT

router = APIRouter()


@router.post("/chat/completions")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@router.post(path="/chat/completions")
@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def chat_completions(
request: Request, body: ChatCompletionRequest, user: User = Security(check_api_key)
request: Request, body: ChatCompletionRequest, user: User = Security(dependency=check_api_key)
) -> Union[ChatCompletion, ChatCompletionChunk]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create for the API specification.
Expand All @@ -25,20 +26,25 @@ async def chat_completions(
url = f"{client.base_url}chat/completions"
headers = {"Authorization": f"Bearer {client.api_key}"}

# non stream case
if not body.stream:
async with httpx.AsyncClient(timeout=20) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump())
response.raise_for_status()
try:
# non stream case
if not body.stream:
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump())
response.raise_for_status()

data = response.json()
return ChatCompletion(**data)

data = response.json()
return ChatCompletion(**data)
# stream case
async def forward_stream(url: str, headers: dict, request: dict):
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client:
async with async_client.stream(method="POST", url=url, headers=headers, json=request) as response:
response.raise_for_status()
async for chunk in response.aiter_raw():
yield chunk

# stream case
async def forward_stream(url: str, headers: dict, request: dict):
async with httpx.AsyncClient(timeout=20) as async_client:
async with async_client.stream(method="POST", url=url, headers=headers, json=request) as response:
async for chunk in response.aiter_raw():
yield chunk
return StreamingResponse(forward_stream(url, headers, body.model_dump()), media_type="text/event-stream")

return StreamingResponse(forward_stream(url, headers, body.model_dump()), media_type="text/event-stream")
except Exception as e:
raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"])
24 changes: 15 additions & 9 deletions app/endpoints/completions.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from fastapi import APIRouter, Request, Security
from fastapi import APIRouter, Request, Security, HTTPException
import httpx
import json

from app.schemas.completions import CompletionRequest, Completions
from app.schemas.security import User
from app.utils.settings import settings
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.variables import DEFAULT_TIMEOUT

router = APIRouter()


@router.post("/completions")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def completions(request: Request, body: CompletionRequest, user: User = Security(check_api_key)) -> Completions:
@router.post(path="/completions")
@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def completions(request: Request, body: CompletionRequest, user: User = Security(dependency=check_api_key)) -> Completions:
"""
Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create for the API specification.
Expand All @@ -21,9 +23,13 @@ async def completions(request: Request, body: CompletionRequest, user: User = Se
url = f"{client.base_url}completions"
headers = {"Authorization": f"Bearer {client.api_key}"}

async with httpx.AsyncClient(timeout=20) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump())
response.raise_for_status()
try:
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump())
response.raise_for_status()

data = response.json()
return Completions(**data)
data = response.json()
return Completions(**data)

except Exception as e:
raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"])
36 changes: 19 additions & 17 deletions app/endpoints/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from fastapi import APIRouter, Request, Security
from fastapi import APIRouter, Request, Security, HTTPException
import httpx
import json

from app.schemas.embeddings import Embeddings, EmbeddingsRequest
from app.schemas.security import User
from app.utils.settings import settings
from app.utils.exceptions import ContextLengthExceededException, WrongModelTypeException
from app.utils.exceptions import WrongModelTypeException
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.variables import EMBEDDINGS_MODEL_TYPE
from app.utils.variables import EMBEDDINGS_MODEL_TYPE, DEFAULT_TIMEOUT

router = APIRouter()


@router.post("/embeddings")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Security(check_api_key)) -> Embeddings:
@router.post(path="/embeddings")
@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Security(dependency=check_api_key)) -> Embeddings:
"""
Embedding API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification.
Expand All @@ -27,15 +28,16 @@ async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Sec
url = f"{client.base_url}embeddings"
headers = {"Authorization": f"Bearer {client.api_key}"}

async with httpx.AsyncClient(timeout=20) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump())
try:
try:
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump())
# try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
if "`inputs` must have less than" in e.response.text:
raise ContextLengthExceededException()
raise e

data = response.json()

return Embeddings(**data)
# except httpx.HTTPStatusError as e:
# if "`inputs` must have less than" in e.response.text:
# raise ContextLengthExceededException()
# raise e
data = response.json()
return Embeddings(**data)
except Exception as e:
raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"])
50 changes: 36 additions & 14 deletions app/endpoints/search.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,72 @@
import uuid

from fastapi import APIRouter, Request, Security

from app.schemas.search import Searches, SearchRequest
from app.schemas.security import User
from app.utils.settings import settings
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.settings import settings
from app.utils.variables import INTERNET_COLLECTION_DISPLAY_ID


router = APIRouter()


@router.post("/search")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def search(request: Request, body: SearchRequest, user: User = Security(check_api_key)) -> Searches:
@router.post(path="/search")
@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def search(request: Request, body: SearchRequest, user: User = Security(dependency=check_api_key)) -> Searches:
"""
Endpoint to search on the internet or with our engine client
"""

# TODO: to be handled by a service top to InternetExplorer
body = await request.json()
body = SearchRequest(**body)

# Internet search
need_internet_search = not body.collections or INTERNET_COLLECTION_DISPLAY_ID in body.collections
internet_chunks = []
if need_internet_search:
internet_chunks = clients.internet.get_chunks(prompt=body.prompt)
# get internet results chunks
internet_collection_id = str(uuid.uuid4())
internet_chunks = clients.internet.get_chunks(prompt=body.prompt, collection_id=internet_collection_id)

if internet_chunks:
internet_collection = clients.internet.create_temporary_internet_collection(internet_chunks, body.collections, user)
internet_embeddings_model_id = (
clients.internet.default_embeddings_model_id
if body.collections == [INTERNET_COLLECTION_DISPLAY_ID]
else clients.search.get_collections(collection_ids=body.collections, user=user)[0].model
)

clients.search.create_collection(
collection_id=internet_collection_id,
collection_name=internet_collection_id,
collection_model=internet_embeddings_model_id,
user=user,
)
clients.search.upsert(chunks=internet_chunks, collection_id=internet_collection_id, user=user)

# case: no other collections, only internet, and no internet results
elif body.collections == [INTERNET_COLLECTION_DISPLAY_ID]:
return Searches(data=[])

# case: other collections or only internet and internet results
if INTERNET_COLLECTION_DISPLAY_ID in body.collections:
body.collections.remove(INTERNET_COLLECTION_DISPLAY_ID)
if not body.collections and not internet_chunks:
return Searches(data=[])
if internet_chunks:
body.collections.append(internet_collection.id)
body.collections.append(internet_collection_id)

searches = clients.search.query(
prompt=body.prompt,
collection_ids=body.collections,
method=body.method,
k=body.k,
rff_k=body.rff_k,
score_threshold=body.score_threshold,
user=user,
)

if internet_chunks:
clients.search.delete_collection(internet_collection.id, user=user)
clients.search.delete_collection(collection_id=internet_collection_id, user=user)

if body.score_threshold:
searches = [search for search in searches if search.score >= body.score_threshold]

return Searches(data=searches)
2 changes: 1 addition & 1 deletion app/helpers/_clientsmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def set(self):

self.search = SearchClient.import_constructor(self.settings.search.type)(models=self.models, **self.settings.search.args)

self.internet = InternetClient(model_clients=self.models, search_client=self.search, **self.settings.internet.args)
self.internet = InternetClient(model_clients=self.models, search_client=self.search, **self.settings.internet.args.model_dump())

self.auth = AuthenticationClient(cache=self.cache, **self.settings.auth.args) if self.settings.auth else None

Expand Down
Loading

0 comments on commit e85f03e

Please sign in to comment.