Skip to content

Commit

Permalink
Adding tests to improve test coverage
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 2, 2023
1 parent d4d0a71 commit e976c79
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,23 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() {
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenRetryThreeTimes() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
);
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener);

Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void setup() {
}

@SneakyThrows
private TextImageEmbeddingProcessor createInstance(List<List<Float>> vector) {
private TextImageEmbeddingProcessor createInstance() {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
Expand Down Expand Up @@ -112,7 +112,7 @@ public void testExecute_successful() {
sourceAndMetadata.put("my_text_field", "value2");
sourceAndMetadata.put("key3", "value3");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
TextImageEmbeddingProcessor processor = createInstance();

List<List<Float>> modelTensorList = createMockVectorResult();
doAnswer(invocation -> {
Expand Down Expand Up @@ -157,7 +157,7 @@ public void testExecute_withListTypeInput_successful() {
sourceAndMetadata.put("my_text_field", "value1");
sourceAndMetadata.put("another_text_field", "value2");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(6));
TextImageEmbeddingProcessor processor = createInstance();

List<List<Float>> modelTensorList = createMockVectorResult();
doAnswer(invocation -> {
Expand All @@ -177,7 +177,7 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() {
sourceAndMetadata.put("key1", "hello world");
sourceAndMetadata.put("my_text_field", ret);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
TextImageEmbeddingProcessor processor = createInstance();
BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
Expand All @@ -188,7 +188,7 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() {
sourceAndMetadata.put("my_text_field", "value1");
sourceAndMetadata.put("key2", "value2");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
TextImageEmbeddingProcessor processor = createInstance();

doAnswer(invocation -> {
ActionListener<List<List<Float>>> listener = invocation.getArgument(2);
Expand All @@ -208,7 +208,7 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() {
sourceAndMetadata.put("key1", map1);
sourceAndMetadata.put("my_text_field", map2);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
TextImageEmbeddingProcessor processor = createInstance();
BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
Expand All @@ -221,12 +221,23 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() {
sourceAndMetadata.put("key1", map1);
sourceAndMetadata.put("my_text_field", map2);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
TextImageEmbeddingProcessor processor = createInstance();
BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_hybridTypeInput_successful() throws Exception {
List<String> list1 = ImmutableList.of("test1", "test2");
Map<String, List<String>> map1 = ImmutableMap.of("test3", list1);
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key2", map1);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextImageEmbeddingProcessor processor = createInstance();
IngestDocument document = processor.execute(ingestDocument);
assert document.getSourceAndMetadata().containsKey("key2");
}

private List<List<Float>> createMockVectorResult() {
List<List<Float>> modelTensorList = new ArrayList<>();
List<Float> number1 = ImmutableList.of(1.234f, 2.354f);
Expand Down

0 comments on commit e976c79

Please sign in to comment.