Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Natural Language Processing tasks and models #73523

Merged
merged 21 commits into from
Jun 2, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Move results into core and add tests
  • Loading branch information
davidkyle committed Jun 1, 2021
commit 5edff630088159350fda6688a5d9bd67b0814098
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
@@ -219,6 +220,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
NerResults.NAME,
NerResults::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
FillMaskResults.NAME,
FillMaskResults::new));

// Inference Configs
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
Original file line number Diff line number Diff line change
@@ -5,21 +5,88 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.pipelines.nlp;
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

public class FillMaskResult implements InferenceResults {
public class FillMaskResults implements InferenceResults {

public static final String NAME = "fill_mask_result";
public static final String DEFAULT_RESULTS_FIELD = "results";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be predictions?


private final List<Result> results;

public FillMaskResults(List<Result> results) {
this.results = results;
}

public FillMaskResults(StreamInput in) throws IOException {
this.results = in.readList(Result::new);
}

public List<Result> getResults() {
return results;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startArray();
for (Result result : results) {
result.toXContent(builder, params);
}
builder.endArray();
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeList(results);
}

@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(DEFAULT_RESULTS_FIELD, results.stream().map(Result::toMap).collect(Collectors.toList()));
return map;
}

@Override
public Object predictedValue() {
if (results.isEmpty()) {
return null;
}
return results.get(0).token;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FillMaskResults that = (FillMaskResults) o;
return Objects.equals(results, that.results);
}

@Override
public int hashCode() {
return Objects.hash(results);
}

public static class Result implements ToXContentObject, Writeable {

@@ -37,6 +104,32 @@ public Result(String token, double score, String sequence) {
this.sequence = Objects.requireNonNull(sequence);
}

public Result(StreamInput in) throws IOException {
token = in.readString();
score = in.readDouble();
sequence = in.readString();
}

public double getScore() {
return score;
}

public String getSequence() {
return sequence;
}

public String getToken() {
return token;
}

public Map<String, Object> toMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(TOKEN.getPreferredName(), token);
map.put(SCORE.getPreferredName(), score);
map.put(SEQUENCE.getPreferredName(), sequence);
return map;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@@ -53,43 +146,20 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(score);
out.writeString(sequence);
}
}

private static final String NAME = "fill_mask_result";

private final List<Result> results;

public FillMaskResult(List<Result> results) {
this.results = results;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startArray();
for (Result result : results) {
result.toXContent(builder, params);
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result result = (Result) o;
return Double.compare(result.score, score) == 0 &&
Objects.equals(token, result.token) &&
Objects.equals(sequence, result.sequence);
}
builder.endArray();
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeList(results);
}

@Override
public Map<String, Object> asMap() {
return null;
}

@Override
public Object predictedValue() {
return null;
@Override
public int hashCode() {
return Objects.hash(token, score, sequence);
}
}
}
Original file line number Diff line number Diff line change
@@ -7,23 +7,41 @@

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

public class NerResults implements InferenceResults {

public static final String NAME = "ner_result";

public NerResults(StreamInput in) throws IOException {
private final List<EntityGroup> entityGroups;

public NerResults(List<EntityGroup> entityGroups) {
this.entityGroups = Objects.requireNonNull(entityGroups);
}

public NerResults(StreamInput in) throws IOException {
entityGroups = in.readList(EntityGroup::new);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startArray();
for (EntityGroup entity : entityGroups) {
entity.toXContent(builder, params);
}
builder.endArray();
return builder;
}

@@ -34,18 +52,111 @@ public String getWriteableName() {

@Override
public void writeTo(StreamOutput out) throws IOException {

out.writeList(entityGroups);
}

@Override
public Map<String, Object> asMap() {
// TODO required for Ingest Pipelines
return null;
Map<String, Object> map = new LinkedHashMap<>();
map.put(FillMaskResults.DEFAULT_RESULTS_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList()));
return map;
}

@Override
public Object predictedValue() {
// TODO required for Ingest Pipelines
return null;
// Used by the inference aggregation
throw new UnsupportedOperationException("Named Entity Recognition does not support a single predicted value");
}

public List<EntityGroup> getEntityGroups() {
return entityGroups;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
NerResults that = (NerResults) o;
return Objects.equals(entityGroups, that.entityGroups);
}

@Override
public int hashCode() {
return Objects.hash(entityGroups);
}

public static class EntityGroup implements ToXContentObject, Writeable {

private static final ParseField LABEL = new ParseField("label");
private static final ParseField SCORE = new ParseField("score");
private static final ParseField WORD = new ParseField("word");

private final String label;
private final double score;
private final String word;

public EntityGroup(String label, double score, String word) {
this.label = Objects.requireNonNull(label);
this.score = score;
this.word = Objects.requireNonNull(word);
}

public EntityGroup(StreamInput in) throws IOException {
label = in.readString();
score = in.readDouble();
word = in.readString();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(LABEL.getPreferredName(), label);
builder.field(SCORE.getPreferredName(), score);
builder.field(WORD.getPreferredName(), word);
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(label);
out.writeDouble(score);
out.writeString(word);
}

public Map<String, Object> toMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(LABEL.getPreferredName(), label);
map.put(SCORE.getPreferredName(), score);
map.put(WORD.getPreferredName(), word);
return map;
}

public String getLabel() {
return label;
}

public double getScore() {
return score;
}

public String getWord() {
return word;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
EntityGroup that = (EntityGroup) o;
return Double.compare(that.score, score) == 0 &&
Objects.equals(label, that.label) &&
Objects.equals(word, that.word);
}

@Override
public int hashCode() {
return Objects.hash(label, score, word);
}
}
}
Loading