From 3c87773e92cc74c0b2e20b0ecdeabd79a9f71789 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 20 Feb 2020 08:33:19 -0500 Subject: [PATCH 1/3] [ML][Inference] don't return inflated definition in PUT response --- .../core/ml/inference/TrainedModelConfig.java | 3 ++ .../TransportPutTrainedModelAction.java | 5 +- .../rest-api-spec/test/ml/inference_crud.yml | 50 +++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 9c49661cfc95f..352931312fe21 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -371,6 +371,9 @@ public Builder(TrainedModelConfig config) { this.tags = config.getTags(); this.metadata = config.getMetadata(); this.input = config.getInput(); + this.estimatedOperations = config.estimatedOperations; + this.estimatedHeapMemory = config.estimatedHeapMemory; + this.licenseLevel = config.licenseLevel.description(); } public Builder setModelId(String modelId) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index 575b8ac00dfb5..f17ee697b660d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -108,7 +108,10 @@ protected void masterOperation(Task task, ActionListener tagsModelIdCheckListener = ActionListener.wrap( r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap( - storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)), + bool -> { + TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build(); + listener.onResponse(new PutTrainedModelAction.Response(configToReturn)); + }, listener::onFailure )), listener::onFailure diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index 7f14987c38756..cb5ccd90e1311 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -460,3 +460,53 @@ setup: } } } +--- +"Test put model": + - do: + ml.put_trained_model: + model_id: my-regression-model + body: > + { + "description": "model for tests", + "input": {"field_names": ["field1", "field2"]}, + "definition": { + "preprocessors": [], + "trained_model": { + "ensemble": { + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + }, + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + } + ] + } + } + } + } + - match: { model_id: my-regression-model } + - match: { estimated_operations: 6 } + - is_false: definition + - is_false: compressed_definition + - is_true: license_level + - is_true: create_time + - is_true: version + - is_true: estimated_heap_memory_usage_bytes From 88a5026f5b6a0c7828103b48fffd2a0bfe4cded5 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 20 Feb 2020 09:42:37 -0500 Subject: [PATCH 2/3] defaulting inflation to false, but to true on GET --- .../high-level/ml/put-trained-model.asciidoc | 2 ++ .../core/ml/inference/TrainedModelConfig.java | 2 +- .../xpack/ml/integration/TrainedModelIT.java | 1 + .../inference/RestGetTrainedModelsAction.java | 33 ++++++++++++++++++- 4 files changed, 36 insertions(+), 2 deletions(-) diff --git a/docs/java-rest/high-level/ml/put-trained-model.asciidoc b/docs/java-rest/high-level/ml/put-trained-model.asciidoc index dadc8dcf65a4f..6a0f96a78b961 100644 --- a/docs/java-rest/high-level/ml/put-trained-model.asciidoc +++ b/docs/java-rest/high-level/ml/put-trained-model.asciidoc @@ -46,6 +46,8 @@ include::../execution.asciidoc[] ==== Response The returned +{response}+ contains the newly created trained model. +The +{response}+ will omit the model definition as a precaution against +streaming large model definitions back to the client. ["source","java",subs="attributes,callouts,macros"] -------------------------------------------------- diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 352931312fe21..dcc2d513a4dc3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -280,7 +280,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); // We don't store the definition in the same document as the configuration if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) { - if (params.paramAsBoolean(DECOMPRESS_DEFINITION, true)) { + if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) { builder.field(DEFINITION.getPreferredName(), definition); } else { builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString()); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index 0aec6bc337412..e993d523b5430 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -93,6 +93,7 @@ public void testGetTrainedModels() throws IOException { assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\"")); assertThat(response, containsString("\"estimated_heap_memory_usage\"")); assertThat(response, containsString("\"definition\"")); + assertThat(response, not(containsString("\"compressed_definition\""))); assertThat(response, containsString("\"count\":1")); getModel = client().performRequest(new Request("GET", diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java index 04f5523365dd9..2a908708ebe46 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -8,8 +8,14 @@ import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.BytesRestResponse; +import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.action.RestToXContentListener; import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; @@ -18,7 +24,9 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import static java.util.Arrays.asList; @@ -34,6 +42,8 @@ public List routes() { new Route(GET, MachineLearning.BASE_PATH + "inference")); } + private static final Map DEFAULT_TO_XCONTENT_VALUES = + Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, Boolean.toString(true)); @Override public String getName() { return "ml_get_trained_models_action"; @@ -56,7 +66,9 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); } request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources())); - return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + return channel -> client.execute(GetTrainedModelsAction.INSTANCE, + request, + new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES)); } @Override @@ -64,4 +76,23 @@ protected Set responseParams() { return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION); } + private static class RestToXContentListenerWithDefaultValues extends RestToXContentListener { + private final Map defaultToXContentParamValues; + + private RestToXContentListenerWithDefaultValues(RestChannel channel, Map defaultToXContentParamValues) { + super(channel); + this.defaultToXContentParamValues = defaultToXContentParamValues; + } + + @Override + public RestResponse buildResponse(T response, XContentBuilder builder) throws Exception { + assert response.isFragment() == false; //would be nice if we could make default methods final + Map params = new HashMap<>(channel.request().params()); + defaultToXContentParamValues.forEach((k, v) -> + params.computeIfAbsent(k, defaultToXContentParamValues::get) + ); + response.toXContent(builder, new ToXContent.MapParams(params)); + return new BytesRestResponse(getStatus(response), builder); + } + } } From 2ae1aa87ee84e2d4daaf99023b7fbb5dd0f1a1cf Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 20 Feb 2020 10:23:42 -0500 Subject: [PATCH 3/3] fixing tests --- .../core/ml/inference/TrainedModelConfigTests.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 3b0a19b496723..03e155cf9d5ec 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -143,21 +143,21 @@ public void testToXContentWithParams() throws IOException { "platinum"); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); - assertThat(reference.utf8ToString(), containsString("\"definition\"")); + assertThat(reference.utf8ToString(), containsString("\"compressed_definition\"")); reference = XContentHelper.toXContent(config, XContentType.JSON, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), false); assertThat(reference.utf8ToString(), not(containsString("definition"))); + assertThat(reference.utf8ToString(), not(containsString("compressed_definition"))); reference = XContentHelper.toXContent(config, XContentType.JSON, - new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "false")), + new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "true")), false); - assertThat(reference.utf8ToString(), not(containsString("\"definition\""))); - assertThat(reference.utf8ToString(), containsString("compressed_definition")); - assertThat(reference.utf8ToString(), containsString(lazyModelDefinition.getCompressedString())); + assertThat(reference.utf8ToString(), containsString("\"definition\"")); + assertThat(reference.utf8ToString(), not(containsString("compressed_definition"))); } public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException { @@ -180,7 +180,7 @@ public void testParseWithBothDefinitionAndCompressedSupplied() throws IOExceptio BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); Map objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2(); - objectMap.put(TrainedModelConfig.COMPRESSED_DEFINITION.getPreferredName(), lazyModelDefinition.getCompressedString()); + objectMap.put(TrainedModelConfig.DEFINITION.getPreferredName(), config.getModelDefinition()); try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(objectMap); XContentParser parser = XContentType.JSON