diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/transform/transforms/util/TimeUtil.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/common/TimeUtil.java similarity index 87% rename from client/rest-high-level/src/main/java/org/elasticsearch/client/transform/transforms/util/TimeUtil.java rename to client/rest-high-level/src/main/java/org/elasticsearch/client/common/TimeUtil.java index e2d72f91e5550..c3f26be56075e 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/transform/transforms/util/TimeUtil.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/common/TimeUtil.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.client.transform.transforms.util; +package org.elasticsearch.client.common; import org.elasticsearch.common.time.DateFormatters; import org.elasticsearch.common.xcontent.XContentParser; @@ -46,6 +46,14 @@ public static Date parseTimeField(XContentParser parser, String fieldName) throw "unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]"); } + /** + * Parse out an Instant object given the current parser and field name. + * + * @param parser current XContentParser + * @param fieldName the field's preferred name (utilized in exception) + * @return parsed Instant object + * @throws IOException from XContentParser + */ public static Instant parseTimeFieldToInstant(XContentParser parser, String fieldName) throws IOException { if (parser.currentToken() == XContentParser.Token.VALUE_NUMBER) { return Instant.ofEpochMilli(parser.longValue()); diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/calendars/ScheduledEvent.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/calendars/ScheduledEvent.java index decaff728c6c7..663c329f6186a 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/calendars/ScheduledEvent.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/calendars/ScheduledEvent.java @@ -18,7 +18,7 @@ */ package org.elasticsearch.client.ml.calendars; -import org.elasticsearch.client.ml.job.util.TimeUtil; +import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java index 1c333c0bad02b..7830bccb45069 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java @@ -20,7 +20,7 @@ package org.elasticsearch.client.ml.dataframe; import org.elasticsearch.Version; -import org.elasticsearch.client.transform.transforms.util.TimeUtil; +import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObject.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObject.java new file mode 100644 index 0000000000000..969add5254766 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObject.java @@ -0,0 +1,34 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +/** + * Simple interface for XContent Objects that are named. + * + * This affords more general handling when serializing and de-serializing this type of XContent when it is used in a NamedObjects + * parser. + */ +public interface NamedXContentObject extends ToXContentObject { + /** + * @return The name of the XContentObject that is to be serialized + */ + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java new file mode 100644 index 0000000000000..1795f5da49511 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java @@ -0,0 +1,57 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; + +public final class NamedXContentObjectHelper { + + private NamedXContentObjectHelper() {} + + public static XContentBuilder writeNamedObjects(XContentBuilder builder, + ToXContent.Params params, + boolean useExplicitOrder, + String namedObjectsName, + List namedObjects) throws IOException { + if (useExplicitOrder) { + builder.startArray(namedObjectsName); + } else { + builder.startObject(namedObjectsName); + } + for (NamedXContentObject object : namedObjects) { + if (useExplicitOrder) { + builder.startObject(); + } + builder.field(object.getName(), object, params); + if (useExplicitOrder) { + builder.endObject(); + } + } + if (useExplicitOrder) { + builder.endArray(); + } else { + builder.endObject(); + } + return builder; + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java new file mode 100644 index 0000000000000..616aaea21d12b --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -0,0 +1,299 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.Version; +import org.elasticsearch.client.common.TimeUtil; +import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class TrainedModelConfig implements ToXContentObject { + + public static final String NAME = "trained_model_doc"; + + public static final ParseField MODEL_ID = new ParseField("model_id"); + public static final ParseField CREATED_BY = new ParseField("created_by"); + public static final ParseField VERSION = new ParseField("version"); + public static final ParseField DESCRIPTION = new ParseField("description"); + public static final ParseField CREATED_TIME = new ParseField("created_time"); + public static final ParseField MODEL_VERSION = new ParseField("model_version"); + public static final ParseField DEFINITION = new ParseField("definition"); + public static final ParseField MODEL_TYPE = new ParseField("model_type"); + public static final ParseField METADATA = new ParseField("metadata"); + + public static final ObjectParser PARSER = new ObjectParser<>(NAME, + true, + TrainedModelConfig.Builder::new); + static { + PARSER.declareString(TrainedModelConfig.Builder::setModelId, MODEL_ID); + PARSER.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY); + PARSER.declareString(TrainedModelConfig.Builder::setVersion, VERSION); + PARSER.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION); + PARSER.declareField(TrainedModelConfig.Builder::setCreatedTime, + (p, c) -> TimeUtil.parseTimeFieldToInstant(p, CREATED_TIME.getPreferredName()), + CREATED_TIME, + ObjectParser.ValueType.VALUE); + PARSER.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION); + PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE); + PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); + PARSER.declareNamedObjects(TrainedModelConfig.Builder::setDefinition, + (p, c, n) -> p.namedObject(TrainedModel.class, n, null), + (modelDocBuilder) -> { /* Noop does not matter client side */ }, + DEFINITION); + } + + public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private final String modelId; + private final String createdBy; + private final Version version; + private final String description; + private final Instant createdTime; + private final Long modelVersion; + private final String modelType; + private final Map metadata; + private final TrainedModel definition; + + TrainedModelConfig(String modelId, + String createdBy, + Version version, + String description, + Instant createdTime, + Long modelVersion, + String modelType, + TrainedModel definition, + Map metadata) { + this.modelId = modelId; + this.createdBy = createdBy; + this.version = version; + this.createdTime = Instant.ofEpochMilli(createdTime.toEpochMilli()); + this.modelType = modelType; + this.definition = definition; + this.description = description; + this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); + this.modelVersion = modelVersion; + } + + public String getModelId() { + return modelId; + } + + public String getCreatedBy() { + return createdBy; + } + + public Version getVersion() { + return version; + } + + public String getDescription() { + return description; + } + + public Instant getCreatedTime() { + return createdTime; + } + + public Long getModelVersion() { + return modelVersion; + } + + public String getModelType() { + return modelType; + } + + public Map getMetadata() { + return metadata; + } + + public TrainedModel getDefinition() { + return definition; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (modelId != null) { + builder.field(MODEL_ID.getPreferredName(), modelId); + } + if (createdBy != null) { + builder.field(CREATED_BY.getPreferredName(), createdBy); + } + if (version != null) { + builder.field(VERSION.getPreferredName(), version.toString()); + } + if (description != null) { + builder.field(DESCRIPTION.getPreferredName(), description); + } + if (createdTime != null) { + builder.timeField(CREATED_TIME.getPreferredName(), CREATED_TIME.getPreferredName() + "_string", createdTime.toEpochMilli()); + } + if (modelVersion != null) { + builder.field(MODEL_VERSION.getPreferredName(), modelVersion); + } + if (modelType != null) { + builder.field(MODEL_TYPE.getPreferredName(), modelType); + } + if (definition != null) { + NamedXContentObjectHelper.writeNamedObjects(builder, + params, + false, + DEFINITION.getPreferredName(), + Collections.singletonList(definition)); + } + if (metadata != null) { + builder.field(METADATA.getPreferredName(), metadata); + } + builder.endObject(); + return builder; + } + + @Override + public String toString() { + return Strings.toString(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelConfig that = (TrainedModelConfig) o; + return Objects.equals(modelId, that.modelId) && + Objects.equals(createdBy, that.createdBy) && + Objects.equals(version, that.version) && + Objects.equals(description, that.description) && + Objects.equals(createdTime, that.createdTime) && + Objects.equals(modelVersion, that.modelVersion) && + Objects.equals(modelType, that.modelType) && + Objects.equals(definition, that.definition) && + Objects.equals(metadata, that.metadata); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, + createdBy, + version, + createdTime, + modelType, + definition, + description, + metadata, + modelVersion); + } + + + public static class Builder { + + private String modelId; + private String createdBy; + private Version version; + private String description; + private Instant createdTime; + private Long modelVersion; + private String modelType; + private Map metadata; + private TrainedModel definition; + + public Builder setModelId(String modelId) { + this.modelId = modelId; + return this; + } + + private Builder setCreatedBy(String createdBy) { + this.createdBy = createdBy; + return this; + } + + private Builder setVersion(Version version) { + this.version = version; + return this; + } + + private Builder setVersion(String version) { + return this.setVersion(Version.fromString(version)); + } + + public Builder setDescription(String description) { + this.description = description; + return this; + } + + private Builder setCreatedTime(Instant createdTime) { + this.createdTime = createdTime; + return this; + } + + public Builder setModelVersion(Long modelVersion) { + this.modelVersion = modelVersion; + return this; + } + + public Builder setModelType(String modelType) { + this.modelType = modelType; + return this; + } + + public Builder setMetadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public Builder setDefinition(TrainedModel definition) { + this.definition = definition; + return this; + } + + private Builder setDefinition(List definition) { + assert definition.size() == 1; + return setDefinition(definition.get(0)); + } + + public TrainedModelConfig build() { + return new TrainedModelConfig( + modelId, + createdBy, + version, + description, + createdTime, + modelVersion, + modelType, + definition, + metadata); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModel.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModel.java index fb1f5c3b4ab92..43ff877089b51 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModel.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModel.java @@ -18,11 +18,11 @@ */ package org.elasticsearch.client.ml.inference.trainedmodel; -import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.client.ml.inference.NamedXContentObject; import java.util.List; -public interface TrainedModel extends ToXContentObject { +public interface TrainedModel extends NamedXContentObject { /** * @return List of featureNames expected by the model. In the order that they are expected diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/config/Job.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/config/Job.java index 119f4e86ffd7f..c5f9f00f895ff 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/config/Job.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/config/Job.java @@ -18,7 +18,7 @@ */ package org.elasticsearch.client.ml.job.config; -import org.elasticsearch.client.ml.job.util.TimeUtil; +import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.unit.TimeValue; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/process/DataCounts.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/process/DataCounts.java index c0e16622ba593..a19902eb8fbfa 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/process/DataCounts.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/process/DataCounts.java @@ -18,8 +18,8 @@ */ package org.elasticsearch.client.ml.job.process; +import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.client.ml.job.config.Job; -import org.elasticsearch.client.ml.job.util.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser.ValueType; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/process/ModelSizeStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/process/ModelSizeStats.java index 6ea3cede0e3f1..5f7a1e2988560 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/process/ModelSizeStats.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/process/ModelSizeStats.java @@ -18,9 +18,9 @@ */ package org.elasticsearch.client.ml.job.process; +import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.client.ml.job.config.Job; import org.elasticsearch.client.ml.job.results.Result; -import org.elasticsearch.client.ml.job.util.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser.ValueType; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/process/ModelSnapshot.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/process/ModelSnapshot.java index 6a92eaf019021..b282baa4f9f5b 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/process/ModelSnapshot.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/process/ModelSnapshot.java @@ -19,8 +19,8 @@ package org.elasticsearch.client.ml.job.process; import org.elasticsearch.Version; +import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.client.ml.job.config.Job; -import org.elasticsearch.client.ml.job.util.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser.ValueType; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/AnomalyRecord.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/AnomalyRecord.java index 3f743b3642256..3c52aad74d0a8 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/AnomalyRecord.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/AnomalyRecord.java @@ -18,8 +18,8 @@ */ package org.elasticsearch.client.ml.job.results; +import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.client.ml.job.config.Job; -import org.elasticsearch.client.ml.job.util.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser.ValueType; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/Bucket.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/Bucket.java index 9f549f16bbc0b..01c2a6cd6f5bf 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/Bucket.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/Bucket.java @@ -18,8 +18,8 @@ */ package org.elasticsearch.client.ml.job.results; +import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.client.ml.job.config.Job; -import org.elasticsearch.client.ml.job.util.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser.ValueType; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/BucketInfluencer.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/BucketInfluencer.java index ade5a5a2f50f2..63e7e2022f735 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/BucketInfluencer.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/BucketInfluencer.java @@ -18,8 +18,8 @@ */ package org.elasticsearch.client.ml.job.results; +import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.client.ml.job.config.Job; -import org.elasticsearch.client.ml.job.util.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser.ValueType; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/Influencer.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/Influencer.java index 4892b7f93468d..906c5b3ef9c99 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/Influencer.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/Influencer.java @@ -18,8 +18,8 @@ */ package org.elasticsearch.client.ml.job.results; +import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.client.ml.job.config.Job; -import org.elasticsearch.client.ml.job.util.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser.ValueType; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/OverallBucket.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/OverallBucket.java index 722c2361b6762..67ebc55c87f24 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/OverallBucket.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/OverallBucket.java @@ -18,8 +18,8 @@ */ package org.elasticsearch.client.ml.job.results; +import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.client.ml.job.config.Job; -import org.elasticsearch.client.ml.job.util.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/util/TimeUtil.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/util/TimeUtil.java deleted file mode 100644 index 254979a360d0f..0000000000000 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/util/TimeUtil.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.elasticsearch.client.ml.job.util; - -import org.elasticsearch.common.time.DateFormatters; -import org.elasticsearch.common.xcontent.XContentParser; - -import java.io.IOException; -import java.time.format.DateTimeFormatter; -import java.util.Date; - -public final class TimeUtil { - - /** - * Parse out a Date object given the current parser and field name. - * - * @param parser current XContentParser - * @param fieldName the field's preferred name (utilized in exception) - * @return parsed Date object - * @throws IOException from XContentParser - */ - public static Date parseTimeField(XContentParser parser, String fieldName) throws IOException { - if (parser.currentToken() == XContentParser.Token.VALUE_NUMBER) { - return new Date(parser.longValue()); - } else if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { - return new Date(DateFormatters.from(DateTimeFormatter.ISO_INSTANT.parse(parser.text())).toInstant().toEpochMilli()); - } - throw new IllegalArgumentException( - "unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]"); - } - -} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/transform/transforms/TransformCheckpointingInfo.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/transform/transforms/TransformCheckpointingInfo.java index d5ba364384416..5edb42779a2c9 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/transform/transforms/TransformCheckpointingInfo.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/transform/transforms/TransformCheckpointingInfo.java @@ -19,7 +19,7 @@ package org.elasticsearch.client.transform.transforms; -import org.elasticsearch.client.transform.transforms.util.TimeUtil; +import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/transform/transforms/TransformConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/transform/transforms/TransformConfig.java index ff740cfcf242d..23d65dd5ca7a6 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/transform/transforms/TransformConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/transform/transforms/TransformConfig.java @@ -20,8 +20,8 @@ package org.elasticsearch.client.transform.transforms; import org.elasticsearch.Version; +import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.client.transform.transforms.pivot.PivotConfig; -import org.elasticsearch.client.transform.transforms.util.TimeUtil; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelperTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelperTests.java new file mode 100644 index 0000000000000..9eca65e529928 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelperTests.java @@ -0,0 +1,112 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class NamedXContentObjectHelperTests extends ESTestCase { + + static class NamedTestObject implements NamedXContentObject { + + private String fieldValue; + public static final ObjectParser PARSER = + new ObjectParser<>("my_named_object", true, NamedTestObject::new); + static { + PARSER.declareString(NamedTestObject::setFieldValue, new ParseField("my_field")); + } + + NamedTestObject() { + + } + + NamedTestObject(String value) { + this.fieldValue = value; + } + + @Override + public String getName() { + return "my_named_object"; + } + + public void setFieldValue(String value) { + this.fieldValue = value; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (fieldValue != null) { + builder.field("my_field", fieldValue); + } + builder.endObject(); + return builder; + } + } + + public void testSerializeInOrder() throws IOException { + String expected = + "{\"my_objects\":[{\"my_named_object\":{\"my_field\":\"value1\"}},{\"my_named_object\":{\"my_field\":\"value2\"}}]}"; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + List objects = Arrays.asList(new NamedTestObject("value1"), new NamedTestObject("value2")); + NamedXContentObjectHelper.writeNamedObjects(builder, ToXContent.EMPTY_PARAMS, true, "my_objects", objects); + builder.endObject(); + assertThat(BytesReference.bytes(builder).utf8ToString(), equalTo(expected)); + } + } + + public void testSerialize() throws IOException { + String expected = "{\"my_objects\":{\"my_named_object\":{\"my_field\":\"value1\"},\"my_named_object\":{\"my_field\":\"value2\"}}}"; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + List objects = Arrays.asList(new NamedTestObject("value1"), new NamedTestObject("value2")); + NamedXContentObjectHelper.writeNamedObjects(builder, ToXContent.EMPTY_PARAMS, false, "my_objects", objects); + builder.endObject(); + assertThat(BytesReference.bytes(builder).utf8ToString(), equalTo(expected)); + } + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(Collections.singletonList(new NamedXContentRegistry.Entry(NamedXContentObject.class, + new ParseField("my_named_object"), + (p, c) -> NamedTestObject.PARSER.apply(p, null)))); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java new file mode 100644 index 0000000000000..1f484a991aa9a --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -0,0 +1,76 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.Version; +import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Predicate; + + +public class TrainedModelConfigTests extends AbstractXContentTestCase { + + @Override + protected TrainedModelConfig doParseInstance(XContentParser parser) throws IOException { + return TrainedModelConfig.fromXContent(parser).build(); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> !field.isEmpty(); + } + + @Override + protected TrainedModelConfig createTestInstance() { + return new TrainedModelConfig( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + Version.CURRENT, + randomBoolean() ? null : randomAlphaOfLength(100), + Instant.ofEpochMilli(randomNonNegativeLong()), + randomBoolean() ? null : randomNonNegativeLong(), + randomAlphaOfLength(10), + randomFrom(TreeTests.createRandom()), + randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java new file mode 100644 index 0000000000000..6b8694e7e3b0c --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -0,0 +1,372 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference; + +import org.elasticsearch.Version; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.common.time.TimeUtils; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.MlStrings; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class TrainedModelConfig implements ToXContentObject, Writeable { + + public static final String NAME = "trained_model_doc"; + + public static final ParseField MODEL_ID = new ParseField("model_id"); + public static final ParseField CREATED_BY = new ParseField("created_by"); + public static final ParseField VERSION = new ParseField("version"); + public static final ParseField DESCRIPTION = new ParseField("description"); + public static final ParseField CREATED_TIME = new ParseField("created_time"); + public static final ParseField MODEL_VERSION = new ParseField("model_version"); + public static final ParseField DEFINITION = new ParseField("definition"); + public static final ParseField MODEL_TYPE = new ParseField("model_type"); + public static final ParseField METADATA = new ParseField("metadata"); + + // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly + public static final ObjectParser LENIENT_PARSER = createParser(true); + public static final ObjectParser STRICT_PARSER = createParser(false); + + private static ObjectParser createParser(boolean ignoreUnknownFields) { + ObjectParser parser = new ObjectParser<>(NAME, + ignoreUnknownFields, + TrainedModelConfig.Builder::new); + parser.declareString(TrainedModelConfig.Builder::setModelId, MODEL_ID); + parser.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY); + parser.declareString(TrainedModelConfig.Builder::setVersion, VERSION); + parser.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION); + parser.declareField(TrainedModelConfig.Builder::setCreatedTime, + (p, c) -> TimeUtils.parseTimeFieldToInstant(p, CREATED_TIME.getPreferredName()), + CREATED_TIME, + ObjectParser.ValueType.VALUE); + parser.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION); + parser.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE); + parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); + parser.declareNamedObjects(TrainedModelConfig.Builder::setDefinition, + (p, c, n) -> ignoreUnknownFields ? + p.namedObject(LenientlyParsedTrainedModel.class, n, null) : + p.namedObject(StrictlyParsedTrainedModel.class, n, null), + (modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ }, + DEFINITION); + return parser; + } + + public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boolean lenient) throws IOException { + return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); + } + + public static String documentId(String modelId, long modelVersion) { + return NAME + "-" + modelId + "-" + modelVersion; + } + + + private final String modelId; + private final String createdBy; + private final Version version; + private final String description; + private final Instant createdTime; + private final long modelVersion; + private final String modelType; + private final Map metadata; + // TODO how to reference and store large models that will not be executed in Java??? + // Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something + // TODO Should this be lazily parsed when loading via the index??? + private final TrainedModel definition; + TrainedModelConfig(String modelId, + String createdBy, + Version version, + String description, + Instant createdTime, + Long modelVersion, + String modelType, + TrainedModel definition, + Map metadata) { + this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); + this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY); + this.version = ExceptionsHelper.requireNonNull(version, VERSION); + this.createdTime = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(createdTime, CREATED_TIME).toEpochMilli()); + this.modelType = ExceptionsHelper.requireNonNull(modelType, MODEL_TYPE); + this.definition = definition; + this.description = description; + this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); + this.modelVersion = modelVersion == null ? 0 : modelVersion; + } + + public TrainedModelConfig(StreamInput in) throws IOException { + modelId = in.readString(); + createdBy = in.readString(); + version = Version.readVersion(in); + description = in.readOptionalString(); + createdTime = in.readInstant(); + modelVersion = in.readVLong(); + modelType = in.readString(); + definition = in.readOptionalNamedWriteable(TrainedModel.class); + metadata = in.readMap(); + } + + public String getModelId() { + return modelId; + } + + public String getCreatedBy() { + return createdBy; + } + + public Version getVersion() { + return version; + } + + public String getDescription() { + return description; + } + + public Instant getCreatedTime() { + return createdTime; + } + + public long getModelVersion() { + return modelVersion; + } + + public String getModelType() { + return modelType; + } + + public Map getMetadata() { + return metadata; + } + + @Nullable + public TrainedModel getDefinition() { + return definition; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(createdBy); + Version.writeVersion(version, out); + out.writeOptionalString(description); + out.writeInstant(createdTime); + out.writeVLong(modelVersion); + out.writeString(modelType); + out.writeOptionalNamedWriteable(definition); + out.writeMap(metadata); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.field(CREATED_BY.getPreferredName(), createdBy); + builder.field(VERSION.getPreferredName(), version.toString()); + if (description != null) { + builder.field(DESCRIPTION.getPreferredName(), description); + } + builder.timeField(CREATED_TIME.getPreferredName(), CREATED_TIME.getPreferredName() + "_string", createdTime.toEpochMilli()); + builder.field(MODEL_VERSION.getPreferredName(), modelVersion); + builder.field(MODEL_TYPE.getPreferredName(), modelType); + if (definition != null) { + NamedXContentObjectHelper.writeNamedObjects(builder, + params, + false, + DEFINITION.getPreferredName(), + Collections.singletonList(definition)); + } + if (metadata != null) { + builder.field(METADATA.getPreferredName(), metadata); + } + builder.endObject(); + return builder; + } + + @Override + public String toString() { + return Strings.toString(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelConfig that = (TrainedModelConfig) o; + return Objects.equals(modelId, that.modelId) && + Objects.equals(createdBy, that.createdBy) && + Objects.equals(version, that.version) && + Objects.equals(description, that.description) && + Objects.equals(createdTime, that.createdTime) && + Objects.equals(modelVersion, that.modelVersion) && + Objects.equals(modelType, that.modelType) && + Objects.equals(definition, that.definition) && + Objects.equals(metadata, that.metadata); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, + createdBy, + version, + createdTime, + modelType, + definition, + description, + metadata, + modelVersion); + } + + + public static class Builder { + + private String modelId; + private String createdBy; + private Version version; + private String description; + private Instant createdTime; + private Long modelVersion; + private String modelType; + private Map metadata; + private TrainedModel definition; + + public Builder setModelId(String modelId) { + this.modelId = modelId; + return this; + } + + public Builder setCreatedBy(String createdBy) { + this.createdBy = createdBy; + return this; + } + + public Builder setVersion(Version version) { + this.version = version; + return this; + } + + private Builder setVersion(String version) { + return this.setVersion(Version.fromString(version)); + } + + public Builder setDescription(String description) { + this.description = description; + return this; + } + + public Builder setCreatedTime(Instant createdTime) { + this.createdTime = createdTime; + return this; + } + + public Builder setModelVersion(Long modelVersion) { + this.modelVersion = modelVersion; + return this; + } + + public Builder setModelType(String modelType) { + this.modelType = modelType; + return this; + } + + public Builder setMetadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public Builder setDefinition(TrainedModel definition) { + this.definition = definition; + return this; + } + + private Builder setDefinition(List definition) { + if (definition.size() != 1) { + throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.", + DEFINITION.getPreferredName()); + } + return setDefinition(definition.get(0)); + } + + // TODO move to REST level instead of here in the builder + public void validate() { + // We require a definition to be available until we support other means of supplying the definition + ExceptionsHelper.requireNonNull(definition, DEFINITION); + ExceptionsHelper.requireNonNull(modelId, MODEL_ID); + + if (MlStrings.isValidId(modelId) == false) { + throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INVALID_ID, MODEL_ID.getPreferredName(), modelId)); + } + + if (MlStrings.hasValidLengthForId(modelId) == false) { + throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.ID_TOO_LONG, + MODEL_ID.getPreferredName(), + modelId, + MlStrings.ID_LENGTH_LIMIT)); + } + + if (version != null) { + throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", VERSION.getPreferredName()); + } + + if (createdBy != null) { + throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", + CREATED_BY.getPreferredName()); + } + + if (createdTime != null) { + throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", + CREATED_TIME.getPreferredName()); + } + } + + public TrainedModelConfig build() { + return new TrainedModelConfig( + modelId, + createdBy, + version, + description, + createdTime, + modelVersion, + modelType, + definition, + metadata); + } + + public TrainedModelConfig build(Version version) { + return new TrainedModelConfig( + modelId, + createdBy, + version, + description, + Instant.now(), + modelVersion, + modelType, + definition, + metadata); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java new file mode 100644 index 0000000000000..e5820f4068e60 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.persistence; + +/** + * Class containing the index constants so that the index version, name, and prefix are available to a wider audience. + */ +public final class InferenceIndexConstants { + + public static final String INDEX_VERSION = "000001"; + public static final String INDEX_NAME_PREFIX = ".ml-inference-"; + public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*"; + public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION; + + private InferenceIndexConstants() {} + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 4d33260053af5..8054f3c802b43 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -75,9 +75,16 @@ public final class Messages { "Inconsistent {0}; ''{1}'' specified in the body differs from ''{2}'' specified as a URL argument"; public static final String INVALID_ID = "Invalid {0}; ''{1}'' can contain lowercase alphanumeric (a-z and 0-9), hyphens or " + "underscores; must start and end with alphanumeric"; + public static final String ID_TOO_LONG = "Invalid {0}; ''{1}'' cannot contain more than {2} characters."; public static final String INVALID_GROUP = "Invalid group id ''{0}''; must be non-empty string and may contain lowercase alphanumeric" + " (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric"; + public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] with version [{1}] already exists"; + public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; + public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL = + "Failed to serialize the trained model [{0}] with version [{1}] for storage"; + public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}] with version [{1}]"; + public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; public static final String JOB_AUDIT_UPDATED = "Job updated: {0}"; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java new file mode 100644 index 0000000000000..a7a6d22ae3e0c --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.utils; + +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; + +public final class NamedXContentObjectHelper { + + private NamedXContentObjectHelper() {} + + public static XContentBuilder writeNamedObjects(XContentBuilder builder, + ToXContent.Params params, + boolean useExplicitOrder, + String namedObjectsName, + List namedObjects) throws IOException { + if (useExplicitOrder) { + builder.startArray(namedObjectsName); + } else { + builder.startObject(namedObjectsName); + } + for (NamedXContentObject object : namedObjects) { + if (useExplicitOrder) { + builder.startObject(); + } + builder.field(object.getName(), object, params); + if (useExplicitOrder) { + builder.endObject(); + } + } + if (useExplicitOrder) { + builder.endArray(); + } else { + builder.endObject(); + } + return builder; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java new file mode 100644 index 0000000000000..7a6b884eee34c --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.MlStrings; +import org.junit.Before; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.equalTo; + + +public class TrainedModelConfigTests extends AbstractSerializingTestCase { + + private boolean lenient; + + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); + } + + @Override + protected TrainedModelConfig doParseInstance(XContentParser parser) throws IOException { + return TrainedModelConfig.fromXContent(parser, lenient).build(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> !field.isEmpty(); + } + + @Override + protected TrainedModelConfig createTestInstance() { + return new TrainedModelConfig( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + Version.CURRENT, + randomBoolean() ? null : randomAlphaOfLength(100), + Instant.ofEpochMilli(randomNonNegativeLong()), + randomBoolean() ? null : randomNonNegativeLong(), + randomAlphaOfLength(10), + randomBoolean() ? null : randomFrom(TreeTests.createRandom()), + randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))); + } + + @Override + protected Writeable.Reader instanceReader() { + return TrainedModelConfig::new; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + + public void testValidateWithNullDefinition() { + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> TrainedModelConfig.builder().validate()); + assertThat(ex.getMessage(), equalTo("[definition] must not be null.")); + } + + public void testValidateWithInvalidID() { + String modelId = "InvalidID-"; + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate()); + assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId))); + } + + public void testValidateWithLongID() { + String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining()); + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate()); + assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT))); + } + + public void testValidateWithIllegallyUserProvidedFields() { + String modelId = "simplemodel"; + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> TrainedModelConfig.builder() + .setDefinition(randomFrom(TreeTests.createRandom())) + .setCreatedTime(Instant.now()) + .setModelId(modelId).validate()); + assertThat(ex.getMessage(), equalTo("illegal to set [created_time] at inference model creation")); + + ex = expectThrows(ElasticsearchException.class, + () -> TrainedModelConfig.builder() + .setDefinition(randomFrom(TreeTests.createRandom())) + .setVersion(Version.CURRENT) + .setModelId(modelId).validate()); + assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation")); + + ex = expectThrows(ElasticsearchException.class, + () -> TrainedModelConfig.builder() + .setDefinition(randomFrom(TreeTests.createRandom())) + .setCreatedBy("ml_user") + .setModelId(modelId).validate()); + assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation")); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelperTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelperTests.java new file mode 100644 index 0000000000000..a9a30d68c0e6f --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelperTests.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.utils; + +import org.elasticsearch.client.ml.inference.NamedXContentObject; +import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class NamedXContentObjectHelperTests extends ESTestCase { + + static class NamedTestObject implements NamedXContentObject { + + private String fieldValue; + public static final ObjectParser PARSER = + new ObjectParser<>("my_named_object", true, NamedTestObject::new); + static { + PARSER.declareString(NamedTestObject::setFieldValue, new ParseField("my_field")); + } + + NamedTestObject() { + + } + + NamedTestObject(String value) { + this.fieldValue = value; + } + + @Override + public String getName() { + return "my_named_object"; + } + + void setFieldValue(String value) { + this.fieldValue = value; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (fieldValue != null) { + builder.field("my_field", fieldValue); + } + builder.endObject(); + return builder; + } + } + + public void testSerializeInOrder() throws IOException { + String expected = + "{\"my_objects\":[{\"my_named_object\":{\"my_field\":\"value1\"}},{\"my_named_object\":{\"my_field\":\"value2\"}}]}"; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + List objects = Arrays.asList(new NamedTestObject("value1"), new NamedTestObject("value2")); + NamedXContentObjectHelper.writeNamedObjects(builder, ToXContent.EMPTY_PARAMS, true, "my_objects", objects); + builder.endObject(); + assertThat(BytesReference.bytes(builder).utf8ToString(), equalTo(expected)); + } + } + + public void testSerialize() throws IOException { + String expected = "{\"my_objects\":{\"my_named_object\":{\"my_field\":\"value1\"},\"my_named_object\":{\"my_field\":\"value2\"}}}"; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + List objects = Arrays.asList(new NamedTestObject("value1"), new NamedTestObject("value2")); + NamedXContentObjectHelper.writeNamedObjects(builder, ToXContent.EMPTY_PARAMS, false, "my_objects", objects); + builder.endObject(); + assertThat(BytesReference.bytes(builder).utf8ToString(), equalTo(expected)); + } + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(Collections.singletonList(new NamedXContentRegistry.Entry(NamedXContentObject.class, + new ParseField("my_named_object"), + (p, c) -> NamedTestObject.PARSER.apply(p, null)))); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index cc01e15b36658..32ba339fa48c7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -124,6 +124,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields; import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; @@ -199,6 +200,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactory; +import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; @@ -898,6 +900,12 @@ public UnaryOperator> getIndexTemplateMetaDat logger.error("Error loading the template for the " + AnomalyDetectorsIndex.jobResultsIndexPrefix() + " indices", e); } + try { + templates.put(InferenceIndexConstants.LATEST_INDEX_NAME, InferenceInternalIndex.getIndexTemplateMetaData()); + } catch (IOException e) { + logger.error("Error loading the template for the " + InferenceIndexConstants.LATEST_INDEX_NAME + " index", e); + } + return templates; }; } @@ -909,7 +917,8 @@ public static boolean allTemplatesInstalled(ClusterState clusterState) { AuditorField.NOTIFICATIONS_INDEX, MlMetaIndex.INDEX_NAME, AnomalyDetectorsIndexFields.STATE_INDEX_PREFIX, - AnomalyDetectorsIndex.jobResultsIndexPrefix()); + AnomalyDetectorsIndex.jobResultsIndexPrefix(), + InferenceIndexConstants.LATEST_INDEX_NAME); for (String templateName : templateNames) { allPresent = allPresent && TemplateUtils.checkTemplateExistsAndVersionIsGTECurrentVersion(templateName, clusterState); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java new file mode 100644 index 0000000000000..2f1cf2aed4ef2 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.persistence; + +import org.elasticsearch.Version; +import org.elasticsearch.cluster.metadata.IndexMetaData; +import org.elasticsearch.cluster.metadata.IndexTemplateMetaData; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; + +import java.io.IOException; +import java.util.Collections; + +import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME; +import static org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants.LATEST_INDEX_NAME; +import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.DATE; +import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.DYNAMIC; +import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.ENABLED; +import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.KEYWORD; +import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.LONG; +import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.PROPERTIES; +import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TEXT; +import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TYPE; +import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.addMetaInformation; + + +/** + * Changelog of internal index versions + * + * Please list changes, increase the version in {@link InferenceInternalIndex} if you are 1st in this release cycle + * + * version 1 (7.5): initial + */ +public final class InferenceInternalIndex { + + private InferenceInternalIndex() {} + + public static XContentBuilder mappings() throws IOException { + return configMapping(SINGLE_MAPPING_NAME); + } + + public static IndexTemplateMetaData getIndexTemplateMetaData() throws IOException { + IndexTemplateMetaData inferenceTemplate = IndexTemplateMetaData.builder(LATEST_INDEX_NAME) + .patterns(Collections.singletonList(LATEST_INDEX_NAME)) + .version(Version.CURRENT.id) + .settings(Settings.builder() + // the configurations are expected to be small + .put(IndexMetaData.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetaData.SETTING_AUTO_EXPAND_REPLICAS, "0-1")) + .putMapping(SINGLE_MAPPING_NAME, Strings.toString(mappings())) + .build(); + return inferenceTemplate; + } + + public static XContentBuilder configMapping(String mappingType) throws IOException { + XContentBuilder builder = jsonBuilder(); + builder.startObject(); + builder.startObject(mappingType); + addMetaInformation(builder); + + // do not allow anything outside of the defined schema + builder.field(DYNAMIC, "false"); + + builder.startObject(PROPERTIES); + addInferenceDocFields(builder); + return builder.endObject() + .endObject() + .endObject(); + } + + private static void addInferenceDocFields(XContentBuilder builder) throws IOException { + builder.startObject(TrainedModelConfig.MODEL_ID.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(TrainedModelConfig.CREATED_BY.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(TrainedModelConfig.VERSION.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(TrainedModelConfig.DESCRIPTION.getPreferredName()) + .field(TYPE, TEXT) + .endObject() + .startObject(TrainedModelConfig.CREATED_TIME.getPreferredName()) + .field(TYPE, DATE) + .endObject() + .startObject(TrainedModelConfig.MODEL_VERSION.getPreferredName()) + .field(TYPE, LONG) + .endObject() + .startObject(TrainedModelConfig.DEFINITION.getPreferredName()) + .field(ENABLED, false) + .endObject() + .startObject(TrainedModelConfig.MODEL_TYPE.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(TrainedModelConfig.METADATA.getPreferredName()) + .field(ENABLED, false) + .endObject(); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java new file mode 100644 index 0000000000000..e569edc07fd8d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.persistence; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.index.IndexAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; + +import java.io.IOException; +import java.io.InputStream; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class TrainedModelProvider { + + private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class); + private final Client client; + private final NamedXContentRegistry xContentRegistry; + + public TrainedModelProvider(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + public void storeTrainedModel(TrainedModelConfig trainedModelConfig, ActionListener listener) { + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + XContentBuilder source = trainedModelConfig.toXContent(builder, ToXContent.EMPTY_PARAMS); + + IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME) + .opType(DocWriteRequest.OpType.CREATE) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .id(TrainedModelConfig.documentId(trainedModelConfig.getModelId(), trainedModelConfig.getModelVersion())) + .source(source); + executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, + ActionListener.wrap( + r -> listener.onResponse(true), + e -> { + logger.error( + new ParameterizedMessage("[{}][{}] failed to store trained model for inference", + trainedModelConfig.getModelId(), + trainedModelConfig.getModelVersion()), + e); + if (e instanceof VersionConflictEngineException) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, + trainedModelConfig.getModelId(), trainedModelConfig.getModelVersion()))); + } else { + listener.onFailure( + new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL, + RestStatus.INTERNAL_SERVER_ERROR, + e, + trainedModelConfig.getModelId())); + } + })); + } catch (IOException e) { + // not expected to happen but for the sake of completeness + listener.onFailure(new ElasticsearchParseException( + Messages.getMessage(Messages.INFERENCE_FAILED_TO_SERIALIZE_MODEL, trainedModelConfig.getModelId()), + e)); + } + } + + public void getTrainedModel(String modelId, long modelVersion, ActionListener listener) { + QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders + .idsQuery() + .addIds(TrainedModelConfig.documentId(modelId, modelVersion))); + SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) + .setQuery(queryBuilder) + // use sort to get the last + .addSort("_index", SortOrder.DESC) + .setSize(1) + .request(); + + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, + ActionListener.wrap( + searchResponse -> { + if (searchResponse.getHits().getHits().length == 0) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, modelVersion))); + return; + } + BytesReference source = searchResponse.getHits().getHits()[0].getSourceRef(); + parseInferenceDocLenientlyFromSource(source, modelId, modelVersion, listener); + }, + listener::onFailure)); + } + + + private void parseInferenceDocLenientlyFromSource(BytesReference source, + String modelId, + long modelVersion, + ActionListener modelListener) { + try (InputStream stream = source.streamInput(); + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { + modelListener.onResponse(TrainedModelConfig.fromXContent(parser, true).build()); + } catch (Exception e) { + logger.error(new ParameterizedMessage("[{}][{}] failed to parse model", modelId, modelVersion), e); + modelListener.onFailure(e); + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java new file mode 100644 index 0000000000000..8da2f4a3b4ef3 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.Version; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; + +public class TrainedModelProviderIT extends MlSingleNodeTestCase { + + private TrainedModelProvider trainedModelProvider; + + @Before + public void createComponents() throws Exception { + trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry()); + waitForMlTemplates(); + } + + public void testPutTrainedModelConfig() throws Exception { + String modelId = "test-put-trained-model-config"; + TrainedModelConfig config = buildTrainedModelConfig(modelId, 0); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + } + + public void testPutTrainedModelConfigThatAlreadyExists() throws Exception { + String modelId = "test-put-trained-model-config-exists"; + TrainedModelConfig config = buildTrainedModelConfig(modelId, 0); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder); + assertThat(exceptionHolder.get(), is(not(nullValue()))); + assertThat(exceptionHolder.get().getMessage(), + equalTo(Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, modelId, 0))); + } + + public void testGetTrainedModelConfig() throws Exception { + String modelId = "test-get-trained-model-config"; + TrainedModelConfig config = buildTrainedModelConfig(modelId, 0); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + + AtomicReference getConfigHolder = new AtomicReference<>(); + blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, 0, listener), getConfigHolder, exceptionHolder); + assertThat(getConfigHolder.get(), is(not(nullValue()))); + assertThat(getConfigHolder.get(), equalTo(config)); + } + + public void testGetMissingTrainingModelConfig() throws Exception { + String modelId = "test-get-missing-trained-model-config"; + AtomicReference getConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, 0, listener), getConfigHolder, exceptionHolder); + assertThat(exceptionHolder.get(), is(not(nullValue()))); + assertThat(exceptionHolder.get().getMessage(), + equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, 0))); + } + + private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) { + return TrainedModelConfig.builder() + .setCreatedBy("ml_test") + .setDefinition(TreeTests.createRandom()) + .setDescription("trained model config for test") + .setModelId(modelId) + .setModelType("binary_decision_tree") + .setModelVersion(modelVersion) + .build(Version.CURRENT); + } + + @Override + public NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + + } + +}