diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java index 0f25c25421289..2c0fe83f04c65 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java @@ -11,9 +11,11 @@ import org.elasticsearch.client.ml.inference.preprocessing.Multi; import org.elasticsearch.client.ml.inference.preprocessing.NGram; import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.client.ml.inference.trainedmodel.IndexLocation; import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModelLocation; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Exponent; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression; @@ -82,6 +84,11 @@ public List getNamedXContentParsers() { new ParseField(Exponent.NAME), Exponent::fromXContent)); + // location + namedXContent.add(new NamedXContentRegistry.Entry(TrainedModelLocation.class, + new ParseField(IndexLocation.INDEX), + IndexLocation::fromXContent)); + return namedXContent; } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java index f43ef31c928c1..83a798e983506 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -10,6 +10,8 @@ import org.elasticsearch.Version; import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModelLocation; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.unit.ByteSizeValue; @@ -33,6 +35,7 @@ public class TrainedModelConfig implements ToXContentObject { public static final String NAME = "trained_model_config"; public static final ParseField MODEL_ID = new ParseField("model_id"); + public static final ParseField MODEL_TYPE = new ParseField("model_type"); public static final ParseField CREATED_BY = new ParseField("created_by"); public static final ParseField VERSION = new ParseField("version"); public static final ParseField DESCRIPTION = new ParseField("description"); @@ -47,12 +50,14 @@ public class TrainedModelConfig implements ToXContentObject { public static final ParseField LICENSE_LEVEL = new ParseField("license_level"); public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map"); public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config"); + public static final ParseField LOCATION = new ParseField("location"); public static final ObjectParser PARSER = new ObjectParser<>(NAME, true, TrainedModelConfig.Builder::new); static { PARSER.declareString(TrainedModelConfig.Builder::setModelId, MODEL_ID); + PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE); PARSER.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY); PARSER.declareString(TrainedModelConfig.Builder::setVersion, VERSION); PARSER.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION); @@ -74,6 +79,9 @@ public class TrainedModelConfig implements ToXContentObject { PARSER.declareNamedObject(TrainedModelConfig.Builder::setInferenceConfig, (p, c, n) -> p.namedObject(InferenceConfig.class, n, null), INFERENCE_CONFIG); + PARSER.declareNamedObject(TrainedModelConfig.Builder::setLocation, + (p, c, n) -> p.namedObject(TrainedModelLocation.class, n, null), + LOCATION); } public static TrainedModelConfig fromXContent(XContentParser parser) throws IOException { @@ -81,6 +89,7 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx } private final String modelId; + private final TrainedModelType modelType; private final String createdBy; private final Version version; private final String description; @@ -95,8 +104,10 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx private final String licenseLevel; private final Map defaultFieldMap; private final InferenceConfig inferenceConfig; + private final TrainedModelLocation location; TrainedModelConfig(String modelId, + TrainedModelType modelType, String createdBy, Version version, String description, @@ -110,8 +121,10 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx Long estimatedOperations, String licenseLevel, Map defaultFieldMap, - InferenceConfig inferenceConfig) { + InferenceConfig inferenceConfig, + TrainedModelLocation location) { this.modelId = modelId; + this.modelType = modelType; this.createdBy = createdBy; this.version = version; this.createTime = createTime == null ? null : Instant.ofEpochMilli(createTime.toEpochMilli()); @@ -126,12 +139,17 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx this.licenseLevel = licenseLevel; this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap); this.inferenceConfig = inferenceConfig; + this.location = location; } public String getModelId() { return modelId; } + public TrainedModelType getModelType() { + return modelType; + } + public String getCreatedBy() { return createdBy; } @@ -164,6 +182,11 @@ public String getCompressedDefinition() { return compressedDefinition; } + @Nullable + public TrainedModelLocation getLocation() { + return location; + } + public TrainedModelInput getInput() { return input; } @@ -202,6 +225,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (modelId != null) { builder.field(MODEL_ID.getPreferredName(), modelId); } + if (modelType != null) { + builder.field(MODEL_TYPE.getPreferredName(), modelType.toString()); + } if (createdBy != null) { builder.field(CREATED_BY.getPreferredName(), createdBy); } @@ -244,6 +270,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (inferenceConfig != null) { writeNamedObject(builder, params, INFERENCE_CONFIG.getPreferredName(), inferenceConfig); } + if (location != null) { + writeNamedObject(builder, params, LOCATION.getPreferredName(), location); + } builder.endObject(); return builder; } @@ -259,6 +288,7 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; TrainedModelConfig that = (TrainedModelConfig) o; return Objects.equals(modelId, that.modelId) && + Objects.equals(modelType, that.modelType) && Objects.equals(createdBy, that.createdBy) && Objects.equals(version, that.version) && Objects.equals(description, that.description) && @@ -272,12 +302,14 @@ public boolean equals(Object o) { Objects.equals(licenseLevel, that.licenseLevel) && Objects.equals(defaultFieldMap, that.defaultFieldMap) && Objects.equals(inferenceConfig, that.inferenceConfig) && - Objects.equals(metadata, that.metadata); + Objects.equals(metadata, that.metadata) && + Objects.equals(location, that.location); } @Override public int hashCode() { return Objects.hash(modelId, + modelType, createdBy, version, createTime, @@ -291,13 +323,15 @@ public int hashCode() { licenseLevel, input, inferenceConfig, - defaultFieldMap); + defaultFieldMap, + location); } public static class Builder { private String modelId; + private TrainedModelType modelType; private String createdBy; private Version version; private String description; @@ -312,12 +346,23 @@ public static class Builder { private String licenseLevel; private Map defaultFieldMap; private InferenceConfig inferenceConfig; + private TrainedModelLocation location; public Builder setModelId(String modelId) { this.modelId = modelId; return this; } + public Builder setModelType(String modelType) { + this.modelType = TrainedModelType.fromString(modelType); + return this; + } + + public Builder setModelType(TrainedModelType modelType) { + this.modelType = modelType; + return this; + } + private Builder setCreatedBy(String createdBy) { this.createdBy = createdBy; return this; @@ -371,6 +416,11 @@ public Builder setDefinition(TrainedModelDefinition definition) { return this; } + public Builder setLocation(TrainedModelLocation location) { + this.location = location; + return this; + } + public Builder setInput(TrainedModelInput input) { this.input = input; return this; @@ -404,6 +454,7 @@ public Builder setInferenceConfig(InferenceConfig inferenceConfig) { public TrainedModelConfig build() { return new TrainedModelConfig( modelId, + modelType, createdBy, version, description, @@ -417,7 +468,8 @@ public TrainedModelConfig build() { estimatedOperations, licenseLevel, defaultFieldMap, - inferenceConfig); + inferenceConfig, + location); } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelType.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelType.java new file mode 100644 index 0000000000000..7829a4a56f0ab --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelType.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.client.ml.inference; + +import java.util.Locale; + +public enum TrainedModelType { + TREE_ENSEMBLE, LANG_IDENT, PYTORCH; + + public static TrainedModelType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/IndexLocation.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/IndexLocation.java new file mode 100644 index 0000000000000..457606c2611f8 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/IndexLocation.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.client.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class IndexLocation implements TrainedModelLocation { + + public static final String INDEX = "index"; + private static final ParseField MODEL_ID = new ParseField("model_id"); + private static final ParseField NAME = new ParseField("name"); + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(INDEX, true, a -> new IndexLocation((String) a[0], (String) a[1])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID); + PARSER.declareString(ConstructingObjectParser.constructorArg(), NAME); + } + + public static IndexLocation fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private final String modelId; + private final String index; + + public IndexLocation(String modelId, String index) { + this.modelId = Objects.requireNonNull(modelId); + this.index = Objects.requireNonNull(index); + } + + public String getModelId() { + return modelId; + } + + public String getIndex() { + return index; + } + + @Override + public String getName() { + return INDEX; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(NAME.getPreferredName(), index); + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + IndexLocation that = (IndexLocation) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(index, that.index); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, index); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TargetType.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TargetType.java index 5999a260783de..8c0202555eebe 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TargetType.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TargetType.java @@ -7,12 +7,16 @@ */ package org.elasticsearch.client.ml.inference.trainedmodel; +import org.elasticsearch.common.ParseField; + import java.util.Locale; public enum TargetType { REGRESSION, CLASSIFICATION; + public static final ParseField TARGET_TYPE = new ParseField("target_type"); + public static TargetType fromString(String name) { return valueOf(name.trim().toUpperCase(Locale.ROOT)); } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModelLocation.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModelLocation.java new file mode 100644 index 0000000000000..c5914b5bb625d --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModelLocation.java @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.client.ml.inference.trainedmodel; + +import org.elasticsearch.client.ml.inference.NamedXContentObject; + +public interface TrainedModelLocation extends NamedXContentObject { +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java index 7af14673fedae..4673eabe1057b 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -29,7 +29,6 @@ public class Ensemble implements TrainedModel { public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); public static final ParseField TRAINED_MODELS = new ParseField("trained_models"); public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output"); - public static final ParseField TARGET_TYPE = new ParseField("target_type"); public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); public static final ParseField CLASSIFICATION_WEIGHTS = new ParseField("classification_weights"); @@ -48,7 +47,7 @@ public class Ensemble implements TrainedModel { PARSER.declareNamedObject(Ensemble.Builder::setOutputAggregator, (p, c, n) -> p.namedObject(OutputAggregator.class, n, null), AGGREGATE_OUTPUT); - PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE); + PARSER.declareString(Ensemble.Builder::setTargetType, TargetType.TARGET_TYPE); PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS); PARSER.declareDoubleArray(Ensemble.Builder::setClassificationWeights, CLASSIFICATION_WEIGHTS); } @@ -105,7 +104,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par Collections.singletonList(outputAggregator)); } if (targetType != null) { - builder.field(TARGET_TYPE.getPreferredName(), targetType); + builder.field(TargetType.TARGET_TYPE.getPreferredName(), targetType); } if (classificationLabels != null) { builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java index 337bdf77ab70b..4c3314a11acfe 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java @@ -30,7 +30,6 @@ public class Tree implements TrainedModel { public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure"); - public static final ParseField TARGET_TYPE = new ParseField("target_type"); public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); private static final ObjectParser PARSER = new ObjectParser<>(NAME, true, Builder::new); @@ -38,7 +37,7 @@ public class Tree implements TrainedModel { static { PARSER.declareStringArray(Builder::setFeatureNames, FEATURE_NAMES); PARSER.declareObjectArray(Builder::setNodes, (p, c) -> TreeNode.fromXContent(p), TREE_STRUCTURE); - PARSER.declareString(Builder::setTargetType, TARGET_TYPE); + PARSER.declareString(Builder::setTargetType, TargetType.TARGET_TYPE); PARSER.declareStringArray(Builder::setClassificationLabels, CLASSIFICATION_LABELS); } @@ -94,7 +93,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); } if (targetType != null) { - builder.field(TARGET_TYPE.getPreferredName(), targetType.toString()); + builder.field(TargetType.TARGET_TYPE.getPreferredName(), targetType.toString()); } builder.endObject(); return builder; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 25faa195657fc..c5344dacc6e1d 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -700,7 +700,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(77, namedXContents.size()); + assertEquals(78, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -710,7 +710,7 @@ public void testProvidedNamedXContents() { categories.put(namedXContent.categoryClass, counter + 1); } } - assertEquals("Had: " + categories, 15, categories.size()); + assertEquals("Had: " + categories, 16, categories.size()); assertEquals(Integer.valueOf(3), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 07288adce7436..631bc24527151 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -171,6 +171,7 @@ import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.client.ml.inference.TrainedModelInput; import org.elasticsearch.client.ml.inference.TrainedModelStats; +import org.elasticsearch.client.ml.inference.TrainedModelType; import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; @@ -3826,11 +3827,12 @@ public void testPutTrainedModel() throws Exception { .setDefinition(definition) // <1> .setCompressedDefinition(InferenceToXContentCompressor.deflate(definition)) // <2> .setModelId("my-new-trained-model") // <3> - .setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) // <4> - .setDescription("test model") // <5> - .setMetadata(new HashMap<>()) // <6> - .setTags("my_regression_models") // <7> - .setInferenceConfig(new RegressionConfig("value", 0)) // <8> + .setModelType(TrainedModelType.TREE_ENSEMBLE) // <4> + .setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) // <5> + .setDescription("test model") // <6> + .setMetadata(new HashMap<>()) // <7> + .setTags("my_regression_models") // <8> + .setInferenceConfig(new RegressionConfig("value", 0)) // <9> .build(); // end::put-trained-model-config diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/IndexLocationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/IndexLocationTests.java new file mode 100644 index 0000000000000..ab19977d7ee2f --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/IndexLocationTests.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.client.ml.inference.trainedmodel.IndexLocation; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class IndexLocationTests extends AbstractXContentTestCase { + + static IndexLocation randomInstance() { + return new IndexLocation(randomAlphaOfLength(7), randomAlphaOfLength(7)); + } + + @Override + protected IndexLocation createTestInstance() { + return randomInstance(); + } + + @Override + protected IndexLocation doParseInstance(XContentParser parser) throws IOException { + return IndexLocation.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index e3a0dd6f7b68f..3a26c7701411e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -34,6 +34,7 @@ public static TrainedModelConfig createTestTrainedModelConfig() { TargetType targetType = randomFrom(TargetType.values()); return new TrainedModelConfig( randomAlphaOfLength(10), + randomBoolean() ? null : randomFrom(TrainedModelType.values()), randomAlphaOfLength(10), Version.CURRENT, randomBoolean() ? null : randomAlphaOfLength(100), @@ -53,7 +54,8 @@ public static TrainedModelConfig createTestTrainedModelConfig() { .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))), targetType.equals(TargetType.CLASSIFICATION) ? ClassificationConfigTests.randomClassificationConfig() : - RegressionConfigTests.randomRegressionConfig()); + RegressionConfigTests.randomRegressionConfig(), + randomBoolean() ? null : IndexLocationTests.randomInstance()); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/langident/LangIdentNeuralNetworkTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/langident/LangIdentNeuralNetworkTests.java index 6c35ad65febcb..13955bb13d0a9 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/langident/LangIdentNeuralNetworkTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/langident/LangIdentNeuralNetworkTests.java @@ -10,7 +10,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider; import java.io.IOException; diff --git a/docs/java-rest/high-level/ml/put-trained-model.asciidoc b/docs/java-rest/high-level/ml/put-trained-model.asciidoc index b28ae7be50284..c1104f7ebaa37 100644 --- a/docs/java-rest/high-level/ml/put-trained-model.asciidoc +++ b/docs/java-rest/high-level/ml/put-trained-model.asciidoc @@ -36,11 +36,12 @@ include-tagged::{doc-tests-file}[{api}-config] <2> Optionally, if the {infer} definition is large, you may choose to compress it for transport. Do not supply both the compressed and uncompressed definitions. <3> The unique model id -<4> The input field names for the model definition -<5> Optionally, a human-readable description -<6> Optionally, an object map contain metadata about the model -<7> Optionally, an array of tags to organize the model -<8> The default inference config to use with the model. Must match the underlying +<4> The type of model being configured. If not set the type is inferred from the model definition +<5> The input field names for the model definition +<6> Optionally, a human-readable description +<7> Optionally, an object map contain metadata about the model +<8> Optionally, an array of tags to organize the model +<9> The default inference config to use with the model. Must match the underlying definition target_type. include::../execution.asciidoc[] diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/ml.infer_trained_model_deployment.json b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.infer_trained_model_deployment.json new file mode 100644 index 0000000000000..dd157151abe66 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.infer_trained_model_deployment.json @@ -0,0 +1,30 @@ +{ + "ml.infer_trained_model_deployment":{ + "documentation":{ + "url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-infer-trained-model-deployment.html", + "description":"Evaluate a trained model." + }, + "stability":"experimental", + "visibility":"public", + "headers":{ + "accept": [ "application/json"], + "content_type": ["application/json"] + }, + "url":{ + "paths":[ + { + "path":"/_ml/trained_models/{model_id}/deployment/_infer", + "methods":[ + "POST" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the model to perform inference on" + } + } + } + ] + } + } +} diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/ml.start_trained_model_deployment.json b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.start_trained_model_deployment.json new file mode 100644 index 0000000000000..d91159eeed208 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.start_trained_model_deployment.json @@ -0,0 +1,30 @@ +{ + "ml.start_trained_model_deployment":{ + "documentation":{ + "url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-start-trained-model-deployment.html", + "description":"Start a trained model deployment." + }, + "stability":"experimental", + "visibility":"public", + "headers":{ + "accept": [ "application/json"], + "content_type": ["application/json"] + }, + "url":{ + "paths":[ + { + "path":"/_ml/trained_models/{model_id}/deployment/_start", + "methods":[ + "POST" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the model to deploy" + } + } + } + ] + } + } +} diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/ml.stop_trained_model_deployment.json b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.stop_trained_model_deployment.json new file mode 100644 index 0000000000000..fcc6f05899a0b --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.stop_trained_model_deployment.json @@ -0,0 +1,30 @@ +{ + "ml.stop_trained_model_deployment":{ + "documentation":{ + "url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/stop-trained-model-deployment.html", + "description":"Stop a trained model deployment." + }, + "stability":"experimental", + "visibility":"public", + "headers":{ + "accept": [ "application/json"], + "content_type": ["application/json"] + }, + "url":{ + "paths":[ + { + "path":"/_ml/trained_models/{model_id}/deployment/_stop", + "methods":[ + "POST" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the model to undeploy" + } + } + } + ] + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java index c0259fc62b826..34bfd86078fa1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java @@ -29,11 +29,13 @@ public final class MlTasks { public static final String DATAFEED_TASK_NAME = "xpack/ml/datafeed"; public static final String DATA_FRAME_ANALYTICS_TASK_NAME = "xpack/ml/data_frame/analytics"; public static final String JOB_SNAPSHOT_UPGRADE_TASK_NAME = "xpack/ml/job/snapshot/upgrade"; + public static final String TRAINED_MODEL_DEPLOYMENT_TASK_NAME = "xpack/ml/trained_model/deployment"; public static final String JOB_TASK_ID_PREFIX = "job-"; public static final String DATAFEED_TASK_ID_PREFIX = "datafeed-"; public static final String DATA_FRAME_ANALYTICS_TASK_ID_PREFIX = "data_frame_analytics-"; public static final String JOB_SNAPSHOT_UPGRADE_TASK_ID_PREFIX = "job-snapshot-upgrade-"; + public static final String TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX = "trained_model_deployment-"; public static final PersistentTasksCustomMetadata.Assignment AWAITING_UPGRADE = new PersistentTasksCustomMetadata.Assignment(null, @@ -91,6 +93,10 @@ public static String dataFrameAnalyticsId(String taskId) { return taskId.substring(DATA_FRAME_ANALYTICS_TASK_ID_PREFIX.length()); } + public static String trainedModelDeploymentTaskId(String modelId) { + return TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + modelId; + } + @Nullable public static PersistentTasksCustomMetadata.PersistentTask getJobTask(String jobId, @Nullable PersistentTasksCustomMetadata tasks) { return tasks == null ? null : tasks.getTask(jobTaskId(jobId)); @@ -115,6 +121,12 @@ public static PersistentTasksCustomMetadata.PersistentTask getSnapshotUpgrade return tasks == null ? null : tasks.getTask(snapshotUpgradeTaskId(jobId, snapshotId)); } + @Nullable + public static PersistentTasksCustomMetadata.PersistentTask getTrainedModelDeploymentTask( + String modelId, @Nullable PersistentTasksCustomMetadata tasks) { + return tasks == null ? null : tasks.getTask(trainedModelDeploymentTaskId(modelId)); + } + /** * Note that the return value of this method does NOT take node relocations into account. * Use {@link #getJobStateModifiedForReassignments} to return a value adjusted to the most diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java index 4a4799d894275..0372fb9f8d328 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java @@ -119,7 +119,7 @@ public int hashCode() { public static class Request extends AbstractGetResourcesRequest { public static final ParseField INCLUDE = new ParseField("include"); - public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition"; + public static final String DEFINITION = "definition"; public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); public static final ParseField TAGS = new ParseField("tags"); @@ -138,6 +138,10 @@ public Request(String id, boolean includeModelDefinition, List tags) { } } + public Request(String id) { + this(id, null, null); + } + public Request(String id, List tags, Set includes) { setResourceId(id); setAllowNoResources(true); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..41d973d4b8a78 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.tasks.BaseTasksRequest; +import org.elasticsearch.action.support.tasks.BaseTasksResponse; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; + +import java.io.IOException; +import java.util.Collections; +import java.util.Objects; + +public class InferTrainedModelDeploymentAction extends ActionType { + + public static final InferTrainedModelDeploymentAction INSTANCE = new InferTrainedModelDeploymentAction(); + + // TODO Review security level + public static final String NAME = "cluster:monitor/xpack/ml/trained_models/deployment/infer"; + + public InferTrainedModelDeploymentAction() { + super(NAME, InferTrainedModelDeploymentAction.Response::new); + } + + public static class Request extends BaseTasksRequest implements ToXContentObject { + + public static final String DEPLOYMENT_ID = "deployment_id"; + public static final ParseField INPUT = new ParseField("input"); + + private static final ObjectParser PARSER = new ObjectParser<>(NAME, Request::new); + static { + PARSER.declareString((request, inputs) -> request.input = inputs, INPUT); + } + + public static Request parseRequest(String deploymentId, XContentParser parser) { + Request r = PARSER.apply(parser, null); + r.deploymentId = deploymentId; + return r; + } + + private String deploymentId; + private String input; + + private Request() { + } + + public Request(String deploymentId, String input) { + this.deploymentId = Objects.requireNonNull(deploymentId); + this.input = Objects.requireNonNull(input); + } + + public Request(StreamInput in) throws IOException { + super(in); + deploymentId = in.readString(); + input = in.readString(); + } + + public String getDeploymentId() { + return deploymentId; + } + + public String getInput() { + return input; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(deploymentId); + out.writeString(input); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(DEPLOYMENT_ID, deploymentId); + builder.field(INPUT.getPreferredName(), input); + builder.endObject(); + return builder; + } + + @Override + public boolean match(Task task) { + return StartTrainedModelDeploymentAction.TaskMatcher.match(task, deploymentId); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferTrainedModelDeploymentAction.Request that = (InferTrainedModelDeploymentAction.Request) o; + return Objects.equals(deploymentId, that.deploymentId) + && Objects.equals(input, that.input); + } + + @Override + public int hashCode() { + return Objects.hash(deploymentId, input); + } + } + + public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject { + + private final InferenceResults results; + + public Response(InferenceResults result) { + super(Collections.emptyList(), Collections.emptyList()); + this.results = Objects.requireNonNull(result); + } + + public Response(StreamInput in) throws IOException { + super(in); + results = in.readNamedWriteable(InferenceResults.class); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + results.toXContent(builder, params); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeNamedWriteable(results); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..485e7f3f5dd8b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -0,0 +1,197 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.MasterNodeRequest; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.persistent.PersistentTaskParams; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +public class StartTrainedModelDeploymentAction extends ActionType { + + public static final StartTrainedModelDeploymentAction INSTANCE = new StartTrainedModelDeploymentAction(); + public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/start"; + + public static final TimeValue DEFAULT_TIMEOUT = new TimeValue(20, TimeUnit.SECONDS); + + public StartTrainedModelDeploymentAction() { + super(NAME, NodeAcknowledgedResponse::new); + } + + public static class Request extends MasterNodeRequest implements ToXContentObject { + + private static final ParseField MODEL_ID = new ParseField("model_id"); + private static final ParseField TIMEOUT = new ParseField("timeout"); + + private String modelId; + private TimeValue timeout = DEFAULT_TIMEOUT; + + public Request(String modelId) { + setModelId(modelId); + } + + public Request(StreamInput in) throws IOException { + super(in); + modelId = in.readString(); + timeout = in.readTimeValue(); + } + + public final void setModelId(String modelId) { + this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); + } + + public String getModelId() { + return modelId; + } + + public void setTimeout(TimeValue timeout) { + this.timeout = ExceptionsHelper.requireNonNull(timeout, TIMEOUT); + } + + public TimeValue getTimeout() { + return timeout; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + out.writeTimeValue(timeout); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep()); + return builder; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public int hashCode() { + return Objects.hash(modelId, timeout); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + Request other = (Request) obj; + return Objects.equals(modelId, other.modelId) && Objects.equals(timeout, other.timeout); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static class TaskParams implements PersistentTaskParams { + + public static final Version VERSION_INTRODUCED = Version.V_8_0_0; + + private final String modelId; + private final String index; + + public TaskParams(String modelId, String index) { + this.modelId = Objects.requireNonNull(modelId); + this.index = Objects.requireNonNull(index); + } + + public TaskParams(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.index = in.readString(); + } + + public String getModelId() { + return modelId; + } + + public String getIndex() { + return index; + } + + @Override + public String getWriteableName() { + return MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + return VERSION_INTRODUCED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(index); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId); + builder.field(IndexLocation.INDEX.getPreferredName(), index); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(modelId); + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + TaskParams other = (TaskParams) o; + return Objects.equals(modelId, other.modelId); + } + } + + public interface TaskMatcher { + + static boolean match(Task task, String expectedId) { + if (task instanceof TaskMatcher) { + if (Strings.isAllOrWildcard(expectedId)) { + return true; + } + String expectedDescription = MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + expectedId; + return expectedDescription.equals(task.getDescription()); + } + return false; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..2fd52f5baa5d0 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentAction.java @@ -0,0 +1,161 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.tasks.BaseTasksRequest; +import org.elasticsearch.action.support.tasks.BaseTasksResponse; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class StopTrainedModelDeploymentAction extends ActionType { + + public static final StopTrainedModelDeploymentAction INSTANCE = new StopTrainedModelDeploymentAction(); + public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/stop"; + + public StopTrainedModelDeploymentAction() { + super(NAME, StopTrainedModelDeploymentAction.Response::new); + } + + public static class Request extends BaseTasksRequest implements ToXContentObject { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + public static final ParseField FORCE = new ParseField("force"); + + private String id; + private boolean allowNoMatch = true; + private boolean force; + + public Request(String id) { + setId(id); + } + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + allowNoMatch = in.readBoolean(); + force = in.readBoolean(); + } + + public final void setId(String id) { + this.id = ExceptionsHelper.requireNonNull(id, TrainedModelConfig.MODEL_ID); + } + + public String getId() { + return id; + } + + public void setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + } + + public boolean isAllowNoMatch() { + return allowNoMatch; + } + + public void setForce(boolean force) { + this.force = force; + } + + public boolean isForce() { + return force; + } + + @Override + public boolean match(Task task) { + return StartTrainedModelDeploymentAction.TaskMatcher.match(task, id); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + out.writeBoolean(allowNoMatch); + out.writeBoolean(force); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), id); + builder.field(ALLOW_NO_MATCH.getPreferredName(), allowNoMatch); + builder.field(FORCE.getPreferredName(), force); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(id, allowNoMatch, force); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + Request that = (Request) o; + return Objects.equals(id, that.id) && + allowNoMatch == that.allowNoMatch && + force == that.force; + } + } + + public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject { + + private final boolean undeployed; + + public Response(boolean undeployed) { + super(null, null); + this.undeployed = undeployed; + } + + public Response(StreamInput in) throws IOException { + super(in); + undeployed = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(undeployed); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentCommon(builder, params); + builder.field("stopped", undeployed); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(undeployed); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Response that = (Response) o; + return undeployed == that.undeployed; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java index 70f4998a7f96c..b7f7532668b77 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; @@ -29,8 +28,6 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.nio.charset.StandardCharsets; -import java.util.Base64; import java.util.Map; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; @@ -47,24 +44,24 @@ public final class InferenceToXContentCompressor { private InferenceToXContentCompressor() {} - public static String deflate(T objectToCompress) throws IOException { + public static BytesReference deflate(T objectToCompress) throws IOException { BytesReference reference = XContentHelper.toXContent(objectToCompress, XContentType.JSON, false); return deflate(reference); } - public static T inflate(String compressedString, + public static T inflate(BytesReference compressedBytes, CheckedFunction parserFunction, NamedXContentRegistry xContentRegistry) throws IOException { - return inflate(compressedString, parserFunction, xContentRegistry, MAX_INFLATED_BYTES); + return inflate(compressedBytes, parserFunction, xContentRegistry, MAX_INFLATED_BYTES); } - static T inflate(String compressedString, + static T inflate(BytesReference compressedBytes, CheckedFunction parserFunction, NamedXContentRegistry xContentRegistry, long maxBytes) throws IOException { try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, - inflate(compressedString, maxBytes))) { + inflate(compressedBytes, maxBytes))) { return parserFunction.apply(parser); } catch (XContentParseException parseException) { SimpleBoundedInputStream.StreamSizeExceededException streamSizeCause = @@ -82,32 +79,31 @@ static T inflate(String compressedString, } } - static Map inflateToMap(String compressedString) throws IOException { + static Map inflateToMap(BytesReference compressedBytes) throws IOException { // Don't need the xcontent registry as we are not deflating named objects. try(XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, - inflate(compressedString, MAX_INFLATED_BYTES))) { + inflate(compressedBytes, MAX_INFLATED_BYTES))) { return parser.mapOrdered(); } } - static InputStream inflate(String compressedString, long streamSize) throws IOException { - byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8)); + static InputStream inflate(BytesReference compressedBytes, long streamSize) throws IOException { // If the compressed length is already too large, it make sense that the inflated length would be as well // In the extremely small string case, the compressed data could actually be longer than the compressed stream - if (compressedBytes.length > Math.max(100L, streamSize)) { + if (compressedBytes.length() > Math.max(100L, streamSize)) { throw new CircuitBreakingException("compressed stream is longer than maximum allowed bytes [" + streamSize + "]", CircuitBreaker.Durability.PERMANENT); } - InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE); + InputStream gzipStream = new GZIPInputStream(compressedBytes.streamInput(), BUFFER_SIZE); return new SimpleBoundedInputStream(gzipStream, streamSize); } - private static String deflate(BytesReference reference) throws IOException { + private static BytesReference deflate(BytesReference reference) throws IOException { BytesStreamOutput out = new BytesStreamOutput(); try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) { reference.writeTo(compressedOutput); } - return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8); + return out.bytes(); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index e186e8885d351..2eb91e140ba0e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -19,22 +19,28 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.NerResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModelLocation; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModelLocation; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Exponent; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator; @@ -128,6 +134,16 @@ public List getNamedXContentParsers() { Exponent.NAME, Exponent::fromXContentStrict)); + // Location lenient + namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModelLocation.class, + IndexLocation.INDEX, + IndexLocation::fromXContentLenient)); + + // Location strict + namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModelLocation.class, + IndexLocation.INDEX, + IndexLocation::fromXContentStrict)); + // Inference Configs namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::fromXContentLenient)); @@ -173,7 +189,7 @@ public List getNamedWriteables() { // Model namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Ensemble.NAME.getPreferredName(), Ensemble::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(LangIdentNeuralNetwork.class, + namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, LangIdentNeuralNetwork.NAME.getPreferredName(), LangIdentNeuralNetwork::new)); @@ -201,6 +217,12 @@ public List getNamedWriteables() { namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, WarningInferenceResults.NAME, WarningInferenceResults::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, + NerResults.NAME, + NerResults::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, + FillMaskResults.NAME, + FillMaskResults::new)); // Inference Configs namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, @@ -217,6 +239,10 @@ public List getNamedWriteables() { namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class, EmptyConfigUpdate.NAME, EmptyConfigUpdate::new)); + // Location + namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModelLocation.class, + IndexLocation.INDEX.getPreferredName(), IndexLocation::new)); + return namedWriteables; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 3248e3f8bc2ad..19fefacc8a350 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -12,6 +12,8 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -26,7 +28,10 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModelLocation; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModelLocation; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaseline; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.Hyperparameters; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance; @@ -36,8 +41,11 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Arrays; +import java.util.Base64; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -64,6 +72,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage"; public static final ParseField MODEL_ID = new ParseField("model_id"); + public static final ParseField MODEL_TYPE = new ParseField("model_type"); public static final ParseField CREATED_BY = new ParseField("created_by"); public static final ParseField VERSION = new ParseField("version"); public static final ParseField DESCRIPTION = new ParseField("description"); @@ -78,6 +87,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final ParseField LICENSE_LEVEL = new ParseField("license_level"); public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map"); public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config"); + public static final ParseField LOCATION = new ParseField("location"); + + public static final Version VERSION_3RD_PARTY_CONFIG_ADDED = Version.V_8_0_0; // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly public static final ObjectParser LENIENT_PARSER = createParser(true); @@ -88,6 +100,7 @@ private static ObjectParser createParser(boole ignoreUnknownFields, TrainedModelConfig.Builder::new); parser.declareString(TrainedModelConfig.Builder::setModelId, MODEL_ID); + parser.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE); parser.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY); parser.declareString(TrainedModelConfig.Builder::setVersion, VERSION); parser.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION); @@ -113,6 +126,11 @@ private static ObjectParser createParser(boole p.namedObject(LenientlyParsedInferenceConfig.class, n, null) : p.namedObject(StrictlyParsedInferenceConfig.class, n, null), INFERENCE_CONFIG); + parser.declareNamedObject(TrainedModelConfig.Builder::setLocation, + (p, c, n) -> ignoreUnknownFields ? + p.namedObject(LenientlyParsedTrainedModelLocation.class, n, null) : + p.namedObject(StrictlyParsedTrainedModelLocation.class, n, null), + LOCATION); return parser; } @@ -125,6 +143,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo private final Version version; private final String description; private final Instant createTime; + private final TrainedModelType modelType; private final List tags; private final Map metadata; private final TrainedModelInput input; @@ -135,8 +154,10 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo private final InferenceConfig inferenceConfig; private final LazyModelDefinition definition; + private final TrainedModelLocation location; TrainedModelConfig(String modelId, + TrainedModelType modelType, String createdBy, Version version, String description, @@ -149,8 +170,10 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo Long estimatedOperations, String licenseLevel, Map defaultFieldMap, - InferenceConfig inferenceConfig) { + InferenceConfig inferenceConfig, + TrainedModelLocation location) { this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); + this.modelType = modelType; this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY); this.version = ExceptionsHelper.requireNonNull(version, VERSION); this.createTime = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(createTime, CREATE_TIME).toEpochMilli()); @@ -171,6 +194,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo this.licenseLevel = License.OperationMode.parse(ExceptionsHelper.requireNonNull(licenseLevel, LICENSE_LEVEL)); this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap); this.inferenceConfig = inferenceConfig; + this.location = location; } public TrainedModelConfig(StreamInput in) throws IOException { @@ -191,12 +215,24 @@ public TrainedModelConfig(StreamInput in) throws IOException { null; this.inferenceConfig = in.readOptionalNamedWriteable(InferenceConfig.class); + if (in.getVersion().onOrAfter(VERSION_3RD_PARTY_CONFIG_ADDED)) { + this.modelType = in.readOptionalEnum(TrainedModelType.class); + this.location = in.readOptionalNamedWriteable(TrainedModelLocation.class); + } else { + this.modelType = null; + this.location = null; + } } public String getModelId() { return modelId; } + @Nullable + public TrainedModelType getModelType() { + return this.modelType; + } + public String getCreatedBy() { return createdBy; } @@ -231,15 +267,15 @@ public InferenceConfig getInferenceConfig() { } @Nullable - public String getCompressedDefinition() throws IOException { + public BytesReference getCompressedDefinition() throws IOException { if (definition == null) { return null; } - return definition.getCompressedString(); + return definition.getCompressedDefinition(); } public void clearCompressed() { - definition.compressedString = null; + definition.compressedRepresentation = null; } public TrainedModelConfig ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException { @@ -258,6 +294,11 @@ public TrainedModelDefinition getModelDefinition() { return definition.parsedDefinition; } + @Nullable + public TrainedModelLocation getLocation() { + return location; + } + public TrainedModelInput getInput() { return input; } @@ -299,12 +340,19 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalNamedWriteable(inferenceConfig); + if (out.getVersion().onOrAfter(VERSION_3RD_PARTY_CONFIG_ADDED)) { + out.writeOptionalEnum(modelType); + out.writeOptionalNamedWriteable(location); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(MODEL_ID.getPreferredName(), modelId); + if (modelType != null) { + builder.field(MODEL_TYPE.getPreferredName(), modelType.toString()); + } // If the model is to be exported for future import to another cluster, these fields are irrelevant. if (params.paramAsBoolean(EXCLUDE_GENERATED, false) == false) { builder.field(CREATED_BY.getPreferredName(), createdBy); @@ -325,7 +373,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) { builder.field(DEFINITION.getPreferredName(), definition); } else { - builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString()); + builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getBase64CompressedDefinition()); } } builder.field(TAGS.getPreferredName(), tags); @@ -342,6 +390,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (inferenceConfig != null) { writeNamedObject(builder, params, INFERENCE_CONFIG.getPreferredName(), inferenceConfig); } + if (location != null) { + writeNamedObject(builder, params, LOCATION.getPreferredName(), location); + } builder.endObject(); return builder; } @@ -357,6 +408,7 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; TrainedModelConfig that = (TrainedModelConfig) o; return Objects.equals(modelId, that.modelId) && + Objects.equals(modelType, that.modelType) && Objects.equals(createdBy, that.createdBy) && Objects.equals(version, that.version) && Objects.equals(description, that.description) && @@ -369,12 +421,14 @@ public boolean equals(Object o) { Objects.equals(licenseLevel, that.licenseLevel) && Objects.equals(defaultFieldMap, that.defaultFieldMap) && Objects.equals(inferenceConfig, that.inferenceConfig) && - Objects.equals(metadata, that.metadata); + Objects.equals(metadata, that.metadata) && + Objects.equals(location, that.location); } @Override public int hashCode() { return Objects.hash(modelId, + modelType, createdBy, version, createTime, @@ -387,12 +441,14 @@ public int hashCode() { input, licenseLevel, inferenceConfig, - defaultFieldMap); + defaultFieldMap, + location); } public static class Builder { private String modelId; + private TrainedModelType modelType; private String createdBy; private Version version; private String description; @@ -406,11 +462,13 @@ public static class Builder { private String licenseLevel; private Map defaultFieldMap; private InferenceConfig inferenceConfig; + private TrainedModelLocation location; public Builder() {} public Builder(TrainedModelConfig config) { this.modelId = config.getModelId(); + this.modelType = config.getModelType(); this.createdBy = config.getCreatedBy(); this.version = config.getVersion(); this.createTime = config.getCreateTime(); @@ -424,6 +482,7 @@ public Builder(TrainedModelConfig config) { this.licenseLevel = config.licenseLevel.description(); this.defaultFieldMap = config.defaultFieldMap == null ? null : new HashMap<>(config.defaultFieldMap); this.inferenceConfig = config.inferenceConfig; + this.location = config.location; } public Builder setModelId(String modelId) { @@ -431,6 +490,20 @@ public Builder setModelId(String modelId) { return this; } + public TrainedModelType getModelType() { + return modelType; + } + + private Builder setModelType(String modelType) { + this.modelType = TrainedModelType.fromString(modelType); + return this; + } + + public Builder setModelType(TrainedModelType modelType) { + this.modelType = modelType; + return this; + } + public String getModelId() { return this.modelId; } @@ -440,6 +513,10 @@ public Builder setCreatedBy(String createdBy) { return this; } + public Version getVersion() { + return version; + } + public Builder setVersion(Version version) { this.version = version; return this; @@ -519,11 +596,11 @@ public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) { return this; } - public Builder setDefinitionFromString(String definitionFromString) { - if (definitionFromString == null) { + public Builder setDefinitionFromBytes(BytesReference definition) { + if (definition == null) { return this; } - this.definition = LazyModelDefinition.fromCompressedString(definitionFromString); + this.definition = LazyModelDefinition.fromCompressedData(definition); return this; } @@ -560,7 +637,12 @@ private Builder setLazyDefinition(String compressedString) { DEFINITION.getPreferredName()) .getFormattedMessage()); } - this.definition = LazyModelDefinition.fromCompressedString(compressedString); + this.definition = LazyModelDefinition.fromBase64String(compressedString); + return this; + } + + public Builder setLocation(TrainedModelLocation location) { + this.location = location; return this; } @@ -606,8 +688,17 @@ public Builder validate() { public Builder validate(boolean forCreation) { // We require a definition to be available here even though it will be stored in a different doc ActionRequestValidationException validationException = null; - if (definition == null) { - validationException = addValidationError("[" + DEFINITION.getPreferredName() + "] must not be null.", validationException); + if (definition == null && location == null) { + validationException = addValidationError("either a model [" + DEFINITION.getPreferredName() + "] " + + "or [" + LOCATION.getPreferredName() + "] must be defined.", validationException); + } + if (definition != null && location != null) { + validationException = addValidationError("[" + DEFINITION.getPreferredName() + "] " + + "and [" + LOCATION.getPreferredName() + "] are both defined but only one can be used.", validationException); + } + if (definition == null && modelType == null) { + validationException = addValidationError("[" + MODEL_TYPE.getPreferredName() + "] must be set if " + + "[" + DEFINITION.getPreferredName() + "] is not defined.", validationException); } if (modelId == null) { validationException = addValidationError("[" + MODEL_ID.getPreferredName() + "] must not be null.", validationException); @@ -698,6 +789,7 @@ private static ActionRequestValidationException checkIllegalSetting(Object value public TrainedModelConfig build() { return new TrainedModelConfig( modelId, + modelType, createdBy == null ? "user" : createdBy, version == null ? Version.CURRENT : version, description, @@ -710,60 +802,83 @@ public TrainedModelConfig build() { estimatedOperations == null ? 0 : estimatedOperations, licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel, defaultFieldMap, - inferenceConfig); + inferenceConfig, + location); } } - public static class LazyModelDefinition implements ToXContentObject, Writeable { + static class LazyModelDefinition implements ToXContentObject, Writeable { - private String compressedString; + private BytesReference compressedRepresentation; private TrainedModelDefinition parsedDefinition; public static LazyModelDefinition fromParsedDefinition(TrainedModelDefinition definition) { return new LazyModelDefinition(null, definition); } - public static LazyModelDefinition fromCompressedString(String compressedString) { - return new LazyModelDefinition(compressedString, null); + public static LazyModelDefinition fromCompressedData(BytesReference compressed) { + return new LazyModelDefinition(compressed, null); + } + + public static LazyModelDefinition fromBase64String(String base64String) { + byte[] decodedBytes = Base64.getDecoder().decode(base64String); + return new LazyModelDefinition(new BytesArray(decodedBytes), null); } public static LazyModelDefinition fromStreamInput(StreamInput input) throws IOException { - return new LazyModelDefinition(input.readString(), null); + if (input.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO adjust on backport + return new LazyModelDefinition(input.readBytesReference(), null); + } else { + return fromBase64String(input.readString()); + } } private LazyModelDefinition(LazyModelDefinition definition) { if (definition != null) { - this.compressedString = definition.compressedString; + this.compressedRepresentation = definition.compressedRepresentation; this.parsedDefinition = definition.parsedDefinition; } } - private LazyModelDefinition(String compressedString, TrainedModelDefinition trainedModelDefinition) { - if (compressedString == null && trainedModelDefinition == null) { + private LazyModelDefinition(BytesReference compressedRepresentation, TrainedModelDefinition trainedModelDefinition) { + if (compressedRepresentation == null && trainedModelDefinition == null) { throw new IllegalArgumentException("unexpected null model definition"); } - this.compressedString = compressedString; + this.compressedRepresentation = compressedRepresentation; this.parsedDefinition = trainedModelDefinition; } - public void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException { - if (parsedDefinition == null) { - parsedDefinition = InferenceToXContentCompressor.inflate(compressedString, - parser -> TrainedModelDefinition.fromXContent(parser, true).build(), - xContentRegistry); + private BytesReference getCompressedDefinition() throws IOException { + if (compressedRepresentation == null) { + compressedRepresentation = InferenceToXContentCompressor.deflate(parsedDefinition); } + return compressedRepresentation; } - public String getCompressedString() throws IOException { - if (compressedString == null) { - compressedString = InferenceToXContentCompressor.deflate(parsedDefinition); + private String getBase64CompressedDefinition() throws IOException { + BytesReference compressedDef = getCompressedDefinition(); + + ByteBuffer bb = Base64.getEncoder().encode( + ByteBuffer.wrap(compressedDef.array(), compressedDef.arrayOffset(), compressedDef.length())); + + return new String(bb.array(), StandardCharsets.UTF_8); + } + + private void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException { + if (parsedDefinition == null) { + parsedDefinition = InferenceToXContentCompressor.inflate(compressedRepresentation, + parser -> TrainedModelDefinition.fromXContent(parser, true).build(), + xContentRegistry); } - return compressedString; } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(getCompressedString()); + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO adjust on backport + out.writeBytesReference(getCompressedDefinition()); + } else { + out.writeString(getBase64CompressedDefinition()); + } } @Override @@ -771,7 +886,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (parsedDefinition != null) { return parsedDefinition.toXContent(builder, params); } - Map map = InferenceToXContentCompressor.inflateToMap(compressedString); + Map map = InferenceToXContentCompressor.inflateToMap(compressedRepresentation); return builder.map(map); } @@ -780,15 +895,13 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; LazyModelDefinition that = (LazyModelDefinition) o; - return Objects.equals(compressedString, that.compressedString) && + return Objects.equals(compressedRepresentation, that.compressedRepresentation) && Objects.equals(parsedDefinition, that.parsedDefinition); } @Override public int hashCode() { - return Objects.hash(compressedString, parsedDefinition); + return Objects.hash(compressedRepresentation, parsedDefinition); } - } - } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelType.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelType.java new file mode 100644 index 0000000000000..087811ffaf1fb --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelType.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference; + +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; + +import java.util.Locale; + +public enum TrainedModelType { + + TREE_ENSEMBLE, + LANG_IDENT, + PYTORCH { + @Override + public boolean hasInferenceDefinition() { + return false; + } + }; + + public static TrainedModelType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + /** + * Introspect the given model and return the model type + * representing it. + * @param model A Trained model + * @return The model type or null if unknown + */ + public static TrainedModelType typeFromTrainedModel(TrainedModel model) { + if (model instanceof Ensemble || model instanceof Tree) { + return TrainedModelType.TREE_ENSEMBLE; + } else if (model instanceof LangIdentNeuralNetwork) { + return TrainedModelType.LANG_IDENT; + } else { + return null; + } + } + + public boolean hasInferenceDefinition() { + return true; + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/PyTorchResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/PyTorchResult.java new file mode 100644 index 0000000000000..5cc0cf0c4d4e2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/PyTorchResult.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.deployment; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.MlParserUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/* + * TODO This does not necessarily belong in core. Will have to reconsider + * once we figure the format we store inference results in client calls. +*/ +public class PyTorchResult implements ToXContentObject, Writeable { + + private static final ParseField REQUEST_ID = new ParseField("request_id"); + private static final ParseField INFERENCE = new ParseField("inference"); + private static final ParseField ERROR = new ParseField("error"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("pytorch_result", + a -> new PyTorchResult((String) a[0], (double[][]) a[1], (String) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), REQUEST_ID); + PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> { + List> listOfListOfDoubles = MlParserUtils.parseArrayOfArrays( + INFERENCE.getPreferredName(), XContentParser::doubleValue, p); + double[][] primitiveDoubles = new double[listOfListOfDoubles.size()][]; + for (int i = 0; i < listOfListOfDoubles.size(); i++) { + List row = listOfListOfDoubles.get(i); + primitiveDoubles[i] = row.stream().mapToDouble(d -> d).toArray(); + } + return primitiveDoubles; + }, + INFERENCE, + ObjectParser.ValueType.VALUE_ARRAY + ); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), ERROR); + } + + private final String requestId; + private final double[][] inference; + private final String error; + + public PyTorchResult(String requestId, @Nullable double[][] inference, @Nullable String error) { + this.requestId = Objects.requireNonNull(requestId); + this.inference = inference; + this.error = error; + } + + public PyTorchResult(StreamInput in) throws IOException { + requestId = in.readString(); + boolean hasInference = in.readBoolean(); + if (hasInference) { + inference = in.readArray(StreamInput::readDoubleArray, length -> new double[length][]); + } else { + inference = null; + } + error = in.readOptionalString(); + } + + public String getRequestId() { + return requestId; + } + + public boolean isError() { + return error != null; + } + + public String getError() { + return error; + } + + public double[][] getInferenceResult() { + return inference; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(REQUEST_ID.getPreferredName(), requestId); + if (inference != null) { + builder.field(INFERENCE.getPreferredName(), inference); + } + if (error != null) { + builder.field(ERROR.getPreferredName(), error); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(requestId); + if (inference == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeArray(StreamOutput::writeDoubleArray, inference); + } + out.writeOptionalString(error); + } + + @Override + public int hashCode() { + return Objects.hash(requestId, Arrays.hashCode(inference), error); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + + PyTorchResult that = (PyTorchResult) other; + return Objects.equals(requestId, that.requestId) + && Objects.equals(inference, that.inference) + && Objects.equals(error, that.error); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentState.java new file mode 100644 index 0000000000000..b63b903809e3d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentState.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.deployment; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Locale; + +public enum TrainedModelDeploymentState implements Writeable { + + STARTING, STARTED, STOPPING, STOPPED; + + public static TrainedModelDeploymentState fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static TrainedModelDeploymentState fromStream(StreamInput in) throws IOException { + return in.readEnum(TrainedModelDeploymentState.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(this); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java new file mode 100644 index 0000000000000..29641b6b5512b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.deployment; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.persistent.PersistentTaskState; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; + +import java.io.IOException; +import java.util.Objects; + +public class TrainedModelDeploymentTaskState implements PersistentTaskState { + + public static final String NAME = MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME; + + private static ParseField STATE = new ParseField("state"); + private static ParseField ALLOCATION_ID = new ParseField("allocation_id"); + private static ParseField REASON = new ParseField("reason"); + + private final TrainedModelDeploymentState state; + private final long allocationId; + private final String reason; + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, true, + a -> new TrainedModelDeploymentTaskState((TrainedModelDeploymentState) a[0], (long) a[1], (String) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsState::fromString, STATE); + PARSER.declareLong(ConstructingObjectParser.constructorArg(), ALLOCATION_ID); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON); + } + + public static TrainedModelDeploymentTaskState fromXContent(XContentParser parser) { + try { + return PARSER.parse(parser, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public TrainedModelDeploymentTaskState(TrainedModelDeploymentState state, long allocationId, @Nullable String reason) { + this.state = Objects.requireNonNull(state); + this.allocationId = allocationId; + this.reason = reason; + } + + public TrainedModelDeploymentTaskState(StreamInput in) throws IOException { + this.state = TrainedModelDeploymentState.fromStream(in); + this.allocationId = in.readLong(); + this.reason = in.readOptionalString(); + } + + public TrainedModelDeploymentState getState() { + return state; + } + + public String getReason() { + return reason; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(STATE.getPreferredName(), state.toString()); + builder.field(ALLOCATION_ID.getPreferredName(), allocationId); + if (reason != null) { + builder.field(REASON.getPreferredName(), reason); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + state.writeTo(out); + out.writeLong(allocationId); + out.writeOptionalString(reason); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelDeploymentTaskState that = (TrainedModelDeploymentTaskState) o; + return allocationId == that.allocationId && + state == that.state && + Objects.equals(reason, that.reason); + } + + @Override + public int hashCode() { + return Objects.hash(state, allocationId, reason); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java index 5f93c6d7b7cfc..f95cf84df99cd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java @@ -26,8 +26,11 @@ public final class InferenceIndexConstants { * * version: 7.10.0: 000003 * - adds trained_model_metadata object + * + * version: 8.0.0: 000004 + * - adds binary_definition for TrainedModelDefinitionDoc */ - public static final String INDEX_VERSION = "000003"; + public static final String INDEX_VERSION = "000004"; public static final String INDEX_NAME_PREFIX = ".ml-inference-"; public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*"; public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java index 3e42b397f1089..1ddedec7cb7c0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.core.ml.inference.preprocessing; import org.apache.lucene.util.RamUsageEstimator; -import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -22,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor; +import org.elasticsearch.xpack.core.ml.utils.MlParserUtils; import java.io.IOException; import java.util.ArrayList; @@ -63,7 +63,7 @@ private static ConstructingObjectParser { - List> listOfListOfShorts = parseArrays(EMBEDDING_QUANT_SCALES.getPreferredName(), + List> listOfListOfShorts = MlParserUtils.parseArrayOfArrays(EMBEDDING_QUANT_SCALES.getPreferredName(), XContentParser::shortValue, p); short[][] primitiveShorts = new short[listOfListOfShorts.size()][]; @@ -99,30 +99,6 @@ private static ConstructingObjectParser List> parseArrays(String fieldName, - CheckedFunction fromParser, - XContentParser p) throws IOException { - if (p.currentToken() != XContentParser.Token.START_ARRAY) { - throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]"); - } - List> values = new ArrayList<>(); - while(p.nextToken() != XContentParser.Token.END_ARRAY) { - if (p.currentToken() != XContentParser.Token.START_ARRAY) { - throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]"); - } - List innerList = new ArrayList<>(); - while(p.nextToken() != XContentParser.Token.END_ARRAY) { - if(p.currentToken().isValue() == false) { - throw new IllegalStateException("expected non-null value but got [" + p.currentToken() + "] " + - "for [" + fieldName + "]"); - } - innerList.add(fromParser.apply(p)); - } - values.add(innerList); - } - return values; - } - public static CustomWordEmbedding fromXContentStrict(XContentParser parser) { return STRICT_PARSER.apply(parser, PreProcessorParseContext.DEFAULT); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java new file mode 100644 index 0000000000000..b89defe347977 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java @@ -0,0 +1,165 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class FillMaskResults implements InferenceResults { + + public static final String NAME = "fill_mask_result"; + public static final String DEFAULT_RESULTS_FIELD = "results"; + + private final List predictions; + + public FillMaskResults(List predictions) { + this.predictions = predictions; + } + + public FillMaskResults(StreamInput in) throws IOException { + this.predictions = in.readList(Prediction::new); + } + + public List getPredictions() { + return predictions; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startArray(); + for (Prediction prediction : predictions) { + prediction.toXContent(builder, params); + } + builder.endArray(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(predictions); + } + + @Override + public Map asMap() { + Map map = new LinkedHashMap<>(); + map.put(DEFAULT_RESULTS_FIELD, predictions.stream().map(Prediction::toMap).collect(Collectors.toList())); + return map; + } + + @Override + public Object predictedValue() { + if (predictions.isEmpty()) { + return null; + } + return predictions.get(0).token; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FillMaskResults that = (FillMaskResults) o; + return Objects.equals(predictions, that.predictions); + } + + @Override + public int hashCode() { + return Objects.hash(predictions); + } + + public static class Prediction implements ToXContentObject, Writeable { + + private static final ParseField TOKEN = new ParseField("token"); + private static final ParseField SCORE = new ParseField("score"); + private static final ParseField SEQUENCE = new ParseField("sequence"); + + private final String token; + private final double score; + private final String sequence; + + public Prediction(String token, double score, String sequence) { + this.token = Objects.requireNonNull(token); + this.score = score; + this.sequence = Objects.requireNonNull(sequence); + } + + public Prediction(StreamInput in) throws IOException { + token = in.readString(); + score = in.readDouble(); + sequence = in.readString(); + } + + public double getScore() { + return score; + } + + public String getSequence() { + return sequence; + } + + public String getToken() { + return token; + } + + public Map toMap() { + Map map = new LinkedHashMap<>(); + map.put(TOKEN.getPreferredName(), token); + map.put(SCORE.getPreferredName(), score); + map.put(SEQUENCE.getPreferredName(), sequence); + return map; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TOKEN.getPreferredName(), token); + builder.field(SCORE.getPreferredName(), score); + builder.field(SEQUENCE.getPreferredName(), sequence); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(token); + out.writeDouble(score); + out.writeString(sequence); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Prediction result = (Prediction) o; + return Double.compare(result.score, score) == 0 && + Objects.equals(token, result.token) && + Objects.equals(sequence, result.sequence); + } + + @Override + public int hashCode() { + return Objects.hash(token, score, sequence); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java new file mode 100644 index 0000000000000..69a236c8ad178 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class NerResults implements InferenceResults { + + public static final String NAME = "ner_result"; + + private final List entityGroups; + + public NerResults(List entityGroups) { + this.entityGroups = Objects.requireNonNull(entityGroups); + } + + public NerResults(StreamInput in) throws IOException { + entityGroups = in.readList(EntityGroup::new); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startArray(); + for (EntityGroup entity : entityGroups) { + entity.toXContent(builder, params); + } + builder.endArray(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(entityGroups); + } + + @Override + public Map asMap() { + Map map = new LinkedHashMap<>(); + map.put(FillMaskResults.DEFAULT_RESULTS_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList())); + return map; + } + + @Override + public Object predictedValue() { + // Used by the inference aggregation + throw new UnsupportedOperationException("Named Entity Recognition does not support a single predicted value"); + } + + public List getEntityGroups() { + return entityGroups; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NerResults that = (NerResults) o; + return Objects.equals(entityGroups, that.entityGroups); + } + + @Override + public int hashCode() { + return Objects.hash(entityGroups); + } + + public static class EntityGroup implements ToXContentObject, Writeable { + + private static final ParseField LABEL = new ParseField("label"); + private static final ParseField SCORE = new ParseField("score"); + private static final ParseField WORD = new ParseField("word"); + + private final String label; + private final double score; + private final String word; + + public EntityGroup(String label, double score, String word) { + this.label = Objects.requireNonNull(label); + this.score = score; + this.word = Objects.requireNonNull(word); + } + + public EntityGroup(StreamInput in) throws IOException { + label = in.readString(); + score = in.readDouble(); + word = in.readString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(LABEL.getPreferredName(), label); + builder.field(SCORE.getPreferredName(), score); + builder.field(WORD.getPreferredName(), word); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(label); + out.writeDouble(score); + out.writeString(word); + } + + public Map toMap() { + Map map = new LinkedHashMap<>(); + map.put(LABEL.getPreferredName(), label); + map.put(SCORE.getPreferredName(), score); + map.put(WORD.getPreferredName(), word); + return map; + } + + public String getLabel() { + return label; + } + + public double getScore() { + return score; + } + + public String getWord() { + return word; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + EntityGroup that = (EntityGroup) o; + return Double.compare(that.score, score) == 0 && + Objects.equals(label, that.label) && + Objects.equals(word, that.word); + } + + @Override + public int hashCode() { + return Objects.hash(label, score, word); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java new file mode 100644 index 0000000000000..fab9b8400efb3 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class IndexLocation implements StrictlyParsedTrainedModelLocation, LenientlyParsedTrainedModelLocation { + + public static final ParseField INDEX = new ParseField("index"); + private static final ParseField MODEL_ID = new ParseField("model_id"); + private static final ParseField NAME = new ParseField("name"); + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( + NAME.getPreferredName(), + lenient, + a -> new IndexLocation((String) a[0], (String) a[1])); + parser.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID); + parser.declareString(ConstructingObjectParser.constructorArg(), NAME); + return parser; + } + + public static IndexLocation fromXContentStrict(XContentParser parser) throws IOException { + return STRICT_PARSER.parse(parser, null); + } + + public static IndexLocation fromXContentLenient(XContentParser parser) throws IOException { + return LENIENT_PARSER.parse(parser, null); + } + + private final String modelId; + private final String indexName; + + IndexLocation(String modelId, String indexName) { + this.modelId = Objects.requireNonNull(modelId); + this.indexName = Objects.requireNonNull(indexName); + } + + public IndexLocation(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.indexName = in.readString(); + } + + @Override + public String getModelId() { + return modelId; + } + + public String getIndexName() { + return indexName; + } + + @Override + public String getResourceName() { + return getIndexName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.field(NAME.getPreferredName(), indexName); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(indexName); + } + + @Override + public String getWriteableName() { + return INDEX.getPreferredName(); + } + + @Override + public String getName() { + return INDEX.getPreferredName(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + IndexLocation that = (IndexLocation) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(indexName, that.indexName); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, indexName); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LenientlyParsedTrainedModelLocation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LenientlyParsedTrainedModelLocation.java new file mode 100644 index 0000000000000..c70062416dbda --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LenientlyParsedTrainedModelLocation.java @@ -0,0 +1,11 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +public interface LenientlyParsedTrainedModelLocation extends TrainedModelLocation{ +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/StrictlyParsedTrainedModelLocation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/StrictlyParsedTrainedModelLocation.java new file mode 100644 index 0000000000000..36c027583926d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/StrictlyParsedTrainedModelLocation.java @@ -0,0 +1,11 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +public interface StrictlyParsedTrainedModelLocation extends TrainedModelLocation{ +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TargetType.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TargetType.java index 537408415b066..b6a80b6c89cb5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TargetType.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TargetType.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -17,6 +18,8 @@ public enum TargetType implements Writeable { REGRESSION, CLASSIFICATION; + public static final ParseField TARGET_TYPE = new ParseField("target_type"); + public static TargetType fromString(String name) { return valueOf(name.trim().toUpperCase(Locale.ROOT)); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModelLocation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModelLocation.java new file mode 100644 index 0000000000000..046412e8d1a37 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModelLocation.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + +public interface TrainedModelLocation extends NamedXContentObject, NamedWriteable { + + String getModelId(); + + String getResourceName(); + + default Version getMinimalCompatibilityVersion() { + return TrainedModelConfig.VERSION_3RD_PARTY_CONFIG_ADDED; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 01a8e15a58637..166df98099df5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -42,7 +42,6 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); public static final ParseField TRAINED_MODELS = new ParseField("trained_models"); public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output"); - public static final ParseField TARGET_TYPE = new ParseField("target_type"); public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); public static final ParseField CLASSIFICATION_WEIGHTS = new ParseField("classification_weights"); @@ -66,7 +65,7 @@ private static ObjectParser createParser(boolean lenient lenient ? p.namedObject(LenientlyParsedOutputAggregator.class, n, null) : p.namedObject(StrictlyParsedOutputAggregator.class, n, null), AGGREGATE_OUTPUT); - parser.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE); + parser.declareString(Ensemble.Builder::setTargetType, TargetType.TARGET_TYPE); parser.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS); parser.declareDoubleArray(Ensemble.Builder::setClassificationWeights, CLASSIFICATION_WEIGHTS); return parser; @@ -96,7 +95,7 @@ public static Ensemble fromXContentLenient(XContentParser parser) { this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES)); this.models = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(models, TRAINED_MODELS)); this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT); - this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); + this.targetType = ExceptionsHelper.requireNonNull(targetType, TargetType.TARGET_TYPE); this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); this.classificationWeights = classificationWeights == null ? null : @@ -163,7 +162,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws false, AGGREGATE_OUTPUT.getPreferredName(), Collections.singletonList(outputAggregator)); - builder.field(TARGET_TYPE.getPreferredName(), targetType.toString()); + builder.field(TargetType.TARGET_TYPE.getPreferredName(), targetType.toString()); if (classificationLabels != null) { builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java index b1090e9583c7e..1b96c988fcf79 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java @@ -50,7 +50,6 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS; -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TARGET_TYPE; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS; public class EnsembleInferenceModel implements InferenceModel { @@ -75,7 +74,7 @@ public class EnsembleInferenceModel implements InferenceModel { PARSER.declareNamedObject(constructorArg(), (p, c, n) -> p.namedObject(LenientlyParsedOutputAggregator.class, n, null), AGGREGATE_OUTPUT); - PARSER.declareString(constructorArg(), TARGET_TYPE); + PARSER.declareString(constructorArg(), TargetType.TARGET_TYPE); PARSER.declareStringArray(optionalConstructorArg(), CLASSIFICATION_LABELS); PARSER.declareDoubleArray(optionalConstructorArg(), CLASSIFICATION_WEIGHTS); } @@ -99,7 +98,7 @@ private EnsembleInferenceModel(List models, List classificationWeights) { this.models = ExceptionsHelper.requireNonNull(models, TRAINED_MODELS); this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT); - this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); + this.targetType = ExceptionsHelper.requireNonNull(targetType, TargetType.TARGET_TYPE); this.classificationLabels = classificationLabels; this.classificationWeights = classificationWeights == null ? null : diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java index 432c65abcfc48..54ba0d26cb0c1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java @@ -48,7 +48,6 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree.CLASSIFICATION_LABELS; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree.FEATURE_NAMES; -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree.TARGET_TYPE; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree.TREE_STRUCTURE; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.DECISION_TYPE; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.DEFAULT_LEFT; @@ -77,7 +76,7 @@ public class TreeInferenceModel implements InferenceModel { static { PARSER.declareStringArray(constructorArg(), FEATURE_NAMES); PARSER.declareObjectArray(constructorArg(), NodeBuilder.PARSER::apply, TREE_STRUCTURE); - PARSER.declareString(optionalConstructorArg(), TARGET_TYPE); + PARSER.declareString(optionalConstructorArg(), TargetType.TARGET_TYPE); PARSER.declareStringArray(optionalConstructorArg(), CLASSIFICATION_LABELS); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 4294b901058ac..cd2986f9909c5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -43,7 +43,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure"); - public static final ParseField TARGET_TYPE = new ParseField("target_type"); public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); private static final ObjectParser LENIENT_PARSER = createParser(true); @@ -56,7 +55,7 @@ private static ObjectParser createParser(boolean lenient) { Tree.Builder::new); parser.declareStringArray(Tree.Builder::setFeatureNames, FEATURE_NAMES); parser.declareObjectArray(Tree.Builder::setNodes, (p, c) -> TreeNode.fromXContent(p, lenient), TREE_STRUCTURE); - parser.declareString(Tree.Builder::setTargetType, TARGET_TYPE); + parser.declareString(Tree.Builder::setTargetType, TargetType.TARGET_TYPE); parser.declareStringArray(Tree.Builder::setClassificationLabels, CLASSIFICATION_LABELS); return parser; } @@ -80,7 +79,7 @@ public static Tree fromXContentLenient(XContentParser parser) { throw new IllegalArgumentException("[tree_structure] must not be empty"); } this.nodes = Collections.unmodifiableList(nodes); - this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); + this.targetType = ExceptionsHelper.requireNonNull(targetType, TargetType.TARGET_TYPE); this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); } @@ -126,7 +125,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(FEATURE_NAMES.getPreferredName(), featureNames); builder.field(TREE_STRUCTURE.getPreferredName(), nodes); - builder.field(TARGET_TYPE.getPreferredName(), targetType.toString()); + builder.field(TargetType.TARGET_TYPE.getPreferredName(), targetType.toString()); if(classificationLabels != null) { builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 4292b0a89cc09..b86060dae5114 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -115,6 +115,7 @@ public final class Messages { "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]"; public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]"; public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata {0}"; + public static final String TASK_CONFIG_NOT_FOUND = "Could not find task config for model [{0}]"; public static final String INFERENCE_CANNOT_DELETE_ML_MANAGED_MODEL = "Unable to delete model [{0}] as it is required by machine learning"; public static final String MODEL_DEFINITION_TRUNCATED = diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MlParserUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MlParserUtils.java new file mode 100644 index 0000000000000..3afc3db3d1ead --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MlParserUtils.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.utils; + +import org.elasticsearch.common.CheckedFunction; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public final class MlParserUtils { + + private MlParserUtils() {} + + /** + * Parses an array of arrays of the given type + * + * @param fieldName the field name + * @param valueParser the parser to use for the inner array values + * @param parser the outer parser + * @param the type of the values of the inner array + * @return a list of lists representing the array of arrays + * @throws IOException an exception if parsing fails + */ + public static List> parseArrayOfArrays(String fieldName, CheckedFunction valueParser, + XContentParser parser) throws IOException { + if (parser.currentToken() != XContentParser.Token.START_ARRAY) { + throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]"); + } + List> values = new ArrayList<>(); + while(parser.nextToken() != XContentParser.Token.END_ARRAY) { + if (parser.currentToken() != XContentParser.Token.START_ARRAY) { + throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]"); + } + List innerList = new ArrayList<>(); + while(parser.nextToken() != XContentParser.Token.END_ARRAY) { + if(parser.currentToken().isValue() == false) { + throw new IllegalStateException("expected non-null value but got [" + parser.currentToken() + "] " + + "for [" + fieldName + "]"); + } + innerList.add(valueParser.apply(parser)); + } + values.add(innerList); + } + return values; + } +} diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_mappings.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_mappings.json index 171cbabc52c30..bf164f049dc43 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_mappings.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_mappings.json @@ -44,6 +44,9 @@ "definition": { "enabled": false }, + "binary_definition": { + "type": "binary" + }, "compression_version": { "type": "long" }, @@ -135,7 +138,7 @@ "supplied": { "type": "boolean" } - } + } } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java index 47a131f5d758c..6a98cdb49550a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java @@ -8,6 +8,8 @@ import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.test.ESTestCase; @@ -16,7 +18,6 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -27,7 +28,7 @@ public class InferenceToXContentCompressorTests extends ESTestCase { public void testInflateAndDeflate() throws IOException { for(int i = 0; i < 10; i++) { TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build(); - String firstDeflate = InferenceToXContentCompressor.deflate(definition); + BytesReference firstDeflate = InferenceToXContentCompressor.deflate(definition); TrainedModelDefinition inflatedDefinition = InferenceToXContentCompressor.inflate(firstDeflate, parser -> TrainedModelDefinition.fromXContent(parser, false).build(), xContentRegistry()); @@ -45,8 +46,8 @@ public void testInflateTooLargeStream() throws IOException { .limit(100) .collect(Collectors.toList())) .build(); - String firstDeflate = InferenceToXContentCompressor.deflate(definition); - int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10; + BytesReference firstDeflate = InferenceToXContentCompressor.deflate(definition); + int max = firstDeflate.length() + 10; IOException ex = expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max))); assertThat(ex.getMessage(), equalTo("" + @@ -54,7 +55,8 @@ public void testInflateTooLargeStream() throws IOException { } public void testInflateGarbage() { - expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L))); + expectThrows(IOException.class, () -> Streams.readFully( + InferenceToXContentCompressor.inflate(new BytesArray(randomByteArrayOfLength(10)), 100L))); } public void testInflateParsingTooLargeStream() throws IOException { @@ -65,8 +67,8 @@ public void testInflateParsingTooLargeStream() throws IOException { .limit(100) .collect(Collectors.toList())) .build(); - String compressedString = InferenceToXContentCompressor.deflate(definition); - int max = compressedString.getBytes(StandardCharsets.UTF_8).length + 10; + BytesReference compressedString = InferenceToXContentCompressor.deflate(definition); + int max = compressedString.length() + 10; CircuitBreakingException e = expectThrows(CircuitBreakingException.class, ()-> InferenceToXContentCompressor.inflate( compressedString, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index a23220b69f8e0..a320b4557dfd8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocationTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.MlStrings; @@ -54,6 +55,7 @@ public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase tags = Arrays.asList(generateRandomStringArray(randomIntBetween(0, 5), 15, false)); return TrainedModelConfig.builder() .setInput(TrainedModelInputTests.createRandomInput()) @@ -61,8 +63,9 @@ public static TrainedModelConfig.Builder createTestInstance(String modelId) { .setCreateTime(Instant.ofEpochMilli(randomLongBetween(Instant.MIN.getEpochSecond(), Instant.MAX.getEpochSecond()))) .setVersion(Version.CURRENT) .setModelId(modelId) + .setModelType(randomFrom(TrainedModelType.values())) .setCreatedBy(randomAlphaOfLength(10)) - .setDescription(randomBoolean() ? null : randomAlphaOfLength(100)) + .setDescription(randomBoolean() ? null : randomAlphaOfLength(10)) .setEstimatedHeapMemory(randomNonNegativeLong()) .setEstimatedOperations(randomNonNegativeLong()) .setLicenseLevel(randomFrom(License.OperationMode.PLATINUM.description(), @@ -71,7 +74,8 @@ public static TrainedModelConfig.Builder createTestInstance(String modelId) { License.OperationMode.BASIC.description())) .setInferenceConfig(randomFrom(ClassificationConfigTests.randomClassificationConfig(), RegressionConfigTests.randomRegressionConfig())) - .setTags(tags); + .setTags(tags) + .setLocation(randomBoolean() ? null : IndexLocationTests.randomInstance()); } @Before @@ -114,8 +118,7 @@ protected NamedXContentRegistry xContentRegistry() { @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - List entries = new ArrayList<>(); - entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + List entries = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables()); return new NamedWriteableRegistry(entries); } @@ -134,6 +137,7 @@ public void testToXContentWithParams() throws IOException { .fromParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder().build()); TrainedModelConfig config = new TrainedModelConfig( randomAlphaOfLength(10), + TrainedModelType.TREE_ENSEMBLE, randomAlphaOfLength(10), Version.CURRENT, randomBoolean() ? null : randomAlphaOfLength(100), @@ -149,7 +153,8 @@ public void testToXContentWithParams() throws IOException { Stream.generate(() -> randomAlphaOfLength(10)) .limit(randomIntBetween(1, 10)) .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))), - randomFrom(ClassificationConfigTests.randomClassificationConfig(), RegressionConfigTests.randomRegressionConfig())); + randomFrom(ClassificationConfigTests.randomClassificationConfig(), RegressionConfigTests.randomRegressionConfig()), + null); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); assertThat(reference.utf8ToString(), containsString("\"compressed_definition\"")); @@ -168,12 +173,13 @@ public void testToXContentWithParams() throws IOException { assertThat(reference.utf8ToString(), containsString("\"definition\"")); assertThat(reference.utf8ToString(), not(containsString("compressed_definition"))); } - + public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException { TrainedModelConfig.LazyModelDefinition lazyModelDefinition = TrainedModelConfig.LazyModelDefinition .fromParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder().build()); TrainedModelConfig config = new TrainedModelConfig( randomAlphaOfLength(10), + TrainedModelType.TREE_ENSEMBLE, randomAlphaOfLength(10), Version.CURRENT, randomBoolean() ? null : randomAlphaOfLength(100), @@ -189,7 +195,8 @@ public void testParseWithBothDefinitionAndCompressedSupplied() throws IOExceptio Stream.generate(() -> randomAlphaOfLength(10)) .limit(randomIntBetween(1, 10)) .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))), - randomFrom(ClassificationConfigTests.randomClassificationConfig(), RegressionConfigTests.randomRegressionConfig())); + randomFrom(ClassificationConfigTests.randomClassificationConfig(), RegressionConfigTests.randomRegressionConfig()), + null); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); Map objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2(); @@ -207,10 +214,28 @@ public void testParseWithBothDefinitionAndCompressedSupplied() throws IOExceptio } } - public void testValidateWithNullDefinition() { + public void testValidateWithNoDefinitionOrLocation() { ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class, () -> TrainedModelConfig.builder().validate()); - assertThat(ex.getMessage(), containsString("[definition] must not be null.")); + assertThat(ex.getMessage(), containsString("either a model [definition] or [location] must be defined.")); + } + + public void testValidateWithBothDefinitionAndLocation() { + ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class, + () -> TrainedModelConfig.builder() + .setLocation(IndexLocationTests.randomInstance()) + .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setModelType(TrainedModelType.PYTORCH) + .validate()); + assertThat(ex.getMessage(), containsString("[definition] and [location] are both defined but only one can be used.")); + } + + public void testValidateWithWithMissingTypeAndDefinition() { + ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class, + () -> TrainedModelConfig.builder() + .setLocation(IndexLocationTests.randomInstance()) + .validate()); + assertThat(ex.getMessage(), containsString("[model_type] must be set if [definition] is not defined")); } public void testValidateWithInvalidID() { @@ -260,9 +285,9 @@ public void testSerializationWithLazyDefinition() throws IOException { xContentTester(this::createParser, () -> { try { - String compressedString = InferenceToXContentCompressor.deflate(TrainedModelDefinitionTests.createRandomBuilder().build()); + BytesReference bytes = InferenceToXContentCompressor.deflate(TrainedModelDefinitionTests.createRandomBuilder().build()); return createTestInstance(randomAlphaOfLength(10)) - .setDefinitionFromString(compressedString) + .setDefinitionFromBytes(bytes) .build(); } catch (IOException ex) { fail(ex.getMessage()); @@ -291,10 +316,10 @@ public void testSerializationWithCompressedLazyDefinition() throws IOException { xContentTester(this::createParser, () -> { try { - String compressedString = + BytesReference bytes = InferenceToXContentCompressor.deflate(TrainedModelDefinitionTests.createRandomBuilder().build()); return createTestInstance(randomAlphaOfLength(10)) - .setDefinitionFromString(compressedString) + .setDefinitionFromBytes(bytes) .build(); } catch (IOException ex) { fail(ex.getMessage()); @@ -328,6 +353,10 @@ protected TrainedModelConfig mutateInstanceForVersion(TrainedModelConfig instanc if (version.before(Version.V_7_8_0)) { builder.setInferenceConfig(null); } + if (version.before(TrainedModelConfig.VERSION_3RD_PARTY_CONFIG_ADDED)) { + builder.setModelType((TrainedModelType)null); + builder.setLocation(null); + } return builder.build(); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelTypeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelTypeTests.java new file mode 100644 index 0000000000000..009dcf5da794b --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelTypeTests.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetworkTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; + +public class TrainedModelTypeTests extends ESTestCase { + + public void testTypeFromTrainedModel() { + { + TrainedModel tm = randomFrom(TreeTests.createRandom(TargetType.CLASSIFICATION), + EnsembleTests.createRandom(TargetType.CLASSIFICATION)); + assertEquals(TrainedModelType.TREE_ENSEMBLE, TrainedModelType.typeFromTrainedModel(tm)); + } + { + TrainedModel tm = LangIdentNeuralNetworkTests.createRandom(); + assertEquals(TrainedModelType.LANG_IDENT, TrainedModelType.typeFromTrainedModel(tm)); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java new file mode 100644 index 0000000000000..57867b6091d0f --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class FillMaskResultsTests extends AbstractWireSerializingTestCase { + @Override + protected Writeable.Reader instanceReader() { + return FillMaskResults::new; + } + + @Override + protected FillMaskResults createTestInstance() { + int numResults = randomIntBetween(0, 3); + List resultList = new ArrayList<>(); + for (int i=0; i asMap = testInstance.asMap(); + List> resultList = (List>)asMap.get("results"); + assertThat(resultList, hasSize(testInstance.getPredictions().size())); + for (int i = 0; i map = resultList.get(i); + assertThat(map.get("score"), equalTo(result.getScore())); + assertThat(map.get("token"), equalTo(result.getToken())); + assertThat(map.get("sequence"), equalTo(result.getSequence())); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java new file mode 100644 index 0000000000000..e2764583c53b6 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class NerResultsTests extends AbstractWireSerializingTestCase { + @Override + protected Writeable.Reader instanceReader() { + return NerResults::new; + } + + @Override + protected NerResults createTestInstance() { + int numEntities = randomIntBetween(0, 3); + List entityGroups = new ArrayList<>(); + for (int i=0; i asMap = testInstance.asMap(); + List> resultList = (List>)asMap.get("results"); + assertThat(resultList, hasSize(testInstance.getEntityGroups().size())); + for (int i=0; i map = resultList.get(i); + assertThat(map.get("label"), equalTo(entity.getLabel())); + assertThat(map.get("score"), equalTo(entity.getScore())); + assertThat(map.get("word"), equalTo(entity.getWord())); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocationTests.java new file mode 100644 index 0000000000000..5df423df1667c --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocationTests.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +public class IndexLocationTests extends AbstractSerializingTestCase { + + private final boolean lenient = randomBoolean(); + + public static IndexLocation randomInstance() { + return new IndexLocation(randomAlphaOfLength(7), randomAlphaOfLength(7)); + } + + @Override + protected IndexLocation doParseInstance(XContentParser parser) throws IOException { + return lenient ? IndexLocation.fromXContentLenient(parser) : IndexLocation.fromXContentStrict(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return IndexLocation::new; + } + + @Override + protected IndexLocation createTestInstance() { + return randomInstance(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java index d1a486dbe0b8a..58cd3eaad2d2b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; +import com.unboundid.util.Base64; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.DeprecationHandler; @@ -24,6 +25,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import java.io.IOException; +import java.text.ParseException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -59,7 +61,7 @@ public void testTreeSchemaDeserialization() throws IOException { assertThat(definition.getTrainedModel().getClass(), equalTo(TreeInferenceModel.class)); } - public void testMultiClassIrisInference() throws IOException { + public void testMultiClassIrisInference() throws IOException, ParseException { // Fairly simple, random forest classification model built to fit in our format // Trained on the well known Iris dataset String compressedDef = "H4sIAPbiMl4C/+1b246bMBD9lVWet8jjG3b/oN9QVYgmToLEkghIL6r23wukl90" + @@ -83,7 +85,8 @@ public void testMultiClassIrisInference() throws IOException { "aLbAYWcAdpeweKa2IfIT2jz5QzXxD6AoP+DrdXtxeluV7pdWrvkcKqPp7rjS19d+wp/fff/5Ez3FPjzFNy" + "fdpTi9JB0sDp2JR7b309mn5HuPkEAAA=="; - InferenceDefinition definition = InferenceToXContentCompressor.inflate(compressedDef, + byte[] bytes = Base64.decode(compressedDef); + InferenceDefinition definition = InferenceToXContentCompressor.inflate(new BytesArray(bytes), InferenceDefinition::fromXContent, xContentRegistry()); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index 649bf86f5f323..f6604df19016f 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -10,9 +10,14 @@ import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; +import org.elasticsearch.client.ml.GetTrainedModelsResponse; +import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.client.ml.inference.TrainedModelConfig; import org.elasticsearch.client.ml.inference.TrainedModelDefinition; import org.elasticsearch.client.ml.inference.TrainedModelInput; +import org.elasticsearch.client.ml.inference.trainedmodel.IndexLocation; +import org.elasticsearch.client.ml.inference.TrainedModelType; +import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; @@ -23,10 +28,12 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.test.SecuritySettingsSourceField; import org.elasticsearch.test.rest.ESRestTestCase; @@ -44,11 +51,18 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.not; +/** + * This test uses a mixture of HLRC and server side classes. + * + * The server classes have builders that set the one-time fields that + * can only be set on creation e.g. create_time. The HLRC classes must + * be used when creating PUT trained model requests as they do not set + * these one-time fields. + */ public class TrainedModelIT extends ESRestTestCase { private static final String BASIC_AUTH_VALUE = UsernamePasswordToken.basicAuthHeaderValue("x_pack_rest_user", @@ -59,6 +73,11 @@ protected Settings restClientSettings() { return Settings.builder().put(super.restClientSettings()).put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE).build(); } + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + } + @Override protected boolean preserveTemplatesUponCompletion() { return true; @@ -96,6 +115,7 @@ public void testGetTrainedModels() throws IOException { assertThat(response, containsString("\"model_id\":\"a_test_regression_model\"")); assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\"")); assertThat(response, containsString("\"estimated_heap_memory_usage\"")); + assertThat(response, containsString("\"model_type\":\"tree_ensemble\"")); assertThat(response, containsString("\"definition\"")); assertThat(response, not(containsString("\"compressed_definition\""))); assertThat(response, containsString("\"count\":1")); @@ -228,6 +248,24 @@ public void testExportImportModel() throws IOException { assertThat(response, containsString("\"count\":2")); } + public void testPyTorchModelConfig() throws IOException { + String modelId = "pytorch1"; + String pytorchModelId = "pytorch_model"; + String index = "pytorch_index"; + putPyTorchModel(modelId, pytorchModelId, index); + Response getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "trained_models/" + modelId)); + + try (XContentParser parser = createParser(XContentType.JSON.xContent(), getModel.getEntity().getContent())) { + GetTrainedModelsResponse response = GetTrainedModelsResponse.fromXContent(parser); + TrainedModelConfig model = response.getTrainedModels().get(0); + assertThat(model.getModelType(), equalTo(TrainedModelType.PYTORCH)); + IndexLocation location = (IndexLocation) model.getLocation(); + assertThat(location.getModelId(), equalTo(pytorchModelId)); + assertThat(location.getIndex(), equalTo(index)); + assertThat(model.getEstimatedOperations(), equalTo(0L)); + } + } + private void putRegressionModel(String modelId) throws IOException { try(XContentBuilder builder = XContentFactory.jsonBuilder()) { TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder() @@ -245,6 +283,21 @@ private void putRegressionModel(String modelId) throws IOException { } } + private void putPyTorchModel(String modelId, String pytorchModelId, String index) throws IOException { + try(XContentBuilder builder = XContentFactory.jsonBuilder()) { + TrainedModelConfig.builder() + .setLocation(new IndexLocation(pytorchModelId, index)) + .setModelType(TrainedModelType.PYTORCH) + .setInferenceConfig(new ClassificationConfig()) + .setModelId(modelId) + .setInput(new TrainedModelInput(Collections.singletonList("text"))) + .build().toXContent(builder, ToXContent.EMPTY_PARAMS); + Request model = new Request("PUT", "_ml/trained_models/" + modelId); + model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON)); + assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200)); + } + } + private static TrainedModel buildRegression() { List featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog"); Tree tree1 = Tree.builder() diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index edb0d52cbc354..1992e25d7f3f2 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -10,6 +10,7 @@ import org.elasticsearch.Version; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.license.License; import org.elasticsearch.xpack.core.action.util.PageParams; @@ -22,10 +23,11 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaselineTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.HyperparametersTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister; @@ -43,6 +45,8 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; import java.util.Collections; import java.util.List; import java.util.Map; @@ -74,9 +78,8 @@ public void testStoreModelViaChunkedPersister() throws IOException { .build(); List extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet())); TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId); - String compressedDefinition = configBuilder.build().getCompressedDefinition(); - int totalSize = compressedDefinition.length(); - List chunks = chunkStringWithSize(compressedDefinition, totalSize/3); + BytesReference compressedDefinition = configBuilder.build().getCompressedDefinition(); + List base64Chunks = chunkBinaryDefinition(compressedDefinition, compressedDefinition.length() / 3); ChunkedTrainedModelPersister persister = new ChunkedTrainedModelPersister(trainedModelProvider, analyticsConfig, @@ -87,9 +90,10 @@ public void testStoreModelViaChunkedPersister() throws IOException { //Accuracy for size is not tested here ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); - persister.createAndIndexInferenceModelConfig(modelSizeInfo); - for (int i = 0; i < chunks.size(); i++) { - persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunks.get(i), i, i == (chunks.size() - 1))); + persister.createAndIndexInferenceModelConfig(modelSizeInfo, configBuilder.getModelType()); + for (int i = 0; i < base64Chunks.size(); i++) { + persister.createAndIndexInferenceModelDoc( + new TrainedModelDefinitionChunk(base64Chunks.get(i), i, i == (base64Chunks.size() - 1))); } ModelMetadata modelMetadata = new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance) .limit(randomIntBetween(1, 10)) @@ -141,6 +145,7 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION)) .setDescription("trained model config for test") .setModelId(modelId) + .setModelType(TrainedModelType.TREE_ENSEMBLE) .setVersion(Version.CURRENT) .setLicenseLevel(License.OperationMode.PLATINUM.description()) .setEstimatedHeapMemory(bytesUsed) @@ -148,12 +153,13 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setInput(TrainedModelInputTests.createRandomInput()); } - public static List chunkStringWithSize(String str, int chunkSize) { - List subStrings = new ArrayList<>((str.length() + chunkSize - 1) / chunkSize); - for (int i = 0; i < str.length(); i += chunkSize) { - subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length()))); + public static List chunkBinaryDefinition(BytesReference bytes, int chunkSize) { + List subStrings = new ArrayList<>((bytes.length() + chunkSize - 1) / chunkSize); + for (int i = 0; i < bytes.length(); i += chunkSize) { + subStrings.add( + Base64.getEncoder().encodeToString( + Arrays.copyOfRange(bytes.array(), i, Math.min(i + chunkSize, bytes.length())))); } return subStrings; } - } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelRestorerIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelRestorerIT.java new file mode 100644 index 0000000000000..6919812c3000c --- /dev/null +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelRestorerIT.java @@ -0,0 +1,218 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.xpack.ml.inference.persistence.ChunkedTrainedModelRestorer; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.Matchers.hasSize; + +public class ChunkedTrainedModelRestorerIT extends MlSingleNodeTestCase { + + public void testRestoreWithMultipleSearches() throws IOException, InterruptedException { + String modelId = "test-multiple-searches"; + int numDocs = 22; + List modelDefs = new ArrayList<>(numDocs); + + for (int i=0; i expectedDocs = createModelDefinitionDocs(modelDefs, modelId); + putModelDefinitions(expectedDocs, InferenceIndexConstants.LATEST_INDEX_NAME, 0); + + + ChunkedTrainedModelRestorer restorer = new ChunkedTrainedModelRestorer(modelId, client(), + client().threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME), xContentRegistry()); + restorer.setSearchSize(5); + List actualDocs = new ArrayList<>(); + + AtomicReference exceptionHolder = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + restorer.restoreModelDefinition( + actualDocs::add, + success -> latch.countDown(), + failure -> { + exceptionHolder.set(failure); + latch.countDown(); + }); + + latch.await(); + + assertNull(exceptionHolder.get()); + assertEquals(actualDocs, expectedDocs); + } + + public void testCancel() throws IOException, InterruptedException { + String modelId = "test-cancel-search"; + int numDocs = 6; + List modelDefs = new ArrayList<>(numDocs); + + for (int i=0; i expectedDocs = createModelDefinitionDocs(modelDefs, modelId); + putModelDefinitions(expectedDocs, InferenceIndexConstants.LATEST_INDEX_NAME, 0); + + ChunkedTrainedModelRestorer restorer = new ChunkedTrainedModelRestorer(modelId, client(), + client().threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME), xContentRegistry()); + restorer.setSearchSize(5); + List actualDocs = new ArrayList<>(); + + AtomicReference exceptionHolder = new AtomicReference<>(); + AtomicBoolean successValue = new AtomicBoolean(Boolean.TRUE); + CountDownLatch latch = new CountDownLatch(1); + + restorer.restoreModelDefinition( + doc -> { + actualDocs.add(doc); + return false; + }, + success -> { + successValue.set(success); + latch.countDown(); + }, + failure -> { + exceptionHolder.set(failure); + latch.countDown(); + }); + + latch.await(); + + assertNull(exceptionHolder.get()); + assertFalse(successValue.get()); + assertThat(actualDocs, hasSize(1)); + assertEquals(expectedDocs.get(0), actualDocs.get(0)); + } + + public void testRestoreWithDocumentsInMultipleIndices() throws IOException, InterruptedException { + String index1 = "foo-1"; + String index2 = "foo-2"; + + for (String index : new String[]{index1, index2}) { + client().admin().indices().prepareCreate(index) + .setMapping(TrainedModelDefinitionDoc.DEFINITION.getPreferredName(), "type=binary", + InferenceIndexConstants.DOC_TYPE.getPreferredName(), "type=keyword", + TrainedModelConfig.MODEL_ID.getPreferredName(), "type=keyword").get(); + } + + String modelId = "test-multiple-indices"; + int numDocs = 24; + List modelDefs = new ArrayList<>(numDocs); + + for (int i=0; i expectedDocs = createModelDefinitionDocs(modelDefs, modelId); + int splitPoint = (numDocs / 2) -1; + putModelDefinitions(expectedDocs.subList(0, splitPoint), index1, 0); + putModelDefinitions(expectedDocs.subList(splitPoint, numDocs), index2, splitPoint); + + ChunkedTrainedModelRestorer restorer = new ChunkedTrainedModelRestorer(modelId, client(), + client().threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME), xContentRegistry()); + restorer.setSearchSize(10); + restorer.setSearchIndex("foo-*"); + + AtomicReference exceptionHolder = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + List actualDocs = new ArrayList<>(); + + restorer.restoreModelDefinition( + actualDocs::add, + success -> latch.countDown(), + failure -> { + exceptionHolder.set(failure); + latch.countDown(); + }); + + latch.await(); + + assertNull(exceptionHolder.get()); + // The results are sorted by index first rather than doc_num + // TODO is this the behaviour we want? + List reorderedDocs = new ArrayList<>(); + reorderedDocs.addAll(expectedDocs.subList(splitPoint, numDocs)); + reorderedDocs.addAll(expectedDocs.subList(0, splitPoint)); + assertEquals(actualDocs, reorderedDocs); + } + + private List createModelDefinitionDocs(List compressedDefinitions, String modelId) { + int totalLength = compressedDefinitions.stream().map(BytesReference::length).reduce(0, Integer::sum); + + List docs = new ArrayList<>(); + for (int i = 0; i < compressedDefinitions.size(); i++) { + docs.add(new TrainedModelDefinitionDoc.Builder() + .setDocNum(i) + .setBinaryData(compressedDefinitions.get(i)) + .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) + .setTotalDefinitionLength(totalLength) + .setDefinitionLength(compressedDefinitions.get(i).length()) + .setEos(i == compressedDefinitions.size() - 1) + .setModelId(modelId) + .build()); + } + + return docs; + } + + private void putModelDefinitions(List docs, String index, int startingDocNum) throws IOException { + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + for (TrainedModelDefinitionDoc doc : docs) { + try (XContentBuilder xContentBuilder = doc.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) { + IndexRequestBuilder indexRequestBuilder = client().prepareIndex(index) + .setSource(xContentBuilder) + .setId(TrainedModelDefinitionDoc.docId(doc.getModelId(), startingDocNum++)); + + bulkRequestBuilder.add(indexRequestBuilder); + } + } + + BulkResponse bulkResponse = bulkRequestBuilder + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get(); + if (bulkResponse.hasFailures()) { + int failures = 0; + for (BulkItemResponse itemResponse : bulkResponse) { + if (itemResponse.isFailed()) { + failures++; + logger.error("Item response failure [{}]", itemResponse.getFailureMessage()); + } + } + fail("Bulk response contained " + failures + " failures"); + } + } +} diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index f9ecad4538045..eb9bef0268779 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; @@ -359,6 +360,7 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setCreatedBy("ml_test") .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDescription("trained model config for test") + .setModelType(TrainedModelType.TREE_ENSEMBLE) .setModelId(modelId); } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/PyTorchStateStreamerIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/PyTorchStateStreamerIT.java new file mode 100644 index 0000000000000..89c8e78f492e6 --- /dev/null +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/PyTorchStateStreamerIT.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; +import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +public class PyTorchStateStreamerIT extends MlSingleNodeTestCase { + + public void testRestoreState() throws IOException, InterruptedException { + int numChunks = 5; + int chunkSize = 100; + int modelSize = numChunks * chunkSize; + + String modelId = "test-state-streamer-restore"; + + List chunks = new ArrayList<>(numChunks); + for (int i=0; i docs = createModelDefinitionDocs(chunks, modelId); + putModelDefinition(docs); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(modelSize); + PyTorchStateStreamer stateStreamer = new PyTorchStateStreamer(client(), + client().threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME), xContentRegistry()); + + AtomicReference onSuccess = new AtomicReference<>(); + AtomicReference onFailure = new AtomicReference<>(); + blockingCall(listener -> + stateStreamer.writeStateToStream(modelId, InferenceIndexConstants.LATEST_INDEX_NAME, outputStream, listener), + onSuccess, onFailure); + + byte[] writtenData = outputStream.toByteArray(); + + // the first 4 bytes are the model size + int writtenSize = ByteBuffer.wrap(writtenData, 0, 4).getInt(); + assertEquals(modelSize, writtenSize); + + byte[] writtenChunk = new byte[chunkSize]; + for (int i=0; i createModelDefinitionDocs(List binaryChunks, String modelId) { + + int totalLength = binaryChunks.stream().map(arr -> arr.length).reduce(0, Integer::sum); + + List docs = new ArrayList<>(); + for (int i = 0; i < binaryChunks.size(); i++) { + String encodedData = new String(Base64.getEncoder().encode(binaryChunks.get(i)), StandardCharsets.UTF_8); + + docs.add(new TrainedModelDefinitionDoc.Builder() + .setDocNum(i) + .setCompressedString(encodedData) + .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) + .setTotalDefinitionLength(totalLength) + .setDefinitionLength(encodedData.length()) + .setEos(i == binaryChunks.size() - 1) + .setModelId(modelId) + .build()); + } + return docs; + } + + + private void putModelDefinition(List docs) throws IOException { + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + for (int i = 0; i < docs.size(); i++) { + TrainedModelDefinitionDoc doc = docs.get(i); + try (XContentBuilder xContentBuilder = doc.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) { + IndexRequestBuilder indexRequestBuilder = client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME) + .setSource(xContentBuilder) + .setId(TrainedModelDefinitionDoc.docId(doc.getModelId(), i)); + + bulkRequestBuilder.add(indexRequestBuilder); + } + } + + BulkResponse bulkResponse = bulkRequestBuilder + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get(); + if (bulkResponse.hasFailures()) { + int failures = 0; + for (BulkItemResponse itemResponse : bulkResponse) { + if (itemResponse.isFailed()) { + failures++; + logger.error("Item response failure [{}]", itemResponse.getFailureMessage()); + } + } + fail("Bulk response contained " + failures + " failures"); + } + logger.debug("Indexed [{}] documents", docs.size()); + } +} diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index cd69df7a65041..1c395b6baf6ce 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -8,9 +8,13 @@ import org.elasticsearch.Version; import org.elasticsearch.action.admin.indices.refresh.RefreshResponse; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.delete.DeleteRequest; +import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; @@ -19,7 +23,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaseline; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.xpack.core.ml.job.messages.Messages; @@ -28,6 +34,7 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.junit.Before; +import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicReference; @@ -35,7 +42,6 @@ import java.util.stream.IntStream; import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; -import static org.elasticsearch.xpack.ml.integration.ChunkedTrainedModelPersisterIT.chunkStringWithSize; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasKey; @@ -143,6 +149,7 @@ public void testGetTrainedModelConfigWithoutDefinition() throws Exception { .setEstimatedOperations(config.getEstimatedOperations()) .setInput(config.getInput()) .setModelId(config.getModelId()) + .setModelType(TrainedModelType.TREE_ENSEMBLE) .setTags(config.getTags()) .setVersion(config.getVersion()) .setMetadata(config.getMetadata()) @@ -216,7 +223,7 @@ public void testGetTruncatedModelDeprecatedDefinition() throws Exception { TrainedModelDefinitionDoc truncatedDoc = new TrainedModelDefinitionDoc.Builder() .setDocNum(0) - .setCompressedString(config.getCompressedDefinition().substring(0, config.getCompressedDefinition().length() - 10)) + .setBinaryData(config.getCompressedDefinition().slice(0, config.getCompressedDefinition().length() - 10)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) .setDefinitionLength(config.getCompressedDefinition().length()) .setTotalDefinitionLength(config.getCompressedDefinition().length()) @@ -256,34 +263,39 @@ public void testGetTruncatedModelDefinition() throws Exception { assertThat(putConfigHolder.get(), is(true)); assertThat(exceptionHolder.get(), is(nullValue())); - List chunks = chunkStringWithSize(config.getCompressedDefinition(), config.getCompressedDefinition().length()/3); + // The model definition has been put with the config above but it + // is not large enough to be split into chunk. Chunks are required + // for this test so overwrite the definition with multiple chunks + List docBuilders = createModelDefinitionDocs(config.getCompressedDefinition(), modelId); - List docBuilders = IntStream.range(0, chunks.size()) - .mapToObj(i -> new TrainedModelDefinitionDoc.Builder() - .setDocNum(i) - .setCompressedString(chunks.get(i)) - .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) - .setDefinitionLength(chunks.get(i).length()) - .setEos(i == chunks.size() - 1) - .setModelId(modelId)) - .collect(Collectors.toList()); boolean missingEos = randomBoolean(); - docBuilders.get(docBuilders.size() - 1).setEos(missingEos == false); - for (int i = missingEos ? 0 : 1 ; i < docBuilders.size(); ++i) { + if (missingEos) { + // Set the wrong end of stream value + docBuilders.get(docBuilders.size() - 1).setEos(false); + } else { + // else write fewer than the expected number of docs + docBuilders.remove(docBuilders.size() -1); + } + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + for (int i = 0; i < docBuilders.size(); ++i) { TrainedModelDefinitionDoc doc = docBuilders.get(i).build(); - try(XContentBuilder xContentBuilder = doc.toXContent(XContentFactory.jsonBuilder(), - new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")))) { - AtomicReference putDocHolder = new AtomicReference<>(); - blockingCall(listener -> client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .setSource(xContentBuilder) - .setId(TrainedModelDefinitionDoc.docId(modelId, 0)) - .execute(listener), - putDocHolder, - exceptionHolder); - assertThat(exceptionHolder.get(), is(nullValue())); + try (XContentBuilder xContentBuilder = doc.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) { + + IndexRequestBuilder indexRequestBuilder = client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME) + .setSource(xContentBuilder) + .setId(TrainedModelDefinitionDoc.docId(modelId, i)); + + bulkRequestBuilder.add(indexRequestBuilder); } } + bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + AtomicReference putDocsHolder = new AtomicReference<>(); + blockingCall(bulkRequestBuilder::execute, putDocsHolder, exceptionHolder); + assertThat(exceptionHolder.get(), is(nullValue())); + assertFalse(putDocsHolder.get().hasFailures()); + + AtomicReference getConfigHolder = new AtomicReference<>(); blockingCall( listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener), @@ -294,12 +306,68 @@ public void testGetTruncatedModelDefinition() throws Exception { assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); } + public void testGetTrainedModelForInference() throws InterruptedException, IOException { + String modelId = "test-model-for-inference"; + TrainedModelConfig config = buildTrainedModelConfig(modelId); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + + List docBuilders = createModelDefinitionDocs(config.getCompressedDefinition(), modelId); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + for (int i = 0; i < docBuilders.size(); i++) { + TrainedModelDefinitionDoc doc = docBuilders.get(i).build(); + try (XContentBuilder xContentBuilder = doc.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) { + IndexRequestBuilder indexRequestBuilder = client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME) + .setSource(xContentBuilder) + .setId(TrainedModelDefinitionDoc.docId(modelId, i)); + + bulkRequestBuilder.add(indexRequestBuilder); + } + } + bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + AtomicReference putDocsHolder = new AtomicReference<>(); + blockingCall(bulkRequestBuilder::execute, putDocsHolder, exceptionHolder); + + assertThat(exceptionHolder.get(), is(nullValue())); + assertFalse(putDocsHolder.get().hasFailures()); + + AtomicReference definitionHolder = new AtomicReference<>(); + blockingCall( + listener -> trainedModelProvider.getTrainedModelForInference(modelId, listener), + definitionHolder, + exceptionHolder); + assertThat(exceptionHolder.get(), is(nullValue())); + assertThat(definitionHolder.get(), is(not(nullValue()))); + } + + private List createModelDefinitionDocs(BytesReference compressedDefinition, String modelId) { + List chunks = TrainedModelProvider.chunkDefinitionWithSize(compressedDefinition, compressedDefinition.length()/3); + + return IntStream.range(0, chunks.size()) + .mapToObj(i -> new TrainedModelDefinitionDoc.Builder() + .setDocNum(i) + .setBinaryData(chunks.get(i)) + .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) + .setDefinitionLength(chunks.get(i).length()) + .setTotalDefinitionLength(compressedDefinition.length()) + .setEos(i == chunks.size() - 1) + .setModelId(modelId)) + .collect(Collectors.toList()); + } + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { return TrainedModelConfig.builder() .setCreatedBy("ml_test") .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDescription("trained model config for test") .setModelId(modelId) + .setModelType(TrainedModelType.TREE_ENSEMBLE) .setVersion(Version.CURRENT) .setLicenseLevel(License.OperationMode.PLATINUM.description()) .setEstimatedHeapMemory(0) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 5b7acc8567543..6447c6cb95a57 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -97,6 +97,8 @@ import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction; @@ -144,6 +146,7 @@ import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; +import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.UpdateCalendarJobAction; import org.elasticsearch.xpack.core.ml.action.UpdateDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.UpdateDatafeedAction; @@ -160,6 +163,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; @@ -179,6 +183,8 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.ml.action.TransportInferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportEstimateModelMemoryAction; import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction; import org.elasticsearch.xpack.ml.action.TransportExplainDataFrameAnalyticsAction; @@ -226,6 +232,7 @@ import org.elasticsearch.xpack.ml.action.TransportStartDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportStopDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportStopDatafeedAction; +import org.elasticsearch.xpack.ml.action.TransportStopTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.action.TransportUpdateCalendarJobAction; import org.elasticsearch.xpack.ml.action.TransportUpdateDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportUpdateDatafeedAction; @@ -258,10 +265,13 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; +import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcessFactory; +import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; @@ -335,10 +345,13 @@ import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasAction; +import org.elasticsearch.xpack.ml.rest.inference.RestInferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.ml.rest.inference.RestStartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAliasAction; +import org.elasticsearch.xpack.ml.rest.inference.RestStopTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction; @@ -531,6 +544,7 @@ public Map getProcessors(Processor.Parameters paramet private final SetOnce inferenceModelBreaker = new SetOnce<>(); private final SetOnce modelLoadingService = new SetOnce<>(); private final SetOnce mlAutoscalingDeciderService = new SetOnce<>(); + private final SetOnce deploymentManager = new SetOnce<>(); public MachineLearning(Settings settings, Path configPath) { this.settings = settings; @@ -690,6 +704,7 @@ public Collection createComponents(Client client, ClusterService cluster final NormalizerProcessFactory normalizerProcessFactory; final AnalyticsProcessFactory analyticsProcessFactory; final AnalyticsProcessFactory memoryEstimationProcessFactory; + final PyTorchProcessFactory pyTorchProcessFactory; if (MachineLearningField.AUTODETECT_PROCESS.get(settings)) { try { NativeController nativeController = @@ -711,6 +726,7 @@ public Collection createComponents(Client client, ClusterService cluster dataFrameAnalyticsAuditor); memoryEstimationProcessFactory = new NativeMemoryUsageEstimationProcessFactory(environment, nativeController, clusterService); + pyTorchProcessFactory = new NativePyTorchProcessFactory(environment, nativeController, clusterService); mlController = nativeController; } catch (IOException e) { // The low level cause of failure from the named pipe helper's perspective is almost never the real root cause, so @@ -730,6 +746,7 @@ public Collection createComponents(Client client, ClusterService cluster normalizerProcessFactory = (jobId, quantilesState, bucketSpan, executorService) -> new MultiplyingNormalizerProcess(1.0); analyticsProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null; memoryEstimationProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null; + pyTorchProcessFactory = (jobId, executorService, onProcessCrash) -> null; } NormalizerFactory normalizerFactory = new NormalizerFactory(normalizerProcessFactory, threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)); @@ -770,6 +787,7 @@ public Collection createComponents(Client client, ClusterService cluster clusterService.getNodeName(), inferenceModelBreaker.get()); this.modelLoadingService.set(modelLoadingService); + this.deploymentManager.set(new DeploymentManager(client, xContentRegistry, threadPool, pyTorchProcessFactory)); // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager( @@ -874,7 +892,12 @@ public List> getPersistentTasksExecutor(ClusterServic autodetectProcessManager.get(), memoryTracker.get(), expressionResolver, - client) + client), + new TransportStartTrainedModelDeploymentAction.TaskExecutor(settings, + clusterService, + expressionResolver, + memoryTracker.get(), + deploymentManager.get()) ); } @@ -950,6 +973,9 @@ public List getRestHandlers(Settings settings, RestController restC new RestPutTrainedModelAliasAction(), new RestDeleteTrainedModelAliasAction(), new RestPreviewDataFrameAnalyticsAction(), + new RestStartTrainedModelDeploymentAction(), + new RestStopTrainedModelDeploymentAction(), + new RestInferTrainedModelDeploymentAction(), // CAT Handlers new RestCatJobsAction(), new RestCatTrainedModelsAction(), @@ -1037,6 +1063,9 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(DeleteTrainedModelAliasAction.INSTANCE, TransportDeleteTrainedModelAliasAction.class), new ActionHandler<>(PreviewDataFrameAnalyticsAction.INSTANCE, TransportPreviewDataFrameAnalyticsAction.class), new ActionHandler<>(SetResetModeAction.INSTANCE, TransportSetResetModeAction.class), + new ActionHandler<>(StartTrainedModelDeploymentAction.INSTANCE, TransportStartTrainedModelDeploymentAction.class), + new ActionHandler<>(StopTrainedModelDeploymentAction.INSTANCE, TransportStopTrainedModelDeploymentAction.class), + new ActionHandler<>(InferTrainedModelDeploymentAction.INSTANCE, TransportInferTrainedModelDeploymentAction.class), usageAction, infoAction); } @@ -1175,6 +1204,8 @@ public List getNamedWriteables() { StartDataFrameAnalyticsAction.TaskParams::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME, SnapshotUpgradeTaskParams::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, + StartTrainedModelDeploymentAction.TaskParams::new)); // Persistent task states namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, JobTaskState.NAME, JobTaskState::new)); @@ -1184,6 +1215,8 @@ public List getNamedWriteables() { namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, SnapshotUpgradeTaskState.NAME, SnapshotUpgradeTaskState::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, + TrainedModelDeploymentTaskState.NAME, TrainedModelDeploymentTaskState::new)); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new AnalysisStatsNamedWriteablesProvider().getNamedWriteables()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..7ff1157706079 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.TaskOperationFailure; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.tasks.TransportTasksAction; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; + +import java.util.List; + +public class TransportInferTrainedModelDeploymentAction extends TransportTasksAction { + + @Inject + public TransportInferTrainedModelDeploymentAction(ClusterService clusterService, TransportService transportService, + ActionFilters actionFilters) { + super(InferTrainedModelDeploymentAction.NAME, clusterService, transportService, actionFilters, + InferTrainedModelDeploymentAction.Request::new, InferTrainedModelDeploymentAction.Response::new, + InferTrainedModelDeploymentAction.Response::new, ThreadPool.Names.SAME); + } + + @Override + protected void doExecute(Task task, InferTrainedModelDeploymentAction.Request request, + ActionListener listener) { + String deploymentId = request.getDeploymentId(); + // We need to check whether there is at least an assigned task here, otherwise we cannot redirect to the + // node running the job task. + PersistentTasksCustomMetadata tasks = clusterService.state().getMetadata().custom(PersistentTasksCustomMetadata.TYPE); + PersistentTasksCustomMetadata.PersistentTask deploymentTask = MlTasks.getTrainedModelDeploymentTask(deploymentId, tasks); + if (deploymentTask == null || deploymentTask.isAssigned() == false) { + String message = "Cannot perform requested action because deployment [" + deploymentId + "] is not started"; + listener.onFailure(ExceptionsHelper.conflictStatusException(message)); + } else { + request.setNodes(deploymentTask.getExecutorNode()); + super.doExecute(task, request, listener); + } + } + + @Override + protected InferTrainedModelDeploymentAction.Response newResponse(InferTrainedModelDeploymentAction.Request request, + List tasks, + List taskOperationFailures, + List failedNodeExceptions) { + if (taskOperationFailures.isEmpty() == false) { + throw org.elasticsearch.ExceptionsHelper.convertToElastic(taskOperationFailures.get(0).getCause()); + } else if (failedNodeExceptions.isEmpty() == false) { + throw org.elasticsearch.ExceptionsHelper.convertToElastic(failedNodeExceptions.get(0)); + } else { + return tasks.get(0); + } + } + + @Override + protected void taskOperation(InferTrainedModelDeploymentAction.Request request, TrainedModelDeploymentTask task, + ActionListener listener) { + TimeValue timeout = request.getTimeout() == null ? TimeValue.timeValueSeconds(10) : request.getTimeout(); + task.infer(request.getInput(), timeout, + ActionListener.wrap( + pyTorchResult -> listener.onResponse(new InferTrainedModelDeploymentAction.Response(pyTorchResult)), + listener::onFailure) + ); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index fed61310f1c35..2be3eb4e16df3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Request; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Response; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -73,69 +74,97 @@ protected void masterOperation(Task task, PutTrainedModelAction.Request request, ClusterState state, ActionListener listener) { + TrainedModelConfig config = request.getTrainedModelConfig(); try { - request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry); - request.getTrainedModelConfig().getModelDefinition().getTrainedModel().validate(); + config.ensureParsedDefinition(xContentRegistry); } catch (IOException ex) { listener.onFailure(ExceptionsHelper.badRequestException("Failed to parse definition for [{}]", ex, - request.getTrainedModelConfig().getModelId())); - return; - } catch (ElasticsearchException ex) { - listener.onFailure(ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.", - ex, - request.getTrainedModelConfig().getModelId())); + config.getModelId())); return; } - if (request.getTrainedModelConfig() - .getInferenceConfig() - .isTargetTypeSupported(request.getTrainedModelConfig() - .getModelDefinition() - .getTrainedModel() - .targetType()) == false) { - listener.onFailure(ExceptionsHelper.badRequestException( - "Model [{}] inference config type [{}] does not support definition target type [{}]", - request.getTrainedModelConfig().getModelId(), - request.getTrainedModelConfig().getInferenceConfig().getName(), - request.getTrainedModelConfig() + + boolean hasModelDefinition = config.getModelDefinition() != null; + if (hasModelDefinition) { + try { + config.getModelDefinition().getTrainedModel().validate(); + } catch (ElasticsearchException ex) { + listener.onFailure(ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.", + ex, + config.getModelId())); + return; + } + + TrainedModelType trainedModelType = + TrainedModelType.typeFromTrainedModel(config.getModelDefinition().getTrainedModel()); + if (trainedModelType == null) { + listener.onFailure(ExceptionsHelper.badRequestException("Unknown trained model definition class [{}]", + config.getModelDefinition().getTrainedModel().getName())); + return; + } + + if (config.getModelType() == null) { + // Set the model type from the definition + config = new TrainedModelConfig.Builder(config).setModelType(trainedModelType).build(); + } else if (trainedModelType != config.getModelType()) { + listener.onFailure(ExceptionsHelper.badRequestException( + "{} [{}] does not match the model definition type [{}]", + TrainedModelConfig.MODEL_TYPE.getPreferredName(), config.getModelType(), + trainedModelType)); + return; + } + + if (config.getInferenceConfig() + .isTargetTypeSupported(config .getModelDefinition() .getTrainedModel() - .targetType())); - return; + .targetType()) == false) { + listener.onFailure(ExceptionsHelper.badRequestException( + "Model [{}] inference config type [{}] does not support definition target type [{}]", + config.getModelId(), + config.getInferenceConfig().getName(), + config.getModelDefinition().getTrainedModel().targetType())); + return; + } + + Version minCompatibilityVersion = config + .getModelDefinition() + .getTrainedModel() + .getMinimalCompatibilityVersion(); + if (state.nodes().getMinNodeVersion().before(minCompatibilityVersion)) { + listener.onFailure(ExceptionsHelper.badRequestException( + "Definition for [{}] requires that all nodes are at least version [{}]", + config.getModelId(), + minCompatibilityVersion.toString())); + return; + } } - Version minCompatibilityVersion = request.getTrainedModelConfig() - .getModelDefinition() - .getTrainedModel() - .getMinimalCompatibilityVersion(); - if (state.nodes().getMinNodeVersion().before(minCompatibilityVersion)) { - listener.onFailure(ExceptionsHelper.badRequestException( - "Definition for [{}] requires that all nodes are at least version [{}]", - request.getTrainedModelConfig().getModelId(), - minCompatibilityVersion.toString())); - return; - } - TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig()) + + + TrainedModelConfig.Builder trainedModelConfig = new TrainedModelConfig.Builder(config) .setVersion(Version.CURRENT) .setCreateTime(Instant.now()) .setCreatedBy("api_user") - .setLicenseLevel(License.OperationMode.PLATINUM.description()) - .setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed()) - .setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations()) - .build(); + .setLicenseLevel(License.OperationMode.PLATINUM.description()); + if (hasModelDefinition) { + trainedModelConfig.setEstimatedHeapMemory(config.getModelDefinition().ramBytesUsed()) + .setEstimatedOperations(config.getModelDefinition().getTrainedModel().estimatedNumOperations()); + } + if (ModelAliasMetadata.fromState(state).getModelId(trainedModelConfig.getModelId()) != null) { listener.onFailure(ExceptionsHelper.badRequestException( "requested model_id [{}] is the same as an existing model_alias. Model model_aliases and ids must be unique", - request.getTrainedModelConfig().getModelId() + config.getModelId() )); return; } ActionListener tagsModelIdCheckListener = ActionListener.wrap( - r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap( + r -> trainedModelProvider.storeTrainedModel(trainedModelConfig.build(), ActionListener.wrap( bool -> { - TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build(); + TrainedModelConfig configToReturn = trainedModelConfig.clearDefinition().build(); listener.onResponse(new PutTrainedModelAction.Response(configToReturn)); }, listener::onFailure @@ -148,7 +177,7 @@ protected void masterOperation(Task task, listener::onFailure ); - checkModelIdAgainstTags(request.getTrainedModelConfig().getModelId(), modelIdTagCheckListener); + checkModelIdAgainstTags(config.getModelId(), modelIdTagCheckListener); } private void checkModelIdAgainstTags(String modelId, ActionListener listener) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..66d090daf2e83 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -0,0 +1,328 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.persistent.PersistentTaskParams; +import org.elasticsearch.persistent.PersistentTaskState; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; +import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; +import org.elasticsearch.xpack.ml.job.JobNodeSelector; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; +import org.elasticsearch.xpack.ml.task.AbstractJobPersistentTasksExecutor; + +import java.util.Collection; +import java.util.Map; +import java.util.Objects; +import java.util.function.Predicate; + +public class TransportStartTrainedModelDeploymentAction + extends TransportMasterNodeAction { + + private static final Logger logger = LogManager.getLogger(TransportStartTrainedModelDeploymentAction.class); + + private final XPackLicenseState licenseState; + private final Client client; + private final PersistentTasksService persistentTasksService; + private final NamedXContentRegistry xContentRegistry; + + @Inject + public TransportStartTrainedModelDeploymentAction(TransportService transportService, Client client, ClusterService clusterService, + ThreadPool threadPool, ActionFilters actionFilters, XPackLicenseState licenseState, + IndexNameExpressionResolver indexNameExpressionResolver, + PersistentTasksService persistentTasksService, + NamedXContentRegistry xContentRegistry) { + super(StartTrainedModelDeploymentAction.NAME, transportService, clusterService, threadPool, actionFilters, + StartTrainedModelDeploymentAction.Request::new, indexNameExpressionResolver, NodeAcknowledgedResponse::new, + ThreadPool.Names.SAME); + this.licenseState = Objects.requireNonNull(licenseState); + this.client = Objects.requireNonNull(client); + this.persistentTasksService = Objects.requireNonNull(persistentTasksService); + this.xContentRegistry = Objects.requireNonNull(xContentRegistry); + } + + @Override + protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Request request, ClusterState state, + ActionListener listener) throws Exception { + logger.debug(() -> new ParameterizedMessage("[{}] received deploy request", request.getModelId())); + if (licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING) == false) { + listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + return; + } + + ActionListener> waitForDeploymentToStart = + ActionListener.wrap( + startedTask -> waitForDeploymentStarted(startedTask, request.getTimeout(), listener), + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + e = new ElasticsearchStatusException( + "Cannot start deployment [{}] because it has already been started", + RestStatus.CONFLICT, + e, + request.getModelId() + ); + } + listener.onFailure(e); + } + ); + + ActionListener getModelListener = ActionListener.wrap( + getModelResponse -> { + if (getModelResponse.getResources().results().size() > 1) { + listener.onFailure(ExceptionsHelper.badRequestException( + "cannot deploy more than one models at the same time; [{}] matches [{}] models]", + request.getModelId(), getModelResponse.getResources().results().size())); + return; + } + + + TrainedModelConfig trainedModelConfig = getModelResponse.getResources().results().get(0); + if (trainedModelConfig.getModelType() != TrainedModelType.PYTORCH) { + listener.onFailure(ExceptionsHelper.badRequestException( + "model [{}] of type [{}] cannot be deployed. Only PyTorch models can be deployed", + trainedModelConfig.getModelId(), trainedModelConfig.getModelType())); + return; + } + + if (trainedModelConfig.getLocation() == null) { + listener.onFailure(ExceptionsHelper.serverError( + "model [{}] does not have location", trainedModelConfig.getModelId())); + return; + } + + persistentTasksService.sendStartRequest( + MlTasks.trainedModelDeploymentTaskId(request.getModelId()), + MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, + new TaskParams(trainedModelConfig.getLocation().getModelId(), trainedModelConfig.getLocation().getResourceName()), + waitForDeploymentToStart + ); + }, + listener::onFailure + ); + + GetTrainedModelsAction.Request getModelRequest = new GetTrainedModelsAction.Request(request.getModelId()); + client.execute(GetTrainedModelsAction.INSTANCE, getModelRequest, getModelListener); + } + + private void waitForDeploymentStarted(PersistentTasksCustomMetadata.PersistentTask task, + TimeValue timeout, ActionListener listener) { + DeploymentStartedPredicate predicate = new DeploymentStartedPredicate(); + persistentTasksService.waitForPersistentTaskCondition(task.getId(), predicate, timeout, + new PersistentTasksService.WaitForPersistentTaskListener() { + @Override + public void onResponse(PersistentTasksCustomMetadata.PersistentTask persistentTask) { + if (predicate.exception != null) { + cancelDeploymentStart(task, predicate.exception, listener); + } else { + listener.onResponse(new NodeAcknowledgedResponse(true, predicate.node)); + } + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); + } + + private void cancelDeploymentStart( + PersistentTasksCustomMetadata.PersistentTask persistentTask, Exception exception, + ActionListener listener) { + persistentTasksService.sendRemoveRequest(persistentTask.getId(), ActionListener.wrap( + pTask -> listener.onFailure(exception), + e -> { + logger.error( + new ParameterizedMessage("[{}] Failed to cancel persistent task that could not be assigned due to [{}]", + persistentTask.getParams().getModelId(), exception.getMessage()), + e + ); + listener.onFailure(exception); + } + )); + + } + + @Override + protected ClusterBlockException checkBlock(StartTrainedModelDeploymentAction.Request request, ClusterState state) { + // We only delegate here to PersistentTasksService, but if there is a metadata writeblock, + // then delegating to PersistentTasksService doesn't make a whole lot of sense, + // because PersistentTasksService will then fail. + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } + + private static class DeploymentStartedPredicate implements Predicate> { + + private volatile Exception exception; + private volatile String node = ""; + private volatile String assignmentExplanation; + + @Override + public boolean test(PersistentTasksCustomMetadata.PersistentTask persistentTask) { + if (persistentTask == null) { + return false; + } + + PersistentTasksCustomMetadata.Assignment assignment = persistentTask.getAssignment(); + + String reason = "__unknown__"; + + if (assignment != null) { + if (assignment.equals(JobNodeSelector.AWAITING_LAZY_ASSIGNMENT)) { + return true; + } + if (assignment.equals(PersistentTasksCustomMetadata.INITIAL_ASSIGNMENT) == false && assignment.isAssigned() == false) { + exception = new ElasticsearchStatusException("Could not start trained model deployment, allocation explanation [{}]", + RestStatus.TOO_MANY_REQUESTS, assignment.getExplanation()); + return true; + } + } + + TrainedModelDeploymentTaskState taskState = (TrainedModelDeploymentTaskState) persistentTask.getState(); + reason = taskState != null ? taskState.getReason() : reason; + TrainedModelDeploymentState deploymentState = taskState == null ? TrainedModelDeploymentState.STARTED : taskState.getState(); + switch (deploymentState) { + case STARTED: + node = persistentTask.getExecutorNode(); + return true; + case STARTING: + case STOPPING: + case STOPPED: + return false; + default: + exception = ExceptionsHelper.serverError("Unexpected task state [{}] with reason [{}] while waiting to be started", + taskState.getState(), reason); + return true; + } + } + } + + public static class TaskExecutor extends AbstractJobPersistentTasksExecutor { + + private final DeploymentManager manager; + + public TaskExecutor(Settings settings, ClusterService clusterService, IndexNameExpressionResolver expressionResolver, + MlMemoryTracker memoryTracker, DeploymentManager manager) { + super(MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, + MachineLearning.UTILITY_THREAD_POOL_NAME, + settings, + clusterService, + memoryTracker, + expressionResolver); + this.manager = Objects.requireNonNull(manager); + } + + @Override + protected AllocatedPersistentTask createTask( + long id, String type, String action, TaskId parentTaskId, + PersistentTasksCustomMetadata.PersistentTask persistentTask, + Map headers) { + return new TrainedModelDeploymentTask(id, type, action, parentTaskId, headers, persistentTask.getParams()); + } + + @Override + public PersistentTasksCustomMetadata.Assignment getAssignment(TaskParams params, + Collection candidateNodes, + ClusterState clusterState) { + JobNodeSelector jobNodeSelector = + new JobNodeSelector( + clusterState, + candidateNodes, + params.getModelId(), + MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, + memoryTracker, + 0, + node -> nodeFilter(node, params)); + PersistentTasksCustomMetadata.Assignment assignment = jobNodeSelector.selectNode( + maxOpenJobs, + Integer.MAX_VALUE, + maxMachineMemoryPercent, + maxNodeMemory, + useAutoMemoryPercentage + ); + return assignment; + } + + public static String nodeFilter(DiscoveryNode node, TaskParams params) { + String id = params.getModelId(); + + if (node.getVersion().before(TaskParams.VERSION_INTRODUCED)) { + return "Not opening job [" + id + "] on node [" + JobNodeSelector.nodeNameAndVersion(node) + + "], because the data frame analytics requires a node of version [" + + TaskParams.VERSION_INTRODUCED + "] or higher"; + } + + return null; + } + + @Override + protected void nodeOperation(AllocatedPersistentTask task, TaskParams params, PersistentTaskState state) { + TrainedModelDeploymentTask trainedModelDeploymentTask = (TrainedModelDeploymentTask) task; + trainedModelDeploymentTask.setDeploymentManager(manager); + + TrainedModelDeploymentTaskState deployingState = new TrainedModelDeploymentTaskState( + TrainedModelDeploymentState.STARTING, task.getAllocationId(), null); + task.updatePersistentTaskState(deployingState, ActionListener.wrap( + response -> manager.startDeployment(trainedModelDeploymentTask), + task::markAsFailed + )); + } + + @Override + protected String[] indicesOfInterest(TaskParams params) { + return new String[] { + InferenceIndexConstants.INDEX_PATTERN + }; + } + + @Override + protected String getJobId(TaskParams params) { + return params.getModelId(); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..30b3d663dd5b6 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java @@ -0,0 +1,202 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.TaskOperationFailure; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.tasks.TransportTasksAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.discovery.MasterNotDiscoveredException; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; + +import java.util.Collections; +import java.util.List; +import java.util.Set; + +public class TransportStopTrainedModelDeploymentAction extends TransportTasksAction { + + private static final Logger logger = LogManager.getLogger(TransportStopTrainedModelDeploymentAction.class); + + private final Client client; + private final ThreadPool threadPool; + private final PersistentTasksService persistentTasksService; + + @Inject + public TransportStopTrainedModelDeploymentAction(ClusterService clusterService, TransportService transportService, + ActionFilters actionFilters, Client client, ThreadPool threadPool, + PersistentTasksService persistentTasksService) { + super(StopTrainedModelDeploymentAction.NAME, clusterService, transportService, actionFilters, + StopTrainedModelDeploymentAction.Request::new, StopTrainedModelDeploymentAction.Response::new, + StopTrainedModelDeploymentAction.Response::new, ThreadPool.Names.SAME); + this.client = client; + this.threadPool = threadPool; + this.persistentTasksService = persistentTasksService; + } + + @Override + protected void doExecute(Task task, StopTrainedModelDeploymentAction.Request request, + ActionListener listener) { + ClusterState state = clusterService.state(); + DiscoveryNodes nodes = state.nodes(); + if (nodes.isLocalNodeElectedMaster() == false) { + redirectToMasterNode(nodes.getMasterNode(), request, listener); + return; + } + + logger.debug("[{}] Received request to undeploy", request.getId()); + + ActionListener getModelListener = ActionListener.wrap( + getModelsResponse -> { + List models = getModelsResponse.getResources().results(); + if (models.isEmpty()) { + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); + return; + } + if (models.size() > 1) { + listener.onFailure(ExceptionsHelper.badRequestException("cannot undeploy multiple models at the same time")); + return; + } + + ClusterState clusterState = clusterService.state(); + PersistentTasksCustomMetadata tasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE); + PersistentTasksCustomMetadata.PersistentTask deployTrainedModelTask = + MlTasks.getTrainedModelDeploymentTask(request.getId(), tasks); + if (deployTrainedModelTask == null) { + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); + return; + } + normalUndeploy(task, deployTrainedModelTask, request, listener); + }, + listener::onFailure + ); + + GetTrainedModelsAction.Request getModelRequest = new GetTrainedModelsAction.Request( + request.getId(), null, Collections.emptySet()); + getModelRequest.setAllowNoResources(request.isAllowNoMatch()); + client.execute(GetTrainedModelsAction.INSTANCE, getModelRequest, getModelListener); + } + + private void redirectToMasterNode(DiscoveryNode masterNode, StopTrainedModelDeploymentAction.Request request, + ActionListener listener) { + if (masterNode == null) { + listener.onFailure(new MasterNotDiscoveredException()); + } else { + transportService.sendRequest(masterNode, actionName, request, + new ActionListenerResponseHandler<>(listener, StopTrainedModelDeploymentAction.Response::new)); + } + } + + private void normalUndeploy(Task task, PersistentTasksCustomMetadata.PersistentTask deployTrainedModelTask, + StopTrainedModelDeploymentAction.Request request, + ActionListener listener) { + request.setNodes(deployTrainedModelTask.getExecutorNode()); + + ActionListener finalListener = ActionListener.wrap( + r -> waitForTaskRemoved(Collections.singleton(deployTrainedModelTask.getId()), request, r, listener), + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof FailedNodeException) { + // A node has dropped out of the cluster since we started executing the requests. + // Since undeploying an already undeployed trained model is not an error we can try again. + // The tasks that were running on the node that dropped out of the cluster + // will just have their persistent tasks cancelled. Tasks that were stopped + // by the previous attempt will be noops in the subsequent attempt. + doExecute(task, request, listener); + } else { + listener.onFailure(e); + } + } + ); + + super.doExecute(task, request, finalListener); + } + + void waitForTaskRemoved(Set taskIds, StopTrainedModelDeploymentAction.Request request, + StopTrainedModelDeploymentAction.Response response, + ActionListener listener) { + persistentTasksService.waitForPersistentTasksCondition(persistentTasks -> + persistentTasks.findTasks(MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME, t -> taskIds.contains(t.getId())).isEmpty(), + request.getTimeout(), ActionListener.wrap( + booleanResponse -> { + listener.onResponse(response); + }, + listener::onFailure + ) + ); + } + + @Override + protected StopTrainedModelDeploymentAction.Response newResponse(StopTrainedModelDeploymentAction.Request request, + List tasks, + List taskOperationFailures, + List failedNodeExceptions) { + if (taskOperationFailures.isEmpty() == false) { + throw org.elasticsearch.ExceptionsHelper.convertToElastic(taskOperationFailures.get(0).getCause()); + } else if (failedNodeExceptions.isEmpty() == false) { + throw org.elasticsearch.ExceptionsHelper.convertToElastic(failedNodeExceptions.get(0)); + } else { + return new StopTrainedModelDeploymentAction.Response(true); + } + } + + @Override + protected void taskOperation(StopTrainedModelDeploymentAction.Request request, TrainedModelDeploymentTask task, + ActionListener listener) { + TrainedModelDeploymentTaskState undeployingState = new TrainedModelDeploymentTaskState( + TrainedModelDeploymentState.STOPPING, task.getAllocationId(), "api"); + task.updatePersistentTaskState(undeployingState, ActionListener.wrap( + updatedTask -> { + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + + @Override + protected void doRun() throws Exception { + task.stop("undeploy_trained_model (api)"); + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); + } + }); + }, + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + // the task has disappeared so must have stopped + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); + } else { + listener.onFailure(e); + } + } + )); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 59014fc91d8af..6bcfa8ba7651a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; @@ -142,7 +143,7 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo } ModelSizeInfo modelSize = result.getModelSizeInfo(); if (modelSize != null) { - latestModelId = chunkedTrainedModelPersister.createAndIndexInferenceModelConfig(modelSize); + latestModelId = chunkedTrainedModelPersister.createAndIndexInferenceModelConfig(modelSize, TrainedModelType.TREE_ENSEMBLE); } TrainedModelDefinitionChunk trainedModelDefinitionChunk = result.getTrainedModelDefinitionChunk(); if (trainedModelDefinitionChunk != null && isCancelled == false) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java index 6a99ffe24bd5d..9da146b14b0f7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -101,14 +102,14 @@ public void createAndIndexInferenceModelDoc(TrainedModelDefinitionChunk trainedM } } - public String createAndIndexInferenceModelConfig(ModelSizeInfo inferenceModelSize) { + public String createAndIndexInferenceModelConfig(ModelSizeInfo inferenceModelSize, TrainedModelType trainedModelType) { if (readyToStoreNewModel.compareAndSet(true, false) == false) { failureHandler.accept(ExceptionsHelper.serverError( "new inference model is attempting to be stored before completion previous model storage" )); return null; } - TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModelSize); + TrainedModelConfig trainedModelConfig = createTrainedModelConfig(trainedModelType, inferenceModelSize); CountDownLatch latch = storeTrainedModelConfig(trainedModelConfig); try { if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) { @@ -295,7 +296,7 @@ private long customProcessorSize() { + RamUsageEstimator.NUM_BYTES_OBJECT_REF * preProcessors.size(); } - private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) { + private TrainedModelConfig createTrainedModelConfig(TrainedModelType trainedModelType, ModelSizeInfo modelSize) { Instant createTime = Instant.now(); // The native process does not provide estimates for the custom feature_processor objects long customProcessorSize = customProcessorSize(); @@ -312,6 +313,7 @@ private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) { .collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName)); return TrainedModelConfig.builder() .setModelId(modelId) + .setModelType(trainedModelType) .setCreatedBy(XPackUser.NAME) .setVersion(Version.CURRENT) .setCreateTime(createTime) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java new file mode 100644 index 0000000000000..0316271818164 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -0,0 +1,267 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.deployment; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.IdsQueryBuilder; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState; +import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.nlp.NlpTask; +import org.elasticsearch.xpack.ml.inference.nlp.NlpTaskConfig; +import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcess; +import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory; +import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor; +import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class DeploymentManager { + + private static final Logger logger = LogManager.getLogger(DeploymentManager.class); + private static final AtomicLong requestIdCounter = new AtomicLong(1); + + private final Client client; + private final NamedXContentRegistry xContentRegistry; + private final PyTorchProcessFactory pyTorchProcessFactory; + private final ExecutorService executorServiceForDeployment; + private final ExecutorService executorServiceForProcess; + private final ConcurrentMap processContextByAllocation = new ConcurrentHashMap<>(); + + public DeploymentManager(Client client, NamedXContentRegistry xContentRegistry, + ThreadPool threadPool, PyTorchProcessFactory pyTorchProcessFactory) { + this.client = Objects.requireNonNull(client); + this.xContentRegistry = Objects.requireNonNull(xContentRegistry); + this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory); + this.executorServiceForDeployment = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME); + this.executorServiceForProcess = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); + } + + public void startDeployment(TrainedModelDeploymentTask task) { + executorServiceForDeployment.execute(() -> doStartDeployment(task)); + } + + private void doStartDeployment(TrainedModelDeploymentTask task) { + logger.debug("[{}] Starting model deployment", task.getModelId()); + + ProcessContext processContext = new ProcessContext(task.getModelId(), task.getIndex(), executorServiceForProcess); + + if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) { + throw ExceptionsHelper.serverError("[{}] Could not create process as one already exists", task.getModelId()); + } + + ActionListener modelLoadedListener = ActionListener.wrap( + success -> { + executorServiceForProcess.execute(() -> processContext.resultProcessor.process(processContext.process.get())); + + TrainedModelDeploymentTaskState startedState = new TrainedModelDeploymentTaskState( + TrainedModelDeploymentState.STARTED, task.getAllocationId(), null); + task.updatePersistentTaskState(startedState, ActionListener.wrap( + response -> logger.info("[{}] trained model loaded", task.getModelId()), + task::markAsFailed + )); + }, + e -> failTask(task, e) + ); + + ActionListener configListener = ActionListener.wrap( + searchResponse -> { + if (searchResponse.getHits().getHits().length == 0) { + failTask(task, new ResourceNotFoundException( + Messages.getMessage(Messages.TASK_CONFIG_NOT_FOUND, task.getModelId()))); + return; + } + + NlpTaskConfig config = parseConfigDocLeniently(searchResponse.getHits().getAt(0)); + NlpTask nlpTask = NlpTask.fromConfig(config); + processContext.nlpTask.set(nlpTask); + processContext.startProcess(); + processContext.loadModel(modelLoadedListener); + }, + e -> failTask(task, e) + ); + + SearchRequest searchRequest = taskConfigSearchRequest(task.getModelId(), task.getIndex()); + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configListener); + } + + private SearchRequest taskConfigSearchRequest(String modelId, String index) { + return client.prepareSearch(index) + .setQuery(new IdsQueryBuilder().addIds(NlpTaskConfig.documentId(modelId))) + .setSize(1) + .setTrackTotalHits(false) + .request(); + } + + + public NlpTaskConfig parseConfigDocLeniently(SearchHit hit) throws IOException { + + try (InputStream stream = hit.getSourceRef().streamInput(); + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { + return NlpTaskConfig.fromXContent(parser, true); + } catch (IOException e) { + logger.error(new ParameterizedMessage("failed to parse NLP task config [{}]", hit.getId()), e); + throw e; + } + } + + public void stopDeployment(TrainedModelDeploymentTask task) { + ProcessContext processContext; + synchronized (processContextByAllocation) { + processContext = processContextByAllocation.get(task.getAllocationId()); + } + if (processContext != null) { + logger.info("[{}] Stopping deployment", task.getModelId()); + processContext.stopProcess(); + } else { + logger.info("[{}] No process context to stop", task.getModelId()); + } + } + + public void infer(TrainedModelDeploymentTask task, + String input, TimeValue timeout, + ActionListener listener) { + ProcessContext processContext = processContextByAllocation.get(task.getAllocationId()); + + final String requestId = String.valueOf(requestIdCounter.getAndIncrement()); + + executorServiceForProcess.execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + + @Override + protected void doRun() { + try { + NlpTask.Processor processor = processContext.nlpTask.get().createProcessor(); + processor.validateInputs(input); + BytesReference request = processor.getRequestBuilder().buildRequest(input, requestId); + logger.trace("Inference Request "+ request.utf8ToString()); + processContext.process.get().writeInferenceRequest(request); + + waitForResult(processContext, requestId, timeout, processor.getResultProcessor(), listener); + } catch (IOException e) { + logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.modelId), e); + onFailure(ExceptionsHelper.serverError("error writing to process", e)); + } + } + }); + } + + private void waitForResult(ProcessContext processContext, + String requestId, + TimeValue timeout, + NlpTask.ResultProcessor inferenceResultsProcessor, + ActionListener listener) { + try { + PyTorchResult pyTorchResult = processContext.resultProcessor.waitForResult(requestId, timeout); + if (pyTorchResult == null) { + listener.onFailure(new ElasticsearchStatusException("timeout [{}] waiting for inference result", + RestStatus.TOO_MANY_REQUESTS, timeout)); + return; + } + + if (pyTorchResult.isError()) { + listener.onFailure(new ElasticsearchStatusException(pyTorchResult.getError(), + RestStatus.INTERNAL_SERVER_ERROR)); + return; + } + + logger.debug(() -> new ParameterizedMessage("[{}] retrieved result for request [{}]", processContext.modelId, requestId)); + InferenceResults results = inferenceResultsProcessor.processResult(pyTorchResult); + logger.debug(() -> new ParameterizedMessage("[{}] processed result for request [{}]", processContext.modelId, requestId)); + listener.onResponse(results); + } catch (InterruptedException e) { + listener.onFailure(e); + } + } + + private void failTask(TrainedModelDeploymentTask task, Exception e) { + logger.error(new ParameterizedMessage("[{}] failing model deployment task with error: ", task.getModelId()), e); + task.markAsFailed(e); + } + + class ProcessContext { + + private final String modelId; + private final String index; + private final SetOnce process = new SetOnce<>(); + private final SetOnce nlpTask = new SetOnce<>(); + private final PyTorchResultProcessor resultProcessor; + private final PyTorchStateStreamer stateStreamer; + + ProcessContext(String modelId, String index, ExecutorService executorService) { + this.modelId = Objects.requireNonNull(modelId); + this.index = Objects.requireNonNull(index); + resultProcessor = new PyTorchResultProcessor(modelId); + this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry); + } + + synchronized void startProcess() { + process.set(pyTorchProcessFactory.createProcess(modelId, executorServiceForProcess, onProcessCrash())); + } + + synchronized void stopProcess() { + resultProcessor.stop(); + if (process.get() == null) { + return; + } + try { + stateStreamer.cancel(); + process.get().kill(true); + } catch (IOException e) { + logger.error(new ParameterizedMessage("[{}] Failed to kill process", modelId), e); + } + } + + private Consumer onProcessCrash() { + return reason -> logger.error("[{}] process crashed due to reason [{}]", modelId, reason); + } + + void loadModel(ActionListener listener) { + process.get().loadModel(modelId, index, stateStreamer, listener); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java new file mode 100644 index 0000000000000..290b6f1a2e34e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.deployment; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; + +import java.util.Map; + +public class TrainedModelDeploymentTask extends AllocatedPersistentTask implements StartTrainedModelDeploymentAction.TaskMatcher { + + private static final Logger logger = LogManager.getLogger(TrainedModelDeploymentTask.class); + + private final TaskParams params; + private volatile DeploymentManager manager; + + public TrainedModelDeploymentTask(long id, String type, String action, TaskId parentTask, Map headers, + TaskParams taskParams) { + super(id, type, action, MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + taskParams.getModelId(), parentTask, headers); + this.params = taskParams; + } + + public String getModelId() { + return params.getModelId(); + } + + public String getIndex() { + return params.getIndex(); + } + + public void stop(String reason) { + logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); + + assert manager != null : "manager should not be unset when stop is called"; + manager.stopDeployment(this); + markAsCompleted(); + } + + public void setDeploymentManager(DeploymentManager manager) { + this.manager = manager; + } + + @Override + protected void onCancelled() { + String reason = getReasonCancelled(); + stop(reason); + } + + public void infer(String input, TimeValue timeout, ActionListener listener) { + manager.infer(this, input, timeout, listener); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java new file mode 100644 index 0000000000000..09b754c98f31d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; + +import java.io.IOException; +import java.util.Arrays; + +public class BertRequestBuilder implements NlpTask.RequestBuilder { + + static final String REQUEST_ID = "request_id"; + static final String TOKENS = "tokens"; + static final String ARG1 = "arg_1"; + static final String ARG2 = "arg_2"; + static final String ARG3 = "arg_3"; + + private final BertTokenizer tokenizer; + private BertTokenizer.TokenizationResult tokenization; + + public BertRequestBuilder(BertTokenizer tokenizer) { + this.tokenizer = tokenizer; + } + + public BertTokenizer.TokenizationResult getTokenization() { + return tokenization; + } + + @Override + public BytesReference buildRequest(String input, String requestId) throws IOException { + tokenization = tokenizer.tokenize(input, true); + return jsonRequest(tokenization.getTokenIds(), requestId); + } + + static BytesReference jsonRequest(int[] tokens, String requestId) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field(REQUEST_ID, requestId); + builder.array(TOKENS, tokens); + + int[] inputMask = new int[tokens.length]; + Arrays.fill(inputMask, 1); + int[] segmentMask = new int[tokens.length]; + Arrays.fill(segmentMask, 0); + int[] positionalIds = new int[tokens.length]; + Arrays.setAll(positionalIds, i -> i); + + builder.array(ARG1, inputMask); + builder.array(ARG2, segmentMask); + builder.array(ARG3, positionalIds); + builder.endObject(); + + // BytesReference.bytes closes the builder + return BytesReference.bytes(builder); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java new file mode 100644 index 0000000000000..9333760594295 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class FillMaskProcessor implements NlpTask.Processor { + + private static final int NUM_RESULTS = 5; + + private final BertRequestBuilder bertRequestBuilder; + + FillMaskProcessor(BertTokenizer tokenizer) { + this.bertRequestBuilder = new BertRequestBuilder(tokenizer); + } + + @Override + public void validateInputs(String inputs) { + if (inputs.isBlank()) { + throw new IllegalArgumentException("input request is empty"); + } + + int maskIndex = inputs.indexOf(BertTokenizer.MASK_TOKEN); + if (maskIndex < 0) { + throw new IllegalArgumentException("no " + BertTokenizer.MASK_TOKEN + " token could be found"); + } + + maskIndex = inputs.indexOf(BertTokenizer.MASK_TOKEN, maskIndex + BertTokenizer.MASK_TOKEN.length()); + if (maskIndex > 0) { + throw new IllegalArgumentException("only one " + BertTokenizer.MASK_TOKEN + " token should exist in the input"); + } + } + + @Override + public NlpTask.RequestBuilder getRequestBuilder() { + return bertRequestBuilder; + } + + @Override + public NlpTask.ResultProcessor getResultProcessor() { + return (pyTorchResult) -> processResult(bertRequestBuilder.getTokenization(), pyTorchResult); + } + + InferenceResults processResult(BertTokenizer.TokenizationResult tokenization, + PyTorchResult pyTorchResult) { + + if (tokenization.getTokens().isEmpty()) { + return new FillMaskResults(Collections.emptyList()); + } + + int maskTokenIndex = tokenization.getTokens().indexOf(BertTokenizer.MASK_TOKEN); + double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[maskTokenIndex]); + + NlpHelpers.ScoreAndIndex[] scoreAndIndices = NlpHelpers.topK(NUM_RESULTS, normalizedScores); + List results = new ArrayList<>(NUM_RESULTS); + for (NlpHelpers.ScoreAndIndex scoreAndIndex : scoreAndIndices) { + String predictedToken = tokenization.getFromVocab(scoreAndIndex.index); + String sequence = tokenization.getInput().replace(BertTokenizer.MASK_TOKEN, predictedToken); + results.add(new FillMaskResults.Prediction(predictedToken, scoreAndIndex.score, sequence)); + } + return new FillMaskResults(results); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java new file mode 100644 index 0000000000000..730b6f548ebcd --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; + +import java.io.IOException; +import java.util.Locale; + +public class NerProcessor implements NlpTask.Processor { + + public enum Entity implements Writeable { + NONE, MISC, PERSON, ORGANISATION, LOCATION; + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(this); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + } + + // Inside-Outside-Beginning (IOB) tag + enum IobTag { + O(Entity.NONE), // Outside of a named entity + B_MISC(Entity.MISC), // Beginning of a miscellaneous entity right after another miscellaneous entity + I_MISC(Entity.MISC), // Miscellaneous entity + B_PER(Entity.PERSON), // Beginning of a person's name right after another person's name + I_PER(Entity.PERSON), // Person's name + B_ORG(Entity.ORGANISATION), // Beginning of an organisation right after another organisation + I_ORG(Entity.ORGANISATION), // Organisation + B_LOC(Entity.LOCATION), // Beginning of a location right after another location + I_LOC(Entity.LOCATION); // Location + + private final Entity entity; + + IobTag(Entity entity) { + this.entity = entity; + } + + Entity getEntity() { + return entity; + } + + boolean isBeginning() { + return name().toLowerCase(Locale.ROOT).startsWith("b"); + } + } + + + private final BertRequestBuilder bertRequestBuilder; + + NerProcessor(BertTokenizer tokenizer) { + this.bertRequestBuilder = new BertRequestBuilder(tokenizer); + } + + @Override + public void validateInputs(String inputs) { + // No validation + } + + @Override + public NlpTask.RequestBuilder getRequestBuilder() { + return bertRequestBuilder; + } + + @Override + public NlpTask.ResultProcessor getResultProcessor() { + return new NerResultProcessor(bertRequestBuilder.getTokenization()); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerResultProcessor.java new file mode 100644 index 0000000000000..7a4309928dc36 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerResultProcessor.java @@ -0,0 +1,147 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.NerResults; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +class NerResultProcessor implements NlpTask.ResultProcessor { + + private final BertTokenizer.TokenizationResult tokenization; + + NerResultProcessor(BertTokenizer.TokenizationResult tokenization) { + this.tokenization = Objects.requireNonNull(tokenization); + } + + @Override + public InferenceResults processResult(PyTorchResult pyTorchResult) { + if (tokenization.getTokens().isEmpty()) { + return new NerResults(Collections.emptyList()); + } + // TODO It might be best to do the soft max after averaging scores for + // sub-tokens. If we had a word that is "elastic" which is tokenized to + // "el" and "astic" then perhaps we get a prediction for org of 10 for "el" + // and -5 for "astic". Averaging after softmax would produce a prediction + // of maybe (1 + 0) / 2 = 0.5 while before softmax it'd be exp(10 - 5) / normalization + // which could easily be close to 1. + double[][] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()); + List taggedTokens = tagTokens(normalizedScores); + List entities = groupTaggedTokens(taggedTokens); + return new NerResults(entities); + } + + /** + * Here we tag each token with the IoB label that has the max score. + * Additionally, we merge sub-tokens that are part of the same word + * in the original input replacing them with a single token that + * gets labelled based on the average score of all its sub-tokens. + */ + private List tagTokens(double[][] scores) { + List taggedTokens = new ArrayList<>(); + int startTokenIndex = 0; + while (startTokenIndex < tokenization.getTokens().size()) { + int inputMapping = tokenization.getTokenMap()[startTokenIndex]; + if (inputMapping < 0) { + // This token does not map to a token in the input (special tokens) + startTokenIndex++; + continue; + } + int endTokenIndex = startTokenIndex; + StringBuilder word = new StringBuilder(tokenization.getTokens().get(startTokenIndex)); + while (endTokenIndex < tokenization.getTokens().size() - 1 && tokenization.getTokenMap()[endTokenIndex + 1] == inputMapping) { + endTokenIndex++; + // TODO Here we try to get rid of the continuation hashes at the beginning of sub-tokens. + // It is probably more correct to implement detokenization on the tokenizer + // that does reverse lookup based on token IDs. + String endTokenWord = tokenization.getTokens().get(endTokenIndex).substring(2); + word.append(endTokenWord); + } + double[] avgScores = Arrays.copyOf(scores[startTokenIndex], NerProcessor.IobTag.values().length); + for (int i = startTokenIndex + 1; i <= endTokenIndex; i++) { + for (int j = 0; j < scores[i].length; j++) { + avgScores[j] += scores[i][j]; + } + } + int numTokensInBlock = endTokenIndex - startTokenIndex + 1; + if (numTokensInBlock > 1) { + for (int i = 0; i < avgScores.length; i++) { + avgScores[i] /= numTokensInBlock; + } + } + int maxScoreIndex = NlpHelpers.argmax(avgScores); + double score = avgScores[maxScoreIndex]; + taggedTokens.add(new TaggedToken(word.toString(), NerProcessor.IobTag.values()[maxScoreIndex], score)); + startTokenIndex = endTokenIndex + 1; + } + return taggedTokens; + } + + /** + * Now that we have merged sub-tokens and tagged them with their IoB label, + * we group tokens together into the final entity groups. Effectively, + * we want to group B_X I_X B_X so that it results into two + * entities, one for the first B_X I_X and another for the latter B_X, + * where X is the same entity. + * When multiple tokens are grouped together, the entity score is the + * mean score of the tokens. + */ + static List groupTaggedTokens(List tokens) { + if (tokens.isEmpty()) { + return Collections.emptyList(); + } + List entities = new ArrayList<>(); + int startTokenIndex = 0; + while (startTokenIndex < tokens.size()) { + TaggedToken token = tokens.get(startTokenIndex); + if (token.tag.getEntity() == NerProcessor.Entity.NONE) { + startTokenIndex++; + continue; + } + StringBuilder entityWord = new StringBuilder(token.word); + int endTokenIndex = startTokenIndex + 1; + double scoreSum = token.score; + while (endTokenIndex < tokens.size()) { + TaggedToken endToken = tokens.get(endTokenIndex); + if (endToken.tag.isBeginning() || endToken.tag.getEntity() != token.tag.getEntity()) { + break; + } + // TODO Here we add a space between tokens. + // It is probably more correct to implement detokenization on the tokenizer + // that does reverse lookup based on token IDs. + entityWord.append(" ").append(endToken.word); + scoreSum += endToken.score; + endTokenIndex++; + } + entities.add(new NerResults.EntityGroup(token.tag.getEntity().toString(), + scoreSum / (endTokenIndex - startTokenIndex), entityWord.toString())); + startTokenIndex = endTokenIndex; + } + + return entities; + } + + static class TaggedToken { + private final String word; + private final NerProcessor.IobTag tag; + private final double score; + + TaggedToken(String word, NerProcessor.IobTag tag, double score) { + this.word = word; + this.tag = tag; + this.score = score; + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpHelpers.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpHelpers.java new file mode 100644 index 0000000000000..1509f5172e8f7 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpHelpers.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.search.aggregations.pipeline.MovingFunctions; + +import java.util.Comparator; +import java.util.Objects; +import java.util.PriorityQueue; + +public final class NlpHelpers { + + private NlpHelpers() {} + + static double[][] convertToProbabilitiesBySoftMax(double[][] scores) { + double[][] probabilities = new double[scores.length][]; + double[] sum = new double[scores.length]; + for (int i = 0; i < scores.length; i++) { + probabilities[i] = new double[scores[i].length]; + double maxScore = MovingFunctions.max(scores[i]); + for (int j = 0; j < scores[i].length; j++) { + probabilities[i][j] = Math.exp(scores[i][j] - maxScore); + sum[i] += probabilities[i][j]; + } + } + for (int i = 0; i < scores.length; i++) { + for (int j = 0; j < scores[i].length; j++) { + probabilities[i][j] /= sum[i]; + } + } + return probabilities; + } + + static double[] convertToProbabilitiesBySoftMax(double[] scores) { + double[] probabilities = new double[scores.length]; + double sum = 0.0; + double maxScore = MovingFunctions.max(scores); + for (int i = 0; i < scores.length; i++) { + probabilities[i] = Math.exp(scores[i] - maxScore); + sum += probabilities[i]; + } + for (int i = 0; i < scores.length; i++) { + probabilities[i] /= sum; + } + return probabilities; + } + + /** + * Find the index of the highest value in {@code arr} + * @param arr Array to search + * @return Index of highest value + */ + static int argmax(double[] arr) { + int maxIndex = 0; + for (int i = 1; i < arr.length; i++) { + if (arr[i] > arr[maxIndex]) { + maxIndex = i; + } + } + return maxIndex; + } + + + /** + * Find the top K highest values in {@code arr} and their + * index positions. Similar to {@link #argmax(double[])} + * but generalised to k instead of just 1. If {@code arr.length < k} + * then {@code arr.length} items are returned. + * + * The function uses a PriorityQueue of size {@code k} to + * track the highest values + * + * @param k Number of values to track + * @param arr Array to search + * @return Index positions and values of the top k elements. + */ + static ScoreAndIndex[] topK(int k, double[] arr) { + if (k > arr.length) { + k = arr.length; + } + + PriorityQueue minHeap = new PriorityQueue<>(k, Comparator.comparingDouble(o -> o.score)); + // initialise with the first k values + for (int i=0; i minValue) { + minHeap.poll(); + minHeap.add(new ScoreAndIndex(arr[i], i)); + minValue = minHeap.peek().score; + } + } + + ScoreAndIndex[] result = new ScoreAndIndex[k]; + // The result should be ordered highest score first + // so reverse the min heap order + for (int i=k-1; i>=0; i--) { + result[i] = minHeap.poll(); + } + return result; + } + + public static class ScoreAndIndex { + final double score; + final int index; + + ScoreAndIndex(double value, int index) { + this.score = value; + this.index = index; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ScoreAndIndex that = (ScoreAndIndex) o; + return Double.compare(that.score, score) == 0 && index == that.index; + } + + @Override + public int hashCode() { + return Objects.hash(score, index); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java new file mode 100644 index 0000000000000..3e3065c5f1b9a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; + +import java.io.IOException; + +public class NlpTask { + + private final TaskType taskType; + private final BertTokenizer tokenizer; + + public static NlpTask fromConfig(NlpTaskConfig config) { + return new NlpTask(config.getTaskType(), config.buildTokenizer()); + } + + private NlpTask(TaskType taskType, BertTokenizer tokenizer) { + this.taskType = taskType; + this.tokenizer = tokenizer; + } + + public Processor createProcessor() throws IOException { + return taskType.createProcessor(tokenizer); + } + + public interface RequestBuilder { + BytesReference buildRequest(String inputs, String requestId) throws IOException; + } + + public interface ResultProcessor { + InferenceResults processResult(PyTorchResult pyTorchResult); + } + + public interface Processor { + /** + * Validate the task input. + * Throws an exception if the inputs fail validation + * + * @param inputs Text to validate + */ + void validateInputs(String inputs); + + RequestBuilder getRequestBuilder(); + ResultProcessor getResultProcessor(); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTaskConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTaskConfig.java new file mode 100644 index 0000000000000..47086bcba532c --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTaskConfig.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class NlpTaskConfig implements ToXContentObject { + + public static final ParseField VOCAB = new ParseField("vocab"); + public static final ParseField TASK_TYPE = new ParseField("task_type"); + public static final ParseField LOWER_CASE = new ParseField("do_lower_case"); + + private static final ObjectParser STRICT_PARSER = createParser(false); + private static final ObjectParser LENIENT_PARSER = createParser(true); + + private static ObjectParser createParser(boolean ignoreUnknownFields) { + ObjectParser parser = new ObjectParser<>("task_config", + ignoreUnknownFields, + Builder::new); + + parser.declareStringArray(Builder::setVocabulary, VOCAB); + parser.declareString(Builder::setTaskType, TASK_TYPE); + parser.declareBoolean(Builder::setDoLowerCase, LOWER_CASE); + return parser; + } + + public static NlpTaskConfig fromXContent(XContentParser parser, boolean lenient) { + return lenient ? LENIENT_PARSER.apply(parser, null).build() : STRICT_PARSER.apply(parser, null).build(); + } + + public static String documentId(String model) { + return model + "_task_config"; + } + + private final TaskType taskType; + private final List vocabulary; + private final boolean doLowerCase; + + NlpTaskConfig(TaskType taskType, List vocabulary, boolean doLowerCase) { + this.taskType = taskType; + this.vocabulary = vocabulary; + this.doLowerCase = doLowerCase; + } + + public TaskType getTaskType() { + return taskType; + } + + public BertTokenizer buildTokenizer() { + return BertTokenizer.builder(vocabulary).setDoLowerCase(doLowerCase).build(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TASK_TYPE.getPreferredName(), taskType.toString()); + builder.field(VOCAB.getPreferredName(), vocabulary); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NlpTaskConfig that = (NlpTaskConfig) o; + return taskType == that.taskType && + doLowerCase == that.doLowerCase && + Objects.equals(vocabulary, that.vocabulary); + } + + @Override + public int hashCode() { + return Objects.hash(taskType, vocabulary, doLowerCase); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private TaskType taskType; + private List vocabulary; + private boolean doLowerCase = false; + + public Builder setTaskType(TaskType taskType) { + this.taskType = taskType; + return this; + } + + public Builder setTaskType(String taskType) { + this.taskType = TaskType.fromString(taskType); + return this; + } + + public Builder setVocabulary(List vocab) { + this.vocabulary = vocab; + return this; + } + + public Builder setDoLowerCase(boolean doLowerCase) { + this.doLowerCase = doLowerCase; + return this; + } + + public NlpTaskConfig build() { + return new NlpTaskConfig(taskType, vocabulary, doLowerCase); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java new file mode 100644 index 0000000000000..b5d5074a9ee57 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; + +import java.io.IOException; +import java.util.Locale; + +public enum TaskType { + + TOKEN_CLASSIFICATION { + public NlpTask.Processor createProcessor(BertTokenizer tokenizer) throws IOException { + return new NerProcessor(tokenizer); + } + }, + FILL_MASK { + public NlpTask.Processor createProcessor(BertTokenizer tokenizer) throws IOException { + return new FillMaskProcessor(tokenizer); + } + }; + + public NlpTask.Processor createProcessor(BertTokenizer tokenizer) throws IOException { + throw new UnsupportedOperationException("json request must be specialised for task type [" + this.name() + "]"); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + + public static TaskType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java new file mode 100644 index 0000000000000..e6bb49e41be53 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java @@ -0,0 +1,310 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import joptsimple.internal.Strings; + +import java.text.Normalizer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Predicate; + +/** + * Basic tokenization of text by whitespace with optional extras: + * 1. Lower case the input + * 2. Convert to Unicode NFD + * 3. Stip accents + * 4. Surround CJK characters with ' ' + * + * Derived from + * https://github.com/huggingface/transformers/blob/ba8c4d0ac04acfcdbdeaed954f698d6d5ec3e532/src/transformers/tokenization_bert.py + */ +public class BasicTokenizer { + + private final boolean isLowerCase; + private final boolean isTokenizeCjkChars; + private final boolean isStripAccents; + private final Set neverSplit; + + /** + * Tokenizer behaviour is controlled by the options passed here. + * + * @param isLowerCase If true convert the input to lowercase + * @param isTokenizeCjkChars Should CJK ideographs be tokenized + * @param isStripAccents Strip all accents + * @param neverSplit The set of tokens that should not be split + */ + public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean isStripAccents, + Set neverSplit) { + this.isLowerCase = isLowerCase; + this.isTokenizeCjkChars = isTokenizeCjkChars; + this.isStripAccents = isStripAccents; + this.neverSplit = neverSplit; + } + + public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars, boolean isStripAccents) { + this.isLowerCase = isLowerCase; + this.isTokenizeCjkChars = isTokenizeCjkChars; + this.isStripAccents = isStripAccents; + this.neverSplit = Collections.emptySet(); + } + + /** + * Tokenize CJK chars defaults to the value of {@code isLowerCase} + * when not explicitly set + * @param isLowerCase If true convert the input to lowercase + * @param isTokenizeCjkChars Should CJK ideographs be tokenized + */ + public BasicTokenizer(boolean isLowerCase, boolean isTokenizeCjkChars) { + this(isLowerCase, isTokenizeCjkChars, isLowerCase); + } + + BasicTokenizer() { + this(true, true, true); + } + + /** + * Clean the text and whitespace tokenize then process depending + * on the values of {@code lowerCase}, {@code tokenizeCjkChars}, + * {@code stripAccents} and the contents of {@code neverSplit} + * + * @param text The input text to tokenize + * @return List of tokens + */ + public List tokenize(String text) { + text = cleanText(text); + if (isTokenizeCjkChars) { + text = tokenizeCjkChars(text); + } + + String [] tokens = whiteSpaceTokenize(text); + + List processedTokens = new ArrayList<>(tokens.length); + for (String token : tokens) { + + if (Strings.EMPTY.equals(token)) { + continue; + } + + if (neverSplit.contains(token)) { + processedTokens.add(token); + continue; + } + + // At this point text has been tokenized by whitespace + // but one of the special never split tokens could be adjacent + // to a punctuation character. + if (isCommonPunctuation(token.codePointAt(token.length() -1)) && + neverSplit.contains(token.substring(0, token.length() -1))) { + processedTokens.add(token.substring(0, token.length() -1)); + processedTokens.add(token.substring(token.length() -1)); + continue; + } + + if (isLowerCase) { + token = token.toLowerCase(Locale.ROOT); + } + if (isStripAccents) { + token = stripAccents(token); + } + processedTokens.addAll(splitOnPunctuation(token)); + } + + return processedTokens; + } + + public boolean isLowerCase() { + return isLowerCase; + } + + public boolean isStripAccents() { + return isStripAccents; + } + + public boolean isTokenizeCjkChars() { + return isTokenizeCjkChars; + } + + static String [] whiteSpaceTokenize(String text) { + text = text.trim(); + return text.split(" "); + } + + /** + * Normalize unicode text to NFD form + * "Characters are decomposed by canonical equivalence, and multiple + * combining characters are arranged in a specific order" + * from https://en.wikipedia.org/wiki/Unicode_equivalence#Normal_forms + * + * And remove non-spacing marks https://www.compart.com/en/unicode/category/Mn + * + * @param word Word to strip + * @return {@code word} normalized and stripped. + */ + static String stripAccents(String word) { + String normalizedString = Normalizer.normalize(word, Normalizer.Form.NFD); + + int [] codePoints = normalizedString.codePoints() + .filter(codePoint -> Character.getType(codePoint) != Character.NON_SPACING_MARK) + .toArray(); + + return new String(codePoints, 0, codePoints.length); + } + + static List splitOnPunctuation(String word) { + return splitOnPredicate(word, BasicTokenizer::isPunctuationMark); + } + + static List splitOnPredicate(String word, Predicate test) { + List split = new ArrayList<>(); + int [] codePoints = word.codePoints().toArray(); + + int lastSplit = 0; + for (int i=0; i 0) { + // add a new string for what has gone before + split.add(new String(codePoints, lastSplit, i - lastSplit)); + } + split.add(new String(codePoints, i, 1)); + lastSplit = i+1; + } + } + + if (lastSplit < codePoints.length) { + split.add(new String(codePoints, lastSplit, codePoints.length - lastSplit)); + } + + return split; + } + + /** + * Surrounds any CJK character with whitespace + * @param text To tokenize + * @return tokenized text + */ + static String tokenizeCjkChars(String text) { + StringBuilder sb = new StringBuilder(text.length()); + AtomicBoolean cjkCharFound = new AtomicBoolean(false); + + text.codePoints().forEach(cp -> { + if (isCjkChar(cp)) { + sb.append(' '); + sb.appendCodePoint(cp); + sb.append(' '); + cjkCharFound.set(true); + } else { + sb.appendCodePoint(cp); + } + }); + + // no change + if (cjkCharFound.get() == false) { + return text; + } + + return sb.toString(); + } + + /** + * Remove control chars and normalize white space to ' ' + * @param text Text to clean + * @return Cleaned text + */ + static String cleanText(String text) { + int [] codePoints = text.codePoints() + .filter(codePoint -> (codePoint == 0x00 || codePoint == 0xFFFD || isControlChar(codePoint)) == false) + .map(codePoint -> isWhiteSpace(codePoint) ? ' ' : codePoint) + .toArray(); + + return new String(codePoints, 0, codePoints.length); + } + + static boolean isCjkChar(int codePoint) { + // https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + Character.UnicodeBlock block = Character.UnicodeBlock.of(codePoint); + return Character.UnicodeBlock.CJK_COMPATIBILITY_IDEOGRAPHS.equals(block) || + Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS.equals(block) || + Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS_EXTENSION_A.equals(block) || + Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS_EXTENSION_B.equals(block) || + Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS_EXTENSION_C.equals(block) || + Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS_EXTENSION_D.equals(block) || + Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS_EXTENSION_E.equals(block) || + Character.UnicodeBlock.CJK_COMPATIBILITY_IDEOGRAPHS_SUPPLEMENT.equals(block); + } + + /** + * newline, carriage return and tab are control chars but for + * tokenization purposes they are treated as whitespace. + * + * @param codePoint code point + * @return is control char + */ + static boolean isControlChar(int codePoint) { + if (codePoint == '\n' || codePoint == '\r' || codePoint == '\t' ) { + return false; + } + int category = Character.getType(codePoint); + + return category >= Character.CONTROL && category <= Character.SURROGATE; + } + + /** + * newline, carriage return and tab are technically control chars + * but are not part of the Unicode Space Separator (Zs) group. + * For tokenization purposes they are treated as whitespace + * + * @param codePoint code point + * @return is white space + */ + static boolean isWhiteSpace(int codePoint) { + if (codePoint == '\n' || codePoint == '\r' || codePoint == '\t' ) { + return true; + } + return Character.getType(codePoint) == Character.SPACE_SEPARATOR; + } + + /** + * We treat all non-letter/number ASCII as punctuation. + * Characters such as "^", "$", and "`" are not in the Unicode + * Punctuation class but are treated as punctuation for consistency. + * + * @param codePoint code point + * @return true if is punctuation + */ + static boolean isPunctuationMark(int codePoint) { + if ((codePoint >= 33 && codePoint <= 47) || + (codePoint >= 58 && codePoint <= 64) || + (codePoint >= 91 && codePoint <= 96) || + (codePoint >= 123 && codePoint <= 126)) { + return true; + } + + int category = Character.getType(codePoint); + return category >= Character.DASH_PUNCTUATION && category <= Character.OTHER_PUNCTUATION; + } + + /** + * True if the code point is for a common punctuation character + * {@code ! " # $ % & ' ( ) * + , - . / and : ; < = > ?} + * @param codePoint codepoint + * @return true if codepoint is punctuation + */ + static boolean isCommonPunctuation(int codePoint) { + if ((codePoint >= 33 && codePoint <= 47) || + (codePoint >= 58 && codePoint <= 64) ) { + return true; + } + + return false; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java new file mode 100644 index 0000000000000..78b3839b4382f --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java @@ -0,0 +1,246 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import org.elasticsearch.common.util.set.Sets; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.SortedMap; +import java.util.TreeMap; + +/** + * Performs basic tokenization and normalization of input text + * then tokenizes with the WordPiece algorithm using the given + * vocabulary. + *

