Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[7.6] [ML][Inference] don't return inflated definition when storing trained models (#52573) #52583

Merged
merged 1 commit into from
Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/java-rest/high-level/ml/put-trained-model.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ include::../execution.asciidoc[]
==== Response

The returned +{response}+ contains the newly created trained model.
The +{response}+ will omit the model definition as a precaution against
streaming large model definitions back to the client.

["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
// We don't store the definition in the same document as the configuration
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, true)) {
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
builder.field(DEFINITION.getPreferredName(), definition);
} else {
builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString());
Expand Down Expand Up @@ -370,6 +370,9 @@ public Builder(TrainedModelConfig config) {
this.tags = config.getTags();
this.metadata = config.getMetadata();
this.input = config.getInput();
this.estimatedOperations = config.estimatedOperations;
this.estimatedHeapMemory = config.estimatedHeapMemory;
this.licenseLevel = config.licenseLevel.description();
}

public Builder setModelId(String modelId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,21 @@ public void testToXContentWithParams() throws IOException {
"platinum");

BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
assertThat(reference.utf8ToString(), containsString("\"definition\""));
assertThat(reference.utf8ToString(), containsString("\"compressed_definition\""));

reference = XContentHelper.toXContent(config,
XContentType.JSON,
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")),
false);
assertThat(reference.utf8ToString(), not(containsString("definition")));
assertThat(reference.utf8ToString(), not(containsString("compressed_definition")));

reference = XContentHelper.toXContent(config,
XContentType.JSON,
new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "false")),
new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "true")),
false);
assertThat(reference.utf8ToString(), not(containsString("\"definition\"")));
assertThat(reference.utf8ToString(), containsString("compressed_definition"));
assertThat(reference.utf8ToString(), containsString(lazyModelDefinition.getCompressedString()));
assertThat(reference.utf8ToString(), containsString("\"definition\""));
assertThat(reference.utf8ToString(), not(containsString("compressed_definition")));
}

public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException {
Expand All @@ -179,7 +179,7 @@ public void testParseWithBothDefinitionAndCompressedSupplied() throws IOExceptio
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
Map<String, Object> objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2();

objectMap.put(TrainedModelConfig.COMPRESSED_DEFINITION.getPreferredName(), lazyModelDefinition.getCompressedString());
objectMap.put(TrainedModelConfig.DEFINITION.getPreferredName(), config.getModelDefinition());

try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(objectMap);
XContentParser parser = XContentType.JSON
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ public void testGetTrainedModels() throws IOException {
assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
assertThat(response, containsString("\"estimated_heap_memory_usage\""));
assertThat(response, containsString("\"definition\""));
assertThat(response, not(containsString("\"compressed_definition\"")));
assertThat(response, containsString("\"count\":1"));

getModel = client().performRequest(new Request("GET",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ protected void masterOperation(Request request, ClusterState state, ActionListen

ActionListener<Void> tagsModelIdCheckListener = ActionListener.wrap(
r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(
storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)),
bool -> {
TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build();
listener.onResponse(new PutTrainedModelAction.Response(configToReturn));
},
listener::onFailure
)),
listener::onFailure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
Expand All @@ -19,6 +25,8 @@

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.rest.RestRequest.Method.GET;
Expand All @@ -32,6 +40,8 @@ public RestGetTrainedModelsAction(RestController controller) {
controller.registerHandler(GET, MachineLearning.BASE_PATH + "inference", this);
}

private static final Map<String, String> DEFAULT_TO_XCONTENT_VALUES =
Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, Boolean.toString(true));
@Override
public String getName() {
return "ml_get_trained_models_action";
Expand All @@ -53,12 +63,33 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
}
request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources()));
return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel));
return channel -> client.execute(GetTrainedModelsAction.INSTANCE,
request,
new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES));
}

@Override
protected Set<String> responseParams() {
return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
}

private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {
private final Map<String, String> defaultToXContentParamValues;

private RestToXContentListenerWithDefaultValues(RestChannel channel, Map<String, String> defaultToXContentParamValues) {
super(channel);
this.defaultToXContentParamValues = defaultToXContentParamValues;
}

@Override
public RestResponse buildResponse(T response, XContentBuilder builder) throws Exception {
assert response.isFragment() == false; //would be nice if we could make default methods final
Map<String, String> params = new HashMap<>(channel.request().params());
defaultToXContentParamValues.forEach((k, v) ->
params.computeIfAbsent(k, defaultToXContentParamValues::get)
);
response.toXContent(builder, new ToXContent.MapParams(params));
return new BytesRestResponse(getStatus(response), builder);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,53 @@ setup:
}
}
}
---
"Test put model":
- do:
ml.put_trained_model:
model_id: my-regression-model
body: >
{
"description": "model for tests",
"input": {"field_names": ["field1", "field2"]},
"definition": {
"preprocessors": [],
"trained_model": {
"ensemble": {
"target_type": "regression",
"trained_models": [
{
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
{"node_index": 1, "leaf_value": 0},
{"node_index": 2, "leaf_value": 1}
],
"target_type": "regression"
}
},
{
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
{"node_index": 1, "leaf_value": 0},
{"node_index": 2, "leaf_value": 1}
],
"target_type": "regression"
}
}
]
}
}
}
}
- match: { model_id: my-regression-model }
- match: { estimated_operations: 6 }
- is_false: definition
- is_false: compressed_definition
- is_true: license_level
- is_true: create_time
- is_true: version
- is_true: estimated_heap_memory_usage_bytes