Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Utilise parallel allocations where the inference request contains multiple documents #92819

Merged
merged 6 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/92359.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 92359
summary: Utilise parallel allocations where the inference request contains multiple
documents
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ Controls the amount of time to wait for {infer} results. Defaults to 10 seconds.
`docs`::
(Required, array)
An array of objects to pass to the model for inference. The objects should
contain a field matching your configured trained model input. Typically, the
field name is `text_field`. Currently, only a single value is allowed.
contain a field matching your configured trained model input. Typically, the
field name is `text_field`.

////
[[infer-trained-model-deployment-results]]
Expand All @@ -62,7 +62,7 @@ field name is `text_field`. Currently, only a single value is allowed.
[[infer-trained-model-deployment-example]]
== {api-examples-title}

The response depends on the task the model is trained for. If it is a text
The response depends on the task the model is trained for. If it is a text
classification task, the response is the score. For example:

[source,console]
Expand Down Expand Up @@ -123,7 +123,7 @@ The API returns in this case:
----
// NOTCONSOLE

Zero-shot classification tasks require extra configuration defining the class
Zero-shot classification tasks require extra configuration defining the class
labels. These labels are passed in the zero-shot inference config.

[source,console]
Expand All @@ -150,7 +150,7 @@ POST _ml/trained_models/model2/deployment/_infer
--------------------------------------------------
// TEST[skip:TBD]

The API returns the predicted label and the confidence, as well as the top
The API returns the predicted label and the confidence, as well as the top
classes:

[source,console-result]
Expand Down Expand Up @@ -205,7 +205,7 @@ POST _ml/trained_models/model2/deployment/_infer
--------------------------------------------------
// TEST[skip:TBD]

When the input has been truncated due to the limit imposed by the model's
When the input has been truncated due to the limit imposed by the model's
`max_sequence_length` the `is_truncated` field appears in the response.

[source,console-result]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
<titleabbrev>Infer trained model</titleabbrev>
++++

Evaluates a trained model. The model may be any supervised model either trained
Evaluates a trained model. The model may be any supervised model either trained
by {dfanalytics} or imported.

NOTE: For model deployments with caching enabled, results may be returned
NOTE: For model deployments with caching enabled, results may be returned
directly from the {infer} cache.

beta::[]
Expand Down Expand Up @@ -51,9 +51,7 @@ Controls the amount of time to wait for {infer} results. Defaults to 10 seconds.
(Required, array)
An array of objects to pass to the model for inference. The objects should
contain the fields matching your configured trained model input. Typically for
NLP models, the field name is `text_field`. Currently for NLP models, only a
single value is allowed. For {dfanalytics} or imported classification or
regression models, more than one value is allowed.
NLP models, the field name is `text_field`.

//Begin inference_config
`inference_config`::
Expand Down Expand Up @@ -106,7 +104,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-fill-mask]
=====
`num_top_classes`::::
(Optional, integer)
Number of top predicted tokens to return for replacing the mask token. Defaults
Number of top predicted tokens to return for replacing the mask token. Defaults
to `0`.

`results_field`::::
Expand Down Expand Up @@ -277,7 +275,7 @@ The maximum amount of words in the answer. Defaults to `15`.

`num_top_classes`::::
(Optional, integer)
The number the top found answers to return. Defaults to `0`, meaning only the
The number the top found answers to return. Defaults to `0`, meaning only the
best found answer is returned.

`question`::::
Expand Down Expand Up @@ -374,7 +372,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-text-classific

`num_top_classes`::::
(Optional, integer)
Specifies the number of top class predictions to return. Defaults to all classes
Specifies the number of top class predictions to return. Defaults to all classes
(-1).

`results_field`::::
Expand Down Expand Up @@ -886,7 +884,7 @@ POST _ml/trained_models/model2/_infer
--------------------------------------------------
// TEST[skip:TBD]

When the input has been truncated due to the limit imposed by the model's
When the input has been truncated due to the limit imposed by the model's
`max_sequence_length` the `is_truncated` field appears in the response.

