Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Merge the pytorch-inference feature branch #73660

Merged
merged 22 commits into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8ba697b
[ML] Start and stop model deployments (#70713)
dimitris-athanasiou Mar 29, 2021
89a85e1
Merge branch 'master' into feature/pytorch-inference
davidkyle Mar 30, 2021
0897e8a
Merge branch 'master' into feature/pytorch-inference
davidkyle Apr 1, 2021
99ed8b0
[ML] Add PyTorch model configuration (#71035)
davidkyle Apr 1, 2021
a261f0d
[ML] Infer against model deployment (#71177)
dimitris-athanasiou Apr 7, 2021
720fbed
Merge branch 'master' into feature/pytorch-inference
davidkyle Apr 7, 2021
52df9ad
[ML] Model storage for 3rd Party models (#71323)
davidkyle Apr 12, 2021
ce3a136
Merge branch 'master' into feature/pytorch-inference
davidkyle Apr 14, 2021
91eb2cf
Merge branch 'master' into feature/pytorch-inference
davidkyle Apr 20, 2021
64c04e5
[ML] Store compressed model definitions in ByteReferences (#71679)
davidkyle Apr 20, 2021
f157cde
Merge branch 'master' into feature/pytorch-inference
davidkyle Apr 27, 2021
84da366
Merge branch 'master' into feature/pytorch-inference
davidkyle Apr 28, 2021
63bb0c6
[ML] Load and evaluate 3rd Party Model (#72218)
davidkyle Apr 29, 2021
6f113bf
Merge branch 'master' into feature/pytorch-inference
davidkyle May 28, 2021
7ed153c
Merge branch 'master' into feature/pytorch-inference
davidkyle Jun 1, 2021
418985b
Merge branch 'master' into feature/pytorch-inference
davidkyle Jun 1, 2021
8e51034
[ML] Natural Language Processing tasks and models (#73523)
davidkyle Jun 2, 2021
41e8ba4
Resolve minor TODOs
davidkyle Jun 2, 2021
c450f3b
Use more standard ml/trained_models/{ID}/deployment/.. URL
davidkyle Jun 2, 2021
5ec5a54
Merge branch 'master' into feature/pytorch-inference
davidkyle Jun 2, 2021
4fc7cf6
Fix doc test tags
davidkyle Jun 2, 2021
55e9ab0
fix typo
davidkyle Jun 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
import org.elasticsearch.client.ml.inference.preprocessing.Multi;
import org.elasticsearch.client.ml.inference.preprocessing.NGram;
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.IndexLocation;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModelLocation;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Exponent;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
Expand Down Expand Up @@ -82,6 +84,11 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
new ParseField(Exponent.NAME),
Exponent::fromXContent));

// location
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModelLocation.class,
new ParseField(IndexLocation.INDEX),
IndexLocation::fromXContent));

return namedXContent;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import org.elasticsearch.Version;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModelLocation;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.ByteSizeValue;
Expand All @@ -33,6 +35,7 @@ public class TrainedModelConfig implements ToXContentObject {
public static final String NAME = "trained_model_config";

public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField MODEL_TYPE = new ParseField("model_type");
public static final ParseField CREATED_BY = new ParseField("created_by");
public static final ParseField VERSION = new ParseField("version");
public static final ParseField DESCRIPTION = new ParseField("description");
Expand All @@ -47,12 +50,14 @@ public class TrainedModelConfig implements ToXContentObject {
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map");
public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
public static final ParseField LOCATION = new ParseField("location");

public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
TrainedModelConfig.Builder::new);
static {
PARSER.declareString(TrainedModelConfig.Builder::setModelId, MODEL_ID);
PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
PARSER.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY);
PARSER.declareString(TrainedModelConfig.Builder::setVersion, VERSION);
PARSER.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION);
Expand All @@ -74,13 +79,17 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareNamedObject(TrainedModelConfig.Builder::setInferenceConfig,
(p, c, n) -> p.namedObject(InferenceConfig.class, n, null),
INFERENCE_CONFIG);
PARSER.declareNamedObject(TrainedModelConfig.Builder::setLocation,
(p, c, n) -> p.namedObject(TrainedModelLocation.class, n, null),
LOCATION);
}

public static TrainedModelConfig fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null).build();
}

private final String modelId;
private final TrainedModelType modelType;
private final String createdBy;
private final Version version;
private final String description;
Expand All @@ -95,8 +104,10 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
private final String licenseLevel;
private final Map<String, String> defaultFieldMap;
private final InferenceConfig inferenceConfig;
private final TrainedModelLocation location;

TrainedModelConfig(String modelId,
TrainedModelType modelType,
String createdBy,
Version version,
String description,
Expand All @@ -110,8 +121,10 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
Long estimatedOperations,
String licenseLevel,
Map<String, String> defaultFieldMap,
InferenceConfig inferenceConfig) {
InferenceConfig inferenceConfig,
TrainedModelLocation location) {
this.modelId = modelId;
this.modelType = modelType;
this.createdBy = createdBy;
this.version = version;
this.createTime = createTime == null ? null : Instant.ofEpochMilli(createTime.toEpochMilli());
Expand All @@ -126,12 +139,17 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
this.licenseLevel = licenseLevel;
this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap);
this.inferenceConfig = inferenceConfig;
this.location = location;
}

