From 86627ce90591b12e2c75171f958daa346e2ad67e Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 23 Jun 2022 10:59:03 -0400 Subject: [PATCH 1/3] [ML] improve trained model stats API performance --- .../ml/integration/TestFeatureResetIT.java | 6 +- .../xpack/ml/MachineLearning.java | 4 +- .../TransportGetTrainedModelsStatsAction.java | 42 +--- .../inference/ingest/InferenceProcessor.java | 92 +------- .../InferenceProcessorInfoExtractor.java | 158 +++++++++++++ ...sportGetTrainedModelsStatsActionTests.java | 96 -------- .../InferenceProcessorFactoryTests.java | 70 ------ .../InferenceProcessorInfoExtractorTests.java | 212 ++++++++++++++++++ 8 files changed, 380 insertions(+), 300 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractor.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java index d2f0337dd919..d0b9aa1beed9 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java @@ -47,7 +47,6 @@ import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; -import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors; import static org.elasticsearch.xpack.ml.integration.ClassificationIT.KEYWORD_FIELD; import static org.elasticsearch.xpack.ml.integration.MlNativeDataFrameAnalyticsIntegTestCase.buildAnalytics; import static org.elasticsearch.xpack.ml.integration.PyTorchModelIT.BASE_64_ENCODED_MODEL; @@ -56,6 +55,7 @@ import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.createScheduledJob; import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.getDataCounts; import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.indexDocs; +import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.countInferenceProcessors; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -132,9 +132,7 @@ public void testMLFeatureReset() throws Exception { client().execute(DeletePipelineAction.INSTANCE, new DeletePipelineRequest("feature_reset_inference_pipeline")).actionGet(); createdPipelines.remove("feature_reset_inference_pipeline"); - assertBusy( - () -> assertThat(countNumberInferenceProcessors(client().admin().cluster().prepareState().get().getState()), equalTo(0)) - ); + assertBusy(() -> assertThat(countInferenceProcessors(client().admin().cluster().prepareState().get().getState()), equalTo(0))); client().execute(ResetFeatureStateAction.INSTANCE, new ResetFeatureStateRequest()).actionGet(); assertBusy(() -> { List indices = Arrays.asList(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 4a6b17d1de71..d44c03e21410 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -445,7 +445,7 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields.RESULTS_INDEX_PREFIX; import static org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields.STATE_INDEX_PREFIX; -import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors; +import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.countInferenceProcessors; public class MachineLearning extends Plugin implements @@ -1910,7 +1910,7 @@ public void cleanUpFeature( // validate no pipelines are using machine learning models ActionListener afterResetModeSet = ActionListener.wrap(acknowledgedResponse -> { - int numberInferenceProcessors = countNumberInferenceProcessors(clusterService.state()); + int numberInferenceProcessors = countInferenceProcessors(clusterService.state()); if (numberInferenceProcessors > 0) { unsetResetModeListener.onFailure( new RuntimeException( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java index 8d62eedc0ce9..1cae10309f53 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java @@ -6,7 +6,6 @@ */ package org.elasticsearch.xpack.ml.action; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.cluster.node.stats.NodeStats; import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction; @@ -27,10 +26,7 @@ import org.elasticsearch.core.Tuple; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.ingest.IngestMetadata; -import org.elasticsearch.ingest.IngestService; import org.elasticsearch.ingest.IngestStats; -import org.elasticsearch.ingest.Pipeline; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.tasks.Task; @@ -46,7 +42,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats; import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; -import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -54,7 +49,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -64,6 +58,7 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsByModelIdsOrAliases; public class TransportGetTrainedModelsStatsAction extends HandledTransportAction< GetTrainedModelsStatsAction.Request, @@ -71,7 +66,6 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction private final Client client; private final ClusterService clusterService; - private final IngestService ingestService; private final TrainedModelProvider trainedModelProvider; @Inject @@ -79,14 +73,12 @@ public TransportGetTrainedModelsStatsAction( TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, - IngestService ingestService, TrainedModelProvider trainedModelProvider, Client client ) { super(GetTrainedModelsStatsAction.NAME, transportService, actionFilters, GetTrainedModelsStatsAction.Request::new); this.client = client; this.clusterService = clusterService; - this.ingestService = ingestService; this.trainedModelProvider = trainedModelProvider; } @@ -133,7 +125,6 @@ protected void doExecute( .collect(Collectors.toSet()); Map> pipelineIdsByModelIdsOrAliases = pipelineIdsByModelIdsOrAliases( clusterService.state(), - ingestService, allPossiblePipelineReferences ); Map modelIdIngestStats = inferenceIngestStatsByModelId( @@ -261,37 +252,6 @@ static String[] ingestNodes(final ClusterState clusterState) { return clusterState.nodes().getIngestNodes().keySet().toArray(String[]::new); } - static Map> pipelineIdsByModelIdsOrAliases(ClusterState state, IngestService ingestService, Set modelIds) { - IngestMetadata ingestMetadata = state.metadata().custom(IngestMetadata.TYPE); - Map> pipelineIdsByModelIds = new HashMap<>(); - if (ingestMetadata == null) { - return pipelineIdsByModelIds; - } - - ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> { - try { - Pipeline pipeline = Pipeline.create( - pipelineId, - pipelineConfiguration.getConfigAsMap(), - ingestService.getProcessorFactories(), - ingestService.getScriptService() - ); - pipeline.getProcessors().forEach(processor -> { - if (processor instanceof InferenceProcessor inferenceProcessor) { - if (modelIds.contains(inferenceProcessor.getModelId())) { - pipelineIdsByModelIds.computeIfAbsent(inferenceProcessor.getModelId(), m -> new LinkedHashSet<>()) - .add(pipelineId); - } - } - }); - } catch (Exception ex) { - throw new ElasticsearchException("unexpected failure gathering pipeline information", ex); - } - }); - - return pipelineIdsByModelIds; - } - static IngestStats ingestStatsForPipelineIds(NodeStats nodeStats, Set pipelineIds) { IngestStats fullNodeStats = nodeStats.getIngestStats(); Map> filteredProcessorStats = new HashMap<>(fullNodeStats.getProcessorStats()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 9a04d1cfbd85..cbcd041e15f2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -14,7 +14,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; @@ -22,9 +21,6 @@ import org.elasticsearch.ingest.AbstractProcessor; import org.elasticsearch.ingest.ConfigurationUtils; import org.elasticsearch.ingest.IngestDocument; -import org.elasticsearch.ingest.IngestMetadata; -import org.elasticsearch.ingest.Pipeline; -import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.ingest.Processor; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.action.InferModelAction; @@ -55,6 +51,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; +import org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor; import java.util.Collections; import java.util.HashMap; @@ -65,7 +62,6 @@ import java.util.function.Consumer; import static org.elasticsearch.ingest.IngestDocument.INGEST_KEY; -import static org.elasticsearch.ingest.Pipeline.PROCESSORS_KEY; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.MODEL_ID_RESULTS_FIELD; @@ -192,9 +188,6 @@ public String getType() { public static final class Factory implements Processor.Factory, Consumer { - private static final String FOREACH_PROCESSOR_NAME = "foreach"; - // Any more than 10 nestings of processors, we stop searching for inference processor definitions - private static final int MAX_INFERENCE_PROCESSOR_SEARCH_RECURSIONS = 10; private static final Logger logger = LogManager.getLogger(Factory.class); private final Client client; @@ -213,86 +206,11 @@ public Factory(Client client, ClusterService clusterService, Settings settings) @Override public void accept(ClusterState state) { minNodeVersion = state.nodes().getMinNodeVersion(); - currentInferenceProcessors = countNumberInferenceProcessors(state); - } - - public static int countNumberInferenceProcessors(ClusterState state) { - Metadata metadata = state.getMetadata(); - if (metadata == null) { - return 0; - } - IngestMetadata ingestMetadata = metadata.custom(IngestMetadata.TYPE); - if (ingestMetadata == null) { - return 0; - } - - int count = 0; - for (PipelineConfiguration configuration : ingestMetadata.getPipelines().values()) { - Map configMap = configuration.getConfigAsMap(); - try { - List> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY); - for (Map processorConfigWithKey : processorConfigs) { - for (Map.Entry entry : processorConfigWithKey.entrySet()) { - count += numInferenceProcessors(entry.getKey(), entry.getValue()); - } - } - // We cannot throw any exception here. It might break other pipelines. - } catch (Exception ex) { - logger.debug(() -> "failed gathering processors for pipeline [" + configuration.getId() + "]", ex); - } + try { + currentInferenceProcessors = InferenceProcessorInfoExtractor.countInferenceProcessors(state); + } catch (Exception ex) { + logger.debug("failed gathering processors for pipelines", ex); } - return count; - } - - @SuppressWarnings("unchecked") - static int numInferenceProcessors(String processorType, Object processorDefinition) { - return numInferenceProcessors(processorType, (Map) processorDefinition, 0); - } - - @SuppressWarnings("unchecked") - static int numInferenceProcessors(String processorType, Map processorDefinition, int level) { - int count = 0; - // arbitrary, but we must limit this somehow - if (level > MAX_INFERENCE_PROCESSOR_SEARCH_RECURSIONS) { - return count; - } - if (processorType == null || processorDefinition == null) { - return count; - } - if (TYPE.equals(processorType)) { - count++; - } - if (FOREACH_PROCESSOR_NAME.equals(processorType)) { - Map innerProcessor = (Map) processorDefinition.get("processor"); - if (innerProcessor != null) { - // a foreach processor should only have a SINGLE nested processor. Iteration is for simplicity's sake. - for (Map.Entry innerProcessorWithName : innerProcessor.entrySet()) { - count += numInferenceProcessors( - innerProcessorWithName.getKey(), - (Map) innerProcessorWithName.getValue(), - level + 1 - ); - } - } - } - if (processorDefinition.containsKey(Pipeline.ON_FAILURE_KEY)) { - List> onFailureConfigs = ConfigurationUtils.readList( - null, - null, - processorDefinition, - Pipeline.ON_FAILURE_KEY - ); - count += onFailureConfigs.stream() - .flatMap(map -> map.entrySet().stream()) - .mapToInt(entry -> numInferenceProcessors(entry.getKey(), (Map) entry.getValue(), level + 1)) - .sum(); - } - return count; - } - - // Used for testing - int numInferenceProcessors() { - return currentInferenceProcessors; } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractor.java new file mode 100644 index 000000000000..447e6d35f7aa --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractor.java @@ -0,0 +1,158 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.utils; + +import org.apache.lucene.util.Counter; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.ingest.ConfigurationUtils; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.Pipeline; + +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; + +import static org.elasticsearch.ingest.Pipeline.PROCESSORS_KEY; +import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.MODEL_ID_RESULTS_FIELD; +import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.TYPE; + +/** + * Utilities for extracting information around inference processors from IngestMetadata + */ +public final class InferenceProcessorInfoExtractor { + + private static final String FOREACH_PROCESSOR_NAME = "foreach"; + // Any more than 10 nestings of processors, we stop searching for inference processor definitions + private static final int MAX_INFERENCE_PROCESSOR_SEARCH_RECURSIONS = 10; + + private InferenceProcessorInfoExtractor() {} + + /** + * @param state The current cluster state + * @return The current count of inference processors + */ + @SuppressWarnings("unchecked") + public static int countInferenceProcessors(ClusterState state) { + Metadata metadata = state.getMetadata(); + if (metadata == null) { + return 0; + } + IngestMetadata ingestMetadata = metadata.custom(IngestMetadata.TYPE); + if (ingestMetadata == null) { + return 0; + } + Counter counter = Counter.newCounter(); + ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> { + Map configMap = configuration.getConfigAsMap(); + List> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY); + for (Map processorConfigWithKey : processorConfigs) { + for (Map.Entry entry : processorConfigWithKey.entrySet()) { + addModelsAndPipelines( + entry.getKey(), + pipelineId, + (Map) entry.getValue(), + pam -> counter.addAndGet(1), + 0 + ); + } + } + }); + return (int) counter.get(); + } + + /** + * @param state Current cluster state + * @return a map from Model IDs or Aliases to each pipeline referencing them. + */ + @SuppressWarnings("unchecked") + public static Map> pipelineIdsByModelIdsOrAliases(ClusterState state, Set modelIds) { + Map> pipelineIdsByModelIds = new HashMap<>(); + Metadata metadata = state.metadata(); + if (metadata == null) { + return pipelineIdsByModelIds; + } + IngestMetadata ingestMetadata = metadata.custom(IngestMetadata.TYPE); + if (ingestMetadata == null) { + return pipelineIdsByModelIds; + } + ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> { + Map configMap = configuration.getConfigAsMap(); + List> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY); + for (Map processorConfigWithKey : processorConfigs) { + for (Map.Entry entry : processorConfigWithKey.entrySet()) { + addModelsAndPipelines(entry.getKey(), pipelineId, (Map) entry.getValue(), pam -> { + if (modelIds.contains(pam.modelIdOrAlias)) { + pipelineIdsByModelIds.computeIfAbsent(pam.modelIdOrAlias, m -> new LinkedHashSet<>()).add(pipelineId); + } + }, 0); + } + } + }); + return pipelineIdsByModelIds; + } + + @SuppressWarnings("unchecked") + private static void addModelsAndPipelines( + String processorType, + String pipelineId, + Map processorDefinition, + Consumer handler, + int level + ) { + // arbitrary, but we must limit this somehow + if (level > MAX_INFERENCE_PROCESSOR_SEARCH_RECURSIONS) { + return; + } + if (processorType == null || processorDefinition == null) { + return; + } + if (TYPE.equals(processorType)) { + String modelId = (String) processorDefinition.get(MODEL_ID_RESULTS_FIELD); + if (modelId != null) { + handler.accept(new PipelineAndModel(pipelineId, modelId)); + } + return; + } + if (FOREACH_PROCESSOR_NAME.equals(processorType)) { + Map innerProcessor = (Map) processorDefinition.get("processor"); + if (innerProcessor != null) { + // a foreach processor should only have a SINGLE nested processor. Iteration is for simplicity's sake. + for (Map.Entry innerProcessorWithName : innerProcessor.entrySet()) { + addModelsAndPipelines( + innerProcessorWithName.getKey(), + pipelineId, + (Map) innerProcessorWithName.getValue(), + handler, + level + 1 + ); + } + } + return; + } + if (processorDefinition.containsKey(Pipeline.ON_FAILURE_KEY)) { + List> onFailureConfigs = ConfigurationUtils.readList( + null, + null, + processorDefinition, + Pipeline.ON_FAILURE_KEY + ); + onFailureConfigs.stream() + .flatMap(map -> map.entrySet().stream()) + .forEach( + entry -> addModelsAndPipelines(entry.getKey(), pipelineId, (Map) entry.getValue(), handler, level + 1) + ); + } + } + + private record PipelineAndModel(String pipelineId, String modelIdOrAlias) {} + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java index 8708f0506c4b..7ee464a32e98 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java @@ -10,38 +10,27 @@ import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterName; -import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.OperationRouting; import org.elasticsearch.cluster.service.ClusterApplierService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.MasterService; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.ingest.IngestDocument; -import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.ingest.IngestStats; -import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.ingest.Processor; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.ml.MachineLearningField; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.junit.Before; -import java.io.IOException; import java.time.Instant; import java.util.Arrays; import java.util.Collections; @@ -245,91 +234,6 @@ public void testInferenceIngestStatsByModelId() { assertThat(ingestStatsMap, hasEntry("trained_model_2", expectedStatsModel2)); } - public void testPipelineIdsByModelIds() throws IOException { - String modelId1 = "trained_model_1"; - String modelId2 = "trained_model_2"; - String modelId3 = "trained_model_3"; - Set modelIds = new HashSet<>(Arrays.asList(modelId1, modelId2, modelId3)); - - ClusterState clusterState = buildClusterStateWithModelReferences(modelId1, modelId2, modelId3); - - Map> pipelineIdsByModelIds = TransportGetTrainedModelsStatsAction.pipelineIdsByModelIdsOrAliases( - clusterState, - ingestService, - modelIds - ); - - assertThat(pipelineIdsByModelIds.keySet(), equalTo(modelIds)); - assertThat( - pipelineIdsByModelIds, - hasEntry(modelId1, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId1 + 0, "pipeline_with_model_" + modelId1 + 1))) - ); - assertThat( - pipelineIdsByModelIds, - hasEntry(modelId2, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId2 + 0, "pipeline_with_model_" + modelId2 + 1))) - ); - assertThat( - pipelineIdsByModelIds, - hasEntry(modelId3, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId3 + 0, "pipeline_with_model_" + modelId3 + 1))) - ); - - } - - private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { - Map configurations = Maps.newMapWithExpectedSize(modelId.length); - for (String id : modelId) { - configurations.put("pipeline_with_model_" + id + 0, newConfigurationWithInferenceProcessor(id, 0)); - configurations.put("pipeline_with_model_" + id + 1, newConfigurationWithInferenceProcessor(id, 1)); - } - for (int i = 0; i < 3; i++) { - configurations.put("pipeline_without_model_" + i, newConfigurationWithOutInferenceProcessor(i)); - } - IngestMetadata ingestMetadata = new IngestMetadata(configurations); - - return ClusterState.builder(new ClusterName("_name")) - .metadata(Metadata.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) - .build(); - } - - private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId, int num) throws IOException { - try ( - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() - .map( - Collections.singletonMap( - "processors", - Collections.singletonList(Collections.singletonMap(InferenceProcessor.TYPE, new HashMap() { - { - put(InferenceResults.MODEL_ID_RESULTS_FIELD, modelId); - put("inference_config", Collections.singletonMap("regression", Collections.emptyMap())); - put("field_map", Collections.emptyMap()); - put("target_field", randomAlphaOfLength(10)); - } - })) - ) - ) - ) { - return new PipelineConfiguration( - "pipeline_with_model_" + modelId + num, - BytesReference.bytes(xContentBuilder), - XContentType.JSON - ); - } - } - - private static PipelineConfiguration newConfigurationWithOutInferenceProcessor(int i) throws IOException { - try ( - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() - .map( - Collections.singletonMap( - "processors", - Collections.singletonList(Collections.singletonMap("not_inference", Collections.emptyMap())) - ) - ) - ) { - return new PipelineConfiguration("pipeline_without_model_" + i, BytesReference.bytes(xContentBuilder), XContentType.JSON); - } - } - private static NodeStats buildNodeStats( IngestStats.Stats overallStats, List pipelineNames, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java index 6420c4e80dda..5398719f7832 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -84,76 +84,6 @@ public void setUpVariables() { clusterService = new ClusterService(settings, clusterSettings, tp); } - public void testNumInferenceProcessors() throws Exception { - Metadata metadata = null; - - InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY); - processorFactory.accept(buildClusterState(metadata)); - - assertThat(processorFactory.numInferenceProcessors(), equalTo(0)); - metadata = Metadata.builder().build(); - - processorFactory.accept(buildClusterState(metadata)); - assertThat(processorFactory.numInferenceProcessors(), equalTo(0)); - - processorFactory.accept(buildClusterStateWithModelReferences("model1", "model2", "model3")); - assertThat(processorFactory.numInferenceProcessors(), equalTo(3)); - } - - public void testNumInferenceProcessorsRecursivelyDefined() throws Exception { - Metadata metadata = null; - - InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY); - processorFactory.accept(buildClusterState(metadata)); - - Map configurations = new HashMap<>(); - configurations.put( - "pipeline_with_model_top_level", - randomBoolean() - ? newConfigurationWithInferenceProcessor("top_level") - : newConfigurationWithForeachProcessorProcessor("top_level") - ); - try ( - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() - .map(Collections.singletonMap("processors", Collections.singletonList(Collections.singletonMap("set", new HashMap<>() { - { - put("field", "foo"); - put("value", "bar"); - put( - "on_failure", - Arrays.asList(inferenceProcessorForModel("second_level"), forEachProcessorWithInference("third_level")) - ); - } - })))) - ) { - configurations.put( - "pipeline_with_model_nested", - new PipelineConfiguration("pipeline_with_model_nested", BytesReference.bytes(xContentBuilder), XContentType.JSON) - ); - } - - IngestMetadata ingestMetadata = new IngestMetadata(configurations); - - ClusterState cs = ClusterState.builder(new ClusterName("_name")) - .metadata(Metadata.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) - .nodes( - DiscoveryNodes.builder() - .add(new DiscoveryNode("min_node", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Version.CURRENT)) - .add(new DiscoveryNode("current_node", new TransportAddress(InetAddress.getLoopbackAddress(), 9302), Version.CURRENT)) - .add(new DiscoveryNode("_node_id", new TransportAddress(InetAddress.getLoopbackAddress(), 9304), Version.CURRENT)) - .localNodeId("_node_id") - .masterNodeId("_node_id") - ) - .build(); - - processorFactory.accept(cs); - assertThat(processorFactory.numInferenceProcessors(), equalTo(3)); - } - - public void testNumInferenceWhenLevelExceedsMaxRecurions() { - assertThat(InferenceProcessor.Factory.numInferenceProcessors(InferenceProcessor.TYPE, Collections.emptyMap(), 100), equalTo(0)); - } - public void testCreateProcessorWithTooManyExisting() throws Exception { InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory( client, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java new file mode 100644 index 000000000000..077a4abfced0 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java @@ -0,0 +1,212 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.utils; + +import org.elasticsearch.Version; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.util.Maps; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; + +public class InferenceProcessorInfoExtractorTests extends ESTestCase { + + public void testPipelineIdsByModelIds() throws IOException { + String modelId1 = "trained_model_1"; + String modelId2 = "trained_model_2"; + String modelId3 = "trained_model_3"; + Set modelIds = new HashSet<>(Arrays.asList(modelId1, modelId2, modelId3)); + + ClusterState clusterState = buildClusterStateWithModelReferences(2, modelId1, modelId2, modelId3); + + Map> pipelineIdsByModelIds = InferenceProcessorInfoExtractor.pipelineIdsByModelIdsOrAliases( + clusterState, + modelIds + ); + + assertThat(pipelineIdsByModelIds.keySet(), equalTo(modelIds)); + assertThat( + pipelineIdsByModelIds, + hasEntry(modelId1, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId1 + 0, "pipeline_with_model_" + modelId1 + 1))) + ); + assertThat( + pipelineIdsByModelIds, + hasEntry(modelId2, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId2 + 0, "pipeline_with_model_" + modelId2 + 1))) + ); + assertThat( + pipelineIdsByModelIds, + hasEntry(modelId3, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId3 + 0, "pipeline_with_model_" + modelId3 + 1))) + ); + } + + public void testNumInferenceProcessors() throws IOException { + assertThat(InferenceProcessorInfoExtractor.countInferenceProcessors(buildClusterState(null)), equalTo(0)); + assertThat(InferenceProcessorInfoExtractor.countInferenceProcessors(buildClusterState(Metadata.EMPTY_METADATA)), equalTo(0)); + assertThat( + InferenceProcessorInfoExtractor.countInferenceProcessors(buildClusterStateWithModelReferences(1, "model1", "model2", "model3")), + equalTo(3) + ); + } + + public void testNumInferenceProcessorsRecursivelyDefined() throws IOException { + Map configurations = new HashMap<>(); + configurations.put( + "pipeline_with_model_top_level", + randomBoolean() + ? newConfigurationWithInferenceProcessor("top_level") + : newConfigurationWithForeachProcessorProcessor("top_level") + ); + try ( + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .map(Collections.singletonMap("processors", Collections.singletonList(Collections.singletonMap("set", new HashMap<>() { + { + put("field", "foo"); + put("value", "bar"); + put( + "on_failure", + Arrays.asList(inferenceProcessorForModel("second_level"), forEachProcessorWithInference("third_level")) + ); + } + })))) + ) { + configurations.put( + "pipeline_with_model_nested", + new PipelineConfiguration("pipeline_with_model_nested", BytesReference.bytes(xContentBuilder), XContentType.JSON) + ); + } + + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + + ClusterState cs = ClusterState.builder(new ClusterName("_name")) + .metadata(Metadata.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) + .nodes( + DiscoveryNodes.builder() + .add(new DiscoveryNode("min_node", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Version.CURRENT)) + .add(new DiscoveryNode("current_node", new TransportAddress(InetAddress.getLoopbackAddress(), 9302), Version.CURRENT)) + .add(new DiscoveryNode("_node_id", new TransportAddress(InetAddress.getLoopbackAddress(), 9304), Version.CURRENT)) + .localNodeId("_node_id") + .masterNodeId("_node_id") + ) + .build(); + + assertThat(InferenceProcessorInfoExtractor.countInferenceProcessors(cs), equalTo(3)); + } + + private static PipelineConfiguration newConfigurationWithOutInferenceProcessor(int i) throws IOException { + try ( + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .map( + Collections.singletonMap( + "processors", + Collections.singletonList(Collections.singletonMap("not_inference", Collections.emptyMap())) + ) + ) + ) { + return new PipelineConfiguration("pipeline_without_model_" + i, BytesReference.bytes(xContentBuilder), XContentType.JSON); + } + } + + private static ClusterState buildClusterState(Metadata metadata) { + return ClusterState.builder(new ClusterName("_name")).metadata(metadata).build(); + } + + private static ClusterState buildClusterStateWithModelReferences(int numPipelineReferences, String... modelId) throws IOException { + Map configurations = Maps.newMapWithExpectedSize(modelId.length); + for (String id : modelId) { + for (int i = 0; i < numPipelineReferences; i++) { + configurations.put( + "pipeline_with_model_" + id + i, + randomBoolean() ? newConfigurationWithInferenceProcessor(id) : newConfigurationWithForeachProcessorProcessor(id) + ); + } + + } + for (int i = 0; i < randomInt(5); i++) { + configurations.put("pipeline_without_model_" + i, newConfigurationWithOutInferenceProcessor(i)); + } + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + + return ClusterState.builder(new ClusterName("_name")) + .metadata(Metadata.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) + .nodes( + DiscoveryNodes.builder() + .add(new DiscoveryNode("min_node", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Version.CURRENT)) + .add(new DiscoveryNode("current_node", new TransportAddress(InetAddress.getLoopbackAddress(), 9302), Version.CURRENT)) + .add(new DiscoveryNode("_node_id", new TransportAddress(InetAddress.getLoopbackAddress(), 9304), Version.CURRENT)) + .localNodeId("_node_id") + .masterNodeId("_node_id") + ) + .build(); + } + + private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException { + try ( + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .map(Collections.singletonMap("processors", Collections.singletonList(inferenceProcessorForModel(modelId)))) + ) { + return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON); + } + } + + private static PipelineConfiguration newConfigurationWithForeachProcessorProcessor(String modelId) throws IOException { + try ( + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .map(Collections.singletonMap("processors", Collections.singletonList(forEachProcessorWithInference(modelId)))) + ) { + return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON); + } + } + + private static Map forEachProcessorWithInference(String modelId) { + return Collections.singletonMap("foreach", new HashMap<>() { + { + put("field", "foo"); + put("processor", inferenceProcessorForModel(modelId)); + } + }); + } + + private static Map inferenceProcessorForModel(String modelId) { + return Collections.singletonMap(InferenceProcessor.TYPE, new HashMap<>() { + { + put(InferenceResults.MODEL_ID_RESULTS_FIELD, modelId); + put( + InferenceProcessor.INFERENCE_CONFIG, + Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()) + ); + put(InferenceProcessor.TARGET_FIELD, "new_field"); + put(InferenceProcessor.FIELD_MAP, Collections.singletonMap("source", "dest")); + } + }); + } + +} From b18891cf9fc7366c2dea715d735defc8e44e0762 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 23 Jun 2022 11:02:21 -0400 Subject: [PATCH 2/3] Update docs/changelog/87978.yaml --- docs/changelog/87978.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/87978.yaml diff --git a/docs/changelog/87978.yaml b/docs/changelog/87978.yaml new file mode 100644 index 000000000000..b72d511bfa89 --- /dev/null +++ b/docs/changelog/87978.yaml @@ -0,0 +1,5 @@ +pr: 87978 +summary: Improve trained model stats API performance +area: Machine Learning +type: bug +issues: [] From caee509dc7ca41e7687983f9fdfaf40bf73dce7a Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 27 Jun 2022 14:18:23 -0400 Subject: [PATCH 3/3] Apply suggestions from code review --- .../xpack/ml/inference/ingest/InferenceProcessor.java | 1 + .../xpack/ml/utils/InferenceProcessorInfoExtractorTests.java | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index cbcd041e15f2..48d98c981e14 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -209,6 +209,7 @@ public void accept(ClusterState state) { try { currentInferenceProcessors = InferenceProcessorInfoExtractor.countInferenceProcessors(state); } catch (Exception ex) { + // We cannot throw any exception here. It might break other pipelines. logger.debug("failed gathering processors for pipelines", ex); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java index 077a4abfced0..cdead95ea5d3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java @@ -150,7 +150,8 @@ private static ClusterState buildClusterStateWithModelReferences(int numPipeline } } - for (int i = 0; i < randomInt(5); i++) { + int numPipelinesWithoutModel = randomInt(5); + for (int i = 0; i < numPipelinesWithoutModel; i++) { configurations.put("pipeline_without_model_" + i, newConfigurationWithOutInferenceProcessor(i)); } IngestMetadata ingestMetadata = new IngestMetadata(configurations);