diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java index be7c3c00af2c2..2325bbf27baa0 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java @@ -19,6 +19,10 @@ package org.elasticsearch.client.ml.inference; import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding; @@ -47,6 +51,15 @@ public List getNamedXContentParsers() { // Model namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Ensemble.NAME), Ensemble::fromXContent)); + + // Aggregating output + namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class, + new ParseField(WeightedMode.NAME), + WeightedMode::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class, + new ParseField(WeightedSum.NAME), + WeightedSum::fromXContent)); return namedXContent; } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TargetType.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TargetType.java new file mode 100644 index 0000000000000..694a72f1cc5f8 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TargetType.java @@ -0,0 +1,35 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel; + +import java.util.Locale; + +public enum TargetType { + + REGRESSION, CLASSIFICATION; + + public static TargetType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java new file mode 100644 index 0000000000000..d16d758769c2b --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -0,0 +1,188 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper; +import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class Ensemble implements TrainedModel { + + public static final String NAME = "ensemble"; + public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); + public static final ParseField TRAINED_MODELS = new ParseField("trained_models"); + public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output"); + public static final ParseField TARGET_TYPE = new ParseField("target_type"); + public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); + + private static final ObjectParser PARSER = new ObjectParser<>( + NAME, + true, + Ensemble.Builder::new); + + static { + PARSER.declareStringArray(Ensemble.Builder::setFeatureNames, FEATURE_NAMES); + PARSER.declareNamedObjects(Ensemble.Builder::setTrainedModels, + (p, c, n) -> + p.namedObject(TrainedModel.class, n, null), + (ensembleBuilder) -> { /* Noop does not matter client side */ }, + TRAINED_MODELS); + PARSER.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser, + (p, c, n) -> p.namedObject(OutputAggregator.class, n, null), + (ensembleBuilder) -> { /* Noop does not matter client side */ }, + AGGREGATE_OUTPUT); + PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE); + PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS); + } + + public static Ensemble fromXContent(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + private final List featureNames; + private final List models; + private final OutputAggregator outputAggregator; + private final TargetType targetType; + private final List classificationLabels; + + Ensemble(List featureNames, + List models, + @Nullable OutputAggregator outputAggregator, + TargetType targetType, + @Nullable List classificationLabels) { + this.featureNames = featureNames; + this.models = models; + this.outputAggregator = outputAggregator; + this.targetType = targetType; + this.classificationLabels = classificationLabels; + } + + @Override + public List getFeatureNames() { + return featureNames; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (featureNames != null) { + builder.field(FEATURE_NAMES.getPreferredName(), featureNames); + } + if (models != null) { + NamedXContentObjectHelper.writeNamedObjects(builder, params, true, TRAINED_MODELS.getPreferredName(), models); + } + if (outputAggregator != null) { + NamedXContentObjectHelper.writeNamedObjects(builder, + params, + false, + AGGREGATE_OUTPUT.getPreferredName(), + Collections.singletonList(outputAggregator)); + } + if (targetType != null) { + builder.field(TARGET_TYPE.getPreferredName(), targetType); + } + if (classificationLabels != null) { + builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Ensemble that = (Ensemble) o; + return Objects.equals(featureNames, that.featureNames) + && Objects.equals(models, that.models) + && Objects.equals(targetType, that.targetType) + && Objects.equals(classificationLabels, that.classificationLabels) + && Objects.equals(outputAggregator, that.outputAggregator); + } + + @Override + public int hashCode() { + return Objects.hash(featureNames, models, outputAggregator, classificationLabels, targetType); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private List featureNames; + private List trainedModels; + private OutputAggregator outputAggregator; + private TargetType targetType; + private List classificationLabels; + + public Builder setFeatureNames(List featureNames) { + this.featureNames = featureNames; + return this; + } + + public Builder setTrainedModels(List trainedModels) { + this.trainedModels = trainedModels; + return this; + } + + public Builder setOutputAggregator(OutputAggregator outputAggregator) { + this.outputAggregator = outputAggregator; + return this; + } + + public Builder setTargetType(TargetType targetType) { + this.targetType = targetType; + return this; + } + + public Builder setClassificationLabels(List classificationLabels) { + this.classificationLabels = classificationLabels; + return this; + } + + private void setOutputAggregatorFromParser(List outputAggregators) { + this.setOutputAggregator(outputAggregators.get(0)); + } + + private void setTargetType(String targetType) { + this.targetType = TargetType.fromString(targetType); + } + + public Ensemble build() { + return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/OutputAggregator.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/OutputAggregator.java new file mode 100644 index 0000000000000..955def1999ae3 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/OutputAggregator.java @@ -0,0 +1,28 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.client.ml.inference.NamedXContentObject; + +public interface OutputAggregator extends NamedXContentObject { + /** + * @return The name of the output aggregator + */ + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java new file mode 100644 index 0000000000000..37d589badd1e4 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -0,0 +1,84 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + + +public class WeightedMode implements OutputAggregator { + + public static final String NAME = "weighted_mode"; + public static final ParseField WEIGHTS = new ParseField("weights"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + true, + a -> new WeightedMode((List)a[0])); + static { + PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); + } + + public static WeightedMode fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final List weights; + + public WeightedMode(List weights) { + this.weights = weights; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (weights != null) { + builder.field(WEIGHTS.getPreferredName(), weights); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WeightedMode that = (WeightedMode) o; + return Objects.equals(weights, that.weights); + } + + @Override + public int hashCode() { + return Objects.hash(weights); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSum.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSum.java new file mode 100644 index 0000000000000..534eb8d4def2d --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -0,0 +1,84 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class WeightedSum implements OutputAggregator { + + public static final String NAME = "weighted_sum"; + public static final ParseField WEIGHTS = new ParseField("weights"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + true, + a -> new WeightedSum((List)a[0])); + + static { + PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); + } + + public static WeightedSum fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final List weights; + + public WeightedSum(List weights) { + this.weights = weights; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (weights != null) { + builder.field(WEIGHTS.getPreferredName(), weights); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WeightedSum that = (WeightedSum) o; + return Objects.equals(weights, that.weights); + } + + @Override + public int hashCode() { + return Objects.hash(weights); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java index de040ec6f9ed7..5a1e07f34e256 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java @@ -18,7 +18,9 @@ */ package org.elasticsearch.client.ml.inference.trainedmodel.tree; +import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ObjectParser; @@ -28,7 +30,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -39,12 +40,16 @@ public class Tree implements TrainedModel { public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure"); + public static final ParseField TARGET_TYPE = new ParseField("target_type"); + public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); private static final ObjectParser PARSER = new ObjectParser<>(NAME, true, Builder::new); static { PARSER.declareStringArray(Builder::setFeatureNames, FEATURE_NAMES); PARSER.declareObjectArray(Builder::setNodes, (p, c) -> TreeNode.fromXContent(p), TREE_STRUCTURE); + PARSER.declareString(Builder::setTargetType, TARGET_TYPE); + PARSER.declareStringArray(Builder::setClassificationLabels, CLASSIFICATION_LABELS); } public static Tree fromXContent(XContentParser parser) { @@ -53,10 +58,14 @@ public static Tree fromXContent(XContentParser parser) { private final List featureNames; private final List nodes; - - Tree(List featureNames, List nodes) { - this.featureNames = Collections.unmodifiableList(Objects.requireNonNull(featureNames)); - this.nodes = Collections.unmodifiableList(Objects.requireNonNull(nodes)); + private final TargetType targetType; + private final List classificationLabels; + + Tree(List featureNames, List nodes, TargetType targetType, List classificationLabels) { + this.featureNames = featureNames; + this.nodes = nodes; + this.targetType = targetType; + this.classificationLabels = classificationLabels; } @Override @@ -73,11 +82,30 @@ public List getNodes() { return nodes; } + @Nullable + public List getClassificationLabels() { + return classificationLabels; + } + + public TargetType getTargetType() { + return targetType; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(FEATURE_NAMES.getPreferredName(), featureNames); - builder.field(TREE_STRUCTURE.getPreferredName(), nodes); + if (featureNames != null) { + builder.field(FEATURE_NAMES.getPreferredName(), featureNames); + } + if (nodes != null) { + builder.field(TREE_STRUCTURE.getPreferredName(), nodes); + } + if (classificationLabels != null) { + builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); + } + if (targetType != null) { + builder.field(TARGET_TYPE.getPreferredName(), targetType.toString()); + } builder.endObject(); return builder; } @@ -93,12 +121,14 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; Tree that = (Tree) o; return Objects.equals(featureNames, that.featureNames) + && Objects.equals(classificationLabels, that.classificationLabels) + && Objects.equals(targetType, that.targetType) && Objects.equals(nodes, that.nodes); } @Override public int hashCode() { - return Objects.hash(featureNames, nodes); + return Objects.hash(featureNames, nodes, targetType, classificationLabels); } public static Builder builder() { @@ -109,6 +139,8 @@ public static class Builder { private List featureNames; private ArrayList nodes; private int numNodes; + private TargetType targetType; + private List classificationLabels; public Builder() { nodes = new ArrayList<>(); @@ -137,6 +169,20 @@ public Builder setNodes(TreeNode.Builder... nodes) { return setNodes(Arrays.asList(nodes)); } + public Builder setTargetType(TargetType targetType) { + this.targetType = targetType; + return this; + } + + public Builder setClassificationLabels(List classificationLabels) { + this.classificationLabels = classificationLabels; + return this; + } + + private void setTargetType(String targetType) { + this.targetType = TargetType.fromString(targetType); + } + /** * Add a decision node. Space for the child nodes is allocated * @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index @@ -185,7 +231,9 @@ public Builder addLeaf(int nodeIndex, double value) { public Tree build() { return new Tree(featureNames, - nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList())); + nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()), + targetType, + classificationLabels); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index b5394e5dcbdf3..e632b1f8165ab 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -67,6 +67,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding; @@ -683,7 +686,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(44, namedXContents.size()); + assertEquals(47, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -693,7 +696,7 @@ public void testProvidedNamedXContents() { categories.put(namedXContent.categoryClass, counter + 1); } } - assertEquals("Had: " + categories, 11, categories.size()); + assertEquals("Had: " + categories, 12, categories.size()); assertEquals(Integer.valueOf(3), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); @@ -744,8 +747,11 @@ public void testProvidedNamedXContents() { RSquaredMetric.NAME)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class)); assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME)); - assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class)); - assertThat(names, hasItems(Tree.NAME)); + assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class)); + assertThat(names, hasItems(Tree.NAME, Ensemble.NAME)); + assertEquals(Integer.valueOf(2), + categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class)); + assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java new file mode 100644 index 0000000000000..774ab26bc17c7 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -0,0 +1,97 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class EnsembleTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> !field.isEmpty(); + } + + @Override + protected Ensemble doParseInstance(XContentParser parser) throws IOException { + return Ensemble.fromXContent(parser); + } + + public static Ensemble createRandom() { + int numberOfFeatures = randomIntBetween(1, 10); + List featureNames = Stream.generate(() -> randomAlphaOfLength(10)) + .limit(numberOfFeatures) + .collect(Collectors.toList()); + int numberOfModels = randomIntBetween(1, 10); + List models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6)) + .limit(numberOfFeatures) + .collect(Collectors.toList()); + OutputAggregator outputAggregator = null; + if (randomBoolean()) { + List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); + outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); + } + List categoryLabels = null; + if (randomBoolean()) { + categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); + } + return new Ensemble(featureNames, + models, + outputAggregator, + randomFrom(TargetType.values()), + categoryLabels); + } + + @Override + protected Ensemble createTestInstance() { + return createRandom(); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java new file mode 100644 index 0000000000000..a04652c1d3813 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -0,0 +1,51 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class WeightedModeTests extends AbstractXContentTestCase { + + WeightedMode createTestInstance(int numberOfWeights) { + return new WeightedMode(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList())); + } + + @Override + protected WeightedMode doParseInstance(XContentParser parser) throws IOException { + return WeightedMode.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected WeightedMode createTestInstance() { + return randomBoolean() ? new WeightedMode(null) : createTestInstance(randomIntBetween(1, 100)); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java new file mode 100644 index 0000000000000..ddc4aeccfd34d --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java @@ -0,0 +1,51 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class WeightedSumTests extends AbstractXContentTestCase { + + WeightedSum createTestInstance(int numberOfWeights) { + return new WeightedSum(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList())); + } + + @Override + protected WeightedSum doParseInstance(XContentParser parser) throws IOException { + return WeightedSum.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected WeightedSum createTestInstance() { + return randomBoolean() ? new WeightedSum(null) : createTestInstance(randomIntBetween(1, 100)); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java index 66cdb44b10073..cb06469eaeaf1 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java @@ -18,11 +18,13 @@ */ package org.elasticsearch.client.ml.inference.trainedmodel.tree; +import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.function.Predicate; @@ -50,16 +52,17 @@ protected Tree createTestInstance() { } public static Tree createRandom() { - return buildRandomTree(randomIntBetween(2, 15), 6); + int numberOfFeatures = randomIntBetween(1, 10); + List featureNames = new ArrayList<>(); + for (int i = 0; i < numberOfFeatures; i++) { + featureNames.add(randomAlphaOfLength(10)); + } + return buildRandomTree(featureNames, 6); } - public static Tree buildRandomTree(int numFeatures, int depth) { - + public static Tree buildRandomTree(List featureNames, int depth) { + int numFeatures = featureNames.size(); Tree.Builder builder = Tree.builder(); - List featureNames = new ArrayList<>(numFeatures); - for(int i = 0; i < numFeatures; i++) { - featureNames.add(randomAlphaOfLength(10)); - } builder.setFeatureNames(featureNames); TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble()); @@ -80,8 +83,13 @@ public static Tree buildRandomTree(int numFeatures, int depth) { } childNodes = nextNodes; } - - return builder.build(); + List categoryLabels = null; + if (randomBoolean()) { + categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); + } + return builder.setClassificationLabels(categoryLabels) + .setTargetType(randomFrom(TargetType.values())) + .build(); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 7f14077a1504e..7fff4d6abbd3b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -11,6 +11,12 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.StrictlyParsedOutputAggregator; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; @@ -46,9 +52,27 @@ public List getNamedXContentParsers() { // Model Lenient namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient)); + namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Ensemble.NAME, Ensemble::fromXContentLenient)); + + // Output Aggregator Lenient + namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class, + WeightedMode.NAME, + WeightedMode::fromXContentLenient)); + namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class, + WeightedSum.NAME, + WeightedSum::fromXContentLenient)); // Model Strict namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentStrict)); + namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Ensemble.NAME, Ensemble::fromXContentStrict)); + + // Output Aggregator Strict + namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class, + WeightedMode.NAME, + WeightedMode::fromXContentStrict)); + namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class, + WeightedSum.NAME, + WeightedSum::fromXContentStrict)); return namedXContent; } @@ -66,6 +90,15 @@ public List getNamedWriteables() { // Model namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Ensemble.NAME.getPreferredName(), Ensemble::new)); + + // Output Aggregator + namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class, + WeightedSum.NAME.getPreferredName(), + WeightedSum::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class, + WeightedMode.NAME.getPreferredName(), + WeightedMode::new)); return namedWriteables; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TargetType.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TargetType.java new file mode 100644 index 0000000000000..9897231f5911a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TargetType.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Locale; + +public enum TargetType implements Writeable { + + REGRESSION, CLASSIFICATION; + + public static TargetType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static TargetType fromStream(StreamInput in) throws IOException { + return in.readEnum(TargetType.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(this); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index 1d68e3d6d3f46..cad5a6c0a8c74 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; @@ -28,17 +29,47 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable { double infer(Map fields); /** - * @return {@code true} if the model is classification, {@code false} otherwise. + * @param fields similar to {@link TrainedModel#infer(Map)}, but fields are already in order and doubles + * @return The predicted value. */ - boolean isClassification(); + double infer(List fields); + + /** + * @return {@link TargetType} for the model. + */ + TargetType targetType(); /** * This gathers the probabilities for each potential classification value. * + * The probabilities are indexed by classification ordinal label encoding. + * The length of this list is equal to the number of classification labels. + * * This only should return if the implementation model is inferring classification values and not regression * @param fields The fields and their values to infer against * @return The probabilities of each classification value */ - List inferProbabilities(Map fields); + List classificationProbability(Map fields); + + /** + * @param fields similar to {@link TrainedModel#classificationProbability(Map)} but the fields are already in order and doubles + * @return The probabilities of each classification value + */ + List classificationProbability(List fields); + /** + * The ordinal encoded list of the classification labels. + * @return Oridinal encoded list of classification labels. + */ + @Nullable + List classificationLabels(); + + /** + * Runs validations against the model. + * + * Example: {@link org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree} should check if there are any loops + * + * @throws org.elasticsearch.ElasticsearchException if validations fail + */ + void validate(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java new file mode 100644 index 0000000000000..7f2a7cc9a02ce --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -0,0 +1,311 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { + + // TODO should we have regression/classification sub-classes that accept the builder? + public static final ParseField NAME = new ParseField("ensemble"); + public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); + public static final ParseField TRAINED_MODELS = new ParseField("trained_models"); + public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output"); + public static final ParseField TARGET_TYPE = new ParseField("target_type"); + public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); + + private static final ObjectParser LENIENT_PARSER = createParser(true); + private static final ObjectParser STRICT_PARSER = createParser(false); + + private static ObjectParser createParser(boolean lenient) { + ObjectParser parser = new ObjectParser<>( + NAME.getPreferredName(), + lenient, + Ensemble.Builder::builderForParser); + parser.declareStringArray(Ensemble.Builder::setFeatureNames, FEATURE_NAMES); + parser.declareNamedObjects(Ensemble.Builder::setTrainedModels, + (p, c, n) -> + lenient ? p.namedObject(LenientlyParsedTrainedModel.class, n, null) : + p.namedObject(StrictlyParsedTrainedModel.class, n, null), + (ensembleBuilder) -> ensembleBuilder.setModelsAreOrdered(true), + TRAINED_MODELS); + parser.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser, + (p, c, n) -> + lenient ? p.namedObject(LenientlyParsedOutputAggregator.class, n, null) : + p.namedObject(StrictlyParsedOutputAggregator.class, n, null), + (ensembleBuilder) -> {/*Noop as it could be an array or object, it just has to be a one*/}, + AGGREGATE_OUTPUT); + parser.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE); + parser.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS); + return parser; + } + + public static Ensemble fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null).build(); + } + + public static Ensemble fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null).build(); + } + + private final List featureNames; + private final List models; + private final OutputAggregator outputAggregator; + private final TargetType targetType; + private final List classificationLabels; + + Ensemble(List featureNames, + List models, + OutputAggregator outputAggregator, + TargetType targetType, + @Nullable List classificationLabels) { + this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES)); + this.models = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(models, TRAINED_MODELS)); + this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT); + this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); + this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); + } + + public Ensemble(StreamInput in) throws IOException { + this.featureNames = Collections.unmodifiableList(in.readStringList()); + this.models = Collections.unmodifiableList(in.readNamedWriteableList(TrainedModel.class)); + this.outputAggregator = in.readNamedWriteable(OutputAggregator.class); + this.targetType = TargetType.fromStream(in); + if (in.readBoolean()) { + this.classificationLabels = in.readStringList(); + } else { + this.classificationLabels = null; + } + } + + @Override + public List getFeatureNames() { + return featureNames; + } + + @Override + public double infer(Map fields) { + List features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()); + return infer(features); + } + + @Override + public double infer(List fields) { + List processedInferences = inferAndProcess(fields); + return outputAggregator.aggregate(processedInferences); + } + + @Override + public TargetType targetType() { + return targetType; + } + + @Override + public List classificationProbability(Map fields) { + if ((targetType == TargetType.CLASSIFICATION) == false) { + throw new UnsupportedOperationException( + "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); + } + List features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()); + return classificationProbability(features); + } + + @Override + public List classificationProbability(List fields) { + if ((targetType == TargetType.CLASSIFICATION) == false) { + throw new UnsupportedOperationException( + "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); + } + return inferAndProcess(fields); + } + + @Override + public List classificationLabels() { + return classificationLabels; + } + + private List inferAndProcess(List fields) { + List modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList()); + return outputAggregator.processValues(modelInferences); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeStringCollection(featureNames); + out.writeNamedWriteableList(models); + out.writeNamedWriteable(outputAggregator); + targetType.writeTo(out); + out.writeBoolean(classificationLabels != null); + if (classificationLabels != null) { + out.writeStringCollection(classificationLabels); + } + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FEATURE_NAMES.getPreferredName(), featureNames); + NamedXContentObjectHelper.writeNamedObjects(builder, params, true, TRAINED_MODELS.getPreferredName(), models); + NamedXContentObjectHelper.writeNamedObjects(builder, + params, + false, + AGGREGATE_OUTPUT.getPreferredName(), + Collections.singletonList(outputAggregator)); + builder.field(TARGET_TYPE.getPreferredName(), targetType.toString()); + if (classificationLabels != null) { + builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Ensemble that = (Ensemble) o; + return Objects.equals(featureNames, that.featureNames) + && Objects.equals(models, that.models) + && Objects.equals(targetType, that.targetType) + && Objects.equals(classificationLabels, that.classificationLabels) + && Objects.equals(outputAggregator, that.outputAggregator); + } + + @Override + public int hashCode() { + return Objects.hash(featureNames, models, outputAggregator, targetType, classificationLabels); + } + + @Override + public void validate() { + if (this.featureNames != null) { + if (this.models.stream() + .anyMatch(trainedModel -> trainedModel.getFeatureNames().equals(this.featureNames) == false)) { + throw ExceptionsHelper.badRequestException( + "[{}] must be the same and in the same order for each of the {}", + FEATURE_NAMES.getPreferredName(), + TRAINED_MODELS.getPreferredName()); + } + } + if (outputAggregator.expectedValueSize() != null && + outputAggregator.expectedValueSize() != models.size()) { + throw ExceptionsHelper.badRequestException( + "[{}] expects value array of size [{}] but number of models is [{}]", + AGGREGATE_OUTPUT.getPreferredName(), + outputAggregator.expectedValueSize(), + models.size()); + } + if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) { + throw ExceptionsHelper.badRequestException( + "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"); + } + this.models.forEach(TrainedModel::validate); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private List featureNames; + private List trainedModels; + private OutputAggregator outputAggregator = new WeightedSum(); + private TargetType targetType = TargetType.REGRESSION; + private List classificationLabels; + private boolean modelsAreOrdered; + + private Builder (boolean modelsAreOrdered) { + this.modelsAreOrdered = modelsAreOrdered; + } + + private static Builder builderForParser() { + return new Builder(false); + } + + public Builder() { + this(true); + } + + public Builder setFeatureNames(List featureNames) { + this.featureNames = featureNames; + return this; + } + + public Builder setTrainedModels(List trainedModels) { + this.trainedModels = trainedModels; + return this; + } + + public Builder setOutputAggregator(OutputAggregator outputAggregator) { + this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT); + return this; + } + + public Builder setTargetType(TargetType targetType) { + this.targetType = targetType; + return this; + } + + public Builder setClassificationLabels(List classificationLabels) { + this.classificationLabels = classificationLabels; + return this; + } + + private void setOutputAggregatorFromParser(List outputAggregators) { + if (outputAggregators.size() != 1) { + throw ExceptionsHelper.badRequestException("[{}] must have exactly one aggregator defined.", + AGGREGATE_OUTPUT.getPreferredName()); + } + this.setOutputAggregator(outputAggregators.get(0)); + } + + private void setTargetType(String targetType) { + this.targetType = TargetType.fromString(targetType); + } + + private void setModelsAreOrdered(boolean value) { + this.modelsAreOrdered = value; + } + + public Ensemble build() { + // This is essentially a serialization error but the underlying xcontent parsing does not allow us to inject this requirement + // So, we verify the models were parsed in an ordered fashion here instead. + if (modelsAreOrdered == false && trainedModels != null && trainedModels.size() > 1) { + throw ExceptionsHelper.badRequestException("[trained_models] needs to be an array of objects"); + } + return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LenientlyParsedOutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LenientlyParsedOutputAggregator.java new file mode 100644 index 0000000000000..29ba4e3aa7389 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LenientlyParsedOutputAggregator.java @@ -0,0 +1,10 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + + +public interface LenientlyParsedOutputAggregator extends OutputAggregator { +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java new file mode 100644 index 0000000000000..1f882b724ee94 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + +import java.util.List; + +public interface OutputAggregator extends NamedXContentObject, NamedWriteable { + + /** + * @return The expected size of the values array when aggregating. `null` implies there is no expected size. + */ + Integer expectedValueSize(); + + /** + * This pre-processes the values so that they may be passed directly to the {@link OutputAggregator#aggregate(List)} method. + * + * Two major types of pre-processed values could be returned: + * - The confidence/probability scaled values given the input values (See: {@link WeightedMode#processValues(List)} + * - A simple transformation of the passed values in preparation for aggregation (See: {@link WeightedSum#processValues(List)} + * @param values the values to process + * @return A new list containing the processed values or the same list if no processing is required + */ + List processValues(List values); + + /** + * Function to aggregate the processed values into a single double + * + * This may be as simple as returning the index of the maximum value. + * + * Or as complex as a mathematical reduction of all the passed values (i.e. summation, average, etc.). + * + * @param processedValues The values to aggregate + * @return the aggregated value. + */ + double aggregate(List processedValues); + + /** + * @return The name of the output aggregator + */ + String getName(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/StrictlyParsedOutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/StrictlyParsedOutputAggregator.java new file mode 100644 index 0000000000000..017340fda44ac --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/StrictlyParsedOutputAggregator.java @@ -0,0 +1,10 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + + +public interface StrictlyParsedOutputAggregator extends OutputAggregator { +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java new file mode 100644 index 0000000000000..739a4e13d8659 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -0,0 +1,161 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax; + +public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { + + public static final ParseField NAME = new ParseField("weighted_mode"); + public static final ParseField WEIGHTS = new ParseField("weights"); + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + @SuppressWarnings("unchecked") + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( + NAME.getPreferredName(), + lenient, + a -> new WeightedMode((List)a[0])); + parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); + return parser; + } + + public static WeightedMode fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null); + } + + public static WeightedMode fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null); + } + + private final List weights; + + WeightedMode() { + this.weights = null; + } + + public WeightedMode(List weights) { + this.weights = weights == null ? null : Collections.unmodifiableList(weights); + } + + public WeightedMode(StreamInput in) throws IOException { + if (in.readBoolean()) { + this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + } else { + this.weights = null; + } + } + + @Override + public Integer expectedValueSize() { + return this.weights == null ? null : this.weights.size(); + } + + @Override + public List processValues(List values) { + Objects.requireNonNull(values, "values must not be null"); + if (weights != null && values.size() != weights.size()) { + throw new IllegalArgumentException("values must be the same length as weights."); + } + List freqArray = new ArrayList<>(); + Integer maxVal = 0; + for (Double value : values) { + if (value == null) { + throw new IllegalArgumentException("values must not contain null values"); + } + if (Double.isNaN(value) || Double.isInfinite(value) || value < 0.0 || value != Math.rint(value)) { + throw new IllegalArgumentException("values must be whole, non-infinite, and positive"); + } + Integer integerValue = value.intValue(); + freqArray.add(integerValue); + if (integerValue > maxVal) { + maxVal = integerValue; + } + } + List frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY)); + for (int i = 0; i < freqArray.size(); i++) { + Double weight = weights == null ? 1.0 : weights.get(i); + Integer value = freqArray.get(i); + Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight; + frequencies.set(value, frequency); + } + return softMax(frequencies); + } + + @Override + public double aggregate(List values) { + Objects.requireNonNull(values, "values must not be null"); + int bestValue = 0; + double bestFreq = Double.NEGATIVE_INFINITY; + for (int i = 0; i < values.size(); i++) { + if (values.get(i) == null) { + throw new IllegalArgumentException("values must not contain null values"); + } + if (values.get(i) > bestFreq) { + bestFreq = values.get(i); + bestValue = i; + } + } + return bestValue; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(weights != null); + if (weights != null) { + out.writeCollection(weights, StreamOutput::writeDouble); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (weights != null) { + builder.field(WEIGHTS.getPreferredName(), weights); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WeightedMode that = (WeightedMode) o; + return Objects.equals(weights, that.weights); + } + + @Override + public int hashCode() { + return Objects.hash(weights); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java new file mode 100644 index 0000000000000..f5812dabf88f2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -0,0 +1,138 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { + + public static final ParseField NAME = new ParseField("weighted_sum"); + public static final ParseField WEIGHTS = new ParseField("weights"); + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + @SuppressWarnings("unchecked") + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( + NAME.getPreferredName(), + lenient, + a -> new WeightedSum((List)a[0])); + parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); + return parser; + } + + public static WeightedSum fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null); + } + + public static WeightedSum fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null); + } + + private final List weights; + + WeightedSum() { + this.weights = null; + } + + public WeightedSum(List weights) { + this.weights = weights == null ? null : Collections.unmodifiableList(weights); + } + + public WeightedSum(StreamInput in) throws IOException { + if (in.readBoolean()) { + this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + } else { + this.weights = null; + } + } + + @Override + public List processValues(List values) { + Objects.requireNonNull(values, "values must not be null"); + if (weights == null) { + return values; + } + if (values.size() != weights.size()) { + throw new IllegalArgumentException("values must be the same length as weights."); + } + return IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).boxed().collect(Collectors.toList()); + } + + @Override + public double aggregate(List values) { + Objects.requireNonNull(values, "values must not be null"); + if (values.isEmpty()) { + throw new IllegalArgumentException("values must not be empty"); + } + Optional summation = values.stream().reduce(Double::sum); + if (summation.isPresent()) { + return summation.get(); + } + throw new IllegalArgumentException("values must not contain null values"); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(weights != null); + if (weights != null) { + out.writeCollection(weights, StreamOutput::writeDouble); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (weights != null) { + builder.field(WEIGHTS.getPreferredName(), weights); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WeightedSum that = (WeightedSum) o; + return Objects.equals(weights, that.weights); + } + + @Override + public int hashCode() { + return Objects.hash(weights); + } + + @Override + public Integer expectedValueSize() { + return weights == null ? null : this.weights.size(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 8e48fa488a0a8..5dca29d58437e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -9,11 +9,13 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.CachedSupplier; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -31,10 +33,13 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { + // TODO should we have regression/classification sub-classes that accept the builder? public static final ParseField NAME = new ParseField("tree"); public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure"); + public static final ParseField TARGET_TYPE = new ParseField("target_type"); + public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); private static final ObjectParser LENIENT_PARSER = createParser(true); private static final ObjectParser STRICT_PARSER = createParser(false); @@ -46,6 +51,8 @@ private static ObjectParser createParser(boolean lenient) { Tree.Builder::new); parser.declareStringArray(Tree.Builder::setFeatureNames, FEATURE_NAMES); parser.declareObjectArray(Tree.Builder::setNodes, (p, c) -> TreeNode.fromXContent(p, lenient), TREE_STRUCTURE); + parser.declareString(Tree.Builder::setTargetType, TARGET_TYPE); + parser.declareStringArray(Tree.Builder::setClassificationLabels, CLASSIFICATION_LABELS); return parser; } @@ -59,15 +66,28 @@ public static Tree fromXContentLenient(XContentParser parser) { private final List featureNames; private final List nodes; + private final TargetType targetType; + private final List classificationLabels; + private final CachedSupplier highestOrderCategory; - Tree(List featureNames, List nodes) { + Tree(List featureNames, List nodes, TargetType targetType, List classificationLabels) { this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES)); this.nodes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE)); + this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); + this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); + this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue()); } public Tree(StreamInput in) throws IOException { this.featureNames = Collections.unmodifiableList(in.readStringList()); this.nodes = Collections.unmodifiableList(in.readList(TreeNode::new)); + this.targetType = TargetType.fromStream(in); + if (in.readBoolean()) { + this.classificationLabels = Collections.unmodifiableList(in.readStringList()); + } else { + this.classificationLabels = null; + } + this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue()); } @Override @@ -90,7 +110,8 @@ public double infer(Map fields) { return infer(features); } - private double infer(List features) { + @Override + public double infer(List features) { TreeNode node = nodes.get(0); while(node.isLeaf() == false) { node = nodes.get(node.compare(features)); @@ -115,13 +136,40 @@ public List trace(List features) { } @Override - public boolean isClassification() { - return false; + public TargetType targetType() { + return targetType; + } + + @Override + public List classificationProbability(Map fields) { + if ((targetType == TargetType.CLASSIFICATION) == false) { + throw new UnsupportedOperationException( + "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); + } + return classificationProbability(featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList())); + } + + @Override + public List classificationProbability(List fields) { + if ((targetType == TargetType.CLASSIFICATION) == false) { + throw new UnsupportedOperationException( + "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); + } + double label = infer(fields); + // If we are classification, we should assume that the inference return value is whole. + assert label == Math.rint(label); + double maxCategory = this.highestOrderCategory.get(); + // If we are classification, we should assume that the largest leaf value is whole. + assert maxCategory == Math.rint(maxCategory); + List list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)); + // TODO, eventually have TreeNodes contain confidence levels + list.set(Double.valueOf(label).intValue(), 1.0); + return list; } @Override - public List inferProbabilities(Map fields) { - throw new UnsupportedOperationException("Cannot infer probabilities against a regression model."); + public List classificationLabels() { + return classificationLabels; } @Override @@ -133,6 +181,11 @@ public String getWriteableName() { public void writeTo(StreamOutput out) throws IOException { out.writeStringCollection(featureNames); out.writeCollection(nodes); + targetType.writeTo(out); + out.writeBoolean(classificationLabels != null); + if (classificationLabels != null) { + out.writeStringCollection(classificationLabels); + } } @Override @@ -140,6 +193,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(FEATURE_NAMES.getPreferredName(), featureNames); builder.field(TREE_STRUCTURE.getPreferredName(), nodes); + builder.field(TARGET_TYPE.getPreferredName(), targetType.toString()); + if(classificationLabels != null) { + builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); + } builder.endObject(); return builder; } @@ -155,22 +212,96 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; Tree that = (Tree) o; return Objects.equals(featureNames, that.featureNames) - && Objects.equals(nodes, that.nodes); + && Objects.equals(nodes, that.nodes) + && Objects.equals(targetType, that.targetType) + && Objects.equals(classificationLabels, that.classificationLabels); } @Override public int hashCode() { - return Objects.hash(featureNames, nodes); + return Objects.hash(featureNames, nodes, targetType, classificationLabels); } public static Builder builder() { return new Builder(); } + @Override + public void validate() { + checkTargetType(); + detectMissingNodes(); + detectCycle(); + } + + private void checkTargetType() { + if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) { + throw ExceptionsHelper.badRequestException( + "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"); + } + } + + private void detectCycle() { + if (nodes.isEmpty()) { + return; + } + Set visited = new HashSet<>(nodes.size()); + Queue toVisit = new ArrayDeque<>(nodes.size()); + toVisit.add(0); + while(toVisit.isEmpty() == false) { + Integer nodeIdx = toVisit.remove(); + if (visited.contains(nodeIdx)) { + throw ExceptionsHelper.badRequestException("[tree] contains cycle at node {}", nodeIdx); + } + visited.add(nodeIdx); + TreeNode treeNode = nodes.get(nodeIdx); + if (treeNode.getLeftChild() >= 0) { + toVisit.add(treeNode.getLeftChild()); + } + if (treeNode.getRightChild() >= 0) { + toVisit.add(treeNode.getRightChild()); + } + } + } + + private void detectMissingNodes() { + if (nodes.isEmpty()) { + return; + } + + List missingNodes = new ArrayList<>(); + for (int i = 0; i < nodes.size(); i++) { + TreeNode currentNode = nodes.get(i); + if (currentNode == null) { + continue; + } + if (nodeMissing(currentNode.getLeftChild(), nodes)) { + missingNodes.add(currentNode.getLeftChild()); + } + if (nodeMissing(currentNode.getRightChild(), nodes)) { + missingNodes.add(currentNode.getRightChild()); + } + } + if (missingNodes.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("[tree] contains missing nodes {}", missingNodes); + } + } + + private static boolean nodeMissing(int nodeIdx, List nodes) { + return nodeIdx >= nodes.size(); + } + + private Double maxLeafValue() { + return targetType == TargetType.CLASSIFICATION ? + this.nodes.stream().filter(TreeNode::isLeaf).mapToDouble(TreeNode::getLeafValue).max().getAsDouble() : + null; + } + public static class Builder { private List featureNames; private ArrayList nodes; private int numNodes; + private TargetType targetType = TargetType.REGRESSION; + private List classificationLabels; public Builder() { nodes = new ArrayList<>(); @@ -185,13 +316,18 @@ public Builder setFeatureNames(List featureNames) { return this; } + public Builder setRoot(TreeNode.Builder root) { + nodes.set(0, root); + return this; + } + public Builder addNode(TreeNode.Builder node) { nodes.add(node); return this; } public Builder setNodes(List nodes) { - this.nodes = new ArrayList<>(nodes); + this.nodes = new ArrayList<>(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE.getPreferredName())); return this; } @@ -199,6 +335,21 @@ public Builder setNodes(TreeNode.Builder... nodes) { return setNodes(Arrays.asList(nodes)); } + + public Builder setTargetType(TargetType targetType) { + this.targetType = targetType; + return this; + } + + public Builder setClassificationLabels(List classificationLabels) { + this.classificationLabels = classificationLabels; + return this; + } + + private void setTargetType(String targetType) { + this.targetType = TargetType.fromString(targetType); + } + /** * Add a decision node. Space for the child nodes is allocated * @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index @@ -231,61 +382,6 @@ TreeNode.Builder addJunction(int nodeIndex, int featureIndex, boolean isDefaultL return node; } - void detectCycle(List nodes) { - if (nodes.isEmpty()) { - return; - } - Set visited = new HashSet<>(); - Queue toVisit = new ArrayDeque<>(nodes.size()); - toVisit.add(0); - while(toVisit.isEmpty() == false) { - Integer nodeIdx = toVisit.remove(); - if (visited.contains(nodeIdx)) { - throw new IllegalArgumentException("[tree] contains cycle at node " + nodeIdx); - } - visited.add(nodeIdx); - TreeNode.Builder treeNode = nodes.get(nodeIdx); - if (treeNode.getLeftChild() != null) { - toVisit.add(treeNode.getLeftChild()); - } - if (treeNode.getRightChild() != null) { - toVisit.add(treeNode.getRightChild()); - } - } - } - - void detectNullOrMissingNode(List nodes) { - if (nodes.isEmpty()) { - return; - } - if (nodes.get(0) == null) { - throw new IllegalArgumentException("[tree] must have non-null root node."); - } - List nullOrMissingNodes = new ArrayList<>(); - for (int i = 0; i < nodes.size(); i++) { - TreeNode.Builder currentNode = nodes.get(i); - if (currentNode == null) { - continue; - } - if (nodeNullOrMissing(currentNode.getLeftChild())) { - nullOrMissingNodes.add(currentNode.getLeftChild()); - } - if (nodeNullOrMissing(currentNode.getRightChild())) { - nullOrMissingNodes.add(currentNode.getRightChild()); - } - } - if (nullOrMissingNodes.isEmpty() == false) { - throw new IllegalArgumentException("[tree] contains null or missing nodes " + nullOrMissingNodes); - } - } - - private boolean nodeNullOrMissing(Integer nodeIdx) { - if (nodeIdx == null) { - return false; - } - return nodeIdx >= nodes.size() || nodes.get(nodeIdx) == null; - } - /** * Sets the node at {@code nodeIndex} to a leaf node. * @param nodeIndex The index as allocated by a call to {@link #addJunction(int, int, boolean, double)} @@ -301,10 +397,13 @@ Tree.Builder addLeaf(int nodeIndex, double value) { } public Tree build() { - detectNullOrMissingNode(nodes); - detectCycle(nodes); + if (nodes.stream().anyMatch(Objects::isNull)) { + throw ExceptionsHelper.badRequestException("[tree] cannot contain null nodes"); + } return new Tree(featureNames, - nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList())); + nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()), + targetType, + classificationLabels); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java index f0dbb0617503b..9beda88e2c50a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java @@ -143,7 +143,7 @@ public int getRightChild() { } public boolean isLeaf() { - return leftChild < 1; + return leftChild < 0; } public int compare(List features) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java new file mode 100644 index 0000000000000..cb44d03e22bb2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.utils; + +import java.util.List; +import java.util.stream.Collectors; + +public final class Statistics { + + private Statistics(){} + + /** + * Calculates the softMax of the passed values. + * + * Any {@link Double#isInfinite()}, {@link Double#NaN}, or `null` values are ignored in calculation and returned as 0.0 in the + * softMax. + * @param values Values on which to run SoftMax. + * @return A new list containing the softmax of the passed values + */ + public static List softMax(List values) { + Double expSum = 0.0; + Double max = values.stream().filter(v -> isInvalid(v) == false).max(Double::compareTo).orElse(null); + if (max == null) { + throw new IllegalArgumentException("no valid values present"); + } + List exps = values.stream().map(v -> isInvalid(v) ? Double.NEGATIVE_INFINITY : v - max) + .collect(Collectors.toList()); + for (int i = 0; i < exps.size(); i++) { + if (isInvalid(exps.get(i)) == false) { + Double exp = Math.exp(exps.get(i)); + expSum += exp; + exps.set(i, exp); + } + } + for (int i = 0; i < exps.size(); i++) { + if (isInvalid(exps.get(i))) { + exps.set(i, 0.0); + } else { + exps.set(i, exps.get(i)/expSum); + } + } + return exps; + } + + public static boolean isInvalid(Double v) { + return v == null || Double.isInfinite(v) || Double.isNaN(v); + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java index 3a3856cbe95a4..2db86e64e3502 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; @@ -157,7 +158,7 @@ public NamedObjectContainer createTestInstance() { NamedObjectContainer container = new NamedObjectContainer(); container.setPreProcessors(preProcessors); container.setUseExplicitPreprocessorOrder(true); - container.setModel(TreeTests.buildRandomTree(5, 4)); + container.setModel(randomFrom(TreeTests.createRandom(), EnsembleTests.createRandom())); return container; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java new file mode 100644 index 0000000000000..1e1b1f8f7286d --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -0,0 +1,402 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; + +public class EnsembleTests extends AbstractSerializingTestCase { + + private boolean lenient; + + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> !field.isEmpty(); + } + + @Override + protected Ensemble doParseInstance(XContentParser parser) throws IOException { + return lenient ? Ensemble.fromXContentLenient(parser) : Ensemble.fromXContentStrict(parser); + } + + public static Ensemble createRandom() { + int numberOfFeatures = randomIntBetween(1, 10); + List featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList()); + int numberOfModels = randomIntBetween(1, 10); + List models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6)) + .limit(numberOfModels) + .collect(Collectors.toList()); + List weights = randomBoolean() ? + null : + Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); + OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); + List categoryLabels = null; + if (randomBoolean()) { + categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); + } + + return new Ensemble(featureNames, + models, + outputAggregator, + randomFrom(TargetType.values()), + categoryLabels); + } + + @Override + protected Ensemble createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Ensemble::new; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + + public void testEnsembleWithModelsThatHaveDifferentFeatureNames() { + List featureNames = Arrays.asList("foo", "bar", "baz", "farequote"); + ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder().setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("bar", "foo", "baz", "farequote"), 6))) + .build() + .validate(); + }); + assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models")); + + ex = expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder().setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("completely_different"), 6))) + .build() + .validate(); + }); + assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models")); + } + + public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() { + List featureNames = Arrays.asList("foo", "bar"); + int numberOfModels = 5; + List weights = new ArrayList<>(numberOfModels + 2); + for (int i = 0; i < numberOfModels + 2; i++) { + weights.add(randomDouble()); + } + OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); + + List models = new ArrayList<>(numberOfModels); + for (int i = 0; i < numberOfModels; i++) { + models.add(TreeTests.buildRandomTree(featureNames, 6)); + } + ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder() + .setTrainedModels(models) + .setOutputAggregator(outputAggregator) + .setFeatureNames(featureNames) + .build() + .validate(); + }); + assertThat(ex.getMessage(), equalTo("[aggregate_output] expects value array of size [7] but number of models is [5]")); + } + + public void testEnsembleWithInvalidModel() { + List featureNames = Arrays.asList("foo", "bar"); + expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder() + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList( + // Tree with loop + Tree.builder() + .setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble()), + TreeNode.builder(0) + .setLeftChild(0) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .setFeatureNames(featureNames) + .build())) + .build() + .validate(); + }); + } + + public void testEnsembleWithTargetTypeAndLabelsMismatch() { + List featureNames = Arrays.asList("foo", "bar"); + String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"; + ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder() + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList( + Tree.builder() + .setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .setFeatureNames(featureNames) + .build())) + .setClassificationLabels(Arrays.asList("label1", "label2")) + .build() + .validate(); + }); + assertThat(ex.getMessage(), equalTo(msg)); + ex = expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder() + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList( + Tree.builder() + .setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .setFeatureNames(featureNames) + .build())) + .setTargetType(TargetType.CLASSIFICATION) + .build() + .validate(); + }); + assertThat(ex.getMessage(), equalTo(msg)); + } + + public void testClassificationProbability() { + List featureNames = Arrays.asList("foo", "bar"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.0)) + .addNode(TreeNode.builder(4).setLeafValue(1.0)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(0.0)) + .addNode(TreeNode.builder(2).setLeafValue(1.0)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(1) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2).setLeafValue(0.0)) + .build(); + Ensemble ensemble = Ensemble.builder() + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) + .build(); + + List featureVector = Arrays.asList(0.4, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); + List expected = Arrays.asList(0.231475216, 0.768524783); + double eps = 0.000001; + List probabilities = ensemble.classificationProbability(featureMap); + for(int i = 0; i < expected.size(); i++) { + assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + } + + featureVector = Arrays.asList(2.0, 0.7); + featureMap = zipObjMap(featureNames, featureVector); + expected = Arrays.asList(0.3100255188, 0.689974481); + probabilities = ensemble.classificationProbability(featureMap); + for(int i = 0; i < expected.size(); i++) { + assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + } + + featureVector = Arrays.asList(0.0, 1.0); + featureMap = zipObjMap(featureNames, featureVector); + expected = Arrays.asList(0.231475216, 0.768524783); + probabilities = ensemble.classificationProbability(featureMap); + for(int i = 0; i < expected.size(); i++) { + assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + } + } + + public void testClassificationInference() { + List featureNames = Arrays.asList("foo", "bar"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.0)) + .addNode(TreeNode.builder(4).setLeafValue(1.0)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(0.0)) + .addNode(TreeNode.builder(2).setLeafValue(1.0)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(1) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2).setLeafValue(0.0)) + .build(); + Ensemble ensemble = Ensemble.builder() + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) + .build(); + + List featureVector = Arrays.asList(0.4, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); + assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + + featureVector = Arrays.asList(2.0, 0.7); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + + featureVector = Arrays.asList(0.0, 1.0); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + } + + public void testRegressionInference() { + List featureNames = Arrays.asList("foo", "bar"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(0.3)) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.1)) + .addNode(TreeNode.builder(4).setLeafValue(0.2)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + Ensemble ensemble = Ensemble.builder() + .setTargetType(TargetType.REGRESSION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2)) + .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5))) + .build(); + + List featureVector = Arrays.asList(0.4, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); + assertEquals(0.9, ensemble.infer(featureMap), 0.00001); + + featureVector = Arrays.asList(2.0, 0.7); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(0.5, ensemble.infer(featureMap), 0.00001); + + // Test with NO aggregator supplied, verifies default behavior of non-weighted sum + ensemble = Ensemble.builder() + .setTargetType(TargetType.REGRESSION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2)) + .build(); + + featureVector = Arrays.asList(0.4, 0.0); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(1.8, ensemble.infer(featureMap), 0.00001); + + featureVector = Arrays.asList(2.0, 0.7); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + } + + private static Map zipObjMap(List keys, List values) { + return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java new file mode 100644 index 0000000000000..02bfe2797d990 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public abstract class WeightedAggregatorTests extends AbstractSerializingTestCase { + + protected boolean lenient; + + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + public void testWithNullValues() { + OutputAggregator outputAggregator = createTestInstance(); + NullPointerException ex = expectThrows(NullPointerException.class, () -> outputAggregator.processValues(null)); + assertThat(ex.getMessage(), equalTo("values must not be null")); + } + + public void testWithValuesOfWrongLength() { + int numberOfValues = randomIntBetween(5, 10); + List values = new ArrayList<>(numberOfValues); + for (int i = 0; i < numberOfValues; i++) { + values.add(randomDouble()); + } + + OutputAggregator outputAggregatorWithTooFewWeights = createTestInstance(randomIntBetween(1, numberOfValues - 1)); + expectThrows(IllegalArgumentException.class, () -> outputAggregatorWithTooFewWeights.processValues(values)); + + OutputAggregator outputAggregatorWithTooManyWeights = createTestInstance(randomIntBetween(numberOfValues + 1, numberOfValues + 10)); + expectThrows(IllegalArgumentException.class, () -> outputAggregatorWithTooManyWeights.processValues(values)); + } + + abstract T createTestInstance(int numberOfWeights); +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java new file mode 100644 index 0000000000000..7849d6d071ef1 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class WeightedModeTests extends WeightedAggregatorTests { + + @Override + WeightedMode createTestInstance(int numberOfWeights) { + List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()); + return new WeightedMode(weights); + } + + @Override + protected WeightedMode doParseInstance(XContentParser parser) throws IOException { + return lenient ? WeightedMode.fromXContentLenient(parser) : WeightedMode.fromXContentStrict(parser); + } + + @Override + protected WeightedMode createTestInstance() { + return randomBoolean() ? new WeightedMode() : createTestInstance(randomIntBetween(1, 100)); + } + + @Override + protected Writeable.Reader instanceReader() { + return WeightedMode::new; + } + + public void testAggregate() { + List ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0); + List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); + + WeightedMode weightedMode = new WeightedMode(ones); + assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); + + List variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0); + + weightedMode = new WeightedMode(variedWeights); + assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(5.0)); + + weightedMode = new WeightedMode(); + assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java new file mode 100644 index 0000000000000..89222365c83d8 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class WeightedSumTests extends WeightedAggregatorTests { + + @Override + WeightedSum createTestInstance(int numberOfWeights) { + List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()); + return new WeightedSum(weights); + } + + @Override + protected WeightedSum doParseInstance(XContentParser parser) throws IOException { + return lenient ? WeightedSum.fromXContentLenient(parser) : WeightedSum.fromXContentStrict(parser); + } + + @Override + protected WeightedSum createTestInstance() { + return randomBoolean() ? new WeightedSum() : createTestInstance(randomIntBetween(1, 100)); + } + + @Override + protected Writeable.Reader instanceReader() { + return WeightedSum::new; + } + + public void testAggregate() { + List ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0); + List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); + + WeightedSum weightedSum = new WeightedSum(ones); + assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0)); + + List variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0); + + weightedSum = new WeightedSum(variedWeights); + assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(28.0)); + + weightedSum = new WeightedSum(); + assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 391f2e4b7e59a..ce27120d671be 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -5,9 +5,12 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.junit.Before; import java.io.IOException; @@ -47,23 +50,23 @@ protected Predicate getRandomFieldsExcludeFilter() { return field -> field.startsWith("feature_names"); } - @Override protected Tree createTestInstance() { return createRandom(); } public static Tree createRandom() { - return buildRandomTree(randomIntBetween(2, 15), 6); + int numberOfFeatures = randomIntBetween(1, 10); + List featureNames = new ArrayList<>(); + for (int i = 0; i < numberOfFeatures; i++) { + featureNames.add(randomAlphaOfLength(10)); + } + return buildRandomTree(featureNames, 6); } - public static Tree buildRandomTree(int numFeatures, int depth) { - + public static Tree buildRandomTree(List featureNames, int depth) { Tree.Builder builder = Tree.builder(); - List featureNames = new ArrayList<>(numFeatures); - for(int i = 0; i < numFeatures; i++) { - featureNames.add(randomAlphaOfLength(10)); - } + int numFeatures = featureNames.size() - 1; builder.setFeatureNames(featureNames); TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble()); @@ -84,8 +87,14 @@ public static Tree buildRandomTree(int numFeatures, int depth) { } childNodes = nextNodes; } + List categoryLabels = null; + if (randomBoolean()) { + categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); + } - return builder.build(); + return builder.setTargetType(randomFrom(TargetType.values())) + .setClassificationLabels(categoryLabels) + .build(); } @Override @@ -96,7 +105,7 @@ protected Writeable.Reader instanceReader() { public void testInfer() { // Build a tree with 2 nodes and 3 leaves using 2 features // The leaves have unique values 0.1, 0.2, 0.3 - Tree.Builder builder = Tree.builder(); + Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION); TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5); builder.addLeaf(rootNode.getRightChild(), 0.3); TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8); @@ -124,37 +133,76 @@ public void testInfer() { assertEquals(0.2, tree.infer(featureMap), 0.00001); } + public void testTreeClassificationProbability() { + // Build a tree with 2 nodes and 3 leaves using 2 features + // The leaves have unique values 0.1, 0.2, 0.3 + Tree.Builder builder = Tree.builder().setTargetType(TargetType.CLASSIFICATION); + TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5); + builder.addLeaf(rootNode.getRightChild(), 1.0); + TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8); + builder.addLeaf(leftChildNode.getLeftChild(), 1.0); + builder.addLeaf(leftChildNode.getRightChild(), 0.0); + + List featureNames = Arrays.asList("foo", "bar"); + Tree tree = builder.setFeatureNames(featureNames).build(); + + // This feature vector should hit the right child of the root node + List featureVector = Arrays.asList(0.6, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); + assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap)); + + // This should hit the left child of the left child of the root node + // i.e. it takes the path left, left + featureVector = Arrays.asList(0.3, 0.7); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap)); + + // This should hit the right child of the left child of the root node + // i.e. it takes the path left, right + featureVector = Arrays.asList(0.3, 0.9); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap)); + } + public void testTreeWithNullRoot() { - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, - () -> Tree.builder().setNodes(Collections.singletonList(null)) + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> Tree.builder() + .setNodes(Collections.singletonList(null)) + .setFeatureNames(Arrays.asList("foo", "bar")) .build()); - assertThat(ex.getMessage(), equalTo("[tree] must have non-null root node.")); + assertThat(ex.getMessage(), equalTo("[tree] cannot contain null nodes")); } public void testTreeWithInvalidNode() { - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, - () -> Tree.builder().setNodes(TreeNode.builder(0) + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> Tree.builder() + .setNodes(TreeNode.builder(0) .setLeftChild(1) .setSplitFeature(1) .setThreshold(randomDouble())) - .build()); - assertThat(ex.getMessage(), equalTo("[tree] contains null or missing nodes [1]")); + .setFeatureNames(Arrays.asList("foo", "bar")) + .build().validate()); + assertThat(ex.getMessage(), equalTo("[tree] contains missing nodes [1]")); } public void testTreeWithNullNode() { - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, - () -> Tree.builder().setNodes(TreeNode.builder(0) + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> Tree.builder() + .setNodes(TreeNode.builder(0) .setLeftChild(1) .setSplitFeature(1) .setThreshold(randomDouble()), null) - .build()); - assertThat(ex.getMessage(), equalTo("[tree] contains null or missing nodes [1]")); + .setFeatureNames(Arrays.asList("foo", "bar")) + .build() + .validate()); + assertThat(ex.getMessage(), equalTo("[tree] cannot contain null nodes")); } public void testTreeWithCycle() { - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, - () -> Tree.builder().setNodes(TreeNode.builder(0) + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> Tree.builder() + .setNodes(TreeNode.builder(0) .setLeftChild(1) .setSplitFeature(1) .setThreshold(randomDouble()), @@ -162,10 +210,41 @@ public void testTreeWithCycle() { .setLeftChild(0) .setSplitFeature(1) .setThreshold(randomDouble())) - .build()); + .setFeatureNames(Arrays.asList("foo", "bar")) + .build() + .validate()); assertThat(ex.getMessage(), equalTo("[tree] contains cycle at node 0")); } + public void testTreeWithTargetTypeAndLabelsMismatch() { + List featureNames = Arrays.asList("foo", "bar"); + String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"; + ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { + Tree.builder() + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .setFeatureNames(featureNames) + .setClassificationLabels(Arrays.asList("label1", "label2")) + .build() + .validate(); + }); + assertThat(ex.getMessage(), equalTo(msg)); + ex = expectThrows(ElasticsearchException.class, () -> { + Tree.builder() + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .setFeatureNames(featureNames) + .setTargetType(TargetType.CLASSIFICATION) + .build() + .validate(); + }); + assertThat(ex.getMessage(), equalTo(msg)); + } + private static Map zipObjMap(List keys, List values) { return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java new file mode 100644 index 0000000000000..5fb69238b1579 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.utils; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.closeTo; + +public class StatisticsTests extends ESTestCase { + + public void testSoftMax() { + List values = Arrays.asList(Double.NEGATIVE_INFINITY, 1.0, -0.5, null, Double.NaN, Double.POSITIVE_INFINITY, 1.0, 5.0); + List softMax = Statistics.softMax(values); + + List expected = Arrays.asList(0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042); + + for(int i = 0; i < expected.size(); i++) { + assertThat(softMax.get(i), closeTo(expected.get(i), 0.000001)); + } + } + + public void testSoftMaxWithNoValidValues() { + List values = Arrays.asList(Double.NEGATIVE_INFINITY, null, Double.NaN, Double.POSITIVE_INFINITY); + expectThrows(IllegalArgumentException.class, () -> Statistics.softMax(values)); + } + +}