Skip to content

Commit

Permalink
[ML] improve trained model stats API performance (#87978)
Browse files Browse the repository at this point in the history
Previous, get trained model stats API would build every pipeline defined in cluster state.

This is problematic when MANY pipelines are defined. Especially if those pipelines take some time to parse (consider GROK).

This improvement is part of fixing: #87931
  • Loading branch information
benwtrent authored Jun 28, 2022
1 parent 87bc952 commit 6847c0b
Show file tree
Hide file tree
Showing 9 changed files with 387 additions and 300 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/87978.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 87978
summary: Improve trained model stats API performance
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors;
import static org.elasticsearch.xpack.ml.integration.ClassificationIT.KEYWORD_FIELD;
import static org.elasticsearch.xpack.ml.integration.MlNativeDataFrameAnalyticsIntegTestCase.buildAnalytics;
import static org.elasticsearch.xpack.ml.integration.PyTorchModelIT.BASE_64_ENCODED_MODEL;
Expand All @@ -56,6 +55,7 @@
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.createScheduledJob;
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.getDataCounts;
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.indexDocs;
import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.countInferenceProcessors;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -132,9 +132,7 @@ public void testMLFeatureReset() throws Exception {
client().execute(DeletePipelineAction.INSTANCE, new DeletePipelineRequest("feature_reset_inference_pipeline")).actionGet();
createdPipelines.remove("feature_reset_inference_pipeline");

assertBusy(
() -> assertThat(countNumberInferenceProcessors(client().admin().cluster().prepareState().get().getState()), equalTo(0))
);
assertBusy(() -> assertThat(countInferenceProcessors(client().admin().cluster().prepareState().get().getState()), equalTo(0)));
client().execute(ResetFeatureStateAction.INSTANCE, new ResetFeatureStateRequest()).actionGet();
assertBusy(() -> {
List<String> indices = Arrays.asList(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields.RESULTS_INDEX_PREFIX;
import static org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields.STATE_INDEX_PREFIX;
import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors;
import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.countInferenceProcessors;

public class MachineLearning extends Plugin
implements
Expand Down Expand Up @@ -1910,7 +1910,7 @@ public void cleanUpFeature(

// validate no pipelines are using machine learning models
ActionListener<AcknowledgedResponse> afterResetModeSet = ActionListener.wrap(acknowledgedResponse -> {
int numberInferenceProcessors = countNumberInferenceProcessors(clusterService.state());
int numberInferenceProcessors = countInferenceProcessors(clusterService.state());
if (numberInferenceProcessors > 0) {
unsetResetModeListener.onFailure(
new RuntimeException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
*/
package org.elasticsearch.xpack.ml.action;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
Expand All @@ -27,10 +26,7 @@
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.tasks.Task;
Expand All @@ -47,15 +43,13 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -65,29 +59,27 @@

import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsByModelIdsOrAliases;

public class TransportGetTrainedModelsStatsAction extends HandledTransportAction<
GetTrainedModelsStatsAction.Request,
GetTrainedModelsStatsAction.Response> {

private final Client client;
private final ClusterService clusterService;
private final IngestService ingestService;
private final TrainedModelProvider trainedModelProvider;

@Inject
public TransportGetTrainedModelsStatsAction(
TransportService transportService,
ActionFilters actionFilters,
ClusterService clusterService,
IngestService ingestService,
TrainedModelProvider trainedModelProvider,
Client client
) {
super(GetTrainedModelsStatsAction.NAME, transportService, actionFilters, GetTrainedModelsStatsAction.Request::new);
this.client = client;
this.clusterService = clusterService;
this.ingestService = ingestService;
this.trainedModelProvider = trainedModelProvider;
}

Expand Down Expand Up @@ -136,7 +128,6 @@ protected void doExecute(
.collect(Collectors.toSet());
Map<String, Set<String>> pipelineIdsByModelIdsOrAliases = pipelineIdsByModelIdsOrAliases(
clusterService.state(),
ingestService,
allPossiblePipelineReferences
);
Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByModelId(
Expand Down Expand Up @@ -270,37 +261,6 @@ static String[] ingestNodes(final ClusterState clusterState) {
return clusterState.nodes().getIngestNodes().keySet().toArray(String[]::new);
}

static Map<String, Set<String>> pipelineIdsByModelIdsOrAliases(ClusterState state, IngestService ingestService, Set<String> modelIds) {
IngestMetadata ingestMetadata = state.metadata().custom(IngestMetadata.TYPE);
Map<String, Set<String>> pipelineIdsByModelIds = new HashMap<>();
if (ingestMetadata == null) {
return pipelineIdsByModelIds;
}

ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> {
try {
Pipeline pipeline = Pipeline.create(
pipelineId,
pipelineConfiguration.getConfigAsMap(),
ingestService.getProcessorFactories(),
ingestService.getScriptService()
);
pipeline.getProcessors().forEach(processor -> {
if (processor instanceof InferenceProcessor inferenceProcessor) {
if (modelIds.contains(inferenceProcessor.getModelId())) {
pipelineIdsByModelIds.computeIfAbsent(inferenceProcessor.getModelId(), m -> new LinkedHashSet<>())
.add(pipelineId);
}
}
});
} catch (Exception ex) {
throw new ElasticsearchException("unexpected failure gathering pipeline information", ex);
}
});

return pipelineIdsByModelIds;
}

static IngestStats ingestStatsForPipelineIds(NodeStats nodeStats, Set<String> pipelineIds) {
IngestStats fullNodeStats = nodeStats.getIngestStats();
Map<String, List<IngestStats.ProcessorStat>> filteredProcessorStats = new HashMap<>(fullNodeStats.getProcessorStats());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,13 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.ingest.AbstractProcessor;
import org.elasticsearch.ingest.ConfigurationUtils;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.ingest.PipelineConfiguration;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
Expand Down Expand Up @@ -55,6 +51,7 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor;

import java.util.Collections;
import java.util.HashMap;
Expand All @@ -65,7 +62,6 @@
import java.util.function.Consumer;

import static org.elasticsearch.ingest.IngestDocument.INGEST_KEY;
import static org.elasticsearch.ingest.Pipeline.PROCESSORS_KEY;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.MODEL_ID_RESULTS_FIELD;
Expand Down Expand Up @@ -192,9 +188,6 @@ public String getType() {

public static final class Factory implements Processor.Factory, Consumer<ClusterState> {

private static final String FOREACH_PROCESSOR_NAME = "foreach";
// Any more than 10 nestings of processors, we stop searching for inference processor definitions
private static final int MAX_INFERENCE_PROCESSOR_SEARCH_RECURSIONS = 10;
private static final Logger logger = LogManager.getLogger(Factory.class);

private final Client client;
Expand All @@ -213,86 +206,12 @@ public Factory(Client client, ClusterService clusterService, Settings settings)
@Override
public void accept(ClusterState state) {
minNodeVersion = state.nodes().getMinNodeVersion();
currentInferenceProcessors = countNumberInferenceProcessors(state);
}

public static int countNumberInferenceProcessors(ClusterState state) {
Metadata metadata = state.getMetadata();
if (metadata == null) {
return 0;
}
IngestMetadata ingestMetadata = metadata.custom(IngestMetadata.TYPE);
if (ingestMetadata == null) {
return 0;
}

int count = 0;
for (PipelineConfiguration configuration : ingestMetadata.getPipelines().values()) {
Map<String, Object> configMap = configuration.getConfigAsMap();
try {
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
count += numInferenceProcessors(entry.getKey(), entry.getValue());
}
}
// We cannot throw any exception here. It might break other pipelines.
} catch (Exception ex) {
logger.debug(() -> "failed gathering processors for pipeline [" + configuration.getId() + "]", ex);
}
try {
currentInferenceProcessors = InferenceProcessorInfoExtractor.countInferenceProcessors(state);
} catch (Exception ex) {
// We cannot throw any exception here. It might break other pipelines.
logger.debug("failed gathering processors for pipelines", ex);
}
return count;
}

@SuppressWarnings("unchecked")
static int numInferenceProcessors(String processorType, Object processorDefinition) {
return numInferenceProcessors(processorType, (Map<String, Object>) processorDefinition, 0);
}

@SuppressWarnings("unchecked")
static int numInferenceProcessors(String processorType, Map<String, Object> processorDefinition, int level) {
int count = 0;
// arbitrary, but we must limit this somehow
if (level > MAX_INFERENCE_PROCESSOR_SEARCH_RECURSIONS) {
return count;
}
if (processorType == null || processorDefinition == null) {
return count;
}
if (TYPE.equals(processorType)) {
count++;
}
if (FOREACH_PROCESSOR_NAME.equals(processorType)) {
Map<String, Object> innerProcessor = (Map<String, Object>) processorDefinition.get("processor");
if (innerProcessor != null) {
// a foreach processor should only have a SINGLE nested processor. Iteration is for simplicity's sake.
for (Map.Entry<String, Object> innerProcessorWithName : innerProcessor.entrySet()) {
count += numInferenceProcessors(
innerProcessorWithName.getKey(),
(Map<String, Object>) innerProcessorWithName.getValue(),
level + 1
);
}
}
}
if (processorDefinition.containsKey(Pipeline.ON_FAILURE_KEY)) {
List<Map<String, Object>> onFailureConfigs = ConfigurationUtils.readList(
null,
null,
processorDefinition,
Pipeline.ON_FAILURE_KEY
);
count += onFailureConfigs.stream()
.flatMap(map -> map.entrySet().stream())
.mapToInt(entry -> numInferenceProcessors(entry.getKey(), (Map<String, Object>) entry.getValue(), level + 1))
.sum();
}
return count;
}

// Used for testing
int numInferenceProcessors() {
return currentInferenceProcessors;
}

@Override
Expand Down
Loading

0 comments on commit 6847c0b

Please sign in to comment.