Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Natural Language Processing tasks and models #73523

Merged
merged 21 commits into from
Jun 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.tasks.BaseTasksRequest;
import org.elasticsearch.action.support.tasks.BaseTasksResponse;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;

import java.io.IOException;
import java.util.Collections;
Expand All @@ -36,53 +39,57 @@ public InferTrainedModelDeploymentAction() {

public static class Request extends BaseTasksRequest<Request> implements ToXContentObject {

public static final String REQUEST_ID = "request_id";
public static final String DEPLOYMENT_ID = "deployment_id";
public static final String JSON_REQUEST = "json_request";
public static final ParseField INPUT = new ParseField("input");

private static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
static {
PARSER.declareString((request, inputs) -> request.input = inputs, INPUT);
}

public static Request parseRequest(String deploymentId, XContentParser parser) {
Request r = PARSER.apply(parser, null);
r.deploymentId = deploymentId;
return r;
}

private String deploymentId;
private String requestId;
private String jsonDoc;
private String input;

public Request(String deploymentId, String requestId, String jsonDoc) {
private Request() {
}

public Request(String deploymentId, String input) {
this.deploymentId = Objects.requireNonNull(deploymentId);
this.requestId = requestId;
this.jsonDoc = Objects.requireNonNull(jsonDoc);
this.input = Objects.requireNonNull(input);
}

public Request(StreamInput in) throws IOException {
super(in);
deploymentId = in.readString();
requestId = in.readOptionalString();
jsonDoc = in.readString();
input = in.readString();
}

public String getDeploymentId() {
return deploymentId;
}

public String getRequestId() {
return requestId;
}

public String getJsonDoc() {
return jsonDoc;
public String getInput() {
return input;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(deploymentId);
out.writeOptionalString(requestId);
out.writeString(jsonDoc);
out.writeString(input);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(DEPLOYMENT_ID, deploymentId);
builder.field(REQUEST_ID, requestId);
builder.field(JSON_REQUEST, jsonDoc);
builder.field(INPUT.getPreferredName(), input);
builder.endObject();
return builder;
}
Expand All @@ -98,40 +105,39 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
InferTrainedModelDeploymentAction.Request that = (InferTrainedModelDeploymentAction.Request) o;
return Objects.equals(deploymentId, that.deploymentId)
&& Objects.equals(requestId, that.requestId)
&& Objects.equals(jsonDoc, that.jsonDoc);
&& Objects.equals(input, that.input);
}

@Override
public int hashCode() {
return Objects.hash(deploymentId, requestId, jsonDoc);
return Objects.hash(deploymentId, input);
}
}

public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject {

private final PyTorchResult result;
private final InferenceResults results;

public Response(PyTorchResult result) {
public Response(InferenceResults result) {
super(Collections.emptyList(), Collections.emptyList());
this.result = Objects.requireNonNull(result);
this.results = Objects.requireNonNull(result);
}

public Response(StreamInput in) throws IOException {
super(in);
result = new PyTorchResult(in);
results = in.readNamedWriteable(InferenceResults.class);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
result.toXContent(builder, params);
results.toXContent(builder, params);
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
result.writeTo(out);
out.writeNamedWriteable(results);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
Expand Down Expand Up @@ -215,6 +217,12 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
WarningInferenceResults.NAME,
WarningInferenceResults::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
NerResults.NAME,
NerResults::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
FillMaskResults.NAME,
FillMaskResults::new));

// Inference Configs
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ public String getRequestId() {
return requestId;
}

public boolean isError() {
return error != null;
}

public String getError() {
return error;
}

public double[][] getInferenceResult() {
return inference;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

public class FillMaskResults implements InferenceResults {

public static final String NAME = "fill_mask_result";
public static final String DEFAULT_RESULTS_FIELD = "results";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be predictions?


private final List<Prediction> predictions;

public FillMaskResults(List<Prediction> predictions) {
this.predictions = predictions;
}

public FillMaskResults(StreamInput in) throws IOException {
this.predictions = in.readList(Prediction::new);
}

public List<Prediction> getPredictions() {
return predictions;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startArray();
for (Prediction prediction : predictions) {
prediction.toXContent(builder, params);
}
builder.endArray();
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeList(predictions);
}

@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(DEFAULT_RESULTS_FIELD, predictions.stream().map(Prediction::toMap).collect(Collectors.toList()));
return map;
}

@Override
public Object predictedValue() {
if (predictions.isEmpty()) {
return null;
}
return predictions.get(0).token;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FillMaskResults that = (FillMaskResults) o;
return Objects.equals(predictions, that.predictions);
}

@Override
public int hashCode() {
return Objects.hash(predictions);
}

public static class Prediction implements ToXContentObject, Writeable {

private static final ParseField TOKEN = new ParseField("token");
private static final ParseField SCORE = new ParseField("score");
private static final ParseField SEQUENCE = new ParseField("sequence");

private final String token;
private final double score;
private final String sequence;

public Prediction(String token, double score, String sequence) {
this.token = Objects.requireNonNull(token);
this.score = score;
this.sequence = Objects.requireNonNull(sequence);
}

public Prediction(StreamInput in) throws IOException {
token = in.readString();
score = in.readDouble();
sequence = in.readString();
}

public double getScore() {
return score;
}

public String getSequence() {
return sequence;
}

public String getToken() {
return token;
}

public Map<String, Object> toMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(TOKEN.getPreferredName(), token);
map.put(SCORE.getPreferredName(), score);
map.put(SEQUENCE.getPreferredName(), sequence);
return map;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TOKEN.getPreferredName(), token);
builder.field(SCORE.getPreferredName(), score);
builder.field(SEQUENCE.getPreferredName(), sequence);
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(token);
out.writeDouble(score);
out.writeString(sequence);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Prediction result = (Prediction) o;
return Double.compare(result.score, score) == 0 &&
Objects.equals(token, result.token) &&
Objects.equals(sequence, result.sequence);
}

@Override
public int hashCode() {
return Objects.hash(token, score, sequence);
}
}
}
Loading