Skip to content

Commit

Permalink
feat: Guardrails for Amazon Bedrock (#520)
Browse files Browse the repository at this point in the history
* wip

* fontend merge

* wip

* guardrailの一部機能をデプロイ

* bugfix

* mypy, blackの適用

* cdkのテストコード修正

* python formatter

* fix frontend ci

* fix migration guide

* wip

* bug fix

* fix migration guide

* add migration guide arch img

* add unit tests

* add: ja helpers

* fix: simplify find_public_bot_by_id

* nits: simplify websocket.py

* nits: explanation why BedrockRegionResourcesStack needed

* fix: textAttachement -> attachment

* refactoring example for compose_args

* fix: bedrock client import error on agent

* integrate compose_args_for_converse_api

* remove unused code

* sync chat impl

* chore: mypy

* fix: regex for s3 uri

* fix: not work for multi turn conversaition

* chore: support topK on streaming

* fix: unittests

* lint: black

* doc: add AWS Backup

* doc: add v2 notification

---------

Co-authored-by: statefb <[email protected]>
  • Loading branch information
fsatsuki and statefb authored Oct 10, 2024
1 parent dd51a2e commit 101dc79
Show file tree
Hide file tree
Showing 52 changed files with 3,476 additions and 1,882 deletions.
4 changes: 2 additions & 2 deletions backend/app/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ConverseApiToolResult,
ConverseApiToolUseContent,
calculate_price,
get_bedrock_client,
get_bedrock_runtime_client,
get_model_id,
)
from app.repositories.models.conversation import (
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
):
self.bot = bot
self.tools = {tool.name: tool for tool in tools}
self.client = get_bedrock_client()
self.client = get_bedrock_runtime_client()
self.model: type_model_name = model
self.model_id = get_model_id(model)
self.on_thinking = on_thinking
Expand Down
164 changes: 98 additions & 66 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from app.config import BEDROCK_PRICING, DEFAULT_EMBEDDING_CONFIG
from app.config import DEFAULT_GENERATION_CONFIG as DEFAULT_CLAUDE_GENERATION_CONFIG
from app.config import DEFAULT_MISTRAL_GENERATION_CONFIG
from app.repositories.models.conversation import MessageModel
from app.repositories.models.conversation import ContentModel, MessageModel
from app.repositories.models.custom_bot import GenerationParamsModel
from app.repositories.models.custom_bot_guardrails import BedrockGuardrailsModel
from app.routes.schemas.conversation import type_model_name
from app.utils import convert_dict_keys_to_camel_case, get_bedrock_client
from app.utils import convert_dict_keys_to_camel_case, get_bedrock_runtime_client
from typing_extensions import NotRequired, TypedDict, no_type_check

logger = logging.getLogger(__name__)
Expand All @@ -24,7 +25,14 @@
else DEFAULT_CLAUDE_GENERATION_CONFIG
)

client = get_bedrock_client()
client = get_bedrock_runtime_client()


class GuardrailConfig(TypedDict):
guardrailIdentifier: str
guardrailVersion: str
trace: str
streamProcessingMode: NotRequired[str]


class ConverseApiToolSpec(TypedDict):
Expand Down Expand Up @@ -56,6 +64,7 @@ class ConverseApiRequest(TypedDict):
messages: list[dict]
stream: bool
system: list[dict]
guardrailConfig: NotRequired[GuardrailConfig]
tool_config: NotRequired[ConverseApiToolConfig]


Expand Down Expand Up @@ -136,56 +145,73 @@ def _convert_to_valid_file_name(file_name: str) -> str:
return file_name


