diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index 8b401473e4..3696f71751 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -39,6 +39,8 @@ progress, memory usage, etc. (See {ml-pull}906[#906].) * Improve initialization of learn rate for better and more stable results in regression and classification. (See {ml-pull}948[#948].) +* Add number of processed training samples to the definition of decision tree nodes. +(See {ml-pull}991[#991].) * Add new model_size_stats fields to instrument categorization. (See {ml-pull}948[#948] and {pull}51879[#51879], issue: {issue}50794[#50749].) diff --git a/include/api/CBoostedTreeInferenceModelBuilder.h b/include/api/CBoostedTreeInferenceModelBuilder.h index de1083bf41..2d3b31eeca 100644 --- a/include/api/CBoostedTreeInferenceModelBuilder.h +++ b/include/api/CBoostedTreeInferenceModelBuilder.h @@ -38,6 +38,7 @@ class API_EXPORT CBoostedTreeInferenceModelBuilder : public maths::CBoostedTree: bool assignMissingToLeft, double nodeValue, double gain, + std::size_t numberSamples, maths::CBoostedTreeNode::TOptionalNodeIndex leftChild, maths::CBoostedTreeNode::TOptionalNodeIndex rightChild) override; void addIdentityEncoding(std::size_t inputColumnIndex) override; diff --git a/include/api/CInferenceModelDefinition.h b/include/api/CInferenceModelDefinition.h index 38c65db383..ac6df728fb 100644 --- a/include/api/CInferenceModelDefinition.h +++ b/include/api/CInferenceModelDefinition.h @@ -158,7 +158,8 @@ class API_EXPORT CTree final : public CTrainedModel { double threshold, bool defaultLeft, double leafValue, - size_t splitFeature, + std::size_t splitFeature, + std::size_t numberSamples, const TOptionalNodeIndex& leftChild, const TOptionalNodeIndex& rightChild, const TOptionalDouble& splitGain); @@ -175,6 +176,7 @@ class API_EXPORT CTree final : public CTrainedModel { TOptionalNodeIndex m_LeftChild; TOptionalNodeIndex m_RightChild; std::size_t m_SplitFeature; + std::size_t m_NumberSamples; double m_Threshold; double m_LeafValue; TOptionalDouble m_SplitGain; diff --git a/include/maths/CBoostedTree.h b/include/maths/CBoostedTree.h index fb770b1d53..9c99c4259c 100644 --- a/include/maths/CBoostedTree.h +++ b/include/maths/CBoostedTree.h @@ -293,6 +293,7 @@ class MATHS_EXPORT CBoostedTreeNode final { bool assignMissingToLeft, double nodeValue, double gain, + std::size_t numberSamples, TOptionalNodeIndex leftChild, TOptionalNodeIndex rightChild) = 0; }; @@ -334,6 +335,12 @@ class MATHS_EXPORT CBoostedTreeNode final { //! Get the total curvature at the rows below this node. double curvature() const { return m_Curvature; } + //! Set the number of samples to \p value. + void numberSamples(std::size_t value); + + //! Get number of samples affected by the node. + std::size_t numberSamples() const; + //! Get the index of the left child node. TNodeIndex leftChildIndex() const { return m_LeftChild.get(); } @@ -376,6 +383,7 @@ class MATHS_EXPORT CBoostedTreeNode final { double m_NodeValue = 0.0; double m_Gain = 0.0; double m_Curvature = 0.0; + std::size_t m_NumberSamples = 0; }; //! \brief A boosted regression tree model. diff --git a/include/maths/CBoostedTreeImpl.h b/include/maths/CBoostedTreeImpl.h index a3b808bf46..4da500f90c 100644 --- a/include/maths/CBoostedTreeImpl.h +++ b/include/maths/CBoostedTreeImpl.h @@ -252,6 +252,9 @@ class MATHS_EXPORT CBoostedTreeImpl final { //! Get the root node of \p tree. static const CBoostedTreeNode& root(const TNodeVec& tree); + //! Get the root node of \p tree. + static CBoostedTreeNode& root(TNodeVec& tree); + //! Get the forest's prediction for \p row. static double predictRow(const CEncodedDataFrameRowRef& row, const TNodeVecVec& forest); @@ -287,6 +290,9 @@ class MATHS_EXPORT CBoostedTreeImpl final { //! Record the training state using the \p recordTrainState callback function void recordState(const TTrainingStateCallback& recordTrainState) const; + //! Populate numberSamples field in the m_BestForest + void computeNumberSamples(const core::CDataFrame& frame); + private: mutable CPRNG::CXorOShiro128Plus m_Rng; std::size_t m_NumberThreads; diff --git a/include/maths/CTreeShapFeatureImportance.h b/include/maths/CTreeShapFeatureImportance.h index 64f56b6b47..3b0eaba5f4 100644 --- a/include/maths/CTreeShapFeatureImportance.h +++ b/include/maths/CTreeShapFeatureImportance.h @@ -40,18 +40,9 @@ class MATHS_EXPORT CTreeShapFeatureImportance { //! by \p offset. void shap(core::CDataFrame& frame, const CDataFrameCategoryEncoder& encoder, std::size_t offset); - //! Compute number of training samples from \p frame that pass every node in the \p tree. - static TDoubleVec samplesPerNode(const TTree& tree, - const core::CDataFrame& frame, - const CDataFrameCategoryEncoder& encoder, - std::size_t numThreads); - //! Recursively computes inner node values as weighted average of the children (leaf) values //! \returns The maximum depth the the tree. - static std::size_t updateNodeValues(TTree& tree, - std::size_t nodeIndex, - const TDoubleVec& samplesPerNode, - std::size_t depth); + static size_t updateNodeValues(TTree& tree, std::size_t nodeIndex, std::size_t depth); //! Get the reference to the trees. TTreeVec& trees() { return m_Trees; } @@ -126,7 +117,6 @@ class MATHS_EXPORT CTreeShapFeatureImportance { //! Recursively traverses all pathes in the \p tree and updated SHAP values once it hits a leaf. //! Ref. Algorithm 2 in the paper by Lundberg et al. void shapRecursive(const TTree& tree, - const TDoubleVec& samplesPerNode, const CDataFrameCategoryEncoder& encoder, const CEncodedDataFrameRowRef& encodedRow, SPath& splitPath, @@ -146,7 +136,6 @@ class MATHS_EXPORT CTreeShapFeatureImportance { private: TTreeVec m_Trees; std::size_t m_NumberThreads; - TDoubleVecVec m_SamplesPerNode; }; } } diff --git a/lib/api/CBoostedTreeInferenceModelBuilder.cc b/lib/api/CBoostedTreeInferenceModelBuilder.cc index 7783c736b9..a1ebacb230 100644 --- a/lib/api/CBoostedTreeInferenceModelBuilder.cc +++ b/lib/api/CBoostedTreeInferenceModelBuilder.cc @@ -89,6 +89,7 @@ void CBoostedTreeInferenceModelBuilder::addNode(std::size_t splitFeature, bool assignMissingToLeft, double nodeValue, double gain, + std::size_t numberSamples, maths::CBoostedTreeNode::TOptionalNodeIndex leftChild, maths::CBoostedTreeNode::TOptionalNodeIndex rightChild) { auto ensemble{static_cast(m_Definition.trainedModel().get())}; @@ -97,8 +98,9 @@ void CBoostedTreeInferenceModelBuilder::addNode(std::size_t splitFeature, if (tree == nullptr) { HANDLE_FATAL(<< "Internal error. Tree points to a nullptr.") } - tree->treeStructure().emplace_back(tree->size(), splitValue, assignMissingToLeft, nodeValue, - splitFeature, leftChild, rightChild, gain); + tree->treeStructure().emplace_back(tree->size(), splitValue, assignMissingToLeft, + nodeValue, splitFeature, numberSamples, + leftChild, rightChild, gain); } CBoostedTreeInferenceModelBuilder::CBoostedTreeInferenceModelBuilder(TStrVec fieldNames, diff --git a/lib/api/CInferenceModelDefinition.cc b/lib/api/CInferenceModelDefinition.cc index 4ee8cd5e52..817ba2b0ac 100644 --- a/lib/api/CInferenceModelDefinition.cc +++ b/lib/api/CInferenceModelDefinition.cc @@ -34,6 +34,7 @@ const std::string JSON_LEFT_CHILD_TAG{"left_child"}; const std::string JSON_LOGISTIC_REGRESSION_TAG{"logistic_regression"}; const std::string JSON_LT{"lt"}; const std::string JSON_NODE_INDEX_TAG{"node_index"}; +const std::string JSON_NUMBER_SAMPLES_TAG{"number_samples"}; const std::string JSON_ONE_HOT_ENCODING_TAG{"one_hot_encoding"}; const std::string JSON_PREPROCESSORS_TAG{"preprocessors"}; const std::string JSON_RIGHT_CHILD_TAG{"right_child"}; @@ -79,6 +80,9 @@ void addJsonArray(const std::string& tag, void CTree::CTreeNode::addToDocument(rapidjson::Value& parentObject, TRapidJsonWriter& writer) const { writer.addMember(JSON_NODE_INDEX_TAG, rapidjson::Value(m_NodeIndex).Move(), parentObject); + writer.addMember( + JSON_NUMBER_SAMPLES_TAG, + rapidjson::Value(static_cast(m_NumberSamples)).Move(), parentObject); if (m_LeftChild) { // internal node @@ -118,11 +122,13 @@ CTree::CTreeNode::CTreeNode(TNodeIndex nodeIndex, bool defaultLeft, double leafValue, std::size_t splitFeature, + std::size_t numberSamples, const TOptionalNodeIndex& leftChild, const TOptionalNodeIndex& rightChild, const TOptionalDouble& splitGain) - : m_DefaultLeft(defaultLeft), m_NodeIndex(nodeIndex), m_LeftChild(leftChild), - m_RightChild(rightChild), m_SplitFeature(splitFeature), + : m_DefaultLeft(defaultLeft), m_NodeIndex(nodeIndex), + m_LeftChild(leftChild), m_RightChild(rightChild), + m_SplitFeature(splitFeature), m_NumberSamples(numberSamples), m_Threshold(threshold), m_LeafValue(leafValue), m_SplitGain(splitGain) { } diff --git a/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc b/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc index 600b99643a..645f5199d3 100644 --- a/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc +++ b/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc @@ -299,8 +299,9 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoImportance, SFixture) { // c1 explains 95% of the prediction value, i.e. the difference from the prediction is less than 2%. BOOST_REQUIRE_CLOSE(c1, prediction, 5.0); for (const auto& feature : {"c2", "c3", "c4"}) { - BOOST_REQUIRE_SMALL(readShapValue(result, feature), 2.0); - cNoImportanceMean.add(std::fabs(readShapValue(result, feature))); + double c = readShapValue(result, feature); + BOOST_REQUIRE_SMALL(c, 2.0); + cNoImportanceMean.add(std::fabs(c)); } } } diff --git a/lib/api/unittest/testfiles/inference_json_schema/model_definition.schema.json b/lib/api/unittest/testfiles/inference_json_schema/model_definition.schema.json index 594dcaa5e2..470fc3b5bb 100644 --- a/lib/api/unittest/testfiles/inference_json_schema/model_definition.schema.json +++ b/lib/api/unittest/testfiles/inference_json_schema/model_definition.schema.json @@ -66,6 +66,10 @@ }, "right_child": { "type": "integer" + }, + "number_samples": { + "description": "Number of training samples that were affected by the node.", + "type": "integer" } }, "required": [ @@ -75,7 +79,8 @@ "decision_type", "default_left", "left_child", - "right_child" + "right_child", + "number_samples" ], "additionalProperties": false }, @@ -88,11 +93,16 @@ }, "leaf_value": { "type": "number" + }, + "number_samples": { + "description": "Number of training samples that were affected by the node.", + "type": "integer" } }, "required": [ "node_index", - "leaf_value" + "leaf_value", + "number_samples" ], "additionalProperties": false }, @@ -234,10 +244,14 @@ "items": { "type": "number" } + }, + "num_classes": { + "type": "integer" } }, "required": [ - "weights" + "weights", + "num_classes" ] } }, diff --git a/lib/maths/CBoostedTree.cc b/lib/maths/CBoostedTree.cc index a967808e3c..db60a79e23 100644 --- a/lib/maths/CBoostedTree.cc +++ b/lib/maths/CBoostedTree.cc @@ -27,6 +27,7 @@ const std::string SPLIT_FEATURE_TAG{"split_feature"}; const std::string ASSIGN_MISSING_TO_LEFT_TAG{"assign_missing_to_left "}; const std::string NODE_VALUE_TAG{"node_value"}; const std::string SPLIT_VALUE_TAG{"split_value"}; +const std::string NUMBER_SAMPLES_TAG{"number_samples"}; double LOG_EPSILON{std::log(100.0 * std::numeric_limits::epsilon())}; @@ -393,6 +394,7 @@ void CBoostedTreeNode::acceptPersistInserter(core::CStatePersistInserter& insert core::CPersistUtils::persist(ASSIGN_MISSING_TO_LEFT_TAG, m_AssignMissingToLeft, inserter); core::CPersistUtils::persist(NODE_VALUE_TAG, m_NodeValue, inserter); core::CPersistUtils::persist(SPLIT_VALUE_TAG, m_SplitValue, inserter); + core::CPersistUtils::persist(NUMBER_SAMPLES_TAG, m_NumberSamples, inserter); } bool CBoostedTreeNode::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) { @@ -411,6 +413,8 @@ bool CBoostedTreeNode::acceptRestoreTraverser(core::CStateRestoreTraverser& trav core::CPersistUtils::restore(NODE_VALUE_TAG, m_NodeValue, traverser)) RESTORE(SPLIT_VALUE_TAG, core::CPersistUtils::restore(SPLIT_VALUE_TAG, m_SplitValue, traverser)) + RESTORE(NUMBER_SAMPLES_TAG, + core::CPersistUtils::restore(NUMBER_SAMPLES_TAG, m_NumberSamples, traverser)) } while (traverser.next()); return true; } @@ -435,8 +439,16 @@ std::ostringstream& CBoostedTreeNode::doPrint(std::string pad, } void CBoostedTreeNode::accept(CVisitor& visitor) const { - visitor.addNode(m_SplitFeature, m_SplitValue, m_AssignMissingToLeft, - m_NodeValue, m_Gain, m_LeftChild, m_RightChild); + visitor.addNode(m_SplitFeature, m_SplitValue, m_AssignMissingToLeft, m_NodeValue, + m_Gain, m_NumberSamples, m_LeftChild, m_RightChild); +} + +void CBoostedTreeNode::numberSamples(std::size_t numberSamples) { + m_NumberSamples = numberSamples; +} + +std::size_t CBoostedTreeNode::numberSamples() const { + return m_NumberSamples; } CBoostedTree::CBoostedTree(core::CDataFrame& frame, diff --git a/lib/maths/CBoostedTreeImpl.cc b/lib/maths/CBoostedTreeImpl.cc index d8ed7b5ae6..1a5280e2f5 100644 --- a/lib/maths/CBoostedTreeImpl.cc +++ b/lib/maths/CBoostedTreeImpl.cc @@ -226,6 +226,7 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame, this->restoreBestHyperparameters(); std::tie(m_BestForest, std::ignore) = this->trainForest( frame, allTrainingRowsMask, allTrainingRowsMask, m_TrainingProgress); + m_Instrumentation->nextStep(static_cast(m_CurrentRound)); this->recordState(recordTrainStateCallback); @@ -242,12 +243,56 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame, this->computeProbabilityAtWhichToAssignClassOne(frame); + // populate numberSamples field in the final forest + this->computeNumberSamples(frame); + // Force progress to one because we can have early exit from loop skip altogether. m_Instrumentation->updateProgress(1.0); m_Instrumentation->updateMemoryUsage( static_cast(this->memoryUsage()) - lastMemoryUsage); } +void CBoostedTreeImpl::computeNumberSamples(const core::CDataFrame& frame) { + for (auto& tree : m_BestForest) { + if (tree.size() == 1) { + root(tree).numberSamples(frame.numberRows()); + } else { + auto result = frame.readRows( + m_NumberThreads, + core::bindRetrievableState( + [&](TSizeVec& samplesPerNode, const TRowItr& beginRows, const TRowItr& endRows) { + for (auto row = beginRows; row != endRows; ++row) { + auto encodedRow{m_Encoder->encode(*row)}; + const CBoostedTreeNode* node{&root(tree)}; + samplesPerNode[0] += 1; + std::size_t nextIndex; + while (node->isLeaf() == false) { + if (node->assignToLeft(encodedRow)) { + nextIndex = node->leftChildIndex(); + } else { + nextIndex = node->rightChildIndex(); + } + samplesPerNode[nextIndex] += 1; + node = &(tree[nextIndex]); + } + } + }, + TSizeVec(tree.size()))); + auto& state = result.first; + TSizeVec totalSamplesPerNode{std::move(state[0].s_FunctionState)}; + for (std::size_t i = 1; i < state.size(); ++i) { + for (std::size_t nodeIndex = 0; + nodeIndex < totalSamplesPerNode.size(); ++nodeIndex) { + totalSamplesPerNode[nodeIndex] += state[i].s_FunctionState[nodeIndex]; + } + } + for (std::size_t i = 0; i < tree.size(); ++i) { + tree[i].numberSamples(totalSamplesPerNode[i]); + } + } + } +} + void CBoostedTreeImpl::recordState(const TTrainingStateCallback& recordTrainState) const { recordTrainState([this](core::CStatePersistInserter& inserter) { this->acceptPersistInserter(inserter); @@ -997,6 +1042,10 @@ const CBoostedTreeNode& CBoostedTreeImpl::root(const TNodeVec& tree) { return tree[0]; } +CBoostedTreeNode& CBoostedTreeImpl::root(TNodeVec& tree) { + return tree[0]; +} + double CBoostedTreeImpl::predictRow(const CEncodedDataFrameRowRef& row, const TNodeVecVec& forest) { double result{0.0}; @@ -1148,9 +1197,8 @@ std::size_t CBoostedTreeImpl::maximumTreeSize(std::size_t numberRows) const { } namespace { -const std::string VERSION_7_5_TAG{"7.5"}; -const std::string VERSION_7_6_TAG{"7.6"}; -const TStrVec SUPPORTED_VERSIONS{VERSION_7_5_TAG, VERSION_7_6_TAG}; +const std::string VERSION_7_7_TAG{"7.7"}; +const TStrVec SUPPORTED_VERSIONS{VERSION_7_7_TAG}; const std::string BAYESIAN_OPTIMIZATION_TAG{"bayesian_optimization"}; const std::string BEST_FOREST_TAG{"best_forest"}; @@ -1214,7 +1262,7 @@ CBoostedTreeImpl::TStrVec CBoostedTreeImpl::bestHyperparameterNames() { } void CBoostedTreeImpl::acceptPersistInserter(core::CStatePersistInserter& inserter) const { - core::CPersistUtils::persist(VERSION_7_6_TAG, "", inserter); + core::CPersistUtils::persist(VERSION_7_7_TAG, "", inserter); core::CPersistUtils::persist(BAYESIAN_OPTIMIZATION_TAG, *m_BayesianOptimization, inserter); core::CPersistUtils::persist(BEST_FOREST_TEST_LOSS_TAG, m_BestForestTestLoss, inserter); core::CPersistUtils::persist(CURRENT_ROUND_TAG, m_CurrentRound, inserter); @@ -1266,15 +1314,7 @@ void CBoostedTreeImpl::acceptPersistInserter(core::CStatePersistInserter& insert } bool CBoostedTreeImpl::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) { - if (traverser.name() == VERSION_7_5_TAG) { - // Force downsample factor to 1.0. - m_DownsampleFactorOverride = 1.0; - m_DownsampleFactor = 1.0; - m_BestHyperparameters.downsampleFactor(1.0); - // We can't stop cross-validation early because we haven't gathered the - // per fold test losses. - m_StopCrossValidationEarly = false; - } else if (traverser.name() != VERSION_7_6_TAG) { + if (traverser.name() != VERSION_7_7_TAG) { LOG_ERROR(<< "Input error: unsupported state serialization version. " << "Currently supported versions: " << core::CContainerPrinter::print(SUPPORTED_VERSIONS) << "."); diff --git a/lib/maths/CTreeShapFeatureImportance.cc b/lib/maths/CTreeShapFeatureImportance.cc index 88ea00d979..bbdbd4f121 100644 --- a/lib/maths/CTreeShapFeatureImportance.cc +++ b/lib/maths/CTreeShapFeatureImportance.cc @@ -22,87 +22,42 @@ void CTreeShapFeatureImportance::shap(core::CDataFrame& frame, TSizeVec maxDepthVec; maxDepthVec.reserve(m_Trees.size()); for (auto& tree : m_Trees) { - auto samplesPerNode = CTreeShapFeatureImportance::samplesPerNode( - tree, frame, encoder, m_NumberThreads); - std::size_t maxDepth = - CTreeShapFeatureImportance::updateNodeValues(tree, 0, samplesPerNode, 0); + std::size_t maxDepth = CTreeShapFeatureImportance::updateNodeValues(tree, 0, 0); maxDepthVec.push_back(maxDepth); - m_SamplesPerNode.emplace_back(std::move(samplesPerNode)); } - auto result = frame.writeColumns( - m_NumberThreads, [&](const TRowItr& beginRows, const TRowItr& endRows) { - for (auto row = beginRows; row != endRows; ++row) { - auto encodedRow{encoder.encode(*row)}; - for (std::size_t i = 0; i < m_Trees.size(); ++i) { - SPath path(maxDepthVec[i] + 1); - CTreeShapFeatureImportance::shapRecursive( - m_Trees[i], m_SamplesPerNode[i], encoder, encodedRow, - path, 0, 1.0, 1.0, -1, offset, row); - } + auto result = frame.writeColumns(m_NumberThreads, [&](const TRowItr& beginRows, + const TRowItr& endRows) { + for (auto row = beginRows; row != endRows; ++row) { + auto encodedRow{encoder.encode(*row)}; + for (std::size_t i = 0; i < m_Trees.size(); ++i) { + SPath path(maxDepthVec[i] + 1); + CTreeShapFeatureImportance::shapRecursive( + m_Trees[i], encoder, encodedRow, path, 0, 1.0, 1.0, -1, offset, row); } - }); -} - -CTreeShapFeatureImportance::TDoubleVec -CTreeShapFeatureImportance::samplesPerNode(const TTree& tree, - const core::CDataFrame& frame, - const CDataFrameCategoryEncoder& encoder, - std::size_t numThreads) { - auto result = frame.readRows( - numThreads, core::bindRetrievableState( - [&](TDoubleVec& samplesPerNode, - const TRowItr& beginRows, const TRowItr& endRows) { - for (auto row = beginRows; row != endRows; ++row) { - auto encodedRow{encoder.encode(*row)}; - const CBoostedTreeNode* node{&tree[0]}; - samplesPerNode[0] += 1.0; - std::size_t nextIndex; - while (node->isLeaf() == false) { - if (node->assignToLeft(encodedRow)) { - nextIndex = node->leftChildIndex(); - } else { - nextIndex = node->rightChildIndex(); - } - samplesPerNode[nextIndex] += 1.0; - node = &(tree[nextIndex]); - } - } - }, - TDoubleVec(tree.size()))); - - auto& state = result.first; - TDoubleVec totalSamplesPerNode{std::move(state[0].s_FunctionState)}; - for (std::size_t i = 1; i < state.size(); ++i) { - for (std::size_t nodeIndex = 0; nodeIndex < totalSamplesPerNode.size(); ++nodeIndex) { - totalSamplesPerNode[nodeIndex] += state[i].s_FunctionState[nodeIndex]; } - } - - return totalSamplesPerNode; + }); } CTreeShapFeatureImportance::CTreeShapFeatureImportance(TTreeVec trees, std::size_t threads) - : m_Trees{std::move(trees)}, m_NumberThreads{threads}, m_SamplesPerNode() { - m_SamplesPerNode.reserve(m_Trees.size()); + : m_Trees{std::move(trees)}, m_NumberThreads{threads} { } -std::size_t CTreeShapFeatureImportance::updateNodeValues(TTree& tree, - std::size_t nodeIndex, - const TDoubleVec& samplesPerNode, - std::size_t depth) { +size_t CTreeShapFeatureImportance::updateNodeValues(TTree& tree, + std::size_t nodeIndex, + std::size_t depth) { auto& node{tree[nodeIndex]}; if (node.isLeaf()) { return 0; } std::size_t depthLeft{CTreeShapFeatureImportance::updateNodeValues( - tree, node.leftChildIndex(), samplesPerNode, depth + 1)}; + tree, node.leftChildIndex(), depth + 1)}; std::size_t depthRight{CTreeShapFeatureImportance::updateNodeValues( - tree, node.rightChildIndex(), samplesPerNode, depth + 1)}; + tree, node.rightChildIndex(), depth + 1)}; - double leftWeight{samplesPerNode[node.leftChildIndex()]}; - double rightWeight{samplesPerNode[node.rightChildIndex()]}; + std::size_t leftWeight{tree[node.leftChildIndex()].numberSamples()}; + std::size_t rightWeight{tree[node.rightChildIndex()].numberSamples()}; double averageValue{(leftWeight * tree[node.leftChildIndex()].value() + rightWeight * tree[node.rightChildIndex()].value()) / (leftWeight + rightWeight)}; @@ -111,7 +66,6 @@ std::size_t CTreeShapFeatureImportance::updateNodeValues(TTree& tree, } void CTreeShapFeatureImportance::shapRecursive(const TTree& tree, - const TDoubleVec& samplesPerNode, const CDataFrameCategoryEncoder& encoder, const CEncodedDataFrameRowRef& encodedRow, SPath& splitPath, @@ -170,16 +124,18 @@ void CTreeShapFeatureImportance::shapRecursive(const TTree& tree, CTreeShapFeatureImportance::unwindPath(splitPath, pathIndex); } - double hotFractionZero = samplesPerNode[hotIndex] / samplesPerNode[nodeIndex]; - double coldFractionZero = samplesPerNode[coldIndex] / samplesPerNode[nodeIndex]; + double hotFractionZero{static_cast(tree[hotIndex].numberSamples()) / + tree[nodeIndex].numberSamples()}; + double coldFractionZero{static_cast(tree[coldIndex].numberSamples()) / + tree[nodeIndex].numberSamples()}; std::size_t nextIndex = splitPath.nextIndex(); - this->shapRecursive(tree, samplesPerNode, encoder, encodedRow, splitPath, - hotIndex, incomingFractionZero * hotFractionZero, + this->shapRecursive(tree, encoder, encodedRow, splitPath, hotIndex, + incomingFractionZero * hotFractionZero, incomingFractionOne, splitFeature, offset, row); this->unwindPath(splitPath, nextIndex); - this->shapRecursive(tree, samplesPerNode, encoder, encodedRow, splitPath, - coldIndex, incomingFractionZero * coldFractionZero, - 0.0, splitFeature, offset, row); + this->shapRecursive(tree, encoder, encodedRow, splitPath, coldIndex, + incomingFractionZero * coldFractionZero, 0.0, + splitFeature, offset, row); this->unwindPath(splitPath, nextIndex); if (backupPath) { // now we swap to restore the data before unwinding diff --git a/lib/maths/unittest/CTreeShapFeatureImportanceTest.cc b/lib/maths/unittest/CTreeShapFeatureImportanceTest.cc index 2f2b7e6545..e6c48cf770 100644 --- a/lib/maths/unittest/CTreeShapFeatureImportanceTest.cc +++ b/lib/maths/unittest/CTreeShapFeatureImportanceTest.cc @@ -79,6 +79,14 @@ struct SFixtureSingleTree { tree[5].value(13); tree[6].value(18); + tree[0].numberSamples(4); + tree[1].numberSamples(2); + tree[2].numberSamples(2); + tree[3].numberSamples(1); + tree[4].numberSamples(1); + tree[5].numberSamples(1); + tree[6].numberSamples(1); + s_TreeFeatureImportance = std::make_unique>( {tree}); @@ -97,13 +105,15 @@ struct SFixtureSingleTreeRandom { SFixtureSingleTreeRandom() : s_TreeFeatureImportance{}, s_Encoder{} { test::CRandomNumbers rng; this->initFrame(rng); + + CStubMakeDataFrameCategoryEncoder stubParameters{1, *s_Frame, 0, s_NumberFeatures}; + s_Encoder = std::make_unique(stubParameters); + this->initTree(rng); s_TreeFeatureImportance = std::make_unique>( {s_Tree}); - CStubMakeDataFrameCategoryEncoder stubParameters{1, *s_Frame, 0, s_NumberFeatures}; - s_Encoder = std::make_unique(stubParameters); } void initFrame(test::CRandomNumbers& rng) { @@ -158,6 +168,32 @@ struct SFixtureSingleTreeRandom { for (std::size_t i = 0; i < numberLeafs; ++i) { s_Tree[s_NumberInnerNodes + i].value(leafValues[i]); } + + // set correct number samples + auto result = s_Frame->readRows( + 1, core::bindRetrievableState( + [&](TSizeVec& numberSamples, const TRowItr& beginRows, const TRowItr& endRows) { + for (auto row = beginRows; row != endRows; ++row) { + auto node{&(s_Tree[0])}; + auto encodedRow{s_Encoder->encode(*row)}; + numberSamples[0] += 1; + std::size_t nextIndex; + while (node->isLeaf() == false) { + if (node->assignToLeft(encodedRow)) { + nextIndex = node->leftChildIndex(); + } else { + nextIndex = node->rightChildIndex(); + } + numberSamples[nextIndex] += 1; + node = &(s_Tree[nextIndex]); + } + } + }, + TSizeVec(s_Tree.size()))); + TSizeVec numberSamples{std::move(result.first[0].s_FunctionState)}; + for (std::size_t i = 0; i < numberSamples.size(); ++i) { + s_Tree[i].numberSamples(numberSamples[i]); + } } TDataFrameUPtr s_Frame; @@ -189,21 +225,35 @@ struct SFixtureMultipleTrees { TTree tree1(1); tree1[0].split(0, 0.55, true, 0.0, 0.0, tree1); + tree1[0].numberSamples(10); tree1[1].split(0, 0.41, true, 0.0, 0.0, tree1); + tree1[1].numberSamples(6); tree1[2].split(1, 0.25, true, 0.0, 0.0, tree1); + tree1[2].numberSamples(4); tree1[3].value(1.18230136); + tree1[3].numberSamples(5); tree1[4].value(1.98006658); + tree1[4].numberSamples(1); tree1[5].value(3.25350885); + tree1[5].numberSamples(3); tree1[6].value(2.42384369); + tree1[6].numberSamples(1); TTree tree2(1); tree2[0].split(0, 0.45, true, 0.0, 0.0, tree2); + tree2[0].numberSamples(10); tree2[1].split(0, 0.25, true, 0.0, 0.0, tree2); + tree2[1].numberSamples(5); tree2[2].split(0, 0.59, true, 0.0, 0.0, tree2); + tree2[2].numberSamples(5); tree2[3].value(1.04476388); + tree2[3].numberSamples(3); tree2[4].value(1.52799228); + tree2[4].numberSamples(2); tree2[5].value(1.98006658); + tree2[5].numberSamples(1); tree2[6].value(2.950216); + tree2[6].numberSamples(4); s_TreeFeatureImportance = std::make_unique>( @@ -221,8 +271,8 @@ struct SFixtureMultipleTrees { class BruteForceTreeShap { public: - BruteForceTreeShap(const TTree& tree, const TDoubleVec& samplesPerNode, std::size_t numberFeatures) - : m_Tree{tree}, m_SamplesPerNode{samplesPerNode}, m_Powerset{}, m_NumberFeatures{numberFeatures} { + BruteForceTreeShap(const TTree& tree, std::size_t numberFeatures) + : m_Tree{tree}, m_Powerset{}, m_NumberFeatures{numberFeatures} { this->initPowerset({}, numberFeatures); } @@ -251,7 +301,7 @@ class BruteForceTreeShap { (boost::math::binomial_coefficient( static_cast(m_NumberFeatures), static_cast(S.size())) * - (static_cast(m_NumberFeatures) - S.size()))}; + (static_cast(m_NumberFeatures - S.size())))}; double fWithoutIndex = this->conditionalExpectation(encodedRow, S); S.insert(inputColumnIndex); @@ -305,36 +355,28 @@ class BruteForceTreeShap { } } else { - return this->conditionalExpectation(x, S, leftChildIndex, - weight * m_SamplesPerNode[leftChildIndex] / - m_SamplesPerNode[nodeIndex]) + - this->conditionalExpectation(x, S, rightChildIndex, - weight * m_SamplesPerNode[rightChildIndex] / - m_SamplesPerNode[nodeIndex]); + return this->conditionalExpectation( + x, S, leftChildIndex, + weight * m_Tree[leftChildIndex].numberSamples() / + m_Tree[nodeIndex].numberSamples()) + + this->conditionalExpectation( + x, S, rightChildIndex, + weight * m_Tree[rightChildIndex].numberSamples() / + m_Tree[nodeIndex].numberSamples()); } } } private: const TTree& m_Tree; - const TDoubleVec& m_SamplesPerNode; TSizePowerset m_Powerset{}; std::size_t m_NumberFeatures; }; -BOOST_FIXTURE_TEST_CASE(testSingleTreeSamplesPerNode, SFixtureSingleTree) { - - auto samplesPerNode = maths::CTreeShapFeatureImportance::samplesPerNode( - s_TreeFeatureImportance->trees()[0], *s_Frame, *s_Encoder, 1); - TDoubleVec expectedSamplesPerNode{4, 2, 2, 1, 1, 1, 1}; - BOOST_TEST_REQUIRE(samplesPerNode == expectedSamplesPerNode); -} - BOOST_FIXTURE_TEST_CASE(testSingleTreeExpectedNodeValues, SFixtureSingleTree) { - TDoubleVec samplesPerNode{4, 2, 2, 1, 1, 1, 1}; std::size_t depth = maths::CTreeShapFeatureImportance::updateNodeValues( - s_TreeFeatureImportance->trees()[0], 0, samplesPerNode, 0); + s_TreeFeatureImportance->trees()[0], 0, 0); BOOST_TEST_REQUIRE(depth == 2); TDoubleVec expectedValues{10.5, 5.5, 15.5, 3.0, 8.0, 13.0, 18.0}; auto& tree{s_TreeFeatureImportance->trees()[0]}; @@ -379,10 +421,7 @@ BOOST_FIXTURE_TEST_CASE(testMultipleTreesShap, SFixtureMultipleTrees) { } BOOST_FIXTURE_TEST_CASE(testSingleTreeBruteForceShap, SFixtureSingleTree) { - auto samplesPerNode = maths::CTreeShapFeatureImportance::samplesPerNode( - s_TreeFeatureImportance->trees()[0], *s_Frame, *s_Encoder, 1); - BruteForceTreeShap bfShap(s_TreeFeatureImportance->trees()[0], - samplesPerNode, s_NumberFeatures); + BruteForceTreeShap bfShap(s_TreeFeatureImportance->trees()[0], s_NumberFeatures); auto actualPhi = bfShap.shap(*s_Frame, *s_Encoder, 1); TDoubleVecVec expectedPhi{{-5., -2.5}, {-5., 2.5}, {5., -2.5}, {5., 2.5}}; for (std::size_t i = 0; i < s_NumberRows; ++i) { @@ -395,9 +434,7 @@ BOOST_FIXTURE_TEST_CASE(testSingleTreeBruteForceShap, SFixtureSingleTree) { BOOST_FIXTURE_TEST_CASE(testSingleTreeShapRandomDataFrame, SFixtureSingleTreeRandom) { // Compare tree shap algorithm with the brute force approach (Algorithm // 1 in paper by Lundberg et al.) on a random data set with a random tree. - auto samplesPerNode = maths::CTreeShapFeatureImportance::samplesPerNode( - s_TreeFeatureImportance->trees()[0], *s_Frame, *s_Encoder, 1); - BruteForceTreeShap bfShap(this->s_Tree, samplesPerNode, s_NumberFeatures); + BruteForceTreeShap bfShap(this->s_Tree, s_NumberFeatures); auto expectedPhi = bfShap.shap(*s_Frame, *s_Encoder, 1); std::size_t offset{s_Frame->numberColumns()}; s_Frame->resizeColumns(1, offset * 2); diff --git a/lib/maths/unittest/testfiles/error_bayesian_optimisation_state.json b/lib/maths/unittest/testfiles/error_bayesian_optimisation_state.json index 98459ef5d0..f840442489 100644 --- a/lib/maths/unittest/testfiles/error_bayesian_optimisation_state.json +++ b/lib/maths/unittest/testfiles/error_bayesian_optimisation_state.json @@ -1,5 +1,5 @@ { - "7.5": "", + "7.7": "", "bayesian_optimization": { "7.5": "", "rng": "16294208416658607535:7960286522194355700", diff --git a/lib/maths/unittest/testfiles/error_boosted_tree_impl_state.json b/lib/maths/unittest/testfiles/error_boosted_tree_impl_state.json index 3d46515f81..37f4fe3495 100644 --- a/lib/maths/unittest/testfiles/error_boosted_tree_impl_state.json +++ b/lib/maths/unittest/testfiles/error_boosted_tree_impl_state.json @@ -1,5 +1,5 @@ { - "7.5": "", + "7.7": "", "bayesian_optimization": { "7.5": "", "rng": "16294208416658607535:7960286522194355700",