Skip to content

Commit

Permalink
[7.7][ML] Add information about samples per node to the tree (#1006)
Browse files Browse the repository at this point in the history
Backport to #991
  • Loading branch information
valeriy42 authored Feb 18, 2020
1 parent f21507d commit 030d608
Show file tree
Hide file tree
Showing 16 changed files with 215 additions and 139 deletions.
2 changes: 2 additions & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ progress, memory usage, etc. (See {ml-pull}906[#906].)

* Improve initialization of learn rate for better and more stable results in regression
and classification. (See {ml-pull}948[#948].)
* Add number of processed training samples to the definition of decision tree nodes.
(See {ml-pull}991[#991].)
* Add new model_size_stats fields to instrument categorization. (See {ml-pull}948[#948]
and {pull}51879[#51879], issue: {issue}50794[#50749].)

Expand Down
1 change: 1 addition & 0 deletions include/api/CBoostedTreeInferenceModelBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class API_EXPORT CBoostedTreeInferenceModelBuilder : public maths::CBoostedTree:
bool assignMissingToLeft,
double nodeValue,
double gain,
std::size_t numberSamples,
maths::CBoostedTreeNode::TOptionalNodeIndex leftChild,
maths::CBoostedTreeNode::TOptionalNodeIndex rightChild) override;
void addIdentityEncoding(std::size_t inputColumnIndex) override;
Expand Down
4 changes: 3 additions & 1 deletion include/api/CInferenceModelDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class API_EXPORT CTree final : public CTrainedModel {
double threshold,
bool defaultLeft,
double leafValue,
size_t splitFeature,
std::size_t splitFeature,
std::size_t numberSamples,
const TOptionalNodeIndex& leftChild,
const TOptionalNodeIndex& rightChild,
const TOptionalDouble& splitGain);
Expand All @@ -175,6 +176,7 @@ class API_EXPORT CTree final : public CTrainedModel {
TOptionalNodeIndex m_LeftChild;
TOptionalNodeIndex m_RightChild;
std::size_t m_SplitFeature;
std::size_t m_NumberSamples;
double m_Threshold;
double m_LeafValue;
TOptionalDouble m_SplitGain;
Expand Down
8 changes: 8 additions & 0 deletions include/maths/CBoostedTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ class MATHS_EXPORT CBoostedTreeNode final {
bool assignMissingToLeft,
double nodeValue,
double gain,
std::size_t numberSamples,
TOptionalNodeIndex leftChild,
TOptionalNodeIndex rightChild) = 0;
};
Expand Down Expand Up @@ -334,6 +335,12 @@ class MATHS_EXPORT CBoostedTreeNode final {
//! Get the total curvature at the rows below this node.
double curvature() const { return m_Curvature; }

//! Set the number of samples to \p value.
void numberSamples(std::size_t value);

//! Get number of samples affected by the node.
std::size_t numberSamples() const;

//! Get the index of the left child node.
TNodeIndex leftChildIndex() const { return m_LeftChild.get(); }

Expand Down Expand Up @@ -376,6 +383,7 @@ class MATHS_EXPORT CBoostedTreeNode final {
double m_NodeValue = 0.0;
double m_Gain = 0.0;
double m_Curvature = 0.0;
std::size_t m_NumberSamples = 0;
};

//! \brief A boosted regression tree model.
Expand Down
6 changes: 6 additions & 0 deletions include/maths/CBoostedTreeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ class MATHS_EXPORT CBoostedTreeImpl final {
//! Get the root node of \p tree.
static const CBoostedTreeNode& root(const TNodeVec& tree);

//! Get the root node of \p tree.
static CBoostedTreeNode& root(TNodeVec& tree);

//! Get the forest's prediction for \p row.
static double predictRow(const CEncodedDataFrameRowRef& row, const TNodeVecVec& forest);

Expand Down Expand Up @@ -287,6 +290,9 @@ class MATHS_EXPORT CBoostedTreeImpl final {
//! Record the training state using the \p recordTrainState callback function
void recordState(const TTrainingStateCallback& recordTrainState) const;

//! Populate numberSamples field in the m_BestForest
void computeNumberSamples(const core::CDataFrame& frame);

private:
mutable CPRNG::CXorOShiro128Plus m_Rng;
std::size_t m_NumberThreads;
Expand Down
13 changes: 1 addition & 12 deletions include/maths/CTreeShapFeatureImportance.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,9 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
//! by \p offset.
void shap(core::CDataFrame& frame, const CDataFrameCategoryEncoder& encoder, std::size_t offset);

//! Compute number of training samples from \p frame that pass every node in the \p tree.
static TDoubleVec samplesPerNode(const TTree& tree,
const core::CDataFrame& frame,
const CDataFrameCategoryEncoder& encoder,
std::size_t numThreads);

//! Recursively computes inner node values as weighted average of the children (leaf) values
//! \returns The maximum depth the the tree.
static std::size_t updateNodeValues(TTree& tree,
std::size_t nodeIndex,
const TDoubleVec& samplesPerNode,
std::size_t depth);
static size_t updateNodeValues(TTree& tree, std::size_t nodeIndex, std::size_t depth);

//! Get the reference to the trees.
TTreeVec& trees() { return m_Trees; }
Expand Down Expand Up @@ -126,7 +117,6 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
//! Recursively traverses all pathes in the \p tree and updated SHAP values once it hits a leaf.
//! Ref. Algorithm 2 in the paper by Lundberg et al.
void shapRecursive(const TTree& tree,
const TDoubleVec& samplesPerNode,
const CDataFrameCategoryEncoder& encoder,
const CEncodedDataFrameRowRef& encodedRow,
SPath& splitPath,
Expand All @@ -146,7 +136,6 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
private:
TTreeVec m_Trees;
std::size_t m_NumberThreads;
TDoubleVecVec m_SamplesPerNode;
};
}
}
Expand Down
6 changes: 4 additions & 2 deletions lib/api/CBoostedTreeInferenceModelBuilder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ void CBoostedTreeInferenceModelBuilder::addNode(std::size_t splitFeature,
bool assignMissingToLeft,
double nodeValue,
double gain,
std::size_t numberSamples,
maths::CBoostedTreeNode::TOptionalNodeIndex leftChild,
maths::CBoostedTreeNode::TOptionalNodeIndex rightChild) {
auto ensemble{static_cast<CEnsemble*>(m_Definition.trainedModel().get())};
Expand All @@ -97,8 +98,9 @@ void CBoostedTreeInferenceModelBuilder::addNode(std::size_t splitFeature,
if (tree == nullptr) {
HANDLE_FATAL(<< "Internal error. Tree points to a nullptr.")
}
tree->treeStructure().emplace_back(tree->size(), splitValue, assignMissingToLeft, nodeValue,
splitFeature, leftChild, rightChild, gain);
tree->treeStructure().emplace_back(tree->size(), splitValue, assignMissingToLeft,
nodeValue, splitFeature, numberSamples,
leftChild, rightChild, gain);
}

CBoostedTreeInferenceModelBuilder::CBoostedTreeInferenceModelBuilder(TStrVec fieldNames,
Expand Down
10 changes: 8 additions & 2 deletions lib/api/CInferenceModelDefinition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const std::string JSON_LEFT_CHILD_TAG{"left_child"};
const std::string JSON_LOGISTIC_REGRESSION_TAG{"logistic_regression"};
const std::string JSON_LT{"lt"};
const std::string JSON_NODE_INDEX_TAG{"node_index"};
const std::string JSON_NUMBER_SAMPLES_TAG{"number_samples"};
const std::string JSON_ONE_HOT_ENCODING_TAG{"one_hot_encoding"};
const std::string JSON_PREPROCESSORS_TAG{"preprocessors"};
const std::string JSON_RIGHT_CHILD_TAG{"right_child"};
Expand Down Expand Up @@ -79,6 +80,9 @@ void addJsonArray(const std::string& tag,
void CTree::CTreeNode::addToDocument(rapidjson::Value& parentObject,
TRapidJsonWriter& writer) const {
writer.addMember(JSON_NODE_INDEX_TAG, rapidjson::Value(m_NodeIndex).Move(), parentObject);
writer.addMember(
JSON_NUMBER_SAMPLES_TAG,
rapidjson::Value(static_cast<std::uint64_t>(m_NumberSamples)).Move(), parentObject);

if (m_LeftChild) {
// internal node
Expand Down Expand Up @@ -118,11 +122,13 @@ CTree::CTreeNode::CTreeNode(TNodeIndex nodeIndex,
bool defaultLeft,
double leafValue,
std::size_t splitFeature,
std::size_t numberSamples,
const TOptionalNodeIndex& leftChild,
const TOptionalNodeIndex& rightChild,
const TOptionalDouble& splitGain)
: m_DefaultLeft(defaultLeft), m_NodeIndex(nodeIndex), m_LeftChild(leftChild),
m_RightChild(rightChild), m_SplitFeature(splitFeature),
: m_DefaultLeft(defaultLeft), m_NodeIndex(nodeIndex),
m_LeftChild(leftChild), m_RightChild(rightChild),
m_SplitFeature(splitFeature), m_NumberSamples(numberSamples),
m_Threshold(threshold), m_LeafValue(leafValue), m_SplitGain(splitGain) {
}

Expand Down
5 changes: 3 additions & 2 deletions lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,9 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoImportance, SFixture) {
// c1 explains 95% of the prediction value, i.e. the difference from the prediction is less than 2%.
BOOST_REQUIRE_CLOSE(c1, prediction, 5.0);
for (const auto& feature : {"c2", "c3", "c4"}) {
BOOST_REQUIRE_SMALL(readShapValue(result, feature), 2.0);
cNoImportanceMean.add(std::fabs(readShapValue(result, feature)));
double c = readShapValue(result, feature);
BOOST_REQUIRE_SMALL(c, 2.0);
cNoImportanceMean.add(std::fabs(c));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@
},
"right_child": {
"type": "integer"
},
"number_samples": {
"description": "Number of training samples that were affected by the node.",
"type": "integer"
}
},
"required": [
Expand All @@ -75,7 +79,8 @@
"decision_type",
"default_left",
"left_child",
"right_child"
"right_child",
"number_samples"
],
"additionalProperties": false
},
Expand All @@ -88,11 +93,16 @@
},
"leaf_value": {
"type": "number"
},
"number_samples": {
"description": "Number of training samples that were affected by the node.",
"type": "integer"
}
},
"required": [
"node_index",
"leaf_value"
"leaf_value",
"number_samples"
],
"additionalProperties": false
},
Expand Down Expand Up @@ -234,10 +244,14 @@
"items": {
"type": "number"
}
},
"num_classes": {
"type": "integer"
}
},
"required": [
"weights"
"weights",
"num_classes"
]
}
},
Expand Down
16 changes: 14 additions & 2 deletions lib/maths/CBoostedTree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const std::string SPLIT_FEATURE_TAG{"split_feature"};
const std::string ASSIGN_MISSING_TO_LEFT_TAG{"assign_missing_to_left "};
const std::string NODE_VALUE_TAG{"node_value"};
const std::string SPLIT_VALUE_TAG{"split_value"};
const std::string NUMBER_SAMPLES_TAG{"number_samples"};

double LOG_EPSILON{std::log(100.0 * std::numeric_limits<double>::epsilon())};

Expand Down Expand Up @@ -393,6 +394,7 @@ void CBoostedTreeNode::acceptPersistInserter(core::CStatePersistInserter& insert
core::CPersistUtils::persist(ASSIGN_MISSING_TO_LEFT_TAG, m_AssignMissingToLeft, inserter);
core::CPersistUtils::persist(NODE_VALUE_TAG, m_NodeValue, inserter);
core::CPersistUtils::persist(SPLIT_VALUE_TAG, m_SplitValue, inserter);
core::CPersistUtils::persist(NUMBER_SAMPLES_TAG, m_NumberSamples, inserter);
}

bool CBoostedTreeNode::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
Expand All @@ -411,6 +413,8 @@ bool CBoostedTreeNode::acceptRestoreTraverser(core::CStateRestoreTraverser& trav
core::CPersistUtils::restore(NODE_VALUE_TAG, m_NodeValue, traverser))
RESTORE(SPLIT_VALUE_TAG,
core::CPersistUtils::restore(SPLIT_VALUE_TAG, m_SplitValue, traverser))
RESTORE(NUMBER_SAMPLES_TAG,
core::CPersistUtils::restore(NUMBER_SAMPLES_TAG, m_NumberSamples, traverser))
} while (traverser.next());
return true;
}
Expand All @@ -435,8 +439,16 @@ std::ostringstream& CBoostedTreeNode::doPrint(std::string pad,
}

void CBoostedTreeNode::accept(CVisitor& visitor) const {
visitor.addNode(m_SplitFeature, m_SplitValue, m_AssignMissingToLeft,
m_NodeValue, m_Gain, m_LeftChild, m_RightChild);
visitor.addNode(m_SplitFeature, m_SplitValue, m_AssignMissingToLeft, m_NodeValue,
m_Gain, m_NumberSamples, m_LeftChild, m_RightChild);
}

void CBoostedTreeNode::numberSamples(std::size_t numberSamples) {
m_NumberSamples = numberSamples;
}

std::size_t CBoostedTreeNode::numberSamples() const {
return m_NumberSamples;
}

CBoostedTree::CBoostedTree(core::CDataFrame& frame,
Expand Down
66 changes: 53 additions & 13 deletions lib/maths/CBoostedTreeImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,
this->restoreBestHyperparameters();
std::tie(m_BestForest, std::ignore) = this->trainForest(
frame, allTrainingRowsMask, allTrainingRowsMask, m_TrainingProgress);

m_Instrumentation->nextStep(static_cast<std::uint32_t>(m_CurrentRound));
this->recordState(recordTrainStateCallback);

Expand All @@ -242,12 +243,56 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,

this->computeProbabilityAtWhichToAssignClassOne(frame);

// populate numberSamples field in the final forest
this->computeNumberSamples(frame);

// Force progress to one because we can have early exit from loop skip altogether.
m_Instrumentation->updateProgress(1.0);
m_Instrumentation->updateMemoryUsage(
static_cast<std::int64_t>(this->memoryUsage()) - lastMemoryUsage);
}

void CBoostedTreeImpl::computeNumberSamples(const core::CDataFrame& frame) {
for (auto& tree : m_BestForest) {
if (tree.size() == 1) {
root(tree).numberSamples(frame.numberRows());
} else {
auto result = frame.readRows(
m_NumberThreads,
core::bindRetrievableState(
[&](TSizeVec& samplesPerNode, const TRowItr& beginRows, const TRowItr& endRows) {
for (auto row = beginRows; row != endRows; ++row) {
auto encodedRow{m_Encoder->encode(*row)};
const CBoostedTreeNode* node{&root(tree)};
samplesPerNode[0] += 1;
std::size_t nextIndex;
while (node->isLeaf() == false) {
if (node->assignToLeft(encodedRow)) {
nextIndex = node->leftChildIndex();
} else {
nextIndex = node->rightChildIndex();
}
samplesPerNode[nextIndex] += 1;
node = &(tree[nextIndex]);
}
}
},
TSizeVec(tree.size())));
auto& state = result.first;
TSizeVec totalSamplesPerNode{std::move(state[0].s_FunctionState)};
for (std::size_t i = 1; i < state.size(); ++i) {
for (std::size_t nodeIndex = 0;
nodeIndex < totalSamplesPerNode.size(); ++nodeIndex) {
totalSamplesPerNode[nodeIndex] += state[i].s_FunctionState[nodeIndex];
}
}
for (std::size_t i = 0; i < tree.size(); ++i) {
tree[i].numberSamples(totalSamplesPerNode[i]);
}
}
}
}

void CBoostedTreeImpl::recordState(const TTrainingStateCallback& recordTrainState) const {
recordTrainState([this](core::CStatePersistInserter& inserter) {
this->acceptPersistInserter(inserter);
Expand Down Expand Up @@ -997,6 +1042,10 @@ const CBoostedTreeNode& CBoostedTreeImpl::root(const TNodeVec& tree) {
return tree[0];
}

CBoostedTreeNode& CBoostedTreeImpl::root(TNodeVec& tree) {
return tree[0];
}

double CBoostedTreeImpl::predictRow(const CEncodedDataFrameRowRef& row,
const TNodeVecVec& forest) {
double result{0.0};
Expand Down Expand Up @@ -1148,9 +1197,8 @@ std::size_t CBoostedTreeImpl::maximumTreeSize(std::size_t numberRows) const {
}

namespace {
const std::string VERSION_7_5_TAG{"7.5"};
const std::string VERSION_7_6_TAG{"7.6"};
const TStrVec SUPPORTED_VERSIONS{VERSION_7_5_TAG, VERSION_7_6_TAG};
const std::string VERSION_7_7_TAG{"7.7"};
const TStrVec SUPPORTED_VERSIONS{VERSION_7_7_TAG};

const std::string BAYESIAN_OPTIMIZATION_TAG{"bayesian_optimization"};
const std::string BEST_FOREST_TAG{"best_forest"};
Expand Down Expand Up @@ -1214,7 +1262,7 @@ CBoostedTreeImpl::TStrVec CBoostedTreeImpl::bestHyperparameterNames() {
}

void CBoostedTreeImpl::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
core::CPersistUtils::persist(VERSION_7_6_TAG, "", inserter);
core::CPersistUtils::persist(VERSION_7_7_TAG, "", inserter);
core::CPersistUtils::persist(BAYESIAN_OPTIMIZATION_TAG, *m_BayesianOptimization, inserter);
core::CPersistUtils::persist(BEST_FOREST_TEST_LOSS_TAG, m_BestForestTestLoss, inserter);
core::CPersistUtils::persist(CURRENT_ROUND_TAG, m_CurrentRound, inserter);
Expand Down Expand Up @@ -1266,15 +1314,7 @@ void CBoostedTreeImpl::acceptPersistInserter(core::CStatePersistInserter& insert
}

bool CBoostedTreeImpl::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
if (traverser.name() == VERSION_7_5_TAG) {
// Force downsample factor to 1.0.
m_DownsampleFactorOverride = 1.0;
m_DownsampleFactor = 1.0;
m_BestHyperparameters.downsampleFactor(1.0);
// We can't stop cross-validation early because we haven't gathered the
// per fold test losses.
m_StopCrossValidationEarly = false;
} else if (traverser.name() != VERSION_7_6_TAG) {
if (traverser.name() != VERSION_7_7_TAG) {
LOG_ERROR(<< "Input error: unsupported state serialization version. "
<< "Currently supported versions: "
<< core::CContainerPrinter::print(SUPPORTED_VERSIONS) << ".");
Expand Down
Loading

0 comments on commit 030d608

Please sign in to comment.