Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add contract tests for bedrock #106

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions contract-tests/images/applications/aws-sdk/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
"license": "ISC",
"description": "",
"dependencies": {
"@aws-sdk/client-bedrock": "^3.675.0",
"@aws-sdk/client-bedrock-agent": "^3.675.0",
"@aws-sdk/client-bedrock-agent-runtime": "^3.676.0",
"@aws-sdk/client-bedrock-runtime": "^3.675.0",
"@aws-sdk/client-dynamodb": "^3.658.1",
"@aws-sdk/client-kinesis": "^3.658.1",
"@aws-sdk/client-s3": "^3.658.1",
Expand Down
133 changes: 133 additions & 0 deletions contract-tests/images/applications/aws-sdk/server.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ const { DynamoDBClient, CreateTableCommand, PutItemCommand } = require('@aws-sdk
const { SQSClient, CreateQueueCommand, SendMessageCommand, ReceiveMessageCommand } = require('@aws-sdk/client-sqs');
const { KinesisClient, CreateStreamCommand, PutRecordCommand } = require('@aws-sdk/client-kinesis');
const fetch = require('node-fetch');
const { BedrockClient, GetGuardrailCommand } = require('@aws-sdk/client-bedrock');
const { BedrockAgentClient, GetKnowledgeBaseCommand, GetDataSourceCommand, GetAgentCommand } = require('@aws-sdk/client-bedrock-agent');
const { BedrockRuntimeClient, InvokeModelCommand } = require('@aws-sdk/client-bedrock-runtime');
const { BedrockAgentRuntimeClient, InvokeAgentCommand, RetrieveCommand } = require('@aws-sdk/client-bedrock-agent-runtime');


const _PORT = 8080;
const _ERROR = 'error';
Expand Down Expand Up @@ -141,6 +146,8 @@ async function handleGetRequest(req, res, path) {
await handleSqsRequest(req, res, path);
} else if (path.includes('kinesis')) {
await handleKinesisRequest(req, res, path);
} else if (path.includes('bedrock')) {
await handleBedrockRequest(req, res, path);
} else {
res.writeHead(404);
res.end();
Expand Down Expand Up @@ -485,6 +492,132 @@ async function handleKinesisRequest(req, res, path) {
}
}

async function handleBedrockRequest(req, res, path) {
const bedrockClient = new BedrockClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION });
const bedrockAgentClient = new BedrockAgentClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION });
const bedrockRuntimeClient = new BedrockRuntimeClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION });
const bedrockAgentRuntimeClient = new BedrockAgentRuntimeClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION });

