Skip to content

Commit

Permalink
fix multiple docs support (opensearch-project#1516) (opensearch-proje…
Browse files Browse the repository at this point in the history
…ct#1517)

Signed-off-by: Yaliang Wu <[email protected]>
(cherry picked from commit 27ffeb2)

Co-authored-by: Yaliang Wu <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and ylwu-amzn authored Oct 13, 2023
1 parent ccf0b7b commit b60b2c6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ default ModelTensorOutput executePredict(MLInput mlInput) {
List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size());
List<ModelTensors> tempTensorOutputs = new ArrayList<>();
preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs);
processedDocs += Math.max(tempTensorOutputs.size(), 1);
int tensorCount = 0;
if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) {
tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size();
}
processedDocs += Math.max(tensorCount, 1);
tensorOutputs.addAll(tempTensorOutputs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ public void executePredict_TextDocsInput() throws IOException {
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData());
Expand Down

0 comments on commit b60b2c6

Please sign in to comment.