Skip to content

Commit

Permalink
[ML] Utilise parallel allocations where the inference request contain…
Browse files Browse the repository at this point in the history
…s multiple documents (elastic#92819)

Divide work from the _infer API among all allocations
Backport of elastic#92359
  • Loading branch information
davidkyle authored Jan 11, 2023
1 parent bb37dc2 commit 148bf11
Show file tree
Hide file tree
Showing 17 changed files with 518 additions and 276 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/92359.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 92359
summary: Utilise parallel allocations where the inference request contains multiple
documents
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ Controls the amount of time to wait for {infer} results. Defaults to 10 seconds.
`docs`::
(Required, array)
An array of objects to pass to the model for inference. The objects should
contain a field matching your configured trained model input. Typically, the
field name is `text_field`. Currently, only a single value is allowed.
contain a field matching your configured trained model input. Typically, the
field name is `text_field`.

////
[[infer-trained-model-deployment-results]]
Expand All @@ -62,7 +62,7 @@ field name is `text_field`. Currently, only a single value is allowed.
[[infer-trained-model-deployment-example]]
== {api-examples-title}

The response depends on the task the model is trained for. If it is a text
The response depends on the task the model is trained for. If it is a text
classification task, the response is the score. For example:

[source,console]
Expand Down Expand Up @@ -123,7 +123,7 @@ The API returns in this case:
----
// NOTCONSOLE

Zero-shot classification tasks require extra configuration defining the class
Zero-shot classification tasks require extra configuration defining the class
labels. These labels are passed in the zero-shot inference config.

[source,console]
Expand All @@ -150,7 +150,7 @@ POST _ml/trained_models/model2/deployment/_infer
--------------------------------------------------
// TEST[skip:TBD]

The API returns the predicted label and the confidence, as well as the top
The API returns the predicted label and the confidence, as well as the top
classes:

[source,console-result]
Expand Down Expand Up @@ -205,7 +205,7 @@ POST _ml/trained_models/model2/deployment/_infer
--------------------------------------------------
// TEST[skip:TBD]

When the input has been truncated due to the limit imposed by the model's
When the input has been truncated due to the limit imposed by the model's
`max_sequence_length` the `is_truncated` field appears in the response.

[source,console-result]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
<titleabbrev>Infer trained model</titleabbrev>
++++

Evaluates a trained model. The model may be any supervised model either trained
Evaluates a trained model. The model may be any supervised model either trained
by {dfanalytics} or imported.

NOTE: For model deployments with caching enabled, results may be returned
NOTE: For model deployments with caching enabled, results may be returned
directly from the {infer} cache.

beta::[]
Expand Down Expand Up @@ -51,9 +51,7 @@ Controls the amount of time to wait for {infer} results. Defaults to 10 seconds.
(Required, array)
An array of objects to pass to the model for inference. The objects should
contain the fields matching your configured trained model input. Typically for
NLP models, the field name is `text_field`. Currently for NLP models, only a
single value is allowed. For {dfanalytics} or imported classification or
regression models, more than one value is allowed.
NLP models, the field name is `text_field`.

//Begin inference_config
`inference_config`::
Expand Down Expand Up @@ -106,7 +104,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-fill-mask]
=====
`num_top_classes`::::
(Optional, integer)
Number of top predicted tokens to return for replacing the mask token. Defaults
Number of top predicted tokens to return for replacing the mask token. Defaults
to `0`.

`results_field`::::
Expand Down Expand Up @@ -277,7 +275,7 @@ The maximum amount of words in the answer. Defaults to `15`.

`num_top_classes`::::
(Optional, integer)
The number the top found answers to return. Defaults to `0`, meaning only the
The number the top found answers to return. Defaults to `0`, meaning only the
best found answer is returned.

`question`::::
Expand Down Expand Up @@ -374,7 +372,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-text-classific

`num_top_classes`::::
(Optional, integer)
Specifies the number of top class predictions to return. Defaults to all classes
Specifies the number of top class predictions to return. Defaults to all classes
(-1).

`results_field`::::
Expand Down Expand Up @@ -886,7 +884,7 @@ POST _ml/trained_models/model2/_infer
--------------------------------------------------
// TEST[skip:TBD]