+ * Derived from + * https://github.com/huggingface/transformers/blob/ba8c4d0ac04acfcdbdeaed954f698d6d5ec3e532/src/transformers/tokenization_bert.py + */ +public class BertTokenizer { + + public static final String UNKNOWN_TOKEN = "[UNK]"; + public static final String SEPARATOR_TOKEN = "[SEP]"; + public static final String PAD_TOKEN = "[PAD]"; + public static final String CLASS_TOKEN = "[CLS]"; + public static final String MASK_TOKEN = "[MASK]"; + + public static final int SPECIAL_TOKEN_POSITION = -1; + + public static final int DEFAULT_MAX_INPUT_CHARS_PER_WORD = 100; + + private final Set NEVER_SPLIT = new HashSet<>(Arrays.asList(MASK_TOKEN)); + + private final WordPieceTokenizer wordPieceTokenizer; + private final List originalVocab; + // TODO Not sure this needs to be a sorted map + private final SortedMap vocab; + private final boolean doLowerCase; + private final boolean doTokenizeCjKChars; + private final boolean doStripAccents; + private final Set neverSplit; + + private BertTokenizer( + List originalVocab, + SortedMap vocab, + boolean doLowerCase, + boolean doTokenizeCjKChars, + boolean doStripAccents, + Set neverSplit) { + wordPieceTokenizer = new WordPieceTokenizer(vocab, UNKNOWN_TOKEN, DEFAULT_MAX_INPUT_CHARS_PER_WORD); + this.originalVocab = originalVocab; + this.vocab = vocab; + this.doLowerCase = doLowerCase; + this.doTokenizeCjKChars = doTokenizeCjKChars; + this.doStripAccents = doStripAccents; + this.neverSplit = Sets.union(neverSplit, NEVER_SPLIT); + } + + public TokenizationResult tokenize(String text) { + return tokenize(text, true); + } + + /** + * Tokenize the input according to the basic tokenization options + * then perform Word Piece tokenization with the given vocabulary. + * + * The result is the Word Piece tokens, a map of the Word Piece + * token position to the position of the token in the source + * @param text Text to tokenize + * @param withSpecialTokens Include CLS and SEP tokens + * @return Tokenized text, token Ids and map + */ + public TokenizationResult tokenize(String text, boolean withSpecialTokens) { + BasicTokenizer basicTokenizer = new BasicTokenizer(doLowerCase, doTokenizeCjKChars, doStripAccents, neverSplit); + + List delineatedTokens = basicTokenizer.tokenize(text); + List wordPieceTokens = new ArrayList<>(); + List tokenPositionMap = new ArrayList<>(); + if (withSpecialTokens) { + // insert the first token to simplify the loop counter logic later + tokenPositionMap.add(SPECIAL_TOKEN_POSITION); + } + + for (int sourceIndex = 0; sourceIndex < delineatedTokens.size(); sourceIndex++) { + String token = delineatedTokens.get(sourceIndex); + if (neverSplit.contains(token)) { + wordPieceTokens.add(new WordPieceTokenizer.TokenAndId(token, vocab.getOrDefault(token, vocab.get(UNKNOWN_TOKEN)))); + tokenPositionMap.add(sourceIndex); + } else { + List tokens = wordPieceTokenizer.tokenize(token); + for (int tokenCount = 0; tokenCount < tokens.size(); tokenCount++) { + tokenPositionMap.add(sourceIndex); + } + wordPieceTokens.addAll(tokens); + } + } + + int numTokens = withSpecialTokens ? wordPieceTokens.size() + 2 : wordPieceTokens.size(); + List tokens = new ArrayList<>(numTokens); + int [] tokenIds = new int[numTokens]; + int [] tokenMap = new int[numTokens]; + + if (withSpecialTokens) { + tokens.add(CLASS_TOKEN); + tokenIds[0] = vocab.get(CLASS_TOKEN); + tokenMap[0] = SPECIAL_TOKEN_POSITION; + } + + int i = withSpecialTokens ? 1 : 0; + for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokens) { + tokens.add(tokenAndId.getToken()); + tokenIds[i] = tokenAndId.getId(); + tokenMap[i] = tokenPositionMap.get(i); + i++; + } + + if (withSpecialTokens) { + tokens.add(SEPARATOR_TOKEN); + tokenIds[i] = vocab.get(SEPARATOR_TOKEN); + tokenMap[i] = SPECIAL_TOKEN_POSITION; + } + + return new TokenizationResult(text, originalVocab, tokens, tokenIds, tokenMap); + } + + public static class TokenizationResult { + + String input; + List vocab; + private final List tokens; + private final int [] tokenIds; + private final int [] tokenMap; + + public TokenizationResult(String input, List vocab, List tokens, int[] tokenIds, int[] tokenMap) { + assert tokens.size() == tokenIds.length; + assert tokenIds.length == tokenMap.length; + this.input = input; + this.vocab = vocab; + this.tokens = tokens; + this.tokenIds = tokenIds; + this.tokenMap = tokenMap; + } + + public String getFromVocab(int tokenId) { + return vocab.get(tokenId); + } + + /** + * The token strings from the tokenization process + * @return A list of tokens + */ + public List getTokens() { + return tokens; + } + + /** + * The integer values of the tokens in {@link #getTokens()} + * @return A list of token Ids + */ + public int[] getTokenIds() { + return tokenIds; + } + + /** + * Maps the token position to the position in the source text. + * Source words may be divided into more than one token so more + * than one token can map back to the source token + * @return Map of source token to + */ + public int[] getTokenMap() { + return tokenMap; + } + + public String getInput() { + return input; + } + } + + public static Builder builder(List vocab) { + return new Builder(vocab); + } + + public static class Builder { + + private final List originalVocab; + private final SortedMap vocab; + private boolean doLowerCase = false; + private boolean doTokenizeCjKChars = true; + private Boolean doStripAccents = null; + private Set neverSplit; + + private Builder(List vocab) { + this.originalVocab = vocab; + this.vocab = buildSortedVocab(vocab); + } + + private static SortedMap buildSortedVocab(List vocab) { + SortedMap sortedVocab = new TreeMap<>(); + for (int i = 0; i < vocab.size(); i++) { + sortedVocab.put(vocab.get(i), i); + } + return sortedVocab; + } + + public Builder setDoLowerCase(boolean doLowerCase) { + this.doLowerCase = doLowerCase; + return this; + } + + public Builder setDoTokenizeCjKChars(boolean doTokenizeCjKChars) { + this.doTokenizeCjKChars = doTokenizeCjKChars; + return this; + } + + public Builder setDoStripAccents(Boolean doStripAccents) { + this.doStripAccents = doStripAccents; + return this; + } + + public Builder setNeverSplit(Set neverSplit) { + this.neverSplit = neverSplit; + return this; + } + + public BertTokenizer build() { + // if not set strip accents defaults to the value of doLowerCase + if (doStripAccents == null) { + doStripAccents = doLowerCase; + } + + if (neverSplit == null) { + neverSplit = Collections.emptySet(); + } + + return new BertTokenizer(originalVocab, vocab, doLowerCase, doTokenizeCjKChars, doStripAccents, neverSplit); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizer.java new file mode 100644 index 0000000000000..23634d200a771 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizer.java @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * SubWord tokenization via the Word Piece algorithm using the + * provided vocabulary. + * + * The input is split by white space and should be pre-processed + * by {@link BasicTokenizer} + */ +public class WordPieceTokenizer { + + private static final String CONTINUATION = "##"; + + private final Map vocab; + private final String unknownToken; + private final int maxInputCharsPerWord; + + public static class TokenAndId { + private final String token; + private final int id; + + TokenAndId(String token, int id) { + this.token = token; + this.id = id; + } + + public int getId() { + return id; + } + + public String getToken() { + return token; + } + } + + /** + * + * @param vocab The token vocabulary + * @param unknownToken If not found in the vocabulary + * @param maxInputCharsPerWord Inputs tokens longer than this are 'unknown' + */ + public WordPieceTokenizer(Map vocab, String unknownToken, int maxInputCharsPerWord) { + this.vocab = vocab; + this.unknownToken = unknownToken; + this.maxInputCharsPerWord = maxInputCharsPerWord; + } + + /** + * Wordpiece tokenize the input text. + * + * @param text A single token or whitespace separated tokens. + * Input should have been normalized by the {@link BasicTokenizer}. + * @return List of tokens + */ + public List tokenize(String text) { + String[] tokens = BasicTokenizer.whiteSpaceTokenize(text); + + List output = new ArrayList<>(); + for (String token : tokens) { + if (token.length() > maxInputCharsPerWord) { + assert vocab.containsKey(unknownToken); + output.add(new TokenAndId(unknownToken, vocab.get(unknownToken))); + continue; + } + + boolean isBad = false; + int start = 0; + List subTokens = new ArrayList<>(); + int length = token.length(); + while (start < length) { + int end = length; + + String currentValidSubStr = null; + + while (start < end) { + String subStr; + if (start > 0) { + subStr = CONTINUATION + token.substring(start, end); + } else { + subStr = token.substring(start, end); + } + + if (vocab.containsKey(subStr)) { + currentValidSubStr = subStr; + break; + } + + end--; + } + + if (currentValidSubStr == null) { + isBad = true; + break; + } + + subTokens.add(new TokenAndId(currentValidSubStr, vocab.get(currentValidSubStr))); + + start = end; + } + + if (isBad) { + output.add(new TokenAndId(unknownToken, vocab.get(unknownToken))); + } else { + output.addAll(subTokens); + } + } + + return output; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceVocabulary.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceVocabulary.java new file mode 100644 index 0000000000000..1b9acf6dc34da --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceVocabulary.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.SortedMap; +import java.util.TreeMap; + +public class WordPieceVocabulary implements ToXContentObject { + + public static final String NAME = "vocab"; + public static final ParseField VOCAB = new ParseField(NAME); + public static final ParseField UNKNOWN_TOKEN = new ParseField("unknown"); + + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + + @SuppressWarnings("unchecked") + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, + ignoreUnknownFields, + a -> new WordPieceVocabulary((List) a[0], (Integer) a[1])); + + parser.declareStringArray(ConstructingObjectParser.constructorArg(), VOCAB); + parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), UNKNOWN_TOKEN); + + return parser; + } + + public static WordPieceVocabulary fromXContent(XContentParser parser, boolean lenient) { + return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); + } + + private final SortedMap vocab; + private final int unknownToken; + + public WordPieceVocabulary(List words, Integer unknownToken) { + this.unknownToken = unknownToken == null ? -1 : unknownToken; + vocab = new TreeMap<>(); + for (int i = 0; i < words.size(); i++) { + vocab.put(words.get(i), i); + } + } + + public int token(String word) { + Integer token = vocab.get(word); + if (token == null) { + token = unknownToken; + } + return token; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.field(VOCAB.getPreferredName(), vocab.keySet()); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WordPieceVocabulary that = (WordPieceVocabulary) o; + return unknownToken == that.unknownToken && Objects.equals(vocab, that.vocab); + } + + @Override + public int hashCode() { + return Objects.hash(vocab, unknownToken); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/ChunkedTrainedModelRestorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/ChunkedTrainedModelRestorer.java new file mode 100644 index 0000000000000..4cde81d676200 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/ChunkedTrainedModelRestorer.java @@ -0,0 +1,225 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.persistence; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.CheckedFunction; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.sort.SortBuilders; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +/** + * Searches for and emits {@link TrainedModelDefinitionDoc}s in + * order based on the {@code doc_num}. + * + * This is a one-use class it has internal state to track progress + * and cannot be used again to load another model. + * + * Defaults to searching in {@link InferenceIndexConstants#INDEX_PATTERN} + * if a different index is not set. + */ +public class ChunkedTrainedModelRestorer { + + private static final Logger logger = LogManager.getLogger(ChunkedTrainedModelRestorer.class); + + private static final int MAX_NUM_DEFINITION_DOCS = 20; + + private final Client client; + private final NamedXContentRegistry xContentRegistry; + private final ExecutorService executorService; + private final String modelId; + private String index = InferenceIndexConstants.INDEX_PATTERN; + private int searchSize = 10; + private int numDocsWritten = 0; + + public ChunkedTrainedModelRestorer(String modelId, + Client client, + ExecutorService executorService, + NamedXContentRegistry xContentRegistry) { + this.client = client; + this.executorService = executorService; + this.xContentRegistry = xContentRegistry; + this.modelId = modelId; + } + + public void setSearchSize(int searchSize) { + if (searchSize > MAX_NUM_DEFINITION_DOCS) { + throw new IllegalArgumentException("search size [" + searchSize + "] cannot be bigger than [" + MAX_NUM_DEFINITION_DOCS + "]"); + } + if (searchSize <=0) { + throw new IllegalArgumentException("search size [" + searchSize + "] must be greater than 0"); + } + this.searchSize = searchSize; + } + + public void setSearchIndex(String indexNameOrPattern) { + this.index = indexNameOrPattern; + } + + public int getNumDocsWritten() { + return numDocsWritten; + } + + /** + * Return the model definitions one at a time on the {@code modelConsumer}. + * Either {@code errorConsumer} or {@code successConsumer} will be called + * when the process is finished. + * + * The {@code modelConsumer} has the opportunity to cancel loading by + * returning false in which case the {@code successConsumer} is called + * with the parameter Boolean.FALSE. + * + * The docs are returned in order based on {@link TrainedModelDefinitionDoc#getDocNum()} + * there is no error checking for duplicate or missing docs the consumer should handle + * those errors. + * + * Depending on the search size multiple searches may be made. + * + * @param modelConsumer Consumes model definition docs + * @param successConsumer Called when all docs have been returned or the loading is cancelled + * @param errorConsumer In the event of an error + */ + public void restoreModelDefinition(CheckedFunction modelConsumer, + Consumer successConsumer, + Consumer errorConsumer) { + + logger.debug("[{}] restoring model", modelId); + SearchRequest searchRequest = buildSearch(client, modelId, index, searchSize); + + executorService.execute(() -> doSearch(searchRequest, modelConsumer, successConsumer, errorConsumer)); + } + + private void doSearch(SearchRequest searchRequest, + CheckedFunction modelConsumer, + Consumer successConsumer, + Consumer errorConsumer) { + + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( + searchResponse -> { + if (searchResponse.getHits().getHits().length == 0) { + errorConsumer.accept(new ResourceNotFoundException( + Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); + return; + } + + // Set lastNum to a non-zero to prevent an infinite loop of + // search after requests in the absolute worse case where + // it has all gone wrong. + // Docs are numbered 0..N. we must have seen at least + // this many docs so far. + int lastNum = numDocsWritten -1; + for (SearchHit hit : searchResponse.getHits().getHits()) { + try { + TrainedModelDefinitionDoc doc = + parseModelDefinitionDocLenientlyFromSource(hit.getSourceRef(), modelId, xContentRegistry); + lastNum = doc.getDocNum(); + + boolean continueSearching = modelConsumer.apply(doc); + if (continueSearching == false) { + // signal the search has finished early + successConsumer.accept(Boolean.FALSE); + return; + } + + } catch (IOException e) { + logger.error(new ParameterizedMessage("[{}] error writing model definition", modelId), e); + errorConsumer.accept(e); + return; + } + } + + numDocsWritten += searchResponse.getHits().getHits().length; + + boolean endOfSearch = searchResponse.getHits().getHits().length < searchSize || + searchResponse.getHits().getTotalHits().value == numDocsWritten; + + if (endOfSearch) { + successConsumer.accept(Boolean.TRUE); + } else { + // search again with after + SearchHit lastHit = searchResponse.getHits().getAt(searchResponse.getHits().getHits().length -1); + SearchRequestBuilder searchRequestBuilder = buildSearchBuilder(client, modelId, index, searchSize); + searchRequestBuilder.searchAfter(new Object[]{lastHit.getIndex(), lastNum}); + executorService.execute(() -> + doSearch(searchRequestBuilder.request(), modelConsumer, successConsumer, errorConsumer)); + } + }, + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + errorConsumer.accept(new ResourceNotFoundException( + Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); + } else { + errorConsumer.accept(e); + } + } + )); + } + + private static SearchRequestBuilder buildSearchBuilder(Client client, String modelId, String index, int searchSize) { + return client.prepareSearch(index) + .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders + .boolQuery() + .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId)) + .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), + TrainedModelDefinitionDoc.NAME)))) + .setSize(searchSize) + .setTrackTotalHits(true) + // First find the latest index + .addSort("_index", SortOrder.DESC) + // Then, sort by doc_num + .addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName()) + .order(SortOrder.ASC) + .unmappedType("long")); + } + + public static SearchRequest buildSearch(Client client, String modelId, String index, int searchSize) { + return buildSearchBuilder(client, modelId, index, searchSize).request(); + } + + public static TrainedModelDefinitionDoc parseModelDefinitionDocLenientlyFromSource(BytesReference source, + String modelId, + NamedXContentRegistry xContentRegistry) + throws IOException { + + try (InputStream stream = source.streamInput(); + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { + return TrainedModelDefinitionDoc.fromXContent(parser, true).build(); + } catch (IOException e) { + logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), e); + throw e; + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java index 54da58019f50c..ea609d6129e00 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java @@ -8,6 +8,8 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -17,6 +19,8 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; import java.util.Objects; /** @@ -31,6 +35,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { public static final ParseField DOC_NUM = new ParseField("doc_num"); public static final ParseField DEFINITION = new ParseField("definition"); + public static final ParseField BINARY_DEFINITION = new ParseField("binary_definition"); public static final ParseField COMPRESSION_VERSION = new ParseField("compression_version"); public static final ParseField TOTAL_DEFINITION_LENGTH = new ParseField("total_definition_length"); public static final ParseField DEFINITION_LENGTH = new ParseField("definition_length"); @@ -44,8 +49,11 @@ private static ObjectParser createParse ObjectParser parser = new ObjectParser<>(NAME, ignoreUnknownFields, TrainedModelDefinitionDoc.Builder::new); + parser.declareString((a, b) -> {}, InferenceIndexConstants.DOC_TYPE); // type is hard coded but must be parsed parser.declareString(TrainedModelDefinitionDoc.Builder::setModelId, TrainedModelConfig.MODEL_ID); parser.declareString(TrainedModelDefinitionDoc.Builder::setCompressedString, DEFINITION); + parser.declareField(TrainedModelDefinitionDoc.Builder::setBinaryData, (p, c) -> new BytesArray(p.binaryValue()), + BINARY_DEFINITION, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); parser.declareInt(TrainedModelDefinitionDoc.Builder::setDocNum, DOC_NUM); parser.declareInt(TrainedModelDefinitionDoc.Builder::setCompressionVersion, COMPRESSION_VERSION); parser.declareLong(TrainedModelDefinitionDoc.Builder::setDefinitionLength, DEFINITION_LENGTH); @@ -63,23 +71,23 @@ public static String docId(String modelId, int docNum) { return NAME + "-" + modelId + "-" + docNum; } - private final String compressedString; + private final BytesReference binaryData; private final String modelId; private final int docNum; - // for BWC + // for bwc private final Long totalDefinitionLength; private final long definitionLength; private final int compressionVersion; private final boolean eos; - private TrainedModelDefinitionDoc(String compressedString, + private TrainedModelDefinitionDoc(BytesReference binaryData, String modelId, int docNum, Long totalDefinitionLength, long definitionLength, int compressionVersion, boolean eos) { - this.compressedString = ExceptionsHelper.requireNonNull(compressedString, DEFINITION); + this.binaryData = ExceptionsHelper.requireNonNull(binaryData, BINARY_DEFINITION); this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); if (docNum < 0) { throw new IllegalArgumentException("[doc_num] must be greater than or equal to 0"); @@ -97,8 +105,8 @@ private TrainedModelDefinitionDoc(String compressedString, this.eos = eos; } - public String getCompressedString() { - return compressedString; + public BytesReference getBinaryData() { + return binaryData; } public String getModelId() { @@ -136,8 +144,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId); builder.field(DOC_NUM.getPreferredName(), docNum); builder.field(DEFINITION_LENGTH.getPreferredName(), definitionLength); + if (totalDefinitionLength != null) { + builder.field(TOTAL_DEFINITION_LENGTH.getPreferredName(), totalDefinitionLength); + } builder.field(COMPRESSION_VERSION.getPreferredName(), compressionVersion); - builder.field(DEFINITION.getPreferredName(), compressedString); + builder.field(BINARY_DEFINITION.getPreferredName(), binaryData); builder.field(EOS.getPreferredName(), eos); builder.endObject(); return builder; @@ -159,18 +170,18 @@ public boolean equals(Object o) { Objects.equals(totalDefinitionLength, that.totalDefinitionLength) && Objects.equals(compressionVersion, that.compressionVersion) && Objects.equals(eos, that.eos) && - Objects.equals(compressedString, that.compressedString); + Objects.equals(binaryData, that.binaryData); } @Override public int hashCode() { - return Objects.hash(modelId, docNum, definitionLength, totalDefinitionLength, compressionVersion, compressedString, eos); + return Objects.hash(modelId, docNum, definitionLength, totalDefinitionLength, compressionVersion, binaryData, eos); } public static class Builder { private String modelId; - private String compressedString; + private BytesReference binaryData; private int docNum; private Long totalDefinitionLength; private long definitionLength; @@ -183,7 +194,13 @@ public Builder setModelId(String modelId) { } public Builder setCompressedString(String compressedString) { - this.compressedString = compressedString; + this.binaryData = new BytesArray(Base64.getDecoder() + .decode(compressedString.getBytes(StandardCharsets.UTF_8))); + return this; + } + + public Builder setBinaryData(BytesReference binaryData) { + this.binaryData = binaryData; return this; } @@ -214,7 +231,7 @@ public Builder setEos(boolean eos) { public TrainedModelDefinitionDoc build() { return new TrainedModelDefinitionDoc( - this.compressedString, + this.binaryData, this.modelId, this.docNum, this.totalDefinitionLength, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 37afd5d193289..0ff20f0df9dcf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -38,6 +38,7 @@ import org.elasticsearch.common.Numbers; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.util.set.Sets; @@ -72,14 +73,17 @@ import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import java.io.IOException; @@ -110,9 +114,9 @@ public class TrainedModelProvider { public static final Set MODELS_STORED_AS_RESOURCE = Collections.singleton("lang_ident_model_1"); private static final String MODEL_RESOURCE_PATH = "/org/elasticsearch/xpack/ml/inference/persistence/"; private static final String MODEL_RESOURCE_FILE_EXT = ".json"; - private static final int COMPRESSED_STRING_CHUNK_SIZE = 16 * 1024 * 1024; + private static final int COMPRESSED_MODEL_CHUNK_SIZE = 16 * 1024 * 1024; private static final int MAX_NUM_DEFINITION_DOCS = 100; - private static final int MAX_COMPRESSED_STRING_SIZE = COMPRESSED_STRING_CHUNK_SIZE * MAX_NUM_DEFINITION_DOCS; + private static final int MAX_COMPRESSED_MODEL_SIZE = COMPRESSED_MODEL_CHUNK_SIZE * MAX_NUM_DEFINITION_DOCS; private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class); private final Client client; @@ -143,14 +147,20 @@ public void storeTrainedModel(TrainedModelConfig trainedModelConfig, } TrainedModelDefinition definition = trainedModelConfig.getModelDefinition(); - if (definition == null) { - listener.onFailure(ExceptionsHelper.badRequestException("Unable to store [{}]. [{}] is required", + TrainedModelLocation location = trainedModelConfig.getLocation(); + if (definition == null && location == null) { + listener.onFailure(ExceptionsHelper.badRequestException("Unable to store [{}]. [{}] or [{}] is required", trainedModelConfig.getModelId(), - TrainedModelConfig.DEFINITION.getPreferredName())); + TrainedModelConfig.DEFINITION.getPreferredName(), + TrainedModelConfig.LOCATION.getPreferredName())); return; } - storeTrainedModelAndDefinition(trainedModelConfig, listener); + if (definition != null) { + storeTrainedModelAndDefinition(trainedModelConfig, listener); + } else { + storeTrainedModelConfig(trainedModelConfig, listener); + } } public void storeTrainedModelConfig(TrainedModelConfig trainedModelConfig, ActionListener listener) { @@ -161,10 +171,14 @@ public void storeTrainedModelConfig(TrainedModelConfig trainedModelConfig, Actio } assert trainedModelConfig.getModelDefinition() == null; + IndexRequest request = + createRequest(trainedModelConfig.getModelId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelConfig); + request.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, - createRequest(trainedModelConfig.getModelId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelConfig), + request, ActionListener.wrap( indexResponse -> listener.onResponse(true), e -> { @@ -173,10 +187,9 @@ public void storeTrainedModelConfig(TrainedModelConfig trainedModelConfig, Actio Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); } else { listener.onFailure( - new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL, - RestStatus.INTERNAL_SERVER_ERROR, - e, - trainedModelConfig.getModelId())); + new ElasticsearchStatusException( + Messages.getMessage(Messages.INFERENCE_FAILED_TO_STORE_MODEL, trainedModelConfig.getModelId()), + RestStatus.INTERNAL_SERVER_ERROR, e)); } } )); @@ -203,10 +216,9 @@ public void storeTrainedModelDefinitionDoc(TrainedModelDefinitionDoc trainedMode trainedModelDefinitionDoc.getDocNum()))); } else { listener.onFailure( - new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL, - RestStatus.INTERNAL_SERVER_ERROR, - e, - trainedModelDefinitionDoc.getModelId())); + new ElasticsearchStatusException( + Messages.getMessage(Messages.INFERENCE_FAILED_TO_STORE_MODEL, trainedModelDefinitionDoc.getModelId()), + RestStatus.INTERNAL_SERVER_ERROR, e)); } } )); @@ -231,10 +243,9 @@ public void storeTrainedModelMetadata(TrainedModelMetadata trainedModelMetadata, trainedModelMetadata.getModelId()))); } else { listener.onFailure( - new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL_METADATA, - RestStatus.INTERNAL_SERVER_ERROR, - e, - trainedModelMetadata.getModelId())); + new ElasticsearchStatusException( + Messages.getMessage(Messages.INFERENCE_FAILED_TO_STORE_MODEL_METADATA, trainedModelMetadata.getModelId()), + RestStatus.INTERNAL_SERVER_ERROR, e)); } } )); @@ -289,25 +300,25 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi List trainedModelDefinitionDocs = new ArrayList<>(); try { - String compressedString = trainedModelConfig.getCompressedDefinition(); - if (compressedString.length() > MAX_COMPRESSED_STRING_SIZE) { + BytesReference compressedDefinition = trainedModelConfig.getCompressedDefinition(); + if (compressedDefinition.length() > MAX_COMPRESSED_MODEL_SIZE) { listener.onFailure( ExceptionsHelper.badRequestException( - "Unable to store model as compressed definition has length [{}] the limit is [{}]", - compressedString.length(), - MAX_COMPRESSED_STRING_SIZE)); + "Unable to store model as compressed definition of size [{}] bytes the limit is [{}] bytes", + compressedDefinition.length(), + MAX_COMPRESSED_MODEL_SIZE)); return; } - List chunkedStrings = chunkStringWithSize(compressedString, COMPRESSED_STRING_CHUNK_SIZE); - for(int i = 0; i < chunkedStrings.size(); ++i) { + List chunkedDefinition = chunkDefinitionWithSize(compressedDefinition, COMPRESSED_MODEL_CHUNK_SIZE); + for(int i = 0; i < chunkedDefinition.size(); ++i) { trainedModelDefinitionDocs.add(new TrainedModelDefinitionDoc.Builder() .setDocNum(i) .setModelId(trainedModelConfig.getModelId()) - .setCompressedString(chunkedStrings.get(i)) + .setBinaryData(chunkedDefinition.get(i)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) - .setDefinitionLength(chunkedStrings.get(i).length()) + .setDefinitionLength(chunkedDefinition.get(i).length()) // If it is the last doc, it is the EOS - .setEos(i == chunkedStrings.size() - 1) + .setEos(i == chunkedDefinition.size() - 1) .build()); } } catch (IOException ex) { @@ -332,10 +343,9 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); } else { listener.onFailure( - new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL, - RestStatus.INTERNAL_SERVER_ERROR, - e, - trainedModelConfig.getModelId())); + new ElasticsearchStatusException( + Messages.getMessage(Messages.INFERENCE_FAILED_TO_STORE_MODEL, trainedModelConfig.getModelId()), + RestStatus.INTERNAL_SERVER_ERROR, e)); } } ); @@ -374,12 +384,23 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest.request(), bulkResponseActionListener); } + /** + * Get the model definition for inference. + * + * The caller should ensure the requested model has an InferenceDefinition, + * some models such as {@code org.elasticsearch.xpack.core.ml.inference.trainedmodel.pytorch.PyTorchModel} + * do not. + * + * @param modelId The model tp get + * @param listener The listener + */ public void getTrainedModelForInference(final String modelId, final ActionListener listener) { // TODO Change this when we get more than just langIdent stored if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { try { TrainedModelConfig config = loadModelFromResource(modelId, false).build().ensureParsedDefinition(xContentRegistry); assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork; + assert config.getModelType() == TrainedModelType.LANG_IDENT; listener.onResponse( InferenceDefinition.builder() .setPreProcessors(config.getModelDefinition().getPreProcessors()) @@ -392,53 +413,35 @@ public void getTrainedModelForInference(final String modelId, final ActionListen } } - SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) - .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders - .boolQuery() - .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId)) - .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), - TrainedModelDefinitionDoc.NAME)))) - .setSize(MAX_NUM_DEFINITION_DOCS) - // First find the latest index - .addSort("_index", SortOrder.DESC) - // Then, sort by doc_num - .addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName()) - .order(SortOrder.ASC) - .unmappedType("long")) - .request(); - executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( - // TODO how could we stream in the model definition WHILE parsing it? - // This would reduce the overall memory usage as we won't have to load the whole compressed string - // XContentParser supports streams. - searchResponse -> { - if (searchResponse.getHits().getHits().length == 0) { - listener.onFailure(new ResourceNotFoundException( - Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); - return; - } - List docs = handleHits(searchResponse.getHits().getHits(), - modelId, - this::parseModelDefinitionDocLenientlyFromSource); + List docs = new ArrayList<>(); + ChunkedTrainedModelRestorer modelRestorer = + new ChunkedTrainedModelRestorer(modelId, client, + client.threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME), xContentRegistry); + + // TODO how could we stream in the model definition WHILE parsing it? + // This would reduce the overall memory usage as we won't have to load the whole compressed string + // XContentParser supports streams. + modelRestorer.restoreModelDefinition(docs::add, + success -> { try { - String compressedString = getDefinitionFromDocs(docs, modelId); + BytesReference compressedData = getDefinitionFromDocs(docs, modelId); InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate( - compressedString, + compressedData, InferenceDefinition::fromXContent, xContentRegistry); + listener.onResponse(inferenceDefinition); - } catch (ElasticsearchException elasticsearchException) { - listener.onFailure(elasticsearchException); + } catch (Exception e) { + listener.onFailure(e); } }, e -> { if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { listener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); - return; } listener.onFailure(e); - } - )); + }); } public void getTrainedModel(final String modelId, @@ -513,30 +516,15 @@ public void getTrainedModel(final String modelId, .request()); if (includes.isIncludeModelDefinition()) { - multiSearchRequestBuilder.add(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) - .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders - .boolQuery() - .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId)) - .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelDefinitionDoc.NAME)))) - // There should be AT MOST these many docs. There might be more if definitions have been reindex to newer indices - // If this ends up getting duplicate groups of definition documents, the parsing logic will throw away any doc that - // is in a different index than the first index seen. - .setSize(MAX_NUM_DEFINITION_DOCS) - // First find the latest index - .addSort("_index", SortOrder.DESC) - // Then, sort by doc_num - .addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName()) - .order(SortOrder.ASC) - // We need this for the search not to fail when there are no mappings yet in the index - .unmappedType("long")) - .request()); + multiSearchRequestBuilder.add( + ChunkedTrainedModelRestorer.buildSearch(client, modelId, InferenceIndexConstants.INDEX_PATTERN, MAX_NUM_DEFINITION_DOCS)); } ActionListener multiSearchResponseActionListener = ActionListener.wrap( multiSearchResponse -> { TrainedModelConfig.Builder builder; try { - builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource); + builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseModelConfigLenientlyFromSource); } catch (ResourceNotFoundException ex) { getTrainedModelListener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); @@ -550,10 +538,12 @@ public void getTrainedModel(final String modelId, try { List docs = handleSearchItems(multiSearchResponse.getResponses()[1], modelId, - this::parseModelDefinitionDocLenientlyFromSource); + (bytes, resourceId) -> + ChunkedTrainedModelRestorer.parseModelDefinitionDocLenientlyFromSource( + bytes, resourceId, xContentRegistry)); try { - String compressedString = getDefinitionFromDocs(docs, modelId); - builder.setDefinitionFromString(compressedString); + BytesReference compressedData = getDefinitionFromDocs(docs, modelId); + builder.setDefinitionFromBytes(compressedData); } catch (ElasticsearchException elasticsearchException) { getTrainedModelListener.onFailure(elasticsearchException); return; @@ -691,7 +681,7 @@ public void getTrainedModels(Map> modelIds, try { if (observedIds.contains(searchHit.getId()) == false) { configs.add( - parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()) + parseModelConfigLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()) ); observedIds.add(searchHit.getId()); } @@ -706,7 +696,8 @@ public void getTrainedModels(Map> modelIds, // Otherwise, treat it as if it was never expanded to begin with. Set missingConfigs = Sets.difference(modelIds.keySet(), observedIds); if (missingConfigs.isEmpty() == false && allowNoResources == false) { - getTrainedModelListener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs)); + getTrainedModelListener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs))); return; } // Ensure sorted even with the injection of locally resourced models @@ -1104,52 +1095,55 @@ private static List handleHits(SearchHit[] hits, return results; } - private static String getDefinitionFromDocs(List docs, String modelId) throws ElasticsearchException { - String compressedString = docs.stream() - .map(TrainedModelDefinitionDoc::getCompressedString) - .collect(Collectors.joining()); - // BWC for when we tracked the total definition length - // TODO: remove in 9 + private static BytesReference getDefinitionFromDocs(List docs, + String modelId) throws ElasticsearchException { + + BytesReference[] bb = new BytesReference[docs.size()]; + for (int i = 0; i < docs.size(); i++) { + bb[i] = docs.get(i).getBinaryData(); + } + BytesReference bytes = CompositeBytesReference.of(bb); + if (docs.get(0).getTotalDefinitionLength() != null) { - if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { - throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)); - } - } else { - TrainedModelDefinitionDoc lastDoc = docs.get(docs.size() - 1); - // Either we are missing the last doc, or some previous doc - if(lastDoc.isEos() == false || lastDoc.getDocNum() != docs.size() - 1) { + if (bytes.length() != docs.get(0).getTotalDefinitionLength()) { throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)); } } - return compressedString; - } - static List chunkStringWithSize(String str, int chunkSize) { - List subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize)); - for (int i = 0; i < str.length();i += chunkSize) { - subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length()))); + TrainedModelDefinitionDoc lastDoc = docs.get(docs.size() - 1); + // Either we are missing the last doc, or some previous doc + if (lastDoc.isEos() == false || lastDoc.getDocNum() != docs.size() - 1) { + throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)); } - return subStrings; + return bytes; } - private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws IOException { - try (InputStream stream = source.streamInput(); - XContentParser parser = XContentFactory.xContent(XContentType.JSON) - .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { - return TrainedModelConfig.fromXContent(parser, true); - } catch (IOException e) { - logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e); - throw e; + public static List chunkDefinitionWithSize(BytesReference definition, int chunkSize) { + List chunks = new ArrayList<>((int)Math.ceil(definition.length()/(double)chunkSize)); + for (int i = 0; i < definition.length();i += chunkSize) { + BytesReference chunk = definition.slice(i, Math.min(chunkSize, definition.length() - i)); + chunks.add(chunk); } + return chunks; } - private TrainedModelDefinitionDoc parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws IOException { + private TrainedModelConfig.Builder parseModelConfigLenientlyFromSource(BytesReference source, String modelId) throws IOException { try (InputStream stream = source.streamInput(); XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { - return TrainedModelDefinitionDoc.fromXContent(parser, true).build(); + TrainedModelConfig.Builder builder = TrainedModelConfig.fromXContent(parser, true); + + if (builder.getModelType() == null) { + // before TrainedModelConfig::modelType was added tree ensembles and the + // lang ident model were the only models supported. Models created after + // VERSION_3RD_PARTY_CONFIG_ADDED must have modelType set, if not set modelType + // is a tree ensemble + assert builder.getVersion().before(TrainedModelConfig.VERSION_3RD_PARTY_CONFIG_ADDED); + builder.setModelType(TrainedModelType.TREE_ENSEMBLE); + } + return builder; } catch (IOException e) { - logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), e); + logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e); throw e; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java new file mode 100644 index 0000000000000..11eed2b5a2950 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.ml.process.AbstractNativeProcess; +import org.elasticsearch.xpack.ml.process.NativeController; +import org.elasticsearch.xpack.ml.process.ProcessPipes; +import org.elasticsearch.xpack.ml.process.ProcessResultsParser; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Iterator; +import java.util.List; +import java.util.function.Consumer; + +public class NativePyTorchProcess extends AbstractNativeProcess { + + private static final String NAME = "pytorch_inference"; + + private final ProcessResultsParser resultsParser; + + protected NativePyTorchProcess(String jobId, NativeController nativeController, ProcessPipes processPipes, int numberOfFields, + List filesToDelete, Consumer onProcessCrash) { + super(jobId, nativeController, processPipes, numberOfFields, filesToDelete, onProcessCrash); + this.resultsParser = new ProcessResultsParser<>(PyTorchResult.PARSER, NamedXContentRegistry.EMPTY); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public void persistState() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void persistState(long snapshotTimestampMs, String snapshotId, String snapshotDescription) throws IOException { + throw new UnsupportedOperationException(); + } + + public void loadModel(String modelId, String index, PyTorchStateStreamer stateStreamer, ActionListener listener) { + stateStreamer.writeStateToStream(modelId, index, processRestoreStream(), listener); + } + + public Iterator readResults() { + return resultsParser.parseResults(processOutStream()); + } + + public void writeInferenceRequest(BytesReference jsonRequest) throws IOException { + processInStream().write(jsonRequest.array(), jsonRequest.arrayOffset(), jsonRequest.length()); + processInStream().flush(); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java new file mode 100644 index 0000000000000..dfd87bb324cd3 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.env.Environment; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.process.NativeController; +import org.elasticsearch.xpack.ml.process.ProcessPipes; +import org.elasticsearch.xpack.ml.utils.NamedPipeHelper; + +import java.io.IOException; +import java.nio.file.Path; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; + +public class NativePyTorchProcessFactory implements PyTorchProcessFactory { + + private static final Logger logger = LogManager.getLogger(NativePyTorchProcessFactory.class); + + private static final NamedPipeHelper NAMED_PIPE_HELPER = new NamedPipeHelper(); + + private final Environment env; + private final NativeController nativeController; + private volatile Duration processConnectTimeout; + + public NativePyTorchProcessFactory(Environment env, + NativeController nativeController, + ClusterService clusterService) { + this.env = Objects.requireNonNull(env); + this.nativeController = Objects.requireNonNull(nativeController); + setProcessConnectTimeout(MachineLearning.PROCESS_CONNECT_TIMEOUT.get(env.settings())); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.PROCESS_CONNECT_TIMEOUT, + this::setProcessConnectTimeout); + } + + void setProcessConnectTimeout(TimeValue processConnectTimeout) { + this.processConnectTimeout = Duration.ofMillis(processConnectTimeout.getMillis()); + } + + @Override + public NativePyTorchProcess createProcess(String modelId, ExecutorService executorService, Consumer onProcessCrash) { + List filesToDelete = new ArrayList<>(); + ProcessPipes processPipes = new ProcessPipes( + env, + NAMED_PIPE_HELPER, + processConnectTimeout, + PyTorchBuilder.PROCESS_NAME, + modelId, + null, + false, + true, + true, + true, + false + ); + + executeProcess(processPipes, filesToDelete); + + NativePyTorchProcess process = new NativePyTorchProcess(modelId, nativeController, processPipes, 0, filesToDelete, onProcessCrash); + + try { + process.start(executorService); + } catch(IOException | EsRejectedExecutionException e) { + String msg = "Failed to connect to pytorch process for job " + modelId; + logger.error(msg); + try { + IOUtils.close(process); + } catch (IOException ioe) { + logger.error("Can't close pytorch process", ioe); + } + throw ExceptionsHelper.serverError(msg, e); + } + return process; + } + + private void executeProcess(ProcessPipes processPipes, List filesToDelete) { + PyTorchBuilder pyTorchBuilder = new PyTorchBuilder(env::tmpFile, nativeController, processPipes, filesToDelete); + try { + pyTorchBuilder.build(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (IOException e) { + throw ExceptionsHelper.serverError("Failed to launch PyTorch process"); + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java new file mode 100644 index 0000000000000..60f50a8f70507 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.elasticsearch.xpack.ml.process.NativeController; +import org.elasticsearch.xpack.ml.process.ProcessPipes; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class PyTorchBuilder { + + public static final String PROCESS_NAME = "pytorch_inference"; + private static final String PROCESS_PATH = "./" + PROCESS_NAME; + + private final Supplier tempDirPathSupplier; + private final NativeController nativeController; + private final ProcessPipes processPipes; + private final List filesToDelete; + + public PyTorchBuilder(Supplier tempDirPathSupplier, NativeController nativeController, ProcessPipes processPipes, + List filesToDelete) { + this.tempDirPathSupplier = Objects.requireNonNull(tempDirPathSupplier); + this.nativeController = Objects.requireNonNull(nativeController); + this.processPipes = Objects.requireNonNull(processPipes); + this.filesToDelete = Objects.requireNonNull(filesToDelete); + } + + public void build() throws IOException, InterruptedException { + List command = buildCommand(); + processPipes.addArgs(command); + nativeController.startProcess(command); + } + + private List buildCommand() { + List command = new ArrayList<>(); + command.add(PROCESS_PATH); + return command; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessFactory.java new file mode 100644 index 0000000000000..0bf9b206be103 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessFactory.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; + +public interface PyTorchProcessFactory { + + NativePyTorchProcess createProcess(String modelId, ExecutorService executorService, Consumer onProcessCrash); +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessManager.java new file mode 100644 index 0000000000000..c812e490217ed --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessManager.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +public class PyTorchProcessManager { + + private static final Logger logger = LogManager.getLogger(PyTorchProcessManager.class); + + public PyTorchProcessManager() { + + } + + public void start(String taskId) { + + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java new file mode 100644 index 0000000000000..61311b3ff5a37 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; + +import java.util.Iterator; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class PyTorchResultProcessor { + + private static final Logger logger = LogManager.getLogger(PyTorchResultProcessor.class); + + private final ConcurrentMap pendingResults = new ConcurrentHashMap<>(); + + private final String deploymentId; + private volatile boolean isStopping; + + public PyTorchResultProcessor(String deploymentId) { + this.deploymentId = Objects.requireNonNull(deploymentId); + } + + public void process(NativePyTorchProcess process) { + try { + Iterator iterator = process.readResults(); + while (iterator.hasNext()) { + PyTorchResult result = iterator.next(); + logger.debug(() -> new ParameterizedMessage("[{}] Parsed result with id [{}]", deploymentId, result.getRequestId())); + PendingResult pendingResult = pendingResults.get(result.getRequestId()); + if (pendingResult == null) { + logger.warn(() -> new ParameterizedMessage("[{}] no pending result for [{}]", deploymentId, result.getRequestId())); + } else { + pendingResult.result = result; + pendingResult.latch.countDown(); + } + } + } catch (Exception e) { + // No need to report error as we're stopping + if (isStopping == false) { + logger.error(new ParameterizedMessage("[{}] Error processing results", deploymentId), e); + } + } + logger.debug(() -> new ParameterizedMessage("[{}] Results processing finished", deploymentId)); + } + + public PyTorchResult waitForResult(String requestId, TimeValue timeout) throws InterruptedException { + PendingResult pendingResult = pendingResults.computeIfAbsent(requestId, k -> new PendingResult()); + try { + if (pendingResult.latch.await(timeout.millis(), TimeUnit.MILLISECONDS)) { + return pendingResult.result; + } + } finally { + pendingResults.remove(requestId); + } + return null; + } + + public void stop() { + isStopping = true; + } + + private static class PendingResult { + private volatile PyTorchResult result; + private final CountDownLatch latch = new CountDownLatch(1); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java new file mode 100644 index 0000000000000..34c11ec15b947 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.xpack.ml.inference.persistence.ChunkedTrainedModelRestorer; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.Locale; +import java.util.Objects; +import java.util.concurrent.ExecutorService; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +/** + * PyTorch models in the TorchScript format are binary files divided + * into small chunks and base64 encoded for storage in Elasticsearch. + * The model is restored by base64 decoding the stored state and streaming + * the binary objects concatenated in order. There is no delineation between + * individual chunks the state should appear as one contiguous file. + */ +public class PyTorchStateStreamer { + + private static final Logger logger = LogManager.getLogger(PyTorchStateStreamer.class); + + private final OriginSettingClient client; + private final ExecutorService executorService; + private final NamedXContentRegistry xContentRegistry; + private volatile boolean isCancelled; + private boolean modelSizeWritten = false; + + public PyTorchStateStreamer(Client client, ExecutorService executorService, NamedXContentRegistry xContentRegistry) { + this.client = new OriginSettingClient(Objects.requireNonNull(client), ML_ORIGIN); + this.executorService = Objects.requireNonNull(executorService); + this.xContentRegistry = Objects.requireNonNull(xContentRegistry); + } + + /** + * Cancels the state streaming at the first opportunity. + */ + public void cancel() { + isCancelled = true; + } + + /** + * First writes the size of the model so the native process can + * allocated memory then writes the chunks of binary state. + * + * @param modelId The model to write + * @param index The index to search for the model + * @param restoreStream The stream to write to + * @param listener error and success listener + */ + public void writeStateToStream(String modelId, String index, OutputStream restoreStream, ActionListener listener) { + ChunkedTrainedModelRestorer restorer = new ChunkedTrainedModelRestorer(modelId, client, executorService, xContentRegistry); + restorer.setSearchIndex(index); + restorer.setSearchSize(1); + restorer.restoreModelDefinition(doc -> writeChunk(doc, restoreStream), listener::onResponse, listener::onFailure); + logger.debug("model [{}] state restored in [{}] documents from index [{}]", modelId, restorer.getNumDocsWritten(), index); + } + + private boolean writeChunk(TrainedModelDefinitionDoc doc, OutputStream outputStream) throws IOException { + if (isCancelled) { + return false; + } + + if (modelSizeWritten == false) { + writeModelSize(doc.getModelId(), doc.getTotalDefinitionLength(), outputStream); + modelSizeWritten = true; + } + + // The array backing the BytesReference may be bigger than what is + // referred to so write only what is after the offset + outputStream.write(doc.getBinaryData().array(), doc.getBinaryData().arrayOffset(), doc.getBinaryData().length()); + return true; + } + + private void writeModelSize(String modelId, Long modelSizeBytes, OutputStream outputStream) throws IOException { + if (modelSizeBytes == null) { + String message = String.format(Locale.ROOT, + "The definition doc for model [%s] has a null value for field [%s]", + modelId, TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName()); + logger.error(message); + throw new IllegalStateException(message); + } + if (modelSizeBytes <= 0) { + // The other end expects an unsigned 32 bit int a -ve value is invalid. + // ByteSizeValue allows -1 bytes as a valid value so this check is still required + String message = String.format(Locale.ROOT, + "The definition doc for model [%s] has a negative value [%s] for field [%s]", + modelId, + modelSizeBytes, + TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName()); + + logger.error(message); + throw new IllegalStateException(message); + } + + if (modelSizeBytes > Integer.MAX_VALUE) { + // TODO use a long in case models are larger than 2^31 bytes + String message = String.format(Locale.ROOT, + "model [%s] has a size [%s] larger than the max size [%s]", + modelId, modelSizeBytes, Integer.MAX_VALUE); + logger.error(message); + throw new IllegalStateException(message); + } + + ByteBuffer lengthBuffer = ByteBuffer.allocate(4); + lengthBuffer.putInt(modelSizeBytes.intValue()); + outputStream.write(lengthBuffer.array()); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java index 7bcd5646156bf..781ac95d074fe 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java @@ -293,7 +293,7 @@ protected InputStream processOutStream() { } @Nullable - private OutputStream processInStream() { + protected OutputStream processInStream() { return processInStream.get(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/cat/RestCatTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/cat/RestCatTrainedModelsAction.java index 3f74258be7cde..3ebfeca212b97 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/cat/RestCatTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/cat/RestCatTrainedModelsAction.java @@ -153,6 +153,9 @@ protected Table getTableWithHeader(RestRequest request) { table.addCell("description", TableColumnAttributeBuilder.builder("The model description", false) .setAliases("d") .build()); + table.addCell("type", TableColumnAttributeBuilder.builder("The model type") + .setAliases("t") + .build()); // Trained Model Stats table.addCell("ingest.pipelines", TableColumnAttributeBuilder.builder("The number of pipelines referencing the model") @@ -239,6 +242,7 @@ private Table buildTable(RestRequest request, table.addCell(config.getCreateTime()); table.addCell(config.getVersion().toString()); table.addCell(config.getDescription()); + table.addCell(config.getModelType()); GetTrainedModelsStatsAction.Response.TrainedModelStats modelStats = statsMap.get(config.getModelId()); table.addCell(modelStats.getPipelineCount()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java index 207186cdab7e3..4d39f6c259ecb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -42,6 +42,7 @@ public class RestGetTrainedModelsAction extends BaseRestHandler { private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(RestGetTrainedModelsAction.class); + private static final String INCLUDE_MODEL_DEFINITION = "include_model_definition"; @Override public List routes() { @@ -74,14 +75,14 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient GetTrainedModelsAction.Request.INCLUDE.getPreferredName(), Strings.EMPTY_ARRAY))); final GetTrainedModelsAction.Request request; - if (restRequest.hasParam(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION)) { + if (restRequest.hasParam(INCLUDE_MODEL_DEFINITION)) { deprecationLogger.deprecate( DeprecationCategory.API, - GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION, + INCLUDE_MODEL_DEFINITION, "[{}] parameter is deprecated! Use [include=definition] instead.", - GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION); + INCLUDE_MODEL_DEFINITION); request = new GetTrainedModelsAction.Request(modelId, - restRequest.paramAsBoolean(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION, false), + restRequest.paramAsBoolean(INCLUDE_MODEL_DEFINITION, false), tags); } else { request = new GetTrainedModelsAction.Request(modelId, tags, includes); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..5e3b03188a47c --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.ml.MachineLearning.BASE_PATH; + +public class RestInferTrainedModelDeploymentAction extends BaseRestHandler { + + @Override + public String getName() { + return "xpack_ml_infer_trained_models_deployment_action"; + } + + @Override + public List routes() { + return Collections.singletonList( + new Route( + POST, + BASE_PATH + "trained_models/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/deployment/_infer") + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String deploymentId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + if (restRequest.hasContent() == false) { + throw ExceptionsHelper.badRequestException("requires body"); + } + InferTrainedModelDeploymentAction.Request request = + InferTrainedModelDeploymentAction.Request.parseRequest(deploymentId, restRequest.contentParser()); + + return channel -> client.execute(InferTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..fd5ac662a9edf --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; + +public class RestStartTrainedModelDeploymentAction extends BaseRestHandler { + + @Override + public String getName() { + return "xpack_ml_start_trained_models_deployment_action"; + } + + @Override + public List routes() { + return Collections.singletonList( + new Route(POST, + MachineLearning.BASE_PATH + "trained_models/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/deployment/_start")); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + StartTrainedModelDeploymentAction.Request request = new StartTrainedModelDeploymentAction.Request(modelId); + return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStopTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStopTrainedModelDeploymentAction.java new file mode 100644 index 0000000000000..90ca93371255a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStopTrainedModelDeploymentAction.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.ml.MachineLearning.BASE_PATH; + +public class RestStopTrainedModelDeploymentAction extends BaseRestHandler { + + @Override + public String getName() { + return "xpack_ml_stop_trained_models_deployment_action"; + } + + @Override + public List routes() { + return Collections.singletonList( + new Route( + POST, + BASE_PATH + "trained_models/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/deployment/_stop") + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + StopTrainedModelDeploymentAction.Request request = new StopTrainedModelDeploymentAction.Request(modelId); + return channel -> client.execute(StopTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/main/resources/org/elasticsearch/xpack/ml/inference/persistence/lang_ident_model_1.json b/x-pack/plugin/ml/src/main/resources/org/elasticsearch/xpack/ml/inference/persistence/lang_ident_model_1.json index 6d0ffcf579643..35f1649f57995 100644 --- a/x-pack/plugin/ml/src/main/resources/org/elasticsearch/xpack/ml/inference/persistence/lang_ident_model_1.json +++ b/x-pack/plugin/ml/src/main/resources/org/elasticsearch/xpack/ml/inference/persistence/lang_ident_model_1.json @@ -1,6 +1,7 @@ { "model_id" : "lang_ident_model_1", "created_by" : "_xpack", + "model_type" : "lang_ident", "version" : "7.6.0", "description" : "Model used for identifying language from arbitrary input text.", "create_time" : 1575548914594, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java index 4c87429026df9..49be42cbcaee2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaselineTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.HyperparametersTests; @@ -119,7 +120,7 @@ public void testPersistAllDocs() { .limit(randomIntBetween(1, 10)) .collect(Collectors.toList())); - resultProcessor.createAndIndexInferenceModelConfig(modelSizeInfo); + resultProcessor.createAndIndexInferenceModelConfig(modelSizeInfo, TrainedModelType.TREE_ENSEMBLE); resultProcessor.createAndIndexInferenceModelDoc(chunk1); resultProcessor.createAndIndexInferenceModelDoc(chunk2); resultProcessor.createAndIndexInferenceModelMetadata(modelMetadata); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilderTests.java new file mode 100644 index 0000000000000..24037e3bd9fe4 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilderTests.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; + +import static org.hamcrest.Matchers.hasSize; + +public class BertRequestBuilderTests extends ESTestCase { + + public void testBuildRequest() throws IOException { + BertTokenizer tokenizer = BertTokenizer.builder( + Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN)).build(); + + BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer); + BytesReference bytesReference = requestBuilder.buildRequest("Elasticsearch fun", "request1"); + + Map jsonDocAsMap = XContentHelper.convertToMap(bytesReference, true, XContentType.JSON).v2(); + + assertThat(jsonDocAsMap.keySet(), hasSize(5)); + assertEquals("request1", jsonDocAsMap.get("request_id")); + assertEquals(Arrays.asList(3, 0, 1, 2, 4), jsonDocAsMap.get("tokens")); + assertEquals(Arrays.asList(1, 1, 1, 1, 1), jsonDocAsMap.get("arg_1")); + assertEquals(Arrays.asList(0, 0, 0, 0, 0), jsonDocAsMap.get("arg_2")); + assertEquals(Arrays.asList(0, 1, 2, 3, 4), jsonDocAsMap.get("arg_3")); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java new file mode 100644 index 0000000000000..4a2b65670f13c --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.mockito.Mockito.mock; + +public class FillMaskProcessorTests extends ESTestCase { + + public void testProcessResults() { + // only the scores of the MASK index array + // are used the rest is filler + double[][] scores = { + { 0, 0, 0, 0, 0, 0, 0}, // The + { 0, 0, 0, 0, 0, 0, 0}, // capital + { 0, 0, 0, 0, 0, 0, 0}, // of + { 0.01, 0.01, 0.3, 0.1, 0.01, 0.2, 1.2}, // MASK + { 0, 0, 0, 0, 0, 0, 0}, // is + { 0, 0, 0, 0, 0, 0, 0} // paris + }; + + String input = "The capital of " + BertTokenizer.MASK_TOKEN + " is Paris"; + + List vocab = Arrays.asList("The", "capital", "of", BertTokenizer.MASK_TOKEN, "is", "Paris", "France"); + List tokens = Arrays.asList(input.split(" ")); + int[] tokenMap = new int[] {0, 1, 2, 3, 4, 5}; + int[] tokenIds = new int[] {0, 1, 2, 3, 4, 5}; + + BertTokenizer.TokenizationResult tokenization = new BertTokenizer.TokenizationResult(input, vocab, tokens, + tokenIds, tokenMap); + + FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class)); + FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, new PyTorchResult("1", scores, null)); + assertThat(result.getPredictions(), hasSize(5)); + FillMaskResults.Prediction prediction = result.getPredictions().get(0); + assertEquals("France", prediction.getToken()); + assertEquals("The capital of France is Paris", prediction.getSequence()); + + prediction = result.getPredictions().get(1); + assertEquals("of", prediction.getToken()); + assertEquals("The capital of of is Paris", prediction.getSequence()); + + prediction = result.getPredictions().get(2); + assertEquals("Paris", prediction.getToken()); + assertEquals("The capital of Paris is Paris", prediction.getSequence()); + } + + public void testProcessResults_GivenMissingTokens() { + BertTokenizer.TokenizationResult tokenization = + new BertTokenizer.TokenizationResult("", Collections.emptyList(), Collections.emptyList(), + new int[] {}, new int[] {}); + + FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class)); + PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][]{{}}, null); + FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, pyTorchResult); + assertThat(result.getPredictions(), empty()); + } + + public void testValidate_GivenMissingMaskToken() { + String input = "The capital of France is Paris"; + + FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class)); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> processor.validateInputs(input)); + assertThat(e.getMessage(), containsString("no [MASK] token could be found")); + } + + + public void testProcessResults_GivenMultipleMaskTokens() { + String input = "The capital of [MASK] is [MASK]"; + + FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class)); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> processor.validateInputs(input)); + assertThat(e.getMessage(), containsString("only one [MASK] token should exist in the input")); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerResultProcessorTests.java new file mode 100644 index 0000000000000..cb3e14081b22b --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerResultProcessorTests.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.deployment.PyTorchResult; +import org.elasticsearch.xpack.core.ml.inference.results.NerResults; +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +public class NerResultProcessorTests extends ESTestCase { + + public void testProcessResults_GivenNoTokens() { + NerResultProcessor processor = createProcessor(Collections.emptyList(), ""); + NerResults result = (NerResults) processor.processResult(new PyTorchResult("test", null, null)); + assertThat(result.getEntityGroups(), is(empty())); + } + + public void testProcessResults() { + NerResultProcessor processor = createProcessor(Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london"), + "Many use Elasticsearch in London"); + double[][] scores = { + { 7, 0, 0, 0, 0, 0, 0, 0, 0}, // many + { 7, 0, 0, 0, 0, 0, 0, 0, 0}, // use + { 0.01, 0.01, 0, 0.01, 0, 7, 0, 3, 0}, // el + { 0.01, 0.01, 0, 0, 0, 0, 0, 0, 0}, // ##astic + { 0, 0, 0, 0, 0, 0, 0, 0, 0}, // ##search + { 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in + { 0, 0, 0, 0, 0, 0, 0, 6, 0} // london + }; + NerResults result = (NerResults) processor.processResult(new PyTorchResult("1", scores, null)); + + assertThat(result.getEntityGroups().size(), equalTo(2)); + assertThat(result.getEntityGroups().get(0).getWord(), equalTo("elasticsearch")); + assertThat(result.getEntityGroups().get(0).getLabel(), equalTo(NerProcessor.Entity.ORGANISATION.toString())); + assertThat(result.getEntityGroups().get(1).getWord(), equalTo("london")); + assertThat(result.getEntityGroups().get(1).getLabel(), equalTo(NerProcessor.Entity.LOCATION.toString())); + } + + public void testGroupTaggedTokens() { + List tokens = new ArrayList<>(); + tokens.add(new NerResultProcessor.TaggedToken("Hi", NerProcessor.IobTag.O, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("Sarah", NerProcessor.IobTag.B_PER, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("Jessica", NerProcessor.IobTag.I_PER, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("I", NerProcessor.IobTag.O, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("live", NerProcessor.IobTag.O, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("in", NerProcessor.IobTag.O, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("Manchester", NerProcessor.IobTag.B_LOC, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("and", NerProcessor.IobTag.O, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("work", NerProcessor.IobTag.O, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("for", NerProcessor.IobTag.O, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("Elastic", NerProcessor.IobTag.B_ORG, 1.0)); + + List entityGroups = NerResultProcessor.groupTaggedTokens(tokens); + assertThat(entityGroups, hasSize(3)); + assertThat(entityGroups.get(0).getLabel(), equalTo("person")); + assertThat(entityGroups.get(0).getWord(), equalTo("Sarah Jessica")); + assertThat(entityGroups.get(1).getLabel(), equalTo("location")); + assertThat(entityGroups.get(1).getWord(), equalTo("Manchester")); + assertThat(entityGroups.get(2).getLabel(), equalTo("organisation")); + assertThat(entityGroups.get(2).getWord(), equalTo("Elastic")); + } + + public void testGroupTaggedTokens_GivenNoEntities() { + List tokens = new ArrayList<>(); + tokens.add(new NerResultProcessor.TaggedToken("Hi", NerProcessor.IobTag.O, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("there", NerProcessor.IobTag.O, 1.0)); + + List entityGroups = NerResultProcessor.groupTaggedTokens(tokens); + assertThat(entityGroups, is(empty())); + } + + public void testGroupTaggedTokens_GivenConsecutiveEntities() { + List tokens = new ArrayList<>(); + tokens.add(new NerResultProcessor.TaggedToken("Rita", NerProcessor.IobTag.B_PER, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("Sue", NerProcessor.IobTag.B_PER, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("and", NerProcessor.IobTag.O, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("Bob", NerProcessor.IobTag.B_PER, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("to", NerProcessor.IobTag.O, 1.0)); + + List entityGroups = NerResultProcessor.groupTaggedTokens(tokens); + assertThat(entityGroups, hasSize(3)); + assertThat(entityGroups.get(0).getLabel(), equalTo("person")); + assertThat(entityGroups.get(0).getWord(), equalTo("Rita")); + assertThat(entityGroups.get(1).getLabel(), equalTo("person")); + assertThat(entityGroups.get(1).getWord(), equalTo("Sue")); + assertThat(entityGroups.get(2).getLabel(), equalTo("person")); + assertThat(entityGroups.get(2).getWord(), equalTo("Bob")); + } + + public void testGroupTaggedTokens_GivenConsecutiveContinuingEntities() { + List tokens = new ArrayList<>(); + tokens.add(new NerResultProcessor.TaggedToken("FirstName", NerProcessor.IobTag.B_PER, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("SecondName", NerProcessor.IobTag.I_PER, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("NextPerson", NerProcessor.IobTag.B_PER, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("NextPersonSecondName", NerProcessor.IobTag.I_PER, 1.0)); + tokens.add(new NerResultProcessor.TaggedToken("something_else", NerProcessor.IobTag.B_ORG, 1.0)); + + List entityGroups = NerResultProcessor.groupTaggedTokens(tokens); + assertThat(entityGroups, hasSize(3)); + assertThat(entityGroups.get(0).getLabel(), equalTo("person")); + assertThat(entityGroups.get(0).getWord(), equalTo("FirstName SecondName")); + assertThat(entityGroups.get(1).getLabel(), equalTo("person")); + assertThat(entityGroups.get(1).getWord(), equalTo("NextPerson NextPersonSecondName")); + assertThat(entityGroups.get(2).getLabel(), equalTo("organisation")); + } + + private static NerResultProcessor createProcessor(List vocab, String input){ + BertTokenizer tokenizer = BertTokenizer.builder(vocab) + .setDoLowerCase(true) + .build(); + BertTokenizer.TokenizationResult tokenizationResult = tokenizer.tokenize(input, false); + return new NerResultProcessor(tokenizationResult); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NlpHelpersTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NlpHelpersTests.java new file mode 100644 index 0000000000000..bb72f3ee88069 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NlpHelpersTests.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp; + +import org.elasticsearch.search.aggregations.pipeline.MovingFunctions; +import org.elasticsearch.test.ESTestCase; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class NlpHelpersTests extends ESTestCase { + + public void testConvertToProbabilitiesBySoftMax_GivenConcreteExample() { + double[][] scores = { + { 0.1, 0.2, 3}, + { 6, 0.2, 0.1} + }; + + double[][] probabilities = NlpHelpers.convertToProbabilitiesBySoftMax(scores); + + assertThat(probabilities[0][0], closeTo(0.04931133, 0.00000001)); + assertThat(probabilities[0][1], closeTo(0.05449744, 0.00000001)); + assertThat(probabilities[0][2], closeTo(0.89619123, 0.00000001)); + assertThat(probabilities[1][0], closeTo(0.99426607, 0.00000001)); + assertThat(probabilities[1][1], closeTo(0.00301019, 0.00000001)); + assertThat(probabilities[1][2], closeTo(0.00272374, 0.00000001)); + } + + public void testConvertToProbabilitiesBySoftMax_OneDimension() { + double[] scores = { 0.1, 0.2, 3}; + double[] probabilities = NlpHelpers.convertToProbabilitiesBySoftMax(scores); + + assertThat(probabilities[0], closeTo(0.04931133, 0.00000001)); + assertThat(probabilities[1], closeTo(0.05449744, 0.00000001)); + assertThat(probabilities[2], closeTo(0.89619123, 0.00000001)); + } + + public void testConvertToProbabilitiesBySoftMax_GivenRandom() { + double[][] scores = new double[100][100]; + for (int i = 0; i < scores.length; i++) { + for (int j = 0; j < scores[i].length; j++) { + scores[i][j] = randomDoubleBetween(-10, 10, true); + } + } + + double[][] probabilities = NlpHelpers.convertToProbabilitiesBySoftMax(scores); + + // Assert invariants that + // 1. each row sums to 1 + // 2. all values are in [0-1] + assertThat(probabilities.length, equalTo(scores.length)); + for (int i = 0; i < probabilities.length; i++) { + assertThat(probabilities[i].length, equalTo(scores[i].length)); + double rowSum = MovingFunctions.sum(probabilities[i]); + assertThat(rowSum, closeTo(1.0, 0.01)); + for (int j = 0; j < probabilities[i].length; j++) { + assertThat(probabilities[i][j], greaterThanOrEqualTo(0.0)); + assertThat(probabilities[i][j], lessThanOrEqualTo(1.0)); + } + } + } + + public void testTopK_SimpleCase() { + int k = 3; + double[] data = new double[]{1.0, 0.0, 2.0, 8.0, 9.0, 4.2, 4.2, 3.0}; + + NlpHelpers.ScoreAndIndex[] scoreAndIndices = NlpHelpers.topK(k, data); + assertEquals(4, scoreAndIndices[0].index); + assertEquals(3, scoreAndIndices[1].index); + assertEquals(5, scoreAndIndices[2].index); + assertEquals(9.0, scoreAndIndices[0].score, 0.001); + assertEquals(8.0, scoreAndIndices[1].score, 0.001); + assertEquals(4.2, scoreAndIndices[2].score, 0.001); + } + + public void testTopK() { + // in this case use the standard java libraries to sort the + // doubles and track the starting index of each value + int size = randomIntBetween(50, 100); + int k = randomIntBetween(1, 10); + double[] data = new double[size]; + for (int i = 0; i < data.length; i++) { + data[i] = randomDouble(); + } + + AtomicInteger index = new AtomicInteger(0); + List sortedByValue = + Stream.generate(() -> new NlpHelpers.ScoreAndIndex(data[index.get()], index.getAndIncrement())) + .limit(size) + .sorted((o1, o2) -> Double.compare(o2.score, o1.score)) + .collect(Collectors.toList()); + + NlpHelpers.ScoreAndIndex[] scoreAndIndices = NlpHelpers.topK(k, data); + assertEquals(k, scoreAndIndices.length); + + // now compare the starting indices in the sorted list + // to the top k. + for (int i = 0; i < scoreAndIndices.length; i++) { + assertEquals(sortedByValue.get(i), scoreAndIndices[i]); + } + } + + public void testTopK_KGreaterThanArrayLength() { + int k = 6; + double[] data = new double[]{1.0, 0.0, 2.0, 8.0}; + + NlpHelpers.ScoreAndIndex[] scoreAndIndices = NlpHelpers.topK(k, data); + assertEquals(4, scoreAndIndices.length); + assertEquals(3, scoreAndIndices[0].index); + assertEquals(2, scoreAndIndices[1].index); + assertEquals(0, scoreAndIndices[2].index); + assertEquals(1, scoreAndIndices[3].index); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java new file mode 100644 index 0000000000000..b06f7cf84839c --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java @@ -0,0 +1,184 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.Matchers.arrayContaining; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.sameInstance; + +/** + * Some test cases taken from + * https://github.com/huggingface/transformers/blob/ba8c4d0ac04acfcdbdeaed954f698d6d5ec3e532/tests/test_tokenization_bert.py + */ +public class BasicTokenizerTests extends ESTestCase { + + public void testLowerCase() { + BasicTokenizer tokenizer = new BasicTokenizer(); + List tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "); + assertThat(tokens, contains("hello", "!", "how", "are", "you", "?")); + + tokens = tokenizer.tokenize("H\u00E9llo"); + assertThat(tokens, contains("hello")); + } + + public void testLowerCaseWithoutStripAccents() { + BasicTokenizer tokenizer = new BasicTokenizer(true, true, false); + List tokens = tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "); + assertThat(tokens, contains("hällo", "!", "how", "are", "you", "?")); + + tokens = tokenizer.tokenize("H\u00E9llo"); + assertThat(tokens, contains("h\u00E9llo")); + } + + public void testLowerCaseStripAccentsDefault() { + BasicTokenizer tokenizer = new BasicTokenizer(true, true); + List tokens = tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "); + assertThat(tokens, contains("hallo", "!", "how", "are", "you", "?")); + + tokens = tokenizer.tokenize("H\u00E9llo"); + assertThat(tokens, contains("hello")); + } + + public void testNoLower() { + List tokens = new BasicTokenizer(false, true, false).tokenize(" \tHäLLo!how \n Are yoU? "); + assertThat(tokens, contains("HäLLo", "!", "how", "Are", "yoU", "?")); + } + + public void testNoLowerStripAccents() { + List tokens = new BasicTokenizer(false, true, true).tokenize(" \tHäLLo!how \n Are yoU? "); + assertThat(tokens, contains("HaLLo", "!", "how", "Are", "yoU", "?")); + } + + public void testNeverSplit() { + BasicTokenizer tokenizer = new BasicTokenizer(false, false, false, Collections.singleton("[UNK]")); + List tokens = tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"); + assertThat(tokens, contains("HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]")); + + tokens = tokenizer.tokenize("Hello [UNK]."); + assertThat(tokens, contains("Hello", "[UNK]", ".")); + + tokens = tokenizer.tokenize("Hello [UNK]?"); + assertThat(tokens, contains("Hello", "[UNK]", "?")); + } + + public void testSplitOnPunctuation() { + List tokens = BasicTokenizer.splitOnPunctuation("hi!"); + assertThat(tokens, contains("hi", "!")); + + tokens = BasicTokenizer.splitOnPunctuation("hi."); + assertThat(tokens, contains("hi", ".")); + + tokens = BasicTokenizer.splitOnPunctuation("!hi"); + assertThat(tokens, contains("!", "hi")); + + tokens = BasicTokenizer.splitOnPunctuation("don't"); + assertThat(tokens, contains("don", "'", "t")); + + tokens = BasicTokenizer.splitOnPunctuation("!!hi"); + assertThat(tokens, contains("!", "!", "hi")); + + tokens = BasicTokenizer.splitOnPunctuation("[hi]"); + assertThat(tokens, contains("[", "hi", "]")); + + tokens = BasicTokenizer.splitOnPunctuation("hi."); + assertThat(tokens, contains("hi", ".")); + } + + public void testStripAccents() { + assertEquals("Hallo", BasicTokenizer.stripAccents("Hällo")); + } + + public void testTokenizeCjkChars() { + assertEquals(" \u535A \u63A8 ", BasicTokenizer.tokenizeCjkChars("\u535A\u63A8")); + + String noCjkChars = "hello"; + assertThat(BasicTokenizer.tokenizeCjkChars(noCjkChars), sameInstance(noCjkChars)); + } + + public void testTokenizeChinese() { + List tokens = new BasicTokenizer().tokenize("ah\u535A\u63A8zz"); + assertThat(tokens, contains("ah", "\u535A", "\u63A8", "zz")); + } + + public void testCleanText() { + assertEquals("change these chars to spaces", + BasicTokenizer.cleanText("change\tthese chars\rto\nspaces")); + assertEquals("filter control chars", + BasicTokenizer.cleanText("\u0000filter \uFFFDcontrol chars\u0005")); + } + + public void testWhiteSpaceTokenize() { + assertThat(BasicTokenizer.whiteSpaceTokenize("nochange"), arrayContaining("nochange")); + assertThat(BasicTokenizer.whiteSpaceTokenize(" some change "), arrayContaining("some", "", "change")); + } + + public void testIsWhitespace() { + assertTrue(BasicTokenizer.isWhiteSpace(' ')); + assertTrue(BasicTokenizer.isWhiteSpace('\t')); + assertTrue(BasicTokenizer.isWhiteSpace('\r')); + assertTrue(BasicTokenizer.isWhiteSpace('\n')); + assertTrue(BasicTokenizer.isWhiteSpace('\u00A0')); + + assertFalse(BasicTokenizer.isWhiteSpace('_')); + assertFalse(BasicTokenizer.isWhiteSpace('A')); + } + + public void testIsControl() { + assertTrue(BasicTokenizer.isControlChar('\u0005')); + assertTrue(BasicTokenizer.isControlChar('\u001C')); + + assertFalse(BasicTokenizer.isControlChar('A')); + assertFalse(BasicTokenizer.isControlChar(' ')); + assertFalse(BasicTokenizer.isControlChar('\t')); + assertFalse(BasicTokenizer.isControlChar('\r')); + } + + public void testIsPunctuation() { + assertTrue(BasicTokenizer.isCommonPunctuation('-')); + assertTrue(BasicTokenizer.isCommonPunctuation('$')); + assertTrue(BasicTokenizer.isCommonPunctuation('.')); + assertFalse(BasicTokenizer.isCommonPunctuation(' ')); + assertFalse(BasicTokenizer.isCommonPunctuation('A')); + assertFalse(BasicTokenizer.isCommonPunctuation('`')); + + assertTrue(BasicTokenizer.isPunctuationMark('-')); + assertTrue(BasicTokenizer.isPunctuationMark('$')); + assertTrue(BasicTokenizer.isPunctuationMark('`')); + assertTrue(BasicTokenizer.isPunctuationMark('.')); + assertFalse(BasicTokenizer.isPunctuationMark(' ')); + assertFalse(BasicTokenizer.isPunctuationMark('A')); + + assertFalse(BasicTokenizer.isCommonPunctuation('[')); + assertTrue(BasicTokenizer.isPunctuationMark('[')); + } + + public void testIsCjkChar() { + assertTrue(BasicTokenizer.isCjkChar(0x3400)); + assertFalse(BasicTokenizer.isCjkChar(0x4DC0)); + + assertTrue(BasicTokenizer.isCjkChar(0xF900)); + assertFalse(BasicTokenizer.isCjkChar(0xFB00)); + + assertTrue(BasicTokenizer.isCjkChar(0x20000)); + assertFalse(BasicTokenizer.isCjkChar(0x2A6E0)); + + assertTrue(BasicTokenizer.isCjkChar(0x20000)); + assertFalse(BasicTokenizer.isCjkChar(0x2A6E0)); + + assertTrue(BasicTokenizer.isCjkChar(0x2A700)); + assertFalse(BasicTokenizer.isCjkChar(0x2CEB0)); + + assertTrue(BasicTokenizer.isCjkChar(0x2F800)); + assertFalse(BasicTokenizer.isCjkChar(0x2FA20)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java new file mode 100644 index 0000000000000..380d868f85acd --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Arrays; +import java.util.Collections; + +import static org.hamcrest.Matchers.contains; + +public class BertTokenizerTests extends ESTestCase { + + public void testTokenize() { + BertTokenizer tokenizer = BertTokenizer.builder(Arrays.asList("Elastic", "##search", "fun")).build(); + + BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun", false); + assertThat(tokenization.getTokens(), contains("Elastic", "##search", "fun")); + assertArrayEquals(new int[] {0, 1, 2}, tokenization.getTokenIds()); + assertArrayEquals(new int[] {0, 0, 1}, tokenization.getTokenMap()); + } + + public void testTokenizeAppendSpecialTokens() { + BertTokenizer tokenizer = BertTokenizer.builder(Arrays.asList( + "elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN)).build(); + + BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize("elasticsearch fun", true); + assertThat(tokenization.getTokens(), contains("[CLS]", "elastic", "##search", "fun", "[SEP]")); + assertArrayEquals(new int[] {3, 0, 1, 2, 4}, tokenization.getTokenIds()); + assertArrayEquals(new int[] {-1, 0, 0, 1, -1}, tokenization.getTokenMap()); + } + + public void testNeverSplitTokens() { + final String specialToken = "SP001"; + + BertTokenizer tokenizer = BertTokenizer.builder( + Arrays.asList("Elastic", "##search", "fun", specialToken, BertTokenizer.UNKNOWN_TOKEN)) + .setNeverSplit(Collections.singleton(specialToken)) + .build(); + + BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch " + specialToken + " fun", false); + assertThat(tokenization.getTokens(), contains("Elastic", "##search", specialToken, "fun")); + assertArrayEquals(new int[] {0, 1, 3, 2}, tokenization.getTokenIds()); + assertArrayEquals(new int[] {0, 0, 1, 2}, tokenization.getTokenMap()); + } + + public void testDoLowerCase() { + { + BertTokenizer tokenizer = BertTokenizer.builder( + Arrays.asList("elastic", "##search", "fun", BertTokenizer.UNKNOWN_TOKEN)) + .setDoLowerCase(false) + .build(); + + BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun", false); + assertThat(tokenization.getTokens(), contains(BertTokenizer.UNKNOWN_TOKEN, "fun")); + assertArrayEquals(new int[] {3, 2}, tokenization.getTokenIds()); + assertArrayEquals(new int[] {0, 1}, tokenization.getTokenMap()); + + tokenization = tokenizer.tokenize("elasticsearch fun", false); + assertThat(tokenization.getTokens(), contains("elastic", "##search", "fun")); + } + + { + BertTokenizer tokenizer = BertTokenizer.builder(Arrays.asList("elastic", "##search", "fun")) + .setDoLowerCase(true) + .build(); + + BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun", false); + assertThat(tokenization.getTokens(), contains("elastic", "##search", "fun")); + } + } + + public void testPunctuation() { + BertTokenizer tokenizer = BertTokenizer.builder( + Arrays.asList("Elastic", "##search", "fun", ".", ",", + BertTokenizer.MASK_TOKEN, BertTokenizer.UNKNOWN_TOKEN)).build(); + + BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch, fun.", false); + assertThat(tokenization.getTokens(), contains("Elastic", "##search", ",", "fun", ".")); + assertArrayEquals(new int[] {0, 1, 4, 2, 3}, tokenization.getTokenIds()); + assertArrayEquals(new int[] {0, 0, 1, 2, 3}, tokenization.getTokenMap()); + + tokenization = tokenizer.tokenize("Elasticsearch, fun [MASK].", false); + assertThat(tokenization.getTokens(), contains("Elastic", "##search", ",", "fun", "[MASK]", ".")); + assertArrayEquals(new int[] {0, 1, 4, 2, 5, 3}, tokenization.getTokenIds()); + assertArrayEquals(new int[] {0, 0, 1, 2, 3, 4}, tokenization.getTokenMap()); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizerTests.java new file mode 100644 index 0000000000000..671ccd2e4fb27 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenizerTests.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import org.elasticsearch.test.ESTestCase; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; + +public class WordPieceTokenizerTests extends ESTestCase { + + public static final String UNKNOWN_TOKEN = "[UNK]"; + + public void testTokenize() { + Map vocabMap = + createVocabMap(UNKNOWN_TOKEN, "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"); + WordPieceTokenizer tokenizer = new WordPieceTokenizer(vocabMap, UNKNOWN_TOKEN, 100); + + List tokenAndIds = tokenizer.tokenize(""); + assertThat(tokenAndIds, empty()); + + tokenAndIds = tokenizer.tokenize("unwanted running"); + List tokens = tokenAndIds.stream().map(WordPieceTokenizer.TokenAndId::getToken).collect(Collectors.toList()); + assertThat(tokens, contains("un", "##want", "##ed", "runn", "##ing")); + + tokenAndIds = tokenizer.tokenize("unwantedX running"); + tokens = tokenAndIds.stream().map(WordPieceTokenizer.TokenAndId::getToken).collect(Collectors.toList()); + assertThat(tokens, contains(UNKNOWN_TOKEN, "runn", "##ing")); + } + + public void testMaxCharLength() { + Map vocabMap = createVocabMap("Some", "words", "will", "become", "UNK"); + + WordPieceTokenizer tokenizer = new WordPieceTokenizer(vocabMap, "UNK", 4); + List tokenAndIds = tokenizer.tokenize("Some words will become UNK"); + List tokens = tokenAndIds.stream().map(WordPieceTokenizer.TokenAndId::getToken).collect(Collectors.toList()); + assertThat(tokens, contains("Some", "UNK", "will", "UNK", "UNK")); + } + + static Map createVocabMap(String ... words) { + Map vocabMap = new HashMap<>(); + for (int i=0; i { + + private final boolean isLenient = randomBoolean(); + + public void testParsingDocWithCompressedStringDefinition() throws IOException { + byte[] bytes = randomByteArrayOfLength(50); + String base64 = Base64.getEncoder().encodeToString(bytes); + + // The previous storage format was a base64 encoded string. + // The new format should parse and decode the string storing the raw bytes. + String compressedStringDoc = "{\"doc_type\":\"trained_model_definition_doc\"," + + "\"model_id\":\"bntHUo\"," + + "\"doc_num\":6," + + "\"definition_length\":7," + + "\"total_definition_length\":13," + + "\"compression_version\":3," + + "\"definition\":\"" + base64 + "\"," + + "\"eos\":false}"; + + try (XContentParser parser = createParser(JsonXContent.jsonXContent, compressedStringDoc)) { + TrainedModelDefinitionDoc parsed = doParseInstance(parser); + assertArrayEquals(bytes, parsed.getBinaryData().array()); + } + } + + @Override + protected TrainedModelDefinitionDoc doParseInstance(XContentParser parser) throws IOException { + return TrainedModelDefinitionDoc.fromXContent(parser, isLenient).build(); + } + + @Override + protected boolean supportsUnknownFields() { + return isLenient; + } + + @Override + protected TrainedModelDefinitionDoc createTestInstance() { + int length = randomIntBetween(4, 10); + + return new TrainedModelDefinitionDoc.Builder() + .setModelId(randomAlphaOfLength(6)) + .setDefinitionLength(length) + .setTotalDefinitionLength(randomIntBetween(length, length *2)) + .setBinaryData(new BytesArray(randomByteArrayOfLength(length))) + .setDocNum(randomIntBetween(0, 10)) + .setCompressionVersion(randomIntBetween(1, 5)) + .setEos(randomBoolean()) + .build(); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index e447dc29de321..f353a8221d459 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -9,6 +9,8 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.Client; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.ConstantScoreQueryBuilder; @@ -25,9 +27,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.TreeSet; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -146,6 +150,24 @@ public void testGetModelThatExistsAsResourceButIsMissing() { assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, "missing_model"))); } + public void testChunkDefinitionWithSize() { + int totalLength = 100; + int size = 30; + + byte[] bytes = randomByteArrayOfLength(totalLength); + List chunks = TrainedModelProvider.chunkDefinitionWithSize(new BytesArray(bytes), size); + assertThat(chunks, hasSize(4)); + int start = 0; + int end = size; + for (BytesReference chunk : chunks) { + assertArrayEquals(Arrays.copyOfRange(bytes, start, end), + Arrays.copyOfRange(chunk.array(), chunk.arrayOffset(), chunk.arrayOffset() + chunk.length())); + + start += size; + end = Math.min(end + size, totalLength); + } + } + @Override protected NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 5773ccb74edf3..89cd0b41dc274 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -157,6 +157,8 @@ public class Constants { "cluster:admin/xpack/ml/job/update", "cluster:admin/xpack/ml/job/validate", "cluster:admin/xpack/ml/job/validate/detector", + "cluster:admin/xpack/ml/trained_models/deployment/start", + "cluster:admin/xpack/ml/trained_models/deployment/stop", "cluster:admin/xpack/ml/upgrade_mode", "cluster:admin/xpack/monitoring/bulk", "cluster:admin/xpack/monitoring/migrate/alerts", @@ -293,6 +295,7 @@ public class Constants { "cluster:monitor/xpack/ml/job/results/overall_buckets/get", "cluster:monitor/xpack/ml/job/results/records/get", "cluster:monitor/xpack/ml/job/stats/get", + "cluster:monitor/xpack/ml/trained_models/deployment/infer", "cluster:monitor/xpack/repositories_metering/clear_metering_archive", "cluster:monitor/xpack/repositories_metering/get_metrics", "cluster:monitor/xpack/rollup/get", diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml new file mode 100644 index 0000000000000..40dc779dc5290 --- /dev/null +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml @@ -0,0 +1,35 @@ +--- +"Test start deployment": + + - do: + ml.put_trained_model: + model_id: distilbert-finetuned-sst + body: > + { + "description": "distilbert-base-uncased-finetuned-sst-2-english.pt", + "model_type": "pytorch", + "inference_config": { + "classification": { + "num_top_classes": 1 + } + }, + "input": { + "field_names": ["text_field"] + }, + "location": { + "index": { + "model_id": "distilbert-finetuned-sst", + "name": "big_model" + } + } + } + + - do: + ml.get_trained_models: + model_id: distilbert-finetuned-sst + - match: { trained_model_configs.0.location.index.model_id: distilbert-finetuned-sst } + +# - do: +# ml.start_deployment: +# model_id: distilbert-finetuned-sst + diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml index 0d7ecee9571d0..e0d6f0afd83cf 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml @@ -37,6 +37,7 @@ setup: body: > { "description": "empty model for tests", + "model_type": "tree_ensemble", "inference_config": {"regression":{}}, "input": {"field_names": ["field1", "field2"]}, "definition": { @@ -78,8 +79,8 @@ setup: model_id: a-regression-model-0 - match: $body: | - / #id heap_size operations create_time ingest.pipelines data_frame.id - ^ (a\-regression\-model\-0 \s+ \w+ \s+ \d+ \s+ .*? \s+ \d+ .*? \n)+ $/ + / #id\s+heap_size\s+operations\s+create_time\s+type\s+ingest.pipelines\s+data_frame.id + ^ (a-regression-model-0\s+\w+\s+\d+\s+.*?\s+[a-z_]+\s+\d+.*?\n)+ $/ - do: cat.ml_trained_models: @@ -87,8 +88,8 @@ setup: model_id: a-regression-model-0 - match: $body: | - /^ id \s+ heap_size \s+ operations \s+ create_time \s+ ingest\.pipelines \s+ data_frame\.id \n - (a\-regression\-model\-0 \s+ \w+ \s+ \d+ \s+ .*? \s+ \d+ \s+ .*? \n)+ $/ + /id\s+heap_size\s+operations\s+create_time\s+type\s+ingest\.pipelines\s+data_frame\.id\s*\n + (a-regression-model-0\s+\w+\s+\d+\s+.*?\s+tree_ensemble\s+\d+\s+\w+\n)+$/ - do: cat.ml_trained_models: @@ -97,8 +98,8 @@ setup: - match: $body: | /^ id \s+ license \s+ dfid \s+ ip \n - (a\-regression\-model\-0 \s+ \w+ \s+ __none__ \s+ \d+ \n)+ - (a\-regression\-model\-1 \s+ \w+ \s+ __none__ \s+ \d+ \n)+ + (a-regression-model-0 \s+ \w+ \s+ __none__ \s+ \d+ \n)+ + (a-regression-model-1 \s+ \w+ \s+ __none__ \s+ \d+ \n)+ (lang_ident_model_1 \s+ \w+ \s+ prepackaged \s+ \d+ \n)+ $/ - do: @@ -109,4 +110,4 @@ setup: - match: $body: | /^ id \s+ license \s+ dfid \s+ ip \n - (a\-regression\-model\-1 \s+ \w+ \s+ __none__ \s+ \d+ \n)+ $/ + (a-regression-model-1 \s+ \w+ \s+ __none__ \s+ \d+ \n)+ $/