Skip to content

Commit

Permalink
add contract tests for bedrock (#106)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:*

Add contract tests for Bedrock instrumentation which include the
following APIs:

* BedrockRuntime.InvokeModel
* Bedrock.GetGuardrail
* BedrockAgentRuntime.InvokeAgent
* BedrockAgentRuntime.Retrieve
* BedrockAgent.GetAgent
* BedrockAgent.GetKnowledgeBase
* BedrockAgent.GetDataSource


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
pxaws authored Oct 22, 2024
1 parent 3773450 commit a043f25
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 0 deletions.
4 changes: 4 additions & 0 deletions contract-tests/images/applications/aws-sdk/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
"license": "ISC",
"description": "",
"dependencies": {
"@aws-sdk/client-bedrock": "^3.675.0",
"@aws-sdk/client-bedrock-agent": "^3.675.0",
"@aws-sdk/client-bedrock-agent-runtime": "^3.676.0",
"@aws-sdk/client-bedrock-runtime": "^3.675.0",
"@aws-sdk/client-dynamodb": "^3.658.1",
"@aws-sdk/client-kinesis": "^3.658.1",
"@aws-sdk/client-s3": "^3.658.1",
Expand Down
133 changes: 133 additions & 0 deletions contract-tests/images/applications/aws-sdk/server.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ const { DynamoDBClient, CreateTableCommand, PutItemCommand } = require('@aws-sdk
const { SQSClient, CreateQueueCommand, SendMessageCommand, ReceiveMessageCommand } = require('@aws-sdk/client-sqs');
const { KinesisClient, CreateStreamCommand, PutRecordCommand } = require('@aws-sdk/client-kinesis');
const fetch = require('node-fetch');
const { BedrockClient, GetGuardrailCommand } = require('@aws-sdk/client-bedrock');
const { BedrockAgentClient, GetKnowledgeBaseCommand, GetDataSourceCommand, GetAgentCommand } = require('@aws-sdk/client-bedrock-agent');
const { BedrockRuntimeClient, InvokeModelCommand } = require('@aws-sdk/client-bedrock-runtime');
const { BedrockAgentRuntimeClient, InvokeAgentCommand, RetrieveCommand } = require('@aws-sdk/client-bedrock-agent-runtime');


const _PORT = 8080;
const _ERROR = 'error';
Expand Down Expand Up @@ -141,6 +146,8 @@ async function handleGetRequest(req, res, path) {
await handleSqsRequest(req, res, path);
} else if (path.includes('kinesis')) {
await handleKinesisRequest(req, res, path);
} else if (path.includes('bedrock')) {
await handleBedrockRequest(req, res, path);
} else {
res.writeHead(404);
res.end();
Expand Down Expand Up @@ -485,6 +492,132 @@ async function handleKinesisRequest(req, res, path) {
}
}

async function handleBedrockRequest(req, res, path) {
const bedrockClient = new BedrockClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION });
const bedrockAgentClient = new BedrockAgentClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION });
const bedrockRuntimeClient = new BedrockRuntimeClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION });
const bedrockAgentRuntimeClient = new BedrockAgentRuntimeClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION });

