From 5130abe21b2ad91896c98fce1e34fbae25535598 Mon Sep 17 00:00:00 2001 From: leoguillaume Date: Thu, 10 Oct 2024 18:44:19 +0200 Subject: [PATCH] feat: refacto --- .env | 10 ++ .github/workflows/build_and_deploy.yml | 14 +- .gitlab-ci.yml | 47 ------- Dockerfile | 16 --- README.md | 22 ++-- app.py | 69 ---------- app/Dockerfile | 12 ++ .../api_keys.example.json | 0 config.py => app/config.py | 3 +- deps.py => app/deps.py | 2 +- llm.py => app/llm.py | 6 +- app/main.py | 108 +++++++++++++++ schemas/api.py => app/schemas.py | 0 app/scripts/arena.py | 123 ++++++++++++++++++ {scripts => app/scripts}/spp_fewshots.jinja | 4 + security.py => app/security.py | 2 +- subscriptions.py => app/subscriptions.py | 13 +- {tests => app/tests}/conftest.py | 39 ++---- tests/mockups/llm.py => app/tests/mockups.py | 6 +- {tests => app/tests}/test_main.py | 2 +- docker-compose.yml => compose.yml | 10 +- endpoints/__init__.py | 0 endpoints/api.py | 68 ---------- endpoints/misc.py | 10 -- pyproject.toml | 4 +- schemas/__init__.py | 3 - scripts/arena.py | 117 ----------------- tests/__init__.py | 0 tests/mockups/elasticsearch.py | 52 -------- tests/mockups/prompt_config.yml | 7 - tests/mockups/simple_prompt_template.jinja | 10 -- 31 files changed, 308 insertions(+), 471 deletions(-) create mode 100644 .env delete mode 100644 .gitlab-ci.yml delete mode 100644 Dockerfile delete mode 100755 app.py create mode 100644 app/Dockerfile rename api_keys.example.json => app/api_keys.example.json (100%) rename config.py => app/config.py (90%) rename deps.py => app/deps.py (85%) rename llm.py => app/llm.py (89%) create mode 100755 app/main.py rename schemas/api.py => app/schemas.py (100%) create mode 100644 app/scripts/arena.py rename {scripts => app/scripts}/spp_fewshots.jinja (85%) rename security.py => app/security.py (94%) rename subscriptions.py => app/subscriptions.py (76%) rename {tests => app/tests}/conftest.py (71%) rename tests/mockups/llm.py => app/tests/mockups.py (96%) rename {tests => app/tests}/test_main.py (96%) rename docker-compose.yml => compose.yml (75%) delete mode 100644 endpoints/__init__.py delete mode 100755 endpoints/api.py delete mode 100755 endpoints/misc.py delete mode 100755 schemas/__init__.py delete mode 100644 scripts/arena.py delete mode 100644 tests/__init__.py delete mode 100755 tests/mockups/elasticsearch.py delete mode 100644 tests/mockups/prompt_config.yml delete mode 100644 tests/mockups/simple_prompt_template.jinja diff --git a/.env b/.env new file mode 100644 index 0000000..bfa838b --- /dev/null +++ b/.env @@ -0,0 +1,10 @@ +LANGUAGE_MODEL=AgentPublic/llama3-instruct-8b +EMBEDDINGS_MODEL=BAAI/bge-m3 +ALBERT_BASE_URL=https://albert.api.dev.etalab.gouv.fr/v1 +#ALBERT_BASE_URL=http://localhost:8080/v1 +#ALBERT_API_KEY=spp-prod-V3QLrqTTmzd5jrZSiAAdHnFw9ijyc7m2DZDK97nAq4md34DMpvQYmCZj7wQYwkta +ALBERT_API_KEY=leo-qQVK5cFW4R7QbCxx5V33gKE9qzER32tesNb4DTrPeg7sqVrRsUdprJfArwMtAxui +REDIS_HOST=albert.bdd.001.etalab.gouv.fr +REDIS_PORT=36379 +REDIS_PASSWORD=gaYauVqErKgFtAgDZrgJt4ZKohjoJ7FXkgQAEU3gMVSAwHwY2TqeaeTwofroeJnk +COLLECTION_ID=5080b4bc-71a3-49af-acfc-c27f56079f0c \ No newline at end of file diff --git a/.github/workflows/build_and_deploy.yml b/.github/workflows/build_and_deploy.yml index a9bd886..878f901 100644 --- a/.github/workflows/build_and_deploy.yml +++ b/.github/workflows/build_and_deploy.yml @@ -4,11 +4,9 @@ on: push: branches: - main - - staging - - dev env: - IMAGE_NAME: ghcr.io/${{ github.repository }}/albert-spp + APP_IMAGE_NAME: ghcr.io/${{ github.repository }}/app IMAGE_TAG: ${{ github.sha }} application_name: albert-spp deployment_environment: staging @@ -35,18 +33,19 @@ jobs: uses: docker/build-push-action@v6 with: context: . + file: ./app/Dockerfile push: true - tags: ${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }},${{ env.IMAGE_NAME }}:latest + tags: ${{ env.APP_IMAGE_NAME }}:${{ env.IMAGE_TAG }},${{ env.APP_IMAGE_NAME }}:latest cache-from: type=gha cache-to: type=gha,mode=max - deploy-dev: + deploy-staging: name: Deploy from ${{ github.ref_name }}/${{ github.sha }} runs-on: ubuntu-latest needs: build-and-push - if: github.ref_name == 'dev' + if: github.ref_name == 'main' steps: - - name: Trigger dev deployment + - name: Trigger staging deployment run: | RESPONSE="$(curl --request POST \ --form token=${{ secrets.GITLAB_CI_TOKEN }} \ @@ -64,4 +63,3 @@ jobs: echo $RESPONSE exit 1 fi - diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml deleted file mode 100644 index 5538fe9..0000000 --- a/.gitlab-ci.yml +++ /dev/null @@ -1,47 +0,0 @@ -stages: - - test - - build - -cache: - paths: - - .cache/pip - - venv/ - -test: - #stage: test - image: python:3.10-slim - rules: - - if: $CI_MERGE_REQUEST_TARGET_BRANCH_NAME == "dev" || $CI_MERGE_REQUEST_TARGET_BRANCH_NAME == "staging" || $CI_MERGE_REQUEST_TARGET_BRANCH_NAME == "main" - changes: - - "**/*.py" - when: always - before_script: - - | - python -m venv venv - source venv/bin/activate - pip install --cache-dir .cache/pip .[test] - script: - - pytest -W ignore - -build: - rules: - - if: $CI_COMMIT_BRANCH == "dev" - stage: build - image: docker:latest - services: - - docker:dind - script: - - | # build and push image to gitlab registry - docker login --username gitlab-ci-token --password $CI_JOB_TOKEN $CI_REGISTRY - - if [[ $CI_COMMIT_BRANCH == "dev" ]]; then - docker build --tag ${CI_REGISTRY_IMAGE}/api:${CI_COMMIT_SHORT_SHA} --file ./Dockerfile . - docker push ${CI_REGISTRY_IMAGE}/api:${CI_COMMIT_SHORT_SHA} - docker tag ${CI_REGISTRY_IMAGE}/api:${CI_COMMIT_SHORT_SHA} ${CI_REGISTRY_IMAGE}/api:latest - docker push ${CI_REGISTRY_IMAGE}/api:latest - - elif [[ $CI_COMMIT_BRANCH == "main" ]]; then - docker build --tag ${CI_REGISTRY_IMAGE}/api:stable --file ./Dockerfile . - docker push ${CI_REGISTRY_IMAGE}/api:stable - - fi \ No newline at end of file diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 5a0ac8f..0000000 --- a/Dockerfile +++ /dev/null @@ -1,16 +0,0 @@ -FROM python:3.10-slim - -RUN addgroup --gid 1101 albert && \ - adduser --uid 1100 --gid 1101 --home /code albert - -WORKDIR /code - -# add requirements.txt before the rest of the code to cache the pip install -ADD ./pyproject.toml /code/api/pyproject.toml -RUN pip install --no-cache-dir /code/api/. - -# add the rest of the code -ENV PYTHONPATH="$PYTHONPATH:/code/api" -ADD . /code/api - -USER albert diff --git a/README.md b/README.md index dd476e6..028b25e 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,17 @@ # Albert - Services Publics Plus - + 1. Envoi du prompt au modèle -```sh -curl -XPOST https://spp.etalab.gouv.fr/api/spp/anonymize -H "Content-Type: application/json" \ - -H "Authorization: Bearer $API_KEY" \ - -d '{"id":"123", "text":"Merci pour service"}' -``` + ```sh + curl -XPOST https://spp.etalab.gouv.fr/api/spp/anonymize -H "Content-Type: application/json" \ + -H "Authorization: Bearer $API_KEY" \ + -d '{"id":"123", "text":"Merci pour service."}' + ``` 2. Récupération de la réponse du modèle -```sh -curl -XPOST https://spp.etalab.gouv.fr/api/spp/prod/run/ditp-get-data -H "Content-Type: application/json" \ - -H "Authorization: Bearer $API_KEY" \ - -d '{"id":"123"}' -``` + ```sh + curl -XPOST https://spp.etalab.gouv.fr/api/spp/prod/run/ditp-get-data -H "Content-Type: application/json" \ + -H "Authorization: Bearer $API_KEY" \ + -d '{"id":"123"}' + ``` diff --git a/app.py b/app.py deleted file mode 100755 index 81b622f..0000000 --- a/app.py +++ /dev/null @@ -1,69 +0,0 @@ -from contextlib import asynccontextmanager -import requests - -from fastapi import FastAPI -from redis import Redis -from starlette.middleware.cors import CORSMiddleware - -from config import ( - ALBERT_API_KEY, - ALBERT_BASE_URL, - LANGUAGE_MODEL, - EMBEDDINGS_MODEL, - APP_NAME, - APP_VERSION, - BACKEND_CORS_ORIGINS, - ENV, -) -from deps import get_redis -from endpoints import api, misc -from subscriptions import Listener - - -def init_redis(r: Redis): - app.state.listener = Listener(r, ["spp-exp-channel"]) - app.state.listener.start() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - # Startup code - if ENV != "unittest": - r = next(get_redis(finally_close=False)) - init_redis(r) - - request = requests.get(f"{ALBERT_BASE_URL}/models", headers={"Authorization": f"Bearer {ALBERT_API_KEY}"}) - request.raise_for_status() - models = [model["id"] for model in request.json()["data"]] - assert LANGUAGE_MODEL in models, f"Model {LANGUAGE_MODEL} not found" - assert EMBEDDINGS_MODEL in models, f"Model {EMBEDDINGS_MODEL} not found" - - yield - - # Shutdown code - app.state.listener.stop() - - -# Init server -app = FastAPI( - title=APP_NAME, - version=APP_VERSION, - contact={ - "name": "Etalab", - "url": "https://www.etalab.gouv.fr/", - "email": "etalab@modernisation.gouv.fr", - }, - lifespan=lifespan, -) - -if BACKEND_CORS_ORIGINS: - app.add_middleware( - CORSMiddleware, - allow_origins=BACKEND_CORS_ORIGINS, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - -app.include_router(misc.router) -app.include_router(api.router) diff --git a/app/Dockerfile b/app/Dockerfile new file mode 100644 index 0000000..23b9173 --- /dev/null +++ b/app/Dockerfile @@ -0,0 +1,12 @@ +FROM python:3.12-slim + +RUN groupadd --gid 1100 albert +RUN useradd --home /home/albert --gid 1100 --uid 1100 albert +USER albert + +WORKDIR /home/albert +ADD ./pyproject.toml ./pyproject.toml +RUN pip install . +ADD ./app /home/albert/app +ENV PYTHONPATH="/home/albert/app:${PYTHONPATH}" +ENV PATH="/home/albert/.local/bin:${PATH}" \ No newline at end of file diff --git a/api_keys.example.json b/app/api_keys.example.json similarity index 100% rename from api_keys.example.json rename to app/api_keys.example.json diff --git a/config.py b/app/config.py similarity index 90% rename from config.py rename to app/config.py index c7ecd5f..bc2a24c 100755 --- a/config.py +++ b/app/config.py @@ -10,8 +10,9 @@ REDIS_HOST = os.environ.get("REDIS_HOST", "redis") REDIS_PORT = int(os.environ.get("REDIS_PORT", "6379")) REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD", None) + LANGUAGE_MODEL = os.environ["LANGUAGE_MODEL"] EMBEDDINGS_MODEL = os.environ["EMBEDDINGS_MODEL"] ALBERT_BASE_URL = os.environ["ALBERT_BASE_URL"] ALBERT_API_KEY = os.environ["ALBERT_API_KEY"] -COLLECTION = os.getenv("COLLECTION", "plus-transformation-gouv-fr") +COLLECTION_ID = os.environ["COLLECTION_ID"] diff --git a/deps.py b/app/deps.py similarity index 85% rename from deps.py rename to app/deps.py index 44b76b7..718a54d 100755 --- a/deps.py +++ b/app/deps.py @@ -1,6 +1,6 @@ from redis import Redis -from config import REDIS_HOST, REDIS_PASSWORD, REDIS_PORT +from app.config import REDIS_HOST, REDIS_PASSWORD, REDIS_PORT def get_redis(finally_close=True) -> Redis: diff --git a/llm.py b/app/llm.py similarity index 89% rename from llm.py rename to app/llm.py index 64ea804..d963ae6 100755 --- a/llm.py +++ b/app/llm.py @@ -1,6 +1,6 @@ import requests -from config import ALBERT_BASE_URL, ALBERT_API_KEY, LANGUAGE_MODEL, EMBEDDINGS_MODEL, COLLECTION +from app.config import ALBERT_API_KEY, ALBERT_BASE_URL, COLLECTION_ID, EMBEDDINGS_MODEL, LANGUAGE_MODEL def few_shots(prompt: str): @@ -34,7 +34,7 @@ def few_shots(prompt: str): Veuillez apporter une réponse circonstanciée à cette question en respectant scrupuleusement les directives énoncées ci-dessus. """ data = { - "collections": [COLLECTION], + "collections": [COLLECTION_ID], "model": EMBEDDINGS_MODEL, "k": 4, "prompt": prompt, @@ -44,7 +44,7 @@ def few_shots(prompt: str): response = response.json() context = "\n\n\n".join([ - f"Question: {result['chunk']['metadata'].get('question', 'N/A')}\n" f"Réponse: {result['chunk']['metadata'].get('answer', 'N/A')}" + f"Question: {result["chunk"]["metadata"].get("question", "N/A")}\n" f"Réponse: {result["chunk"]["metadata"].get("answer", "N/A")}" for result in response["data"] ]) diff --git a/app/main.py b/app/main.py new file mode 100755 index 0000000..49bdf59 --- /dev/null +++ b/app/main.py @@ -0,0 +1,108 @@ +from contextlib import asynccontextmanager +import datetime as dt +import json +from typing import List, Union +import uuid + +from fastapi import Body, Depends, FastAPI, HTTPException, Response, Security +from redis import Redis +import requests +from starlette.middleware.cors import CORSMiddleware + +from app.config import ( + ALBERT_API_KEY, + ALBERT_BASE_URL, + APP_NAME, + APP_VERSION, + EMBEDDINGS_MODEL, + ENV, + LANGUAGE_MODEL, +) +from app.deps import get_redis +from app.schemas import ExpId, ExpIdWithText +from app.security import check_api_key +from app.subscriptions import Listener + + +def init_redis(r: Redis): + app.state.listener = Listener(r, ["spp-exp-channel"]) + app.state.listener.start() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup code + if ENV != "unittest": + r = next(get_redis(finally_close=False)) + init_redis(r) + + request = requests.get(f"{ALBERT_BASE_URL}/models", headers={"Authorization": f"Bearer {ALBERT_API_KEY}"}) + request.raise_for_status() + models = [model["id"] for model in request.json()["data"]] + assert LANGUAGE_MODEL in models, f"Model {LANGUAGE_MODEL} not found" + assert EMBEDDINGS_MODEL in models, f"Model {EMBEDDINGS_MODEL} not found" + + yield + + # Shutdown code + app.state.listener.stop() + + +app = FastAPI(title=APP_NAME, version=APP_VERSION, lifespan=lifespan) +app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) + + +@app.get("/health") +def healt() -> dict[str, str]: + return Response(status_code=200) + + +@app.post("/anonymize") +def anonymize( + form_data: Union[ExpIdWithText, List[ExpIdWithText]] = Body(...), + redis: Redis = Depends(get_redis), + api_key: str = Security(check_api_key), +): + if not isinstance(form_data, list): + form_data = [form_data] + + for data in form_data: + if not data.id: + # see https://tchap.gouv.fr/#/room/!ZyhOfCwElHmyNMSlcw:agent.dinum.tchap.gouv.fr/$XMeXbIDhGtXBycZu-9Px2frsczn_iU7xiJ5xvjbs-pQ?via=agent.dinum.tchap.gouv.fr&via=agent.externe.tchap.gouv.fr&via=agent.tchap.gouv.fr + data.id = str(uuid.uuid4()) + + data = data.model_dump() + + data["time"] = dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f%z") + print(f"anonymize - {data["id"]}: {data["time"]}") # TODO: replace with logger later + redis.publish("spp-exp-channel", json.dumps(data)) + + if len(form_data) == 1: + responseOutput = {"id": form_data[0].id} + else: + # The spec is ill-defined !! + responseOutput = [{"id": x.id} for x in form_data] + + return {"body": responseOutput} + + +@app.post("/prod/run/ditp-get-data") +def ditp_get_data(form_data: Union[ExpId, List[ExpId]] = Body(...), redis: Redis = Depends(get_redis), api_key: str = Security(check_api_key)): + if not isinstance(form_data, list): + form_data = [form_data] + + answers = [] + for data in form_data: + data = data.model_dump() + answer = redis.get(data["id"]) + answers.append(answer) + + if len(form_data) == 1: + if answers[0] is None: + raise HTTPException(status_code=400, detail="ID not found") + responseOutput = {"generated_answer": answers[0]} + else: + # The spec is ill-defined !! + responseOutput = [{"generated_answer": x} for x in answers] + + return {"body": responseOutput} diff --git a/schemas/api.py b/app/schemas.py similarity index 100% rename from schemas/api.py rename to app/schemas.py diff --git a/app/scripts/arena.py b/app/scripts/arena.py new file mode 100644 index 0000000..6697bb6 --- /dev/null +++ b/app/scripts/arena.py @@ -0,0 +1,123 @@ +########################################################### +###################### DEPRECATED ######################### +########################################################### + +# import os +# import re + +# os.environ["API_URL"] = "https://franceservices.dev.etalab.gouv.fr" +# os.environ["ELASTIC_HOST"] = "albert.bdd.001.etalab.gouv.fr" +# os.environ["ELASTIC_PORT"] = "39200" +# os.environ["ELASTIC_PASSWORD"] = os.getenv("ELASTIC_PASSWORD") + +# import pandas as pd +# from openai import OpenAI +# from pyalbert import set_llm_table + +# # ============================ +# # !pip instal pyalbert==0.7.3 +# # ============================ +# from pyalbert.clients import LlmClient +# from pyalbert.prompt import Prompter + +# # Set model locations +# # -- +# albert_api_key = os.getenv("ALBERT_API_KEY") +# jeanzay_api_key = os.getenv("JEANZAY_API_KEY") + +# LLM_TABLE = [ +# # {"model": "AgentPublic/fabrique-reference-2", "url": "http://albert.gpu.001.etalab.gouv.fr:8001/v1", "token": albert_api_key, "legacy":True, "prompt_format":"llama2-chat", "template":"spp_fabrique_simple.jinja"}, +# { +# "model": "AgentPublic/llama3-instruct-8b", +# "url": "http://albert.gpu.005.etalab.gouv.fr:8001/v1", +# "token": albert_api_key, +# "template": "spp_fewshots.jinja", +# }, +# # {"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "url": "http://llama38b.multivacplatform.org/v1/", "token": jeanzay_api_key, "alias"}, +# ] +# # We just need or to point the embedding model +# set_llm_table( +# [ +# {"model": "BAAI/bge-m3", "url": "http://albert.gpu.005.etalab.gouv.fr:8001"}, +# ] +# ) + +# # Load test data +# # -- +# file_path = "_data/export-expa-c-riences.json" +# df = pd.read_json(file_path) + +# # Filter the df on the given attributes +# domains = ["MSA", "CNAV"] +# df = df[df["intitule_typologie_1"].isin(domains)] + +# # Randomly sample 10 items from the DataFrame +# small_df = df.sample(n=12, random_state=1) # random_state is optional for reproducibility +# del df + + +# # Just for pedogical purpose +# def legacy_generate(messages, model=None, base_url=None, api_key=None, **sampling_params): +# client = OpenAI( +# base_url=base_url, +# api_key=api_key, +# ) + +# completion = client.chat.completions.create(model=model, messages=messages, **sampling_params) + +# return completion + + +# # system_prompt = "Tu es un générateur de réponse automatique à une expérience utilisateur. Tu réponds directement, sans explications." +# system_prompt = "Tu es un générateur de réponse automatique à une expérience utilisateur. Tu parles un français courtois." +# spp_sampling_params = { +# "temperature": 0.25, +# "max_tokens": 4096, +# } + +# results = [] +# for i, item in small_df.iterrows(): +# query = item["description"] +# row = { +# "query": query, +# "reponse_SPP": item["reponse_structure_1"], +# "institution": item["intitule_typologie_1"], +# } +# for model_ in LLM_TABLE: +# llm_client = LlmClient(model_["model"], base_url=model_["url"], api_key=model_["token"]) +# model_name = model_["model"].split("/")[-1] + +# # Build the prompt/messages +# # -- +# config = { +# "do_encode_prompt": model_.get("legacy", False), +# "prompt_format": model_.get("prompt_format"), +# "sampling_params": spp_sampling_params, +# "default": {"limit": 5}, +# } +# prompter = Prompter(config=config, template=model_.get("template")) +# prompt = prompter.make_prompt(query=query, system_prompt=system_prompt) + +# # Gnerate the answer +# # -- +# sampling_params = prompter.get_upstream_sampling_params() # eventual sampling param defined in the prompter config/mode +# try: +# result = llm_client.generate(prompt, **sampling_params) +# # Stricly equivalent here to: +# # result = legacy_generate(prompt, model=model_["model"], base_url=model_["url"], api_key=model_["token"], **sampling_params) +# except: +# result = llm_client.generate(prompt, **sampling_params) + +# # Remove artefact from the answer +# answer = result.choices[0].message.content +# answer = re.sub(r"^<[^>]+>|<[^>]+>$", "", answer.strip("\n \"'#`")) +# if answer.startswith("Réponse :"): +# answer = answer[len("Réponse :") :] +# model_name = model_.get("alias", model_name) +# row[f"answser_{model_name}"] = answer.strip() +# print(".", end="", flush=True) + +# results.append(row) + +# df = pd.DataFrame(results) +# df.to_csv("spp_arena.csv", index=False) diff --git a/scripts/spp_fewshots.jinja b/app/scripts/spp_fewshots.jinja similarity index 85% rename from scripts/spp_fewshots.jinja rename to app/scripts/spp_fewshots.jinja index 282ae79..d8e1a9c 100644 --- a/scripts/spp_fewshots.jinja +++ b/app/scripts/spp_fewshots.jinja @@ -1,5 +1,9 @@ Vous êtes un agent expérimenté de l'administration française, spécialisé dans les questions administratives. Votre mission est de répondre aux interrogations des usagers avec professionnalisme, courtoisie et précision. +########################################################### +###################### DEPRECATED ######################### +########################################################### + Instructions : 1. Utilisez un langage soutenu et élégant, tout en restant accessible. 2. Inspirez vous de la base de connaissance pour répondre. diff --git a/security.py b/app/security.py similarity index 94% rename from security.py rename to app/security.py index 4f7da95..96ac9c7 100755 --- a/security.py +++ b/app/security.py @@ -3,7 +3,7 @@ from fastapi import HTTPException, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from config import API_KEYS, ENV +from app.config import API_KEYS, ENV if API_KEYS and ENV not in ["dev", "unittest"]: diff --git a/subscriptions.py b/app/subscriptions.py similarity index 76% rename from subscriptions.py rename to app/subscriptions.py index 3eb6e0a..0c05ed6 100755 --- a/subscriptions.py +++ b/app/subscriptions.py @@ -5,7 +5,7 @@ from redis import Redis -from llm import few_shots +from app.llm import few_shots class Listener(threading.Thread): @@ -13,7 +13,7 @@ class Listener(threading.Thread): KILL_PILL = "EXIT" def __init__(self, r: Redis, channels): - print("info: listener init") # TODO: replace with logger later + logging.info("listener init") threading.Thread.__init__(self) self.redis = r self.pubsub = self.redis.pubsub() @@ -21,7 +21,7 @@ def __init__(self, r: Redis, channels): self.pubsub.subscribe(self.KILL_PILL) # Subscribe to a special stop channel def run(self): - print("info: listener run") # TODO: replace with logger later + logging.info("listener run") for item in self.pubsub.listen(): if item["type"] == "message" and item["channel"] == self.KILL_PILL.encode(): break @@ -30,17 +30,16 @@ def run(self): data = json.loads(item["data"]) duration = dt.datetime.now(dt.timezone.utc) - dt.datetime.strptime(data["time"], "%Y-%m-%d %H:%M:%S.%f%z") - print(f"duration time - {data['id']}: {duration.total_seconds()} s") + logging.debug(f"duration time - {data["id"]}: {duration.total_seconds()} s") - # do not fail silently try: answer = few_shots(prompt=data["text"]) except Exception: import traceback error_traceback = traceback.format_exc() - logging.error(f"\nRequest prompt:\n{data['text']}\n\nError:\n{error_traceback}") - answer = f"error on {data['id']} request, please resend prompt later." + logging.error(f"\nRequest prompt:\n{data["text"]}\n\nError:\n{error_traceback}") + answer = f"error on {data["id"]} request, please resend prompt later." self.redis.set( name=data["id"], # key diff --git a/tests/conftest.py b/app/tests/conftest.py similarity index 71% rename from tests/conftest.py rename to app/tests/conftest.py index aa5bfcb..f77a276 100755 --- a/tests/conftest.py +++ b/app/tests/conftest.py @@ -2,7 +2,7 @@ import time from pathlib import Path from typing import Generator -from urllib.parse import urlparse + import fakeredis import pytest @@ -10,8 +10,8 @@ from fastapi.testclient import TestClient from pytest import fail -from app import app, init_redis -from deps import get_redis +from app import init_redis, main +from app.deps import get_redis def log_and_assert(response, code): @@ -23,7 +23,7 @@ def log_and_assert(response, code): fail(f"Expected status code 200, but got {response.status_code}.\nError details: {response.text}") -def start_mock_server(command, health_route="/healthcheck", timeout=10, interval=1, cwd=None): +def start_mock_server(command, timeout=10, interval=1, cwd=None): """Starts a mock server using subprocess.Popen and waits for it to be ready by polling a health check endpoint. """ @@ -33,9 +33,7 @@ def start_mock_server(command, health_route="/healthcheck", timeout=10, interval end_time = time.time() + timeout while True: try: - host = "localhost" - port = command[-1] - response = requests.get(f"http://{host}:{port}" + health_route) + response = requests.get("http://localhost:8000/health") if response.status_code == 200: # Server is ready break @@ -44,7 +42,7 @@ def start_mock_server(command, health_route="/healthcheck", timeout=10, interval pass if time.time() > end_time: - raise RuntimeError("Timeout waiting for server to start") + raise RuntimeError("Timeout waiting for server to start.") time.sleep(interval) except Exception as e: @@ -55,7 +53,7 @@ def start_mock_server(command, health_route="/healthcheck", timeout=10, interval # -# API mockups +# Albert API mockup # APP_FOLDER = Path(__file__).parents[1] @@ -63,20 +61,7 @@ def start_mock_server(command, health_route="/healthcheck", timeout=10, interval @pytest.fixture(scope="session") def mock_llm() -> Generator: - if len(LLM_TABLE) > 0: - LLM_HOST, LLM_PORT = urlparse(LLM_TABLE[0]["url"]).netloc.split(":") - - process = start_mock_server(["uvicorn", "tests.mockups.llm:app", "--port", LLM_PORT], cwd=APP_FOLDER) - yield - process.kill() - - -@pytest.fixture(scope="session") -def mock_server_es(): - process = start_mock_server( - ["uvicorn", "tests.mockups.elasticsearch:app", "--port", ELASTIC_PORT], - cwd=APP_FOLDER, - ) + process = start_mock_server(["uvicorn", "tests.mockups:app", "--port", 8080], cwd=APP_FOLDER) yield process.kill() @@ -112,12 +97,12 @@ def client(redis_client) -> Generator: def override_get_redis(): yield redis_client - app.dependency_overrides[get_redis] = override_get_redis - with TestClient(app) as c: + main.dependency_overrides[get_redis] = override_get_redis + with TestClient(main) as c: yield c # Remove the dependency override after the tests - app.dependency_overrides = {} + main.dependency_overrides = {} class TestApi: @@ -127,6 +112,6 @@ def setup_method(self): def teardown_method(self): pass - def test_mockup(self, mock_llm, mock_server_es, mock_redis): + def test_mockup(self, mock_llm, mock_redis): # Start the server pass diff --git a/tests/mockups/llm.py b/app/tests/mockups.py similarity index 96% rename from tests/mockups/llm.py rename to app/tests/mockups.py index 9620a12..1101a4d 100755 --- a/tests/mockups/llm.py +++ b/app/tests/mockups.py @@ -1,7 +1,7 @@ import time from typing import Optional -from fastapi import FastAPI +from fastapi import FastAPI, Response from pydantic import BaseModel, ConfigDict, Field app = FastAPI() @@ -77,9 +77,9 @@ class EmbeddingResponse(BaseModel): # endpoints -@app.get("/healthcheck") +@app.get("/health") async def healthcheck(): - return "ok" + return Response(status_code=200) @app.post("/chat/completions", response_model=ChatCompletionResponse) diff --git a/tests/test_main.py b/app/tests/test_main.py similarity index 96% rename from tests/test_main.py rename to app/tests/test_main.py index c3b7a68..46fdf2b 100755 --- a/tests/test_main.py +++ b/app/tests/test_main.py @@ -2,7 +2,7 @@ from fastapi.testclient import TestClient -from config import APP_VERSION +from app.config import APP_VERSION from tests.conftest import TestApi, log_and_assert diff --git a/docker-compose.yml b/compose.yml similarity index 75% rename from docker-compose.yml rename to compose.yml index 41284c7..0335d18 100644 --- a/docker-compose.yml +++ b/compose.yml @@ -12,21 +12,19 @@ services: volumes: - redis:/data - api: - image: albert-spp/api:latest - build: - context: . - dockerfile: Dockerfile + app: + image: ghcr.io/etalab-ia/albert-spp/app:latest ports: - 8000:8000 restart: always - command: uvicorn api.app:app --proxy-headers --forwarded-allow-ips='*' --host 0.0.0.0 --port 8000 --log-level debug + command: uvicorn app.main:app --proxy-headers --forwarded-allow-ips='*' --host 0.0.0.0 --port 8000 --log-level debug environment: - API_KEYS=changeme - ENV=dev - APP_VERSION=1.0.0 - ALBERT_BASE_URL=http://localhost:8080/v1 - ALBERT_API_KEY=changeme + - COLLECTION_ID=1 - LANGUAGE_MODEL=AgentPublic/llama3-fabrique-texte - EMBEDDINGS_MODEL=AgentPublic/llama3-fabrique-texte depends_on: diff --git a/endpoints/__init__.py b/endpoints/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/endpoints/api.py b/endpoints/api.py deleted file mode 100755 index 949ced0..0000000 --- a/endpoints/api.py +++ /dev/null @@ -1,68 +0,0 @@ -import datetime as dt -import json -import uuid -from typing import List, Union - -from fastapi import APIRouter, Body, Depends, HTTPException, Security -from redis import Redis - -import schemas -from deps import get_redis -from security import check_api_key - -router = APIRouter(tags=["api"]) - - -@router.post("/anonymize") -def anonymize( - form_data: Union[schemas.ExpIdWithText, List[schemas.ExpIdWithText]] = Body(...), - redis: Redis = Depends(get_redis), - api_key: str = Security(check_api_key), -): - if not isinstance(form_data, list): - form_data = [form_data] - - for data in form_data: - if not data.id: - # see https://tchap.gouv.fr/#/room/!ZyhOfCwElHmyNMSlcw:agent.dinum.tchap.gouv.fr/$XMeXbIDhGtXBycZu-9Px2frsczn_iU7xiJ5xvjbs-pQ?via=agent.dinum.tchap.gouv.fr&via=agent.externe.tchap.gouv.fr&via=agent.tchap.gouv.fr - data.id = str(uuid.uuid4()) - - data = data.model_dump() - - data["time"] = dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f%z") - print(f"anonymize - {data['id']}: {data['time']}") # TODO: replace with logger later - redis.publish("spp-exp-channel", json.dumps(data)) - - if len(form_data) == 1: - responseOutput = {"id": form_data[0].id} - else: - # The spec is ill-defined !! - responseOutput = [{"id": x.id} for x in form_data] - - return {"body": responseOutput} - - -@router.post("/prod/run/ditp-get-data") -def ditp_get_data( - form_data: Union[schemas.ExpId, List[schemas.ExpId]] = Body(...), - redis: Redis = Depends(get_redis), - api_key: str = Security(check_api_key), -): - if not isinstance(form_data, list): - form_data = [form_data] - - answers = [] - for data in form_data: - data = data.model_dump() - answer = redis.get(data["id"]) - answers.append(answer) - - if len(form_data) == 1: - if answers[0] is None: - raise HTTPException(status_code=400, detail="ID not found") - responseOutput = {"generated_answer": answers[0]} - else: - # The spec is ill-defined !! - responseOutput = [{"generated_answer": x} for x in answers] - - return {"body": responseOutput} diff --git a/endpoints/misc.py b/endpoints/misc.py deleted file mode 100755 index cc894e3..0000000 --- a/endpoints/misc.py +++ /dev/null @@ -1,10 +0,0 @@ -from fastapi import APIRouter - -from config import APP_VERSION - -router = APIRouter(tags=["misc"]) - - -@router.get("/healthcheck") -def get_healthcheck() -> dict[str, str]: - return {"msg": "OK", "version": APP_VERSION} diff --git a/pyproject.toml b/pyproject.toml index e32d663..1fa79ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,14 @@ [project] name = "albert-spp" version = "1.0.3" -requires-python = ">=3.10" +requires-python = ">=3.12" license = { text = "MIT" } dependencies = [ "fastapi==0.111.1", "pydantic==2.8.2", - "python-dotenv==1.0.1", "uvicorn==0.30.1", "redis==5.0.2", "requests==2.32.3", - "jinja2==3.1.4", ] [project.optional-dependencies] diff --git a/schemas/__init__.py b/schemas/__init__.py deleted file mode 100755 index e5dd28a..0000000 --- a/schemas/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .api import ExpId, ExpIdWithText - -__all__ = ["ExpId", "ExpIdWithText"] diff --git a/scripts/arena.py b/scripts/arena.py deleted file mode 100644 index ed40853..0000000 --- a/scripts/arena.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -import re - -os.environ["API_URL"] = "https://franceservices.dev.etalab.gouv.fr" -os.environ["ELASTIC_HOST"] = "albert.bdd.001.etalab.gouv.fr" -os.environ["ELASTIC_PORT"] = "39200" -os.environ["ELASTIC_PASSWORD"] = os.getenv("ELASTIC_PASSWORD") - -import pandas as pd -from openai import OpenAI -from pyalbert import set_llm_table - -# ============================ -# !pip instal pyalbert==0.7.3 -# ============================ -from pyalbert.clients import LlmClient -from pyalbert.prompt import Prompter - -# Set model locations -# -- -albert_api_key = os.getenv("ALBERT_API_KEY") -jeanzay_api_key = os.getenv("JEANZAY_API_KEY") - -LLM_TABLE = [ - #{"model": "AgentPublic/fabrique-reference-2", "url": "http://albert.gpu.001.etalab.gouv.fr:8001/v1", "token": albert_api_key, "legacy":True, "prompt_format":"llama2-chat", "template":"spp_fabrique_simple.jinja"}, - {"model": "AgentPublic/llama3-instruct-8b", "url": "http://albert.gpu.005.etalab.gouv.fr:8001/v1", "token": albert_api_key, "template":"spp_fewshots.jinja"}, - #{"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "url": "http://llama38b.multivacplatform.org/v1/", "token": jeanzay_api_key, "alias"}, - - ] -# We just need or to point the embedding model -set_llm_table([ - {"model": "BAAI/bge-m3", "url": "http://albert.gpu.005.etalab.gouv.fr:8001" }, -]) - -# Load test data -# -- -file_path = '_data/export-expa-c-riences.json' -df = pd.read_json(file_path) - -# Filter the df on the given attributes -domains = ["MSA", "CNAV"] -df = df[df['intitule_typologie_1'].isin(domains)] - -# Randomly sample 10 items from the DataFrame -small_df = df.sample(n=12, random_state=1) # random_state is optional for reproducibility -del df - - -# Just for pedogical purpose -def legacy_generate(messages, model=None, base_url=None, api_key=None, **sampling_params): - client = OpenAI( - base_url=base_url, - api_key=api_key, - ) - - completion = client.chat.completions.create( - model=model, - messages=messages, - **sampling_params - ) - - return completion - -#system_prompt = "Tu es un générateur de réponse automatique à une expérience utilisateur. Tu réponds directement, sans explications." -system_prompt = "Tu es un générateur de réponse automatique à une expérience utilisateur. Tu parles un français courtois." -spp_sampling_params = { - "temperature": 0.25, - "max_tokens": 4096, -} - -results = [] -for i, item in small_df.iterrows(): - query = item["description"] - row = { - "query": query, - "reponse_SPP": item["reponse_structure_1"], - "institution": item["intitule_typologie_1"], - } - for model_ in LLM_TABLE: - llm_client = LlmClient(model_["model"], base_url=model_["url"], api_key=model_["token"]) - model_name = model_["model"].split("/")[-1] - - # Build the prompt/messages - # -- - config = { - "do_encode_prompt": model_.get("legacy", False), - "prompt_format": model_.get("prompt_format"), - "sampling_params": spp_sampling_params, - "default": {"limit":5} - } - prompter = Prompter(config=config, template=model_.get("template")) - prompt = prompter.make_prompt(query=query, system_prompt=system_prompt) - - # Gnerate the answer - # -- - sampling_params = prompter.get_upstream_sampling_params() # eventual sampling param defined in the prompter config/mode - try: - result = llm_client.generate(prompt, **sampling_params) - # Stricly equivalent here to: - #result = legacy_generate(prompt, model=model_["model"], base_url=model_["url"], api_key=model_["token"], **sampling_params) - except: - result = llm_client.generate(prompt, **sampling_params) - - # Remove artefact from the answer - answer = result.choices[0].message.content - answer = re.sub(r'^<[^>]+>|<[^>]+>$', '', answer.strip("\n \"'#`")) - if answer.startswith("Réponse :"): - answer = answer[len("Réponse :"):] - model_name = model_.get("alias", model_name) - row[f"answser_{model_name}"] = answer.strip() - print(".", end="", flush=True) - - results.append(row) - -df = pd.DataFrame(results) -df.to_csv('spp_arena.csv', index=False) - diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/mockups/elasticsearch.py b/tests/mockups/elasticsearch.py deleted file mode 100755 index ae2b662..0000000 --- a/tests/mockups/elasticsearch.py +++ /dev/null @@ -1,52 +0,0 @@ -from fastapi import FastAPI -from fastapi.responses import JSONResponse - -app = FastAPI() - - -@app.get("/healthcheck") -async def healthcheck(): - return "ok" - - -@app.post("/{index_name}/_search") -async def search(index_name: str): - data = [] - - data = { - "took": 30, - "timed_out": False, - "_shards": {"total": 5, "successful": 5, "skipped": 0, "failed": 0}, - "hits": { - "total": {"value": 1000, "relation": "eq"}, - "max_score": 1.3862944, - "hits": [ - { - "_index": index_name, - "_type": "_doc", - "_id": "1", - "_score": 1.3862944, - "_source": { - "id_experience": "some_id", - "titre": "Titre", - "description": "Description", - "reponse_structure_1": "Reponse Structure 1", - }, - }, - ], - }, - "aggregations": { - "categories": { - "doc_count_error_upper_bound": 0, - "sum_other_doc_count": 0, - "buckets": [ - {"key": "Programming", "doc_count": 2}, - {"key": "Data Science", "doc_count": 1}, - ], - } - }, - } - - response = JSONResponse(data) - response.headers["X-Elastic-Product"] = "Elasticsearch" - return response diff --git a/tests/mockups/prompt_config.yml b/tests/mockups/prompt_config.yml deleted file mode 100644 index bb71952..0000000 --- a/tests/mockups/prompt_config.yml +++ /dev/null @@ -1,7 +0,0 @@ -prompt_format: llama3-chat -max_tokens: 4096 - -prompts: - - mode: simple - template: simple_prompt_template.jinja - diff --git a/tests/mockups/simple_prompt_template.jinja b/tests/mockups/simple_prompt_template.jinja deleted file mode 100644 index 2a6572e..0000000 --- a/tests/mockups/simple_prompt_template.jinja +++ /dev/null @@ -1,10 +0,0 @@ -Mode simple - -Question soumise au service {% if institution %}{{institution}} {% endif %}: {{query}} - -{% if context or links %}{{"\n"}} -Prompt:{% if context %}{{context}}{% endif %}{% if links %}{{links}}{% endif %} -{% endif %} - -###Réponse : -