diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java index 5e105dffd8b74..1774277d8433b 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java @@ -61,11 +61,11 @@ public static Tree createRandom() { } public static Tree buildRandomTree(List featureNames, int depth, TargetType targetType) { - int numFeatures = featureNames.size(); + int maxFeatureIndex = featureNames.size() -1; Tree.Builder builder = Tree.builder(); builder.setFeatureNames(featureNames); - TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble()); + TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble()); List childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild()); for (int i = 0; i < depth -1; i++) { @@ -76,7 +76,7 @@ public static Tree buildRandomTree(List featureNames, int depth, TargetT builder.addLeaf(nodeId, randomDouble()); } else { TreeNode.Builder childNode = - builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble()); + builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble()); nextNodes.add(childNode.getLeftChild()); nextNodes.add(childNode.getRightChild()); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index b0d9a32bb0de6..ff7fbf813db67 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -253,8 +253,10 @@ public static Builder builder() { @Override public void validate() { - if (featureNames.isEmpty()) { - throw ExceptionsHelper.badRequestException("[{}] must not be empty for tree model", FEATURE_NAMES.getPreferredName()); + int maxFeatureIndex = maxFeatureIndex(); + if (maxFeatureIndex >= featureNames.size()) { + throw ExceptionsHelper.badRequestException("feature index [{}] is out of bounds for the [{}] array", + maxFeatureIndex, FEATURE_NAMES.getPreferredName()); } checkTargetType(); detectMissingNodes(); @@ -267,6 +269,23 @@ public long estimatedNumOperations() { return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size(); } + /** + * The highest index of a feature used any of the nodes. + * If no nodes use a feature return -1. This can only happen + * if the tree contains a single leaf node. + * + * @return The max or -1 + */ + int maxFeatureIndex() { + int maxFeatureIndex = -1; + + for (TreeNode node : nodes) { + maxFeatureIndex = Math.max(maxFeatureIndex, node.getSplitFeature()); + } + + return maxFeatureIndex; + } + private void checkTargetType() { if (this.classificationLabels != null && this.targetType != TargetType.CLASSIFICATION) { throw ExceptionsHelper.badRequestException( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index f529c7180c9b1..ec7c35afca1af 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -29,6 +29,7 @@ import java.util.stream.IntStream; import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -72,10 +73,10 @@ public static Tree createRandom() { public static Tree buildRandomTree(List featureNames, int depth) { Tree.Builder builder = Tree.builder(); - int numFeatures = featureNames.size() - 1; + int maxFeatureIndex = featureNames.size() - 1; builder.setFeatureNames(featureNames); - TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble()); + TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble()); List childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild()); for (int i = 0; i < depth -1; i++) { @@ -86,7 +87,7 @@ public static Tree buildRandomTree(List featureNames, int depth) { builder.addLeaf(nodeId, randomDouble()); } else { TreeNode.Builder childNode = - builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble()); + builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble()); nextNodes.add(childNode.getLeftChild()); nextNodes.add(childNode.getRightChild()); } @@ -339,26 +340,83 @@ public void testTreeWithTargetTypeAndLabelsMismatch() { assertThat(ex.getMessage(), equalTo(msg)); } - public void testTreeWithEmptyFeatureNames() { - String msg = "[feature_names] must not be empty for tree model"; - ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { - Tree.builder() - .setRoot(TreeNode.builder(0) - .setLeftChild(1) - .setSplitFeature(1) - .setThreshold(randomDouble())) - .setFeatureNames(Collections.emptyList()) - .build() - .validate(); - }); - assertThat(ex.getMessage(), equalTo(msg)); - } - public void testOperationsEstimations() { Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5); assertThat(tree.estimatedNumOperations(), equalTo(7L)); } + public void testMaxFeatureIndex() { + + int numFeatures = randomIntBetween(1, 15); + // We need a tree where every feature is used, choose a depth big enough to + // accommodate those non-leave nodes (leaf nodes don't have a feature index) + int depth = (int) Math.ceil(Math.log(numFeatures +1) / Math.log(2)) + 1; + List featureNames = new ArrayList<>(numFeatures); + for (int i=0; i childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild()); + + for (int i = 0; i < depth -1; i++) { + List nextNodes = new ArrayList<>(); + for (int nodeId : childNodes) { + if (i == depth -2) { + builder.addLeaf(nodeId, randomDouble()); + } else { + TreeNode.Builder childNode = + builder.addJunction(nodeId, featureIndex++ % numFeatures, true, randomDouble()); + nextNodes.add(childNode.getLeftChild()); + nextNodes.add(childNode.getRightChild()); + } + } + childNodes = nextNodes; + } + + Tree tree = builder.build(); + + assertEquals(numFeatures, tree.maxFeatureIndex() +1); + } + + public void testMaxFeatureIndexSingleNodeTree() { + Tree tree = Tree.builder() + .setRoot(TreeNode.builder(0).setLeafValue(10.0)) + .setFeatureNames(Collections.emptyList()) + .build(); + + assertEquals(-1, tree.maxFeatureIndex()); + } + + public void testValidateGivenMissingFeatures() { + List featureNames = Arrays.asList("foo", "bar", "baz"); + + // build a tree referencing a feature at index 3 which is not in the featureNames list + Tree.Builder builder = Tree.builder().setFeatureNames(featureNames); + builder.addJunction(0, 0, true, randomDouble()); + builder.addJunction(1, 1, true, randomDouble()); + builder.addJunction(2, 3, true, randomDouble()); + builder.addLeaf(3, randomDouble()); + builder.addLeaf(4, randomDouble()); + builder.addLeaf(5, randomDouble()); + builder.addLeaf(6, randomDouble()); + + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> builder.build().validate()); + assertThat(e.getDetailedMessage(), containsString("feature index [3] is out of bounds for the [feature_names] array")); + } + + public void testValidateGivenTreeWithNoFeatures() { + Tree.builder() + .setRoot(TreeNode.builder(0).setLeafValue(10.0)) + .setFeatureNames(Collections.emptyList()) + .build() + .validate(); + } + private static Map zipObjMap(List keys, List values) { return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); } diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index fe500ac14abdc..8b4f118756d95 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -137,7 +137,7 @@ integTest.runner { 'ml/inference_crud/Test get given missing trained model', 'ml/inference_crud/Test get given expression without matches and allow_no_match is false', 'ml/inference_crud/Test put ensemble with empty models', - 'ml/inference_crud/Test put ensemble with tree where tree has empty feature-names', + 'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index', 'ml/inference_crud/Test put model with empty input.field_names', 'ml/inference_stats_crud/Test get stats given missing trained model', 'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false', diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index 74f12519b4adb..45da23a01e6f5 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -172,6 +172,40 @@ setup: - match: { count: 1 } - match: { trained_model_configs.0.model_id: "lang_ident_model_1" } --- +"Test put ensemble with single node and empty feature_names": + + - do: + ml.put_trained_model: + model_id: "ensemble_tree_empty_feature_names" + body: > + { + "input": { + "field_names": "fieldy_mc_fieldname" + }, + "definition": { + "trained_model": { + "ensemble": { + "feature_names": [], + "trained_models": [ + { + "tree": { + "feature_names": [], + "tree_structure": [ + { + "node_index": 0, + "decision_type": "lte", + "leaf_value": 12.0, + "default_left": true + }] + } + } + ] + } + } + } + } + +--- "Test put ensemble with empty models": - do: catch: /\[trained_models\] must not be empty/ @@ -192,11 +226,11 @@ setup: } } --- -"Test put ensemble with tree where tree has empty feature-names": +"Test put ensemble with tree where tree has out of bounds feature_names index": - do: - catch: /\[feature_names\] must not be empty/ + catch: /feature index \[1\] is out of bounds for the \[feature_names\] array/ ml.put_trained_model: - model_id: "ensemble_tree_missing_feature_names" + model_id: "ensemble_tree_out_of_bounds_feature_names_index" body: > { "input": { @@ -213,7 +247,7 @@ setup: "tree_structure": [ { "node_index": 0, - "split_feature": 0, + "split_feature": 1, "split_gain": 12.0, "threshold": 10.0, "decision_type": "lte",