try {
if (path.includes('getknowledgebase/get_knowledge_base')) {
await withInjected200Success(bedrockAgentClient, ['GetKnowledgeBaseCommand'], {}, async () => {
await bedrockAgentClient.send(new GetKnowledgeBaseCommand({ knowledgeBaseId: 'invalid-knowledge-base-id' }));
});
res.statusCode = 200;
} else if (path.includes('getdatasource/get_data_source')) {
await withInjected200Success(bedrockAgentClient, ['GetDataSourceCommand'], {}, async () => {
await bedrockAgentClient.send(new GetDataSourceCommand({ knowledgeBaseId: 'TESTKBSEID', dataSourceId: 'DATASURCID' }));
});
res.statusCode = 200;
} else if (path.includes('getagent/get-agent')) {
await withInjected200Success(bedrockAgentClient, ['GetAgentCommand'], {}, async () => {
await bedrockAgentClient.send(new GetAgentCommand({ agentId: 'TESTAGENTID' }));
});
res.statusCode = 200;
} else if (path.includes('getguardrail/get-guardrail')) {
await withInjected200Success(
bedrockClient,
['GetGuardrailCommand'],
{ guardrailId: 'bt4o77i015cu' },
async () => {
await bedrockClient.send(
new GetGuardrailCommand({
guardrailIdentifier: 'arn:aws:bedrock:us-east-1:000000000000:guardrail/bt4o77i015cu',
})
);
}
);
res.statusCode = 200;
} else if (path.includes('invokeagent/invoke_agent')) {
await withInjected200Success(bedrockAgentRuntimeClient, ['InvokeAgentCommand'], {}, async () => {
await bedrockAgentRuntimeClient.send(
new InvokeAgentCommand({
agentId: 'Q08WFRPHVL',
agentAliasId: 'testAlias',
sessionId: 'testSessionId',
inputText: 'Invoke agent sample input text',
})
);
});
res.statusCode = 200;
} else if (path.includes('retrieve/retrieve')) {
await withInjected200Success(bedrockAgentRuntimeClient, ['RetrieveCommand'], {}, async () => {
await bedrockAgentRuntimeClient.send(
new RetrieveCommand({
knowledgeBaseId: 'test-knowledge-base-id',
retrievalQuery: {
text: 'an example of retrieve query',
},
})
);
});
res.statusCode = 200;
} else if (path.includes('invokemodel/invoke-model')) {
await withInjected200Success(bedrockRuntimeClient, ['InvokeModelCommand'], {}, async () => {
const modelId = 'amazon.titan-text-premier-v1:0';
const userMessage = "Describe the purpose of a 'hello world' program in one line.";
const prompt = `<s>[INST] ${userMessage} [/INST]`;

const body = JSON.stringify({
inputText: prompt,
textGenerationConfig: {
maxTokenCount: 3072,
stopSequences: [],
temperature: 0.7,
topP: 0.9,
},
});

await bedrockRuntimeClient.send(
new InvokeModelCommand({
body: body,
modelId: modelId,
accept: 'application/json',
contentType: 'application/json',
})
);
});
res.statusCode = 200;
} else {
res.statusCode = 404;
}
} catch (error) {
console.error('An error occurred:', error);
res.statusCode = 500;
}

res.end();
}

function inject200Success(client, commandNames, additionalResponse = {}, middlewareName = 'inject200SuccessMiddleware') {
const middleware = (next, context) => async (args) => {
const { commandName } = context;
if (commandNames.includes(commandName)) {
const response = {
$metadata: {
httpStatusCode: 200,
requestId: 'mock-request-id',
},
Message: 'Request succeeded',
...additionalResponse,
};
return { output: response };
}
return next(args);
};
// this middleware intercept the request and inject the response
client.middlewareStack.add(middleware, { step: 'build', name: middlewareName, priority: 'high' });
}

async function withInjected200Success(client, commandNames, additionalResponse, apiCall) {
const middlewareName = 'inject200SuccessMiddleware';
inject200Success(client, commandNames, additionalResponse, middlewareName);
await apiCall();
client.middlewareStack.remove(middlewareName);
}



