Skip to content

Commit

Permalink
Fixed error for case when mltensor has data as null (opensearch-proje…
Browse files Browse the repository at this point in the history
…ct#404)

* Fixed error for case when mltensor has data as null

Signed-off-by: Martin Gaievski <[email protected]>

* Changed error message handling

Signed-off-by: Martin Gaievski <[email protected]>

---------

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Oct 6, 2023
1 parent bc6ac7a commit 47d1aee
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.apache.logging.log4j.util.Strings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
Expand All @@ -40,6 +43,8 @@
public class MLCommonsClientAccessor {
private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding");
private final MachineLearningNodeClient mlClient;
private static final String EXCEPTION_MESSAGE_MODEL_PREDICT_FAILED = "failed while calling model, check error log for details";
private static final String EXCEPTION_MESSAGE_PREFIX_MODEL_PREDICT_FAILED = "encountered following error while calling a model";

/**
* Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating
Expand Down Expand Up @@ -187,6 +192,20 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
if (Objects.isNull(tensor.getData())) {
if (Objects.nonNull(tensor.getDataAsMap()) && Strings.isNotBlank((String) tensor.getDataAsMap().get("message"))) {
String errorFromModel = (String) tensor.getDataAsMap().get("message");
throw new IllegalStateException(
String.format(Locale.ROOT, "%s: %s", EXCEPTION_MESSAGE_PREFIX_MODEL_PREDICT_FAILED, errorFromModel)
);
} else {
log.error(
"Received following output tensor from a model, there is no detailed error message: {}",
tensor.toString()
);
throw new IllegalStateException(EXCEPTION_MESSAGE_MODEL_PREDICT_FAILED);
}
}
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.neuralsearch.ml;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;

Expand Down Expand Up @@ -328,6 +330,65 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);
}

public void testInferenceMultimodal_whenInvalidInputAndEmptyTensorOutput_thenFail() {
List<ModelTensors> tensorsList = new ArrayList<>();
List<ModelTensor> mlModelTensorList = List.of(
new ModelTensor(
"someValue",
null,
new long[] { 1, 2 },
MLResultDataType.FLOAT64,
ByteBuffer.wrap(new byte[12]),
"mockResult",
ImmutableMap.of("message", "The system encountered an unexpected error during processing. Try your request again.")
)
);
final ModelTensors modelTensors = new ModelTensors(mlModelTensorList);
ModelTensorOutput outputWithErrorMessage = new ModelTensorOutput(List.of(modelTensors));

Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(outputWithErrorMessage);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(any());
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);

clearInvocations(client, singleSentenceResultListener);

List<ModelTensor> mlModelTensorList2 = List.of(
new ModelTensor(
"someValue",
null,
new long[] { 1, 2 },
MLResultDataType.FLOAT64,
ByteBuffer.wrap(new byte[12]),
"mockResult",
ImmutableMap.of("test_key", "test_value")
)
);
final ModelTensors modelTensors2 = new ModelTensors(mlModelTensorList2);
ModelTensorOutput outputWithErrorMessage2 = new ModelTensorOutput(List.of(modelTensors2));

Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(outputWithErrorMessage2);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(any());
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand Down

0 comments on commit 47d1aee

Please sign in to comment.