diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index c20e3365..5d2e8b57 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -33,6 +33,7 @@ AMAZON_BEDROCK_TRACE_KEY = "amazon-bedrock-trace" GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAssessment" +GUARDRAILS_CONFIG_KEY = "amazon-bedrock-guardrailConfig" HUMAN_PROMPT = "\n\nHuman:" ASSISTANT_PROMPT = "\n\nAssistant:" ALTERNATION_ERROR = ( @@ -427,6 +428,7 @@ class BedrockBase(BaseLanguageModel, ABC): "trace": None, "guardrailIdentifier": None, "guardrailVersion": None, + "guardrailConfig": None, } """ An optional dictionary to configure guardrails for Bedrock. @@ -439,21 +441,21 @@ class BedrockBase(BaseLanguageModel, ABC): Optional[Mapping[str, str]]: A mapping with 'id' and 'version' keys. Example: - llm = Bedrock(model_id="", client=, + llm = BedrockLLM(model_id="", client=, model_kwargs={}, guardrails={ - "id": "", - "version": ""}) + "guardrailIdentifier": "", + "guardrailVersion": ""}) To enable tracing for guardrails, set the 'trace' key to True and pass a callback handler to the 'run_manager' parameter of the 'generate', '_call' methods. Example: - llm = Bedrock(model_id="", client=, + llm = BedrockLLM(model_id="", client=, model_kwargs={}, guardrails={ - "id": "", - "version": "", + "guardrailIdentifier": "", + "guardrailVersion": "", "trace": True}, callbacks=[BedrockAsyncCallbackHandler()]) @@ -523,13 +525,15 @@ def validate_environment(cls, values: Dict) -> Dict: @property def _identifying_params(self) -> Dict[str, Any]: _model_kwargs = self.model_kwargs or {} + _guardrails = self.guardrails or {} return { "model_id": self.model_id, "provider": self._get_provider(), "stream": self.streaming, - "trace": self.guardrails.get("trace"), # type: ignore[union-attr] - "guardrailIdentifier": self.guardrails.get("guardrailIdentifier", None), # type: ignore[union-attr] - "guardrailVersion": self.guardrails.get("guardrailVersion", None), # type: ignore[union-attr] + "trace": _guardrails.get("trace"), # type: ignore[union-attr] + "guardrail_identifier": _guardrails.get("guardrailIdentifier", None), # type: ignore[union-attr] + "guardrail_version": _guardrails.get("guardrailVersion", None), # type: ignore[union-attr] + "guardrail_config": _guardrails.get("guardrailConfig", None), # type: ignore[union-attr] **_model_kwargs, } @@ -538,8 +542,7 @@ def _get_provider(self) -> str: return self.provider if self.model_id.startswith("arn"): raise ValueError( - "Model provider should be supplied when passing a model ARN as " - "model_id" + "Model provider should be supplied when passing a model ARN as model_id" ) return self.model_id.split(".")[0] @@ -572,8 +575,8 @@ def _guardrails_enabled(self) -> bool: except KeyError as e: raise TypeError( - "Guardrails must be a dictionary with 'guardrailIdentifier' \ - and 'guardrailVersion' keys." + "Guardrails must be a dictionary with 'guardrailIdentifier' " + "and 'guardrailVersion' mandatory keys." ) from e def _prepare_input_and_invoke( @@ -597,26 +600,32 @@ def _prepare_input_and_invoke( system=system, messages=messages, ) + + guardrails = {} + if self._guardrails_enabled: + guardrails["guardrailIdentifier"] = self.guardrails.get( # type: ignore[union-attr] + "guardrailIdentifier", "" + ) + guardrails["guardrailVersion"] = self.guardrails.get("guardrailVersion", "") # type: ignore[union-attr] + if self.guardrails.get("trace") is not None: # type: ignore[union-attr] + guardrails["trace"] = "ENABLED" + + if self.guardrails.get("guardrailConfig") is not None: # type: ignore[union-attr] + input_body[GUARDRAILS_CONFIG_KEY] = self.guardrails.get( # type: ignore[union-attr] + "guardrailConfig", "" + ) + body = json.dumps(input_body) - accept = "application/json" - contentType = "application/json" request_options = { "body": body, "modelId": self.model_id, - "accept": accept, - "contentType": contentType, + "accept": "application/json", + "contentType": "application/json", } - if self._guardrails_enabled: - request_options["guardrailIdentifier"] = self.guardrails.get( # type: ignore[union-attr] - "guardrailIdentifier", "" - ) - request_options["guardrailVersion"] = self.guardrails.get( # type: ignore[union-attr] - "guardrailVersion", "" - ) - if self.guardrails.get("trace"): # type: ignore[union-attr] - request_options["trace"] = "ENABLED" + if guardrails: + request_options.update(guardrails) try: response = self.client.invoke_model(**request_options) @@ -711,6 +720,21 @@ def _prepare_input_and_invoke_stream( messages=messages, model_kwargs=params, ) + + guardrails = {} + if self._guardrails_enabled: + guardrails["guardrailIdentifier"] = self.guardrails.get( # type: ignore[union-attr] + "guardrailIdentifier", "" + ) + guardrails["guardrailVersion"] = self.guardrails.get("guardrailVersion", "") # type: ignore[union-attr] + if self.guardrails.get("trace") is not None: # type: ignore[union-attr] + guardrails["trace"] = "ENABLED" + + if self.guardrails.get("guardrailConfig") is not None: # type: ignore[union-attr] + input_body[GUARDRAILS_CONFIG_KEY] = self.guardrails.get( # type: ignore[union-attr] + "guardrailConfig", "" + ) + body = json.dumps(input_body) request_options = { @@ -720,15 +744,8 @@ def _prepare_input_and_invoke_stream( "contentType": "application/json", } - if self._guardrails_enabled: - request_options["guardrailIdentifier"] = self.guardrails.get( # type: ignore[union-attr] - "guardrailIdentifier", "" - ) - request_options["guardrailVersion"] = self.guardrails.get( # type: ignore[union-attr] - "guardrailVersion", "" - ) - if self.guardrails.get("trace"): # type: ignore[union-attr] - request_options["trace"] = "ENABLED" + if guardrails: + request_options.update(guardrails) try: response = self.client.invoke_model_with_response_stream(**request_options) diff --git a/libs/aws/tests/callbacks.py b/libs/aws/tests/callbacks.py index 3a3902a0..9e573c76 100644 --- a/libs/aws/tests/callbacks.py +++ b/libs/aws/tests/callbacks.py @@ -22,7 +22,7 @@ class BaseFakeCallbackHandler(BaseModel): ignore_retriever_: bool = False ignore_chat_model_: bool = False - # to allow for similar callback handlers that are not technicall equal + # to allow for similar callback handlers that are not technically equal fake_id: Union[str, None] = None # add finer-grained counters for easier debugging of failing tests diff --git a/libs/aws/tests/conftest.py b/libs/aws/tests/conftest.py new file mode 100644 index 00000000..6cc36697 --- /dev/null +++ b/libs/aws/tests/conftest.py @@ -0,0 +1,96 @@ +import json +from typing import Dict +from unittest.mock import MagicMock + +import pytest + + +@pytest.fixture +def mistral_response() -> Dict: + body = MagicMock() + body.read.return_value = json.dumps( + {"outputs": [{"text": "This is the Mistral output text."}]} + ).encode() + response = dict( + body=body, + ResponseMetadata={ + "HTTPHeaders": { + "x-amzn-bedrock-input-token-count": "18", + "x-amzn-bedrock-output-token-count": "28", + } + }, + ) + + return response + + +@pytest.fixture +def cohere_response() -> Dict: + body = MagicMock() + body.read.return_value = json.dumps( + {"generations": [{"text": "This is the Cohere output text."}]} + ).encode() + response = dict( + body=body, + ResponseMetadata={ + "HTTPHeaders": { + "x-amzn-bedrock-input-token-count": "12", + "x-amzn-bedrock-output-token-count": "22", + } + }, + ) + return response + + +@pytest.fixture +def anthropic_response() -> Dict: + body = MagicMock() + body.read.return_value = json.dumps( + {"completion": "This is the output text."} + ).encode() + response = dict( + body=body, + ResponseMetadata={ + "HTTPHeaders": { + "x-amzn-bedrock-input-token-count": "10", + "x-amzn-bedrock-output-token-count": "20", + } + }, + ) + return response + + +@pytest.fixture +def ai21_response() -> Dict: + body = MagicMock() + body.read.return_value = json.dumps( + {"completions": [{"data": {"text": "This is the AI21 output text."}}]} + ).encode() + response = dict( + body=body, + ResponseMetadata={ + "HTTPHeaders": { + "x-amzn-bedrock-input-token-count": "15", + "x-amzn-bedrock-output-token-count": "25", + } + }, + ) + return response + + +@pytest.fixture +def response_with_stop_reason() -> Dict: + body = MagicMock() + body.read.return_value = json.dumps( + {"completion": "This is the output text.", "stop_reason": "length"} + ).encode() + response = dict( + body=body, + ResponseMetadata={ + "HTTPHeaders": { + "x-amzn-bedrock-input-token-count": "10", + "x-amzn-bedrock-output-token-count": "20", + } + }, + ) + return response diff --git a/libs/aws/tests/unit_tests/llms/test_bedrock.py b/libs/aws/tests/unit_tests/llms/test_bedrock.py index 7693cb19..f6d6f26c 100644 --- a/libs/aws/tests/unit_tests/llms/test_bedrock.py +++ b/libs/aws/tests/unit_tests/llms/test_bedrock.py @@ -13,269 +13,126 @@ _human_assistant_format, ) -TEST_CASES = { - """Hey""": """ - -Human: Hey - -Assistant:""", - """ - -Human: Hello - -Assistant:""": """ - -Human: Hello - -Assistant:""", - """Human: Hello - -Assistant:""": """ - -Human: Hello - -Assistant:""", - """ -Human: Hello - -Assistant:""": """ - -Human: Hello - -Assistant:""", - """ - -Human: Human: Hello - -Assistant:""": ( - "Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'." - ), - """Human: Hello - -Assistant: Hello - -Human: Hello - -Assistant:""": """ - -Human: Hello - -Assistant: Hello - -Human: Hello - -Assistant:""", - """ - -Human: Hello - -Assistant: Hello - -Human: Hello - -Assistant:""": """ - -Human: Hello - -Assistant: Hello - -Human: Hello - -Assistant:""", - """ - -Human: Hello - -Assistant: Hello - -Human: Hello - -Assistant: Hello - -Assistant: Hello""": ALTERNATION_ERROR, - """ - -Human: Hi. - -Assistant: Hi. - -Human: Hi. - -Human: Hi. - -Assistant:""": ALTERNATION_ERROR, - """ -Human: Hello""": """ - -Human: Hello - -Assistant:""", - """ - -Human: Hello -Hello - -Assistant""": """ - -Human: Hello -Hello - -Assistant - -Assistant:""", - """Hello - -Assistant:""": """ - -Human: Hello - -Assistant:""", - """Hello - -Human: Hello - -""": """Hello - -Human: Hello - - - -Assistant:""", - """ - -Human: Assistant: Hello""": """ - -Human: - -Assistant: Hello""", - """ - -Human: Human - -Assistant: Assistant - -Human: Assistant - -Assistant: Human""": """ - -Human: Human - -Assistant: Assistant - -Human: Assistant - -Assistant: Human""", - """ -Assistant: Hello there, your name is: - -Human. - -Human: Hello there, your name is: - -Assistant.""": """ - -Human: - -Assistant: Hello there, your name is: - -Human. - -Human: Hello there, your name is: - -Assistant. - -Assistant:""", - """ - -Human: Human: Hi - -Assistant: Hi""": ALTERNATION_ERROR, - """Human: Hi - -Human: Hi""": ALTERNATION_ERROR, - """ - -Assistant: Hi - -Human: Hi""": """ - -Human: - -Assistant: Hi - -Human: Hi - -Assistant:""", - """ - -Human: Hi - -Assistant: Yo - -Human: Hey - -Assistant: Sup - -Human: Hi - -Assistant: Hi -Human: Hi -Assistant:""": """ - -Human: Hi - -Assistant: Yo - -Human: Hey - -Assistant: Sup - -Human: Hi - -Assistant: Hi - -Human: Hi - -Assistant:""", - """ - -Hello. - -Human: Hello. - -Assistant:""": """ - -Hello. - -Human: Hello. - -Assistant:""", -} - - -def test__human_assistant_format() -> None: - for input_text, expected_output in TEST_CASES.items(): - if expected_output == ALTERNATION_ERROR: - with pytest.warns(UserWarning, match=ALTERNATION_ERROR): - _human_assistant_format(input_text) - else: - output = _human_assistant_format(input_text) - assert output == expected_output - - -# Sample mock streaming response data -MOCK_STREAMING_RESPONSE = [ - {"chunk": {"bytes": b'{"text": "nice"}'}}, - {"chunk": {"bytes": b'{"text": " to meet"}'}}, - {"chunk": {"bytes": b'{"text": " you"}'}}, -] +mock_boto3 = MagicMock() +# Mocking the client method of the Session object +mock_boto3.Session.return_value.client.return_value = MagicMock() async def async_gen_mock_streaming_response() -> AsyncGenerator[Dict, None]: + # Sample mock streaming response data + + MOCK_STREAMING_RESPONSE = [ + {"chunk": {"bytes": b'{"text": "nice"}'}}, + {"chunk": {"bytes": b'{"text": " to meet"}'}}, + {"chunk": {"bytes": b'{"text": " you"}'}}, + ] for item in MOCK_STREAMING_RESPONSE: yield item +@pytest.mark.parametrize( + "input_text, expected_output", + [ + ( + """Hey""", + """\n\nHuman: Hey\n\nAssistant:""", + ), + ( + """\n\nHuman: Hello\n\nAssistant:""", + """\n\nHuman: Hello\n\nAssistant:""", + ), + ( + """Human: Hello\n\nAssistant:""", + """\n\nHuman: Hello\n\nAssistant:""", + ), + ( + """\nHuman: Hello\n\nAssistant:""", + """\n\nHuman: Hello\n\nAssistant:""", + ), + ( + """\n\nHuman: Human: Hello\n\nAssistant:""", + "Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'.", + ), + ( + """Human: Hello\n\nAssistant: Hello\n\nHuman: Hello\n\nAssistant:""", + """\n\nHuman: Hello\n\nAssistant: Hello\n\nHuman: Hello\n\nAssistant:""", + ), + ( + """\n\nHuman: Hello\n\nAssistant: Hello\n\nHuman: Hello\n\nAssistant:""", + """\n\nHuman: Hello\n\nAssistant: Hello\n\nHuman: Hello\n\nAssistant:""", + ), + ( + """\n\nHuman: Hello\n\nAssistant: Hello\n\nHuman: """ + """Hello\n\nAssistant: Hello\n\nAssistant: Hello""", + ALTERNATION_ERROR, + ), + ( + """\n\nHuman: Hi.\n\nAssistant: Hi.\n\nHuman: Hi.\n\nHuman: Hi.""" + """\n\nAssistant:""", + ALTERNATION_ERROR, + ), + ( + """\nHuman: Hello""", + """\n\nHuman: Hello\n\nAssistant:""", + ), + ( + """\n\nHuman: Hello\nHello\n\nAssistant""", + """\n\nHuman: Hello\nHello\n\nAssistant\n\nAssistant:""", + ), + ( + """Hello\n\nAssistant:""", + """\n\nHuman: Hello\n\nAssistant:""", + ), + ( + """Hello\n\nHuman: Hello\n\n""", + """Hello\n\nHuman: Hello\n\n\n\nAssistant:""", + ), + ( + """\n\nHuman: Assistant: Hello""", + """\n\nHuman: \n\nAssistant: Hello""", + ), + ( + """\n\nHuman: Human\n\nAssistant: Assistant\n\nHuman: Assistant\n\n""" + """Assistant: Human""", + """\n\nHuman: Human\n\nAssistant: Assistant\n\nHuman: Assistant\n\n""" + """Assistant: Human""", + ), + ( + """\n\nAssistant: Hello there, your name is:\n\nHuman.\n\nHuman: """ + """Hello there, your name is: Assistant.""", + """\n\nHuman: \n\nAssistant: Hello there, your name is:\n\nHuman.""" + """\n\nHuman: Hello there, your name is: Assistant.\n\nAssistant:""", + ), + ("""\n\nHuman: Human: Hi\n\nAssistant: Hi""", ALTERNATION_ERROR), + ( + """Human: Hi\n\nHuman: Hi""", + ALTERNATION_ERROR, + ), + ( + """\n\nAssistant: Hi\n\nHuman: Hi""", + """\n\nHuman: \n\nAssistant: Hi\n\nHuman: Hi\n\nAssistant:""", + ), + ( + """\n\nHuman: Hi\n\nAssistant: Yo\n\nHuman: Hey\n\nAssistant: Sup""" + """\n\nHuman: Hi\n\nAssistant: Hi\n\nHuman: Hi\n\nAssistant:""", + """\n\nHuman: Hi\n\nAssistant: Yo\n\nHuman: Hey\n\nAssistant: Sup""" + """\n\nHuman: Hi\n\nAssistant: Hi\n\nHuman: Hi\n\nAssistant:""", + ), + ( + """\n\nHello.\n\nHuman: Hello.\n\nAssistant:""", + """\n\nHello.\n\nHuman: Hello.\n\nAssistant:""", + ), + ], +) +def test_human_assistant_format(input_text, expected_output) -> None: + if expected_output == ALTERNATION_ERROR: + with pytest.warns(UserWarning, match=ALTERNATION_ERROR): + _human_assistant_format(input_text) + else: + output = _human_assistant_format(input_text) + assert output == expected_output + + @pytest.mark.asyncio async def test_bedrock_async_streaming_call() -> None: # Mock boto3 import @@ -311,97 +168,6 @@ async def test_bedrock_async_streaming_call() -> None: assert chunks[2] == " you" -@pytest.fixture -def mistral_response(): - body = MagicMock() - body.read.return_value = json.dumps( - {"outputs": [{"text": "This is the Mistral output text."}]} - ).encode() - response = dict( - body=body, - ResponseMetadata={ - "HTTPHeaders": { - "x-amzn-bedrock-input-token-count": "18", - "x-amzn-bedrock-output-token-count": "28", - } - }, - ) - - return response - - -@pytest.fixture -def cohere_response(): - body = MagicMock() - body.read.return_value = json.dumps( - {"generations": [{"text": "This is the Cohere output text."}]} - ).encode() - response = dict( - body=body, - ResponseMetadata={ - "HTTPHeaders": { - "x-amzn-bedrock-input-token-count": "12", - "x-amzn-bedrock-output-token-count": "22", - } - }, - ) - return response - - -@pytest.fixture -def anthropic_response(): - body = MagicMock() - body.read.return_value = json.dumps( - {"completion": "This is the output text."} - ).encode() - response = dict( - body=body, - ResponseMetadata={ - "HTTPHeaders": { - "x-amzn-bedrock-input-token-count": "10", - "x-amzn-bedrock-output-token-count": "20", - } - }, - ) - return response - - -@pytest.fixture -def ai21_response(): - body = MagicMock() - body.read.return_value = json.dumps( - {"completions": [{"data": {"text": "This is the AI21 output text."}}]} - ).encode() - response = dict( - body=body, - ResponseMetadata={ - "HTTPHeaders": { - "x-amzn-bedrock-input-token-count": "15", - "x-amzn-bedrock-output-token-count": "25", - } - }, - ) - return response - - -@pytest.fixture -def response_with_stop_reason(): - body = MagicMock() - body.read.return_value = json.dumps( - {"completion": "This is the output text.", "stop_reason": "length"} - ).encode() - response = dict( - body=body, - ResponseMetadata={ - "HTTPHeaders": { - "x-amzn-bedrock-input-token-count": "10", - "x-amzn-bedrock-output-token-count": "20", - } - }, - ) - return response - - def test_prepare_output_for_mistral(mistral_response): result = LLMInputOutputAdapter.prepare_output("mistral", mistral_response) assert result["text"] == "This is the Mistral output text." @@ -447,3 +213,56 @@ def test_prepare_output_for_ai21(ai21_response): assert result["usage"]["completion_tokens"] == 25 assert result["usage"]["total_tokens"] == 40 assert result["stop_reason"] is None + + +@pytest.mark.parametrize( + "error_state, guardrail_input", + [ + (True, {}), + (True, {"guardrailIdentifier": "some-id"}), + (True, {"guardrailVersion": "some-version"}), + (True, {"guardrailConfig": {"config": "value"}}), + ( + False, + { + "guardrailIdentifier": "some-id", + "guardrailVersion": "some-version", + "trace": True, + }, + ), + ( + False, + { + "guardrailIdentifier": "some-id", + "guardrailVersion": "some-version", + "guardrailConfig": {"streamProcessingMode": "SYNCHRONOUS"}, + }, + ), + ], +) +async def test_guardrail_input( + error_state, + guardrail_input, +): + llm = BedrockLLM( + client=mock_boto3, + model_id="anthropic.claude-v2", + guardrails=guardrail_input, + ) + is_valid = False + if error_state: + with pytest.raises(TypeError) as error: + is_valid = llm._guardrails_enabled + assert error.value.args[0] == ( + "Guardrails must be a dictionary with 'guardrailIdentifier' and " + "'guardrailVersion' mandatory keys." + ) + assert is_valid is False + else: + llm = BedrockLLM( + client=mock_boto3, + model_id="anthropic.claude-v2", + guardrails=guardrail_input, + ) + is_valid = llm._guardrails_enabled + assert is_valid is True