Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Preserve order of inference results #100143

Merged
merged 6 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/100143.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 100143
summary: When calling the _infer API with multiple inputs on a model deployment with more than one allocation the output results order was not guaranteed to match the input order. The fix ensures the output order matches the input order.
davidkyle marked this conversation as resolved.
Show resolved Hide resolved
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
Expand Down Expand Up @@ -639,6 +640,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceResults.class, WarningInferenceResults.NAME, WarningInferenceResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceResults.class, ErrorInferenceResults.NAME, ErrorInferenceResults::new)
);
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, NerResults.NAME, NerResults::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, FillMaskResults.NAME, FillMaskResults::new));
namedWriteables.add(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

public class ErrorInferenceResults implements InferenceResults {

public static final String NAME = "error";
public static final ParseField WARNING = new ParseField("error");

private final Exception exception;

public ErrorInferenceResults(Exception exception) {
this.exception = Objects.requireNonNull(exception);
}

public ErrorInferenceResults(StreamInput in) throws IOException {
this.exception = in.readException();
}

public Exception getException() {
return exception;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeException(exception);
}

@Override
public boolean equals(Object object) {
if (object == this) {
return true;
}
if (object == null || getClass() != object.getClass()) {
return false;
}
ErrorInferenceResults that = (ErrorInferenceResults) object;
// Just compare the message for serialization test purposes
return Objects.equals(exception.getMessage(), that.exception.getMessage());
}

@Override
public int hashCode() {
// Just compare the message for serialization test purposes
return Objects.hash(exception.getMessage());
}

@Override
public String getResultsField() {
return NAME;
}

@Override
public Map<String, Object> asMap() {
Map<String, Object> asMap = new LinkedHashMap<>();
asMap.put(NAME, exception.getMessage());
return asMap;
}

@Override
public String toString() {
return Strings.toString(this);
}

@Override
public Object predictedValue() {
return null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(NAME, exception.getMessage());
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.rest.RestStatus;

import java.io.IOException;

import static org.hamcrest.Matchers.equalTo;

public class ErrorInferenceResultsTests extends InferenceResultsTestCase<ErrorInferenceResults> {

@Override
protected Writeable.Reader<ErrorInferenceResults> instanceReader() {
return ErrorInferenceResults::new;
}

@Override
protected ErrorInferenceResults createTestInstance() {
return new ErrorInferenceResults(new ElasticsearchStatusException(randomAlphaOfLength(8), randomFrom(RestStatus.values())));
}

@Override
protected ErrorInferenceResults mutateInstance(ErrorInferenceResults instance) throws IOException {
return null;// TODO implement https://github.com/elastic/elasticsearch/issues/25929
}

@Override
void assertFieldValues(ErrorInferenceResults createdInstance, IngestDocument document, String resultsField) {
assertThat(document.getFieldValue(resultsField + ".error", String.class), equalTo(createdInstance.getException().getMessage()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,24 @@
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

public class TransportInferTrainedModelDeploymentAction extends TransportTasksAction<
TrainedModelDeploymentTask,
Expand Down Expand Up @@ -96,13 +97,55 @@ protected void taskOperation(
}

// Multiple documents to infer on, wait for all results
ActionListener<Collection<InferenceResults>> collectingListener = ActionListener.wrap(pyTorchResults -> {
listener.onResponse(new InferTrainedModelDeploymentAction.Response(new ArrayList<>(pyTorchResults)));
}, listener::onFailure);

GroupedActionListener<InferenceResults> groupedListener = new GroupedActionListener<>(nlpInputs.size(), collectingListener);
// and return order the results to match the request order
AtomicInteger count = new AtomicInteger();
AtomicArray<InferenceResults> results = new AtomicArray<>(nlpInputs.size());
int slot = 0;
for (var input : nlpInputs) {
task.infer(input, request.getUpdate(), request.isHighPriority(), request.getInferenceTimeout(), actionTask, groupedListener);
task.infer(
input,
request.getUpdate(),
request.isHighPriority(),
request.getInferenceTimeout(),
actionTask,
orderedListener(count, results, slot++, nlpInputs.size(), listener)
);
}
}

/**
* Create a listener that groups the results is the correct order.
davidkyle marked this conversation as resolved.
Show resolved Hide resolved
* Exceptions are converted to {@link ErrorInferenceResults},
* the listener will never call {@code finalListener::onFailure}
* instead failures are returned as inference results.
*/
private ActionListener<InferenceResults> orderedListener(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we make this static?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 and I've added a test

AtomicInteger count,
AtomicArray<InferenceResults> results,
int slot,
int totalNumberOfResponses,
ActionListener<InferTrainedModelDeploymentAction.Response> finalListener
) {
return new ActionListener<>() {
@Override
public void onResponse(InferenceResults response) {
results.setOnce(slot, response);
if (count.incrementAndGet() == totalNumberOfResponses) {
sendResponse();
}
}

@Override
public void onFailure(Exception e) {
results.setOnce(slot, new ErrorInferenceResults(e));
if (count.incrementAndGet() == totalNumberOfResponses) {
sendResponse();
}
}

private void sendResponse() {
finalListener.onResponse(new InferTrainedModelDeploymentAction.Response(results.asList()));
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
Expand Down Expand Up @@ -348,15 +349,24 @@ public void onFailure(Exception e) {
}

private void sendResponse() {
if (results.nonNullLength() > 0) {
if (failure.get() != null) {
finalListener.onFailure(failure.get());
} else {
for (int i = 0; i < results.length(); i++) {
if (results.get(i) != null) {
responseBuilder.addInferenceResults(results.get(i));
var resultList = results.get(i);
if (resultList != null) {
for (var result : resultList) {
if (result instanceof ErrorInferenceResults errorResult) {
// Any failure fails all requests
// TODO is this the correct behaviour for batched requests?
finalListener.onFailure(errorResult.getException());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the code well enough but maybe in the future we could make the response similar to a bulk response where an entry in the results array can either be a failure or a successful result?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That the idea. The rest response does not have to change but internal users (such as ingest) can make better decisions about how to handle a response which is partially successful

return;
}
}
responseBuilder.addInferenceResults(resultList);
}
}
finalListener.onResponse(responseBuilder.build());
} else {
finalListener.onFailure(failure.get());
}
}
};
Expand Down