Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new xgboost parameter to represent missing features in feature vector #252

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class NaiveAdditiveDecisionTree extends DenseLtrRanker implements Account
private final float[] weights;
private final int modelSize;
private final Normalizer normalizer;
private final Float missingValue;

/**
* TODO: Constructor for these classes are strict and not really
Expand All @@ -45,13 +46,15 @@ public class NaiveAdditiveDecisionTree extends DenseLtrRanker implements Account
* @param weights the respective weights
* @param modelSize the modelSize in number of feature used
* @param normalizer class to perform any normalization on model score
* @param missingValue sentinel value to indicate feature is missing in the feature vector (optional)
*/
public NaiveAdditiveDecisionTree(Node[] trees, float[] weights, int modelSize, Normalizer normalizer) {
public NaiveAdditiveDecisionTree(Node[] trees, float[] weights, int modelSize, Normalizer normalizer, Float missingValue) {
assert trees.length == weights.length;
this.trees = trees;
this.weights = weights;
this.modelSize = modelSize;
this.normalizer = normalizer;
this.missingValue = missingValue;
}

@Override
Expand All @@ -63,12 +66,50 @@ public String name() {
protected float score(DenseFeatureVector vector) {
float sum = 0;
float[] scores = vector.scores;
for (int i = 0; i < trees.length; i++) {
sum += weights[i]*trees[i].eval(scores);
if (this.missingValue != null) {
for (int i = 0; i < trees.length; i++) {
sum += weights[i] * evalWithMissing(trees[i], scores);
}
} else {
for (int i = 0; i < trees.length; i++) {
sum += weights[i] * eval(trees[i], scores);
}
}
return normalizer.normalize(sum);
}

private float eval(Node rootNode, float[] scores) {
Node n = rootNode;
while (!n.isLeaf()) {
assert n instanceof Split;
Split s = (Split) n;
if (s.threshold > scores[s.feature]) {
n = s.left;
} else {
n = s.right;
}
}
assert n instanceof Leaf;
return ((Leaf) n).getOutput();
}

private float evalWithMissing(Node rootNode, float[] scores) {
Node n = rootNode;
while (!n.isLeaf()) {
assert n instanceof Split;
Split s = (Split) n;
if (scores[s.feature] == this.missingValue) {
n = s.onMissing;
} else if (s.threshold > scores[s.feature]) {
n = s.left;
} else {
n = s.right;
}
}
assert n instanceof Leaf;
return ((Leaf) n).getOutput();
}

@Override
protected int size() {
return modelSize;
Expand All @@ -85,44 +126,29 @@ public long ramBytesUsed() {

public interface Node extends Accountable {
boolean isLeaf();
float eval(float[] scores);
}

public static class Split implements Node {
private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(Split.class);
private final Node left;
private final Node right;
private final int feature;
private final float threshold;
final Node left;
final Node right;
final int feature;
final float threshold;
final Node onMissing;

public Split(Node left, Node right, int feature, float threshold) {
public Split(Node left, Node right, Node onMissing, int feature, float threshold) {
this.left = Objects.requireNonNull(left);
this.right = Objects.requireNonNull(right);
this.feature = feature;
this.threshold = threshold;
this.onMissing = onMissing;
}

@Override
public boolean isLeaf() {
return false;
}

@Override
public float eval(float[] scores) {
Node n = this;
while (!n.isLeaf()) {
assert n instanceof Split;
Split s = (Split) n;
if (s.threshold > scores[s.feature]) {
n = s.left;
} else {
n = s.right;
}
}
assert n instanceof Leaf;
return n.eval(scores);
}

/**
* Return the memory usage of this object in bytes. Negative values are illegal.
*/
Expand All @@ -135,6 +161,10 @@ public long ramBytesUsed() {
public static class Leaf implements Node {
private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(Split.class);

float getOutput() {
return output;
}

private final float output;

public Leaf(float output) {
Expand All @@ -146,11 +176,6 @@ public boolean isLeaf() {
return true;
}

@Override
public float eval(float[] scores) {
return output;
}

/**
* Return the memory usage of this object in bytes. Negative values are illegal.
*/
Expand Down
41 changes: 35 additions & 6 deletions src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,34 @@ public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) {
float[] weights = new float[trees.length];
// Tree weights are already encoded in outputs
Arrays.fill(weights, 1F);
return new NaiveAdditiveDecisionTree(trees, weights, set.size(), modelDefinition.normalizer);
return new NaiveAdditiveDecisionTree(trees, weights, set.size(),
modelDefinition.normalizer,
modelDefinition.missingValue);
}

private static class XGBoostDefinition {
private static final ObjectParser<XGBoostDefinition, FeatureSet> PARSER;
static {
PARSER = new ObjectParser<>("xgboost_definition", XGBoostDefinition::new);
PARSER.declareString(XGBoostDefinition::setNormalizer, new ParseField("objective"));

// Parameters to represent missing features in feature vector
//
// There are convenience options like `use_float_max_for_missing`
// which can be used instead of specifying the actual floating
// point value.
//
// Only one of the following parameter should be specified
// otherwise the behavior is undefined (depends on the order of
// names in JSON).
PARSER.declareFloat(XGBoostDefinition::setMissingValue, new ParseField("missing_value"));
Copy link
Collaborator

@nomoa nomoa Dec 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of adding a new param it might be more coherent to be able to parse a float and string
this way one could pass:
"missing_value": "max"
or
"missing_value": 0.0

everything else would throw an exception.

See https://github.com/elastic/elasticsearch/blob/237650e9c054149fd08213b38a81a3666c1868e5/server/src/main/java/org/elasticsearch/search/suggest/completion/FuzzyOptions.java#L66 to declare a field that accepts any kind of value
and https://github.com/elastic/elasticsearch/blob/f92ebb2ff909d0083ae988e04ecd398d979e9210/server/src/main/java/org/elasticsearch/common/unit/Fuzziness.java#L160 for how to parse the value

PARSER.declareBoolean(XGBoostDefinition::setFloatMaxForMissing, new ParseField("use_float_max_for_missing"));

PARSER.declareObjectArray(XGBoostDefinition::setSplitParserStates, SplitParserState::parse, new ParseField("splits"));
}

private Normalizer normalizer;
private Float missingValue;
private List<SplitParserState> splitParserStates;

public static XGBoostDefinition parse(XContentParser parser, FeatureSet set) throws IOException {
Expand Down Expand Up @@ -132,6 +148,14 @@ void setNormalizer(String objectiveName) {
}
}

void setFloatMaxForMissing(boolean useFloatMaxForMissing) {
this.missingValue = Float.MAX_VALUE;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useFloatMaxForMissing is now ignored

}

void setMissingValue(float missingValue) {
this.missingValue = missingValue;
}

void setSplitParserStates(List<SplitParserState> splitParserStates) {
this.splitParserStates = splitParserStates;
}
Expand All @@ -140,7 +164,7 @@ Node[] getTrees(FeatureSet set) {
Node[] trees = new Node[splitParserStates.size()];
ListIterator<SplitParserState> it = splitParserStates.listIterator();
while(it.hasNext()) {
trees[it.nextIndex()] = it.next().toNode(set);
trees[it.nextIndex()] = it.next().toNode(set, this);
}
return trees;
}
Expand Down Expand Up @@ -169,7 +193,6 @@ private static class SplitParserState {
private Float threshold;
private Integer rightNodeId;
private Integer leftNodeId;
// Ignored
private Integer missingNodeId;
private Float leaf;
private List<SplitParserState> children;
Expand Down Expand Up @@ -246,10 +269,16 @@ boolean isSplit() {
}


Node toNode(FeatureSet set) {
Node toNode(FeatureSet set, XGBoostDefinition xgb) {
if (isSplit()) {
return new NaiveAdditiveDecisionTree.Split(children.get(0).toNode(set), children.get(1).toNode(set),
set.featureOrdinal(split), threshold);
Node left = children.get(0).toNode(set, xgb);
Node right = children.get(1).toNode(set, xgb);
Node onMissing = this.missingNodeId.equals(this.rightNodeId) ? right : left;
if (xgb.missingValue != null) {
return new NaiveAdditiveDecisionTree.Split(left, right, onMissing, set.featureOrdinal(split), threshold);
} else {
return new NaiveAdditiveDecisionTree.Split(left, right, null, set.featureOrdinal(split), threshold);
}
} else {
return new NaiveAdditiveDecisionTree.Leaf(leaf);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ public class NaiveAdditiveDecisionTreeTests extends LuceneTestCase {
static final Logger LOG = LogManager.getLogger(NaiveAdditiveDecisionTreeTests.class);
public void testName() {
NaiveAdditiveDecisionTree dectree = new NaiveAdditiveDecisionTree(new NaiveAdditiveDecisionTree.Node[0],
new float[0], 0, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME));
new float[0], 0, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), null);
assertEquals("naive_additive_decision_tree", dectree.name());
}

public void testScore() throws IOException {
NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME));
NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), null);
LtrRanker.FeatureVector vector = ranker.newFeatureVector(null);
vector.setFeatureScore(0, 1);
vector.setFeatureScore(1, 2);
Expand All @@ -72,7 +72,7 @@ public void testScore() throws IOException {
}

public void testSigmoidScore() throws IOException {
NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME));
NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME), null);
LtrRanker.FeatureVector vector = ranker.newFeatureVector(null);
vector.setFeatureScore(0, 1);
vector.setFeatureScore(1, 2);
Expand All @@ -83,6 +83,31 @@ public void testSigmoidScore() throws IOException {
assertEquals(expected, ranker.score(vector), Math.ulp(expected));
}

public void testScoreSparseFeatureSet() throws IOException {
NaiveAdditiveDecisionTree ranker = parseTreeModel("tree_with_missing_branches.txt",
Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), Float.MAX_VALUE);
LtrRanker.FeatureVector vector = ranker.newFeatureVector(null);
vector.setFeatureScore(0, 1);
vector.setFeatureScore(1, 2);
vector.setFeatureScore(2, Float.MAX_VALUE);

float expected = 1.2F*3.4F + 10F*2.8F;
assertEquals(expected, ranker.score(vector), Math.ulp(expected));
}

public void testScoreSparseFeatureSetWithMissingField() throws IOException {
NaiveAdditiveDecisionTree ranker = parseTreeModel("tree_with_missing_branches.txt",
Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), Float.MAX_VALUE);
LtrRanker.FeatureVector vector = ranker.newFeatureVector(null);
vector.setFeatureScore(0, Float.MAX_VALUE);
vector.setFeatureScore(1, 2);
vector.setFeatureScore(2, 3);

float expected = 1.2F*3.4F + 23F*2.8F;
assertEquals(expected, ranker.score(vector), Math.ulp(expected));
}


public void testPerfAndRobustness() {
SimpleCountRandomTreeGeneratorStatsCollector counts = new SimpleCountRandomTreeGeneratorStatsCollector();
NaiveAdditiveDecisionTree ranker = generateRandomDecTree(100, 1000,
Expand Down Expand Up @@ -139,16 +164,16 @@ public static NaiveAdditiveDecisionTree generateRandomDecTree(int nbFeatures, in
}
trees[i] = new RandomTreeGenerator(nbFeatures, minDepth, maxDepth, collector).genTree();
}
return new NaiveAdditiveDecisionTree(trees, weights, nbFeatures, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME));
return new NaiveAdditiveDecisionTree(trees, weights, nbFeatures, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), null);
}

public void testSize() {
NaiveAdditiveDecisionTree ranker = new NaiveAdditiveDecisionTree(new NaiveAdditiveDecisionTree.Node[0],
new float[0], 3, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME));
new float[0], 3, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME), null);
assertEquals(ranker.size(), 3);
}

