From 071e7ce620baf3a3751855f29eaacc82463e3ca8 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 1 Oct 2024 15:23:09 +0100 Subject: [PATCH] [ML] Move code specific to the Elasticsearch in cluster services to those sevices (#113749) Remove the platform arch argument from parseRequest and move code used by internal services out of the transport action into the service. --- muted-tests.yml | 9 +- .../inference/InferenceService.java | 29 +---- .../inference/InferenceServiceExtension.java | 3 +- .../inference/InferenceBaseRestTest.java | 5 +- .../TestDenseInferenceServiceExtension.java | 2 - .../mock/TestRerankingServiceExtension.java | 2 - .../TestSparseInferenceServiceExtension.java | 2 - ...stStreamingCompletionServiceExtension.java | 1 - .../integration/ModelRegistryIT.java | 5 +- .../xpack/inference/InferencePlugin.java | 22 ++-- .../TransportPutInferenceModelAction.java | 80 ++----------- .../AlibabaCloudSearchService.java | 2 - .../amazonbedrock/AmazonBedrockService.java | 2 - .../services/anthropic/AnthropicService.java | 2 - .../azureaistudio/AzureAiStudioService.java | 2 - .../azureopenai/AzureOpenAiService.java | 2 - .../services/cohere/CohereService.java | 2 - .../elastic/ElasticInferenceService.java | 2 - .../BaseElasticsearchInternalService.java | 106 +++++++++++++----- .../ElasticsearchInternalModel.java | 6 + .../ElasticsearchInternalService.java | 16 ++- .../services/elser/ElserInternalService.java | 52 ++++++--- .../googleaistudio/GoogleAiStudioService.java | 2 - .../googlevertexai/GoogleVertexAiService.java | 2 - .../huggingface/HuggingFaceBaseService.java | 2 - .../ibmwatsonx/IbmWatsonxService.java | 2 - .../services/mistral/MistralService.java | 2 - .../services/openai/OpenAiService.java | 2 - .../services/SenderServiceTests.java | 2 - .../AlibabaCloudSearchServiceTests.java | 2 - .../AmazonBedrockServiceTests.java | 15 +-- .../anthropic/AnthropicServiceTests.java | 11 +- .../AzureAiStudioServiceTests.java | 27 ++--- .../azureopenai/AzureOpenAiServiceTests.java | 12 +- .../services/cohere/CohereServiceTests.java | 13 +-- .../elastic/ElasticInferenceServiceTests.java | 11 +- .../ElasticsearchInternalServiceTests.java | 64 ++++++++--- .../elser/ElserInternalServiceTests.java | 35 +++--- .../GoogleAiStudioServiceTests.java | 12 +- .../GoogleVertexAiServiceTests.java | 12 +- .../huggingface/HuggingFaceServiceTests.java | 9 +- .../ibmwatsonx/IbmWatsonxServiceTests.java | 5 +- .../services/mistral/MistralServiceTests.java | 10 +- .../services/openai/OpenAiServiceTests.java | 18 +-- 44 files changed, 294 insertions(+), 330 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index 26f51e6fdef21..5ff8254328793 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -278,12 +278,9 @@ tests: - class: org.elasticsearch.xpack.ml.integration.MlJobIT method: testCreateJobsWithIndexNameOption issue: https://github.com/elastic/elasticsearch/issues/113528 -- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT - method: testPutE5WithTrainedModelAndInference - issue: https://github.com/elastic/elasticsearch/issues/113565 -- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT - method: testPutE5Small_withPlatformAgnosticVariant - issue: https://github.com/elastic/elasticsearch/issues/113577 +- class: org.elasticsearch.validation.DotPrefixClientYamlTestSuiteIT + method: test {p0=dot_prefix/10_basic/Deprecated index template with a dot prefix index pattern} + issue: https://github.com/elastic/elasticsearch/issues/113529 - class: org.elasticsearch.xpack.ml.integration.MlJobIT method: testCantCreateJobWithSameID issue: https://github.com/elastic/elasticsearch/issues/113581 diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 854c58b4f57ad..aba644b392cec 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -39,17 +39,9 @@ default void init(Client client) {} * @param modelId Model Id * @param taskType The model task type * @param config Configuration options including the secrets - * @param platformArchitectures The Set of platform architectures (OS name and hardware architecture) - * the cluster nodes and models are running on. * @param parsedModelListener A listener which will handle the resulting model or failure */ - void parseRequestConfig( - String modelId, - TaskType taskType, - Map config, - Set platformArchitectures, - ActionListener parsedModelListener - ); + void parseRequestConfig(String modelId, TaskType taskType, Map config, ActionListener parsedModelListener); /** * Parse model configuration from {@code config map} from persisted storage and return the parsed {@link Model}. This requires that @@ -155,17 +147,6 @@ default void putModel(Model modelVariant, ActionListener listener) { listener.onResponse(true); } - /** - * Checks if the modelId has been downloaded to the local Elasticsearch cluster using the trained models API - * The default action does nothing except acknowledge the request (false). - * Any internal services should Override this method. - * @param model - * @param listener The listener - */ - default void isModelDownloaded(Model model, ActionListener listener) { - listener.onResponse(false); - }; - /** * Optionally test the new model configuration in the inference service. * This function should be called when the model is first created, the @@ -188,14 +169,6 @@ default Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { return model; } - /** - * Return true if this model is hosted in the local Elasticsearch cluster - * @return True if in cluster - */ - default boolean isInClusterService() { - return false; - } - /** * Defines the version required across all clusters to use this service * @return {@link TransportVersion} specifying the version diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java index b0502bb5b37fd..68dc865b4c7db 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java @@ -10,6 +10,7 @@ package org.elasticsearch.inference; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.threadpool.ThreadPool; import java.util.List; @@ -20,7 +21,7 @@ public interface InferenceServiceExtension { List getInferenceServiceFactories(); - record InferenceServiceFactoryContext(Client client) {} + record InferenceServiceFactoryContext(Client client, ThreadPool threadPool) {} interface Factory { /** diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index c19cd916055d3..55a6292ecd165 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -53,7 +53,10 @@ protected String getTestRestCluster() { @Override protected Settings restClientSettings() { String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); - return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); + return Settings.builder() + .put(ThreadContext.PREFIX + ".Authorization", token) + .put(CLIENT_SOCKET_TIMEOUT, "120s") // Long timeout for model download + .build(); } static String mockSparseServiceModelConfig() { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index daa29d33699ef..cd9a773f49f44 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -38,7 +38,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Set; public class TestDenseInferenceServiceExtension implements InferenceServiceExtension { @Override @@ -76,7 +75,6 @@ public void parseRequestConfig( String modelId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 1894db6db8df6..d8ee70986a57d 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -34,7 +34,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Set; public class TestRerankingServiceExtension implements InferenceServiceExtension { @Override @@ -67,7 +66,6 @@ public void parseRequestConfig( String modelId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 1a5df146a3aa4..6eb0caad36261 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -37,7 +37,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Set; public class TestSparseInferenceServiceExtension implements InferenceServiceExtension { @Override @@ -70,7 +69,6 @@ public void parseRequestConfig( String modelId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 4313026e92521..206aa1f3e5d28 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -67,7 +67,6 @@ public void parseRequestConfig( String modelId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 5157683f2dce9..524cd5014c19e 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -23,6 +23,7 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; @@ -117,7 +118,9 @@ public void testGetModel() throws Exception { assertEquals(model.getConfigurations().getService(), modelHolder.get().service()); - var elserService = new ElserInternalService(new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class))); + var elserService = new ElserInternalService( + new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class), mock(ThreadPool.class)) + ); ElserInternalModel roundTripModel = elserService.parsePersistedConfigWithSecrets( modelHolder.get().inferenceEntityId(), modelHolder.get().taskType(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 16bd0942c6c26..f2f019490444e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -206,7 +206,7 @@ public Collection createComponents(PluginServices services) { ); } - var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client()); + var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client(), services.threadPool()); // This must be done after the HttpRequestSenderFactory is created so that the services can get the // reference correctly var registry = new InferenceServiceRegistry(inferenceServices, factoryContext); @@ -299,15 +299,17 @@ public Collection getSystemIndexDescriptors(Settings sett @Override public List> getExecutorBuilders(Settings settingsToUse) { - return List.of( - new ScalingExecutorBuilder( - UTILITY_THREAD_POOL_NAME, - 0, - 10, - TimeValue.timeValueMinutes(10), - false, - "xpack.inference.utility_thread_pool" - ) + return List.of(inferenceUtilityExecutor(settings)); + } + + public static ExecutorBuilder inferenceUtilityExecutor(Settings settings) { + return new ScalingExecutorBuilder( + UTILITY_THREAD_POOL_NAME, + 0, + 10, + TimeValue.timeValueMinutes(10), + false, + "xpack.inference.utility_thread_pool" ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index ec54294432fe8..efd4d4dfd19d7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -12,7 +12,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.action.support.master.TransportMasterNodeAction; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; @@ -20,7 +19,6 @@ import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.XContentHelper; @@ -38,17 +36,14 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; -import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.io.IOException; import java.util.Map; -import java.util.Set; import static org.elasticsearch.core.Strings.format; @@ -156,36 +151,7 @@ protected void masterOperation( return; } - if (service.get().isInClusterService()) { - // Find the cluster platform as the service may need that - // information when creating the model - MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet(listener.delegateFailureAndWrap((delegate, architectures) -> { - if (architectures.isEmpty() && clusterIsInElasticCloud(clusterService.getClusterSettings())) { - parseAndStoreModel( - service.get(), - request.getInferenceEntityId(), - resolvedTaskType, - requestAsMap, - // In Elastic cloud ml nodes run on Linux x86 - Set.of("linux-x86_64"), - delegate - ); - } else { - // The architecture field could be an empty set, the individual services will need to handle that - parseAndStoreModel( - service.get(), - request.getInferenceEntityId(), - resolvedTaskType, - requestAsMap, - architectures, - delegate - ); - } - }), client, threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME)); - } else { - // Not an in cluster service, it does not care about the cluster platform - parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, Set.of(), listener); - } + parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, listener); } private void parseAndStoreModel( @@ -193,13 +159,12 @@ private void parseAndStoreModel( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener listener ) { ActionListener storeModelListener = listener.delegateFailureAndWrap( (delegate, verifiedModel) -> modelRegistry.storeModel( verifiedModel, - ActionListener.wrap(r -> putAndStartModel(service, verifiedModel, delegate), e -> { + ActionListener.wrap(r -> startInferenceEndpoint(service, verifiedModel, delegate), e -> { if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) { delegate.onFailure( new ElasticsearchStatusException( @@ -223,36 +188,15 @@ private void parseAndStoreModel( } }); - service.parseRequestConfig(inferenceEntityId, taskType, config, platformArchitectures, parsedModelListener); - + service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener); } - private void putAndStartModel(InferenceService service, Model model, ActionListener finalListener) { - SubscribableListener.newForked(listener -> { - var errorCatchingListener = ActionListener.wrap(listener::onResponse, e -> { listener.onResponse(false); }); - service.isModelDownloaded(model, errorCatchingListener); - }).andThen((listener, isDownloaded) -> { - if (isDownloaded == false) { - service.putModel(model, listener); - } else { - listener.onResponse(true); - } - }).andThen((listener, modelDidPut) -> { - if (modelDidPut) { - if (skipValidationAndStart) { - listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations())); - } else { - service.start( - model, - listener.delegateFailureAndWrap( - (l3, ok) -> l3.onResponse(new PutInferenceModelAction.Response(model.getConfigurations())) - ) - ); - } - } else { - logger.warn("Failed to put model [{}]", model.getInferenceEntityId()); - } - }).addListener(finalListener); + private void startInferenceEndpoint(InferenceService service, Model model, ActionListener listener) { + if (skipValidationAndStart) { + listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations())); + } else { + service.start(model, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations()))); + } } private Map requestToMap(PutInferenceModelAction.Request request) throws IOException { @@ -276,12 +220,6 @@ protected ClusterBlockException checkBlock(PutInferenceModelAction.Request reque return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); } - static boolean clusterIsInElasticCloud(ClusterSettings settings) { - // use a heuristic to determine if in Elastic cloud. - // One such heuristic is where USE_AUTO_MACHINE_MEMORY_PERCENT == true - return settings.get(MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT); - } - /** * task_type can be specified as either a URL parameter or in the * request body. Resolve which to use or throw if the settings are diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 994bad194aef6..0bd0eee1aa9a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -42,7 +42,6 @@ import java.util.List; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.DEFAULT_TIMEOUT; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -69,7 +68,6 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index c00932a169c24..bc0d10279ae44 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -41,7 +41,6 @@ import java.io.IOException; import java.util.List; import java.util.Map; -import java.util.Set; import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -121,7 +120,6 @@ public void parseRequestConfig( String modelId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index d7b945cd709fc..07c45e7c6e710 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -33,7 +33,6 @@ import java.util.List; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -58,7 +57,6 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index bd648250a509b..7981fb393a842 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -40,7 +40,6 @@ import java.util.List; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -110,7 +109,6 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index e22500cc9dad7..96399bb954cd2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -39,7 +39,6 @@ import java.util.List; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -65,7 +64,6 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 27f8fdf3a029a..728a4ac137dff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -39,7 +39,6 @@ import java.util.List; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -70,7 +69,6 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index abbe893823b96..7cfbc272aac5a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -39,7 +39,6 @@ import java.util.List; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -117,7 +116,6 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index 457416370e559..1dd7a36315c19 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -10,7 +10,9 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; @@ -28,12 +30,16 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; +import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; import java.io.IOException; import java.util.EnumSet; import java.util.List; import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -41,11 +47,29 @@ public abstract class BaseElasticsearchInternalService implements InferenceService { protected final OriginSettingClient client; + protected final ExecutorService inferenceExecutor; + protected final Consumer>> platformArch; private static final Logger logger = LogManager.getLogger(BaseElasticsearchInternalService.class); public BaseElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) { this.client = new OriginSettingClient(context.client(), ClientHelper.INFERENCE_ORIGIN); + this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME); + this.platformArch = this::platformArchitecture; + } + + // For testing. + // platformArchFn enables similating different architectures + // without extensive mocking on the client to simulate the nodes info response. + // TODO make package private once the elser service is moved to the Elasticsearch + // service package. + public BaseElasticsearchInternalService( + InferenceServiceExtension.InferenceServiceFactoryContext context, + Consumer>> platformArchFn + ) { + this.client = new OriginSettingClient(context.client(), ClientHelper.INFERENCE_ORIGIN); + this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME); + this.platformArch = platformArchFn; } /** @@ -55,24 +79,34 @@ public BaseElasticsearchInternalService(InferenceServiceExtension.InferenceServi protected abstract EnumSet supportedTaskTypes(); @Override - public void start(Model model, ActionListener listener) { - if (model instanceof ElasticsearchInternalModel == false) { - listener.onFailure(notElasticsearchModelException(model)); - return; - } + public void start(Model model, ActionListener finalListener) { + if (model instanceof ElasticsearchInternalModel esModel) { - if (supportedTaskTypes().contains(model.getTaskType()) == false) { - listener.onFailure( - new IllegalStateException(TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name())) - ); - return; - } + if (supportedTaskTypes().contains(model.getTaskType()) == false) { + finalListener.onFailure( + new IllegalStateException(TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name())) + ); + return; + } - var esModel = (ElasticsearchInternalModel) model; - var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(); - var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, listener); + SubscribableListener.newForked(forkedListener -> { isBuiltinModelPut(model, forkedListener); }) + .andThen((l, modelConfigExists) -> { + if (modelConfigExists == false) { + putModel(model, l); + } else { + l.onResponse(true); + } + }) + .andThen((l2, modelDidPut) -> { + var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(); + var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener); + client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener); + }) + .addListener(finalListener); - client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener); + } else { + finalListener.onFailure(notElasticsearchModelException(model)); + } } @Override @@ -136,13 +170,18 @@ private void putBuiltInModel(String modelId, ActionListener listener) { ); } - @Override - public void isModelDownloaded(Model model, ActionListener listener) { - ActionListener getModelsResponseListener = listener.delegateFailure((delegate, response) -> { + protected void isBuiltinModelPut(Model model, ActionListener listener) { + ActionListener getModelsResponseListener = ActionListener.wrap(response -> { if (response.getResources().count() < 1) { - delegate.onResponse(Boolean.FALSE); + listener.onResponse(Boolean.FALSE); + } else { + listener.onResponse(Boolean.TRUE); + } + }, exception -> { + if (exception instanceof ResourceNotFoundException) { + listener.onResponse(Boolean.FALSE); } else { - delegate.onResponse(Boolean.TRUE); + listener.onFailure(exception); } }); @@ -163,11 +202,6 @@ public void isModelDownloaded(Model model, ActionListener listener) { } } - @Override - public boolean isInClusterService() { - return true; - } - @Override public void close() throws IOException {} @@ -187,6 +221,28 @@ public static String selectDefaultModelVariantBasedOnClusterArchitecture( } } + private void platformArchitecture(ActionListener> platformArchitectureListener) { + // Find the cluster platform as the service may need that + // information when creating the model + MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet( + platformArchitectureListener.delegateFailureAndWrap((delegate, architectures) -> { + if (architectures.isEmpty() && clusterIsInElasticCloud()) { + // In Elastic cloud ml nodes run on Linux x86 + delegate.onResponse(Set.of("linux-x86_64")); + } else { + delegate.onResponse(architectures); + } + }), + client, + inferenceExecutor + ); + } + + static boolean clusterIsInElasticCloud() { + // use a heuristic to determine if in Elastic cloud. + return true; // TODO + } + public static InferModelAction.Request buildInferenceRequest( String id, InferenceConfigUpdate update, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java index 405c687839629..07d0cc14b2ac8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskSettings; @@ -56,4 +57,9 @@ public abstract ActionListener getC Model model, ActionListener listener ); + + @Override + public String toString() { + return Strings.toString(this.getConfigurations()); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 675bc275c8bd1..b0c0fb0b8e7cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -48,6 +48,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Consumer; import java.util.function.Function; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; @@ -70,6 +71,14 @@ public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFa super(context); } + // for testing + ElasticsearchInternalService( + InferenceServiceExtension.InferenceServiceFactoryContext context, + Consumer>> platformArch + ) { + super(context, platformArch); + } + @Override protected EnumSet supportedTaskTypes() { return EnumSet.of(TaskType.RERANK, TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING); @@ -80,7 +89,6 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener modelListener ) { try { @@ -94,7 +102,11 @@ public void parseRequestConfig( throw new ValidationException().addValidationError("Error parsing request config, model id is missing"); } if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) { - e5Case(inferenceEntityId, taskType, config, platformArchitectures, serviceSettingsMap, modelListener); + platformArch.accept( + modelListener.delegateFailureAndWrap( + (delegate, arch) -> e5Case(inferenceEntityId, taskType, config, arch, serviceSettingsMap, modelListener) + ) + ); } else { customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, modelListener); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java index 746cb6e89fad0..1198be7ab7a3b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java @@ -41,6 +41,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Consumer; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -57,6 +58,14 @@ public ElserInternalService(InferenceServiceExtension.InferenceServiceFactoryCon super(context); } + // for testing + ElserInternalService( + InferenceServiceExtension.InferenceServiceFactoryContext context, + Consumer>> platformArch + ) { + super(context, platformArch); + } + @Override protected EnumSet supportedTaskTypes() { return EnumSet.of(TaskType.SPARSE_EMBEDDING); @@ -67,19 +76,12 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set modelArchitectures, ActionListener parsedModelListener ) { try { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); var serviceSettingsBuilder = ElserInternalServiceSettings.fromRequestMap(serviceSettingsMap); - if (serviceSettingsBuilder.getModelId() == null) { - serviceSettingsBuilder.setModelId( - selectDefaultModelVariantBasedOnClusterArchitecture(modelArchitectures, ELSER_V2_MODEL_LINUX_X86, ELSER_V2_MODEL) - ); - } - Map taskSettingsMap; // task settings are optional if (config.containsKey(ModelConfigurations.TASK_SETTINGS)) { @@ -94,15 +96,33 @@ public void parseRequestConfig( throwIfNotEmptyMap(serviceSettingsMap, NAME); throwIfNotEmptyMap(taskSettingsMap, NAME); - parsedModelListener.onResponse( - new ElserInternalModel( - inferenceEntityId, - taskType, - NAME, - new ElserInternalServiceSettings(serviceSettingsBuilder.build()), - taskSettings - ) - ); + if (serviceSettingsBuilder.getModelId() == null) { + platformArch.accept(parsedModelListener.delegateFailureAndWrap((delegate, arch) -> { + serviceSettingsBuilder.setModelId( + selectDefaultModelVariantBasedOnClusterArchitecture(arch, ELSER_V2_MODEL_LINUX_X86, ELSER_V2_MODEL) + ); + })); + + parsedModelListener.onResponse( + new ElserInternalModel( + inferenceEntityId, + taskType, + NAME, + new ElserInternalServiceSettings(serviceSettingsBuilder.build()), + taskSettings + ) + ); + } else { + parsedModelListener.onResponse( + new ElserInternalModel( + inferenceEntityId, + taskType, + NAME, + new ElserInternalServiceSettings(serviceSettingsBuilder.build()), + taskSettings + ) + ); + } } catch (Exception e) { parsedModelListener.onFailure(e); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 422fc5b0ed720..b84dadc27dd77 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -39,7 +39,6 @@ import java.util.List; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -66,7 +65,6 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 2bbf219438280..d9d8850048564 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -38,7 +38,6 @@ import java.util.List; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -65,7 +64,6 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parseModelListener ) { try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index d129a0c44e835..fb4221c0d5ebc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -23,7 +23,6 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -48,7 +47,6 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 895ebaa66c806..14be1a70b5daa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -38,7 +38,6 @@ import java.util.List; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -65,7 +64,6 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 221951f7a621e..47191ba96cb82 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -37,7 +37,6 @@ import java.util.List; import java.util.Map; -import java.util.Set; import static org.elasticsearch.TransportVersions.ADD_MISTRAL_EMBEDDINGS_INFERENCE; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -110,7 +109,6 @@ public void parseRequestConfig( String modelId, TaskType taskType, Map config, - Set platfromArchitectures, ActionListener parsedModelListener ) { try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index f9565a915124f..bba8721c48c88 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -42,7 +42,6 @@ import java.util.List; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -69,7 +68,6 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { try { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 6ad17424dbcaa..a063a398a4947 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -29,7 +29,6 @@ import java.io.IOException; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; @@ -138,7 +137,6 @@ public void parseRequestConfig( String inferenceEntityId, TaskType taskType, Map config, - Set platformArchitectures, ActionListener parsedModelListener ) { parsedModelListener.onResponse(null); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 9d9dbfaf86c15..e8c34eec96171 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -46,7 +46,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; @@ -95,7 +94,6 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), getSecretSettingsMap("secret") ), - Set.of(), modelVerificationListener ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 297a42f9d1fa7..0b3cf533d818f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -52,7 +52,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; @@ -111,7 +110,6 @@ public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOExcept Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret") ), - Set.of(), modelVerificationListener ); } @@ -135,7 +133,6 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret") ), - Set.of(), modelVerificationListener ); } @@ -159,7 +156,6 @@ public void testCreateModel_ForEmbeddingsTask_InvalidProvider() throws IOExcepti Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret") ), - Set.of(), modelVerificationListener ); } @@ -183,7 +179,6 @@ public void testCreateModel_TopKParameter_NotAvailable() throws IOException { getChatCompletionTaskSettingsMap(1.0, 0.5, 0.2, 128), getAmazonBedrockSecretSettingsMap("access", "secret") ), - Set.of(), modelVerificationListener ); } @@ -210,7 +205,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -231,7 +226,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa ); }); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -255,7 +250,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() ); }); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, modelVerificationListener); } } @@ -279,7 +274,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap ); }); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, modelVerificationListener); } } @@ -305,7 +300,6 @@ public void testParseRequestConfig_MovesModel() throws IOException { Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret") ), - Set.of(), modelVerificationListener ); } @@ -329,7 +323,6 @@ public void testCreateModel_ForEmbeddingsTask_DimensionsIsNotAllowed() throws IO Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret") ), - Set.of(), modelVerificationListener ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index c3693c227c435..c502425b22ac3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -43,7 +43,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.buildExpectationCompletions; @@ -108,7 +107,6 @@ public void testParseRequestConfig_CreatesACompletionModel() throws IOException new HashMap<>(Map.of(AnthropicServiceFields.MAX_TOKENS, 1)), getSecretSettingsMap(apiKey) ), - Set.of(), modelListener ); } @@ -129,7 +127,6 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti new HashMap<>(Map.of()), getSecretSettingsMap("secret") ), - Set.of(), failureListener ); } @@ -148,7 +145,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" ); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } } @@ -167,7 +164,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" ); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } } @@ -186,7 +183,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" ); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } } @@ -205,7 +202,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" ); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index bb736f592fbdb..9c3afc68306b2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -56,7 +56,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; @@ -120,7 +119,6 @@ public void testParseRequestConfig_CreatesAnAzureAiStudioEmbeddingsModel() throw getEmbeddingsTaskSettingsMap("user"), getSecretSettingsMap("secret") ), - Set.of(), modelVerificationListener ); } @@ -148,7 +146,6 @@ public void testParseRequestConfig_CreatesAnAzureAiStudioChatCompletionModel() t getChatCompletionTaskSettingsMap(null, null, true, null), getSecretSettingsMap("secret") ), - Set.of(), modelVerificationListener ); } @@ -172,7 +169,6 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti getChatCompletionTaskSettingsMap(null, null, true, null), getSecretSettingsMap("secret") ), - Set.of(), modelVerificationListener ); } @@ -198,7 +194,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I } ); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, modelVerificationListener); } } @@ -220,7 +216,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingServiceS } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -243,7 +239,7 @@ public void testParseRequestConfig_ThrowsWhenDimsSetByUserExistsInEmbeddingServi } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -269,7 +265,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingTaskSett } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -295,7 +291,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSe } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -321,7 +317,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionSer } ); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, modelVerificationListener); } } @@ -347,7 +343,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionTas } ); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, modelVerificationListener); } } @@ -373,7 +369,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionSec } ); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, modelVerificationListener); } } @@ -391,7 +387,7 @@ public void testParseRequestConfig_ThrowsWhenProviderIsNotValidForEmbeddings() t } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -412,7 +408,7 @@ public void testParseRequestConfig_ThrowsWhenEndpointTypeIsNotValidForEmbeddings } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -437,7 +433,7 @@ public void testParseRequestConfig_ThrowsWhenEndpointTypeIsNotValidForChatComple } ); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, modelVerificationListener); } } @@ -506,7 +502,6 @@ public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOExcep getChatCompletionTaskSettingsMap(null, null, true, null), getSecretSettingsMap("secret") ), - Set.of(), modelVerificationListener ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 142877c09180f..098e41b72ea8f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -50,7 +50,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; @@ -118,7 +117,6 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModel() throws IOExc getAzureOpenAiRequestTaskSettingsMap("user"), getAzureOpenAiSecretSettingsMap("secret", null) ), - Set.of(), modelVerificationListener ); } @@ -142,7 +140,6 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti getAzureOpenAiRequestTaskSettingsMap("user"), getAzureOpenAiSecretSettingsMap("secret", null) ), - Set.of(), modelVerificationListener ); } @@ -168,7 +165,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -193,7 +190,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa ); }); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -218,7 +215,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() ); }); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -243,7 +240,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap ); }); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -268,7 +265,6 @@ public void testParseRequestConfig_MovesModel() throws IOException { getAzureOpenAiRequestTaskSettingsMap("user"), getAzureOpenAiSecretSettingsMap("secret", null) ), - Set.of(), modelVerificationListener ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index a577a6664d39d..22503108b5262 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -54,7 +54,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; @@ -123,7 +122,6 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModel() throws IOExce getTaskSettingsMap(InputType.INGEST, CohereTruncation.START), getSecretSettingsMap("secret") ), - Set.of(), modelListener ); @@ -151,7 +149,6 @@ public void testParseRequestConfig_OptionalTaskSettings() throws IOException { CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", CohereEmbeddingType.FLOAT), getSecretSettingsMap("secret") ), - Set.of(), modelListener ); @@ -173,7 +170,6 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti getTaskSettingsMapEmpty(), getSecretSettingsMap("secret") ), - Set.of(), failureListener ); } @@ -199,7 +195,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } } @@ -214,7 +210,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } } @@ -233,7 +229,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } } @@ -253,7 +249,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [cohere] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } } @@ -276,7 +272,6 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModelWithoutUrl() thr getTaskSettingsMapEmpty(), getSecretSettingsMap("secret") ), - Set.of(), modelListener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 0bbf2be7301d8..ab85e112418f5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -46,7 +46,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; @@ -104,7 +103,6 @@ public void testParseRequestConfig_CreatesASparseEmbeddingsModel() throws IOExce "id", TaskType.SPARSE_EMBEDDING, getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL), Map.of(), Map.of()), - Set.of(), modelListener ); } @@ -121,7 +119,6 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti "id", TaskType.COMPLETION, getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL), Map.of(), Map.of()), - Set.of(), failureListener ); } @@ -136,7 +133,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [elastic] service" ); - service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, failureListener); } } @@ -151,7 +148,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [elastic] service" ); - service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, failureListener); } } @@ -165,7 +162,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [elastic] service" ); - service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, failureListener); } } @@ -179,7 +176,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [elastic] service" ); - service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, failureListener); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 8569117c348b1..de9298f1b08dd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkedInferenceServiceResults; @@ -49,6 +50,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; import org.elasticsearch.xpack.core.utils.FloatConversionUtils; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.junit.After; import org.junit.Before; @@ -86,7 +88,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase { @Before public void setUpThreadPool() { - threadPool = new TestThreadPool("test"); + threadPool = createThreadPool(InferencePlugin.inferenceUtilityExecutor(Settings.EMPTY)); } @After @@ -110,7 +112,7 @@ public void testParseRequestConfig() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, settings, modelListener); } public void testParseRequestConfig_Misconfigured() { @@ -131,7 +133,7 @@ public void testParseRequestConfig_Misconfigured() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, settings, modelListener); } // Invalid config map @@ -152,13 +154,13 @@ public void testParseRequestConfig_Misconfigured() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, settings, modelListener); } } public void testParseRequestConfig_E5() { { - var service = createService(mock(Client.class)); + var service = createService(mock(Client.class), Set.of("Aarch64")); var settings = new HashMap(); settings.put( ModelConfigurations.SERVICE_SETTINGS, @@ -180,14 +182,45 @@ public void testParseRequestConfig_E5() { randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings, - Set.of(), + getModelVerificationActionListener(e5ServiceSettings) + ); + } + + { + var service = createService(mock(Client.class), Set.of("linux-x86_64")); + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, + 1, + ElasticsearchInternalServiceSettings.NUM_THREADS, + 4, + ElasticsearchInternalServiceSettings.MODEL_ID, + ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 + ) + ) + ); + + var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings( + 1, + 4, + ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86, + null + ); + + service.parseRequestConfig( + randomInferenceEntityId, + TaskType.TEXT_EMBEDDING, + settings, getModelVerificationActionListener(e5ServiceSettings) ); } // Invalid service settings { - var service = createService(mock(Client.class)); + var service = createService(mock(Client.class), Set.of("Aarch64")); var settings = new HashMap(); settings.put( ModelConfigurations.SERVICE_SETTINGS, @@ -210,7 +243,7 @@ public void testParseRequestConfig_E5() { e -> assertThat(e, instanceOf(ElasticsearchStatusException.class)) ); - service.parseRequestConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings, Set.of(), modelListener); + service.parseRequestConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings, modelListener); } } @@ -257,7 +290,7 @@ public void testParseRequestConfig_Rerank() { assertEquals(returnDocs, ((CustomElandRerankTaskSettings) model.getTaskSettings()).returnDocuments()); }, e -> { fail("Model parsing failed " + e.getMessage()); }); - service.parseRequestConfig(randomInferenceEntityId, TaskType.RERANK, settings, Set.of(), modelListener); + service.parseRequestConfig(randomInferenceEntityId, TaskType.RERANK, settings, modelListener); } } @@ -299,7 +332,7 @@ public void testParseRequestConfig_Rerank_DefaultTaskSettings() { assertEquals(Boolean.TRUE, ((CustomElandRerankTaskSettings) model.getTaskSettings()).returnDocuments()); }, e -> { fail("Model parsing failed " + e.getMessage()); }); - service.parseRequestConfig(randomInferenceEntityId, TaskType.RERANK, settings, Set.of(), modelListener); + service.parseRequestConfig(randomInferenceEntityId, TaskType.RERANK, settings, modelListener); } } @@ -338,7 +371,7 @@ public void testParseRequestConfig_SparseEmbedding() { assertThat(model.getServiceSettings(), instanceOf(CustomElandInternalServiceSettings.class)); }, e -> { fail("Model parsing failed " + e.getMessage()); }); - service.parseRequestConfig(randomInferenceEntityId, TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelListener); + service.parseRequestConfig(randomInferenceEntityId, TaskType.SPARSE_EMBEDDING, settings, modelListener); } private ActionListener getModelVerificationActionListener(MultilingualE5SmallInternalServiceSettings e5ServiceSettings) { @@ -787,7 +820,7 @@ public void testParseRequestConfigEland_PreservesTaskType() { CustomElandModel expectedModel = getCustomElandModel(taskType); PlainActionFuture listener = new PlainActionFuture<>(); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), listener); + service.parseRequestConfig(randomInferenceEntityId, taskType, settings, listener); var model = listener.actionGet(TimeValue.THIRTY_SECONDS); assertThat(model, is(expectedModel)); } @@ -968,7 +1001,12 @@ public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() } private ElasticsearchInternalService createService(Client client) { - var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); + var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool); return new ElasticsearchInternalService(context); } + + private ElasticsearchInternalService createService(Client client, Set architectures) { + var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool); + return new ElasticsearchInternalService(context, l -> l.onResponse(architectures)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java index 85add1a0090c8..adf9c7b4f5bc5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; import org.elasticsearch.inference.InferenceResults; @@ -33,6 +34,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceChunkedTextExpansionResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.junit.After; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -64,7 +66,7 @@ public class ElserInternalServiceTests extends ESTestCase { @Before public void setUpThreadPool() { - threadPool = new TestThreadPool("test"); + threadPool = createThreadPool(InferencePlugin.inferenceUtilityExecutor(Settings.EMPTY)); } @After @@ -114,7 +116,7 @@ public void testParseConfigStrict() { var modelVerificationListener = getModelVerificationListener(expectedModel); - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelVerificationListener); + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelVerificationListener); } @@ -159,7 +161,7 @@ private static ActionListener getModelVerificationListener(ElserInternalM } public void testParseConfigStrictWithNoTaskSettings() { - var service = createService(mock(Client.class)); + var service = createService(mock(Client.class), Set.of("Aarch64")); var settings = new HashMap(); settings.put( @@ -177,8 +179,7 @@ public void testParseConfigStrictWithNoTaskSettings() { var modelVerificationListener = getModelVerificationListener(expectedModel); - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelVerificationListener); - + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelVerificationListener); } public void testParseConfigStrictWithUnknownSettings() { @@ -228,7 +229,7 @@ public void testParseConfigStrictWithUnknownSettings() { ); } else { - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, errorVerificationListener); } } @@ -271,7 +272,7 @@ public void testParseConfigStrictWithUnknownSettings() { Collections.emptyMap() ); } else { - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, errorVerificationListener); } } @@ -316,15 +317,15 @@ public void testParseConfigStrictWithUnknownSettings() { Collections.emptyMap() ); } else { - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, errorVerificationListener); } } } } public void testParseRequestConfig_DefaultModel() { - var service = createService(mock(Client.class)); { + var service = createService(mock(Client.class), Set.of()); var settings = new HashMap(); settings.put( ModelConfigurations.SERVICE_SETTINGS, @@ -333,11 +334,12 @@ public void testParseRequestConfig_DefaultModel() { ActionListener modelActionListener = ActionListener.wrap((model) -> { assertEquals(".elser_model_2", ((ElserInternalModel) model).getServiceSettings().modelId()); - }, (e) -> { fail("Model verification should not fail"); }); + }, (e) -> { fail(e, "Model verification should not fail"); }); - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelActionListener); + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelActionListener); } { + var service = createService(mock(Client.class), Set.of("linux-x86_64")); var settings = new HashMap(); settings.put( ModelConfigurations.SERVICE_SETTINGS, @@ -346,9 +348,9 @@ public void testParseRequestConfig_DefaultModel() { ActionListener modelActionListener = ActionListener.wrap((model) -> { assertEquals(".elser_model_2_linux-x86_64", ((ElserInternalModel) model).getServiceSettings().modelId()); - }, (e) -> { fail("Model verification should not fail"); }); + }, (e) -> { fail(e, "Model verification should not fail"); }); - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of("linux-x86_64"), modelActionListener); + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelActionListener); } } @@ -510,7 +512,12 @@ public void onFailure(Exception e) { } private ElserInternalService createService(Client client) { - var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); + var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool); return new ElserInternalService(context); } + + private ElserInternalService createService(Client client, Set architectures) { + var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool); + return new ElserInternalService(context, (l) -> l.onResponse(architectures)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 5d79d0e01f401..f311340101279 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -52,7 +52,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -122,7 +121,6 @@ public void testParseRequestConfig_CreatesAGoogleAiStudioCompletionModel() throw new HashMap<>(Map.of()), getSecretSettingsMap(apiKey) ), - Set.of(), modelListener ); } @@ -149,7 +147,6 @@ public void testParseRequestConfig_CreatesAGoogleAiStudioEmbeddingsModel() throw new HashMap<>(Map.of()), getSecretSettingsMap(apiKey) ), - Set.of(), modelListener ); } @@ -170,7 +167,6 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti new HashMap<>(Map.of()), getSecretSettingsMap("secret") ), - Set.of(), failureListener ); } @@ -189,7 +185,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" ); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } } @@ -204,7 +200,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" ); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } } @@ -223,7 +219,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" ); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } } @@ -242,7 +238,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" ); - service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.COMPLETION, config, failureListener); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index d8c727c5a58bc..6a96d289a8190 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -33,7 +33,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; @@ -103,7 +102,6 @@ public void testParseRequestConfig_CreatesGoogleVertexAiEmbeddingsModel() throws new HashMap<>(Map.of()), getSecretSettingsMap(serviceAccountJson) ), - Set.of(), modelListener ); } @@ -135,7 +133,6 @@ public void testParseRequestConfig_CreatesGoogleVertexAiRerankModel() throws IOE new HashMap<>(Map.of()), getSecretSettingsMap(serviceAccountJson) ), - Set.of(), modelListener ); } @@ -165,7 +162,6 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti new HashMap<>(Map.of()), getSecretSettingsMap("{}") ), - Set.of(), failureListener ); } @@ -193,7 +189,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } } @@ -217,7 +213,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } } @@ -245,7 +241,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } } @@ -273,7 +269,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index d13dea2ab6b4c..1645b92e63b20 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -50,7 +50,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResultsTests.asMapWithListsInsteadOfArrays; @@ -104,7 +103,6 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap(getServiceSettingsMap("url"), getSecretSettingsMap("secret")), - Set.of(), modelVerificationActionListener ); } @@ -124,7 +122,6 @@ public void testParseRequestConfig_CreatesAnElserModel() throws IOException { "id", TaskType.SPARSE_EMBEDDING, getRequestConfigMap(getServiceSettingsMap("url"), getSecretSettingsMap("secret")), - Set.of(), modelVerificationActionListener ); } @@ -146,7 +143,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationActionListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationActionListener); } } @@ -168,7 +165,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationActionListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationActionListener); } } @@ -190,7 +187,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationActionListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationActionListener); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index a2de7c15d54da..f8f08e6f880ab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -56,7 +56,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; @@ -139,7 +138,6 @@ public void testParseRequestConfig_CreatesAIbmWatsonxEmbeddingsModel() throws IO new HashMap<>(Map.of()), getSecretSettingsMap(apiKey) ), - Set.of(), modelListener ); } @@ -160,7 +158,6 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti new HashMap<>(Map.of()), getSecretSettingsMap("secret") ), - Set.of(), failureListener ); } @@ -192,7 +189,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap ElasticsearchStatusException.class, "Model configuration contains settings [{extra_key=value}] unknown to the [watsonxai] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), failureListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 33a2b43caf174..76075fc85f202 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -52,7 +52,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; @@ -114,7 +113,6 @@ public void testParseRequestConfig_CreatesAMistralEmbeddingsModel() throws IOExc getEmbeddingsTaskSettingsMap(), getSecretSettingsMap("secret") ), - Set.of(), modelVerificationListener ); } @@ -138,7 +136,6 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti getEmbeddingsTaskSettingsMap(), getSecretSettingsMap("secret") ), - Set.of(), modelVerificationListener ); } @@ -164,7 +161,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -190,7 +187,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingTaskSett } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -216,7 +213,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSe } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -258,7 +255,6 @@ public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOExcep getEmbeddingsTaskSettingsMap(), getSecretSettingsMap("secret") ), - Set.of(), modelVerificationListener ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 32099c4bd0be9..508da45ac2fc2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -53,7 +53,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; @@ -124,7 +123,6 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModel() throws IOExc getTaskSettingsMap("user"), getSecretSettingsMap("secret") ), - Set.of(), modelVerificationListener ); } @@ -159,7 +157,6 @@ public void testParseRequestConfig_CreatesAnOpenAiChatCompletionsModel() throws getTaskSettingsMap(user), getSecretSettingsMap(secret) ), - Set.of(), modelVerificationListener ); } @@ -183,7 +180,6 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti getTaskSettingsMap("user"), getSecretSettingsMap("secret") ), - Set.of(), modelVerificationListener ); } @@ -209,7 +205,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -227,7 +223,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa assertThat(e.getMessage(), is("Model configuration contains settings [{extra_key=value}] unknown to the [openai] service")); }); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -245,7 +241,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() assertThat(e.getMessage(), is("Model configuration contains settings [{extra_key=value}] unknown to the [openai] service")); }); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -263,7 +259,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap assertThat(e.getMessage(), is("Model configuration contains settings [{extra_key=value}] unknown to the [openai] service")); }); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } } @@ -284,7 +280,6 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlO "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap(getServiceSettingsMap("model", null, null), getTaskSettingsMap(null), getSecretSettingsMap("secret")), - Set.of(), modelVerificationListener ); } @@ -311,7 +306,6 @@ public void testParseRequestConfig_CreatesAnOpenAiChatCompletionsModelWithoutUse "id", TaskType.COMPLETION, getRequestConfigMap(getServiceSettingsMap(model, null, null), getTaskSettingsMap(null), getSecretSettingsMap(secret)), - Set.of(), modelVerificationListener ); } @@ -338,7 +332,6 @@ public void testParseRequestConfig_MovesModel() throws IOException { getTaskSettingsMap("user"), getSecretSettingsMap("secret") ), - Set.of(), modelVerificationListener ); } @@ -365,7 +358,6 @@ public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkin createRandomChunkingSettingsMap(), getSecretSettingsMap("secret") ), - Set.of(), modelVerificationListener ); } @@ -396,7 +388,6 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSet createRandomChunkingSettingsMap(), getSecretSettingsMap("secret") ), - Set.of(), modelVerificationListener ); } @@ -422,7 +413,6 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSet "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap(getServiceSettingsMap("model", null, null), getTaskSettingsMap(null), getSecretSettingsMap("secret")), - Set.of(), modelVerificationListener ); }