From 466a12567cd0913e73cc73631f40c009a29be12b Mon Sep 17 00:00:00 2001 From: Ping Xiang Date: Tue, 22 Oct 2024 17:52:42 +0000 Subject: [PATCH] add contract tests for bedrock --- .../images/applications/aws-sdk/package.json | 4 + .../images/applications/aws-sdk/server.js | 133 +++++++++++++++++ .../tests/test/amazon/aws-sdk/aws_sdk_test.py | 139 ++++++++++++++++++ 3 files changed, 276 insertions(+) diff --git a/contract-tests/images/applications/aws-sdk/package.json b/contract-tests/images/applications/aws-sdk/package.json index 25ade2b..cef37d8 100644 --- a/contract-tests/images/applications/aws-sdk/package.json +++ b/contract-tests/images/applications/aws-sdk/package.json @@ -10,6 +10,10 @@ "license": "ISC", "description": "", "dependencies": { + "@aws-sdk/client-bedrock": "^3.675.0", + "@aws-sdk/client-bedrock-agent": "^3.675.0", + "@aws-sdk/client-bedrock-agent-runtime": "^3.676.0", + "@aws-sdk/client-bedrock-runtime": "^3.675.0", "@aws-sdk/client-dynamodb": "^3.658.1", "@aws-sdk/client-kinesis": "^3.658.1", "@aws-sdk/client-s3": "^3.658.1", diff --git a/contract-tests/images/applications/aws-sdk/server.js b/contract-tests/images/applications/aws-sdk/server.js index 9f9ce9e..0a84155 100644 --- a/contract-tests/images/applications/aws-sdk/server.js +++ b/contract-tests/images/applications/aws-sdk/server.js @@ -11,6 +11,11 @@ const { DynamoDBClient, CreateTableCommand, PutItemCommand } = require('@aws-sdk const { SQSClient, CreateQueueCommand, SendMessageCommand, ReceiveMessageCommand } = require('@aws-sdk/client-sqs'); const { KinesisClient, CreateStreamCommand, PutRecordCommand } = require('@aws-sdk/client-kinesis'); const fetch = require('node-fetch'); +const { BedrockClient, GetGuardrailCommand } = require('@aws-sdk/client-bedrock'); +const { BedrockAgentClient, GetKnowledgeBaseCommand, GetDataSourceCommand, GetAgentCommand } = require('@aws-sdk/client-bedrock-agent'); +const { BedrockRuntimeClient, InvokeModelCommand } = require('@aws-sdk/client-bedrock-runtime'); +const { BedrockAgentRuntimeClient, InvokeAgentCommand, RetrieveCommand } = require('@aws-sdk/client-bedrock-agent-runtime'); + const _PORT = 8080; const _ERROR = 'error'; @@ -141,6 +146,8 @@ async function handleGetRequest(req, res, path) { await handleSqsRequest(req, res, path); } else if (path.includes('kinesis')) { await handleKinesisRequest(req, res, path); + } else if (path.includes('bedrock')) { + await handleBedrockRequest(req, res, path); } else { res.writeHead(404); res.end(); @@ -485,6 +492,132 @@ async function handleKinesisRequest(req, res, path) { } } +async function handleBedrockRequest(req, res, path) { + const bedrockClient = new BedrockClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION }); + const bedrockAgentClient = new BedrockAgentClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION }); + const bedrockRuntimeClient = new BedrockRuntimeClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION }); + const bedrockAgentRuntimeClient = new BedrockAgentRuntimeClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION }); + + try { + if (path.includes('getknowledgebase/get_knowledge_base')) { + await withInjected200Success(bedrockAgentClient, ['GetKnowledgeBaseCommand'], {}, async () => { + await bedrockAgentClient.send(new GetKnowledgeBaseCommand({ knowledgeBaseId: 'invalid-knowledge-base-id' })); + }); + res.statusCode = 200; + } else if (path.includes('getdatasource/get_data_source')) { + await withInjected200Success(bedrockAgentClient, ['GetDataSourceCommand'], {}, async () => { + await bedrockAgentClient.send(new GetDataSourceCommand({ knowledgeBaseId: 'TESTKBSEID', dataSourceId: 'DATASURCID' })); + }); + res.statusCode = 200; + } else if (path.includes('getagent/get-agent')) { + await withInjected200Success(bedrockAgentClient, ['GetAgentCommand'], {}, async () => { + await bedrockAgentClient.send(new GetAgentCommand({ agentId: 'TESTAGENTID' })); + }); + res.statusCode = 200; + } else if (path.includes('getguardrail/get-guardrail')) { + await withInjected200Success( + bedrockClient, + ['GetGuardrailCommand'], + { guardrailId: 'bt4o77i015cu' }, + async () => { + await bedrockClient.send( + new GetGuardrailCommand({ + guardrailIdentifier: 'arn:aws:bedrock:us-east-1:000000000000:guardrail/bt4o77i015cu', + }) + ); + } + ); + res.statusCode = 200; + } else if (path.includes('invokeagent/invoke_agent')) { + await withInjected200Success(bedrockAgentRuntimeClient, ['InvokeAgentCommand'], {}, async () => { + await bedrockAgentRuntimeClient.send( + new InvokeAgentCommand({ + agentId: 'Q08WFRPHVL', + agentAliasId: 'testAlias', + sessionId: 'testSessionId', + inputText: 'Invoke agent sample input text', + }) + ); + }); + res.statusCode = 200; + } else if (path.includes('retrieve/retrieve')) { + await withInjected200Success(bedrockAgentRuntimeClient, ['RetrieveCommand'], {}, async () => { + await bedrockAgentRuntimeClient.send( + new RetrieveCommand({ + knowledgeBaseId: 'test-knowledge-base-id', + retrievalQuery: { + text: 'an example of retrieve query', + }, + }) + ); + }); + res.statusCode = 200; + } else if (path.includes('invokemodel/invoke-model')) { + await withInjected200Success(bedrockRuntimeClient, ['InvokeModelCommand'], {}, async () => { + const modelId = 'amazon.titan-text-premier-v1:0'; + const userMessage = "Describe the purpose of a 'hello world' program in one line."; + const prompt = `[INST] ${userMessage} [/INST]`; + + const body = JSON.stringify({ + inputText: prompt, + textGenerationConfig: { + maxTokenCount: 3072, + stopSequences: [], + temperature: 0.7, + topP: 0.9, + }, + }); + + await bedrockRuntimeClient.send( + new InvokeModelCommand({ + body: body, + modelId: modelId, + accept: 'application/json', + contentType: 'application/json', + }) + ); + }); + res.statusCode = 200; + } else { + res.statusCode = 404; + } + } catch (error) { + console.error('An error occurred:', error); + res.statusCode = 500; + } + + res.end(); +} + +function inject200Success(client, commandNames, additionalResponse = {}, middlewareName = 'inject200SuccessMiddleware') { + const middleware = (next, context) => async (args) => { + const { commandName } = context; + if (commandNames.includes(commandName)) { + const response = { + $metadata: { + httpStatusCode: 200, + requestId: 'mock-request-id', + }, + Message: 'Request succeeded', + ...additionalResponse, + }; + return { output: response }; + } + return next(args); + }; + // this middleware intercept the request and inject the response + client.middlewareStack.add(middleware, { step: 'build', name: middlewareName, priority: 'high' }); +} + +async function withInjected200Success(client, commandNames, additionalResponse, apiCall) { + const middlewareName = 'inject200SuccessMiddleware'; + inject200Success(client, commandNames, additionalResponse, middlewareName); + await apiCall(); + client.middlewareStack.remove(middlewareName); +} + + + prepareAwsServer().then(() => { server.listen(_PORT, '0.0.0.0', () => { console.log('Server is listening on port', _PORT); diff --git a/contract-tests/tests/test/amazon/aws-sdk/aws_sdk_test.py b/contract-tests/tests/test/amazon/aws-sdk/aws_sdk_test.py index 8773bb7..bb2ce3c 100644 --- a/contract-tests/tests/test/amazon/aws-sdk/aws_sdk_test.py +++ b/contract-tests/tests/test/amazon/aws-sdk/aws_sdk_test.py @@ -29,6 +29,12 @@ _AWS_SQS_QUEUE_URL: str = "aws.sqs.queue.url" _AWS_SQS_QUEUE_NAME: str = "aws.sqs.queue.name" _AWS_KINESIS_STREAM_NAME: str = "aws.kinesis.stream.name" +_AWS_BEDROCK_AGENT_ID: str = "aws.bedrock.agent.id" +_AWS_BEDROCK_GUARDRAIL_ID: str = "aws.bedrock.guardrail.id" +_AWS_BEDROCK_KNOWLEDGE_BASE_ID: str = "aws.bedrock.knowledge_base.id" +_AWS_BEDROCK_DATA_SOURCE_ID: str = "aws.bedrock.data_source.id" +_GEN_AI_REQUEST_MODEL: str = "gen_ai.request.model" + # pylint: disable=too-many-public-methods class AWSSDKTest(ContractTestBase): @@ -400,6 +406,139 @@ def test_kinesis_fault(self): span_name="Kinesis.PutRecord", ) + def test_bedrock_runtime_invoke_model(self): + self.do_test_requests( + "bedrock/invokemodel/invoke-model", + "GET", + 200, + 0, + 0, + local_operation="GET /bedrock", + rpc_service="BedrockRuntime", + remote_service="AWS::BedrockRuntime", + remote_operation="InvokeModel", + remote_resource_type="AWS::Bedrock::Model", + remote_resource_identifier="amazon.titan-text-premier-v1:0", + request_specific_attributes={ + _GEN_AI_REQUEST_MODEL: "amazon.titan-text-premier-v1:0", + }, + span_name="BedrockRuntime.InvokeModel", + ) + + def test_bedrock_get_guardrail(self): + self.do_test_requests( + "bedrock/getguardrail/get-guardrail", + "GET", + 200, + 0, + 0, + local_operation="GET /bedrock", + rpc_service="Bedrock", + remote_service="AWS::Bedrock", + remote_operation="GetGuardrail", + remote_resource_type="AWS::Bedrock::Guardrail", + remote_resource_identifier="bt4o77i015cu", + request_specific_attributes={ + _AWS_BEDROCK_GUARDRAIL_ID: "bt4o77i015cu", + }, + span_name="Bedrock.GetGuardrail", + ) + + def test_bedrock_agent_runtime_invoke_agent(self): + self.do_test_requests( + "bedrock/invokeagent/invoke_agent", + "GET", + 200, + 0, + 0, + local_operation="GET /bedrock", + rpc_service="BedrockAgentRuntime", + remote_service="AWS::Bedrock", + remote_operation="InvokeAgent", + remote_resource_type="AWS::Bedrock::Agent", + remote_resource_identifier="Q08WFRPHVL", + request_specific_attributes={ + _AWS_BEDROCK_AGENT_ID: "Q08WFRPHVL", + }, + span_name="BedrockAgentRuntime.InvokeAgent", + ) + + def test_bedrock_agent_runtime_retrieve(self): + self.do_test_requests( + "bedrock/retrieve/retrieve", + "GET", + 200, + 0, + 0, + local_operation="GET /bedrock", + rpc_service="BedrockAgentRuntime", + remote_service="AWS::Bedrock", + remote_operation="Retrieve", + remote_resource_type="AWS::Bedrock::KnowledgeBase", + remote_resource_identifier="test-knowledge-base-id", + request_specific_attributes={ + _AWS_BEDROCK_KNOWLEDGE_BASE_ID: "test-knowledge-base-id", + }, + span_name="BedrockAgentRuntime.Retrieve", + ) + + def test_bedrock_agent_get_agent(self): + self.do_test_requests( + "bedrock/getagent/get-agent", + "GET", + 200, + 0, + 0, + local_operation="GET /bedrock", + rpc_service="BedrockAgent", + remote_service="AWS::Bedrock", + remote_operation="GetAgent", + remote_resource_type="AWS::Bedrock::Agent", + remote_resource_identifier="TESTAGENTID", + request_specific_attributes={ + _AWS_BEDROCK_AGENT_ID: "TESTAGENTID", + }, + span_name="BedrockAgent.GetAgent", + ) + + def test_bedrock_agent_get_knowledge_base(self): + self.do_test_requests( + "bedrock/getknowledgebase/get_knowledge_base", + "GET", + 200, + 0, + 0, + local_operation="GET /bedrock", + rpc_service="BedrockAgent", + remote_service="AWS::Bedrock", + remote_operation="GetKnowledgeBase", + remote_resource_type="AWS::Bedrock::KnowledgeBase", + remote_resource_identifier="invalid-knowledge-base-id", + request_specific_attributes={ + _AWS_BEDROCK_KNOWLEDGE_BASE_ID: "invalid-knowledge-base-id", + }, + span_name="BedrockAgent.GetKnowledgeBase", + ) + + def test_bedrock_agent_get_data_source(self): + self.do_test_requests( + "bedrock/getdatasource/get_data_source", + "GET", + 200, + 0, + 0, + local_operation="GET /bedrock", + rpc_service="BedrockAgent", + remote_service="AWS::Bedrock", + remote_operation="GetDataSource", + remote_resource_type="AWS::Bedrock::DataSource", + remote_resource_identifier="DATASURCID", + request_specific_attributes={ + _AWS_BEDROCK_DATA_SOURCE_ID: "DATASURCID", + }, + span_name="BedrockAgent.GetDataSource", + ) + @override def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs) -> None: target_spans: List[Span] = []