Skip to content

Commit

Permalink
Merge pull request #19 from aws-samples/api
Browse files Browse the repository at this point in the history
feat: Api
  • Loading branch information
supinyu authored Apr 11, 2024
2 parents 6a68d50 + 1219ff2 commit 3c0cfa5
Show file tree
Hide file tree
Showing 27 changed files with 1,152 additions and 21 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

*_test.py
# CDK
node_modules/

Expand All @@ -167,4 +168,4 @@ cdk.out
cdk.context.json
**/cdk.out
package-lock.json
!source/resources/lib/
!source/resources/lib/
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Create an EC2 with following configuration:
- OS Image (AMI): Amazon Linux 2023, Amazon Linux 2(AL2 End of Life is 2025-06-30)
- Instance type: t3.large or higher
- VPC: use default one and choose a public subnet
- Security group: Allow access to 22, 80 port from anywhere (Select "Allow SSH traffic from Anywhere" and "Allow HTTP traffic from the internet")
- Security group: Allow access to 22, 80, 8000 port from anywhere (Select "Allow SSH traffic from Anywhere" and "Allow HTTP traffic from the internet")
- Storage (volumes): 1 GP3 volume(s) - 30 GiB

### 2. Config Permission
Expand Down Expand Up @@ -149,6 +149,12 @@ Open in your browser: `http://<your-ec2-public-ip>`

Note: Use HTTP instead of HTTPS.

### 8. Access the API

Open in your browser: `http://<your-ec2-public-ip>:8000`

Note: Use HTTP instead of HTTPS.

## CDK Deployment Guide

### 1. Prepare CDK Pre-requisites
Expand Down
8 changes: 7 additions & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
- OS镜像(AMI): Amazon Linux 2023, Amazon Linux 2(AL2将在2025-06-30结束支持)
- 实例类型: t3.large或更高配置
- VPC: 使用默认的VPC并部署在公有子网
- 安全组: 允许任何位置访问22, 80端口 (勾选允许来自以下对象的SSH流量和允许来自互联网的HTTP流量)
- 安全组: 允许任何位置访问22, 80, 8000端口 (勾选允许来自以下对象的SSH流量和允许来自互联网的HTTP流量)
- 存储(卷): 1个GP3卷 - 30 GiB

### 2. 配置权限
Expand Down Expand Up @@ -141,6 +141,12 @@ docker exec nlq-webserver python opensearch_deploy.py custom false

注意:使用 HTTP 而不是 HTTPS。

### 8. 访问API

在浏览器中打开网址: `http://<your-ec2-public-ip>:8000`

注意:使用 HTTP 而不是 HTTPS。

## Demo应用使用自定义数据源的方法
1. 先在Data Connection Management和Data Profile Management页面创建对应的Data Profile

Expand Down
10 changes: 10 additions & 0 deletions application/Dockerfile-api
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FROM python:3.10-slim

WORKDIR /app

COPY . /app/
RUN pip3 install -r requirements-api.txt

EXPOSE 8000

ENTRYPOINT ["uvicorn", "main:app", "--host", "0.0.0.0"]
Empty file added application/api/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions application/api/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
class _Const(object):
class ConstError(TypeError):
def __init__(self, msg):
super().__init__(msg)

def __setattr__(self, name, value):
if name in self.__dict__:
err = self.ConstError("Can't change const.%s" % name)
raise err
if not name.isupper():
err = self.ConstError('Const name "%s" is not all uppercase' % name)
raise err
self.__dict__[name] = value


const = _Const()

const.MODE = 'mode'
const.MODE_DEV = 'dev'
const.BEDROCK_MODEL_IDS = ['anthropic.claude-3-sonnet-20240229-v1:0',
'anthropic.claude-3-haiku-20240307-v1:0',
'anthropic.claude-v2:1',
]
16 changes: 16 additions & 0 deletions application/api/enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from enum import Enum, unique
from .constant import const


@unique
class ErrorEnum(Enum):
SUCCEEDED = {1: "Operation succeeded"}
NOT_SUPPORTED = {1001: "Your query statement is currently not supported by the system"}
INVAILD_BEDROCK_MODEL_ID = {1002: f"Invalid bedrock model id.Vaild ids:{const.BEDROCK_MODEL_IDS}"}
UNKNOWN_ERROR = {9999: "Unknown error."}

