From bbf4bf4316d9a4075fd50d2e8fcf27bad74ca989 Mon Sep 17 00:00:00 2001 From: Akshay Kumar Date: Wed, 30 Oct 2019 23:21:46 -0700 Subject: [PATCH 1/4] Add new xgboost parameter to represent missing features in feature vector --- .../dectree/NaiveAdditiveDecisionTree.java | 48 +++++++++++++++++++ .../ltr/ranker/parser/XGBoostJsonParser.java | 22 +++++++-- .../NaiveAdditiveDecisionTreeTests.java | 33 +++++++++++-- .../ranker/parser/XGBoostJsonParserTests.java | 29 +++++++++++ .../es/ltr/ranker/dectree/simple_tree.txt | 4 +- 5 files changed, 126 insertions(+), 10 deletions(-) diff --git a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java index eb115c16..c683351f 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java +++ b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java @@ -132,6 +132,54 @@ public long ramBytesUsed() { } } + public static class SplitWithMissing implements Node { + private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(SplitWithMissing.class); + private final Node left; + private final Node right; + private final int feature; + private final float threshold; + private final Node onMissing; + + public SplitWithMissing(Node left, Node right, Node onMissing, int feature, float threshold) { + this.left = Objects.requireNonNull(left); + this.right = Objects.requireNonNull(right); + this.feature = feature; + this.threshold = threshold; + this.onMissing = onMissing == null ? left : onMissing; + } + + @Override + public boolean isLeaf() { + return false; + } + + @Override + public float eval(float[] scores) { + Node n = this; + while (!n.isLeaf()) { + assert n instanceof SplitWithMissing; + SplitWithMissing s = (SplitWithMissing) n; + if (scores[s.feature] == Float.MAX_VALUE) { + n = s.onMissing; + } else if (s.threshold > scores[s.feature]) { + n = s.left; + } else { + n = s.right; + } + } + assert n instanceof Leaf; + return n.eval(scores); + } + + /** + * Return the memory usage of this object in bytes. Negative values are illegal. + */ + @Override + public long ramBytesUsed() { + return BASE_RAM_USED + left.ramBytesUsed() + right.ramBytesUsed(); + } + } + public static class Leaf implements Node { private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(Split.class); diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java index 1710507a..9fde11f2 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java @@ -65,10 +65,12 @@ private static class XGBoostDefinition { static { PARSER = new ObjectParser<>("xgboost_definition", XGBoostDefinition::new); PARSER.declareString(XGBoostDefinition::setNormalizer, new ParseField("objective")); + PARSER.declareBoolean(XGBoostDefinition::setFloatMaxForMissing, new ParseField("use_float_max_for_missing")); PARSER.declareObjectArray(XGBoostDefinition::setSplitParserStates, SplitParserState::parse, new ParseField("splits")); } private Normalizer normalizer; + private boolean useFloatMaxForMissing; private List splitParserStates; public static XGBoostDefinition parse(XContentParser parser, FeatureSet set) throws IOException { @@ -105,6 +107,7 @@ public static XGBoostDefinition parse(XContentParser parser, FeatureSet set) thr XGBoostDefinition() { normalizer = Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME); + useFloatMaxForMissing = false; } /** @@ -132,6 +135,10 @@ void setNormalizer(String objectiveName) { } } + void setFloatMaxForMissing(boolean useFloatMaxForMissing) { + this.useFloatMaxForMissing = useFloatMaxForMissing; + } + void setSplitParserStates(List splitParserStates) { this.splitParserStates = splitParserStates; } @@ -140,7 +147,7 @@ Node[] getTrees(FeatureSet set) { Node[] trees = new Node[splitParserStates.size()]; ListIterator it = splitParserStates.listIterator(); while(it.hasNext()) { - trees[it.nextIndex()] = it.next().toNode(set); + trees[it.nextIndex()] = it.next().toNode(set, this); } return trees; } @@ -169,7 +176,6 @@ private static class SplitParserState { private Float threshold; private Integer rightNodeId; private Integer leftNodeId; - // Ignored private Integer missingNodeId; private Float leaf; private List children; @@ -246,10 +252,16 @@ boolean isSplit() { } - Node toNode(FeatureSet set) { + Node toNode(FeatureSet set, XGBoostDefinition xgb) { if (isSplit()) { - return new NaiveAdditiveDecisionTree.Split(children.get(0).toNode(set), children.get(1).toNode(set), - set.featureOrdinal(split), threshold); + Node left = children.get(0).toNode(set, xgb); + Node right = children.get(1).toNode(set, xgb); + if (xgb.useFloatMaxForMissing) { + Node onMissing = this.missingNodeId.equals(this.rightNodeId) ? right : left; + return new NaiveAdditiveDecisionTree.SplitWithMissing(left, right, onMissing, set.featureOrdinal(split), threshold); + } else { + return new NaiveAdditiveDecisionTree.Split(left, right, set.featureOrdinal(split), threshold); + } } else { return new NaiveAdditiveDecisionTree.Leaf(leaf); } diff --git a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java index 6abb1f25..a5e96b73 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java @@ -83,6 +83,31 @@ public void testSigmoidScore() throws IOException { assertEquals(expected, ranker.score(vector), Math.ulp(expected)); } + public void testScoreSparseFeatureSet() throws IOException { + NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME)); + LtrRanker.FeatureVector vector = ranker.newFeatureVector(null); + vector.setFeatureScore(0, 1); + vector.setFeatureScore(1, 3); + vector.setFeatureScore(2, Float.MAX_VALUE); + + // simple_tree model does not specify `missing`. We should take the + // left branch in that case. + float expected = 17F*3.4F + 3.2F*2.8F; + assertEquals(expected, ranker.score(vector), Math.ulp(expected)); + } + + public void testScoreSparseFeatureSetWithMissingField() throws IOException { + NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME)); + LtrRanker.FeatureVector vector = ranker.newFeatureVector(null); + vector.setFeatureScore(0, Float.MAX_VALUE); + vector.setFeatureScore(1, 2); + vector.setFeatureScore(2, 3); + + float expected = 3.2F*3.4F + 3.2F*2.8F; + assertEquals(expected, ranker.score(vector), Math.ulp(expected)); + } + + public void testPerfAndRobustness() { SimpleCountRandomTreeGeneratorStatsCollector counts = new SimpleCountRandomTreeGeneratorStatsCollector(); NaiveAdditiveDecisionTree ranker = generateRandomDecTree(100, 1000, @@ -207,7 +232,8 @@ NaiveAdditiveDecisionTree.Node parseTree() { if (line.contains("- output")) { return new NaiveAdditiveDecisionTree.Leaf(extractLastFloat(line)); } else if(line.contains("- split")) { - String featName = line.split(":")[1]; + String[] splitString = line.split(":"); + String featName = splitString[1]; int ord = set.featureOrdinal(featName); if (ord < 0 || ord > set.size()) { throw new IllegalArgumentException("Unknown feature " + featName); @@ -215,9 +241,10 @@ NaiveAdditiveDecisionTree.Node parseTree() { float threshold = extractLastFloat(line); NaiveAdditiveDecisionTree.Node right = parseTree(); NaiveAdditiveDecisionTree.Node left = parseTree(); + NaiveAdditiveDecisionTree.Node onMissing = (splitString.length > 3) ? + (Boolean.parseBoolean(splitString[2]) ? left : right) : null; + return new NaiveAdditiveDecisionTree.SplitWithMissing(left, right, onMissing, ord, threshold); - return new NaiveAdditiveDecisionTree.Split(left, right, - ord, threshold); } else { throw new IllegalArgumentException("Invalid tree"); } diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java index 17eebf21..797b0b55 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java @@ -128,6 +128,35 @@ public void testReadSimpleSplitWithObjective() throws IOException { assertEquals(0.2F, tree.score(v), Math.ulp(0.2F)); } + public void testReadSimpleSplitWithSupportForMissing() throws IOException { + String model = "{" + + "\"use_float_max_for_missing\": true," + + "\"splits\": [{" + + " \"nodeid\": 0," + + " \"split\":\"feat1\"," + + " \"depth\":0," + + " \"split_condition\":0.123," + + " \"yes\":1," + + " \"no\": 2," + + " \"missing\":2,"+ + " \"children\": [" + + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + + "]}]}"; + + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + NaiveAdditiveDecisionTree tree = parser.parse(set, model); + FeatureVector v = tree.newFeatureVector(null); + v.setFeatureScore(0, 0.124F); + assertEquals(0.2F, tree.score(v), Math.ulp(0.2F)); + v.setFeatureScore(0, 0.122F); + assertEquals(0.5F, tree.score(v), Math.ulp(0.5F)); + v.setFeatureScore(0, 0.123F); + assertEquals(0.2F, tree.score(v), Math.ulp(0.2F)); + v.setFeatureScore(0, Float.MAX_VALUE); + assertEquals(0.2F, tree.score(v), Math.ulp(0.2F)); + } + public void testReadSplitWithUnknownParams() throws IOException { String model = "{" + "\"not_param\": \"value\"," + diff --git a/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt b/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt index 109b1cb3..ad9def02 100644 --- a/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt +++ b/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt @@ -1,7 +1,7 @@ # first line after split is right # data point: feature1:1, feature2:2, feature3:3 - tree:3.4 - - split:feature1:2.3 + - split:feature1:false:2.3 - output:3.2 # right wins - split:feature2:2.2 @@ -11,7 +11,7 @@ # left wins => output 1.2*3.4 - output:1.2 - tree:2.8 - - split:feature1:0.1 + - split:feature1:false:0.1 # right wins - split:feature2:1.8 # right wins From e401247d776c7e192e9b4527545c259efbd83f97 Mon Sep 17 00:00:00 2001 From: Akshay Kumar Date: Thu, 5 Dec 2019 13:54:48 -0800 Subject: [PATCH 2/4] Make NaiveAdditiveDecisionTree aware of missing value sentinel --- .../dectree/NaiveAdditiveDecisionTree.java | 133 ++++++++---------- .../ltr/ranker/parser/XGBoostJsonParser.java | 10 +- .../NaiveAdditiveDecisionTreeTests.java | 36 ++--- .../es/ltr/ranker/dectree/simple_tree.txt | 4 +- .../dectree/tree_with_missing_branches.txt | 18 +++ 5 files changed, 99 insertions(+), 102 deletions(-) create mode 100644 src/test/resources/com/o19s/es/ltr/ranker/dectree/tree_with_missing_branches.txt diff --git a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java index c683351f..6aa98214 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java +++ b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java @@ -35,6 +35,7 @@ public class NaiveAdditiveDecisionTree extends DenseLtrRanker implements Account private final float[] weights; private final int modelSize; private final Normalizer normalizer; + private final Float missingValue; /** * TODO: Constructor for these classes are strict and not really @@ -45,13 +46,15 @@ public class NaiveAdditiveDecisionTree extends DenseLtrRanker implements Account * @param weights the respective weights * @param modelSize the modelSize in number of feature used * @param normalizer class to perform any normalization on model score + * @param missingValue sentinel value to indicate feature is missing in the feature vector (optional) */ - public NaiveAdditiveDecisionTree(Node[] trees, float[] weights, int modelSize, Normalizer normalizer) { + public NaiveAdditiveDecisionTree(Node[] trees, float[] weights, int modelSize, Normalizer normalizer, Float missingValue) { assert trees.length == weights.length; this.trees = trees; this.weights = weights; this.modelSize = modelSize; this.normalizer = normalizer; + this.missingValue = missingValue; } @Override @@ -63,12 +66,50 @@ public String name() { protected float score(DenseFeatureVector vector) { float sum = 0; float[] scores = vector.scores; - for (int i = 0; i < trees.length; i++) { - sum += weights[i]*trees[i].eval(scores); + if (this.missingValue != null) { + for (int i = 0; i < trees.length; i++) { + sum += weights[i] * evalWithMissing(trees[i], scores); + } + } else { + for (int i = 0; i < trees.length; i++) { + sum += weights[i] * eval(trees[i], scores); + } } return normalizer.normalize(sum); } + private float eval(Node rootNode, float[] scores) { + Node n = rootNode; + while (!n.isLeaf()) { + assert n instanceof Split; + Split s = (Split) n; + if (s.threshold > scores[s.feature]) { + n = s.left; + } else { + n = s.right; + } + } + assert n instanceof Leaf; + return ((Leaf) n).getOutput(); + } + + private float evalWithMissing(Node rootNode, float[] scores) { + Node n = rootNode; + while (!n.isLeaf()) { + assert n instanceof Split; + Split s = (Split) n; + if (scores[s.feature] == this.missingValue) { + n = s.onMissing; + } else if (s.threshold > scores[s.feature]) { + n = s.left; + } else { + n = s.right; + } + } + assert n instanceof Leaf; + return ((Leaf) n).getOutput(); + } + @Override protected int size() { return modelSize; @@ -85,67 +126,22 @@ public long ramBytesUsed() { public interface Node extends Accountable { boolean isLeaf(); - float eval(float[] scores); } public static class Split implements Node { private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(Split.class); - private final Node left; - private final Node right; - private final int feature; - private final float threshold; - - public Split(Node left, Node right, int feature, float threshold) { - this.left = Objects.requireNonNull(left); - this.right = Objects.requireNonNull(right); - this.feature = feature; - this.threshold = threshold; - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public float eval(float[] scores) { - Node n = this; - while (!n.isLeaf()) { - assert n instanceof Split; - Split s = (Split) n; - if (s.threshold > scores[s.feature]) { - n = s.left; - } else { - n = s.right; - } - } - assert n instanceof Leaf; - return n.eval(scores); - } - - /** - * Return the memory usage of this object in bytes. Negative values are illegal. - */ - @Override - public long ramBytesUsed() { - return BASE_RAM_USED + left.ramBytesUsed() + right.ramBytesUsed(); - } - } - - public static class SplitWithMissing implements Node { - private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(SplitWithMissing.class); - private final Node left; - private final Node right; - private final int feature; - private final float threshold; - private final Node onMissing; + final Node left; + final Node right; + final int feature; + final float threshold; + final Node onMissing; - public SplitWithMissing(Node left, Node right, Node onMissing, int feature, float threshold) { + public Split(Node left, Node right, Node onMissing, int feature, float threshold) { this.left = Objects.requireNonNull(left); this.right = Objects.requireNonNull(right); this.feature = feature; this.threshold = threshold; - this.onMissing = onMissing == null ? left : onMissing; + this.onMissing = onMissing; } @Override @@ -153,24 +149,6 @@ public boolean isLeaf() { return false; } - @Override - public float eval(float[] scores) { - Node n = this; - while (!n.isLeaf()) { - assert n instanceof SplitWithMissing; - SplitWithMissing s = (SplitWithMissing) n; - if (scores[s.feature] == Float.MAX_VALUE) { - n = s.onMissing; - } else if (s.threshold > scores[s.feature]) { - n = s.left; - } else { - n = s.right; - } - } - assert n instanceof Leaf; - return n.eval(scores); - } - /** * Return the memory usage of this object in bytes. Negative values are illegal. */ @@ -183,6 +161,10 @@ public long ramBytesUsed() { public static class Leaf implements Node { private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(Split.class); + float getOutput() { + return output; + } + private final float output; public Leaf(float output) { @@ -194,11 +176,6 @@ public boolean isLeaf() { return true; } - @Override - public float eval(float[] scores) { - return output; - } - /** * Return the memory usage of this object in bytes. Negative values are illegal. */ diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java index 9fde11f2..0f86d9ca 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java @@ -57,7 +57,9 @@ public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) { float[] weights = new float[trees.length]; // Tree weights are already encoded in outputs Arrays.fill(weights, 1F); - return new NaiveAdditiveDecisionTree(trees, weights, set.size(), modelDefinition.normalizer); + return new NaiveAdditiveDecisionTree(trees, weights, set.size(), + modelDefinition.normalizer, + modelDefinition.useFloatMaxForMissing ? Float.MAX_VALUE : null); } private static class XGBoostDefinition { @@ -256,11 +258,11 @@ Node toNode(FeatureSet set, XGBoostDefinition xgb) { if (isSplit()) { Node left = children.get(0).toNode(set, xgb); Node right = children.get(1).toNode(set, xgb); + Node onMissing = this.missingNodeId.equals(this.rightNodeId) ? right : left; if (xgb.useFloatMaxForMissing) { - Node onMissing = this.missingNodeId.equals(this.rightNodeId) ? right : left; - return new NaiveAdditiveDecisionTree.SplitWithMissing(left, right, onMissing, set.featureOrdinal(split), threshold); + return new NaiveAdditiveDecisionTree.Split(left, right, onMissing, set.featureOrdinal(split), threshold); } else { - return new NaiveAdditiveDecisionTree.Split(left, right, set.featureOrdinal(split), threshold); + return new NaiveAdditiveDecisionTree.Split(left, right, null, set.featureOrdinal(split), threshold); } } else { return new NaiveAdditiveDecisionTree.Leaf(leaf); diff --git a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java index a5e96b73..bc1d8ce3 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java @@ -56,12 +56,12 @@ public class NaiveAdditiveDecisionTreeTests extends LuceneTestCase { static final Logger LOG = LogManager.getLogger(NaiveAdditiveDecisionTreeTests.class); public void testName() { NaiveAdditiveDecisionTree dectree = new NaiveAdditiveDecisionTree(new NaiveAdditiveDecisionTree.Node[0], - new float[0], 0, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME)); + new float[0], 0, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), null); assertEquals("naive_additive_decision_tree", dectree.name()); } public void testScore() throws IOException { - NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME)); + NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), null); LtrRanker.FeatureVector vector = ranker.newFeatureVector(null); vector.setFeatureScore(0, 1); vector.setFeatureScore(1, 2); @@ -72,7 +72,7 @@ public void testScore() throws IOException { } public void testSigmoidScore() throws IOException { - NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME)); + NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME), null); LtrRanker.FeatureVector vector = ranker.newFeatureVector(null); vector.setFeatureScore(0, 1); vector.setFeatureScore(1, 2); @@ -84,26 +84,24 @@ public void testSigmoidScore() throws IOException { } public void testScoreSparseFeatureSet() throws IOException { - NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME)); + NaiveAdditiveDecisionTree ranker = parseTreeModel("tree_with_missing_branches.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), Float.MAX_VALUE); LtrRanker.FeatureVector vector = ranker.newFeatureVector(null); vector.setFeatureScore(0, 1); - vector.setFeatureScore(1, 3); + vector.setFeatureScore(1, 2); vector.setFeatureScore(2, Float.MAX_VALUE); - // simple_tree model does not specify `missing`. We should take the - // left branch in that case. - float expected = 17F*3.4F + 3.2F*2.8F; + float expected = 1.2F*3.4F + 10F*2.8F; assertEquals(expected, ranker.score(vector), Math.ulp(expected)); } public void testScoreSparseFeatureSetWithMissingField() throws IOException { - NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME)); + NaiveAdditiveDecisionTree ranker = parseTreeModel("tree_with_missing_branches.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), Float.MAX_VALUE); LtrRanker.FeatureVector vector = ranker.newFeatureVector(null); vector.setFeatureScore(0, Float.MAX_VALUE); vector.setFeatureScore(1, 2); vector.setFeatureScore(2, 3); - float expected = 3.2F*3.4F + 3.2F*2.8F; + float expected = 1.2F*3.4F + 23F*2.8F; assertEquals(expected, ranker.score(vector), Math.ulp(expected)); } @@ -164,16 +162,16 @@ public static NaiveAdditiveDecisionTree generateRandomDecTree(int nbFeatures, in } trees[i] = new RandomTreeGenerator(nbFeatures, minDepth, maxDepth, collector).genTree(); } - return new NaiveAdditiveDecisionTree(trees, weights, nbFeatures, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME)); + return new NaiveAdditiveDecisionTree(trees, weights, nbFeatures, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), null); } public void testSize() { NaiveAdditiveDecisionTree ranker = new NaiveAdditiveDecisionTree(new NaiveAdditiveDecisionTree.Node[0], - new float[0], 3, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME)); + new float[0], 3, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), null); assertEquals(ranker.size(), 3); } - private NaiveAdditiveDecisionTree parseTreeModel(String textRes, Normalizer normalizer) throws IOException { + private NaiveAdditiveDecisionTree parseTreeModel(String textRes, Normalizer normalizer, Float missingValue) throws IOException { List features = new ArrayList<>(3); features.add(new PrebuiltFeature("feature1", new MatchAllDocsQuery())); features.add(new PrebuiltFeature("feature2", new MatchAllDocsQuery())); @@ -188,7 +186,7 @@ private NaiveAdditiveDecisionTree parseTreeModel(String textRes, Normalizer norm weights[i] = treesAndWeight.get(i).weight; trees[i] = treesAndWeight.get(i).tree; } - return new NaiveAdditiveDecisionTree(trees, weights, set.size(), normalizer); + return new NaiveAdditiveDecisionTree(trees, weights, set.size(), normalizer, missingValue); } private static class TreeTextParser { @@ -241,9 +239,11 @@ NaiveAdditiveDecisionTree.Node parseTree() { float threshold = extractLastFloat(line); NaiveAdditiveDecisionTree.Node right = parseTree(); NaiveAdditiveDecisionTree.Node left = parseTree(); - NaiveAdditiveDecisionTree.Node onMissing = (splitString.length > 3) ? - (Boolean.parseBoolean(splitString[2]) ? left : right) : null; - return new NaiveAdditiveDecisionTree.SplitWithMissing(left, right, onMissing, ord, threshold); + NaiveAdditiveDecisionTree.Node onMissing = null; + if (splitString.length > 3) { + onMissing = Boolean.parseBoolean(splitString[2]) ? left : right; + } + return new NaiveAdditiveDecisionTree.Split(left, right, onMissing, ord, threshold); } else { throw new IllegalArgumentException("Invalid tree"); @@ -309,7 +309,7 @@ private NaiveAdditiveDecisionTree.Node newSplit(int depth) { int feature = featureGen.get(); float thresh = thresholdGenerator.apply(feature); statsCollector.newSplit(depth, feature, thresh); - return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), feature, thresh); + return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), null, feature, thresh); } private NaiveAdditiveDecisionTree.Node newLeaf(int depth) { diff --git a/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt b/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt index ad9def02..109b1cb3 100644 --- a/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt +++ b/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt @@ -1,7 +1,7 @@ # first line after split is right # data point: feature1:1, feature2:2, feature3:3 - tree:3.4 - - split:feature1:false:2.3 + - split:feature1:2.3 - output:3.2 # right wins - split:feature2:2.2 @@ -11,7 +11,7 @@ # left wins => output 1.2*3.4 - output:1.2 - tree:2.8 - - split:feature1:false:0.1 + - split:feature1:0.1 # right wins - split:feature2:1.8 # right wins diff --git a/src/test/resources/com/o19s/es/ltr/ranker/dectree/tree_with_missing_branches.txt b/src/test/resources/com/o19s/es/ltr/ranker/dectree/tree_with_missing_branches.txt new file mode 100644 index 00000000..a383ed0a --- /dev/null +++ b/src/test/resources/com/o19s/es/ltr/ranker/dectree/tree_with_missing_branches.txt @@ -0,0 +1,18 @@ +# first line after split is right +# third field in split -> take left branch on missing +- tree:3.4 + - split:feature1:true:2.3 + - output:3.2 + - split:feature2:false:2.2 + - split:feature3:true:3.2 + - output:11 + - output:17 + - output:1.2 +- tree:2.8 + - split:feature1:true:0.1 + - split:feature2:true:1.8 + - split:feature3:false:3.2 + - output:10 + - output:3.2 + - output:15 + - output:23 From ffd3c212e835d7bbfe9d91e06d9f6d6449b382a5 Mon Sep 17 00:00:00 2001 From: Akshay Kumar Date: Thu, 5 Dec 2019 14:03:33 -0800 Subject: [PATCH 3/4] Make checkstyleTest pass --- .../ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java index bc1d8ce3..88dd299e 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java @@ -84,7 +84,8 @@ public void testSigmoidScore() throws IOException { } public void testScoreSparseFeatureSet() throws IOException { - NaiveAdditiveDecisionTree ranker = parseTreeModel("tree_with_missing_branches.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), Float.MAX_VALUE); + NaiveAdditiveDecisionTree ranker = parseTreeModel("tree_with_missing_branches.txt", + Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), Float.MAX_VALUE); LtrRanker.FeatureVector vector = ranker.newFeatureVector(null); vector.setFeatureScore(0, 1); vector.setFeatureScore(1, 2); @@ -95,7 +96,8 @@ public void testScoreSparseFeatureSet() throws IOException { } public void testScoreSparseFeatureSetWithMissingField() throws IOException { - NaiveAdditiveDecisionTree ranker = parseTreeModel("tree_with_missing_branches.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), Float.MAX_VALUE); + NaiveAdditiveDecisionTree ranker = parseTreeModel("tree_with_missing_branches.txt", + Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), Float.MAX_VALUE); LtrRanker.FeatureVector vector = ranker.newFeatureVector(null); vector.setFeatureScore(0, Float.MAX_VALUE); vector.setFeatureScore(1, 2); From 80abe566fb053a2090dd532da1a9531ecb861f95 Mon Sep 17 00:00:00 2001 From: Akshay Kumar Date: Thu, 5 Dec 2019 17:06:27 -0800 Subject: [PATCH 4/4] Allow users to specify any floating point value as missing value --- .../ltr/ranker/parser/XGBoostJsonParser.java | 25 +++++++++++--- .../ranker/parser/XGBoostJsonParserTests.java | 33 +++++++++++++++++-- 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java index 0f86d9ca..d9757127 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java @@ -59,7 +59,7 @@ public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) { Arrays.fill(weights, 1F); return new NaiveAdditiveDecisionTree(trees, weights, set.size(), modelDefinition.normalizer, - modelDefinition.useFloatMaxForMissing ? Float.MAX_VALUE : null); + modelDefinition.missingValue); } private static class XGBoostDefinition { @@ -67,12 +67,24 @@ private static class XGBoostDefinition { static { PARSER = new ObjectParser<>("xgboost_definition", XGBoostDefinition::new); PARSER.declareString(XGBoostDefinition::setNormalizer, new ParseField("objective")); + + // Parameters to represent missing features in feature vector + // + // There are convenience options like `use_float_max_for_missing` + // which can be used instead of specifying the actual floating + // point value. + // + // Only one of the following parameter should be specified + // otherwise the behavior is undefined (depends on the order of + // names in JSON). + PARSER.declareFloat(XGBoostDefinition::setMissingValue, new ParseField("missing_value")); PARSER.declareBoolean(XGBoostDefinition::setFloatMaxForMissing, new ParseField("use_float_max_for_missing")); + PARSER.declareObjectArray(XGBoostDefinition::setSplitParserStates, SplitParserState::parse, new ParseField("splits")); } private Normalizer normalizer; - private boolean useFloatMaxForMissing; + private Float missingValue; private List splitParserStates; public static XGBoostDefinition parse(XContentParser parser, FeatureSet set) throws IOException { @@ -109,7 +121,6 @@ public static XGBoostDefinition parse(XContentParser parser, FeatureSet set) thr XGBoostDefinition() { normalizer = Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME); - useFloatMaxForMissing = false; } /** @@ -138,7 +149,11 @@ void setNormalizer(String objectiveName) { } void setFloatMaxForMissing(boolean useFloatMaxForMissing) { - this.useFloatMaxForMissing = useFloatMaxForMissing; + this.missingValue = Float.MAX_VALUE; + } + + void setMissingValue(float missingValue) { + this.missingValue = missingValue; } void setSplitParserStates(List splitParserStates) { @@ -259,7 +274,7 @@ Node toNode(FeatureSet set, XGBoostDefinition xgb) { Node left = children.get(0).toNode(set, xgb); Node right = children.get(1).toNode(set, xgb); Node onMissing = this.missingNodeId.equals(this.rightNodeId) ? right : left; - if (xgb.useFloatMaxForMissing) { + if (xgb.missingValue != null) { return new NaiveAdditiveDecisionTree.Split(left, right, onMissing, set.featureOrdinal(split), threshold); } else { return new NaiveAdditiveDecisionTree.Split(left, right, null, set.featureOrdinal(split), threshold); diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java index 797b0b55..57ce8fe8 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java @@ -130,7 +130,7 @@ public void testReadSimpleSplitWithObjective() throws IOException { public void testReadSimpleSplitWithSupportForMissing() throws IOException { String model = "{" + - "\"use_float_max_for_missing\": true," + + "\"missing_value\": -999," + "\"splits\": [{" + " \"nodeid\": 0," + " \"split\":\"feat1\"," + @@ -153,8 +153,37 @@ public void testReadSimpleSplitWithSupportForMissing() throws IOException { assertEquals(0.5F, tree.score(v), Math.ulp(0.5F)); v.setFeatureScore(0, 0.123F); assertEquals(0.2F, tree.score(v), Math.ulp(0.2F)); - v.setFeatureScore(0, Float.MAX_VALUE); + v.setFeatureScore(0, -999); + assertEquals(0.2F, tree.score(v), Math.ulp(0.2F)); + } + + public void testReadSimpleSplitWithFloatMaxForMissing() throws IOException { + String model = "{" + + "\"use_float_max_for_missing\": true," + + "\"splits\": [{" + + " \"nodeid\": 0," + + " \"split\":\"feat1\"," + + " \"depth\":0," + + " \"split_condition\":0.123," + + " \"yes\":1," + + " \"no\": 2," + + " \"missing\":1,"+ + " \"children\": [" + + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + + "]}]}"; + + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + NaiveAdditiveDecisionTree tree = parser.parse(set, model); + FeatureVector v = tree.newFeatureVector(null); + v.setFeatureScore(0, 0.124F); assertEquals(0.2F, tree.score(v), Math.ulp(0.2F)); + v.setFeatureScore(0, 0.122F); + assertEquals(0.5F, tree.score(v), Math.ulp(0.5F)); + v.setFeatureScore(0, 0.123F); + assertEquals(0.2F, tree.score(v), Math.ulp(0.2F)); + v.setFeatureScore(0, Float.MAX_VALUE); + assertEquals(0.5F, tree.score(v), Math.ulp(0.2F)); } public void testReadSplitWithUnknownParams() throws IOException {