-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML] Add PyTorch model configuration (#71035)
Adds the model_type field to TrainedModelConfig for distinguishing between models that can be loaded via the model loading service and those that require a native process.
- Loading branch information
Showing
36 changed files
with
739 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 changes: 24 additions & 0 deletions
24
...rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelType.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0 and the Server Side Public License, v 1; you may not use this file except | ||
* in compliance with, at your election, the Elastic License 2.0 or the Server | ||
* Side Public License, v 1. | ||
*/ | ||
|
||
package org.elasticsearch.client.ml.inference; | ||
|
||
import java.util.Locale; | ||
|
||
public enum TrainedModelType { | ||
TREE_ENSEMBLE, LANG_IDENT, PYTORCH; | ||
|
||
public static TrainedModelType fromString(String name) { | ||
return valueOf(name.trim().toUpperCase(Locale.ROOT)); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return name().toLowerCase(Locale.ROOT); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
120 changes: 120 additions & 0 deletions
120
...rc/main/java/org/elasticsearch/client/ml/inference/trainedmodel/pytorch/PyTorchModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0 and the Server Side Public License, v 1; you may not use this file except | ||
* in compliance with, at your election, the Elastic License 2.0 or the Server | ||
* Side Public License, v 1. | ||
*/ | ||
|
||
package org.elasticsearch.client.ml.inference.trainedmodel.pytorch; | ||
|
||
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; | ||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; | ||
import org.elasticsearch.common.ParseField; | ||
import org.elasticsearch.common.xcontent.ObjectParser; | ||
import org.elasticsearch.common.xcontent.XContentBuilder; | ||
import org.elasticsearch.common.xcontent.XContentParser; | ||
|
||
import java.io.IOException; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
public class PyTorchModel implements TrainedModel { | ||
|
||
public static final String NAME = "pytorch"; | ||
public static final ParseField MODEL_ID = new ParseField("model_id"); | ||
|
||
private static ObjectParser<PyTorchModel.Builder, Void> PARSER = new ObjectParser<>( | ||
NAME, | ||
true, | ||
PyTorchModel.Builder::new); | ||
|
||
static { | ||
PARSER.declareString(PyTorchModel.Builder::setModelId, MODEL_ID); | ||
PARSER.declareString(PyTorchModel.Builder::setTargetType, TargetType.TARGET_TYPE); | ||
} | ||
|
||
public static PyTorchModel fromXContent(XContentParser parser) { | ||
return PARSER.apply(parser, null).build(); | ||
} | ||
|
||
private final String modelId; | ||
private final TargetType targetType; | ||
|
||
public PyTorchModel(String modelId, TargetType targetType) { | ||
this.modelId = Objects.requireNonNull(modelId); | ||
this.targetType = Objects.requireNonNull(targetType); | ||
} | ||
|
||
public String getModelId() { | ||
return modelId; | ||
} | ||
|
||
public TargetType getTargetType() { | ||
return targetType; | ||
} | ||
|
||
@Override | ||
public List<String> getFeatureNames() { | ||
return Collections.emptyList(); | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return NAME; | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(MODEL_ID.getPreferredName(), modelId); | ||
builder.field(TargetType.TARGET_TYPE.getPreferredName(), targetType); | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) { | ||
return true; | ||
} | ||
if (o == null || getClass() != o.getClass()) { | ||
return false; | ||
} | ||
PyTorchModel that = (PyTorchModel) o; | ||
return Objects.equals(modelId, that.modelId) | ||
&& Objects.equals(targetType, that.targetType); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(modelId, targetType); | ||
} | ||
|
||
public static class Builder { | ||
|
||
private String modelId; | ||
private TargetType targetType; | ||
|
||
public Builder setModelId(String modelId) { | ||
this.modelId = modelId; | ||
return this; | ||
} | ||
|
||
public Builder setTargetType(TargetType targetType) { | ||
this.targetType = targetType; | ||
return this; | ||
|
||
} | ||
|
||
private Builder setTargetType(String targetType) { | ||
this.targetType = TargetType.fromString(targetType); | ||
return this; | ||
} | ||
|
||
PyTorchModel build() { | ||
return new PyTorchModel(modelId, targetType); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.