Skip to content

Commit

Permalink
[ML] adding new defer_definition_decompression parameter to put train…
Browse files Browse the repository at this point in the history
…ed model API (#77189)

This new parameter is a boolean parameter that allows
users to put in a compressed model without it having
to be inflated on the master node during the put
request

This is useful for system/module set up and then later
having the model validated and fully parsed when it
is being loaded on a node for usage
  • Loading branch information
benwtrent authored Sep 3, 2021
1 parent 174e226 commit 02e17c3
Show file tree
Hide file tree
Showing 14 changed files with 167 additions and 31 deletions.
13 changes: 12 additions & 1 deletion docs/reference/ml/df-analytics/apis/put-trained-models.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ WARNING: Models created in version 7.8.0 are not backwards compatible
[[ml-put-trained-models-prereq]]
== {api-prereq-title}

Requires the `manage_ml` cluster privilege. This privilege is included in the
Requires the `manage_ml` cluster privilege. This privilege is included in the
`machine_learning_admin` built-in role.


Expand All @@ -42,6 +42,17 @@ created by {dfanalytics}.
(Required, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]

[[ml-put-trained-models-query-params]]
== {api-query-parms-title}

`defer_definition_decompression`::
(Optional, boolean)
If set to `true` and a `compressed_definition` is provided, the request defers
definition decompression and skips relevant validations.
This deferral is useful for systems or users that know a good JVM heap size estimate for their
model and know that their model is valid and likely won't fail during inference.


[role="child_attributes"]
[[ml-put-trained-models-request-body]]
== {api-request-body-title}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
}
]
},
"params":{
"defer_definition_decompression": {
"required": false,
"type": "boolean",
"description": "If set to `true` and a `compressed_definition` is provided, the request defers definition decompression and skips relevant validations.",
"default": false
}
},
"body":{
"description":"The trained model configuration",
"required":true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
Expand All @@ -22,9 +23,12 @@
import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig.ESTIMATED_HEAP_MEMORY_USAGE_BYTES;


public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Response> {

public static final String DEFER_DEFINITION_DECOMPRESSION = "defer_definition_decompression";
public static final PutTrainedModelAction INSTANCE = new PutTrainedModelAction();
public static final String NAME = "cluster:admin/xpack/ml/inference/put";
private PutTrainedModelAction() {
Expand All @@ -33,7 +37,7 @@ private PutTrainedModelAction() {

public static class Request extends AcknowledgedRequest<Request> {

public static Request parseRequest(String modelId, XContentParser parser) {
public static Request parseRequest(String modelId, boolean deferDefinitionValidation, XContentParser parser) {
TrainedModelConfig.Builder builder = TrainedModelConfig.STRICT_PARSER.apply(parser, null);

if (builder.getModelId() == null) {
Expand All @@ -47,18 +51,25 @@ public static Request parseRequest(String modelId, XContentParser parser) {
}
// Validations are done against the builder so we can build the full config object.
// This allows us to not worry about serializing a builder class between nodes.
return new Request(builder.validate(true).build());
return new Request(builder.validate(true).build(), deferDefinitionValidation);
}

private final TrainedModelConfig config;
private final boolean deferDefinitionDecompression;

public Request(TrainedModelConfig config) {
public Request(TrainedModelConfig config, boolean deferDefinitionDecompression) {
this.config = config;
this.deferDefinitionDecompression = deferDefinitionDecompression;
}

public Request(StreamInput in) throws IOException {
super(in);
this.config = new TrainedModelConfig(in);
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
this.deferDefinitionDecompression = in.readBoolean();
} else {
this.deferDefinitionDecompression = false;
}
}

public TrainedModelConfig getTrainedModelConfig() {
Expand All @@ -67,26 +78,44 @@ public TrainedModelConfig getTrainedModelConfig() {

@Override
public ActionRequestValidationException validate() {
if (deferDefinitionDecompression
&& config.getEstimatedHeapMemory() == 0
&& config.getCompressedDefinitionIfSet() != null) {
ActionRequestValidationException validationException = new ActionRequestValidationException();
validationException.addValidationError(
"when ["
+ DEFER_DEFINITION_DECOMPRESSION
+ "] is true and a compressed definition is provided, " + ESTIMATED_HEAP_MEMORY_USAGE_BYTES + " must be set"
);
return validationException;
}
return null;
}

public boolean isDeferDefinitionDecompression() {
return deferDefinitionDecompression;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
config.writeTo(out);
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeBoolean(deferDefinitionDecompression);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(config, request.config);
return Objects.equals(config, request.config) && deferDefinitionDecompression == request.deferDefinitionDecompression;
}

@Override
public int hashCode() {
return Objects.hash(config);
return Objects.hash(config, deferDefinitionDecompression);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,14 @@ public BytesReference getCompressedDefinition() throws IOException {
return definition.getCompressedDefinition();
}

public BytesReference getCompressedDefinitionIfSet() {
if (definition == null) {
return null;
}
return definition.getCompressedDefinitionIfSet();
}


public void clearCompressed() {
definition.compressedRepresentation = null;
}
Expand Down Expand Up @@ -704,6 +712,7 @@ public Builder validate() {

/**
* Runs validations against the builder.
* @param forCreation indicates if we should validate for model creation or for a model read from storage
* @return The current builder object if validations are successful
* @throws ActionRequestValidationException when there are validation failures.
*/
Expand Down Expand Up @@ -773,12 +782,6 @@ public Builder validate(boolean forCreation) {
validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);
validationException = checkIllegalSetting(createTime, CREATE_TIME.getPreferredName(), validationException);
validationException = checkIllegalSetting(estimatedHeapMemory,
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
validationException);
validationException = checkIllegalSetting(estimatedOperations,
ESTIMATED_OPERATIONS.getPreferredName(),
validationException);
validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
if (metadata != null) {
validationException = checkIllegalSetting(
Expand Down Expand Up @@ -877,6 +880,10 @@ private BytesReference getCompressedDefinition() throws IOException {
return compressedRepresentation;
}

private BytesReference getCompressedDefinitionIfSet() {
return compressedRepresentation;
}

private String getBase64CompressedDefinition() throws IOException {
BytesReference compressedDef = getCompressedDefinition();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ public class PutTrainedModelActionRequestTests extends AbstractWireSerializingTe
@Override
protected Request createTestInstance() {
String modelId = randomAlphaOfLength(10);
return new Request(TrainedModelConfigTests.createTestInstance(modelId)
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.build());
return new Request(
TrainedModelConfigTests.createTestInstance(modelId)
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.build(),
randomBoolean()
);
}

@Override
Expand Down
2 changes: 2 additions & 0 deletions x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ tasks.named("yamlRestTest").configure {
'ml/inference_crud/Test update model alias where alias exists but reassign is false',
'ml/inference_crud/Test delete model alias with missing alias',
'ml/inference_crud/Test delete model alias where alias points to different model',
'ml/inference_crud/Test put with defer_definition_decompression with invalid compression definition and no memory estimate',
'ml/inference_crud/Test put with defer_definition_decompression with invalid definition and no memory estimate',
'ml/inference_processor/Test create processor with missing mandatory fields',
'ml/inference_stats_crud/Test get stats given missing trained model',
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ void createModelDeployment() {
)
.setLocation(new IndexLocation(indexname))
.setModelId(TRAINED_MODEL_ID)
.build()
.build(),
false
)
)
.actionGet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ private void putInferenceModel(String modelId) {
.setInput(new TrainedModelInput(Collections.singletonList("feature1")))
.setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
.build();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config, false)).actionGet();
}

private static OperationMode randomInvalidLicenseType() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public void testFeatureTrackingInferenceModelPipeline() throws Exception {
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding, false)))
.setTrainedModel(buildClassification(true)))
.build();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config, false)).actionGet();

String pipelineId = "pipeline-inference-model-tracked";
putTrainedModelIngestPipeline(pipelineId, modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ public void testRemoveUnusedStats() throws Exception {
.build())
)
.validate(true)
.build())).actionGet();
.build(),
false)).actionGet();

indexStatDocument(new DataCounts("analytics-with-stats", 1, 1, 1),
DataCounts.documentId("analytics-with-stats"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,17 @@ protected void masterOperation(Task task,
ActionListener<Response> listener) {
TrainedModelConfig config = request.getTrainedModelConfig();
try {
config.ensureParsedDefinition(xContentRegistry);
if (request.isDeferDefinitionDecompression() == false) {
config.ensureParsedDefinition(xContentRegistry);
}
} catch (IOException ex) {
listener.onFailure(ExceptionsHelper.badRequestException("Failed to parse definition for [{}]",
ex,
config.getModelId()));
return;
}

// NOTE: hasModelDefinition is false if we don't parse it. But, if the fully parsed model was already provided, continue
boolean hasModelDefinition = config.getModelDefinition() != null;
if (hasModelDefinition) {
try {
Expand Down Expand Up @@ -140,9 +143,6 @@ protected void masterOperation(Task task,
}
}




TrainedModelConfig.Builder trainedModelConfig = new TrainedModelConfig.Builder(config)
.setVersion(Version.CURRENT)
.setCreateTime(Instant.now())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
Expand Down Expand Up @@ -137,16 +136,16 @@ public void storeTrainedModel(TrainedModelConfig trainedModelConfig,
return;
}

BytesReference definition;
try {
trainedModelConfig.ensureParsedDefinition(xContentRegistry);
definition = trainedModelConfig.getCompressedDefinition();
} catch (IOException ex) {
listener.onFailure(ExceptionsHelper.serverError(
"Unexpected serialization error when parsing model definition for model [" + trainedModelConfig.getModelId() + "]",
ex));
"Unexpected IOException while serializing definition for storage for model [{}]",
ex,
trainedModelConfig.getModelId()));
return;
}

TrainedModelDefinition definition = trainedModelConfig.getModelDefinition();
TrainedModelLocation location = trainedModelConfig.getLocation();
if (definition == null && location == null) {
listener.onFailure(ExceptionsHelper.badRequestException("Unable to store [{}]. [{}] or [{}] is required",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ public String getName() {
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
XContentParser parser = restRequest.contentParser();
PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, parser);
boolean deferDefinitionDecompression = restRequest.paramAsBoolean(PutTrainedModelAction.DEFER_DEFINITION_DECOMPRESSION, false);
PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, deferDefinitionDecompression, parser);
putRequest.timeout(restRequest.paramAsTime("timeout", putRequest.timeout()));

return channel -> client.execute(PutTrainedModelAction.INSTANCE, putRequest, new RestToXContentListener<>(channel));
}
}
Loading

0 comments on commit 02e17c3

Please sign in to comment.