Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] improve trained model stats API performance #87978

Merged
5 changes: 5 additions & 0 deletions docs/changelog/87978.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 87978
summary: Improve trained model stats API performance
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors;
import static org.elasticsearch.xpack.ml.integration.ClassificationIT.KEYWORD_FIELD;
import static org.elasticsearch.xpack.ml.integration.MlNativeDataFrameAnalyticsIntegTestCase.buildAnalytics;
import static org.elasticsearch.xpack.ml.integration.PyTorchModelIT.BASE_64_ENCODED_MODEL;
Expand All @@ -56,6 +55,7 @@
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.createScheduledJob;
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.getDataCounts;
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.indexDocs;
import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.countInferenceProcessors;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -132,9 +132,7 @@ public void testMLFeatureReset() throws Exception {
client().execute(DeletePipelineAction.INSTANCE, new DeletePipelineRequest("feature_reset_inference_pipeline")).actionGet();
createdPipelines.remove("feature_reset_inference_pipeline");

assertBusy(
() -> assertThat(countNumberInferenceProcessors(client().admin().cluster().prepareState().get().getState()), equalTo(0))
);
assertBusy(() -> assertThat(countInferenceProcessors(client().admin().cluster().prepareState().get().getState()), equalTo(0)));
client().execute(ResetFeatureStateAction.INSTANCE, new ResetFeatureStateRequest()).actionGet();
assertBusy(() -> {
List<String> indices = Arrays.asList(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields.RESULTS_INDEX_PREFIX;
import static org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields.STATE_INDEX_PREFIX;
import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors;
import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.countInferenceProcessors;

public class MachineLearning extends Plugin
implements
Expand Down Expand Up @@ -1910,7 +1910,7 @@ public void cleanUpFeature(

// validate no pipelines are using machine learning models
ActionListener<AcknowledgedResponse> afterResetModeSet = ActionListener.wrap(acknowledgedResponse -> {
int numberInferenceProcessors = countNumberInferenceProcessors(clusterService.state());
int numberInferenceProcessors = countInferenceProcessors(clusterService.state());
if (numberInferenceProcessors > 0) {
unsetResetModeListener.onFailure(
new RuntimeException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
*/
package org.elasticsearch.xpack.ml.action;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
Expand All @@ -27,10 +26,7 @@
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.tasks.Task;
Expand All @@ -47,15 +43,13 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -65,29 +59,27 @@

import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsByModelIdsOrAliases;

public class TransportGetTrainedModelsStatsAction extends HandledTransportAction<
GetTrainedModelsStatsAction.Request,
GetTrainedModelsStatsAction.Response> {

private final Client client;
private final ClusterService clusterService;
private final IngestService ingestService;
private final TrainedModelProvider trainedModelProvider;

@Inject
public TransportGetTrainedModelsStatsAction(
TransportService transportService,
ActionFilters actionFilters,
ClusterService clusterService,
IngestService ingestService,
TrainedModelProvider trainedModelProvider,
Client client
) {
super(GetTrainedModelsStatsAction.NAME, transportService, actionFilters, GetTrainedModelsStatsAction.Request::new);
this.client = client;
this.clusterService = clusterService;
this.ingestService = ingestService;
this.trainedModelProvider = trainedModelProvider;
}

Expand Down Expand Up @@ -136,7 +128,6 @@ protected void doExecute(
.collect(Collectors.toSet());
Map<String, Set<String>> pipelineIdsByModelIdsOrAliases = pipelineIdsByModelIdsOrAliases(
clusterService.state(),
ingestService,
allPossiblePipelineReferences
);
Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByModelId(
Expand Down Expand Up @@ -270,37 +261,6 @@ static String[] ingestNodes(final ClusterState clusterState) {
return clusterState.nodes().getIngestNodes().keySet().toArray(String[]::new);
}

static Map<String, Set<String>> pipelineIdsByModelIdsOrAliases(ClusterState state, IngestService ingestService, Set<String> modelIds) {
IngestMetadata ingestMetadata = state.metadata().custom(IngestMetadata.TYPE);
Map<String, Set<String>> pipelineIdsByModelIds = new HashMap<>();
if (ingestMetadata == null) {
return pipelineIdsByModelIds;
}

ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> {
try {
Pipeline pipeline = Pipeline.create(
pipelineId,
pipelineConfiguration.getConfigAsMap(),
ingestService.getProcessorFactories(),
ingestService.getScriptService()
);
pipeline.getProcessors().forEach(processor -> {
if (processor instanceof InferenceProcessor inferenceProcessor) {
if (modelIds.contains(inferenceProcessor.getModelId())) {
pipelineIdsByModelIds.computeIfAbsent(inferenceProcessor.getModelId(), m -> new LinkedHashSet<>())
.add(pipelineId);
}
}
});
} catch (Exception ex) {
throw new ElasticsearchException("unexpected failure gathering pipeline information", ex);
}
});

return pipelineIdsByModelIds;
}

static IngestStats ingestStatsForPipelineIds(NodeStats nodeStats, Set<String> pipelineIds) {
IngestStats fullNodeStats = nodeStats.getIngestStats();
Map<String, List<IngestStats.ProcessorStat>> filteredProcessorStats = new HashMap<>(fullNodeStats.getProcessorStats());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,13 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.ingest.AbstractProcessor;
import org.elasticsearch.ingest.ConfigurationUtils;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.ingest.PipelineConfiguration;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
Expand Down Expand Up @@ -55,6 +51,7 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor;

import java.util.Collections;
import java.util.HashMap;
Expand All @@ -65,7 +62,6 @@
import java.util.function.Consumer;

import static org.elasticsearch.ingest.IngestDocument.INGEST_KEY;
import static org.elasticsearch.ingest.Pipeline.PROCESSORS_KEY;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.MODEL_ID_RESULTS_FIELD;
Expand Down Expand Up @@ -192,9 +188,6 @@ public String getType() {

public static final class Factory implements Processor.Factory, Consumer<ClusterState> {

private static final String FOREACH_PROCESSOR_NAME = "foreach";
// Any more than 10 nestings of processors, we stop searching for inference processor definitions
private static final int MAX_INFERENCE_PROCESSOR_SEARCH_RECURSIONS = 10;
private static final Logger logger = LogManager.getLogger(Factory.class);

private final Client client;
Expand All @@ -213,86 +206,12 @@ public Factory(Client client, ClusterService clusterService, Settings settings)
@Override
public void accept(ClusterState state) {
minNodeVersion = state.nodes().getMinNodeVersion();
currentInferenceProcessors = countNumberInferenceProcessors(state);
}

public static int countNumberInferenceProcessors(ClusterState state) {
Metadata metadata = state.getMetadata();
if (metadata == null) {
return 0;
}
IngestMetadata ingestMetadata = metadata.custom(IngestMetadata.TYPE);
if (ingestMetadata == null) {
return 0;
}

int count = 0;
for (PipelineConfiguration configuration : ingestMetadata.getPipelines().values()) {
Map<String, Object> configMap = configuration.getConfigAsMap();
try {
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
count += numInferenceProcessors(entry.getKey(), entry.getValue());
}
}
// We cannot throw any exception here. It might break other pipelines.
} catch (Exception ex) {
logger.debug(() -> "failed gathering processors for pipeline [" + configuration.getId() + "]", ex);
}
try {
currentInferenceProcessors = InferenceProcessorInfoExtractor.countInferenceProcessors(state);
} catch (Exception ex) {
// We cannot throw any exception here. It might break other pipelines.
logger.debug("failed gathering processors for pipelines", ex);
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
}
return count;
}

@SuppressWarnings("unchecked")
static int numInferenceProcessors(String processorType, Object processorDefinition) {
return numInferenceProcessors(processorType, (Map<String, Object>) processorDefinition, 0);
}

@SuppressWarnings("unchecked")
static int numInferenceProcessors(String processorType, Map<String, Object> processorDefinition, int level) {
int count = 0;
// arbitrary, but we must limit this somehow
if (level > MAX_INFERENCE_PROCESSOR_SEARCH_RECURSIONS) {
return count;
}
if (processorType == null || processorDefinition == null) {
return count;
}
if (TYPE.equals(processorType)) {
count++;
}
if (FOREACH_PROCESSOR_NAME.equals(processorType)) {
Map<String, Object> innerProcessor = (Map<String, Object>) processorDefinition.get("processor");
if (innerProcessor != null) {
// a foreach processor should only have a SINGLE nested processor. Iteration is for simplicity's sake.
for (Map.Entry<String, Object> innerProcessorWithName : innerProcessor.entrySet()) {
count += numInferenceProcessors(
innerProcessorWithName.getKey(),
(Map<String, Object>) innerProcessorWithName.getValue(),
level + 1
);
}
}
}
if (processorDefinition.containsKey(Pipeline.ON_FAILURE_KEY)) {
List<Map<String, Object>> onFailureConfigs = ConfigurationUtils.readList(
null,
null,
processorDefinition,
Pipeline.ON_FAILURE_KEY
);
count += onFailureConfigs.stream()
.flatMap(map -> map.entrySet().stream())
.mapToInt(entry -> numInferenceProcessors(entry.getKey(), (Map<String, Object>) entry.getValue(), level + 1))
.sum();
}
return count;
}

// Used for testing
int numInferenceProcessors() {
return currentInferenceProcessors;
}

@Override
Expand Down
Loading