try {
if (path.includes('getknowledgebase/get_knowledge_base')) {
await withInjected200Success(bedrockAgentClient, ['GetKnowledgeBaseCommand'], {}, async () => {
await bedrockAgentClient.send(new GetKnowledgeBaseCommand({ knowledgeBaseId: 'invalid-knowledge-base-id' }));
});
res.statusCode = 200;
} else if (path.includes('getdatasource/get_data_source')) {
await withInjected200Success(bedrockAgentClient, ['GetDataSourceCommand'], {}, async () => {
await bedrockAgentClient.send(new GetDataSourceCommand({ knowledgeBaseId: 'TESTKBSEID', dataSourceId: 'DATASURCID' }));
});
res.statusCode = 200;
} else if (path.includes('getagent/get-agent')) {
await withInjected200Success(bedrockAgentClient, ['GetAgentCommand'], {}, async () => {
await bedrockAgentClient.send(new GetAgentCommand({ agentId: 'TESTAGENTID' }));
});
res.statusCode = 200;
} else if (path.includes('getguardrail/get-guardrail')) {
await withInjected200Success(
bedrockClient,
['GetGuardrailCommand'],
{ guardrailId: 'bt4o77i015cu' },
async () => {
await bedrockClient.send(
new GetGuardrailCommand({
guardrailIdentifier: 'arn:aws:bedrock:us-east-1:000000000000:guardrail/bt4o77i015cu',
})
);
}
);
res.statusCode = 200;
} else if (path.includes('invokeagent/invoke_agent')) {
await withInjected200Success(bedrockAgentRuntimeClient, ['InvokeAgentCommand'], {}, async () => {
await bedrockAgentRuntimeClient.send(
new InvokeAgentCommand({
agentId: 'Q08WFRPHVL',
agentAliasId: 'testAlias',
sessionId: 'testSessionId',
inputText: 'Invoke agent sample input text',
})
);
});
res.statusCode = 200;
} else if (path.includes('retrieve/retrieve')) {
await withInjected200Success(bedrockAgentRuntimeClient, ['RetrieveCommand'], {}, async () => {
await bedrockAgentRuntimeClient.send(
new RetrieveCommand({
knowledgeBaseId: 'test-knowledge-base-id',
retrievalQuery: {
text: 'an example of retrieve query',
},
})
);
});
res.statusCode = 200;
} else if (path.includes('invokemodel/invoke-model')) {
await withInjected200Success(bedrockRuntimeClient, ['InvokeModelCommand'], {}, async () => {
const modelId = 'amazon.titan-text-premier-v1:0';
const userMessage = "Describe the purpose of a 'hello world' program in one line.";
const prompt = `<s>[INST] ${userMessage} [/INST]`;

const body = JSON.stringify({
inputText: prompt,
textGenerationConfig: {
maxTokenCount: 3072,
stopSequences: [],
temperature: 0.7,
topP: 0.9,
},
});

await bedrockRuntimeClient.send(
new InvokeModelCommand({
body: body,
modelId: modelId,
accept: 'application/json',
contentType: 'application/json',
})
);
});
res.statusCode = 200;
} else {
res.statusCode = 404;
}
} catch (error) {
console.error('An error occurred:', error);
res.statusCode = 500;
}

res.end();
}

function inject200Success(client, commandNames, additionalResponse = {}, middlewareName = 'inject200SuccessMiddleware') {
const middleware = (next, context) => async (args) => {
const { commandName } = context;
if (commandNames.includes(commandName)) {
const response = {
$metadata: {
httpStatusCode: 200,
requestId: 'mock-request-id',
},
Message: 'Request succeeded',
...additionalResponse,
};
return { output: response };
}
return next(args);
};
// this middleware intercept the request and inject the response
client.middlewareStack.add(middleware, { step: 'build', name: middlewareName, priority: 'high' });
}

async function withInjected200Success(client, commandNames, additionalResponse, apiCall) {
const middlewareName = 'inject200SuccessMiddleware';
inject200Success(client, commandNames, additionalResponse, middlewareName);
await apiCall();
client.middlewareStack.remove(middlewareName);
}



prepareAwsServer().then(() => {
server.listen(_PORT, '0.0.0.0', () => {
console.log('Server is listening on port', _PORT);
Expand Down
139 changes: 139 additions & 0 deletions contract-tests/tests/test/amazon/aws-sdk/aws_sdk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
_AWS_SQS_QUEUE_URL: str = "aws.sqs.queue.url"
_AWS_SQS_QUEUE_NAME: str = "aws.sqs.queue.name"
_AWS_KINESIS_STREAM_NAME: str = "aws.kinesis.stream.name"
_AWS_BEDROCK_AGENT_ID: str = "aws.bedrock.agent.id"
_AWS_BEDROCK_GUARDRAIL_ID: str = "aws.bedrock.guardrail.id"
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: str = "aws.bedrock.knowledge_base.id"
_AWS_BEDROCK_DATA_SOURCE_ID: str = "aws.bedrock.data_source.id"
_GEN_AI_REQUEST_MODEL: str = "gen_ai.request.model"


# pylint: disable=too-many-public-methods
class AWSSDKTest(ContractTestBase):
Expand Down Expand Up @@ -400,6 +406,139 @@ def test_kinesis_fault(self):
span_name="Kinesis.PutRecord",
)

def test_bedrock_runtime_invoke_model(self):
self.do_test_requests(
"bedrock/invokemodel/invoke-model",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="BedrockRuntime",
remote_service="AWS::BedrockRuntime",
remote_operation="InvokeModel",
remote_resource_type="AWS::Bedrock::Model",
remote_resource_identifier="amazon.titan-text-premier-v1:0",
request_specific_attributes={
_GEN_AI_REQUEST_MODEL: "amazon.titan-text-premier-v1:0",
},
span_name="BedrockRuntime.InvokeModel",
)

def test_bedrock_get_guardrail(self):
self.do_test_requests(
"bedrock/getguardrail/get-guardrail",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="Bedrock",
remote_service="AWS::Bedrock",
remote_operation="GetGuardrail",
remote_resource_type="AWS::Bedrock::Guardrail",
remote_resource_identifier="bt4o77i015cu",
request_specific_attributes={
_AWS_BEDROCK_GUARDRAIL_ID: "bt4o77i015cu",
},
span_name="Bedrock.GetGuardrail",
)

def test_bedrock_agent_runtime_invoke_agent(self):
self.do_test_requests(
"bedrock/invokeagent/invoke_agent",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="BedrockAgentRuntime",
remote_service="AWS::Bedrock",
remote_operation="InvokeAgent",
remote_resource_type="AWS::Bedrock::Agent",
remote_resource_identifier="Q08WFRPHVL",
request_specific_attributes={
_AWS_BEDROCK_AGENT_ID: "Q08WFRPHVL",
},
span_name="BedrockAgentRuntime.InvokeAgent",
)

def test_bedrock_agent_runtime_retrieve(self):
self.do_test_requests(
"bedrock/retrieve/retrieve",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="BedrockAgentRuntime",
remote_service="AWS::Bedrock",
remote_operation="Retrieve",
remote_resource_type="AWS::Bedrock::KnowledgeBase",
remote_resource_identifier="test-knowledge-base-id",
request_specific_attributes={
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: "test-knowledge-base-id",
},
span_name="BedrockAgentRuntime.Retrieve",
)

def test_bedrock_agent_get_agent(self):
self.do_test_requests(
"bedrock/getagent/get-agent",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="BedrockAgent",
remote_service="AWS::Bedrock",
remote_operation="GetAgent",
remote_resource_type="AWS::Bedrock::Agent",
remote_resource_identifier="TESTAGENTID",
request_specific_attributes={
_AWS_BEDROCK_AGENT_ID: "TESTAGENTID",
},
span_name="BedrockAgent.GetAgent",
)

def test_bedrock_agent_get_knowledge_base(self):
self.do_test_requests(
"bedrock/getknowledgebase/get_knowledge_base",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="BedrockAgent",
remote_service="AWS::Bedrock",
remote_operation="GetKnowledgeBase",
remote_resource_type="AWS::Bedrock::KnowledgeBase",
remote_resource_identifier="invalid-knowledge-base-id",
request_specific_attributes={
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: "invalid-knowledge-base-id",
},
span_name="BedrockAgent.GetKnowledgeBase",
)

def test_bedrock_agent_get_data_source(self):
self.do_test_requests(
"bedrock/getdatasource/get_data_source",
"GET",
200,
0,
0,
local_operation="GET /bedrock",
rpc_service="BedrockAgent",
remote_service="AWS::Bedrock",
remote_operation="GetDataSource",
remote_resource_type="AWS::Bedrock::DataSource",
remote_resource_identifier="DATASURCID",
request_specific_attributes={
_AWS_BEDROCK_DATA_SOURCE_ID: "DATASURCID",
},
span_name="BedrockAgent.GetDataSource",
)

@override
def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs) -> None:
target_spans: List[Span] = []
Expand Down
Loading