-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from aws-samples/api
feat: Api
- Loading branch information
Showing
27 changed files
with
1,152 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
FROM python:3.10-slim | ||
|
||
WORKDIR /app | ||
|
||
COPY . /app/ | ||
RUN pip3 install -r requirements-api.txt | ||
|
||
EXPOSE 8000 | ||
|
||
ENTRYPOINT ["uvicorn", "main:app", "--host", "0.0.0.0"] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
class _Const(object): | ||
class ConstError(TypeError): | ||
def __init__(self, msg): | ||
super().__init__(msg) | ||
|
||
def __setattr__(self, name, value): | ||
if name in self.__dict__: | ||
err = self.ConstError("Can't change const.%s" % name) | ||
raise err | ||
if not name.isupper(): | ||
err = self.ConstError('Const name "%s" is not all uppercase' % name) | ||
raise err | ||
self.__dict__[name] = value | ||
|
||
|
||
const = _Const() | ||
|
||
const.MODE = 'mode' | ||
const.MODE_DEV = 'dev' | ||
const.BEDROCK_MODEL_IDS = ['anthropic.claude-3-sonnet-20240229-v1:0', | ||
'anthropic.claude-3-haiku-20240307-v1:0', | ||
'anthropic.claude-v2:1', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from enum import Enum, unique | ||
from .constant import const | ||
|
||
|
||
@unique | ||
class ErrorEnum(Enum): | ||
SUCCEEDED = {1: "Operation succeeded"} | ||
NOT_SUPPORTED = {1001: "Your query statement is currently not supported by the system"} | ||
INVAILD_BEDROCK_MODEL_ID = {1002: f"Invalid bedrock model id.Vaild ids:{const.BEDROCK_MODEL_IDS}"} | ||
UNKNOWN_ERROR = {9999: "Unknown error."} | ||
|
||
def get_code(self): | ||
return list(self.value.keys())[0] | ||
|
||
def get_message(self): | ||
return list(self.value.values())[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
from fastapi.responses import JSONResponse | ||
from fastapi import status, FastAPI, Request, Response | ||
from fastapi.exceptions import RequestValidationError | ||
from .enum import ErrorEnum | ||
from .constant import const | ||
import traceback | ||
from loguru import logger | ||
|
||
|
||
def response_error(code: int, message: str, status_code: int = status.HTTP_400_BAD_REQUEST) -> Response: | ||
headers = {} | ||
if os.getenv(const.MODE) == const.MODE_DEV: | ||
headers = { | ||
'Access-Control-Allow-Origin': '*', | ||
'Access-Control-Allow-Headers': '*', | ||
'Access-Control-Allow-Methods': '*', | ||
'Access-Control-Allow-Credentials': 'true' | ||
} | ||
return JSONResponse( | ||
content={ | ||
'code': code, | ||
'message': message, | ||
}, | ||
headers=headers, | ||
status_code=status_code, | ||
) | ||
|
||
|
||
def biz_exception(app: FastAPI): | ||
# customize request validation error | ||
@app.exception_handler(RequestValidationError) | ||
async def val_exception_handler(req: Request, rve: RequestValidationError, code: int = status.HTTP_422_UNPROCESSABLE_ENTITY): | ||
lst = [] | ||
for error in rve.errors(): | ||
lst.append('{}=>{}'.format('.'.join(error['loc']), error['msg'])) | ||
return response_error(code, ' , '.join(lst)) | ||
|
||
# customize business error | ||
@app.exception_handler(BizException) | ||
async def biz_exception_handler(req: Request, exc: BizException): | ||
return response_error(exc.code, exc.message) | ||
|
||
# system error | ||
@app.exception_handler(Exception) | ||
async def exception_handler(req: Request, exc: Exception): | ||
if isinstance(exc, BizException): | ||
return | ||
error_msg = traceback.format_exc() | ||
logger.error(error_msg) | ||
return response_error(ErrorEnum.UNKNOWN_ERROR.get_code(), error_msg, status.HTTP_500_INTERNAL_SERVER_ERROR) | ||
|
||
|
||
class BizException(Exception): | ||
def __init__(self, error_message: ErrorEnum): | ||
self.code = error_message.get_code() | ||
self.message = error_message.get_message() | ||
|
||
|
||
def __msg__(self): | ||
return self.message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import json | ||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect | ||
from loguru import logger | ||
from .schemas import Question, Answer, Option | ||
from . import service | ||
|
||
router = APIRouter(prefix="/qa", tags=["qa"]) | ||
|
||
|
||
@router.get("/option", response_model=Option) | ||
def option(): | ||
return service.get_option() | ||
|
||
|
||
@router.post("/ask", response_model=Answer) | ||
def ask(question: Question): | ||
return service.ask(question) | ||
|
||
|
||
@router.websocket("/ws") | ||
async def websocket_endpoint(websocket: WebSocket): | ||
await websocket.accept() | ||
try: | ||
while True: | ||
data = await websocket.receive_text() | ||
question_json = json.loads(data) | ||
question = Question(**question_json) | ||
logger.info(question) | ||
current_nlq_chain = service.get_nlq_chain(question) | ||
if question.use_rag: | ||
examples = service.get_example(current_nlq_chain) | ||
await websocket.send_text("Examples:\n```json\n") | ||
await websocket.send_text(str(examples)) | ||
await websocket.send_text("\n```\n") | ||
response = service.ask_with_response_stream(question, current_nlq_chain) | ||
result_pieces = [] | ||
for event in response['body']: | ||
final_answer = event["chunk"]["bytes"].decode('utf8') | ||
# logger.info(final_answer) | ||
current_content = json.loads(final_answer) | ||
if current_content.get("type") == "content_block_delta": | ||
current_text = current_content.get("delta").get("text") | ||
result_pieces.append(current_text) | ||
await websocket.send_text(current_text) | ||
elif current_content.get("type") == "content_block_stop": | ||
break | ||
current_nlq_chain.set_generated_sql_response(''.join(result_pieces)) | ||
if question.query_result: | ||
final_sql_query_result = service.get_executed_result(current_nlq_chain) | ||
await websocket.send_text("\n\nQuery result: \n") | ||
await websocket.send_text(final_sql_query_result) | ||
await websocket.send_text("\n") | ||
except WebSocketDisconnect: | ||
logger.info(f"{websocket.client.host} disconnected.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from typing import Any | ||
from pydantic import BaseModel | ||
|
||
|
||
class Question(BaseModel): | ||
keywords: str | ||
bedrock_model_id: str = "anthropic.claude-3-sonnet-20240229-v1:0" | ||
use_rag: bool = True | ||
query_result: bool = True | ||
intent_ner_recognition: bool = False | ||
profile_name: str = "shopping_guide" | ||
|
||
|
||
class Example(BaseModel): | ||
score: float | ||
question: str | ||
answer: str | ||
|
||
|
||
class Answer(BaseModel): | ||
examples: list[Example] | ||
sql: str | ||
sql_explain: str | ||
sql_query_result: list[Any] | ||
|
||
|
||
class Option(BaseModel): | ||
data_profiles: list[str] | ||
model_ids: list[str] |
Oops, something went wrong.