Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse feature support for xgboost models #248

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,14 @@ public static class Split implements Node {
private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(Split.class);
private final Node left;
private final Node right;
private final Node onMissing;
private final int feature;
private final float threshold;

public Split(Node left, Node right, int feature, float threshold) {
public Split(Node left, Node right, Node onMissing, int feature, float threshold) {
this.left = Objects.requireNonNull(left);
this.right = Objects.requireNonNull(right);
this.onMissing = onMissing == null ? this.left : onMissing;
this.feature = feature;
this.threshold = threshold;
}
Expand All @@ -109,7 +111,9 @@ public float eval(float[] scores) {
while (!n.isLeaf()) {
assert n instanceof Split;
Split s = (Split) n;
if (s.threshold > scores[s.feature]) {
if (scores[s.feature] == Float.MAX_VALUE) {
n = s.onMissing;
} else if (s.threshold > scores[s.feature]) {
n = s.left;
} else {
n = s.right;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ private static class SplitParserState {
private Float threshold;
private Integer rightNodeId;
private Integer leftNodeId;
// Ignored
private Integer missingNodeId;
private Float leaf;
private List<SplitParserState> children;
Expand Down Expand Up @@ -161,8 +160,10 @@ boolean isSplit() {

Node toNode(FeatureSet set) {
if (isSplit()) {
return new NaiveAdditiveDecisionTree.Split(children.get(0).toNode(set), children.get(1).toNode(set),
set.featureOrdinal(split), threshold);
Node left = children.get(0).toNode(set);
Node right = children.get(1).toNode(set);
Node onMissing = this.missingNodeId.equals(this.rightNodeId) ? right : left;
return new NaiveAdditiveDecisionTree.Split(left, right, onMissing, set.featureOrdinal(split), threshold);
} else {
return new NaiveAdditiveDecisionTree.Leaf(leaf);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,30 @@ public void testScore() throws IOException {
assertEquals(expected, ranker.score(vector), Math.ulp(expected));
}

public void testScoreSparseFeatureSet() throws IOException {
NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt");
LtrRanker.FeatureVector vector = ranker.newFeatureVector(null);
vector.setFeatureScore(0, 1);
vector.setFeatureScore(1, 3);
vector.setFeatureScore(2, Float.MAX_VALUE);

// simple_tree model does not specify `missing`. We should take the
// left branch in that case.
float expected = 17F*3.4F + 3.2F*2.8F;
assertEquals(expected, ranker.score(vector), Math.ulp(expected));
}

public void testScoreSparseFeatureSetWithMissingField() throws IOException {
NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt");
LtrRanker.FeatureVector vector = ranker.newFeatureVector(null);
vector.setFeatureScore(0, Float.MAX_VALUE);
vector.setFeatureScore(1, 2);
vector.setFeatureScore(2, 3);

float expected = 3.2F*3.4F + 3.2F*2.8F;
assertEquals(expected, ranker.score(vector), Math.ulp(expected));
}

public void testPerfAndRobustness() {
SimpleCountRandomTreeGeneratorStatsCollector counts = new SimpleCountRandomTreeGeneratorStatsCollector();
NaiveAdditiveDecisionTree ranker = generateRandomDecTree(100, 1000,
Expand Down Expand Up @@ -193,17 +217,18 @@ NaiveAdditiveDecisionTree.Node parseTree() {
if (line.contains("- output")) {
return new NaiveAdditiveDecisionTree.Leaf(extractLastFloat(line));
} else if(line.contains("- split")) {
String featName = line.split(":")[1];
String[] splitString = line.split(":");
String featName = splitString[1];
int ord = set.featureOrdinal(featName);
if (ord < 0 || ord > set.size()) {
throw new IllegalArgumentException("Unknown feature " + featName);
}
float threshold = extractLastFloat(line);
NaiveAdditiveDecisionTree.Node right = parseTree();
NaiveAdditiveDecisionTree.Node left = parseTree();

return new NaiveAdditiveDecisionTree.Split(left, right,
ord, threshold);
NaiveAdditiveDecisionTree.Node onMissing = (splitString.length > 3) ?
(Boolean.parseBoolean(splitString[2]) ? left : right) : null;
return new NaiveAdditiveDecisionTree.Split(left, right, onMissing, ord, threshold);
} else {
throw new IllegalArgumentException("Invalid tree");
}
Expand Down Expand Up @@ -268,7 +293,7 @@ private NaiveAdditiveDecisionTree.Node newSplit(int depth) {
int feature = featureGen.get();
float thresh = thresholdGenerator.apply(feature);
statsCollector.newSplit(depth, feature, thresh);
return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), feature, thresh);
return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), null, feature, thresh);
}

private NaiveAdditiveDecisionTree.Node newLeaf(int depth) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ public void testReadSimpleSplit() throws IOException {
assertEquals(0.5F, tree.score(v), Math.ulp(0.5F));
v.setFeatureScore(0, 0.123F);
assertEquals(0.2F, tree.score(v), Math.ulp(0.2F));
v.setFeatureScore(0, Float.MAX_VALUE);
assertEquals(0.2F, tree.score(v), Math.ulp(0.2F));
}

public void testMissingField() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# first line after split is right
# data point: feature1:1, feature2:2, feature3:3
- tree:3.4
- split:feature1:2.3
- split:feature1:false:2.3
- output:3.2
# right wins
- split:feature2:2.2
Expand All @@ -11,7 +11,7 @@
# left wins => output 1.2*3.4
- output:1.2
- tree:2.8
- split:feature1:0.1
- split:feature1:false:0.1
# right wins
- split:feature2:1.8
# right wins
Expand Down