From 148bf11f925717b1a4731b451dae88b5053fdb62 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 11 Jan 2023 15:20:26 +0000 Subject: [PATCH] [ML] Utilise parallel allocations where the inference request contains multiple documents (#92819) Divide work from the _infer API among all allocations Backport of #92359 --- docs/changelog/92359.yaml | 6 + .../infer-trained-model-deployment.asciidoc | 12 +- .../apis/infer-trained-model.asciidoc | 16 +- .../core/ml/action/InferModelAction.java | 53 +++-- .../InferTrainedModelDeploymentAction.java | 49 +++-- .../assignment/TrainedModelAssignment.java | 39 +++- .../TextEmbeddingConfigUpdate.java | 2 + .../action/InferModelActionRequestTests.java | 3 +- ...erTrainedModelDeploymentResponseTests.java | 71 +++++++ .../TrainedModelAssignmentTests.java | 58 +++--- .../xpack/ml/integration/PyTorchModelIT.java | 117 ++++++++++- ...portInferTrainedModelDeploymentAction.java | 99 ++------- .../TransportInternalInferModelAction.java | 197 ++++++++++++------ .../inference/ingest/InferenceProcessor.java | 2 +- .../loadingservice/ModelLoadingService.java | 21 +- .../RestInferTrainedModelAction.java | 11 +- ...RestInferTrainedModelDeploymentAction.java | 38 +++- 17 files changed, 518 insertions(+), 276 deletions(-) create mode 100644 docs/changelog/92359.yaml create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java diff --git a/docs/changelog/92359.yaml b/docs/changelog/92359.yaml new file mode 100644 index 0000000000000..cf76f3713a139 --- /dev/null +++ b/docs/changelog/92359.yaml @@ -0,0 +1,6 @@ +pr: 92359 +summary: Utilise parallel allocations where the inference request contains multiple + documents +area: Machine Learning +type: bug +issues: [] diff --git a/docs/reference/ml/trained-models/apis/infer-trained-model-deployment.asciidoc b/docs/reference/ml/trained-models/apis/infer-trained-model-deployment.asciidoc index 132d8e9de0700..d92d74d894a33 100644 --- a/docs/reference/ml/trained-models/apis/infer-trained-model-deployment.asciidoc +++ b/docs/reference/ml/trained-models/apis/infer-trained-model-deployment.asciidoc @@ -46,8 +46,8 @@ Controls the amount of time to wait for {infer} results. Defaults to 10 seconds. `docs`:: (Required, array) An array of objects to pass to the model for inference. The objects should -contain a field matching your configured trained model input. Typically, the -field name is `text_field`. Currently, only a single value is allowed. +contain a field matching your configured trained model input. Typically, the +field name is `text_field`. //// [[infer-trained-model-deployment-results]] @@ -62,7 +62,7 @@ field name is `text_field`. Currently, only a single value is allowed. [[infer-trained-model-deployment-example]] == {api-examples-title} -The response depends on the task the model is trained for. If it is a text +The response depends on the task the model is trained for. If it is a text classification task, the response is the score. For example: [source,console] @@ -123,7 +123,7 @@ The API returns in this case: ---- // NOTCONSOLE -Zero-shot classification tasks require extra configuration defining the class +Zero-shot classification tasks require extra configuration defining the class labels. These labels are passed in the zero-shot inference config. [source,console] @@ -150,7 +150,7 @@ POST _ml/trained_models/model2/deployment/_infer -------------------------------------------------- // TEST[skip:TBD] -The API returns the predicted label and the confidence, as well as the top +The API returns the predicted label and the confidence, as well as the top classes: [source,console-result] @@ -205,7 +205,7 @@ POST _ml/trained_models/model2/deployment/_infer -------------------------------------------------- // TEST[skip:TBD] -When the input has been truncated due to the limit imposed by the model's +When the input has been truncated due to the limit imposed by the model's `max_sequence_length` the `is_truncated` field appears in the response. [source,console-result] diff --git a/docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc b/docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc index 9f64a4a0e10dd..d5dbc90bfab0d 100644 --- a/docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc +++ b/docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc @@ -6,10 +6,10 @@ Infer trained model ++++ -Evaluates a trained model. The model may be any supervised model either trained +Evaluates a trained model. The model may be any supervised model either trained by {dfanalytics} or imported. -NOTE: For model deployments with caching enabled, results may be returned +NOTE: For model deployments with caching enabled, results may be returned directly from the {infer} cache. beta::[] @@ -51,9 +51,7 @@ Controls the amount of time to wait for {infer} results. Defaults to 10 seconds. (Required, array) An array of objects to pass to the model for inference. The objects should contain the fields matching your configured trained model input. Typically for -NLP models, the field name is `text_field`. Currently for NLP models, only a -single value is allowed. For {dfanalytics} or imported classification or -regression models, more than one value is allowed. +NLP models, the field name is `text_field`. //Begin inference_config `inference_config`:: @@ -106,7 +104,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-fill-mask] ===== `num_top_classes`:::: (Optional, integer) -Number of top predicted tokens to return for replacing the mask token. Defaults +Number of top predicted tokens to return for replacing the mask token. Defaults to `0`. `results_field`:::: @@ -277,7 +275,7 @@ The maximum amount of words in the answer. Defaults to `15`. `num_top_classes`:::: (Optional, integer) -The number the top found answers to return. Defaults to `0`, meaning only the +The number the top found answers to return. Defaults to `0`, meaning only the best found answer is returned. `question`:::: @@ -374,7 +372,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-text-classific `num_top_classes`:::: (Optional, integer) -Specifies the number of top class predictions to return. Defaults to all classes +Specifies the number of top class predictions to return. Defaults to all classes (-1). `results_field`:::: @@ -886,7 +884,7 @@ POST _ml/trained_models/model2/_infer -------------------------------------------------- // TEST[skip:TBD] -When the input has been truncated due to the limit imposed by the model's +When the input has been truncated due to the limit imposed by the model's `max_sequence_length` the `is_truncated` field appears in the response. [source,console-result] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index e2e957481680e..7182dfe18f028 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -23,11 +23,12 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -80,43 +81,29 @@ public static Builder parseRequest(String modelId, XContentParser parser) { private final List> objectsToInfer; private final InferenceConfigUpdate update; private final boolean previouslyLicensed; - private final TimeValue timeout; - - public Request(String modelId, boolean previouslyLicensed) { - this(modelId, Collections.emptyList(), RegressionConfigUpdate.EMPTY_PARAMS, TimeValue.MAX_VALUE, previouslyLicensed); - } + private TimeValue timeout; public Request( String modelId, List> objectsToInfer, - InferenceConfigUpdate inferenceConfig, - TimeValue timeout, + InferenceConfigUpdate inferenceConfigUpdate, boolean previouslyLicensed ) { - this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); - this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, DOCS.getPreferredName())); - this.update = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config"); - this.previouslyLicensed = previouslyLicensed; - this.timeout = timeout; + this(modelId, objectsToInfer, inferenceConfigUpdate, DEFAULT_TIMEOUT, previouslyLicensed); } public Request( String modelId, List> objectsToInfer, - InferenceConfigUpdate inferenceConfig, + InferenceConfigUpdate inferenceConfigUpdate, + TimeValue timeout, boolean previouslyLicensed ) { - this(modelId, objectsToInfer, inferenceConfig, TimeValue.MAX_VALUE, previouslyLicensed); - } - - public Request(String modelId, Map objectToInfer, InferenceConfigUpdate update, boolean previouslyLicensed) { - this( - modelId, - Collections.singletonList(ExceptionsHelper.requireNonNull(objectToInfer, DOCS.getPreferredName())), - update, - TimeValue.MAX_VALUE, - previouslyLicensed - ); + this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); + this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, DOCS.getPreferredName())); + this.update = ExceptionsHelper.requireNonNull(inferenceConfigUpdate, "inference_config"); + this.previouslyLicensed = previouslyLicensed; + this.timeout = timeout; } public Request(StreamInput in) throws IOException { @@ -132,6 +119,10 @@ public Request(StreamInput in) throws IOException { } } + public int numberOfDocuments() { + return objectsToInfer.size(); + } + public String getModelId() { return modelId; } @@ -152,6 +143,10 @@ public TimeValue getTimeout() { return timeout; } + public void setTimeout(TimeValue timeout) { + this.timeout = timeout; + } + @Override public ActionRequestValidationException validate() { return null; @@ -196,7 +191,7 @@ public static class Builder { private String modelId; private List> docs; private TimeValue timeout; - private InferenceConfigUpdate update; + private InferenceConfigUpdate update = new EmptyConfigUpdate(); private Builder() {} @@ -302,12 +297,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } public static class Builder { - private List inferenceResults; + private List inferenceResults = new ArrayList<>(); private String modelId; private boolean isLicensed; - public Builder setInferenceResults(List inferenceResults) { - this.inferenceResults = inferenceResults; + public Builder addInferenceResults(List inferenceResults) { + this.inferenceResults.addAll(inferenceResults); return this; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java index f285da6b935ff..d49bba4a026bf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java @@ -44,7 +44,14 @@ public class InferTrainedModelDeploymentAction extends ActionType 1) { - // TODO support multiple docs - validationException = addValidationError("multiple documents are not supported", validationException); - } } return validationException; } @@ -244,34 +247,44 @@ public Request build() { public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject { - private final InferenceResults results; + private final List results; - public Response(InferenceResults result) { + public Response(List results) { super(Collections.emptyList(), Collections.emptyList()); - this.results = Objects.requireNonNull(result); + this.results = Objects.requireNonNull(results); } public Response(StreamInput in) throws IOException { super(in); - results = in.readNamedWriteable(InferenceResults.class); - } - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - results.toXContent(builder, params); - builder.endObject(); - return builder; + // Multiple results added in 8.6.1 + if (in.getVersion().onOrAfter(Version.V_8_6_1)) { + results = in.readNamedWriteableList(InferenceResults.class); + } else { + results = List.of(in.readNamedWriteable(InferenceResults.class)); + } } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeNamedWriteable(results); + if (out.getVersion().onOrAfter(Version.V_8_6_1)) { + out.writeNamedWriteableList(results); + } else { + out.writeNamedWriteable(results.get(0)); + } } - public InferenceResults getResults() { + public List getResults() { return results; } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + results.get(0).toXContent(builder, params); + builder.endObject(); + return builder; + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java index 901416a71b513..0468514d31eb1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.Randomness; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Tuple; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -177,7 +178,7 @@ public String[] getStartedNodes() { .toArray(String[]::new); } - public Optional selectRandomStartedNodeWeighedOnAllocations() { + public List> selectRandomStartedNodesWeighedOnAllocationsForNRequests(int numberOfRequests) { List nodeIds = new ArrayList<>(nodeRoutingTable.size()); List cumulativeAllocations = new ArrayList<>(nodeRoutingTable.size()); int allocationSum = 0; @@ -189,18 +190,42 @@ public Optional selectRandomStartedNodeWeighedOnAllocations() { } } + if (nodeIds.isEmpty()) { + return List.of(); + } + if (allocationSum == 0) { // If we are in a mixed cluster where there are assignments prior to introducing allocation distribution // we could have a zero-sum of allocations. We fall back to returning a random started node. - return nodeIds.isEmpty() ? Optional.empty() : Optional.of(nodeIds.get(Randomness.get().nextInt(nodeIds.size()))); + int[] counts = new int[nodeIds.size()]; + for (int i = 0; i < numberOfRequests; i++) { + counts[Randomness.get().nextInt(nodeIds.size())]++; + } + + var nodeCounts = new ArrayList>(); + for (int i = 0; i < counts.length; i++) { + nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i])); + } + return nodeCounts; + } + + int[] counts = new int[nodeIds.size()]; + var randomIter = Randomness.get().ints(numberOfRequests, 1, allocationSum + 1).iterator(); + for (int i = 0; i < numberOfRequests; i++) { + int randomInt = randomIter.nextInt(); + int nodeIndex = Collections.binarySearch(cumulativeAllocations, randomInt); + if (nodeIndex < 0) { + nodeIndex = -nodeIndex - 1; + } + + counts[nodeIndex]++; } - int randomInt = Randomness.get().ints(1, 1, allocationSum + 1).iterator().nextInt(); - int nodeIndex = Collections.binarySearch(cumulativeAllocations, randomInt); - if (nodeIndex < 0) { - nodeIndex = -nodeIndex - 1; + var nodeCounts = new ArrayList>(); + for (int i = 0; i < counts.length; i++) { + nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i])); } - return Optional.of(nodeIds.get(nodeIndex)); + return nodeCounts; } public Optional getReason() { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdate.java index 589b71bd631d0..d8884069997d0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdate.java @@ -28,6 +28,8 @@ public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedX public static final String NAME = TextEmbeddingConfig.NAME; + public static TextEmbeddingConfigUpdate EMPTY_INSTANCE = new TextEmbeddingConfigUpdate(null, null); + public static TextEmbeddingConfigUpdate fromMap(Map map) { Map options = new HashMap<>(map); String resultsField = (String) options.remove(RESULTS_FIELD.getPreferredName()); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java index ad3c926543528..0beb2bd0a9c38 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -53,7 +53,7 @@ protected Request createTestInstance() { TimeValue.parseTimeValue(randomTimeValue(), null, "test"), randomBoolean() ) - : new Request(randomAlphaOfLength(10), randomMap(), randomInferenceConfigUpdate(), randomBoolean()); + : new Request(randomAlphaOfLength(10), List.of(randomMap()), randomInferenceConfigUpdate(), randomBoolean()); } private static InferenceConfigUpdate randomInferenceConfigUpdate() { @@ -115,6 +115,7 @@ protected Request mutateInstanceForVersion(Request instance, Version version) { } else { adjustedUpdate = currentUpdate; } + return version.before(Version.V_8_3_0) ? new Request( instance.getModelId(), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java new file mode 100644 index 0000000000000..349894e9d0e34 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResultsTests; +import org.junit.Before; + +import java.util.List; + +public class InferTrainedModelDeploymentResponseTests extends AbstractBWCWireSerializationTestCase< + InferTrainedModelDeploymentAction.Response> { + + private NamedWriteableRegistry namedWriteableRegistry; + private NamedXContentRegistry namedXContentRegistry; + + @Before + public void registerNamedXContents() { + namedXContentRegistry = new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedWriteableRegistry = new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return namedXContentRegistry; + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return namedWriteableRegistry; + } + + @Override + protected Writeable.Reader instanceReader() { + return InferTrainedModelDeploymentAction.Response::new; + } + + @Override + protected InferTrainedModelDeploymentAction.Response createTestInstance() { + return new InferTrainedModelDeploymentAction.Response( + List.of( + TextEmbeddingResultsTests.createRandomResults(), + TextEmbeddingResultsTests.createRandomResults(), + TextEmbeddingResultsTests.createRandomResults(), + TextEmbeddingResultsTests.createRandomResults() + ) + ); + } + + @Override + protected InferTrainedModelDeploymentAction.Response mutateInstanceForVersion( + InferTrainedModelDeploymentAction.Response instance, + Version version + ) { + if (version.before(Version.V_8_6_1)) { + return new InferTrainedModelDeploymentAction.Response(instance.getResults().subList(0, 1)); + } + + return instance; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java index bb6d1904a99ab..f09dbc7afcdc3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.core.Tuple; import org.elasticsearch.test.AbstractXContentSerializingTestCase; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; @@ -19,17 +20,16 @@ import org.elasticsearch.xpack.core.ml.stats.CountAccumulator; import java.io.IOException; +import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; import java.util.stream.Stream; import static org.hamcrest.Matchers.arrayContainingInAnyOrder; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThanOrEqualTo; @@ -157,62 +157,62 @@ public void testCalculateAndSetAssignmentState_GivenStoppingAssignment() { ); } - public void testSelectRandomStartedNodeWeighedOnAllocations_GivenNoStartedAllocations() { + public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenNoStartedAllocations() { TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTING, "")); builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STOPPED, "")); TrainedModelAssignment assignment = builder.build(); - assertThat(assignment.selectRandomStartedNodeWeighedOnAllocations().isEmpty(), is(true)); + assertThat(assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1).isEmpty(), is(true)); } - public void testSelectRandomStartedNodeWeighedOnAllocations_GivenSingleStartedNode() { + public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenSingleStartedNode() { TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); TrainedModelAssignment assignment = builder.build(); - Optional node = assignment.selectRandomStartedNodeWeighedOnAllocations(); + var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1); - assertThat(node.isPresent(), is(true)); - assertThat(node.get(), equalTo("node-1")); + assertThat(nodes, hasSize(1)); + assertThat(nodes.get(0), equalTo(new Tuple<>("node-1", 1))); } - public void testSelectRandomStartedNodeWeighedOnAllocations_GivenMultipleStartedNodes() { + public void testSelectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMultipleStartedNodes() { TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); TrainedModelAssignment assignment = builder.build(); - final long selectionCount = 10000; + final int selectionCount = 10000; final CountAccumulator countsPerNodeAccumulator = new CountAccumulator(); - for (int i = 0; i < selectionCount; i++) { - Optional node = assignment.selectRandomStartedNodeWeighedOnAllocations(); - assertThat(node.isPresent(), is(true)); - countsPerNodeAccumulator.add(node.get(), 1L); - } + var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount); - Map countsPerNode = countsPerNodeAccumulator.asMap(); - assertThat(countsPerNode.keySet(), contains("node-1", "node-2", "node-3")); - assertThat(countsPerNode.get("node-1") + countsPerNode.get("node-2") + countsPerNode.get("node-3"), equalTo(selectionCount)); + assertThat(nodes, hasSize(3)); + assertThat(nodes.stream().mapToInt(Tuple::v2).sum(), equalTo(selectionCount)); + var asMap = new HashMap(); + for (var node : nodes) { + asMap.put(node.v1(), node.v2()); + } - assertValueWithinPercentageOfExpectedRatio(countsPerNode.get("node-1"), selectionCount, 1.0 / 6.0, 0.2); - assertValueWithinPercentageOfExpectedRatio(countsPerNode.get("node-2"), selectionCount, 2.0 / 6.0, 0.2); - assertValueWithinPercentageOfExpectedRatio(countsPerNode.get("node-3"), selectionCount, 3.0 / 6.0, 0.2); + assertValueWithinPercentageOfExpectedRatio(asMap.get("node-1"), selectionCount, 1.0 / 6.0, 0.2); + assertValueWithinPercentageOfExpectedRatio(asMap.get("node-2"), selectionCount, 2.0 / 6.0, 0.2); + assertValueWithinPercentageOfExpectedRatio(asMap.get("node-3"), selectionCount, 3.0 / 6.0, 0.2); } - public void testSelectRandomStartedNodeWeighedOnAllocations_GivenMultipleStartedNodesWithZeroAllocations() { + public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMultipleStartedNodesWithZeroAllocations() { TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); builder.addRoutingEntry("node-1", new RoutingInfo(0, 0, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(0, 0, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(0, 0, RoutingState.STARTED, "")); TrainedModelAssignment assignment = builder.build(); - final long selectionCount = 1000; - Set selectedNodes = new HashSet<>(); - for (int i = 0; i < selectionCount; i++) { - Optional selectedNode = assignment.selectRandomStartedNodeWeighedOnAllocations(); - assertThat(selectedNode.isPresent(), is(true)); - selectedNodes.add(selectedNode.get()); + final int selectionCount = 1000; + var nodeCounts = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount); + assertThat(nodeCounts, hasSize(3)); + + var selectedNodes = new HashSet(); + for (var node : nodeCounts) { + selectedNodes.add(node.v1()); } assertThat(selectedNodes, contains("node-1", "node-2", "node-3")); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java index 650bfcba71cfa..588dfb9228148 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java @@ -20,8 +20,10 @@ import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; import java.io.IOException; +import java.util.ArrayList; import java.util.Base64; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; @@ -93,6 +95,44 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase { RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length; } + static final String BASE_64_ENCODED_TEXT_EMBEDDING_MODEL = "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWl" + + "paWlpaWlpaWoACY19fdG9yY2hfXwpUaW55VGV4dEVtYmVkZGluZwpxACmBfShYCAAAAHRy" + + "YWluaW5ncQGJWBYAAABfaXNfZnVsbF9iYWNrd2FyZF9ob29rcQJOdWJxAy5QSwcIsFTQsF" + + "gAAABYAAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAB0Ac2ltcGxlbW9kZWwvY29k" + + "ZS9fX3RvcmNoX18ucHlGQhkAWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWoWPMWvDMBCF9/" + + "yKGy1IQ7Ia0q1j2yWbMYdsnWphWWd0Em3+fS3bBEopXd99j/dd77UI3Fy43+grvUwdGePC" + + "R/XKJntS9QEAcdZRT5QoCiJcoWnXtMvW/ohS1C4sZaihY/YFcoI2e4+d7sdPHQ0OzONyf5" + + "+T46B9U8DSNWTBcixMJeRtvQwkjv2AePpld1wKAC7MOaEzUsONgnDc4sQjBUz3mbbbY2qD" + + "2usbB9rQmcWV47/gOiVIReAvUsHT8y5S7yKL/mnSIWuPQmSqLRm0DJWkWD0eUEqtjUgpx7" + + "AXow6mai5HuJzPrTp8A1BLBwiD/6yJ6gAAAKkBAABQSwMEFAAICAgAAAAAAAAAAAAAAAAA" + + "AAAAACcAQQBzaW1wbGVtb2RlbC9jb2RlL19fdG9yY2hfXy5weS5kZWJ1Z19wa2xGQj0AWl" + + "paWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa" + + "WlpaWlpaWo2Qz0rDQBDGk/5RmjfwlmMCbWivBZ9gWL0IFkRCdLcmmOwmuxu0N08O3r2rCO" + + "rdx9CDgm/hWUUQMdugzUk6LCwzv++bGeak5YE1saoorNgCCwsbzFc9sm1PvivQo2zqToU8" + + "iiT1FEunfadXRcLzUocJVWN3i3ElZF3W4pDxUM9yVrPNXCeCR+lOLdp1190NwVktzoVKDF" + + "5COh+nQpbtsX+0/tjpOWYJuR8HMuJUZEEW8TJKQ8UY9eJIxZ7S0vvb3vf9yiCZLiV3Fz5v" + + "1HdHw6HvFK3JWnUElWR5ygbz8TThB4NMUJYG+axowyoWHbiHBwQbSWbHHXiEJ4QWkmOTPM" + + "MLQhvJaZOgSX49Z3a8uPq5Ia/whtBBctEkl4a8wwdCF8lVk1wb8glfCCtIbprkttntrkF0" + + "0Q1+AFBLBwi4BIswOAEAAP0BAABQSwMEAAAICAAAAAAAAAAAAAAAAAAAAAAAABkAQQBzaW" + + "1wbGVtb2RlbC9jb25zdGFudHMucGtsRkI9AFpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa" + + "WlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlqAAikuUEsHCG0vCVcEAAAABA" + + "AAAFBLAwQAAAgIAAAAAAAAAAAAAAAAAAAAAAAAEwA7AHNpbXBsZW1vZGVsL3ZlcnNpb25G" + + "QjcAWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWl" + + "paWlpaWjMKUEsHCNGeZ1UCAAAAAgAAAFBLAQIAAAAACAgAAAAAAACwVNCwWAAAAFgAAAAU" + + "AAAAAAAAAAAAAAAAAAAAAABzaW1wbGVtb2RlbC9kYXRhLnBrbFBLAQIAABQACAgIAAAAAA" + + "CD/6yJ6gAAAKkBAAAdAAAAAAAAAAAAAAAAAKgAAABzaW1wbGVtb2RlbC9jb2RlL19fdG9y" + + "Y2hfXy5weVBLAQIAABQACAgIAAAAAAC4BIswOAEAAP0BAAAnAAAAAAAAAAAAAAAAAPoBAA" + + "BzaW1wbGVtb2RlbC9jb2RlL19fdG9yY2hfXy5weS5kZWJ1Z19wa2xQSwECAAAAAAgIAAAA" + + "AAAAbS8JVwQAAAAEAAAAGQAAAAAAAAAAAAAAAADIAwAAc2ltcGxlbW9kZWwvY29uc3Rhbn" + + "RzLnBrbFBLAQIAAAAACAgAAAAAAADRnmdVAgAAAAIAAAATAAAAAAAAAAAAAAAAAFQEAABz" + + "aW1wbGVtb2RlbC92ZXJzaW9uUEsGBiwAAAAAAAAAHgMtAAAAAAAAAAAABQAAAAAAAAAFAA" + + "AAAAAAAGoBAAAAAAAA0gQAAAAAAABQSwYHAAAAADwGAAAAAAAAAQAAAFBLBQYAAAAABQAFAGoBAADSBAAAAAA="; + + static final long RAW_TEXT_EMBEDDING_MODEL_SIZE; // size of the model before base64 encoding + static { + RAW_TEXT_EMBEDDING_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_TEXT_EMBEDDING_MODEL).length; + } + public void testEvaluate() throws IOException, InterruptedException { String modelId = "test_evaluate"; createPassThroughModel(modelId); @@ -418,6 +458,79 @@ public void testInferWithMissingModel() { assertThat(ex.getMessage(), containsString("Could not find trained model [missing_model]")); } + @SuppressWarnings("unchecked") + public void testInferWithMultipleDocs() throws IOException { + String modelId = "infer_multi_docs"; + // Use the text embedding model from SemanticSearchIT so + // that each response can be linked to the originating request. + // The test ensures the responses are returned in the same order + // as the requests + createTextEmbeddingModel(modelId); + putModelDefinition(modelId, BASE_64_ENCODED_TEXT_EMBEDDING_MODEL, RAW_TEXT_EMBEDDING_MODEL_SIZE); + putVocabulary( + List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"), + modelId + ); + startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED.toString()); + + List inputs = List.of( + "my words", + "the machine is leaking", + "washing machine", + "these are my words", + "the octopus comforter smells", + "the octopus comforter is leaking", + "washing machine smells" + ); + + List> expectedEmbeddings = new ArrayList<>(); + + // Generate the text embeddings one at a time using the _infer API + // then index them for search + for (var input : inputs) { + Response inference = infer(input, modelId); + List> responseMap = (List>) entityAsMap(inference).get("inference_results"); + Map inferenceResult = responseMap.get(0); + List embedding = (List) inferenceResult.get("predicted_value"); + expectedEmbeddings.add(embedding); + } + + // Now do the same with all documents sent at once + var docsBuilder = new StringBuilder(); + int numInputs = inputs.size(); + for (int i = 0; i < numInputs - 1; i++) { + docsBuilder.append("{\"input\":\"").append(inputs.get(i)).append("\"},"); + } + docsBuilder.append("{\"input\":\"").append(inputs.get(numInputs - 1)).append("\"}"); + + { + Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer"); + request.setJsonEntity(String.format(Locale.ROOT, """ + { "docs": [%s] } + """, docsBuilder)); + Response response = client().performRequest(request); + var responseMap = entityAsMap(response); + List> inferenceResults = (List>) responseMap.get("inference_results"); + assertThat(inferenceResults, hasSize(numInputs)); + + // Check the result order matches the input order by comparing + // the to the pre-calculated embeddings + for (int i = 0; i < numInputs; i++) { + List embedding = (List) inferenceResults.get(i).get("predicted_value"); + assertArrayEquals(expectedEmbeddings.get(i).toArray(), embedding.toArray()); + } + } + { + // the deprecated deployment/_infer endpoint does not support multiple docs + Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_infer"); + request.setJsonEntity(String.format(Locale.ROOT, """ + { "docs": [%s] } + """, docsBuilder)); + Exception ex = expectThrows(Exception.class, () -> client().performRequest(request)); + assertThat(ex.getMessage(), containsString("multiple documents are not supported")); + } + } + public void testGetPytorchModelWithDefinition() throws IOException { String model = "should-fail-get"; createPassThroughModel(model); @@ -475,7 +588,7 @@ public void testInferencePipelineAgainstUnallocatedModel() throws IOException { assertThat( response, allOf( - containsString("model [not-deployed] must be deployed to use. Please deploy with the start trained model deployment API."), + containsString("Model [not-deployed] must be deployed to use. Please deploy with the start trained model deployment API."), containsString("error"), not(containsString("warning")) ) @@ -498,7 +611,7 @@ public void testInferencePipelineAgainstUnallocatedModel() throws IOException { } """); Exception ex = expectThrows(Exception.class, () -> client().performRequest(request)); - assertThat(ex.getMessage(), containsString("Trained model [not-deployed] is not deployed.")); + assertThat(ex.getMessage(), containsString("Model [not-deployed] must be deployed to use.")); } public void testTruncation() throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java index f4b8bb9246c92..ce41cf97b8a0f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java @@ -7,37 +7,27 @@ package org.elasticsearch.xpack.ml.action; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.TaskOperationFailure; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; -import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; -import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; -import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; -import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; -import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import java.util.ArrayList; +import java.util.Collection; import java.util.List; -import java.util.Optional; - -import static org.elasticsearch.core.Strings.format; public class TransportInferTrainedModelDeploymentAction extends TransportTasksAction< TrainedModelDeploymentTask, @@ -45,16 +35,11 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc InferTrainedModelDeploymentAction.Response, InferTrainedModelDeploymentAction.Response> { - private static final Logger logger = LogManager.getLogger(TransportInferTrainedModelDeploymentAction.class); - - private final TrainedModelProvider provider; - @Inject public TransportInferTrainedModelDeploymentAction( ClusterService clusterService, TransportService transportService, - ActionFilters actionFilters, - TrainedModelProvider provider + ActionFilters actionFilters ) { super( InferTrainedModelDeploymentAction.NAME, @@ -66,59 +51,6 @@ public TransportInferTrainedModelDeploymentAction( InferTrainedModelDeploymentAction.Response::new, ThreadPool.Names.SAME ); - this.provider = provider; - } - - @Override - protected void doExecute( - Task task, - InferTrainedModelDeploymentAction.Request request, - ActionListener listener - ) { - TaskId taskId = new TaskId(clusterService.localNode().getId(), task.getId()); - // Update the requests model ID if it's an alias - Optional.ofNullable(ModelAliasMetadata.fromState(clusterService.state()).getModelId(request.getModelId())) - .ifPresent(request::setModelId); - // We need to check whether there is at least an assigned task here, otherwise we cannot redirect to the - // node running the job task. - TrainedModelAssignment assignment = TrainedModelAssignmentMetadata.assignmentForModelId( - clusterService.state(), - request.getModelId() - ).orElse(null); - if (assignment == null) { - // If there is no assignment, verify the model even exists so that we can provide a nicer error message - provider.getTrainedModel(request.getModelId(), GetTrainedModelsAction.Includes.empty(), taskId, ActionListener.wrap(config -> { - if (config.getModelType() != TrainedModelType.PYTORCH) { - listener.onFailure( - ExceptionsHelper.badRequestException( - "Only [pytorch] models are supported by _infer, provided model [{}] has type [{}]", - config.getModelId(), - config.getModelType() - ) - ); - return; - } - String message = "Trained model [" + request.getModelId() + "] is not deployed"; - listener.onFailure(ExceptionsHelper.conflictStatusException(message)); - }, listener::onFailure)); - return; - } - if (assignment.getAssignmentState() == AssignmentState.STOPPING) { - String message = "Trained model [" + request.getModelId() + "] is STOPPING"; - listener.onFailure(ExceptionsHelper.conflictStatusException(message)); - return; - } - logger.trace(() -> format("[%s] selecting node from routing table: %s", assignment.getModelId(), assignment.getNodeRoutingTable())); - assignment.selectRandomStartedNodeWeighedOnAllocations().ifPresentOrElse(node -> { - logger.trace(() -> format("[%s] selected node [%s]", assignment.getModelId(), node)); - request.setNodes(node); - super.doExecute(task, request, listener); - }, () -> { - logger.trace(() -> format("[%s] model not allocated to any node [%s]", assignment.getModelId())); - listener.onFailure( - ExceptionsHelper.conflictStatusException("Trained model [" + request.getModelId() + "] is not allocated to any nodes") - ); - }); } @Override @@ -139,6 +71,7 @@ protected InferTrainedModelDeploymentAction.Response newResponse( request.getModelId() ); } else { + assert tasks.size() == 1; return tasks.get(0); } } @@ -151,16 +84,16 @@ protected void taskOperation( ActionListener listener ) { assert actionTask instanceof CancellableTask : "task [" + actionTask + "] not cancellable"; - task.infer( - request.getDocs().get(0), - request.getUpdate(), - request.isSkipQueue(), - request.getInferenceTimeout(), - actionTask, - ActionListener.wrap( - pyTorchResult -> listener.onResponse(new InferTrainedModelDeploymentAction.Response(pyTorchResult)), - listener::onFailure - ) + + // Multiple documents to infer on, wait for all results + ActionListener> collectingListener = ActionListener.wrap( + pyTorchResults -> { listener.onResponse(new InferTrainedModelDeploymentAction.Response(new ArrayList<>(pyTorchResults))); }, + listener::onFailure ); + + GroupedActionListener groupedListener = new GroupedActionListener<>(collectingListener, request.getDocs().size()); + for (var doc : request.getDocs()) { + task.infer(doc, request.getUpdate(), request.isSkipQueue(), request.getInferenceTimeout(), actionTask, groupedListener); + } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index 31cfa27ad626c..c50d24952e939 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -6,13 +6,15 @@ */ package org.elasticsearch.xpack.ml.action; +import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.core.TimeValue; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.core.Tuple; import org.elasticsearch.license.License; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; @@ -28,8 +30,11 @@ import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request; import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata; @@ -38,9 +43,10 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; -import java.util.Collections; -import java.util.Map; +import java.util.List; import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; @@ -132,19 +138,19 @@ private void doInfer( ) { String concreteModelId = Optional.ofNullable(ModelAliasMetadata.fromState(clusterService.state()).getModelId(request.getModelId())) .orElse(request.getModelId()); - if (isAllocatedModel(concreteModelId)) { + + responseBuilder.setModelId(concreteModelId); + + TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(clusterService.state()); + + if (trainedModelAssignmentMetadata.isAssigned(concreteModelId)) { // It is important to use the resolved model ID here as the alias could change between transport calls. - inferAgainstAllocatedModel(request, concreteModelId, responseBuilder, parentTaskId, listener); + inferAgainstAllocatedModel(trainedModelAssignmentMetadata, request, concreteModelId, responseBuilder, parentTaskId, listener); } else { getModelAndInfer(request, responseBuilder, parentTaskId, (CancellableTask) task, listener); } } - private boolean isAllocatedModel(String modelId) { - TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(clusterService.state()); - return trainedModelAssignmentMetadata.isAssigned(modelId); - } - private void getModelAndInfer( Request request, Response.Builder responseBuilder, @@ -169,75 +175,144 @@ private void getModelAndInfer( typedChainTaskExecutor.execute(ActionListener.wrap(inferenceResultsInterfaces -> { model.release(); - listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).setModelId(model.getModelId()).build()); + listener.onResponse(responseBuilder.addInferenceResults(inferenceResultsInterfaces).build()); }, e -> { model.release(); listener.onFailure(e); })); - }, listener::onFailure); + }, e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + listener.onFailure(e); + return; + } + + // The model was found, check if a more relevant error message can be returned + trainedModelProvider.getTrainedModel( + request.getModelId(), + GetTrainedModelsAction.Includes.empty(), + parentTaskId, + ActionListener.wrap(trainedModelConfig -> { + if (trainedModelConfig.getModelType() == TrainedModelType.PYTORCH) { + // The PyTorch model cannot be allocated if we got here + listener.onFailure( + ExceptionsHelper.conflictStatusException( + "Model [" + + request.getModelId() + + "] must be deployed to use. Please deploy with the start trained model deployment API.", + request.getModelId() + ) + ); + } else { + // return the original error + listener.onFailure(e); + } + }, listener::onFailure) + ); + }); + // TODO should `getModelForInternalInference` be used here?? modelLoadingService.getModelForPipeline(request.getModelId(), parentTaskId, getModelListener); } private void inferAgainstAllocatedModel( + TrainedModelAssignmentMetadata assignmentMeta, Request request, String concreteModelId, Response.Builder responseBuilder, TaskId parentTaskId, ActionListener listener ) { - TypedChainTaskExecutor typedChainTaskExecutor = new TypedChainTaskExecutor<>( - client.threadPool().executor(ThreadPool.Names.SAME), - // run through all tasks - r -> true, - // Always fail immediately and return an error - ex -> true - ); - request.getObjectsToInfer() - .forEach( - stringObjectMap -> typedChainTaskExecutor.add( - chainedTask -> inferSingleDocAgainstAllocatedModel( - concreteModelId, - request.getTimeout(), - request.getUpdate(), - stringObjectMap, - parentTaskId, - chainedTask - ) - ) + TrainedModelAssignment assignment = assignmentMeta.getModelAssignment(concreteModelId); + + if (assignment.getAssignmentState() == AssignmentState.STOPPING) { + String message = "Trained model [" + request.getModelId() + "] is STOPPING"; + listener.onFailure(ExceptionsHelper.conflictStatusException(message)); + return; + } + + // Get a list of nodes to send the requests to and the number of + // documents for each node. + var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(request.numberOfDocuments()); + if (nodes.isEmpty()) { + logger.trace(() -> format("[%s] model not allocated to any node [%s]", assignment.getModelId())); + listener.onFailure( + ExceptionsHelper.conflictStatusException("Trained model [" + request.getModelId() + "] is not allocated to any nodes") ); + return; + } - typedChainTaskExecutor.execute( - ActionListener.wrap( - inferenceResults -> listener.onResponse( - responseBuilder.setInferenceResults(inferenceResults).setModelId(concreteModelId).build() - ), - listener::onFailure - ) - ); + assert nodes.stream().mapToInt(Tuple::v2).sum() == request.numberOfDocuments() + : "mismatch; sum of node requests does not match number of documents in request"; + + AtomicInteger count = new AtomicInteger(); + AtomicArray> results = new AtomicArray<>(nodes.size()); + AtomicReference failure = new AtomicReference<>(); + + int startPos = 0; + int slot = 0; + for (var node : nodes) { + InferTrainedModelDeploymentAction.Request deploymentRequest; + deploymentRequest = new InferTrainedModelDeploymentAction.Request( + concreteModelId, + request.getUpdate(), + request.getObjectsToInfer().subList(startPos, startPos + node.v2()), + request.getTimeout() + ); + deploymentRequest.setNodes(node.v1()); + deploymentRequest.setParentTask(parentTaskId); + + startPos += node.v2(); + + executeAsyncWithOrigin( + client, + ML_ORIGIN, + InferTrainedModelDeploymentAction.INSTANCE, + deploymentRequest, + collectingListener(count, results, failure, slot, nodes.size(), responseBuilder, listener) + ); + + slot++; + } } - private void inferSingleDocAgainstAllocatedModel( - String modelId, - TimeValue timeValue, - InferenceConfigUpdate inferenceConfigUpdate, - Map doc, - TaskId parentTaskId, - ActionListener listener + private ActionListener collectingListener( + AtomicInteger count, + AtomicArray> results, + AtomicReference failure, + int slot, + int totalNumberOfResponses, + Response.Builder responseBuilder, + ActionListener finalListener ) { - InferTrainedModelDeploymentAction.Request request = new InferTrainedModelDeploymentAction.Request( - modelId, - inferenceConfigUpdate, - Collections.singletonList(doc), - timeValue - ); - request.setParentTask(parentTaskId); - executeAsyncWithOrigin( - client, - ML_ORIGIN, - InferTrainedModelDeploymentAction.INSTANCE, - request, - ActionListener.wrap(r -> listener.onResponse(r.getResults()), listener::onFailure) - ); + return new ActionListener<>() { + @Override + public void onResponse(InferTrainedModelDeploymentAction.Response response) { + results.setOnce(slot, response.getResults()); + if (count.incrementAndGet() == totalNumberOfResponses) { + sendResponse(); + } + } + + @Override + public void onFailure(Exception e) { + failure.set(e); + if (count.incrementAndGet() == totalNumberOfResponses) { + sendResponse(); + } + } + + private void sendResponse() { + if (results.nonNullLength() > 0) { + for (int i = 0; i < results.length(); i++) { + if (results.get(i) != null) { + responseBuilder.addInferenceResults(results.get(i)); + } + } + finalListener.onResponse(responseBuilder.build()); + } else { + finalListener.onFailure(failure.get()); + } + } + }; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 48d98c981e14e..7e78146b8c5c6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -150,7 +150,7 @@ InferModelAction.Request buildRequest(IngestDocument ingestDocument) { fields.put(INGEST_KEY, ingestDocument.getIngestMetadata()); } LocalModel.mapFieldsIfNecessary(fields, fieldMap); - return new InferModelAction.Request(modelId, fields, inferenceConfig, previouslyLicensed); + return new InferModelAction.Request(modelId, List.of(fields), inferenceConfig, previouslyLicensed); } void auditWarningAboutLicenseIfNecessary() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index d99d81b9c39f2..f893341cbec81 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -355,10 +355,7 @@ private void loadModel(String modelId, Consumer consumer) { ); return; } - handleLoadFailure( - modelId, - new ElasticsearchStatusException("Trained model [{}] is not deployed.", RestStatus.BAD_REQUEST, modelId) - ); + handleLoadFailure(modelId, modelMustBeDeployedError(modelId)); return; } auditNewReferencedModel(modelId); @@ -409,13 +406,7 @@ private void loadWithoutCaching( ); return; } - modelActionListener.onFailure( - new ElasticsearchStatusException( - "model [{}] must be deployed to use. Please deploy with the start trained model deployment API.", - RestStatus.BAD_REQUEST, - modelId - ) - ); + modelActionListener.onFailure(modelMustBeDeployedError(modelId)); return; } // Verify we can pull the model into memory without causing OOM @@ -483,6 +474,14 @@ private void updateCircuitBreakerEstimate( } } + private ElasticsearchStatusException modelMustBeDeployedError(String modelId) { + return new ElasticsearchStatusException( + "Model [{}] must be deployed to use. Please deploy with the start trained model deployment API.", + RestStatus.BAD_REQUEST, + modelId + ); + } + private void handleLoadSuccess( String modelId, Consumer consumer, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelAction.java index ce32e74a62723..9c8b3aa634a4a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelAction.java @@ -14,9 +14,7 @@ import org.elasticsearch.rest.action.RestCancellableNodeClient; import org.elasticsearch.rest.action.RestToXContentListener; import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -48,16 +46,13 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient } InferModelAction.Request.Builder request = InferModelAction.Request.parseRequest(modelId, restRequest.contentParser()); - if (restRequest.hasParam(InferTrainedModelDeploymentAction.Request.TIMEOUT.getPreferredName())) { + if (restRequest.hasParam(InferModelAction.Request.TIMEOUT.getPreferredName())) { TimeValue inferTimeout = restRequest.paramAsTime( - InferTrainedModelDeploymentAction.Request.TIMEOUT.getPreferredName(), - InferTrainedModelDeploymentAction.Request.DEFAULT_TIMEOUT + InferModelAction.Request.TIMEOUT.getPreferredName(), + InferModelAction.Request.DEFAULT_TIMEOUT ); request.setInferenceTimeout(inferTimeout); } - if (request.getUpdate() == null) { - request.setUpdate(new EmptyConfigUpdate()); - } return channel -> new RestCancellableNodeClient(client, restRequest.getHttpChannel()).execute( InferModelAction.EXTERNAL_INSTANCE, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java index 8f736569df425..01f10862bc5f9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java @@ -7,13 +7,16 @@ package org.elasticsearch.xpack.ml.rest.inference; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.RestApiVersion; import org.elasticsearch.core.TimeValue; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestCancellableNodeClient; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -60,23 +63,36 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient if (restRequest.hasContent() == false) { throw ExceptionsHelper.badRequestException("requires body"); } - InferTrainedModelDeploymentAction.Request.Builder request = InferTrainedModelDeploymentAction.Request.parseRequest( - modelId, - restRequest.contentParser() - ); + InferModelAction.Request.Builder requestBuilder = InferModelAction.Request.parseRequest(modelId, restRequest.contentParser()); - if (restRequest.hasParam(InferTrainedModelDeploymentAction.Request.TIMEOUT.getPreferredName())) { + if (restRequest.hasParam(InferModelAction.Request.TIMEOUT.getPreferredName())) { TimeValue inferTimeout = restRequest.paramAsTime( - InferTrainedModelDeploymentAction.Request.TIMEOUT.getPreferredName(), - InferTrainedModelDeploymentAction.Request.DEFAULT_TIMEOUT + InferModelAction.Request.TIMEOUT.getPreferredName(), + InferModelAction.Request.DEFAULT_TIMEOUT ); - request.setInferenceTimeout(inferTimeout); + requestBuilder.setInferenceTimeout(inferTimeout); + } + + // Unlike the _infer API, deployment/_infer only accepts a single document + var request = requestBuilder.build(); + if (request.getObjectsToInfer() != null && request.getObjectsToInfer().size() > 1) { + ValidationException ex = new ValidationException(); + ex.addValidationError("multiple documents are not supported"); + throw ex; } return channel -> new RestCancellableNodeClient(client, restRequest.getHttpChannel()).execute( - InferTrainedModelDeploymentAction.INSTANCE, - request.build(), - new RestToXContentListener<>(channel) + InferModelAction.EXTERNAL_INSTANCE, + request, + // This API is deprecated but refactoring makes it simpler to call + // the new replacement API and swap in the old response. + ActionListener.wrap(response -> { + InferTrainedModelDeploymentAction.Response oldResponse = new InferTrainedModelDeploymentAction.Response( + response.getInferenceResults() + ); + new RestToXContentListener<>(channel).onResponse(oldResponse); + }, e -> new RestToXContentListener<>(channel).onFailure(e)) + ); } }