diff --git a/docs/reference/ml/df-analytics/apis/put-trained-models.asciidoc b/docs/reference/ml/df-analytics/apis/put-trained-models.asciidoc index 9f4b63d1283d4..b8d6c8ced2d09 100644 --- a/docs/reference/ml/df-analytics/apis/put-trained-models.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-trained-models.asciidoc @@ -24,7 +24,7 @@ WARNING: Models created in version 7.8.0 are not backwards compatible [[ml-put-trained-models-prereq]] == {api-prereq-title} -Requires the `manage_ml` cluster privilege. This privilege is included in the +Requires the `manage_ml` cluster privilege. This privilege is included in the `machine_learning_admin` built-in role. @@ -42,6 +42,17 @@ created by {dfanalytics}. (Required, string) include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] +[[ml-put-trained-models-query-params]] +== {api-query-parms-title} + +`defer_definition_decompression`:: +(Optional, boolean) +If set to `true` and a `compressed_definition` is provided, the request defers +definition decompression and skips relevant validations. +This deferral is useful for systems or users that know a good JVM heap size estimate for their +model and know that their model is valid and likely won't fail during inference. + + [role="child_attributes"] [[ml-put-trained-models-request-body]] == {api-request-body-title} diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/ml.put_trained_model.json b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.put_trained_model.json index 28cb5821cea18..af3ef880bf6bb 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/ml.put_trained_model.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.put_trained_model.json @@ -26,6 +26,14 @@ } ] }, + "params":{ + "defer_definition_decompression": { + "required": false, + "type": "boolean", + "description": "If set to `true` and a `compressed_definition` is provided, the request defers definition decompression and skips relevant validations.", + "default": false + } + }, "body":{ "description":"The trained model configuration", "required":true diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java index 23c6b8e499812..cbf5814a99863 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; @@ -22,9 +23,12 @@ import java.io.IOException; import java.util.Objects; +import static org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig.ESTIMATED_HEAP_MEMORY_USAGE_BYTES; + public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Response> { + public static final String DEFER_DEFINITION_DECOMPRESSION = "defer_definition_decompression"; public static final PutTrainedModelAction INSTANCE = new PutTrainedModelAction(); public static final String NAME = "cluster:admin/xpack/ml/inference/put"; private PutTrainedModelAction() { @@ -33,7 +37,7 @@ private PutTrainedModelAction() { public static class Request extends AcknowledgedRequest<Request> { - public static Request parseRequest(String modelId, XContentParser parser) { + public static Request parseRequest(String modelId, boolean deferDefinitionValidation, XContentParser parser) { TrainedModelConfig.Builder builder = TrainedModelConfig.STRICT_PARSER.apply(parser, null); if (builder.getModelId() == null) { @@ -47,18 +51,25 @@ public static Request parseRequest(String modelId, XContentParser parser) { } // Validations are done against the builder so we can build the full config object. // This allows us to not worry about serializing a builder class between nodes. - return new Request(builder.validate(true).build()); + return new Request(builder.validate(true).build(), deferDefinitionValidation); } private final TrainedModelConfig config; + private final boolean deferDefinitionDecompression; - public Request(TrainedModelConfig config) { + public Request(TrainedModelConfig config, boolean deferDefinitionDecompression) { this.config = config; + this.deferDefinitionDecompression = deferDefinitionDecompression; } public Request(StreamInput in) throws IOException { super(in); this.config = new TrainedModelConfig(in); + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { + this.deferDefinitionDecompression = in.readBoolean(); + } else { + this.deferDefinitionDecompression = false; + } } public TrainedModelConfig getTrainedModelConfig() { @@ -67,13 +78,31 @@ public TrainedModelConfig getTrainedModelConfig() { @Override public ActionRequestValidationException validate() { + if (deferDefinitionDecompression + && config.getEstimatedHeapMemory() == 0 + && config.getCompressedDefinitionIfSet() != null) { + ActionRequestValidationException validationException = new ActionRequestValidationException(); + validationException.addValidationError( + "when [" + + DEFER_DEFINITION_DECOMPRESSION + + "] is true and a compressed definition is provided, " + ESTIMATED_HEAP_MEMORY_USAGE_BYTES + " must be set" + ); + return validationException; + } return null; } + public boolean isDeferDefinitionDecompression() { + return deferDefinitionDecompression; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); config.writeTo(out); + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { + out.writeBoolean(deferDefinitionDecompression); + } } @Override @@ -81,12 +110,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Request request = (Request) o; - return Objects.equals(config, request.config); + return Objects.equals(config, request.config) && deferDefinitionDecompression == request.deferDefinitionDecompression; } @Override public int hashCode() { - return Objects.hash(config); + return Objects.hash(config, deferDefinitionDecompression); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index e09b16957770f..dd74ebcebe05d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -283,6 +283,14 @@ public BytesReference getCompressedDefinition() throws IOException { return definition.getCompressedDefinition(); } + public BytesReference getCompressedDefinitionIfSet() { + if (definition == null) { + return null; + } + return definition.getCompressedDefinitionIfSet(); + } + + public void clearCompressed() { definition.compressedRepresentation = null; } @@ -704,6 +712,7 @@ public Builder validate() { /** * Runs validations against the builder. + * @param forCreation indicates if we should validate for model creation or for a model read from storage * @return The current builder object if validations are successful * @throws ActionRequestValidationException when there are validation failures. */ @@ -773,12 +782,6 @@ public Builder validate(boolean forCreation) { validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException); validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException); validationException = checkIllegalSetting(createTime, CREATE_TIME.getPreferredName(), validationException); - validationException = checkIllegalSetting(estimatedHeapMemory, - ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), - validationException); - validationException = checkIllegalSetting(estimatedOperations, - ESTIMATED_OPERATIONS.getPreferredName(), - validationException); validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException); if (metadata != null) { validationException = checkIllegalSetting( @@ -877,6 +880,10 @@ private BytesReference getCompressedDefinition() throws IOException { return compressedRepresentation; } + private BytesReference getCompressedDefinitionIfSet() { + return compressedRepresentation; + } + private String getBase64CompressedDefinition() throws IOException { BytesReference compressedDef = getCompressedDefinition(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java index 67a7401b9a3d6..597c99ecec050 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java @@ -20,9 +20,12 @@ public class PutTrainedModelActionRequestTests extends AbstractWireSerializingTe @Override protected Request createTestInstance() { String modelId = randomAlphaOfLength(10); - return new Request(TrainedModelConfigTests.createTestInstance(modelId) - .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder()) - .build()); + return new Request( + TrainedModelConfigTests.createTestInstance(modelId) + .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .build(), + randomBoolean() + ); } @Override diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 5ecfefb85bd95..c50c7090aa459 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -158,6 +158,8 @@ tasks.named("yamlRestTest").configure { 'ml/inference_crud/Test update model alias where alias exists but reassign is false', 'ml/inference_crud/Test delete model alias with missing alias', 'ml/inference_crud/Test delete model alias where alias points to different model', + 'ml/inference_crud/Test put with defer_definition_decompression with invalid compression definition and no memory estimate', + 'ml/inference_crud/Test put with defer_definition_decompression with invalid definition and no memory estimate', 'ml/inference_processor/Test create processor with missing mandatory fields', 'ml/inference_stats_crud/Test get stats given missing trained model', 'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java index b240e169c9a94..60e66fa7e2c57 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java @@ -246,7 +246,8 @@ void createModelDeployment() { ) .setLocation(new IndexLocation(indexname)) .setModelId(TRAINED_MODEL_ID) - .build() + .build(), + false ) ) .actionGet(); diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java index 5d0204f20679b..4a9ead33393e8 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java @@ -747,7 +747,7 @@ private void putInferenceModel(String modelId) { .setInput(new TrainedModelInput(Collections.singletonList("feature1"))) .setInferenceConfig(RegressionConfig.EMPTY_PARAMS) .build(); - client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet(); + client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config, false)).actionGet(); } private static OperationMode randomInvalidLicenseType() { diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureLicenseTrackingIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureLicenseTrackingIT.java index 9af81a3a64529..3b6f2fd184e4f 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureLicenseTrackingIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureLicenseTrackingIT.java @@ -129,7 +129,7 @@ public void testFeatureTrackingInferenceModelPipeline() throws Exception { .setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding, false))) .setTrainedModel(buildClassification(true))) .build(); - client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet(); + client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config, false)).actionGet(); String pipelineId = "pipeline-inference-model-tracked"; putTrainedModelIngestPipeline(pipelineId, modelId); diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/UnusedStatsRemoverIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/UnusedStatsRemoverIT.java index 8c39b6386b721..edd80d0c7881d 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/UnusedStatsRemoverIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/UnusedStatsRemoverIT.java @@ -85,7 +85,8 @@ public void testRemoveUnusedStats() throws Exception { .build()) ) .validate(true) - .build())).actionGet(); + .build(), + false)).actionGet(); indexStatDocument(new DataCounts("analytics-with-stats", 1, 1, 1), DataCounts.documentId("analytics-with-stats")); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index 2be3eb4e16df3..72f4ac53fe0a1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -76,7 +76,9 @@ protected void masterOperation(Task task, ActionListener<Response> listener) { TrainedModelConfig config = request.getTrainedModelConfig(); try { - config.ensureParsedDefinition(xContentRegistry); + if (request.isDeferDefinitionDecompression() == false) { + config.ensureParsedDefinition(xContentRegistry); + } } catch (IOException ex) { listener.onFailure(ExceptionsHelper.badRequestException("Failed to parse definition for [{}]", ex, @@ -84,6 +86,7 @@ protected void masterOperation(Task task, return; } + // NOTE: hasModelDefinition is false if we don't parse it. But, if the fully parsed model was already provided, continue boolean hasModelDefinition = config.getModelDefinition() != null; if (hasModelDefinition) { try { @@ -140,9 +143,6 @@ protected void masterOperation(Task task, } } - - - TrainedModelConfig.Builder trainedModelConfig = new TrainedModelConfig.Builder(config) .setVersion(Version.CURRENT) .setCreateTime(Instant.now()) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index eef2c5ed6758c..dc66365b31d70 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -72,7 +72,6 @@ import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; @@ -137,16 +136,16 @@ public void storeTrainedModel(TrainedModelConfig trainedModelConfig, return; } + BytesReference definition; try { - trainedModelConfig.ensureParsedDefinition(xContentRegistry); + definition = trainedModelConfig.getCompressedDefinition(); } catch (IOException ex) { listener.onFailure(ExceptionsHelper.serverError( - "Unexpected serialization error when parsing model definition for model [" + trainedModelConfig.getModelId() + "]", - ex)); + "Unexpected IOException while serializing definition for storage for model [{}]", + ex, + trainedModelConfig.getModelId())); return; } - - TrainedModelDefinition definition = trainedModelConfig.getModelDefinition(); TrainedModelLocation location = trainedModelConfig.getLocation(); if (definition == null && location == null) { listener.onFailure(ExceptionsHelper.badRequestException("Unable to store [{}]. [{}] or [{}] is required", diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java index 08cd18451e72f..effacbae02cc6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java @@ -40,9 +40,9 @@ public String getName() { protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); XContentParser parser = restRequest.contentParser(); - PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, parser); + boolean deferDefinitionDecompression = restRequest.paramAsBoolean(PutTrainedModelAction.DEFER_DEFINITION_DECOMPRESSION, false); + PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, deferDefinitionDecompression, parser); putRequest.timeout(restRequest.paramAsTime("timeout", putRequest.timeout())); - return channel -> client.execute(PutTrainedModelAction.INSTANCE, putRequest, new RestToXContentListener<>(channel)); } } diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml index dafe223bfd60d..df75531eecd68 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml @@ -954,3 +954,78 @@ setup: ml.delete_trained_model_alias: model_alias: "regression-model" model_id: "a-regression-model-0" +--- +"Test put with defer_definition_decompression with invalid compressed definition": + - do: + ml.put_trained_model: + defer_definition_decompression: true + model_id: my-regression-model-with-bad-compressed-definition + body: > + { + "description": "model for tests", + "input": {"field_names": ["field1", "field2"]}, + "inference_config": {"classification": {}}, + "estimated_heap_memory_usage_bytes": 1024, + "compressed_definition": "H4sIAAAAAAAAAEy92a5mW26l9y55HWdj9o3u9RS+SMil4yrBUgpIpywY9fLmR3LMFSpI" + } + +--- +"Test put with defer_definition_decompression with invalid compression definition and no memory estimate": + - do: + catch: /when \[defer_definition_decompression\] is true and a compressed definition is provided, estimated_heap_memory_usage_bytes must be set/ + ml.put_trained_model: + defer_definition_decompression: true + model_id: my-regression-model-compressed-failed + body: > + { + "description": "model for tests", + "input": {"field_names": ["field1", "field2"]}, + "inference_config": {"classification": {}}, + "compressed_definition": "H4sIAAAAAAAAAEy92a5mW26l9y55HWdj9o3u9RS+SMil4yrBUgpIpywY9fLmR3LMFSpI" + } + +--- +"Test put with defer_definition_decompression with invalid definition and no memory estimate": + - do: + catch: /Model \[my-regression-model\] inference config type \[classification\] does not support definition target type \[regression\]/ + ml.put_trained_model: + defer_definition_decompression: true + model_id: my-regression-model + body: > + { + "description": "model for tests", + "input": {"field_names": ["field1", "field2"]}, + "inference_config": {"classification": {}}, + "definition": { + "preprocessors": [], + "trained_model": { + "ensemble": { + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + }, + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + } + ] + } + } + } + }