Skip to content

Commit

Permalink
[ML] Infer against model deployment
Browse files Browse the repository at this point in the history
This adds a temporary API for doing inference against
a trained model deployment.
  • Loading branch information
dimitris-athanasiou committed Apr 2, 2021
1 parent 99ed8b0 commit 4badd50
Show file tree
Hide file tree
Showing 15 changed files with 632 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ public static PersistentTasksCustomMetadata.PersistentTask<?> getSnapshotUpgrade
}

@Nullable
public static PersistentTasksCustomMetadata.PersistentTask<?> getDeployTrainedModelTask(String modelId,
@Nullable PersistentTasksCustomMetadata tasks) {
public static PersistentTasksCustomMetadata.PersistentTask<?> getTrainedModelDeploymentTask(
String modelId, @Nullable PersistentTasksCustomMetadata tasks) {
return tasks == null ? null : tasks.getTask(trainedModelDeploymentTaskId(modelId));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.tasks.BaseTasksRequest;
import org.elasticsearch.action.support.tasks.BaseTasksResponse;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedModelDeploymentAction.Response> {

public static final InferTrainedModelDeploymentAction INSTANCE = new InferTrainedModelDeploymentAction();

// TODO Review security level
public static final String NAME = "cluster:monitor/xpack/ml/trained_models/deployment/infer";

public InferTrainedModelDeploymentAction() {
super(NAME, InferTrainedModelDeploymentAction.Response::new);
}

public static class Request extends BaseTasksRequest<Request> implements ToXContentObject {

private static final ParseField DEPLOYMENT_ID = new ParseField("deployment_id");
public static final ParseField INPUTS = new ParseField("inputs");

private static final ObjectParser<Request, Void> PARSER = new ObjectParser<>("infer_trained_model_request", Request::new);

static {
PARSER.declareString((request, deploymentId) -> request.deploymentId = deploymentId, DEPLOYMENT_ID);
PARSER.declareDoubleArray(Request::setInputs, INPUTS);
}

public static Request parseRequest(String deploymentId, XContentParser parser) {
Request request = PARSER.apply(parser, null);
if (deploymentId != null) {
request.deploymentId = deploymentId;
}
return request;
}

private String deploymentId;
private double[] inputs;

private Request() {
}

public Request(String deploymentId) {
this.deploymentId = Objects.requireNonNull(deploymentId);
}

public Request(StreamInput in) throws IOException {
super(in);
deploymentId = in.readString();
inputs = in.readDoubleArray();
}

public String getDeploymentId() {
return deploymentId;
}

public void setInputs(List<Double> inputs) {
ExceptionsHelper.requireNonNull(inputs, INPUTS);
this.inputs = inputs.stream().mapToDouble(d -> d).toArray();
}

public double[] getInputs() {
return inputs;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(deploymentId);
out.writeDoubleArray(inputs);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(DEPLOYMENT_ID.getPreferredName(), deploymentId);
builder.array(INPUTS.getPreferredName(), inputs);
builder.endObject();
return builder;
}

@Override
public boolean match(Task task) {
return StartTrainedModelDeploymentAction.TaskMatcher.match(task, deploymentId);
}
}

public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject {

private final PyTorchResult result;

public Response(PyTorchResult result) {
super(Collections.emptyList(), Collections.emptyList());
this.result = Objects.requireNonNull(result);
}

public Response(StreamInput in) throws IOException {
super(in);
result = new PyTorchResult(in);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
result.toXContent(builder, params);
return builder;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.deployment;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

/*
* TODO This does not necessarily belong in core. Will have to reconsider
* once we figure the format we store inference results in client calls.
*/
public class PyTorchResult implements ToXContentObject, Writeable {

private static final ParseField REQUEST_ID = new ParseField("request_id");
private static final ParseField INFERENCE = new ParseField("inference");
private static final ParseField ERROR = new ParseField("error");

public static final ConstructingObjectParser<PyTorchResult, Void> PARSER = new ConstructingObjectParser<>("pytorch_result",
a -> new PyTorchResult((String) a[0], (double[][]) a[1], (String) a[2]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), REQUEST_ID);
PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> {
List<List<Double>> listOfListOfDoubles = MlParserUtils.parseArrayOfArrays(
INFERENCE.getPreferredName(), XContentParser::doubleValue, p);
double[][] primitiveDoubles = new double[listOfListOfDoubles.size()][];
for (int i = 0; i < listOfListOfDoubles.size(); i++) {
List<Double> row = listOfListOfDoubles.get(i);
double[] primitiveRow = new double[row.size()];
for (int j = 0; j < row.size(); j++) {
primitiveRow[j] = row.get(j);
}
primitiveDoubles[i] = primitiveRow;
}
return primitiveDoubles;
},
INFERENCE,
ObjectParser.ValueType.VALUE_ARRAY
);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), ERROR);
}

private final String requestId;
private final double[][] inference;
private final String error;

public PyTorchResult(String requestId, @Nullable double[][] inference, @Nullable String error) {
this.requestId = Objects.requireNonNull(requestId);
this.inference = inference;
this.error = error;
}

public PyTorchResult(StreamInput in) throws IOException {
requestId = in.readString();
boolean hasInference = in.readBoolean();
if (hasInference) {
inference = in.readArray(StreamInput::readDoubleArray, length -> new double[length][]);
} else {
inference = null;
}
error = in.readOptionalString();
}

public String getRequestId() {
return requestId;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(REQUEST_ID.getPreferredName(), requestId);
if (inference != null) {
builder.field(INFERENCE.getPreferredName(), inference);
}
if (error != null) {
builder.field(ERROR.getPreferredName(), error);
}
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(requestId);
if (inference == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeArray(StreamOutput::writeDoubleArray, inference);
}
out.writeOptionalString(error);
}

@Override
public int hashCode() {
return Objects.hash(requestId, Arrays.hashCode(inference), error);
}

@Override
public boolean equals(Object other) {
if (this == other) return true;
if (other == null || getClass() != other.getClass()) return false;

PyTorchResult that = (PyTorchResult) other;
return Objects.equals(requestId, that.requestId)
&& Objects.equals(inference, that.inference)
&& Objects.equals(error, that.error);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
package org.elasticsearch.xpack.core.ml.inference.preprocessing;

import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -22,6 +21,7 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -63,7 +63,7 @@ private static ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseCo

parser.declareField(ConstructingObjectParser.constructorArg(),
(p, c) -> {
List<List<Short>> listOfListOfShorts = parseArrays(EMBEDDING_QUANT_SCALES.getPreferredName(),
List<List<Short>> listOfListOfShorts = MlParserUtils.parseArrayOfArrays(EMBEDDING_QUANT_SCALES.getPreferredName(),
XContentParser::shortValue,
p);
short[][] primitiveShorts = new short[listOfListOfShorts.size()][];
Expand Down Expand Up @@ -99,30 +99,6 @@ private static ConstructingObjectParser<CustomWordEmbedding, PreProcessorParseCo
return parser;
}

private static <T> List<List<T>> parseArrays(String fieldName,
CheckedFunction<XContentParser, T, IOException> fromParser,
XContentParser p) throws IOException {
if (p.currentToken() != XContentParser.Token.START_ARRAY) {
throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]");
}
List<List<T>> values = new ArrayList<>();
while(p.nextToken() != XContentParser.Token.END_ARRAY) {
if (p.currentToken() != XContentParser.Token.START_ARRAY) {
throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]");
}
List<T> innerList = new ArrayList<>();
while(p.nextToken() != XContentParser.Token.END_ARRAY) {
if(p.currentToken().isValue() == false) {
throw new IllegalStateException("expected non-null value but got [" + p.currentToken() + "] " +
"for [" + fieldName + "]");
}
innerList.add(fromParser.apply(p));
}
values.add(innerList);
}
return values;
}

public static CustomWordEmbedding fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, PreProcessorParseContext.DEFAULT);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.utils;

import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public final class MlParserUtils {

private MlParserUtils() {}

/**
* Parses an array of arrays of the given type
*
* @param fieldName the field name
* @param valueParser the parser to use for the inner array values
* @param parser the outer parser
* @param <T> the type of the values of the inner array
* @return a list of lists representing the array of arrays
* @throws IOException an exception if parsing fails
*/
public static <T> List<List<T>> parseArrayOfArrays(String fieldName, CheckedFunction<XContentParser, T, IOException> valueParser,
XContentParser parser) throws IOException {
if (parser.currentToken() != XContentParser.Token.START_ARRAY) {
throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]");
}
List<List<T>> values = new ArrayList<>();
while(parser.nextToken() != XContentParser.Token.END_ARRAY) {
if (parser.currentToken() != XContentParser.Token.START_ARRAY) {
throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]");
}
List<T> innerList = new ArrayList<>();
while(parser.nextToken() != XContentParser.Token.END_ARRAY) {
if(parser.currentToken().isValue() == false) {
throw new IllegalStateException("expected non-null value but got [" + parser.currentToken() + "] " +
"for [" + fieldName + "]");
}
innerList.add(valueParser.apply(parser));
}
values.add(innerList);
}
return values;
}
}
Loading

0 comments on commit 4badd50

Please sign in to comment.