From 8101e4fdb67c70ff127a639351b7491a805bcfac Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 20 Feb 2020 11:25:34 -0500 Subject: [PATCH] [ML][Inference] don't return inflated definition when storing trained models (#52573) When `PUT` is called to store a trained model, it is useful to return the newly create model config. But, it is NOT useful to return the inflated definition. These definitions can be large and returning the inflated definition causes undo work on the server and client side. --- .../high-level/ml/put-trained-model.asciidoc | 2 + .../core/ml/inference/TrainedModelConfig.java | 5 +- .../ml/inference/TrainedModelConfigTests.java | 12 ++--- .../xpack/ml/integration/TrainedModelIT.java | 1 + .../TransportPutTrainedModelAction.java | 5 +- .../inference/RestGetTrainedModelsAction.java | 33 +++++++++++- .../rest-api-spec/test/ml/inference_crud.yml | 50 +++++++++++++++++++ 7 files changed, 99 insertions(+), 9 deletions(-) diff --git a/docs/java-rest/high-level/ml/put-trained-model.asciidoc b/docs/java-rest/high-level/ml/put-trained-model.asciidoc index dadc8dcf65a4f..6a0f96a78b961 100644 --- a/docs/java-rest/high-level/ml/put-trained-model.asciidoc +++ b/docs/java-rest/high-level/ml/put-trained-model.asciidoc @@ -46,6 +46,8 @@ include::../execution.asciidoc[] ==== Response The returned +{response}+ contains the newly created trained model. +The +{response}+ will omit the model definition as a precaution against +streaming large model definitions back to the client. ["source","java",subs="attributes,callouts,macros"] -------------------------------------------------- diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 9bd447319cc4a..57b2ef7091ce5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -279,7 +279,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); // We don't store the definition in the same document as the configuration if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) { - if (params.paramAsBoolean(DECOMPRESS_DEFINITION, true)) { + if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) { builder.field(DEFINITION.getPreferredName(), definition); } else { builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString()); @@ -370,6 +370,9 @@ public Builder(TrainedModelConfig config) { this.tags = config.getTags(); this.metadata = config.getMetadata(); this.input = config.getInput(); + this.estimatedOperations = config.estimatedOperations; + this.estimatedHeapMemory = config.estimatedHeapMemory; + this.licenseLevel = config.licenseLevel.description(); } public Builder setModelId(String modelId) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 810112620120f..76c0492ce2c79 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -142,21 +142,21 @@ public void testToXContentWithParams() throws IOException { "platinum"); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); - assertThat(reference.utf8ToString(), containsString("\"definition\"")); + assertThat(reference.utf8ToString(), containsString("\"compressed_definition\"")); reference = XContentHelper.toXContent(config, XContentType.JSON, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), false); assertThat(reference.utf8ToString(), not(containsString("definition"))); + assertThat(reference.utf8ToString(), not(containsString("compressed_definition"))); reference = XContentHelper.toXContent(config, XContentType.JSON, - new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "false")), + new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "true")), false); - assertThat(reference.utf8ToString(), not(containsString("\"definition\""))); - assertThat(reference.utf8ToString(), containsString("compressed_definition")); - assertThat(reference.utf8ToString(), containsString(lazyModelDefinition.getCompressedString())); + assertThat(reference.utf8ToString(), containsString("\"definition\"")); + assertThat(reference.utf8ToString(), not(containsString("compressed_definition"))); } public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException { @@ -179,7 +179,7 @@ public void testParseWithBothDefinitionAndCompressedSupplied() throws IOExceptio BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); Map objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2(); - objectMap.put(TrainedModelConfig.COMPRESSED_DEFINITION.getPreferredName(), lazyModelDefinition.getCompressedString()); + objectMap.put(TrainedModelConfig.DEFINITION.getPreferredName(), config.getModelDefinition()); try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(objectMap); XContentParser parser = XContentType.JSON diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index 0aec6bc337412..e993d523b5430 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -93,6 +93,7 @@ public void testGetTrainedModels() throws IOException { assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\"")); assertThat(response, containsString("\"estimated_heap_memory_usage\"")); assertThat(response, containsString("\"definition\"")); + assertThat(response, not(containsString("\"compressed_definition\""))); assertThat(response, containsString("\"count\":1")); getModel = client().performRequest(new Request("GET", diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index a520f62167272..f71c98b815fda 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -105,7 +105,10 @@ protected void masterOperation(Request request, ClusterState state, ActionListen ActionListener tagsModelIdCheckListener = ActionListener.wrap( r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap( - storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)), + bool -> { + TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build(); + listener.onResponse(new PutTrainedModelAction.Response(configToReturn)); + }, listener::onFailure )), listener::onFailure diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java index 1aa0fd42350b5..ffd3f03a59ceb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -8,9 +8,15 @@ import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.BytesRestResponse; +import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.action.RestToXContentListener; import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; @@ -19,6 +25,8 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.Set; import static org.elasticsearch.rest.RestRequest.Method.GET; @@ -32,6 +40,8 @@ public RestGetTrainedModelsAction(RestController controller) { controller.registerHandler(GET, MachineLearning.BASE_PATH + "inference", this); } + private static final Map DEFAULT_TO_XCONTENT_VALUES = + Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, Boolean.toString(true)); @Override public String getName() { return "ml_get_trained_models_action"; @@ -53,7 +63,9 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); } request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources())); - return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + return channel -> client.execute(GetTrainedModelsAction.INSTANCE, + request, + new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES)); } @Override @@ -61,4 +73,23 @@ protected Set responseParams() { return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION); } + private static class RestToXContentListenerWithDefaultValues extends RestToXContentListener { + private final Map defaultToXContentParamValues; + + private RestToXContentListenerWithDefaultValues(RestChannel channel, Map defaultToXContentParamValues) { + super(channel); + this.defaultToXContentParamValues = defaultToXContentParamValues; + } + + @Override + public RestResponse buildResponse(T response, XContentBuilder builder) throws Exception { + assert response.isFragment() == false; //would be nice if we could make default methods final + Map params = new HashMap<>(channel.request().params()); + defaultToXContentParamValues.forEach((k, v) -> + params.computeIfAbsent(k, defaultToXContentParamValues::get) + ); + response.toXContent(builder, new ToXContent.MapParams(params)); + return new BytesRestResponse(getStatus(response), builder); + } + } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index 45da23a01e6f5..f2b2141377d39 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -299,3 +299,53 @@ setup: } } } +--- +"Test put model": + - do: + ml.put_trained_model: + model_id: my-regression-model + body: > + { + "description": "model for tests", + "input": {"field_names": ["field1", "field2"]}, + "definition": { + "preprocessors": [], + "trained_model": { + "ensemble": { + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + }, + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + } + ] + } + } + } + } + - match: { model_id: my-regression-model } + - match: { estimated_operations: 6 } + - is_false: definition + - is_false: compressed_definition + - is_true: license_level + - is_true: create_time + - is_true: version + - is_true: estimated_heap_memory_usage_bytes