def get_code(self):
return list(self.value.keys())[0]

def get_message(self):
return list(self.value.values())[0]
61 changes: 61 additions & 0 deletions application/api/exception_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
from fastapi.responses import JSONResponse
from fastapi import status, FastAPI, Request, Response
from fastapi.exceptions import RequestValidationError
from .enum import ErrorEnum
from .constant import const
import traceback
from loguru import logger


def response_error(code: int, message: str, status_code: int = status.HTTP_400_BAD_REQUEST) -> Response:
headers = {}
if os.getenv(const.MODE) == const.MODE_DEV:
headers = {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Headers': '*',
'Access-Control-Allow-Methods': '*',
'Access-Control-Allow-Credentials': 'true'
}
return JSONResponse(
content={
'code': code,
'message': message,
},
headers=headers,
status_code=status_code,
)


def biz_exception(app: FastAPI):
# customize request validation error
@app.exception_handler(RequestValidationError)
async def val_exception_handler(req: Request, rve: RequestValidationError, code: int = status.HTTP_422_UNPROCESSABLE_ENTITY):
lst = []
for error in rve.errors():
lst.append('{}=>{}'.format('.'.join(error['loc']), error['msg']))
return response_error(code, ' , '.join(lst))

# customize business error
@app.exception_handler(BizException)
async def biz_exception_handler(req: Request, exc: BizException):
return response_error(exc.code, exc.message)

# system error
@app.exception_handler(Exception)
async def exception_handler(req: Request, exc: Exception):
if isinstance(exc, BizException):
return
error_msg = traceback.format_exc()
logger.error(error_msg)
return response_error(ErrorEnum.UNKNOWN_ERROR.get_code(), error_msg, status.HTTP_500_INTERNAL_SERVER_ERROR)


class BizException(Exception):
def __init__(self, error_message: ErrorEnum):
self.code = error_message.get_code()
self.message = error_message.get_message()


def __msg__(self):
return self.message
54 changes: 54 additions & 0 deletions application/api/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import json
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from loguru import logger
from .schemas import Question, Answer, Option
from . import service

router = APIRouter(prefix="/qa", tags=["qa"])


@router.get("/option", response_model=Option)
def option():
return service.get_option()


@router.post("/ask", response_model=Answer)
def ask(question: Question):
return service.ask(question)


@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
question_json = json.loads(data)
question = Question(**question_json)
logger.info(question)
current_nlq_chain = service.get_nlq_chain(question)
if question.use_rag:
examples = service.get_example(current_nlq_chain)
await websocket.send_text("Examples:\n```json\n")
await websocket.send_text(str(examples))
await websocket.send_text("\n```\n")
response = service.ask_with_response_stream(question, current_nlq_chain)
result_pieces = []
for event in response['body']:
final_answer = event["chunk"]["bytes"].decode('utf8')
# logger.info(final_answer)
current_content = json.loads(final_answer)
if current_content.get("type") == "content_block_delta":
current_text = current_content.get("delta").get("text")
result_pieces.append(current_text)
await websocket.send_text(current_text)
elif current_content.get("type") == "content_block_stop":
break
current_nlq_chain.set_generated_sql_response(''.join(result_pieces))
if question.query_result:
final_sql_query_result = service.get_executed_result(current_nlq_chain)
await websocket.send_text("\n\nQuery result: \n")
await websocket.send_text(final_sql_query_result)
await websocket.send_text("\n")
except WebSocketDisconnect:
logger.info(f"{websocket.client.host} disconnected.")
29 changes: 29 additions & 0 deletions application/api/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any
from pydantic import BaseModel


class Question(BaseModel):
keywords: str
bedrock_model_id: str = "anthropic.claude-3-sonnet-20240229-v1:0"
use_rag: bool = True
query_result: bool = True
intent_ner_recognition: bool = False
profile_name: str = "shopping_guide"


class Example(BaseModel):
score: float
question: str
answer: str


class Answer(BaseModel):
examples: list[Example]
sql: str
sql_explain: str
sql_query_result: list[Any]


class Option(BaseModel):
data_profiles: list[str]
model_ids: list[str]
Loading

0 comments on commit 3c0cfa5

Please sign in to comment.