prepareAwsServer().then(() => {
server.listen(_PORT, '0.0.0.0', () => {
console.log('Server is listening on port', _PORT);
Expand Down
139 changes: 139 additions & 0 deletions contract-tests/tests/test/amazon/aws-sdk/aws_sdk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
_AWS_SQS_QUEUE_URL: str = "aws.sqs.queue.url"
_AWS_SQS_QUEUE_NAME: str = "aws.sqs.queue.name"
_AWS_KINESIS_STREAM_NAME: str = "aws.kinesis.stream.name"
_AWS_BEDROCK_AGENT_ID: str = "aws.bedrock.agent.id"
_AWS_BEDROCK_GUARDRAIL_ID: str = "aws.bedrock.guardrail.id"
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: str = "aws.bedrock.knowledge_base.id"
_AWS_BEDROCK_DATA_SOURCE_ID: str = "aws.bedrock.data_source.id"
_GEN_AI_REQUEST_MODEL: str = "gen_ai.request.model"


# pylint: disable=too-many-public-methods
class AWSSDKTest(ContractTestBase):
Expand Down Expand Up @@ -400,6 +406,139 @@ def test_kinesis_fault(self):
span_name="Kinesis.PutRecord",
)

def test_bedrock_runtime_invoke_model(self):
self.do_test_requests(
"bedrock/invokemodel/invoke-model",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="BedrockRuntime",
remote_service="AWS::BedrockRuntime",
remote_operation="InvokeModel",
remote_resource_type="AWS::Bedrock::Model",
remote_resource_identifier="amazon.titan-text-premier-v1:0",
request_specific_attributes={
_GEN_AI_REQUEST_MODEL: "amazon.titan-text-premier-v1:0",
},
span_name="BedrockRuntime.InvokeModel",
)

def test_bedrock_get_guardrail(self):
self.do_test_requests(
"bedrock/getguardrail/get-guardrail",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="Bedrock",
remote_service="AWS::Bedrock",
remote_operation="GetGuardrail",
remote_resource_type="AWS::Bedrock::Guardrail",
remote_resource_identifier="bt4o77i015cu",
request_specific_attributes={
_AWS_BEDROCK_GUARDRAIL_ID: "bt4o77i015cu",
},
span_name="Bedrock.GetGuardrail",
)

def test_bedrock_agent_runtime_invoke_agent(self):
self.do_test_requests(
"bedrock/invokeagent/invoke_agent",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="BedrockAgentRuntime",
remote_service="AWS::Bedrock",
remote_operation="InvokeAgent",
remote_resource_type="AWS::Bedrock::Agent",
remote_resource_identifier="Q08WFRPHVL",
request_specific_attributes={
_AWS_BEDROCK_AGENT_ID: "Q08WFRPHVL",
},
span_name="BedrockAgentRuntime.InvokeAgent",
)

def test_bedrock_agent_runtime_retrieve(self):
self.do_test_requests(
"bedrock/retrieve/retrieve",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="BedrockAgentRuntime",
remote_service="AWS::Bedrock",
remote_operation="Retrieve",
remote_resource_type="AWS::Bedrock::KnowledgeBase",
remote_resource_identifier="test-knowledge-base-id",
request_specific_attributes={
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: "test-knowledge-base-id",
},
span_name="BedrockAgentRuntime.Retrieve",
)

def test_bedrock_agent_get_agent(self):
self.do_test_requests(
"bedrock/getagent/get-agent",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="BedrockAgent",
remote_service="AWS::Bedrock",
remote_operation="GetAgent",
remote_resource_type="AWS::Bedrock::Agent",
remote_resource_identifier="TESTAGENTID",
request_specific_attributes={
_AWS_BEDROCK_AGENT_ID: "TESTAGENTID",
},
span_name="BedrockAgent.GetAgent",
)

def test_bedrock_agent_get_knowledge_base(self):
self.do_test_requests(
"bedrock/getknowledgebase/get_knowledge_base",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="BedrockAgent",
remote_service="AWS::Bedrock",
remote_operation="GetKnowledgeBase",
remote_resource_type="AWS::Bedrock::KnowledgeBase",
remote_resource_identifier="invalid-knowledge-base-id",
request_specific_attributes={
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: "invalid-knowledge-base-id",
},
span_name="BedrockAgent.GetKnowledgeBase",
)

def test_bedrock_agent_get_data_source(self):
self.do_test_requests(
"bedrock/getdatasource/get_data_source",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="BedrockAgent",
remote_service="AWS::Bedrock",
remote_operation="GetDataSource",
remote_resource_type="AWS::Bedrock::DataSource",
remote_resource_identifier="DATASURCID",
request_specific_attributes={
_AWS_BEDROCK_DATA_SOURCE_ID: "DATASURCID",
},
span_name="BedrockAgent.GetDataSource",
)

@override
def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs) -> None:
target_spans: List[Span] = []
Expand Down

0 comments on commit a043f25

Please sign in to comment.