diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 48b6156a43039..a024d4d842e95 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -30,6 +30,7 @@ dependencies { compileOnly project(":server") compileOnly project(path: xpackModule('core')) testImplementation(testArtifact(project(xpackModule('core')))) + testImplementation(project(':x-pack:plugin:inference:qa:test-service-plugin')) testImplementation project(':modules:reindex') clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 1bde3704864d5..a65b8e43e6adf 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -27,14 +27,6 @@ public abstract class AbstractTestInferenceService implements InferenceService { - protected static int stringWeight(String input, int position) { - int hashCode = input.hashCode(); - if (hashCode < 0) { - hashCode = -hashCode; - } - return hashCode + position; - } - @Override public TransportVersion getMinimalSupportedVersion() { return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index a54b14d8fad18..b4d4bfa7bcfb5 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -22,6 +22,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -43,8 +44,22 @@ public List getInferenceServiceFactories() { return List.of(TestInferenceService::new); } + public static class TestDenseModel extends Model { + public TestDenseModel(String inferenceEntityId, TestDenseInferenceServiceExtension.TestServiceSettings serviceSettings) { + super( + new ModelConfigurations( + inferenceEntityId, + TaskType.TEXT_EMBEDDING, + TestDenseInferenceServiceExtension.TestInferenceService.NAME, + serviceSettings + ), + new ModelSecrets(new AbstractTestInferenceService.TestSecretSettings("api_key")) + ); + } + } + public static class TestInferenceService extends AbstractTestInferenceService { - private static final String NAME = "text_embedding_test_service"; + public static final String NAME = "text_embedding_test_service"; public TestInferenceService(InferenceServiceFactoryContext context) {} @@ -83,9 +98,10 @@ public void infer( ActionListener listener ) { switch (model.getConfigurations().getTaskType()) { - case ANY, TEXT_EMBEDDING -> listener.onResponse( - makeResults(input, ((TestServiceModel) model).getServiceSettings().dimensions()) - ); + case ANY, TEXT_EMBEDDING -> { + ServiceSettings modelServiceSettings = model.getServiceSettings(); + listener.onResponse(makeResults(input, modelServiceSettings.dimensions(), modelServiceSettings.similarity())); + } default -> listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), @@ -107,9 +123,10 @@ public void chunkedInfer( ActionListener> listener ) { switch (model.getConfigurations().getTaskType()) { - case ANY, TEXT_EMBEDDING -> listener.onResponse( - makeChunkedResults(input, ((TestServiceModel) model).getServiceSettings().dimensions()) - ); + case ANY, TEXT_EMBEDDING -> { + ServiceSettings modelServiceSettings = model.getServiceSettings(); + listener.onResponse(makeChunkedResults(input, modelServiceSettings.dimensions(), modelServiceSettings.similarity())); + } default -> listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), @@ -119,28 +136,30 @@ public void chunkedInfer( } } - private TextEmbeddingResults makeResults(List input, int dimensions) { + private TextEmbeddingResults makeResults(List input, int dimensions, SimilarityMeasure similarityMeasure) { List embeddings = new ArrayList<>(); for (int i = 0; i < input.size(); i++) { - List values = new ArrayList<>(); + double[] doubleEmbeddings = generateEmbedding(input.get(i), dimensions, similarityMeasure); + List floatEmbeddings = new ArrayList<>(dimensions); for (int j = 0; j < dimensions; j++) { - values.add((float) stringWeight(input.get(i), j)); + floatEmbeddings.add((float) doubleEmbeddings[j]); } - embeddings.add(new TextEmbeddingResults.Embedding(values)); + embeddings.add(new TextEmbeddingResults.Embedding(floatEmbeddings)); } return new TextEmbeddingResults(embeddings); } - private List makeChunkedResults(List input, int dimensions) { + private List makeChunkedResults( + List input, + int dimensions, + SimilarityMeasure similarityMeasure + ) { var results = new ArrayList(); for (int i = 0; i < input.size(); i++) { - double[] values = new double[dimensions]; - for (int j = 0; j < dimensions; j++) { - values[j] = stringWeight(input.get(i), j); - } + double[] embeddings = generateEmbedding(input.get(i), dimensions, similarityMeasure); results.add( new org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults( - List.of(new ChunkedTextEmbeddingResults.EmbeddingChunk(input.get(i), values)) + List.of(new ChunkedTextEmbeddingResults.EmbeddingChunk(input.get(i), embeddings)) ) ); } @@ -150,6 +169,15 @@ private List makeChunkedResults(List inp protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { return TestServiceSettings.fromMap(serviceSettingsMap); } + + private static double[] generateEmbedding(String input, int dimensions, SimilarityMeasure similarityMeasure) { + double[] embedding = new double[dimensions]; + for (int j = 0; j < dimensions; j++) { + embedding[j] = input.hashCode() + (double) j; + } + + return embedding; + } } public record TestServiceSettings(String model, Integer dimensions, SimilarityMeasure similarity) implements ServiceSettings { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 42b8ccd11a64b..d1632879355bc 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -22,6 +22,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; @@ -44,8 +45,17 @@ public List getInferenceServiceFactories() { return List.of(TestInferenceService::new); } + public static class TestSparseModel extends Model { + public TestSparseModel(String inferenceEntityId, TestServiceSettings serviceSettings) { + super( + new ModelConfigurations(inferenceEntityId, TaskType.SPARSE_EMBEDDING, TestInferenceService.NAME, serviceSettings), + new ModelSecrets(new AbstractTestInferenceService.TestSecretSettings("api_key")) + ); + } + } + public static class TestInferenceService extends AbstractTestInferenceService { - private static final String NAME = "test_service"; + public static final String NAME = "test_service"; public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {} @@ -121,7 +131,7 @@ private SparseEmbeddingResults makeResults(List input) { for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { - tokens.add(new SparseEmbeddingResults.WeightedToken("feature_" + j, stringWeight(input.get(i), j))); + tokens.add(new SparseEmbeddingResults.WeightedToken("feature_" + j, generateEmbedding(input.get(i), j))); } embeddings.add(new SparseEmbeddingResults.Embedding(tokens, false)); } @@ -133,7 +143,7 @@ private List makeChunkedResults(List inp for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { - tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, stringWeight(input.get(i), j))); + tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, generateEmbedding(input.get(i), j))); } results.add( new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens))) @@ -145,6 +155,11 @@ private List makeChunkedResults(List inp protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { return TestServiceSettings.fromMap(serviceSettingsMap); } + + private static float generateEmbedding(String input, int position) { + // Ensure non-negative and non-zero values for features + return Math.abs(input.hashCode()) + 1 + position; + } } public record TestServiceSettings(String model, String hiddenField, boolean shouldReturnHiddenField) implements ServiceSettings { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java new file mode 100644 index 0000000000000..6b4b658e23285 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java @@ -0,0 +1,198 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.update.UpdateRequestBuilder; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.Before; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; + +public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase { + + public static final String INDEX_NAME = "test-index"; + + @Before + public void setup() throws Exception { + storeSparseModel(); + storeDenseModel(); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(TestInferencePlugin.class); + } + + public void testBulkOperations() throws Exception { + Map shardsSettings = Collections.singletonMap(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10)); + indicesAdmin().prepareCreate(INDEX_NAME).setMapping(""" + { + "properties": { + "sparse_field": { + "type": "semantic_text", + "inference_id": "test_service" + }, + "dense_field": { + "type": "semantic_text", + "inference_id": "text_embedding_test_service" + } + } + } + """).setSettings(shardsSettings).get(); + + int totalBulkReqs = randomIntBetween(2, 100); + long totalDocs = 0; + for (int bulkReqs = 0; bulkReqs < totalBulkReqs; bulkReqs++) { + BulkRequestBuilder bulkReqBuilder = client().prepareBulk(); + int totalBulkSize = randomIntBetween(1, 100); + for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) { + String id = Long.toString(totalDocs); + boolean isIndexRequest = randomBoolean(); + Map source = new HashMap<>(); + source.put("sparse_field", isIndexRequest && rarely() ? null : randomAlphaOfLengthBetween(0, 1000)); + source.put("dense_field", isIndexRequest && rarely() ? null : randomAlphaOfLengthBetween(0, 1000)); + if (isIndexRequest) { + bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(id).setSource(source)); + totalDocs++; + } else { + boolean isUpsert = randomBoolean(); + UpdateRequestBuilder request = new UpdateRequestBuilder(client()).setIndex(INDEX_NAME).setDoc(source); + if (isUpsert || totalDocs == 0) { + request.setDocAsUpsert(true); + totalDocs++; + } else { + // Update already existing document + id = Long.toString(randomLongBetween(0, totalDocs - 1)); + } + request.setId(id); + bulkReqBuilder.add(request); + } + } + BulkResponse bulkResponse = bulkReqBuilder.get(); + if (bulkResponse.hasFailures()) { + // Get more details in case something fails + for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) { + if (bulkItemResponse.isFailed()) { + fail( + bulkItemResponse.getFailure().getCause(), + "Failed to index document %s: %s", + bulkItemResponse.getId(), + bulkItemResponse.getFailureMessage() + ); + } + } + } + assertFalse(bulkResponse.hasFailures()); + } + + client().admin().indices().refresh(new RefreshRequest(INDEX_NAME)).get(); + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(0).trackTotalHits(true); + SearchResponse searchResponse = client().search(new SearchRequest(INDEX_NAME).source(sourceBuilder)).get(); + assertThat(searchResponse.getHits().getTotalHits().value, equalTo(totalDocs)); + searchResponse.decRef(); + } + + private void storeSparseModel() throws Exception { + Model model = new TestSparseInferenceServiceExtension.TestSparseModel( + TestSparseInferenceServiceExtension.TestInferenceService.NAME, + new TestSparseInferenceServiceExtension.TestServiceSettings( + TestSparseInferenceServiceExtension.TestInferenceService.NAME, + null, + false + ) + ); + storeModel(model); + } + + private void storeDenseModel() throws Exception { + Model model = new TestDenseInferenceServiceExtension.TestDenseModel( + TestDenseInferenceServiceExtension.TestInferenceService.NAME, + new TestDenseInferenceServiceExtension.TestServiceSettings( + TestDenseInferenceServiceExtension.TestInferenceService.NAME, + randomIntBetween(1, 100), + // dot product means that we need normalized vectors; it's not worth doing that in this test + randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values())) + ) + ); + + storeModel(model); + } + + private void storeModel(Model model) throws Exception { + ModelRegistry modelRegistry = new ModelRegistry(client()); + + AtomicReference storeModelHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder); + + assertThat(storeModelHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + } + + private void blockingCall(Consumer> function, AtomicReference response, AtomicReference error) + throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + ActionListener listener = ActionListener.wrap(r -> { + response.set(r); + latch.countDown(); + }, e -> { + error.set(e); + latch.countDown(); + }); + + function.accept(listener); + latch.await(); + } + + public static class TestInferencePlugin extends InferencePlugin { + public TestInferencePlugin(Settings settings) { + super(settings); + } + + @Override + public List getInferenceServiceFactories() { + return List.of( + TestSparseInferenceServiceExtension.TestInferenceService::new, + TestDenseInferenceServiceExtension.TestInferenceService::new + ); + } + } +}