Skip to content

Commit

Permalink
[ML] Move code specific to the Elasticsearch in cluster services to t…
Browse files Browse the repository at this point in the history
…hose sevices (#113749)

Remove the platform arch argument from parseRequest and move code
used by internal services out of the transport action into the service.
  • Loading branch information
davidkyle authored Oct 1, 2024
1 parent 0fbb3bc commit 071e7ce
Show file tree
Hide file tree
Showing 44 changed files with 294 additions and 330 deletions.
9 changes: 3 additions & 6 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,9 @@ tests:
- class: org.elasticsearch.xpack.ml.integration.MlJobIT
method: testCreateJobsWithIndexNameOption
issue: https://github.com/elastic/elasticsearch/issues/113528
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
method: testPutE5WithTrainedModelAndInference
issue: https://github.com/elastic/elasticsearch/issues/113565
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
method: testPutE5Small_withPlatformAgnosticVariant
issue: https://github.com/elastic/elasticsearch/issues/113577
- class: org.elasticsearch.validation.DotPrefixClientYamlTestSuiteIT
method: test {p0=dot_prefix/10_basic/Deprecated index template with a dot prefix index pattern}
issue: https://github.com/elastic/elasticsearch/issues/113529
- class: org.elasticsearch.xpack.ml.integration.MlJobIT
method: testCantCreateJobWithSameID
issue: https://github.com/elastic/elasticsearch/issues/113581
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,9 @@ default void init(Client client) {}
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options including the secrets
* @param platformArchitectures The Set of platform architectures (OS name and hardware architecture)
* the cluster nodes and models are running on.
* @param parsedModelListener A listener which will handle the resulting model or failure
*/
void parseRequestConfig(
String modelId,
TaskType taskType,
Map<String, Object> config,
Set<String> platformArchitectures,
ActionListener<Model> parsedModelListener
);
void parseRequestConfig(String modelId, TaskType taskType, Map<String, Object> config, ActionListener<Model> parsedModelListener);

/**
* Parse model configuration from {@code config map} from persisted storage and return the parsed {@link Model}. This requires that
Expand Down Expand Up @@ -155,17 +147,6 @@ default void putModel(Model modelVariant, ActionListener<Boolean> listener) {
listener.onResponse(true);
}

/**
* Checks if the modelId has been downloaded to the local Elasticsearch cluster using the trained models API
* The default action does nothing except acknowledge the request (false).
* Any internal services should Override this method.
* @param model
* @param listener The listener
*/
default void isModelDownloaded(Model model, ActionListener<Boolean> listener) {
listener.onResponse(false);
};

/**
* Optionally test the new model configuration in the inference service.
* This function should be called when the model is first created, the
Expand All @@ -188,14 +169,6 @@ default Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
return model;
}

/**
* Return true if this model is hosted in the local Elasticsearch cluster
* @return True if in cluster
*/
default boolean isInClusterService() {
return false;
}

/**
* Defines the version required across all clusters to use this service
* @return {@link TransportVersion} specifying the version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.inference;

import org.elasticsearch.client.internal.Client;
import org.elasticsearch.threadpool.ThreadPool;

import java.util.List;

Expand All @@ -20,7 +21,7 @@ public interface InferenceServiceExtension {

List<Factory> getInferenceServiceFactories();

record InferenceServiceFactoryContext(Client client) {}
record InferenceServiceFactoryContext(Client client, ThreadPool threadPool) {}

interface Factory {
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ protected String getTestRestCluster() {
@Override
protected Settings restClientSettings() {
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
return Settings.builder()
.put(ThreadContext.PREFIX + ".Authorization", token)
.put(CLIENT_SOCKET_TIMEOUT, "120s") // Long timeout for model download
.build();
}

static String mockSparseServiceModelConfig() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class TestDenseInferenceServiceExtension implements InferenceServiceExtension {
@Override
Expand Down Expand Up @@ -76,7 +75,6 @@ public void parseRequestConfig(
String modelId,
TaskType taskType,
Map<String, Object> config,
Set<String> platformArchitectures,
ActionListener<Model> parsedModelListener
) {
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class TestRerankingServiceExtension implements InferenceServiceExtension {
@Override
Expand Down Expand Up @@ -67,7 +66,6 @@ public void parseRequestConfig(
String modelId,
TaskType taskType,
Map<String, Object> config,
Set<String> platformArchitectures,
ActionListener<Model> parsedModelListener
) {
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class TestSparseInferenceServiceExtension implements InferenceServiceExtension {
@Override
Expand Down Expand Up @@ -70,7 +69,6 @@ public void parseRequestConfig(
String modelId,
TaskType taskType,
Map<String, Object> config,
Set<String> platformArchitectures,
ActionListener<Model> parsedModelListener
) {
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ public void parseRequestConfig(
String modelId,
TaskType taskType,
Map<String, Object> config,
Set<String> platformArchitectures,
ActionListener<Model> parsedModelListener
) {
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.reindex.ReindexPlugin;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.InferencePlugin;
Expand Down Expand Up @@ -117,7 +118,9 @@ public void testGetModel() throws Exception {

assertEquals(model.getConfigurations().getService(), modelHolder.get().service());

var elserService = new ElserInternalService(new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class)));
var elserService = new ElserInternalService(
new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class), mock(ThreadPool.class))
);
ElserInternalModel roundTripModel = elserService.parsePersistedConfigWithSecrets(
modelHolder.get().inferenceEntityId(),
modelHolder.get().taskType(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ public Collection<?> createComponents(PluginServices services) {
);
}

var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client());
var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client(), services.threadPool());
// This must be done after the HttpRequestSenderFactory is created so that the services can get the
// reference correctly
var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
Expand Down Expand Up @@ -299,15 +299,17 @@ public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings sett

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settingsToUse) {
return List.of(
new ScalingExecutorBuilder(
UTILITY_THREAD_POOL_NAME,
0,
10,
TimeValue.timeValueMinutes(10),
false,
"xpack.inference.utility_thread_pool"
)
return List.of(inferenceUtilityExecutor(settings));
}

public static ExecutorBuilder<?> inferenceUtilityExecutor(Settings settings) {
return new ScalingExecutorBuilder(
UTILITY_THREAD_POOL_NAME,
0,
10,
TimeValue.timeValueMinutes(10),
false,
"xpack.inference.utility_thread_pool"
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.XContentHelper;
Expand All @@ -38,17 +36,14 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;

import java.io.IOException;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.core.Strings.format;

Expand Down Expand Up @@ -156,50 +151,20 @@ protected void masterOperation(
return;
}

if (service.get().isInClusterService()) {
// Find the cluster platform as the service may need that
// information when creating the model
MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet(listener.delegateFailureAndWrap((delegate, architectures) -> {
if (architectures.isEmpty() && clusterIsInElasticCloud(clusterService.getClusterSettings())) {
parseAndStoreModel(
service.get(),
request.getInferenceEntityId(),
resolvedTaskType,
requestAsMap,
// In Elastic cloud ml nodes run on Linux x86
Set.of("linux-x86_64"),
delegate
);
} else {
// The architecture field could be an empty set, the individual services will need to handle that
parseAndStoreModel(
service.get(),
request.getInferenceEntityId(),
resolvedTaskType,
requestAsMap,
architectures,
delegate
);
}
}), client, threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME));
} else {
// Not an in cluster service, it does not care about the cluster platform
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, Set.of(), listener);
}
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, listener);
}

private void parseAndStoreModel(
InferenceService service,
String inferenceEntityId,
TaskType taskType,
Map<String, Object> config,
Set<String> platformArchitectures,
ActionListener<PutInferenceModelAction.Response> listener
) {
ActionListener<Model> storeModelListener = listener.delegateFailureAndWrap(
(delegate, verifiedModel) -> modelRegistry.storeModel(
verifiedModel,
ActionListener.wrap(r -> putAndStartModel(service, verifiedModel, delegate), e -> {
ActionListener.wrap(r -> startInferenceEndpoint(service, verifiedModel, delegate), e -> {
if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) {
delegate.onFailure(
new ElasticsearchStatusException(
Expand All @@ -223,36 +188,15 @@ private void parseAndStoreModel(
}
});

service.parseRequestConfig(inferenceEntityId, taskType, config, platformArchitectures, parsedModelListener);

service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener);
}

private void putAndStartModel(InferenceService service, Model model, ActionListener<PutInferenceModelAction.Response> finalListener) {
SubscribableListener.<Boolean>newForked(listener -> {
var errorCatchingListener = ActionListener.<Boolean>wrap(listener::onResponse, e -> { listener.onResponse(false); });
service.isModelDownloaded(model, errorCatchingListener);
}).<Boolean>andThen((listener, isDownloaded) -> {
if (isDownloaded == false) {
service.putModel(model, listener);
} else {
listener.onResponse(true);
}
}).<PutInferenceModelAction.Response>andThen((listener, modelDidPut) -> {
if (modelDidPut) {
if (skipValidationAndStart) {
listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
} else {
service.start(
model,
listener.delegateFailureAndWrap(
(l3, ok) -> l3.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()))
)
);
}
} else {
logger.warn("Failed to put model [{}]", model.getInferenceEntityId());
}
}).addListener(finalListener);
private void startInferenceEndpoint(InferenceService service, Model model, ActionListener<PutInferenceModelAction.Response> listener) {
if (skipValidationAndStart) {
listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
} else {
service.start(model, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
}
}

private Map<String, Object> requestToMap(PutInferenceModelAction.Request request) throws IOException {
Expand All @@ -276,12 +220,6 @@ protected ClusterBlockException checkBlock(PutInferenceModelAction.Request reque
return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
}

static boolean clusterIsInElasticCloud(ClusterSettings settings) {
// use a heuristic to determine if in Elastic cloud.
// One such heuristic is where USE_AUTO_MACHINE_MEMORY_PERCENT == true
return settings.get(MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT);
}

/**
* task_type can be specified as either a URL parameter or in the
* request body. Resolve which to use or throw if the settings are
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.DEFAULT_TIMEOUT;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
Expand All @@ -69,7 +68,6 @@ public void parseRequestConfig(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> config,
Set<String> platformArchitectures,
ActionListener<Model> parsedModelListener
) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
Expand Down Expand Up @@ -121,7 +120,6 @@ public void parseRequestConfig(
String modelId,
TaskType taskType,
Map<String, Object> config,
Set<String> platformArchitectures,
ActionListener<Model> parsedModelListener
) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
Expand All @@ -58,7 +57,6 @@ public void parseRequestConfig(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> config,
Set<String> platformArchitectures,
ActionListener<Model> parsedModelListener
) {
try {
Expand Down
Loading

0 comments on commit 071e7ce

Please sign in to comment.