[source,console-result]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -80,43 +81,29 @@ public static Builder parseRequest(String modelId, XContentParser parser) {
private final List<Map<String, Object>> objectsToInfer;
private final InferenceConfigUpdate update;
private final boolean previouslyLicensed;
private final TimeValue timeout;

public Request(String modelId, boolean previouslyLicensed) {
this(modelId, Collections.emptyList(), RegressionConfigUpdate.EMPTY_PARAMS, TimeValue.MAX_VALUE, previouslyLicensed);
}
private TimeValue timeout;

public Request(
String modelId,
List<Map<String, Object>> objectsToInfer,
InferenceConfigUpdate inferenceConfig,
TimeValue timeout,
InferenceConfigUpdate inferenceConfigUpdate,
boolean previouslyLicensed
) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, DOCS.getPreferredName()));
this.update = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config");
this.previouslyLicensed = previouslyLicensed;
this.timeout = timeout;
this(modelId, objectsToInfer, inferenceConfigUpdate, DEFAULT_TIMEOUT, previouslyLicensed);
}

public Request(
String modelId,
List<Map<String, Object>> objectsToInfer,
InferenceConfigUpdate inferenceConfig,
InferenceConfigUpdate inferenceConfigUpdate,
TimeValue timeout,
boolean previouslyLicensed
) {
this(modelId, objectsToInfer, inferenceConfig, TimeValue.MAX_VALUE, previouslyLicensed);
}

public Request(String modelId, Map<String, Object> objectToInfer, InferenceConfigUpdate update, boolean previouslyLicensed) {
this(
modelId,
Collections.singletonList(ExceptionsHelper.requireNonNull(objectToInfer, DOCS.getPreferredName())),
update,
TimeValue.MAX_VALUE,
previouslyLicensed
);
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, DOCS.getPreferredName()));
this.update = ExceptionsHelper.requireNonNull(inferenceConfigUpdate, "inference_config");
this.previouslyLicensed = previouslyLicensed;
this.timeout = timeout;
}

public Request(StreamInput in) throws IOException {
Expand All @@ -132,6 +119,10 @@ public Request(StreamInput in) throws IOException {
}
}

public int numberOfDocuments() {
return objectsToInfer.size();
}

public String getModelId() {
return modelId;
}
Expand All @@ -152,6 +143,10 @@ public TimeValue getTimeout() {
return timeout;
}

public void setTimeout(TimeValue timeout) {
this.timeout = timeout;
}

