Skip to content

Commit

Permalink
Rewrite tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed May 14, 2024
1 parent 0f1ef56 commit f50d41f
Showing 1 changed file with 84 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.bulk;
package org.elasticsearch.xpack.inference.action.bulk;

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.bulk.BulkItemRequest;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkShardRequest;
import org.elasticsearch.action.bulk.TransportShardBulkAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.ActionFilterChain;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.support.replication.ClusterStateCreationUtils;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
Expand All @@ -31,15 +28,13 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.inference.action.bulk.BulkShardOperationInferenceProcessor;
import org.elasticsearch.xpack.inference.model.TestModel;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.junit.After;
Expand All @@ -64,18 +59,31 @@
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class BulkShardOperationInferenceProcessorTests extends ESTestCase {
private ThreadPool threadPool;
private ClusterState clusterState;
private Metadata metadata;

@Before
public void setupThreadPool() {
public void setup() {
threadPool = new TestThreadPool(getTestName());

metadata = mock(Metadata.class);
clusterState = mock(ClusterState.class);
when(clusterState.metadata()).thenReturn(metadata);
}

private void mockRequestIndexInferenceFieldMap(BulkShardRequest request, Map<String, InferenceFieldMetadata> inferenceFieldMap) {
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(metadata.index(request.shardId().getIndexName())).thenReturn(indexMetadata);
when(indexMetadata.getInferenceFields()).thenReturn(inferenceFieldMap);
}

@After
Expand All @@ -84,63 +92,37 @@ public void tearDownThreadPool() throws Exception {
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testFilterNoop() throws Exception {
BulkShardOperationInferenceProcessor inferenceProcessor = createInfrenceProcessor(threadPool, Map.of(), DEFAULT_BATCH_SIZE);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
assertTrue(((BulkShardRequest) request).getInferenceFieldMap().isEmpty());
} finally {
chainExecuted.countDown();
}
};
ActionListener actionListener = mock(ActionListener.class);
Task task = mock(Task.class);
public void testNoop() throws Exception {
BulkShardOperationInferenceProcessor inferenceProcessor = createInferenceProcessor(threadPool, Map.of(), DEFAULT_BATCH_SIZE);
CountDownLatch latch = new CountDownLatch(1);
ActionListener<BulkShardRequest> actionListener = mock(ActionListener.class);
doAnswer(invocation -> {
latch.countDown();
return null;
}).when(actionListener).onResponse(any());
BulkShardRequest request = new BulkShardRequest(
new ShardId("test", "test", 0),
WriteRequest.RefreshPolicy.NONE,
new BulkItemRequest[0]
);
request.setInferenceFieldMap(
mockRequestIndexInferenceFieldMap(
request,
Map.of("foo", new InferenceFieldMetadata("foo", "bar", generateRandomStringArray(5, 10, false, false)))
);
Metadata metadata = mock(Metadata.class);
IndexMetadata indexMetadata = mock(IndexMetadata.class);

when(metadata.index(request.shardId().getIndex())).thenReturn(indexMetadata);
ClusterState clusterState = mock(ClusterState.class);
when(clusterState.metadata()).thenReturn(metadata);

inferenceProcessor.apply(request, actionListener, actionFilterChain);
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
inferenceProcessor.apply(request, clusterState, actionListener);
awaitLatch(latch, 10, TimeUnit.SECONDS);
verify(actionListener).onResponse(eq(request));
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testInferenceNotFound() throws Exception {
StaticModel model = StaticModel.createRandomInstance();
BulkShardOperationInferenceProcessor inferenceProcessor = createInfrenceProcessor(
BulkShardOperationInferenceProcessor inferenceProcessor = createInferenceProcessor(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10)
);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
assertTrue(bulkShardRequest.getInferenceFieldMap().isEmpty());
for (BulkItemRequest item : bulkShardRequest.items()) {
assertNotNull(item.getPrimaryResponse());
assertTrue(item.getPrimaryResponse().isFailed());
BulkItemResponse.Failure failure = item.getPrimaryResponse().getFailure();
assertThat(failure.getStatus(), equalTo(RestStatus.NOT_FOUND));
}
} finally {
chainExecuted.countDown();
}
};
ActionListener actionListener = mock(ActionListener.class);
Task task = mock(Task.class);

Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
"field1",
new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }),
Expand All @@ -154,26 +136,46 @@ public void testInferenceNotFound() throws Exception {
items[i] = randomBulkItemRequest(Map.of(), inferenceFieldMap)[0];
}
BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
request.setInferenceFieldMap(inferenceFieldMap);
inferenceProcessor.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
mockRequestIndexInferenceFieldMap(request, inferenceFieldMap);

CountDownLatch latch = new CountDownLatch(1);
ActionListener<BulkShardRequest> actionListener = mock(ActionListener.class);
doAnswer(invocation -> {
try {
BulkShardRequest bulkShardRequest = invocation.getArgument(0);
for (BulkItemRequest item : bulkShardRequest.items()) {
assertNotNull(item.getPrimaryResponse());
assertTrue(item.getPrimaryResponse().isFailed());
BulkItemResponse.Failure failure = item.getPrimaryResponse().getFailure();
assertThat(failure.getStatus(), equalTo(RestStatus.NOT_FOUND));
}
return null;
} finally {
latch.countDown();
}
}).when(actionListener).onResponse(any());

inferenceProcessor.apply(request, clusterState, actionListener);
awaitLatch(latch, 10, TimeUnit.SECONDS);
verify(actionListener).onResponse(any());
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testItemFailures() throws Exception {
StaticModel model = StaticModel.createRandomInstance();
BulkShardOperationInferenceProcessor inferenceProcessor = createInfrenceProcessor(
model.putResult("I am a failure", new ErrorChunkedInferenceResults(new IllegalArgumentException("boom")));
model.putResult("I am a success", randomSparseEmbeddings(List.of("I am a success")));
BulkShardOperationInferenceProcessor inferenceProcessor = createInferenceProcessor(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10)
);
model.putResult("I am a failure", new ErrorChunkedInferenceResults(new IllegalArgumentException("boom")));
model.putResult("I am a success", randomSparseEmbeddings(List.of("I am a success")));
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {

CountDownLatch latch = new CountDownLatch(1);
ActionListener<BulkShardRequest> actionListener = mock(ActionListener.class);
doAnswer(invocation -> {
try {
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
assertTrue(bulkShardRequest.getInferenceFieldMap().isEmpty());
BulkShardRequest bulkShardRequest = invocation.getArgument(0);
assertThat(bulkShardRequest.items().length, equalTo(3));

// item 0 is a failure
Expand All @@ -192,13 +194,11 @@ public void testItemFailures() throws Exception {
assertTrue(bulkShardRequest.items()[2].getPrimaryResponse().isFailed());
failure = bulkShardRequest.items()[2].getPrimaryResponse().getFailure();
assertThat(failure.getCause().getCause().getMessage(), containsString("boom"));
return null;
} finally {
chainExecuted.countDown();
latch.countDown();
}
};
ActionListener actionListener = mock(ActionListener.class);
Task task = mock(Task.class);

}).when(actionListener).onResponse(any());
Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
"field1",
new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" })
Expand All @@ -208,9 +208,11 @@ public void testItemFailures() throws Exception {
items[1] = new BulkItemRequest(1, new IndexRequest("index").source("field1", "I am a success"));
items[2] = new BulkItemRequest(2, new IndexRequest("index").source("field1", "I am a failure"));
BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
request.setInferenceFieldMap(inferenceFieldMap);
inferenceProcessor.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
mockRequestIndexInferenceFieldMap(request, inferenceFieldMap);

inferenceProcessor.apply(request, clusterState, actionListener);
awaitLatch(latch, 10, TimeUnit.SECONDS);
verify(actionListener).onResponse(any());
}

@SuppressWarnings({ "unchecked", "rawtypes" })
Expand Down Expand Up @@ -239,13 +241,16 @@ public void testManyRandomDocs() throws Exception {
modifiedRequests[id] = res[1];
}

BulkShardOperationInferenceProcessor inferenceProcessor = createInfrenceProcessor(threadPool, inferenceModelMap, randomIntBetween(10, 30));
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
BulkShardOperationInferenceProcessor inferenceProcessor = createInferenceProcessor(
threadPool,
inferenceModelMap,
randomIntBetween(10, 30)
);
CountDownLatch latch = new CountDownLatch(1);
ActionListener<BulkShardRequest> actionListener = mock(ActionListener.class);
doAnswer(invocation -> {
try {
assertThat(request, instanceOf(BulkShardRequest.class));
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
assertTrue(bulkShardRequest.getInferenceFieldMap().isEmpty());
BulkShardRequest bulkShardRequest = invocation.getArgument(0);
BulkItemRequest[] items = bulkShardRequest.items();
assertThat(items.length, equalTo(originalRequests.length));
for (int id = 0; id < items.length; id++) {
Expand All @@ -257,20 +262,19 @@ public void testManyRandomDocs() throws Exception {
throw new IllegalStateException(exc);
}
}
return null;
} finally {
chainExecuted.countDown();
latch.countDown();
}
};
ActionListener actionListener = mock(ActionListener.class);
Task task = mock(Task.class);
}).when(actionListener).onResponse(any());
BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests);
original.setInferenceFieldMap(inferenceFieldMap);
inferenceProcessor.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain);
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
mockRequestIndexInferenceFieldMap(original, inferenceFieldMap);
inferenceProcessor.apply(original, clusterState, actionListener);
awaitLatch(latch, 10, TimeUnit.SECONDS);
}

@SuppressWarnings("unchecked")
private static BulkShardOperationInferenceProcessor createInfrenceProcessor(
private static BulkShardOperationInferenceProcessor createInferenceProcessor(
ThreadPool threadPool,
Map<String, StaticModel> modelMap,
int batchSize
Expand Down

0 comments on commit f50d41f

Please sign in to comment.