diff --git a/include/api/CInferenceModelMetadata.h b/include/api/CInferenceModelMetadata.h index d844bc3d38..8e3256a778 100644 --- a/include/api/CInferenceModelMetadata.h +++ b/include/api/CInferenceModelMetadata.h @@ -21,6 +21,8 @@ namespace api { //! (such as totol feature importance) into JSON format. class API_EXPORT CInferenceModelMetadata { public: + static const std::string JSON_BASELINE_TAG; + static const std::string JSON_FEATURE_IMPORTANCE_BASELINE_TAG; static const std::string JSON_CLASS_NAME_TAG; static const std::string JSON_CLASSES_TAG; static const std::string JSON_FEATURE_NAME_TAG; @@ -48,6 +50,9 @@ class API_EXPORT CInferenceModelMetadata { //! Add importances \p values to the feature with index \p i to calculate total feature importance. //! Total feature importance is the mean of the magnitudes of importances for individual data points. void addToFeatureImportance(std::size_t i, const TVector& values); + //! Set the feature importance baseline (the individual feature importances are additive corrections + //! to the baseline value). + void featureImportanceBaseline(TVector&& baseline); private: using TMeanAccumulator = @@ -55,13 +60,16 @@ class API_EXPORT CInferenceModelMetadata { using TMinMaxAccumulator = std::vector>; using TSizeMeanAccumulatorUMap = std::unordered_map; using TSizeMinMaxAccumulatorUMap = std::unordered_map; + using TOptionalVector = boost::optional; private: void writeTotalFeatureImportance(TRapidJsonWriter& writer) const; + void writeFeatureImportanceBaseline(TRapidJsonWriter& writer) const; private: TSizeMeanAccumulatorUMap m_TotalShapValuesMean; TSizeMinMaxAccumulatorUMap m_TotalShapValuesMinMax; + TOptionalVector m_ShapBaseline; TStrVec m_ColumnNames; TStrVec m_ClassValues; TPredictionFieldTypeResolverWriter m_PredictionFieldTypeResolverWriter = diff --git a/include/maths/CTreeShapFeatureImportance.h b/include/maths/CTreeShapFeatureImportance.h index 6b46ce90e2..d94ee0c3f1 100644 --- a/include/maths/CTreeShapFeatureImportance.h +++ b/include/maths/CTreeShapFeatureImportance.h @@ -77,7 +77,7 @@ class MATHS_EXPORT CTreeShapFeatureImportance { const TStrVec& columnNames() const; //! Get the baseline. - double baseline(std::size_t classIdx = 0) const; + TVector baseline() const; private: //! Collects the elements of the path through decision tree that are updated together diff --git a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc index 7e98620702..78730cb95c 100644 --- a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc @@ -169,74 +169,61 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( [this](const std::string& categoryValue, core::CRapidJsonConcurrentLineWriter& writer) { this->writePredictedCategoryValue(categoryValue, writer); }); - featureImportance->shap(row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices, - const TStrVec& featureNames, - const maths::CTreeShapFeatureImportance::TVectorVec& shap) { - writer.Key(FEATURE_IMPORTANCE_FIELD_NAME); - writer.StartArray(); - TDoubleVec baseline; - baseline.reserve(numberClasses); - for (std::size_t j = 0; j < shap[0].size() && j < numberClasses; ++j) { - baseline.push_back(featureImportance->baseline(j)); - } - for (auto i : indices) { - if (shap[i].norm() != 0.0) { - writer.StartObject(); - writer.Key(FEATURE_NAME_FIELD_NAME); - writer.String(featureNames[i]); - if (shap[i].size() == 1) { - // output feature importance for individual classes in binary case - writer.Key(CLASSES_FIELD_NAME); - writer.StartArray(); - for (std::size_t j = 0; j < numberClasses; ++j) { - writer.StartObject(); - writer.Key(CLASS_NAME_FIELD_NAME); - writePredictedCategoryValue(classValues[j], writer); - writer.Key(IMPORTANCE_FIELD_NAME); - if (j == 1) { - writer.Double(shap[i](0)); - } else { - writer.Double(-shap[i](0)); + featureImportance->shap( + row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices, + const TStrVec& featureNames, + const maths::CTreeShapFeatureImportance::TVectorVec& shap) { + writer.Key(FEATURE_IMPORTANCE_FIELD_NAME); + writer.StartArray(); + for (auto i : indices) { + if (shap[i].norm() != 0.0) { + writer.StartObject(); + writer.Key(FEATURE_NAME_FIELD_NAME); + writer.String(featureNames[i]); + if (shap[i].size() == 1) { + // output feature importance for individual classes in binary case + writer.Key(CLASSES_FIELD_NAME); + writer.StartArray(); + for (std::size_t j = 0; j < numberClasses; ++j) { + writer.StartObject(); + writer.Key(CLASS_NAME_FIELD_NAME); + writePredictedCategoryValue(classValues[j], writer); + writer.Key(IMPORTANCE_FIELD_NAME); + if (j == 1) { + writer.Double(shap[i](0)); + } else { + writer.Double(-shap[i](0)); + } + writer.EndObject(); } - writer.EndObject(); - } - writer.EndArray(); - } else { - // output feature importance for individual classes in multiclass case - writer.Key(CLASSES_FIELD_NAME); - writer.StartArray(); - TDoubleVec featureImportanceSum(numberClasses, 0.0); - for (std::size_t j = 0; - j < shap[i].size() && j < numberClasses; ++j) { - for (auto k : indices) { - featureImportanceSum[j] += shap[k](j); + writer.EndArray(); + } else { + // output feature importance for individual classes in multiclass case + writer.Key(CLASSES_FIELD_NAME); + writer.StartArray(); + for (std::size_t j = 0; + j < shap[i].size() && j < numberClasses; ++j) { + writer.StartObject(); + writer.Key(CLASS_NAME_FIELD_NAME); + writePredictedCategoryValue(classValues[j], writer); + writer.Key(IMPORTANCE_FIELD_NAME); + writer.Double(shap[i](j)); + writer.EndObject(); } + writer.EndArray(); } - for (std::size_t j = 0; - j < shap[i].size() && j < numberClasses; ++j) { - writer.StartObject(); - writer.Key(CLASS_NAME_FIELD_NAME); - writePredictedCategoryValue(classValues[j], writer); - writer.Key(IMPORTANCE_FIELD_NAME); - double correctedShap{ - shap[i](j) * (baseline[j] / featureImportanceSum[j] + 1.0)}; - writer.Double(correctedShap); - writer.EndObject(); - } - writer.EndArray(); + writer.EndObject(); } - writer.EndObject(); } - } - writer.EndArray(); + writer.EndArray(); - for (std::size_t i = 0; i < shap.size(); ++i) { - if (shap[i].lpNorm<1>() != 0) { - const_cast(this) - ->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]); + for (std::size_t i = 0; i < shap.size(); ++i) { + if (shap[i].lpNorm<1>() != 0) { + const_cast(this) + ->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]); + } } - } - }); + }); } writer.EndObject(); } @@ -306,6 +293,10 @@ CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelDefinition( CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelMetadata() const { + const auto& featureImportance = this->boostedTree().shap(); + if (featureImportance) { + m_InferenceModelMetadata.featureImportanceBaseline(featureImportance->baseline()); + } return m_InferenceModelMetadata; } diff --git a/lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc b/lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc index daa0c34b95..da38052361 100644 --- a/lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc @@ -155,7 +155,11 @@ CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelDefinition( CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelMetadata() const { - return TOptionalInferenceModelMetadata(m_InferenceModelMetadata); + const auto& featureImportance = this->boostedTree().shap(); + if (featureImportance) { + m_InferenceModelMetadata.featureImportanceBaseline(featureImportance->baseline()); + } + return m_InferenceModelMetadata; } // clang-format off diff --git a/lib/api/CInferenceModelMetadata.cc b/lib/api/CInferenceModelMetadata.cc index 2a6553d581..bbc4605533 100644 --- a/lib/api/CInferenceModelMetadata.cc +++ b/lib/api/CInferenceModelMetadata.cc @@ -12,6 +12,7 @@ namespace api { void CInferenceModelMetadata::write(TRapidJsonWriter& writer) const { this->writeTotalFeatureImportance(writer); + this->writeFeatureImportanceBaseline(writer); } void CInferenceModelMetadata::writeTotalFeatureImportance(TRapidJsonWriter& writer) const { @@ -88,6 +89,53 @@ void CInferenceModelMetadata::writeTotalFeatureImportance(TRapidJsonWriter& writ writer.EndArray(); } +void CInferenceModelMetadata::writeFeatureImportanceBaseline(TRapidJsonWriter& writer) const { + if (m_ShapBaseline) { + writer.Key(JSON_FEATURE_IMPORTANCE_BASELINE_TAG); + writer.StartObject(); + if (m_ShapBaseline->size() == 1 && m_ClassValues.empty()) { + // Regression + writer.Key(JSON_BASELINE_TAG); + writer.Double(m_ShapBaseline.get()(0)); + } else if (m_ShapBaseline->size() == 1 && m_ClassValues.empty() == false) { + // Binary classification + writer.Key(JSON_CLASSES_TAG); + writer.StartArray(); + for (std::size_t j = 0; j < m_ClassValues.size(); ++j) { + writer.StartObject(); + writer.Key(JSON_CLASS_NAME_TAG); + m_PredictionFieldTypeResolverWriter(m_ClassValues[j], writer); + writer.Key(JSON_BASELINE_TAG); + if (j == 1) { + writer.Double(m_ShapBaseline.get()(0)); + } else { + writer.Double(-m_ShapBaseline.get()(0)); + } + writer.EndObject(); + } + + writer.EndArray(); + + } else { + // Multiclass classification + writer.Key(JSON_CLASSES_TAG); + writer.StartArray(); + for (std::size_t j = 0; j < static_cast(m_ShapBaseline->size()) && + j < m_ClassValues.size(); + ++j) { + writer.StartObject(); + writer.Key(JSON_CLASS_NAME_TAG); + m_PredictionFieldTypeResolverWriter(m_ClassValues[j], writer); + writer.Key(JSON_BASELINE_TAG); + writer.Double(m_ShapBaseline.get()(j)); + writer.EndObject(); + } + writer.EndArray(); + } + writer.EndObject(); + } +} + const std::string& CInferenceModelMetadata::typeString() const { return JSON_MODEL_METADATA_TAG; } @@ -119,7 +167,13 @@ void CInferenceModelMetadata::addToFeatureImportance(std::size_t i, const TVecto } } +void CInferenceModelMetadata::featureImportanceBaseline(TVector&& baseline) { + m_ShapBaseline = baseline; +} + // clang-format off +const std::string CInferenceModelMetadata::JSON_BASELINE_TAG{"baseline"}; +const std::string CInferenceModelMetadata::JSON_FEATURE_IMPORTANCE_BASELINE_TAG{"feature_importance_baseline"}; const std::string CInferenceModelMetadata::JSON_CLASS_NAME_TAG{"class_name"}; const std::string CInferenceModelMetadata::JSON_CLASSES_TAG{"classes"}; const std::string CInferenceModelMetadata::JSON_FEATURE_NAME_TAG{"feature_name"}; diff --git a/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc b/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc index d55624f3a4..5706127052 100644 --- a/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc +++ b/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc @@ -487,6 +487,41 @@ double readTotalShapValue(const RESULTS& results, std::string shapField, std::st } return 0.0; } + +template +double readBaselineValue(const RESULTS& results) { + using TModelMetadata = api::CInferenceModelMetadata; + for (const auto& result : results.GetArray()) { + if (result.HasMember(TModelMetadata::JSON_MODEL_METADATA_TAG) && + result[TModelMetadata::JSON_MODEL_METADATA_TAG].HasMember( + TModelMetadata::JSON_FEATURE_IMPORTANCE_BASELINE_TAG)) { + return result[TModelMetadata::JSON_MODEL_METADATA_TAG][TModelMetadata::JSON_FEATURE_IMPORTANCE_BASELINE_TAG] + [TModelMetadata::JSON_BASELINE_TAG] + .GetDouble(); + } + } + return 0.0; +} + +template +double readBaselineValue(const RESULTS& results, std::string className) { + using TModelMetadata = api::CInferenceModelMetadata; + for (const auto& result : results.GetArray()) { + if (result.HasMember(TModelMetadata::JSON_MODEL_METADATA_TAG) && + result[TModelMetadata::JSON_MODEL_METADATA_TAG].HasMember( + TModelMetadata::JSON_FEATURE_IMPORTANCE_BASELINE_TAG)) { + for (const auto& item : + result[TModelMetadata::JSON_MODEL_METADATA_TAG][TModelMetadata::JSON_FEATURE_IMPORTANCE_BASELINE_TAG] + [TModelMetadata::JSON_CLASSES_TAG] + .GetArray()) { + if (item[TModelMetadata::JSON_CLASS_NAME_TAG].GetString() == className) { + return item[TModelMetadata::JSON_BASELINE_TAG].GetDouble(); + } + } + } + } + return 0.0; +} } BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceAllShap, SFixture) { @@ -509,14 +544,7 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceAllShap, SFixture) { double c1TotalShapActual{0.0}, c2TotalShapActual{0.0}, c3TotalShapActual{0.0}, c4TotalShapActual{0.0}; bool hasTotalFeatureImportance{false}; - for (const auto& result : results.GetArray()) { - if (result.HasMember("row_results")) { - double prediction{ - result["row_results"]["results"]["ml"]["target_prediction"].GetDouble()}; - baselineAccumulator.add(prediction); - } - } - double baseline{maths::CBasicStatistics::mean(baselineAccumulator)}; + double baseline{readBaselineValue(results)}; for (const auto& result : results.GetArray()) { if (result.HasMember("row_results")) { double c1{readShapValue(result, "c1")}; @@ -609,8 +637,6 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) { // values are indeed a local approximation of the predicted log-odds. std::size_t topShapValues{4}; - TMeanAccumulator baselineFooAccumulator; - TMeanAccumulator baselineBarAccumulator; auto resultsPair{runBinaryClassification(topShapValues, {0.5, -0.7, 0.2, -0.2})}; auto results{std::move(resultsPair.first)}; TMeanAccumulator c1TotalShapExpected; @@ -621,29 +647,8 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) { double c1TotalShapActual[2], c2TotalShapActual[2], c3TotalShapActual[2], c4TotalShapActual[2]; bool hasTotalFeatureImportance{false}; - for (const auto& result : results.GetArray()) { - if (result.HasMember("row_results")) { - std::string targetPrediction{ - result["row_results"]["results"]["ml"]["target_prediction"].GetString()}; - double predictionProbability{ - result["row_results"]["results"]["ml"]["prediction_probability"].GetDouble()}; - double logOdds{std::log(predictionProbability / - (1.0 - predictionProbability + 1e-10))}; - if (targetPrediction == "bar") { - // there are many ways to compute the baseline. This way generalizes to - // the multi-class classification - baselineBarAccumulator.add(logOdds); - baselineFooAccumulator.add(-logOdds); - } else if (targetPrediction == "foo") { - baselineFooAccumulator.add(logOdds); - baselineBarAccumulator.add(-logOdds); - } else { - BOOST_TEST_FAIL("Unknown predicted class " + targetPrediction); - } - } - } - double baselineFoo{maths::CBasicStatistics::mean(baselineFooAccumulator)}; - double baselineBar{maths::CBasicStatistics::mean(baselineBarAccumulator)}; + double baselineFoo{readBaselineValue(results, "foo")}; + double baselineBar{readBaselineValue(results, "bar")}; BOOST_TEST_REQUIRE(baselineFoo == -baselineBar); TStrVec classes{"foo", "bar"}; for (const auto& result : results.GetArray()) { @@ -728,19 +733,12 @@ BOOST_FIXTURE_TEST_CASE(testMultiClassClassificationFeatureImportanceAllShap, SF double c4TotalShapActual[3]; bool hasTotalFeatureImportance{false}; TStrVec classes{"foo", "bar", "baz"}; - TMeanAccumulatorVec baselineAccumulator(3); + TDoubleVec baselines; + baselines.reserve(3); // get baselines - for (const auto& result : results.GetArray()) { - if (result.HasMember("row_results")) { - for (std::size_t i = 0; i < classes.size(); ++i) { - double classProbability{readClassProbability(result, classes[i])}; - double logOdds = - std::log(classProbability / (1.0 - classProbability + 1e-10)); - baselineAccumulator[i].add(logOdds); - } - } + for (const auto& className : classes) { + baselines.push_back(readBaselineValue(results, className)); } - double localApproximations[3]; double classProbabilities[3]; for (const auto& result : results.GetArray()) { @@ -766,7 +764,8 @@ BOOST_FIXTURE_TEST_CASE(testMultiClassClassificationFeatureImportanceAllShap, SF c4TotalShapExpected[i].add(std::abs(c4ClassName)); double classProbability{readClassProbability(result, classes[i])}; - double localApproximation{c1ClassName + c2ClassName + c3ClassName + c4ClassName}; + double localApproximation{baselines[i] + c1ClassName + + c2ClassName + c3ClassName + c4ClassName}; localApproximations[i] = localApproximation; classProbabilities[i] = classProbability; denominator += std::exp(localApproximation); @@ -775,9 +774,8 @@ BOOST_FIXTURE_TEST_CASE(testMultiClassClassificationFeatureImportanceAllShap, SF // Test that sum of feature importances is a local approximations of // prediction probabilities for all classes for (std::size_t i = 0; i < classes.size(); ++i) { - BOOST_REQUIRE_SMALL(classProbabilities[i] - - std::exp(localApproximations[i]) / denominator, - 1e-3); + BOOST_REQUIRE_CLOSE(classProbabilities[i], + std::exp(localApproximations[i]) / denominator, 1.0); } // We should have at least one feature that is important @@ -803,15 +801,14 @@ BOOST_FIXTURE_TEST_CASE(testMultiClassClassificationFeatureImportanceAllShap, SF LOG_INFO(<< "Incorrect results, missing total shap values: " << resultsPair.second); } - // TODO now I cannot test for feature - // BOOST_REQUIRE_CLOSE(c1TotalShapActual[i], - // maths::CBasicStatistics::mean(c1TotalShapExpected[i]), 1.0); - // BOOST_REQUIRE_CLOSE(c2TotalShapActual[i], - // maths::CBasicStatistics::mean(c2TotalShapExpected[i]), 1.0); - // BOOST_REQUIRE_CLOSE(c3TotalShapActual[i], - // maths::CBasicStatistics::mean(c3TotalShapExpected[i]), 1.0); - // BOOST_REQUIRE_CLOSE(c4TotalShapActual[i], - // maths::CBasicStatistics::mean(c4TotalShapExpected[i]), 1.0); + BOOST_REQUIRE_CLOSE(c1TotalShapActual[i], + maths::CBasicStatistics::mean(c1TotalShapExpected[i]), 1.0); + BOOST_REQUIRE_CLOSE(c2TotalShapActual[i], + maths::CBasicStatistics::mean(c2TotalShapExpected[i]), 1.0); + BOOST_REQUIRE_CLOSE(c3TotalShapActual[i], + maths::CBasicStatistics::mean(c3TotalShapExpected[i]), 1.0); + BOOST_REQUIRE_CLOSE(c4TotalShapActual[i], + maths::CBasicStatistics::mean(c4TotalShapExpected[i]), 1.0); } } diff --git a/lib/maths/CTreeShapFeatureImportance.cc b/lib/maths/CTreeShapFeatureImportance.cc index b80ece0112..809726026e 100644 --- a/lib/maths/CTreeShapFeatureImportance.cc +++ b/lib/maths/CTreeShapFeatureImportance.cc @@ -367,10 +367,12 @@ const CTreeShapFeatureImportance::TStrVec& CTreeShapFeatureImportance::columnNam return m_ColumnNames; } -double CTreeShapFeatureImportance::baseline(std::size_t classIdx) const { - double result{0.0}; +CTreeShapFeatureImportance::TVector CTreeShapFeatureImportance::baseline() const { + // The root node, i.e. the first node in each tree, value is set to the average of the + // tree's leaf values. So we compute the baseline simply by averaging root node values. + TVector result{las::zero((*m_Forest)[0][0].value())}; for (const auto& tree : *m_Forest) { - result += tree[0].value()(classIdx); + result += tree[0].value(); } return result; }