Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for amazon-bedrock-guardrailConfig #59

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 52 additions & 35 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

AMAZON_BEDROCK_TRACE_KEY = "amazon-bedrock-trace"
GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAssessment"
GUARDRAILS_CONFIG_KEY = "amazon-bedrock-guardrailConfig"
HUMAN_PROMPT = "\n\nHuman:"
ASSISTANT_PROMPT = "\n\nAssistant:"
ALTERNATION_ERROR = (
Expand Down Expand Up @@ -427,6 +428,7 @@ class BedrockBase(BaseLanguageModel, ABC):
"trace": None,
"guardrailIdentifier": None,
"guardrailVersion": None,
"guardrailConfig": None,
}
"""
An optional dictionary to configure guardrails for Bedrock.
Expand All @@ -439,21 +441,21 @@ class BedrockBase(BaseLanguageModel, ABC):
Optional[Mapping[str, str]]: A mapping with 'id' and 'version' keys.

Example:
llm = Bedrock(model_id="<model_id>", client=<bedrock_client>,
llm = BedrockLLM(model_id="<model_id>", client=<bedrock_client>,
model_kwargs={},
guardrails={
"id": "<guardrail_id>",
"version": "<guardrail_version>"})
"guardrailIdentifier": "<guardrail_id>",
"guardrailVersion": "<guardrail_version>"})

To enable tracing for guardrails, set the 'trace' key to True and pass a callback handler to the
'run_manager' parameter of the 'generate', '_call' methods.

Example:
llm = Bedrock(model_id="<model_id>", client=<bedrock_client>,
llm = BedrockLLM(model_id="<model_id>", client=<bedrock_client>,
model_kwargs={},
guardrails={
"id": "<guardrail_id>",
"version": "<guardrail_version>",
"guardrailIdentifier": "<guardrail_id>",
"guardrailVersion": "<guardrail_version>",
"trace": True},
callbacks=[BedrockAsyncCallbackHandler()])

Expand Down Expand Up @@ -523,13 +525,15 @@ def validate_environment(cls, values: Dict) -> Dict:
@property
def _identifying_params(self) -> Dict[str, Any]:
_model_kwargs = self.model_kwargs or {}
_guardrails = self.guardrails or {}
return {
"model_id": self.model_id,
"provider": self._get_provider(),
"stream": self.streaming,
"trace": self.guardrails.get("trace"), # type: ignore[union-attr]
"guardrailIdentifier": self.guardrails.get("guardrailIdentifier", None), # type: ignore[union-attr]
"guardrailVersion": self.guardrails.get("guardrailVersion", None), # type: ignore[union-attr]
"trace": _guardrails.get("trace"), # type: ignore[union-attr]
"guardrail_identifier": _guardrails.get("guardrailIdentifier", None), # type: ignore[union-attr]
"guardrail_version": _guardrails.get("guardrailVersion", None), # type: ignore[union-attr]
"guardrail_config": _guardrails.get("guardrailConfig", None), # type: ignore[union-attr]
**_model_kwargs,
}

Expand All @@ -538,8 +542,7 @@ def _get_provider(self) -> str:
return self.provider
if self.model_id.startswith("arn"):
raise ValueError(
"Model provider should be supplied when passing a model ARN as "
"model_id"
"Model provider should be supplied when passing a model ARN as model_id"
)

return self.model_id.split(".")[0]
Expand Down Expand Up @@ -572,8 +575,8 @@ def _guardrails_enabled(self) -> bool:

except KeyError as e:
raise TypeError(
"Guardrails must be a dictionary with 'guardrailIdentifier' \
and 'guardrailVersion' keys."
"Guardrails must be a dictionary with 'guardrailIdentifier' "
"and 'guardrailVersion' mandatory keys."
) from e

def _prepare_input_and_invoke(
Expand All @@ -597,26 +600,32 @@ def _prepare_input_and_invoke(
system=system,
messages=messages,
)

guardrails = {}
if self._guardrails_enabled:
guardrails["guardrailIdentifier"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailIdentifier", ""
)
guardrails["guardrailVersion"] = self.guardrails.get("guardrailVersion", "") # type: ignore[union-attr]
if self.guardrails.get("trace") is not None: # type: ignore[union-attr]
guardrails["trace"] = "ENABLED"

if self.guardrails.get("guardrailConfig") is not None: # type: ignore[union-attr]
input_body[GUARDRAILS_CONFIG_KEY] = self.guardrails.get( # type: ignore[union-attr]
"guardrailConfig", ""
)

body = json.dumps(input_body)
accept = "application/json"
contentType = "application/json"

request_options = {
"body": body,
"modelId": self.model_id,
"accept": accept,
"contentType": contentType,
"accept": "application/json",
"contentType": "application/json",
}

if self._guardrails_enabled:
request_options["guardrailIdentifier"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailIdentifier", ""
)
request_options["guardrailVersion"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailVersion", ""
)
if self.guardrails.get("trace"): # type: ignore[union-attr]
request_options["trace"] = "ENABLED"
if guardrails:
request_options.update(guardrails)

try:
response = self.client.invoke_model(**request_options)
Expand Down Expand Up @@ -711,6 +720,21 @@ def _prepare_input_and_invoke_stream(
messages=messages,
model_kwargs=params,
)

guardrails = {}
if self._guardrails_enabled:
guardrails["guardrailIdentifier"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailIdentifier", ""
)
guardrails["guardrailVersion"] = self.guardrails.get("guardrailVersion", "") # type: ignore[union-attr]
if self.guardrails.get("trace") is not None: # type: ignore[union-attr]
guardrails["trace"] = "ENABLED"

if self.guardrails.get("guardrailConfig") is not None: # type: ignore[union-attr]
input_body[GUARDRAILS_CONFIG_KEY] = self.guardrails.get( # type: ignore[union-attr]
"guardrailConfig", ""
)

body = json.dumps(input_body)

request_options = {
Expand All @@ -720,15 +744,8 @@ def _prepare_input_and_invoke_stream(
"contentType": "application/json",
}

if self._guardrails_enabled:
request_options["guardrailIdentifier"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailIdentifier", ""
)
request_options["guardrailVersion"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailVersion", ""
)
if self.guardrails.get("trace"): # type: ignore[union-attr]
request_options["trace"] = "ENABLED"
if guardrails:
request_options.update(guardrails)

try:
response = self.client.invoke_model_with_response_stream(**request_options)
Expand Down
2 changes: 1 addition & 1 deletion libs/aws/tests/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class BaseFakeCallbackHandler(BaseModel):
ignore_retriever_: bool = False
ignore_chat_model_: bool = False

# to allow for similar callback handlers that are not technicall equal
# to allow for similar callback handlers that are not technically equal
fake_id: Union[str, None] = None

# add finer-grained counters for easier debugging of failing tests
Expand Down
96 changes: 96 additions & 0 deletions libs/aws/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import json
from typing import Dict
from unittest.mock import MagicMock

import pytest


@pytest.fixture
def mistral_response() -> Dict:
body = MagicMock()
body.read.return_value = json.dumps(
{"outputs": [{"text": "This is the Mistral output text."}]}
).encode()
response = dict(
body=body,
ResponseMetadata={
"HTTPHeaders": {
"x-amzn-bedrock-input-token-count": "18",
"x-amzn-bedrock-output-token-count": "28",
}
},
)

return response


@pytest.fixture
def cohere_response() -> Dict:
body = MagicMock()
body.read.return_value = json.dumps(
{"generations": [{"text": "This is the Cohere output text."}]}
).encode()
response = dict(
body=body,
ResponseMetadata={
"HTTPHeaders": {
"x-amzn-bedrock-input-token-count": "12",
"x-amzn-bedrock-output-token-count": "22",
}
},
)
return response


@pytest.fixture
def anthropic_response() -> Dict:
body = MagicMock()
body.read.return_value = json.dumps(
{"completion": "This is the output text."}
).encode()
response = dict(
body=body,
ResponseMetadata={
"HTTPHeaders": {
"x-amzn-bedrock-input-token-count": "10",
"x-amzn-bedrock-output-token-count": "20",
}
},
)
return response


@pytest.fixture
def ai21_response() -> Dict:
body = MagicMock()
body.read.return_value = json.dumps(
{"completions": [{"data": {"text": "This is the AI21 output text."}}]}
).encode()
response = dict(
body=body,
ResponseMetadata={
"HTTPHeaders": {
"x-amzn-bedrock-input-token-count": "15",
"x-amzn-bedrock-output-token-count": "25",
}
},
)
return response


@pytest.fixture
def response_with_stop_reason() -> Dict:
body = MagicMock()
body.read.return_value = json.dumps(
{"completion": "This is the output text.", "stop_reason": "length"}
).encode()
response = dict(
body=body,
ResponseMetadata={
"HTTPHeaders": {
"x-amzn-bedrock-input-token-count": "10",
"x-amzn-bedrock-output-token-count": "20",
}
},
)
return response
Loading
Loading