Skip to content

Commit

Permalink
[ML] Merge the pytorch-inference feature branch (#73660)
Browse files Browse the repository at this point in the history
The feature branch contains changes to configure PyTorch models with a 
TrainedModelConfig and defines a format to store the binary models. 
The _start and _stop deployment actions control the model lifecycle 
and the model can be directly evaluated with the _infer endpoint. 
2 Types of NLP tasks are supported: Named Entity Recognition and Fill Mask.

The feature branch consists of these PRs: #73523, #72218, #71679
#71323, #71035, #71177, #70713
  • Loading branch information
davidkyle authored Jun 3, 2021
1 parent c089d91 commit 94adaa5
Show file tree
Hide file tree
Showing 109 changed files with 7,067 additions and 373 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
import org.elasticsearch.client.ml.inference.preprocessing.Multi;
import org.elasticsearch.client.ml.inference.preprocessing.NGram;
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.IndexLocation;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModelLocation;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Exponent;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
Expand Down Expand Up @@ -82,6 +84,11 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
new ParseField(Exponent.NAME),
Exponent::fromXContent));

// location
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModelLocation.class,
new ParseField(IndexLocation.INDEX),
IndexLocation::fromXContent));

return namedXContent;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import org.elasticsearch.Version;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModelLocation;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.ByteSizeValue;
Expand All @@ -33,6 +35,7 @@ public class TrainedModelConfig implements ToXContentObject {
public static final String NAME = "trained_model_config";

public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField MODEL_TYPE = new ParseField("model_type");
public static final ParseField CREATED_BY = new ParseField("created_by");
public static final ParseField VERSION = new ParseField("version");
public static final ParseField DESCRIPTION = new ParseField("description");
Expand All @@ -47,12 +50,14 @@ public class TrainedModelConfig implements ToXContentObject {
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map");
public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
public static final ParseField LOCATION = new ParseField("location");

public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
TrainedModelConfig.Builder::new);
static {
PARSER.declareString(TrainedModelConfig.Builder::setModelId, MODEL_ID);
PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
PARSER.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY);
PARSER.declareString(TrainedModelConfig.Builder::setVersion, VERSION);
PARSER.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION);
Expand All @@ -74,13 +79,17 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareNamedObject(TrainedModelConfig.Builder::setInferenceConfig,
(p, c, n) -> p.namedObject(InferenceConfig.class, n, null),
INFERENCE_CONFIG);
PARSER.declareNamedObject(TrainedModelConfig.Builder::setLocation,
(p, c, n) -> p.namedObject(TrainedModelLocation.class, n, null),
LOCATION);
}

public static TrainedModelConfig fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null).build();
}

private final String modelId;
private final TrainedModelType modelType;
private final String createdBy;
private final Version version;
private final String description;
Expand All @@ -95,8 +104,10 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
private final String licenseLevel;
private final Map<String, String> defaultFieldMap;
private final InferenceConfig inferenceConfig;
private final TrainedModelLocation location;

TrainedModelConfig(String modelId,
TrainedModelType modelType,
String createdBy,
Version version,
String description,
Expand All @@ -110,8 +121,10 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
Long estimatedOperations,
String licenseLevel,
Map<String, String> defaultFieldMap,
InferenceConfig inferenceConfig) {
InferenceConfig inferenceConfig,
TrainedModelLocation location) {
this.modelId = modelId;
this.modelType = modelType;
this.createdBy = createdBy;
this.version = version;
this.createTime = createTime == null ? null : Instant.ofEpochMilli(createTime.toEpochMilli());
Expand All @@ -126,12 +139,17 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
this.licenseLevel = licenseLevel;
this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap);
this.inferenceConfig = inferenceConfig;
this.location = location;
}

public String getModelId() {
return modelId;
}

public TrainedModelType getModelType() {
return modelType;
}

public String getCreatedBy() {
return createdBy;
}
Expand Down Expand Up @@ -164,6 +182,11 @@ public String getCompressedDefinition() {
return compressedDefinition;
}

@Nullable
public TrainedModelLocation getLocation() {
return location;
}

