diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 5eb99085570a3..a65b8e43e6adf 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -18,7 +18,6 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; -import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 068aef404131a..b4d4bfa7bcfb5 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -100,9 +100,7 @@ public void infer( switch (model.getConfigurations().getTaskType()) { case ANY, TEXT_EMBEDDING -> { ServiceSettings modelServiceSettings = model.getServiceSettings(); - listener.onResponse( - makeResults(input, modelServiceSettings.dimensions(), modelServiceSettings.similarity()) - ); + listener.onResponse(makeResults(input, modelServiceSettings.dimensions(), modelServiceSettings.similarity())); } default -> listener.onFailure( new ElasticsearchStatusException( @@ -175,7 +173,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map serviceS private static double[] generateEmbedding(String input, int dimensions, SimilarityMeasure similarityMeasure) { double[] embedding = new double[dimensions]; for (int j = 0; j < dimensions; j++) { - embedding[j] = input.hashCode() + (double)j; + embedding[j] = input.hashCode() + (double) j; } return embedding; diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java index fb3e868de4a86..6b4b658e23285 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java @@ -130,34 +130,34 @@ public void testBulkOperations() throws Exception { } private void storeSparseModel() throws Exception { - ModelRegistry modelRegistry = new ModelRegistry(client()); - - String inferenceEntityId = TestSparseInferenceServiceExtension.TestInferenceService.NAME; Model model = new TestSparseInferenceServiceExtension.TestSparseModel( - inferenceEntityId, - new TestSparseInferenceServiceExtension.TestServiceSettings(inferenceEntityId, null, false) + TestSparseInferenceServiceExtension.TestInferenceService.NAME, + new TestSparseInferenceServiceExtension.TestServiceSettings( + TestSparseInferenceServiceExtension.TestInferenceService.NAME, + null, + false + ) ); - AtomicReference storeModelHolder = new AtomicReference<>(); - AtomicReference exceptionHolder = new AtomicReference<>(); - - blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder); - - assertThat(storeModelHolder.get(), is(true)); - assertThat(exceptionHolder.get(), is(nullValue())); + storeModel(model); } private void storeDenseModel() throws Exception { - ModelRegistry modelRegistry = new ModelRegistry(client()); - - String inferenceEntityId = TestDenseInferenceServiceExtension.TestInferenceService.NAME; Model model = new TestDenseInferenceServiceExtension.TestDenseModel( - inferenceEntityId, + TestDenseInferenceServiceExtension.TestInferenceService.NAME, new TestDenseInferenceServiceExtension.TestServiceSettings( - inferenceEntityId, + TestDenseInferenceServiceExtension.TestInferenceService.NAME, randomIntBetween(1, 100), + // dot product means that we need normalized vectors; it's not worth doing that in this test randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values())) ) ); + + storeModel(model); + } + + private void storeModel(Model model) throws Exception { + ModelRegistry modelRegistry = new ModelRegistry(client()); + AtomicReference storeModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>();