From 79f143907a6bbedbd52fc96ed2afd7084ffb54fa Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 5 Feb 2020 08:26:44 -0500 Subject: [PATCH] [7.x] [ML] add _cat/ml/trained_models API (#51529) (#51936) * [ML] add _cat/ml/trained_models API (#51529) This adds _cat/ml/trained_models. --- .../action/GetTrainedModelsStatsAction.java | 8 + .../xpack/ml/MachineLearning.java | 2 + .../process/AnalyticsResultProcessor.java | 1 + .../rest/cat/RestCatTrainedModelsAction.java | 283 ++++++++++++++++++ .../api/cat.ml.trained_models.json | 100 +++++++ .../test/ml/trained_model_cat_apis.yml | 110 +++++++ 6 files changed, 504 insertions(+) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/cat/RestCatTrainedModelsAction.java create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/cat.ml.trained_models.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java index 3e91fd0444b0a..0bf3582ffc5c3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java @@ -103,6 +103,14 @@ public String getModelId() { return modelId; } + public IngestStats getIngestStats() { + return ingestStats; + } + + public int getPipelineCount() { + return pipelineCount; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 6bede03c073d1..0ad37eac88436 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -256,6 +256,7 @@ import org.elasticsearch.xpack.ml.rest.calendar.RestPutCalendarJobAction; import org.elasticsearch.xpack.ml.rest.cat.RestCatDatafeedsAction; import org.elasticsearch.xpack.ml.rest.cat.RestCatJobsAction; +import org.elasticsearch.xpack.ml.rest.cat.RestCatTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.datafeeds.RestDeleteDatafeedAction; import org.elasticsearch.xpack.ml.rest.datafeeds.RestGetDatafeedStatsAction; import org.elasticsearch.xpack.ml.rest.datafeeds.RestGetDatafeedsAction; @@ -786,6 +787,7 @@ public List getRestHandlers(Settings settings, RestController restC new RestPutTrainedModelAction(restController), // CAT Handlers new RestCatJobsAction(restController), + new RestCatTrainedModelsAction(restController), new RestCatDatafeedsAction(restController) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 5168c9296d2f9..9ce7a16084461 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -177,6 +177,7 @@ private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Build .setCreatedBy(XPackUser.NAME) .setVersion(Version.CURRENT) .setCreateTime(createTime) + // NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags .setTags(Collections.singletonList(analytics.getId())) .setDescription(analytics.getDescription()) .setMetadata(Collections.singletonMap("analytics_config", diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/cat/RestCatTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/cat/RestCatTrainedModelsAction.java new file mode 100644 index 0000000000000..9f62a3bba786c --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/cat/RestCatTrainedModelsAction.java @@ -0,0 +1,283 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.cat; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.support.GroupedActionListener; +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.Table; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.rest.action.RestResponseListener; +import org.elasticsearch.rest.action.cat.AbstractCatAction; +import org.elasticsearch.rest.action.cat.RestTable; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.security.user.XPackUser; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.elasticsearch.rest.RestRequest.Method.GET; +import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH; + +public class RestCatTrainedModelsAction extends AbstractCatAction { + + public RestCatTrainedModelsAction(RestController controller) { + controller.registerHandler(GET, "_cat/ml/trained_models/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this); + controller.registerHandler(GET, "_cat/ml/trained_models", this); + } + + @Override + public String getName() { + return "cat_ml_get_trained_models_action"; + } + + @Override + protected RestChannelConsumer doCatRequest(RestRequest restRequest, NodeClient client) { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + if (Strings.isNullOrEmpty(modelId)) { + modelId = MetaData.ALL; + } + GetTrainedModelsStatsAction.Request statsRequest = new GetTrainedModelsStatsAction.Request(modelId); + GetTrainedModelsAction.Request modelsAction = new GetTrainedModelsAction.Request(modelId, false, null); + if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { + statsRequest.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), + restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); + modelsAction.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), + restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); + } + statsRequest.setAllowNoResources(true); + modelsAction.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), + statsRequest.isAllowNoResources())); + + return channel -> { + final ActionListener listener = ActionListener.notifyOnce(new RestResponseListener
(channel) { + @Override + public RestResponse buildResponse(final Table table) throws Exception { + return RestTable.buildResponse(table, channel); + } + }); + + client.execute(GetTrainedModelsAction.INSTANCE, modelsAction, ActionListener.wrap( + trainedModels -> { + final List trainedModelConfigs = trainedModels.getResources().results(); + + Set potentialAnalyticsIds = new HashSet<>(); + // Analytics Configs are created by the XPackUser + trainedModelConfigs.stream() + .filter(c -> XPackUser.NAME.equals(c.getCreatedBy())) + .forEach(c -> potentialAnalyticsIds.addAll(c.getTags())); + + + // Find the related DataFrameAnalyticsConfigs + String requestIdPattern = Strings.collectionToDelimitedString(potentialAnalyticsIds, "*,") + "*"; + + final GroupedActionListener groupedListener = createGroupedListener(restRequest, + 2, + trainedModels.getResources().results(), + listener); + + client.execute(GetTrainedModelsStatsAction.INSTANCE, + statsRequest, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure)); + + GetDataFrameAnalyticsAction.Request dataFrameAnalyticsRequest = + new GetDataFrameAnalyticsAction.Request(requestIdPattern); + dataFrameAnalyticsRequest.setAllowNoResources(true); + dataFrameAnalyticsRequest.setPageParams(new PageParams(0, potentialAnalyticsIds.size())); + client.execute(GetDataFrameAnalyticsAction.INSTANCE, + dataFrameAnalyticsRequest, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure)); + }, + listener::onFailure + )); + }; + } + + @Override + protected void documentation(StringBuilder sb) { + sb.append("/_cat/ml/trained_models\n"); + sb.append("/_cat/ml/trained_models/{model_id}\n"); + } + + @Override + protected Table getTableWithHeader(RestRequest request) { + Table table = new Table(); + table.startHeaders(); + + // Trained Model Info + table.addCell("id", TableColumnAttributeBuilder.builder().setDescription("the trained model id").build()); + table.addCell("created_by", TableColumnAttributeBuilder.builder("who created the model", false) + .setAliases("c", "createdBy") + .setTextAlignment(TableColumnAttributeBuilder.TextAlign.RIGHT) + .build()); + table.addCell("heap_size", TableColumnAttributeBuilder.builder() + .setDescription("the estimated heap size to keep the model in memory") + .setAliases("hs","modelHeapSize") + .build()); + table.addCell("operations", TableColumnAttributeBuilder.builder() + .setDescription("the estimated number of operations to use the model") + .setAliases("o", "modelOperations") + .build()); + table.addCell("license", TableColumnAttributeBuilder.builder("The license level of the model", false) + .setAliases("l") + .build()); + table.addCell("create_time", TableColumnAttributeBuilder.builder("The time the model was created") + .setAliases("ct") + .build()); + table.addCell("version", TableColumnAttributeBuilder.builder("The version of Elasticsearch when the model was created", false) + .setAliases("v") + .build()); + table.addCell("description", TableColumnAttributeBuilder.builder("The model description", false) + .setAliases("d") + .build()); + + // Trained Model Stats + table.addCell("ingest.pipelines", TableColumnAttributeBuilder.builder("The number of pipelines referencing the model") + .setAliases("ip", "ingestPipelines") + .build()); + table.addCell("ingest.count", TableColumnAttributeBuilder.builder("The total number of docs processed by the model", false) + .setAliases("ic", "ingestCount") + .build()); + table.addCell("ingest.time", TableColumnAttributeBuilder.builder( + "The total time spent processing docs with this model", + false) + .setAliases("it", "ingestTime") + .build()); + table.addCell("ingest.current", TableColumnAttributeBuilder.builder( + "The total documents currently being handled by the model", + false) + .setAliases("icurr", "ingestCurrent") + .build()); + table.addCell("ingest.failed", TableColumnAttributeBuilder.builder( + "The total count of failed ingest attempts with this model", + false) + .setAliases("if", "ingestFailed") + .build()); + + table.addCell("data_frame.id", TableColumnAttributeBuilder.builder( + "The data frame analytics config id that created the model (if still available)") + .setAliases("dfid", "dataFrameAnalytics") + .build()); + table.addCell("data_frame.create_time", TableColumnAttributeBuilder.builder( + "The time the data frame analytics config was created", + false) + .setAliases("dft", "dataFrameAnalyticsTime") + .build()); + table.addCell("data_frame.source_index", TableColumnAttributeBuilder.builder( + "The source index used to train in the data frame analysis", + false) + .setAliases("dfsi", "dataFrameAnalyticsSrcIndex") + .build()); + table.addCell("data_frame.analysis", TableColumnAttributeBuilder.builder( + "The analysis used by the data frame to build the model", + false) + .setAliases("dfa", "dataFrameAnalyticsAnalysis") + .build()); + + table.endHeaders(); + return table; + } + + private GroupedActionListener createGroupedListener(final RestRequest request, + final int size, + final List configs, + final ActionListener
listener) { + return new GroupedActionListener<>(new ActionListener>() { + @Override + public void onResponse(final Collection responses) { + GetTrainedModelsStatsAction.Response statsResponse = extractResponse(responses, GetTrainedModelsStatsAction.Response.class); + GetDataFrameAnalyticsAction.Response analytics = extractResponse(responses, GetDataFrameAnalyticsAction.Response.class); + listener.onResponse(buildTable(request, + statsResponse.getResources().results(), + configs, + analytics == null ? Collections.emptyList() : analytics.getResources().results())); + } + + @Override + public void onFailure(final Exception e) { + listener.onFailure(e); + } + }, size); + } + + + private Table buildTable(RestRequest request, + List stats, + List configs, + List analyticsConfigs) { + Table table = getTableWithHeader(request); + assert configs.size() == stats.size(); + + Map analyticsMap = analyticsConfigs.stream() + .collect(Collectors.toMap(DataFrameAnalyticsConfig::getId, Function.identity())); + Map statsMap = stats.stream() + .collect(Collectors.toMap(GetTrainedModelsStatsAction.Response.TrainedModelStats::getModelId, Function.identity())); + + configs.forEach(config -> { + table.startRow(); + // Trained Model Info + table.addCell(config.getModelId()); + table.addCell(config.getCreatedBy()); + table.addCell(new ByteSizeValue(config.getEstimatedHeapMemory())); + table.addCell(config.getEstimatedOperations()); + table.addCell(config.getLicenseLevel()); + table.addCell(config.getCreateTime()); + table.addCell(config.getVersion().toString()); + table.addCell(config.getDescription()); + + GetTrainedModelsStatsAction.Response.TrainedModelStats modelStats = statsMap.get(config.getModelId()); + table.addCell(modelStats.getPipelineCount()); + boolean hasIngestStats = modelStats != null && modelStats.getIngestStats() != null; + table.addCell(hasIngestStats ? modelStats.getIngestStats().getTotalStats().getIngestCount() : 0); + table.addCell(hasIngestStats ? + TimeValue.timeValueMillis(modelStats.getIngestStats().getTotalStats().getIngestTimeInMillis()) : + TimeValue.timeValueMillis(0)); + table.addCell(hasIngestStats ? modelStats.getIngestStats().getTotalStats().getIngestCurrent() : 0); + table.addCell(hasIngestStats ? modelStats.getIngestStats().getTotalStats().getIngestFailedCount() : 0); + + DataFrameAnalyticsConfig dataFrameAnalyticsConfig = config.getTags() + .stream() + .filter(analyticsMap::containsKey) + .map(analyticsMap::get) + .findFirst() + .orElse(null); + table.addCell(dataFrameAnalyticsConfig == null ? "__none__" : dataFrameAnalyticsConfig.getId()); + table.addCell(dataFrameAnalyticsConfig == null ? null : dataFrameAnalyticsConfig.getCreateTime()); + table.addCell(dataFrameAnalyticsConfig == null ? + null : + Strings.arrayToCommaDelimitedString(dataFrameAnalyticsConfig.getSource().getIndex())); + DataFrameAnalysis analysis = dataFrameAnalyticsConfig == null ? null : dataFrameAnalyticsConfig.getAnalysis(); + table.addCell(analysis == null ? null : analysis.getWriteableName()); + + table.endRow(); + }); + return table; + } + + @SuppressWarnings("unchecked") + private static A extractResponse(final Collection responses, Class c) { + return (A) responses.stream().filter(c::isInstance).findFirst().get(); + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/cat.ml.trained_models.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/cat.ml.trained_models.json new file mode 100644 index 0000000000000..a76cb993ca1fa --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/cat.ml.trained_models.json @@ -0,0 +1,100 @@ +{ + "cat.ml.trained_models":{ + "documentation":{ + "url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/get-inference-stats.html" + }, + "stability":"stable", + "url":{ + "paths":[ + { + "path":"/_cat/ml/trained_models", + "methods":[ + "GET" + ] + }, + { + "path":"/_cat/ml/trained_models/{model_id}", + "methods":[ + "GET" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the trained models stats to fetch" + } + } + } + ] + }, + "params":{ + "allow_no_match":{ + "type":"boolean", + "required":false, + "description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)", + "default":true + }, + "from":{ + "type":"int", + "description":"skips a number of trained models", + "default":0 + }, + "size":{ + "type":"int", + "description":"specifies a max number of trained models to get", + "default":100 + }, + "bytes":{ + "type":"enum", + "description":"The unit in which to display byte values", + "options":[ + "b", + "k", + "kb", + "m", + "mb", + "g", + "gb", + "t", + "tb", + "p", + "pb" + ] + }, + "format":{ + "type":"string", + "description":"a short version of the Accept header, e.g. json, yaml" + }, + "h":{ + "type":"list", + "description":"Comma-separated list of column names to display" + }, + "help":{ + "type":"boolean", + "description":"Return help information", + "default":false + }, + "s":{ + "type":"list", + "description":"Comma-separated list of column names or column aliases to sort by" + }, + "time":{ + "type":"enum", + "description":"The unit in which to display time values", + "options":[ + "d (Days)", + "h (Hours)", + "m (Minutes)", + "s (Seconds)", + "ms (Milliseconds)", + "micros (Microseconds)", + "nanos (Nanoseconds)" + ] + }, + "v":{ + "type":"boolean", + "description":"Verbose mode. Display column headers", + "default":false + } + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml new file mode 100644 index 0000000000000..9837f8d0c23c6 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml @@ -0,0 +1,110 @@ +setup: + - skip: + features: headers + - do: + indices.create: + index: index-source + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model: + model_id: a-regression-model-0 + body: > + { + "description": "empty model for tests", + "tags": ["regression", "tag1"], + "input": {"field_names": ["field1", "field2"]}, + "definition": { + "preprocessors": [], + "trained_model": { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "leaf_value": 1} + ], + "target_type": "regression" + } + } + } + } + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model: + model_id: a-regression-model-1 + body: > + { + "description": "empty model for tests", + "input": {"field_names": ["field1", "field2"]}, + "definition": { + "preprocessors": [], + "trained_model": { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "leaf_value": 1} + ], + "target_type": "regression" + } + } + } + } + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_data_frame_analytics: + id: "prepackaged" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": {"regression":{ + "dependent_variable": "to_predict" + }} + } + +--- +"Test cat trained models": + + - do: + cat.ml.trained_models: + model_id: a-regression-model-0 + - match: + $body: | + / #id heap_size operations create_time ingest.pipelines data_frame.id + ^ (a\-regression\-model\-0 \s+ \w+ \s+ \d+ \s+ .*? \s+ \d+ .*? \n)+ $/ + + - do: + cat.ml.trained_models: + v: true + model_id: a-regression-model-0 + - match: + $body: | + /^ id \s+ heap_size \s+ operations \s+ create_time \s+ ingest\.pipelines \s+ data_frame\.id \n + (a\-regression\-model\-0 \s+ \w+ \s+ \d+ \s+ .*? \s+ \d+ \s+ .*? \n)+ $/ + + - do: + cat.ml.trained_models: + h: id,license,dfid,ip + v: true + - match: + $body: | + /^ id \s+ license \s+ dfid \s+ ip \n + (a\-regression\-model\-0 \s+ \w+ \s+ __none__ \s+ \d+ \n)+ + (a\-regression\-model\-1 \s+ \w+ \s+ __none__ \s+ \d+ \n)+ + (lang_ident_model_1 \s+ \w+ \s+ prepackaged \s+ \d+ \n)+ $/ + + - do: + cat.ml.trained_models: + model_id: a-regression-model-1 + h: id,license,dfid,ip + v: true + - match: + $body: | + /^ id \s+ license \s+ dfid \s+ ip \n + (a\-regression\-model\-1 \s+ \w+ \s+ __none__ \s+ \d+ \n)+ $/