When the input has been truncated due to the limit imposed by the model's
When the input has been truncated due to the limit imposed by the model's
`max_sequence_length` the `is_truncated` field appears in the response.

[source,console-result]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -80,43 +81,29 @@ public static Builder parseRequest(String modelId, XContentParser parser) {
private final List<Map<String, Object>> objectsToInfer;
private final InferenceConfigUpdate update;
private final boolean previouslyLicensed;
private final TimeValue timeout;

public Request(String modelId, boolean previouslyLicensed) {
this(modelId, Collections.emptyList(), RegressionConfigUpdate.EMPTY_PARAMS, TimeValue.MAX_VALUE, previouslyLicensed);
}
private TimeValue timeout;

public Request(
String modelId,
List<Map<String, Object>> objectsToInfer,
InferenceConfigUpdate inferenceConfig,
TimeValue timeout,
InferenceConfigUpdate inferenceConfigUpdate,
boolean previouslyLicensed
) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, DOCS.getPreferredName()));
this.update = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config");
this.previouslyLicensed = previouslyLicensed;
this.timeout = timeout;
this(modelId, objectsToInfer, inferenceConfigUpdate, DEFAULT_TIMEOUT, previouslyLicensed);
}

public Request(
String modelId,
List<Map<String, Object>> objectsToInfer,
InferenceConfigUpdate inferenceConfig,
InferenceConfigUpdate inferenceConfigUpdate,
TimeValue timeout,
boolean previouslyLicensed
) {
this(modelId, objectsToInfer, inferenceConfig, TimeValue.MAX_VALUE, previouslyLicensed);
}

public Request(String modelId, Map<String, Object> objectToInfer, InferenceConfigUpdate update, boolean previouslyLicensed) {
this(
modelId,
Collections.singletonList(ExceptionsHelper.requireNonNull(objectToInfer, DOCS.getPreferredName())),
update,
TimeValue.MAX_VALUE,
previouslyLicensed
);
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, DOCS.getPreferredName()));
this.update = ExceptionsHelper.requireNonNull(inferenceConfigUpdate, "inference_config");
this.previouslyLicensed = previouslyLicensed;
this.timeout = timeout;
}

public Request(StreamInput in) throws IOException {
Expand All @@ -132,6 +119,10 @@ public Request(StreamInput in) throws IOException {
}
}

public int numberOfDocuments() {
return objectsToInfer.size();
}

public String getModelId() {
return modelId;
}
Expand All @@ -152,6 +143,10 @@ public TimeValue getTimeout() {
return timeout;
}

public void setTimeout(TimeValue timeout) {
this.timeout = timeout;
}

