Skip to content

Commit

Permalink
[ML] Add PyTorch model configuration (#71035)
Browse files Browse the repository at this point in the history
Adds the model_type field to TrainedModelConfig for distinguishing between models
that can be loaded via the model loading service and those that require a native process.
  • Loading branch information
davidkyle authored Apr 1, 2021
1 parent 0897e8a commit 99ed8b0
Show file tree
Hide file tree
Showing 36 changed files with 739 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.client.ml.inference.trainedmodel.pytorch.PyTorchModel;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
Expand Down Expand Up @@ -59,6 +60,9 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class,
new ParseField(LangIdentNeuralNetwork.NAME),
LangIdentNeuralNetwork::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class,
new ParseField(PyTorchModel.NAME),
PyTorchModel::fromXContent));

// Inference Config
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class TrainedModelConfig implements ToXContentObject {
public static final String NAME = "trained_model_config";

public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField MODEL_TYPE = new ParseField("model_type");
public static final ParseField CREATED_BY = new ParseField("created_by");
public static final ParseField VERSION = new ParseField("version");
public static final ParseField DESCRIPTION = new ParseField("description");
Expand All @@ -53,6 +54,7 @@ public class TrainedModelConfig implements ToXContentObject {
TrainedModelConfig.Builder::new);
static {
PARSER.declareString(TrainedModelConfig.Builder::setModelId, MODEL_ID);
PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
PARSER.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY);
PARSER.declareString(TrainedModelConfig.Builder::setVersion, VERSION);
PARSER.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION);
Expand Down Expand Up @@ -81,6 +83,7 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
}

private final String modelId;
private final TrainedModelType modelType;
private final String createdBy;
private final Version version;
private final String description;
Expand All @@ -97,6 +100,7 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
private final InferenceConfig inferenceConfig;