public String getModelId() {
return modelId;
}

public TrainedModelType getModelType() {
return modelType;
}

public String getCreatedBy() {
return createdBy;
}
Expand Down Expand Up @@ -164,6 +182,11 @@ public String getCompressedDefinition() {
return compressedDefinition;
}

@Nullable
public TrainedModelLocation getLocation() {
return location;
}

public TrainedModelInput getInput() {
return input;
}
Expand Down Expand Up @@ -202,6 +225,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (modelId != null) {
builder.field(MODEL_ID.getPreferredName(), modelId);
}
if (modelType != null) {
builder.field(MODEL_TYPE.getPreferredName(), modelType.toString());
}
if (createdBy != null) {
builder.field(CREATED_BY.getPreferredName(), createdBy);
}
Expand Down Expand Up @@ -244,6 +270,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (inferenceConfig != null) {
writeNamedObject(builder, params, INFERENCE_CONFIG.getPreferredName(), inferenceConfig);
}
if (location != null) {
writeNamedObject(builder, params, LOCATION.getPreferredName(), location);
}
builder.endObject();
return builder;
}
Expand All @@ -259,6 +288,7 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
TrainedModelConfig that = (TrainedModelConfig) o;
return Objects.equals(modelId, that.modelId) &&
Objects.equals(modelType, that.modelType) &&
Objects.equals(createdBy, that.createdBy) &&
Objects.equals(version, that.version) &&
Objects.equals(description, that.description) &&
Expand All @@ -272,12 +302,14 @@ public boolean equals(Object o) {
Objects.equals(licenseLevel, that.licenseLevel) &&
Objects.equals(defaultFieldMap, that.defaultFieldMap) &&
Objects.equals(inferenceConfig, that.inferenceConfig) &&
Objects.equals(metadata, that.metadata);
Objects.equals(metadata, that.metadata) &&
Objects.equals(location, that.location);
}

@Override
public int hashCode() {
return Objects.hash(modelId,
modelType,
createdBy,
version,
createTime,
Expand All @@ -291,13 +323,15 @@ public int hashCode() {
licenseLevel,
input,
inferenceConfig,
defaultFieldMap);
defaultFieldMap,
location);
}


public static class Builder {

private String modelId;
private TrainedModelType modelType;
private String createdBy;
private Version version;
private String description;
Expand All @@ -312,12 +346,23 @@ public static class Builder {
private String licenseLevel;
private Map<String, String> defaultFieldMap;
private InferenceConfig inferenceConfig;
private TrainedModelLocation location;

public Builder setModelId(String modelId) {
this.modelId = modelId;
return this;
}

public Builder setModelType(String modelType) {
this.modelType = TrainedModelType.fromString(modelType);
return this;
}

public Builder setModelType(TrainedModelType modelType) {
this.modelType = modelType;
return this;
}

private Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
Expand Down Expand Up @@ -371,6 +416,11 @@ public Builder setDefinition(TrainedModelDefinition definition) {
return this;
}

public Builder setLocation(TrainedModelLocation location) {
this.location = location;
return this;
}

public Builder setInput(TrainedModelInput input) {
this.input = input;
return this;
Expand Down Expand Up @@ -404,6 +454,7 @@ public Builder setInferenceConfig(InferenceConfig inferenceConfig) {
public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
modelType,
createdBy,
version,
description,
Expand All @@ -417,7 +468,8 @@ public TrainedModelConfig build() {
estimatedOperations,
licenseLevel,
defaultFieldMap,
inferenceConfig);
inferenceConfig,
location);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.client.ml.inference;

import java.util.Locale;

public enum TrainedModelType {
TREE_ENSEMBLE, LANG_IDENT, PYTORCH;

public static TrainedModelType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.client.ml.inference.trainedmodel;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Objects;

public class IndexLocation implements TrainedModelLocation {

public static final String INDEX = "index";
private static final ParseField MODEL_ID = new ParseField("model_id");
private static final ParseField NAME = new ParseField("name");

private static final ConstructingObjectParser<IndexLocation, Void> PARSER =
new ConstructingObjectParser<>(INDEX, true, a -> new IndexLocation((String) a[0], (String) a[1]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
PARSER.declareString(ConstructingObjectParser.constructorArg(), NAME);
}

public static IndexLocation fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}

private final String modelId;
private final String index;

public IndexLocation(String modelId, String index) {
this.modelId = Objects.requireNonNull(modelId);
this.index = Objects.requireNonNull(index);
}

public String getModelId() {
return modelId;
}

public String getIndex() {
return index;
}

@Override
public String getName() {
return INDEX;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(NAME.getPreferredName(), index);
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
IndexLocation that = (IndexLocation) o;
return Objects.equals(modelId, that.modelId)
&& Objects.equals(index, that.index);
}

@Override
public int hashCode() {
return Objects.hash(modelId, index);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
*/
package org.elasticsearch.client.ml.inference.trainedmodel;

import org.elasticsearch.common.ParseField;

import java.util.Locale;

public enum TargetType {

REGRESSION, CLASSIFICATION;

public static final ParseField TARGET_TYPE = new ParseField("target_type");

public static TargetType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.client.ml.inference.trainedmodel;

import org.elasticsearch.client.ml.inference.NamedXContentObject;

public interface TrainedModelLocation extends NamedXContentObject {
}
Loading