Skip to content

Commit

Permalink
Add a new IT test that uses both an image and a documenet in RAG.
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee committed Oct 9, 2024
1 parent 74c211e commit c984f62
Showing 1 changed file with 69 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,72 @@ public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws
assertNotNull(answer);
}

public void testBM25WithBedrockConverseUsingLlmMessagesForImageAndDocument() throws Exception {
// Skip test if key is null
if (AWS_ACCESS_KEY_ID == null) {
return;
}
Response response = createConnector(BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);

PipelineParameters pipelineParameters = new PipelineParameters();
pipelineParameters.tag = "testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat";
pipelineParameters.description = "desc";
pipelineParameters.modelId = modelId;
// pipelineParameters.systemPrompt = "You are a helpful assistant";
pipelineParameters.userInstructions = "none";
pipelineParameters.context_field = "text";
Response response1 = createSearchPipeline2("pipeline_test", pipelineParameters);
assertEquals(200, response1.getStatusLine().getStatusCode());

byte[] rawImage = FileUtils
.readFileToByteArray(Path.of(classLoader.getResource(TEST_DOC_PATH + "openai_boardwalk.jpg").toURI()).toFile());
String imageContent = Base64.getEncoder().encodeToString(rawImage);

byte[] docBytes = FileUtils.readFileToByteArray(Path.of(classLoader.getResource(TEST_DOC_PATH + "lincoln.pdf").toURI()).toFile());
String docContent = Base64.getEncoder().encodeToString(docBytes);

SearchRequestParameters requestParameters;
requestParameters = new SearchRequestParameters();
requestParameters.source = "text";
requestParameters.match = "president";
requestParameters.llmModel = BEDROCK_CONVERSE_ANTHROPIC_CLAUDE;
requestParameters.llmQuestion = "use the information from the attached document to tell me something interesting about lincoln";
requestParameters.contextSize = 5;
requestParameters.interactionSize = 5;
requestParameters.timeout = 60;
requestParameters.imageFormat = "jpeg";
requestParameters.imageType = "data"; // Bedrock does not support URLs
requestParameters.imageData = imageContent;
requestParameters.documentFormat = "pdf";
requestParameters.documentName = "lincoln";
requestParameters.documentData = docContent;
Response response3 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters);
assertEquals(200, response3.getStatusLine().getStatusCode());

Map responseMap3 = parseResponseToMap(response3);
Map ext = (Map) responseMap3.get("ext");
assertNotNull(ext);
Map rag = (Map) ext.get("retrieval_augmented_generation");
assertNotNull(rag);

// TODO handle errors such as throttling
String answer = (String) rag.get("answer");
assertNotNull(answer);
}

public void testBM25WithOpenAIWithConversation() throws Exception {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand Down Expand Up @@ -1352,6 +1418,9 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
requestParameters.interactionSize,
requestParameters.timeout
);

System.out.println(httpEntity);

return makeRequest(
client(),
"POST",
Expand Down

0 comments on commit c984f62

Please sign in to comment.