private NaiveAdditiveDecisionTree parseTreeModel(String textRes, Normalizer normalizer) throws IOException {
private NaiveAdditiveDecisionTree parseTreeModel(String textRes, Normalizer normalizer, Float missingValue) throws IOException {
List<PrebuiltFeature> features = new ArrayList<>(3);
features.add(new PrebuiltFeature("feature1", new MatchAllDocsQuery()));
features.add(new PrebuiltFeature("feature2", new MatchAllDocsQuery()));
Expand All @@ -163,7 +188,7 @@ private NaiveAdditiveDecisionTree parseTreeModel(String textRes, Normalizer norm
weights[i] = treesAndWeight.get(i).weight;
trees[i] = treesAndWeight.get(i).tree;
}
return new NaiveAdditiveDecisionTree(trees, weights, set.size(), normalizer);
return new NaiveAdditiveDecisionTree(trees, weights, set.size(), normalizer, missingValue);
}

private static class TreeTextParser {
Expand Down Expand Up @@ -207,17 +232,21 @@ NaiveAdditiveDecisionTree.Node parseTree() {
if (line.contains("- output")) {
return new NaiveAdditiveDecisionTree.Leaf(extractLastFloat(line));
} else if(line.contains("- split")) {
String featName = line.split(":")[1];
String[] splitString = line.split(":");
String featName = splitString[1];
int ord = set.featureOrdinal(featName);
if (ord < 0 || ord > set.size()) {
throw new IllegalArgumentException("Unknown feature " + featName);
}
float threshold = extractLastFloat(line);
NaiveAdditiveDecisionTree.Node right = parseTree();
NaiveAdditiveDecisionTree.Node left = parseTree();
NaiveAdditiveDecisionTree.Node onMissing = null;
if (splitString.length > 3) {
onMissing = Boolean.parseBoolean(splitString[2]) ? left : right;
}
return new NaiveAdditiveDecisionTree.Split(left, right, onMissing, ord, threshold);

return new NaiveAdditiveDecisionTree.Split(left, right,
ord, threshold);
} else {
throw new IllegalArgumentException("Invalid tree");
}
Expand Down Expand Up @@ -282,7 +311,7 @@ private NaiveAdditiveDecisionTree.Node newSplit(int depth) {
int feature = featureGen.get();
float thresh = thresholdGenerator.apply(feature);
statsCollector.newSplit(depth, feature, thresh);
return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), feature, thresh);
return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), null, feature, thresh);
}

private NaiveAdditiveDecisionTree.Node newLeaf(int depth) {
Expand Down
Loading