@Override
public ActionRequestValidationException validate() {
return null;
Expand Down Expand Up @@ -196,7 +191,7 @@ public static class Builder {
private String modelId;
private List<Map<String, Object>> docs;
private TimeValue timeout;
private InferenceConfigUpdate update;
private InferenceConfigUpdate update = new EmptyConfigUpdate();

private Builder() {}

Expand Down Expand Up @@ -302,12 +297,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

public static class Builder {
private List<InferenceResults> inferenceResults;
private List<InferenceResults> inferenceResults = new ArrayList<>();
private String modelId;
private boolean isLicensed;

public Builder setInferenceResults(List<InferenceResults> inferenceResults) {
this.inferenceResults = inferenceResults;
public Builder addInferenceResults(List<InferenceResults> inferenceResults) {
this.inferenceResults.addAll(inferenceResults);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,14 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo

public static final InferTrainedModelDeploymentAction INSTANCE = new InferTrainedModelDeploymentAction();

// TODO Review security level
/**
* Do not call this action directly, use InferModelAction instead
* which will perform various checks and set the node the request
* should execute on.
*
* The action is poorly named as once it was publicly accessible
* and exposed through a REST API now it _must_ only called internally.
*/
public static final String NAME = "cluster:monitor/xpack/ml/trained_models/deployment/infer";

public InferTrainedModelDeploymentAction() {
Expand Down Expand Up @@ -157,10 +164,6 @@ public ActionRequestValidationException validate() {
if (docs.isEmpty()) {
validationException = addValidationError("at least one document is required", validationException);
}
if (docs.size() > 1) {
// TODO support multiple docs
validationException = addValidationError("multiple documents are not supported", validationException);
}
}
return validationException;
}
Expand Down Expand Up @@ -244,34 +247,44 @@ public Request build() {

public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject {

private final InferenceResults results;
private final List<InferenceResults> results;

public Response(InferenceResults result) {
public Response(List<InferenceResults> results) {
super(Collections.emptyList(), Collections.emptyList());
this.results = Objects.requireNonNull(result);
this.results = Objects.requireNonNull(results);
}

public Response(StreamInput in) throws IOException {
super(in);
results = in.readNamedWriteable(InferenceResults.class);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
results.toXContent(builder, params);
builder.endObject();
return builder;
// Multiple results added in 8.6.1
if (in.getVersion().onOrAfter(Version.V_8_6_1)) {
results = in.readNamedWriteableList(InferenceResults.class);
} else {
results = List.of(in.readNamedWriteable(InferenceResults.class));
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeNamedWriteable(results);
if (out.getVersion().onOrAfter(Version.V_8_6_1)) {
out.writeNamedWriteableList(results);
} else {
out.writeNamedWriteable(results.get(0));
}
}

public InferenceResults getResults() {
public List<InferenceResults> getResults() {
return results;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
results.get(0).toXContent(builder, params);
builder.endObject();
return builder;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
Expand Down Expand Up @@ -177,7 +178,7 @@ public String[] getStartedNodes() {
.toArray(String[]::new);
}

public Optional<String> selectRandomStartedNodeWeighedOnAllocations() {
public List<Tuple<String, Integer>> selectRandomStartedNodesWeighedOnAllocationsForNRequests(int numberOfRequests) {
List<String> nodeIds = new ArrayList<>(nodeRoutingTable.size());
List<Integer> cumulativeAllocations = new ArrayList<>(nodeRoutingTable.size());
int allocationSum = 0;
Expand All @@ -189,18 +190,42 @@ public Optional<String> selectRandomStartedNodeWeighedOnAllocations() {
}
}

if (nodeIds.isEmpty()) {
return List.of();
}

if (allocationSum == 0) {
// If we are in a mixed cluster where there are assignments prior to introducing allocation distribution
// we could have a zero-sum of allocations. We fall back to returning a random started node.
return nodeIds.isEmpty() ? Optional.empty() : Optional.of(nodeIds.get(Randomness.get().nextInt(nodeIds.size())));
int[] counts = new int[nodeIds.size()];
for (int i = 0; i < numberOfRequests; i++) {
counts[Randomness.get().nextInt(nodeIds.size())]++;
}

var nodeCounts = new ArrayList<Tuple<String, Integer>>();
for (int i = 0; i < counts.length; i++) {
nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i]));
}
return nodeCounts;
}

int[] counts = new int[nodeIds.size()];
var randomIter = Randomness.get().ints(numberOfRequests, 1, allocationSum + 1).iterator();
for (int i = 0; i < numberOfRequests; i++) {
int randomInt = randomIter.nextInt();
int nodeIndex = Collections.binarySearch(cumulativeAllocations, randomInt);
if (nodeIndex < 0) {
nodeIndex = -nodeIndex - 1;
}

counts[nodeIndex]++;
}

int randomInt = Randomness.get().ints(1, 1, allocationSum + 1).iterator().nextInt();
int nodeIndex = Collections.binarySearch(cumulativeAllocations, randomInt);
if (nodeIndex < 0) {
nodeIndex = -nodeIndex - 1;
var nodeCounts = new ArrayList<Tuple<String, Integer>>();
for (int i = 0; i < counts.length; i++) {
nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i]));
}
return Optional.of(nodeIds.get(nodeIndex));
return nodeCounts;
}

public Optional<String> getReason() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedX

public static final String NAME = TextEmbeddingConfig.NAME;

public static TextEmbeddingConfigUpdate EMPTY_INSTANCE = new TextEmbeddingConfigUpdate(null, null);

public static TextEmbeddingConfigUpdate fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
String resultsField = (String) options.remove(RESULTS_FIELD.getPreferredName());
Expand Down
Loading

0 comments on commit 148bf11

Please sign in to comment.