@Override
public ActionRequestValidationException validate() {
return null;
Expand Down Expand Up @@ -196,7 +191,7 @@ public static class Builder {
private String modelId;
private List<Map<String, Object>> docs;
private TimeValue timeout;
private InferenceConfigUpdate update;
private InferenceConfigUpdate update = new EmptyConfigUpdate();

private Builder() {}

Expand Down Expand Up @@ -302,12 +297,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

public static class Builder {
private List<InferenceResults> inferenceResults;
private List<InferenceResults> inferenceResults = new ArrayList<>();
private String modelId;
private boolean isLicensed;

public Builder setInferenceResults(List<InferenceResults> inferenceResults) {
this.inferenceResults = inferenceResults;
public Builder addInferenceResults(List<InferenceResults> inferenceResults) {
this.inferenceResults.addAll(inferenceResults);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,14 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo

public static final InferTrainedModelDeploymentAction INSTANCE = new InferTrainedModelDeploymentAction();

// TODO Review security level
/**
* Do not call this action directly, use InferModelAction instead
* which will perform various checks and set the node the request
* should execute on.
*
* The action is poorly named as once it was publicly accessible
* and exposed through a REST API now it _must_ only called internally.
*/
public static final String NAME = "cluster:monitor/xpack/ml/trained_models/deployment/infer";

public InferTrainedModelDeploymentAction() {
Expand Down Expand Up @@ -157,10 +164,6 @@ public ActionRequestValidationException validate() {
if (docs.isEmpty()) {
validationException = addValidationError("at least one document is required", validationException);
}
if (docs.size() > 1) {
// TODO support multiple docs
validationException = addValidationError("multiple documents are not supported", validationException);
}
}
return validationException;
}
Expand Down Expand Up @@ -244,34 +247,44 @@ public Request build() {

public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject {

private final InferenceResults results;
private final List<InferenceResults> results;

public Response(InferenceResults result) {
public Response(List<InferenceResults> results) {
super(Collections.emptyList(), Collections.emptyList());
this.results = Objects.requireNonNull(result);
this.results = Objects.requireNonNull(results);
}

public Response(StreamInput in) throws IOException {
super(in);
results = in.readNamedWriteable(InferenceResults.class);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
results.toXContent(builder, params);
builder.endObject();
return builder;
// Multiple results added in 8.6.1
if (in.getVersion().onOrAfter(Version.V_8_6_1)) {
results = in.readNamedWriteableList(InferenceResults.class);
} else {
results = List.of(in.readNamedWriteable(InferenceResults.class));
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeNamedWriteable(results);
if (out.getVersion().onOrAfter(Version.V_8_6_1)) {
out.writeNamedWriteableList(results);
} else {
out.writeNamedWriteable(results.get(0));
}
}

public InferenceResults getResults() {
public List<InferenceResults> getResults() {
return results;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
results.get(0).toXContent(builder, params);
builder.endObject();
return builder;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
Expand Down Expand Up @@ -177,7 +178,7 @@ public String[] getStartedNodes() {
.toArray(String[]::new);
}

public Optional<String> selectRandomStartedNodeWeighedOnAllocations() {
public List<Tuple<String, Integer>> selectRandomStartedNodesWeighedOnAllocationsForNRequests(int numberOfRequests) {
List<String> nodeIds = new ArrayList<>(nodeRoutingTable.size());
List<Integer> cumulativeAllocations = new ArrayList<>(nodeRoutingTable.size());
int allocationSum = 0;
Expand All @@ -189,18 +190,42 @@ public Optional<String> selectRandomStartedNodeWeighedOnAllocations() {
}
}

if (nodeIds.isEmpty()) {
return List.of();
}

if (allocationSum == 0) {
// If we are in a mixed cluster where there are assignments prior to introducing allocation distribution
// we could have a zero-sum of allocations. We fall back to returning a random started node.
return nodeIds.isEmpty() ? Optional.empty() : Optional.of(nodeIds.get(Randomness.get().nextInt(nodeIds.size())));
int[] counts = new int[nodeIds.size()];
for (int i = 0; i < numberOfRequests; i++) {
counts[Randomness.get().nextInt(nodeIds.size())]++;
}

var nodeCounts = new ArrayList<Tuple<String, Integer>>();
for (int i = 0; i < counts.length; i++) {
nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i]));
}
return nodeCounts;
}

int[] counts = new int[nodeIds.size()];
var randomIter = Randomness.get().ints(numberOfRequests, 1, allocationSum + 1).iterator();
for (int i = 0; i < numberOfRequests; i++) {
int randomInt = randomIter.nextInt();
int nodeIndex = Collections.binarySearch(cumulativeAllocations, randomInt);
if (nodeIndex < 0) {
nodeIndex = -nodeIndex - 1;
}

counts[nodeIndex]++;
}

int randomInt = Randomness.get().ints(1, 1, allocationSum + 1).iterator().nextInt();
int nodeIndex = Collections.binarySearch(cumulativeAllocations, randomInt);
if (nodeIndex < 0) {
nodeIndex = -nodeIndex - 1;
var nodeCounts = new ArrayList<Tuple<String, Integer>>();
for (int i = 0; i < counts.length; i++) {
nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i]));
}
return Optional.of(nodeIds.get(nodeIndex));
return nodeCounts;
}

public Optional<String> getReason() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedX

public static final String NAME = TextEmbeddingConfig.NAME;

public static TextEmbeddingConfigUpdate EMPTY_INSTANCE = new TextEmbeddingConfigUpdate(null, null);

public static TextEmbeddingConfigUpdate fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
String resultsField = (String) options.remove(RESULTS_FIELD.getPreferredName());
Expand Down
Loading