TrainedModelConfig(String modelId,
TrainedModelType modelType,
String createdBy,
Version version,
String description,
Expand All @@ -112,6 +116,7 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
Map<String, String> defaultFieldMap,
InferenceConfig inferenceConfig) {
this.modelId = modelId;
this.modelType = modelType;
this.createdBy = createdBy;
this.version = version;
this.createTime = createTime == null ? null : Instant.ofEpochMilli(createTime.toEpochMilli());
Expand All @@ -132,6 +137,10 @@ public String getModelId() {
return modelId;
}

public TrainedModelType getModelType() {
return modelType;
}

public String getCreatedBy() {
return createdBy;
}
Expand Down Expand Up @@ -202,6 +211,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (modelId != null) {
builder.field(MODEL_ID.getPreferredName(), modelId);
}
if (modelType != null) {
builder.field(MODEL_TYPE.getPreferredName(), modelType.toString());
}
if (createdBy != null) {
builder.field(CREATED_BY.getPreferredName(), createdBy);
}
Expand Down Expand Up @@ -259,6 +271,7 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
TrainedModelConfig that = (TrainedModelConfig) o;
return Objects.equals(modelId, that.modelId) &&
Objects.equals(modelType, that.modelType) &&
Objects.equals(createdBy, that.createdBy) &&
Objects.equals(version, that.version) &&
Objects.equals(description, that.description) &&
Expand All @@ -278,6 +291,7 @@ public boolean equals(Object o) {
@Override
public int hashCode() {
return Objects.hash(modelId,
modelType,
createdBy,
version,
createTime,
Expand All @@ -298,6 +312,7 @@ public int hashCode() {
public static class Builder {

private String modelId;
private TrainedModelType modelType;
private String createdBy;
private Version version;
private String description;
Expand All @@ -318,6 +333,16 @@ public Builder setModelId(String modelId) {
return this;
}

public Builder setModelType(String modelType) {
this.modelType = TrainedModelType.fromString(modelType);
return this;
}

public Builder setModelType(TrainedModelType modelType) {
this.modelType = modelType;
return this;
}

private Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
Expand Down Expand Up @@ -404,6 +429,7 @@ public Builder setInferenceConfig(InferenceConfig inferenceConfig) {
public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
modelType,
createdBy,
version,
description,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.client.ml.inference;

import java.util.Locale;

public enum TrainedModelType {
TREE_ENSEMBLE, LANG_IDENT, PYTORCH;

public static TrainedModelType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
*/
package org.elasticsearch.client.ml.inference.trainedmodel;

import org.elasticsearch.common.ParseField;

import java.util.Locale;

public enum TargetType {

REGRESSION, CLASSIFICATION;

public static final ParseField TARGET_TYPE = new ParseField("target_type");

public static TargetType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ public class Ensemble implements TrainedModel {
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
public static final ParseField TRAINED_MODELS = new ParseField("trained_models");
public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output");
public static final ParseField TARGET_TYPE = new ParseField("target_type");
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
public static final ParseField CLASSIFICATION_WEIGHTS = new ParseField("classification_weights");

Expand All @@ -48,7 +47,7 @@ public class Ensemble implements TrainedModel {
PARSER.declareNamedObject(Ensemble.Builder::setOutputAggregator,
(p, c, n) -> p.namedObject(OutputAggregator.class, n, null),
AGGREGATE_OUTPUT);
PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
PARSER.declareString(Ensemble.Builder::setTargetType, TargetType.TARGET_TYPE);
PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
PARSER.declareDoubleArray(Ensemble.Builder::setClassificationWeights, CLASSIFICATION_WEIGHTS);
}
Expand Down Expand Up @@ -105,7 +104,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
Collections.singletonList(outputAggregator));
}
if (targetType != null) {
builder.field(TARGET_TYPE.getPreferredName(), targetType);
builder.field(TargetType.TARGET_TYPE.getPreferredName(), targetType);
}
if (classificationLabels != null) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.client.ml.inference.trainedmodel.pytorch;

import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

public class PyTorchModel implements TrainedModel {

public static final String NAME = "pytorch";
public static final ParseField MODEL_ID = new ParseField("model_id");

private static ObjectParser<PyTorchModel.Builder, Void> PARSER = new ObjectParser<>(
NAME,
true,
PyTorchModel.Builder::new);

static {
PARSER.declareString(PyTorchModel.Builder::setModelId, MODEL_ID);
PARSER.declareString(PyTorchModel.Builder::setTargetType, TargetType.TARGET_TYPE);
}

public static PyTorchModel fromXContent(XContentParser parser) {
return PARSER.apply(parser, null).build();
}

private final String modelId;
private final TargetType targetType;

public PyTorchModel(String modelId, TargetType targetType) {
this.modelId = Objects.requireNonNull(modelId);
this.targetType = Objects.requireNonNull(targetType);
}

public String getModelId() {
return modelId;
}

public TargetType getTargetType() {
return targetType;
}

@Override
public List<String> getFeatureNames() {
return Collections.emptyList();
}

@Override
public String getName() {
return NAME;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.field(TargetType.TARGET_TYPE.getPreferredName(), targetType);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
PyTorchModel that = (PyTorchModel) o;
return Objects.equals(modelId, that.modelId)
&& Objects.equals(targetType, that.targetType);
}

@Override
public int hashCode() {
return Objects.hash(modelId, targetType);
}

public static class Builder {

private String modelId;
private TargetType targetType;

public Builder setModelId(String modelId) {
this.modelId = modelId;
return this;
}

public Builder setTargetType(TargetType targetType) {
this.targetType = targetType;
return this;

}

private Builder setTargetType(String targetType) {
this.targetType = TargetType.fromString(targetType);
return this;
}

PyTorchModel build() {
return new PyTorchModel(modelId, targetType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@ public class Tree implements TrainedModel {

public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure");
public static final ParseField TARGET_TYPE = new ParseField("target_type");
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");

private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME, true, Builder::new);

static {
PARSER.declareStringArray(Builder::setFeatureNames, FEATURE_NAMES);
PARSER.declareObjectArray(Builder::setNodes, (p, c) -> TreeNode.fromXContent(p), TREE_STRUCTURE);
PARSER.declareString(Builder::setTargetType, TARGET_TYPE);
PARSER.declareString(Builder::setTargetType, TargetType.TARGET_TYPE);
PARSER.declareStringArray(Builder::setClassificationLabels, CLASSIFICATION_LABELS);
}

Expand Down Expand Up @@ -94,7 +93,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
}
if (targetType != null) {
builder.field(TARGET_TYPE.getPreferredName(), targetType.toString());
builder.field(TargetType.TARGET_TYPE.getPreferredName(), targetType.toString());
}
builder.endObject();
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ public void testDefaultNamedXContents() {

public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(77, namedXContents.size());
assertEquals(78, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
Expand Down Expand Up @@ -797,7 +797,7 @@ public void testProvidedNamedXContents() {
NGram.NAME,
Multi.NAME
));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME));
assertEquals(Integer.valueOf(4),
categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.client.ml.inference.TrainedModelInput;
import org.elasticsearch.client.ml.inference.TrainedModelStats;
import org.elasticsearch.client.ml.inference.TrainedModelType;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
Expand Down Expand Up @@ -3825,6 +3826,7 @@ public void testPutTrainedModel() throws Exception {
.setDefinition(definition) // <1>
.setCompressedDefinition(InferenceToXContentCompressor.deflate(definition)) // <2>
.setModelId("my-new-trained-model") // <3>
.setModelType(TrainedModelType.TREE_ENSEMBLE) // <4>
.setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) // <4>
.setDescription("test model") // <5>
.setMetadata(new HashMap<>()) // <6>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public static TrainedModelConfig createTestTrainedModelConfig() {
TargetType targetType = randomFrom(TargetType.values());
return new TrainedModelConfig(
randomAlphaOfLength(10),
randomBoolean() ? null : randomFrom(TrainedModelType.values()),
randomAlphaOfLength(10),
Version.CURRENT,
randomBoolean() ? null : randomAlphaOfLength(100),
Expand Down
Loading

0 comments on commit 99ed8b0

Please sign in to comment.