diff --git a/docs/java-rest/high-level/ml/put-trained-model.asciidoc b/docs/java-rest/high-level/ml/put-trained-model.asciidoc index dadc8dcf65a4f..6a0f96a78b961 100644 --- a/docs/java-rest/high-level/ml/put-trained-model.asciidoc +++ b/docs/java-rest/high-level/ml/put-trained-model.asciidoc @@ -46,6 +46,8 @@ include::../execution.asciidoc[] ==== Response The returned +{response}+ contains the newly created trained model. +The +{response}+ will omit the model definition as a precaution against +streaming large model definitions back to the client. ["source","java",subs="attributes,callouts,macros"] -------------------------------------------------- diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 9bd447319cc4a..57b2ef7091ce5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -279,7 +279,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); // We don't store the definition in the same document as the configuration if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) { - if (params.paramAsBoolean(DECOMPRESS_DEFINITION, true)) { + if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) { builder.field(DEFINITION.getPreferredName(), definition); } else { builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString()); @@ -370,6 +370,9 @@ public Builder(TrainedModelConfig config) { this.tags = config.getTags(); this.metadata = config.getMetadata(); this.input = config.getInput(); + this.estimatedOperations = config.estimatedOperations; + this.estimatedHeapMemory = config.estimatedHeapMemory; + this.licenseLevel = config.licenseLevel.description(); } public Builder setModelId(String modelId) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 810112620120f..76c0492ce2c79 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -142,21 +142,21 @@ public void testToXContentWithParams() throws IOException { "platinum"); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); - assertThat(reference.utf8ToString(), containsString("\"definition\"")); + assertThat(reference.utf8ToString(), containsString("\"compressed_definition\"")); reference = XContentHelper.toXContent(config, XContentType.JSON, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), false); assertThat(reference.utf8ToString(), not(containsString("definition"))); + assertThat(reference.utf8ToString(), not(containsString("compressed_definition"))); reference = XContentHelper.toXContent(config, XContentType.JSON, - new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "false")), + new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "true")), false); - assertThat(reference.utf8ToString(), not(containsString("\"definition\""))); - assertThat(reference.utf8ToString(), containsString("compressed_definition")); - assertThat(reference.utf8ToString(), containsString(lazyModelDefinition.getCompressedString())); + assertThat(reference.utf8ToString(), containsString("\"definition\"")); + assertThat(reference.utf8ToString(), not(containsString("compressed_definition"))); } public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException { @@ -179,7 +179,7 @@ public void testParseWithBothDefinitionAndCompressedSupplied() throws IOExceptio BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); Map objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2(); - objectMap.put(TrainedModelConfig.COMPRESSED_DEFINITION.getPreferredName(), lazyModelDefinition.getCompressedString()); + objectMap.put(TrainedModelConfig.DEFINITION.getPreferredName(), config.getModelDefinition()); try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(objectMap); XContentParser parser = XContentType.JSON diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index 0aec6bc337412..e993d523b5430 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -93,6 +93,7 @@ public void testGetTrainedModels() throws IOException { assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\"")); assertThat(response, containsString("\"estimated_heap_memory_usage\"")); assertThat(response, containsString("\"definition\"")); + assertThat(response, not(containsString("\"compressed_definition\""))); assertThat(response, containsString("\"count\":1")); getModel = client().performRequest(new Request("GET", diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index a520f62167272..f71c98b815fda 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -105,7 +105,10 @@ protected void masterOperation(Request request, ClusterState state, ActionListen ActionListener tagsModelIdCheckListener = ActionListener.wrap( r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap( - storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)), + bool -> { + TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build(); + listener.onResponse(new PutTrainedModelAction.Response(configToReturn)); + }, listener::onFailure )), listener::onFailure diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java index 1aa0fd42350b5..ffd3f03a59ceb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -8,9 +8,15 @@ import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.BytesRestResponse; +import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.action.RestToXContentListener; import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; @@ -19,6 +25,8 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.Set; import static org.elasticsearch.rest.RestRequest.Method.GET; @@ -32,6 +40,8 @@ public RestGetTrainedModelsAction(RestController controller) { controller.registerHandler(GET, MachineLearning.BASE_PATH + "inference", this); } + private static final Map DEFAULT_TO_XCONTENT_VALUES = + Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, Boolean.toString(true)); @Override public String getName() { return "ml_get_trained_models_action"; @@ -53,7 +63,9 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); } request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources())); - return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + return channel -> client.execute(GetTrainedModelsAction.INSTANCE, + request, + new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES)); } @Override @@ -61,4 +73,23 @@ protected Set responseParams() { return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION); } + private static class RestToXContentListenerWithDefaultValues extends RestToXContentListener { + private final Map defaultToXContentParamValues; + + private RestToXContentListenerWithDefaultValues(RestChannel channel, Map defaultToXContentParamValues) { + super(channel); + this.defaultToXContentParamValues = defaultToXContentParamValues; + } + + @Override + public RestResponse buildResponse(T response, XContentBuilder builder) throws Exception { + assert response.isFragment() == false; //would be nice if we could make default methods final + Map params = new HashMap<>(channel.request().params()); + defaultToXContentParamValues.forEach((k, v) -> + params.computeIfAbsent(k, defaultToXContentParamValues::get) + ); + response.toXContent(builder, new ToXContent.MapParams(params)); + return new BytesRestResponse(getStatus(response), builder); + } + } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index 45da23a01e6f5..f2b2141377d39 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -299,3 +299,53 @@ setup: } } } +--- +"Test put model": + - do: + ml.put_trained_model: + model_id: my-regression-model + body: > + { + "description": "model for tests", + "input": {"field_names": ["field1", "field2"]}, + "definition": { + "preprocessors": [], + "trained_model": { + "ensemble": { + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + }, + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + } + ] + } + } + } + } + - match: { model_id: my-regression-model } + - match: { estimated_operations: 6 } + - is_false: definition + - is_false: compressed_definition + - is_true: license_level + - is_true: create_time + - is_true: version + - is_true: estimated_heap_memory_usage_bytes