public TrainedModelInput getInput() {
return input;
}
Expand Down Expand Up @@ -202,6 +225,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (modelId != null) {
builder.field(MODEL_ID.getPreferredName(), modelId);
}
if (modelType != null) {
builder.field(MODEL_TYPE.getPreferredName(), modelType.toString());
}
if (createdBy != null) {
builder.field(CREATED_BY.getPreferredName(), createdBy);
}
Expand Down Expand Up @@ -244,6 +270,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (inferenceConfig != null) {
writeNamedObject(builder, params, INFERENCE_CONFIG.getPreferredName(), inferenceConfig);
}
if (location != null) {
writeNamedObject(builder, params, LOCATION.getPreferredName(), location);
}
builder.endObject();
return builder;
}
Expand All @@ -259,6 +288,7 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
TrainedModelConfig that = (TrainedModelConfig) o;
return Objects.equals(modelId, that.modelId) &&
Objects.equals(modelType, that.modelType) &&
Objects.equals(createdBy, that.createdBy) &&
Objects.equals(version, that.version) &&
Objects.equals(description, that.description) &&
Expand All @@ -272,12 +302,14 @@ public boolean equals(Object o) {
Objects.equals(licenseLevel, that.licenseLevel) &&
Objects.equals(defaultFieldMap, that.defaultFieldMap) &&
Objects.equals(inferenceConfig, that.inferenceConfig) &&
Objects.equals(metadata, that.metadata);
Objects.equals(metadata, that.metadata) &&
Objects.equals(location, that.location);
}

@Override
public int hashCode() {
return Objects.hash(modelId,
modelType,
createdBy,
version,
createTime,
Expand All @@ -291,13 +323,15 @@ public int hashCode() {
licenseLevel,
input,
inferenceConfig,
defaultFieldMap);
defaultFieldMap,
location);
}


public static class Builder {

private String modelId;
private TrainedModelType modelType;
private String createdBy;
private Version version;
private String description;
Expand All @@ -312,12 +346,23 @@ public static class Builder {
private String licenseLevel;
private Map<String, String> defaultFieldMap;
private InferenceConfig inferenceConfig;
private TrainedModelLocation location;

public Builder setModelId(String modelId) {
this.modelId = modelId;
return this;
}

public Builder setModelType(String modelType) {
this.modelType = TrainedModelType.fromString(modelType);
return this;
}

public Builder setModelType(TrainedModelType modelType) {
this.modelType = modelType;
return this;
}

private Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
Expand Down Expand Up @@ -371,6 +416,11 @@ public Builder setDefinition(TrainedModelDefinition definition) {
return this;
}

public Builder setLocation(TrainedModelLocation location) {
this.location = location;
return this;
}

public Builder setInput(TrainedModelInput input) {
this.input = input;
return this;
Expand Down Expand Up @@ -404,6 +454,7 @@ public Builder setInferenceConfig(InferenceConfig inferenceConfig) {
public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
modelType,
createdBy,
version,
description,
Expand All @@ -417,7 +468,8 @@ public TrainedModelConfig build() {
estimatedOperations,
licenseLevel,
defaultFieldMap,
inferenceConfig);
inferenceConfig,
location);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.client.ml.inference;

import java.util.Locale;

public enum TrainedModelType {
TREE_ENSEMBLE, LANG_IDENT, PYTORCH;

public static TrainedModelType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.client.ml.inference.trainedmodel;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Objects;

public class IndexLocation implements TrainedModelLocation {

public static final String INDEX = "index";
private static final ParseField MODEL_ID = new ParseField("model_id");
private static final ParseField NAME = new ParseField("name");

private static final ConstructingObjectParser<IndexLocation, Void> PARSER =
new ConstructingObjectParser<>(INDEX, true, a -> new IndexLocation((String) a[0], (String) a[1]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
PARSER.declareString(ConstructingObjectParser.constructorArg(), NAME);
}

public static IndexLocation fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}

private final String modelId;
private final String index;

public IndexLocation(String modelId, String index) {
this.modelId = Objects.requireNonNull(modelId);
this.index = Objects.requireNonNull(index);
}

public String getModelId() {
return modelId;
}

public String getIndex() {
return index;
}

@Override
public String getName() {
return INDEX;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(NAME.getPreferredName(), index);
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
IndexLocation that = (IndexLocation) o;
return Objects.equals(modelId, that.modelId)
&& Objects.equals(index, that.index);
}

@Override
public int hashCode() {
return Objects.hash(modelId, index);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
*/
package org.elasticsearch.client.ml.inference.trainedmodel;

import org.elasticsearch.common.ParseField;

import java.util.Locale;

public enum TargetType {

REGRESSION, CLASSIFICATION;

public static final ParseField TARGET_TYPE = new ParseField("target_type");

public static TargetType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.client.ml.inference.trainedmodel;

import org.elasticsearch.client.ml.inference.NamedXContentObject;

public interface TrainedModelLocation extends NamedXContentObject {
}
Loading

0 comments on commit 94adaa5

Please sign in to comment.