diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/bulk/BulkShardOperationInferenceProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/bulk/BulkShardOperationInferenceProcessorTests.java similarity index 81% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/bulk/BulkShardOperationInferenceProcessorTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/bulk/BulkShardOperationInferenceProcessorTests.java index 06a57ee7835c7..4de23682dae69 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/bulk/BulkShardOperationInferenceProcessorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/bulk/BulkShardOperationInferenceProcessorTests.java @@ -5,18 +5,15 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.bulk; +package org.elasticsearch.xpack.inference.action.bulk; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.bulk.BulkItemRequest; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkShardRequest; -import org.elasticsearch.action.bulk.TransportShardBulkAction; import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.action.support.replication.ClusterStateCreationUtils; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; @@ -31,7 +28,6 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; @@ -39,7 +35,6 @@ import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.inference.action.bulk.BulkShardOperationInferenceProcessor; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.After; @@ -64,18 +59,31 @@ import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class BulkShardOperationInferenceProcessorTests extends ESTestCase { private ThreadPool threadPool; + private ClusterState clusterState; + private Metadata metadata; @Before - public void setupThreadPool() { + public void setup() { threadPool = new TestThreadPool(getTestName()); + + metadata = mock(Metadata.class); + clusterState = mock(ClusterState.class); + when(clusterState.metadata()).thenReturn(metadata); + } + + private void mockRequestIndexInferenceFieldMap(BulkShardRequest request, Map inferenceFieldMap) { + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(metadata.index(request.shardId().getIndexName())).thenReturn(indexMetadata); + when(indexMetadata.getInferenceFields()).thenReturn(inferenceFieldMap); } @After @@ -84,63 +92,37 @@ public void tearDownThreadPool() throws Exception { } @SuppressWarnings({ "unchecked", "rawtypes" }) - public void testFilterNoop() throws Exception { - BulkShardOperationInferenceProcessor inferenceProcessor = createInfrenceProcessor(threadPool, Map.of(), DEFAULT_BATCH_SIZE); - CountDownLatch chainExecuted = new CountDownLatch(1); - ActionFilterChain actionFilterChain = (task, action, request, listener) -> { - try { - assertTrue(((BulkShardRequest) request).getInferenceFieldMap().isEmpty()); - } finally { - chainExecuted.countDown(); - } - }; - ActionListener actionListener = mock(ActionListener.class); - Task task = mock(Task.class); + public void testNoop() throws Exception { + BulkShardOperationInferenceProcessor inferenceProcessor = createInferenceProcessor(threadPool, Map.of(), DEFAULT_BATCH_SIZE); + CountDownLatch latch = new CountDownLatch(1); + ActionListener actionListener = mock(ActionListener.class); + doAnswer(invocation -> { + latch.countDown(); + return null; + }).when(actionListener).onResponse(any()); BulkShardRequest request = new BulkShardRequest( new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, new BulkItemRequest[0] ); - request.setInferenceFieldMap( + mockRequestIndexInferenceFieldMap( + request, Map.of("foo", new InferenceFieldMetadata("foo", "bar", generateRandomStringArray(5, 10, false, false))) ); - Metadata metadata = mock(Metadata.class); - IndexMetadata indexMetadata = mock(IndexMetadata.class); - when(metadata.index(request.shardId().getIndex())).thenReturn(indexMetadata); - ClusterState clusterState = mock(ClusterState.class); - when(clusterState.metadata()).thenReturn(metadata); - - inferenceProcessor.apply(request, actionListener, actionFilterChain); - awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + inferenceProcessor.apply(request, clusterState, actionListener); + awaitLatch(latch, 10, TimeUnit.SECONDS); + verify(actionListener).onResponse(eq(request)); } @SuppressWarnings({ "unchecked", "rawtypes" }) public void testInferenceNotFound() throws Exception { StaticModel model = StaticModel.createRandomInstance(); - BulkShardOperationInferenceProcessor inferenceProcessor = createInfrenceProcessor( + BulkShardOperationInferenceProcessor inferenceProcessor = createInferenceProcessor( threadPool, Map.of(model.getInferenceEntityId(), model), randomIntBetween(1, 10) ); - CountDownLatch chainExecuted = new CountDownLatch(1); - ActionFilterChain actionFilterChain = (task, action, request, listener) -> { - try { - BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertTrue(bulkShardRequest.getInferenceFieldMap().isEmpty()); - for (BulkItemRequest item : bulkShardRequest.items()) { - assertNotNull(item.getPrimaryResponse()); - assertTrue(item.getPrimaryResponse().isFailed()); - BulkItemResponse.Failure failure = item.getPrimaryResponse().getFailure(); - assertThat(failure.getStatus(), equalTo(RestStatus.NOT_FOUND)); - } - } finally { - chainExecuted.countDown(); - } - }; - ActionListener actionListener = mock(ActionListener.class); - Task task = mock(Task.class); - Map inferenceFieldMap = Map.of( "field1", new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }), @@ -154,26 +136,46 @@ public void testInferenceNotFound() throws Exception { items[i] = randomBulkItemRequest(Map.of(), inferenceFieldMap)[0]; } BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); - request.setInferenceFieldMap(inferenceFieldMap); - inferenceProcessor.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); - awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + mockRequestIndexInferenceFieldMap(request, inferenceFieldMap); + + CountDownLatch latch = new CountDownLatch(1); + ActionListener actionListener = mock(ActionListener.class); + doAnswer(invocation -> { + try { + BulkShardRequest bulkShardRequest = invocation.getArgument(0); + for (BulkItemRequest item : bulkShardRequest.items()) { + assertNotNull(item.getPrimaryResponse()); + assertTrue(item.getPrimaryResponse().isFailed()); + BulkItemResponse.Failure failure = item.getPrimaryResponse().getFailure(); + assertThat(failure.getStatus(), equalTo(RestStatus.NOT_FOUND)); + } + return null; + } finally { + latch.countDown(); + } + }).when(actionListener).onResponse(any()); + + inferenceProcessor.apply(request, clusterState, actionListener); + awaitLatch(latch, 10, TimeUnit.SECONDS); + verify(actionListener).onResponse(any()); } @SuppressWarnings({ "unchecked", "rawtypes" }) public void testItemFailures() throws Exception { StaticModel model = StaticModel.createRandomInstance(); - BulkShardOperationInferenceProcessor inferenceProcessor = createInfrenceProcessor( + model.putResult("I am a failure", new ErrorChunkedInferenceResults(new IllegalArgumentException("boom"))); + model.putResult("I am a success", randomSparseEmbeddings(List.of("I am a success"))); + BulkShardOperationInferenceProcessor inferenceProcessor = createInferenceProcessor( threadPool, Map.of(model.getInferenceEntityId(), model), randomIntBetween(1, 10) ); - model.putResult("I am a failure", new ErrorChunkedInferenceResults(new IllegalArgumentException("boom"))); - model.putResult("I am a success", randomSparseEmbeddings(List.of("I am a success"))); - CountDownLatch chainExecuted = new CountDownLatch(1); - ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + + CountDownLatch latch = new CountDownLatch(1); + ActionListener actionListener = mock(ActionListener.class); + doAnswer(invocation -> { try { - BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertTrue(bulkShardRequest.getInferenceFieldMap().isEmpty()); + BulkShardRequest bulkShardRequest = invocation.getArgument(0); assertThat(bulkShardRequest.items().length, equalTo(3)); // item 0 is a failure @@ -192,13 +194,11 @@ public void testItemFailures() throws Exception { assertTrue(bulkShardRequest.items()[2].getPrimaryResponse().isFailed()); failure = bulkShardRequest.items()[2].getPrimaryResponse().getFailure(); assertThat(failure.getCause().getCause().getMessage(), containsString("boom")); + return null; } finally { - chainExecuted.countDown(); + latch.countDown(); } - }; - ActionListener actionListener = mock(ActionListener.class); - Task task = mock(Task.class); - + }).when(actionListener).onResponse(any()); Map inferenceFieldMap = Map.of( "field1", new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }) @@ -208,9 +208,11 @@ public void testItemFailures() throws Exception { items[1] = new BulkItemRequest(1, new IndexRequest("index").source("field1", "I am a success")); items[2] = new BulkItemRequest(2, new IndexRequest("index").source("field1", "I am a failure")); BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); - request.setInferenceFieldMap(inferenceFieldMap); - inferenceProcessor.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); - awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + mockRequestIndexInferenceFieldMap(request, inferenceFieldMap); + + inferenceProcessor.apply(request, clusterState, actionListener); + awaitLatch(latch, 10, TimeUnit.SECONDS); + verify(actionListener).onResponse(any()); } @SuppressWarnings({ "unchecked", "rawtypes" }) @@ -239,13 +241,16 @@ public void testManyRandomDocs() throws Exception { modifiedRequests[id] = res[1]; } - BulkShardOperationInferenceProcessor inferenceProcessor = createInfrenceProcessor(threadPool, inferenceModelMap, randomIntBetween(10, 30)); - CountDownLatch chainExecuted = new CountDownLatch(1); - ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + BulkShardOperationInferenceProcessor inferenceProcessor = createInferenceProcessor( + threadPool, + inferenceModelMap, + randomIntBetween(10, 30) + ); + CountDownLatch latch = new CountDownLatch(1); + ActionListener actionListener = mock(ActionListener.class); + doAnswer(invocation -> { try { - assertThat(request, instanceOf(BulkShardRequest.class)); - BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertTrue(bulkShardRequest.getInferenceFieldMap().isEmpty()); + BulkShardRequest bulkShardRequest = invocation.getArgument(0); BulkItemRequest[] items = bulkShardRequest.items(); assertThat(items.length, equalTo(originalRequests.length)); for (int id = 0; id < items.length; id++) { @@ -257,20 +262,19 @@ public void testManyRandomDocs() throws Exception { throw new IllegalStateException(exc); } } + return null; } finally { - chainExecuted.countDown(); + latch.countDown(); } - }; - ActionListener actionListener = mock(ActionListener.class); - Task task = mock(Task.class); + }).when(actionListener).onResponse(any()); BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); - original.setInferenceFieldMap(inferenceFieldMap); - inferenceProcessor.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); - awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + mockRequestIndexInferenceFieldMap(original, inferenceFieldMap); + inferenceProcessor.apply(original, clusterState, actionListener); + awaitLatch(latch, 10, TimeUnit.SECONDS); } @SuppressWarnings("unchecked") - private static BulkShardOperationInferenceProcessor createInfrenceProcessor( + private static BulkShardOperationInferenceProcessor createInferenceProcessor( ThreadPool threadPool, Map modelMap, int batchSize