@no_type_check
def compose_args_for_converse_api(
messages: list[MessageModel],
model: type_model_name,
instruction: str | None = None,
stream: bool = False,
generation_params: GenerationParamsModel | None = None,
grounding_source: dict | None = None,
guardrail: BedrockGuardrailsModel | None = None,
) -> ConverseApiRequest:
"""Compose arguments for Converse API.
Ref: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html
"""
arg_messages = []
for message in messages:
if message.role not in ["system", "instruction"]:
content_blocks = []
for c in message.content:
if c.content_type == "text":
content_blocks.append({"text": c.body})
elif c.content_type == "image":
# e.g. "image/png" -> "png"
format = c.media_type.split("/")[1]
content_blocks.append(
{
"image": {
"format": format,
# decode base64 encoded image
"source": {"bytes": base64.b64decode(c.body)},
}
}
)
elif c.content_type == "attachment":
content_blocks.append(
{
"document": {
"format": _get_converse_supported_format(
Path(c.file_name).suffix[
1:
], # e.g. "document.txt" -> "txt"
),
"name": Path(
_convert_to_valid_file_name(c.file_name)
).stem, # e.g. "document.txt" -> "document"
# encode text attachment body
"source": {"bytes": base64.b64decode(c.body)},
}
def process_content(c: ContentModel, role: str):
if c.content_type == "text":
if role == "user" and guardrail and guardrail.grounding_threshold > 0:
return [
{"guardContent": grounding_source},
{
"guardContent": {
"text": {"text": c.body, "qualifiers": ["query"]}
}
)
else:
raise NotImplementedError()
arg_messages.append({"role": message.role, "content": content_blocks})
},
]
elif role == "assistant":
return [{"text": c.body if isinstance(c.body, str) else None}]
else:
return [{"text": c.body}]
elif c.content_type == "image":
format = c.media_type.split("/")[1] if c.media_type else "unknown"
return [
{
"image": {
"format": format,
"source": {"bytes": base64.b64decode(c.body)},
}
}
]
elif c.content_type == "attachment":
return [
{
"document": {
"format": _get_converse_supported_format(
Path(c.file_name).suffix[1:] # type: ignore
),
"name": Path(c.file_name).stem, # type: ignore
"source": {
"bytes": (
c.body.encode("utf-8")
if isinstance(c.body, str)
else c.body
)
}, # And this line
}
}
]
else:
raise NotImplementedError(f"Unsupported content type: {c.content_type}")

arg_messages = [
{
"role": message.role,
"content": [
block
for c in message.content
for block in process_content(c, message.role)
],
}
for message in messages
if message.role not in ["system", "instruction"]
]

inference_config = {
**DEFAULT_GENERATION_CONFIG,
Expand All @@ -201,40 +227,46 @@ def compose_args_for_converse_api(
),
}

# `top_k` is configured in `additional_model_request_fields` instead of `inference_config`
additional_model_request_fields = {"top_k": inference_config["top_k"]}
del inference_config["top_k"]
additional_model_request_fields = {"top_k": inference_config.pop("top_k")}

args: ConverseApiRequest = {
"inference_config": convert_dict_keys_to_camel_case(inference_config),
"additional_model_request_fields": additional_model_request_fields,
"model_id": get_model_id(model),
"messages": arg_messages,
"stream": stream,
"system": [],
"system": [{"text": instruction}] if instruction else [],
}
if instruction:
args["system"].append({"text": instruction})

if guardrail and guardrail.guardrail_arn and guardrail.guardrail_version:
args["guardrailConfig"] = {
"guardrailIdentifier": guardrail.guardrail_arn,
"guardrailVersion": guardrail.guardrail_version,
"trace": "enabled",
}

if stream:
# https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-streaming.html
args["guardrailConfig"]["streamProcessingMode"] = "async"

return args


def call_converse_api(args: ConverseApiRequest) -> ConverseApiResponse:
client = get_bedrock_client()
messages = args["messages"]
inference_config = args["inference_config"]
additional_model_request_fields = args["additional_model_request_fields"]
model_id = args["model_id"]
system = args["system"]

response = client.converse(
modelId=model_id,
messages=messages,
inferenceConfig=inference_config,
system=system,
additionalModelRequestFields=additional_model_request_fields,
)
client = get_bedrock_runtime_client()

base_args = {
"modelId": args["model_id"],
"messages": args["messages"],
"inferenceConfig": args["inference_config"],
"system": args["system"],
"additionalModelRequestFields": args["additional_model_request_fields"],
}

if "guardrailConfig" in args:
base_args["guardrailConfig"] = args["guardrailConfig"] # type: ignore

return response
return client.converse(**base_args)


def calculate_price(
Expand Down
3 changes: 2 additions & 1 deletion backend/app/bot_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

DB_SECRETS_ARN = os.environ.get("DB_SECRETS_ARN", "")
DOCUMENT_BUCKET = os.environ.get("DOCUMENT_BUCKET", "documents")
BEDROCK_REGION = os.environ.get("BEDROCK_REGION", "us-east-1")

s3_client = boto3.client("s3")
s3_client = boto3.client("s3", BEDROCK_REGION)


def delete_from_postgres(bot_id: str):
Expand Down
4 changes: 3 additions & 1 deletion backend/app/repositories/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
s3_client = boto3.client("s3")

THRESHOLD_LARGE_MESSAGE = 300 * 1024 # 300KB
LARGE_MESSAGE_BUCKET = os.environ.get("LARGE_MESSAGE_BUCKET")

BEDROCK_REGION = os.environ.get("BEDROCK_REGION", "us-east-1")
s3_client = boto3.client("s3", BEDROCK_REGION)


def store_conversation(
user_id: str, conversation: ConversationModel, threshold=THRESHOLD_LARGE_MESSAGE
Expand Down
59 changes: 59 additions & 0 deletions backend/app/repositories/custom_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
KnowledgeModel,
SearchParamsModel,
)
from app.repositories.models.custom_bot_guardrails import BedrockGuardrailsModel
from app.repositories.models.custom_bot_kb import BedrockKnowledgeBaseModel
from app.routes.schemas.bot import type_sync_status
from app.utils import get_current_time
Expand All @@ -50,6 +51,18 @@
sts_client = boto3.client("sts")


class BotNotFoundException(Exception):
"""Exception raised when a bot is not found."""

pass


class BotUpdateError(Exception):
"""Exception raised when there's an error updating a bot."""

pass


def store_bot(user_id: str, custom_bot: BotModel):
table = _get_table_client(user_id)
logger.info(f"Storing bot: {custom_bot}")
Expand Down Expand Up @@ -81,6 +94,8 @@ def store_bot(user_id: str, custom_bot: BotModel):
}
if custom_bot.bedrock_knowledge_base:
item["BedrockKnowledgeBase"] = custom_bot.bedrock_knowledge_base.model_dump()
if custom_bot.bedrock_guardrails:
item["GuardrailsParams"] = custom_bot.bedrock_guardrails.model_dump()

response = table.put_item(Item=item)
return response
Expand All @@ -102,6 +117,7 @@ def update_bot(
display_retrieved_chunks: bool,
conversation_quick_starters: list[ConversationQuickStarterModel],
bedrock_knowledge_base: BedrockKnowledgeBaseModel | None = None,
bedrock_guardrails: BedrockGuardrailsModel | None = None,
):
"""Update bot title, description, and instruction.
NOTE: Use `update_bot_visibility` to update visibility.
Expand Down Expand Up @@ -146,6 +162,12 @@ def update_bot(
bedrock_knowledge_base.model_dump()
)

if bedrock_guardrails:
update_expression += ", GuardrailsParams = :bedrock_guardrails"
expression_attribute_values[":bedrock_guardrails"] = (
bedrock_guardrails.model_dump()
)

try:
response = table.update_item(
Key={"PK": user_id, "SK": compose_bot_id(user_id, bot_id)},
Expand Down Expand Up @@ -291,6 +313,33 @@ def update_knowledge_base_id(
return response


def update_guardrails_params(
user_id: str, bot_id: str, guardrail_arn: str, guardrail_version: str
):
logger.info("update_guardrails_params")
table = _get_table_client(user_id)

try:
response = table.update_item(
Key={"PK": user_id, "SK": compose_bot_id(user_id, bot_id)},
UpdateExpression="SET GuardrailsParams.guardrail_arn = :guardrail_arn, GuardrailsParams.guardrail_version = :guardrail_version",
ExpressionAttributeValues={
":guardrail_arn": guardrail_arn,
":guardrail_version": guardrail_version,
},
ConditionExpression="attribute_exists(PK) AND attribute_exists(SK)",
ReturnValues="ALL_NEW",
)
logger.info(f"Updated guardrails_arn for bot: {bot_id} successfully")
except ClientError as e:
if e.response["Error"]["Code"] == "ConditionalCheckFailedException":
raise RecordNotFoundError(f"Bot with id {bot_id} not found")
else:
raise e

return response


def find_private_bots_by_user_id(
user_id: str, limit: int | None = None
) -> list[BotMeta]:
Expand Down Expand Up @@ -461,6 +510,11 @@ def find_private_bot_by_id(user_id: str, bot_id: str) -> BotModel:
if "BedrockKnowledgeBase" in item
else None
),
bedrock_guardrails=(
BedrockGuardrailsModel(**item["GuardrailsParams"])
if "GuardrailsParams" in item
else None
),
)

logger.info(f"Found bot: {bot}")
Expand Down Expand Up @@ -554,6 +608,11 @@ def find_public_bot_by_id(bot_id: str) -> BotModel:
if "BedrockKnowledgeBase" in item
else None
),
bedrock_guardrails=(
BedrockGuardrailsModel(**item["GuardrailsParams"])
if "GuardrailsParams" in item
else None
),
)
logger.info(f"Found public bot: {bot}")
return bot
Expand Down
2 changes: 2 additions & 0 deletions backend/app/repositories/models/custom_bot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from app.repositories.models.common import Float
from app.repositories.models.custom_bot_kb import BedrockKnowledgeBaseModel
from app.repositories.models.custom_bot_guardrails import BedrockGuardrailsModel
from app.routes.schemas.bot import type_sync_status
from pydantic import BaseModel

Expand Down Expand Up @@ -88,6 +89,7 @@ class BotModel(BaseModel):
display_retrieved_chunks: bool
conversation_quick_starters: list[ConversationQuickStarterModel]
bedrock_knowledge_base: BedrockKnowledgeBaseModel | None
bedrock_guardrails: BedrockGuardrailsModel | None

def has_knowledge(self) -> bool:
return (
Expand Down
Loading

0 comments on commit 101dc79

Please sign in to comment.