Skip to content

Commit

Permalink
[ML] Store feature importance baselines in model metadata (#1522)
Browse files Browse the repository at this point in the history
With this PR we will be able to store the feature importance baselines explicitly in the model_metadata. Being able baseline 
to retrieve the baselines will significantly simplify UI code related to the feature importance visualization.
  • Loading branch information
valeriy42 authored Oct 5, 2020
1 parent 88c8951 commit 3fc14ae
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 123 deletions.
8 changes: 8 additions & 0 deletions include/api/CInferenceModelMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ namespace api {
//! (such as totol feature importance) into JSON format.
class API_EXPORT CInferenceModelMetadata {
public:
static const std::string JSON_BASELINE_TAG;
static const std::string JSON_FEATURE_IMPORTANCE_BASELINE_TAG;
static const std::string JSON_CLASS_NAME_TAG;
static const std::string JSON_CLASSES_TAG;
static const std::string JSON_FEATURE_NAME_TAG;
Expand Down Expand Up @@ -48,20 +50,26 @@ class API_EXPORT CInferenceModelMetadata {
//! Add importances \p values to the feature with index \p i to calculate total feature importance.
//! Total feature importance is the mean of the magnitudes of importances for individual data points.
void addToFeatureImportance(std::size_t i, const TVector& values);
//! Set the feature importance baseline (the individual feature importances are additive corrections
//! to the baseline value).
void featureImportanceBaseline(TVector&& baseline);

private:
using TMeanAccumulator =
std::vector<maths::CBasicStatistics::SSampleMean<double>::TAccumulator>;
using TMinMaxAccumulator = std::vector<maths::CBasicStatistics::CMinMax<double>>;
using TSizeMeanAccumulatorUMap = std::unordered_map<std::size_t, TMeanAccumulator>;
using TSizeMinMaxAccumulatorUMap = std::unordered_map<std::size_t, TMinMaxAccumulator>;
using TOptionalVector = boost::optional<TVector>;

private:
void writeTotalFeatureImportance(TRapidJsonWriter& writer) const;
void writeFeatureImportanceBaseline(TRapidJsonWriter& writer) const;

private:
TSizeMeanAccumulatorUMap m_TotalShapValuesMean;
TSizeMinMaxAccumulatorUMap m_TotalShapValuesMinMax;
TOptionalVector m_ShapBaseline;
TStrVec m_ColumnNames;
TStrVec m_ClassValues;
TPredictionFieldTypeResolverWriter m_PredictionFieldTypeResolverWriter =
Expand Down
2 changes: 1 addition & 1 deletion include/maths/CTreeShapFeatureImportance.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
const TStrVec& columnNames() const;

//! Get the baseline.
double baseline(std::size_t classIdx = 0) const;
TVector baseline() const;

private:
//! Collects the elements of the path through decision tree that are updated together
Expand Down
113 changes: 52 additions & 61 deletions lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,74 +169,61 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
[this](const std::string& categoryValue, core::CRapidJsonConcurrentLineWriter& writer) {
this->writePredictedCategoryValue(categoryValue, writer);
});
featureImportance->shap(row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& featureNames,
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
writer.Key(FEATURE_IMPORTANCE_FIELD_NAME);
writer.StartArray();
TDoubleVec baseline;
baseline.reserve(numberClasses);
for (std::size_t j = 0; j < shap[0].size() && j < numberClasses; ++j) {
baseline.push_back(featureImportance->baseline(j));
}
for (auto i : indices) {
if (shap[i].norm() != 0.0) {
writer.StartObject();
writer.Key(FEATURE_NAME_FIELD_NAME);
writer.String(featureNames[i]);
if (shap[i].size() == 1) {
// output feature importance for individual classes in binary case
writer.Key(CLASSES_FIELD_NAME);
writer.StartArray();
for (std::size_t j = 0; j < numberClasses; ++j) {
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writePredictedCategoryValue(classValues[j], writer);
writer.Key(IMPORTANCE_FIELD_NAME);
if (j == 1) {
writer.Double(shap[i](0));
} else {
writer.Double(-shap[i](0));
featureImportance->shap(
row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& featureNames,
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
writer.Key(FEATURE_IMPORTANCE_FIELD_NAME);
writer.StartArray();
for (auto i : indices) {
if (shap[i].norm() != 0.0) {
writer.StartObject();
writer.Key(FEATURE_NAME_FIELD_NAME);
writer.String(featureNames[i]);
if (shap[i].size() == 1) {
// output feature importance for individual classes in binary case
writer.Key(CLASSES_FIELD_NAME);
writer.StartArray();
for (std::size_t j = 0; j < numberClasses; ++j) {
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writePredictedCategoryValue(classValues[j], writer);
writer.Key(IMPORTANCE_FIELD_NAME);
if (j == 1) {
writer.Double(shap[i](0));
} else {
writer.Double(-shap[i](0));
}
writer.EndObject();
}
writer.EndObject();
}
writer.EndArray();
} else {
// output feature importance for individual classes in multiclass case
writer.Key(CLASSES_FIELD_NAME);
writer.StartArray();
TDoubleVec featureImportanceSum(numberClasses, 0.0);
for (std::size_t j = 0;
j < shap[i].size() && j < numberClasses; ++j) {
for (auto k : indices) {
featureImportanceSum[j] += shap[k](j);
writer.EndArray();
} else {
// output feature importance for individual classes in multiclass case
writer.Key(CLASSES_FIELD_NAME);
writer.StartArray();
for (std::size_t j = 0;
j < shap[i].size() && j < numberClasses; ++j) {
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writePredictedCategoryValue(classValues[j], writer);
writer.Key(IMPORTANCE_FIELD_NAME);
writer.Double(shap[i](j));
writer.EndObject();
}
writer.EndArray();
}
for (std::size_t j = 0;
j < shap[i].size() && j < numberClasses; ++j) {
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writePredictedCategoryValue(classValues[j], writer);
writer.Key(IMPORTANCE_FIELD_NAME);
double correctedShap{
shap[i](j) * (baseline[j] / featureImportanceSum[j] + 1.0)};
writer.Double(correctedShap);
writer.EndObject();
}
writer.EndArray();
writer.EndObject();
}
writer.EndObject();
}
}
writer.EndArray();
writer.EndArray();

for (std::size_t i = 0; i < shap.size(); ++i) {
if (shap[i].lpNorm<1>() != 0) {
const_cast<CDataFrameTrainBoostedTreeClassifierRunner*>(this)
->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
for (std::size_t i = 0; i < shap.size(); ++i) {
if (shap[i].lpNorm<1>() != 0) {
const_cast<CDataFrameTrainBoostedTreeClassifierRunner*>(this)
->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
}
}
}
});
});
}
writer.EndObject();
}
Expand Down Expand Up @@ -306,6 +293,10 @@ CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelDefinition(

CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelMetadata() const {
const auto& featureImportance = this->boostedTree().shap();
if (featureImportance) {
m_InferenceModelMetadata.featureImportanceBaseline(featureImportance->baseline());
}
return m_InferenceModelMetadata;
}

Expand Down
6 changes: 5 additions & 1 deletion lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelDefinition(

CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelMetadata() const {
return TOptionalInferenceModelMetadata(m_InferenceModelMetadata);
const auto& featureImportance = this->boostedTree().shap();
if (featureImportance) {
m_InferenceModelMetadata.featureImportanceBaseline(featureImportance->baseline());
}
return m_InferenceModelMetadata;
}

// clang-format off
Expand Down
54 changes: 54 additions & 0 deletions lib/api/CInferenceModelMetadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace api {

void CInferenceModelMetadata::write(TRapidJsonWriter& writer) const {
this->writeTotalFeatureImportance(writer);
this->writeFeatureImportanceBaseline(writer);
}

void CInferenceModelMetadata::writeTotalFeatureImportance(TRapidJsonWriter& writer) const {
Expand Down Expand Up @@ -88,6 +89,53 @@ void CInferenceModelMetadata::writeTotalFeatureImportance(TRapidJsonWriter& writ
writer.EndArray();
}

void CInferenceModelMetadata::writeFeatureImportanceBaseline(TRapidJsonWriter& writer) const {
if (m_ShapBaseline) {
writer.Key(JSON_FEATURE_IMPORTANCE_BASELINE_TAG);
writer.StartObject();
if (m_ShapBaseline->size() == 1 && m_ClassValues.empty()) {
// Regression
writer.Key(JSON_BASELINE_TAG);
writer.Double(m_ShapBaseline.get()(0));
} else if (m_ShapBaseline->size() == 1 && m_ClassValues.empty() == false) {
// Binary classification
writer.Key(JSON_CLASSES_TAG);
writer.StartArray();
for (std::size_t j = 0; j < m_ClassValues.size(); ++j) {
writer.StartObject();
writer.Key(JSON_CLASS_NAME_TAG);
m_PredictionFieldTypeResolverWriter(m_ClassValues[j], writer);
writer.Key(JSON_BASELINE_TAG);
if (j == 1) {
writer.Double(m_ShapBaseline.get()(0));
} else {
writer.Double(-m_ShapBaseline.get()(0));
}
writer.EndObject();
}

writer.EndArray();

} else {
// Multiclass classification
writer.Key(JSON_CLASSES_TAG);
writer.StartArray();
for (std::size_t j = 0; j < static_cast<std::size_t>(m_ShapBaseline->size()) &&
j < m_ClassValues.size();
++j) {
writer.StartObject();
writer.Key(JSON_CLASS_NAME_TAG);
m_PredictionFieldTypeResolverWriter(m_ClassValues[j], writer);
writer.Key(JSON_BASELINE_TAG);
writer.Double(m_ShapBaseline.get()(j));
writer.EndObject();
}
writer.EndArray();
}
writer.EndObject();
}
}

const std::string& CInferenceModelMetadata::typeString() const {
return JSON_MODEL_METADATA_TAG;
}
Expand Down Expand Up @@ -119,7 +167,13 @@ void CInferenceModelMetadata::addToFeatureImportance(std::size_t i, const TVecto
}
}

void CInferenceModelMetadata::featureImportanceBaseline(TVector&& baseline) {
m_ShapBaseline = baseline;
}

// clang-format off
const std::string CInferenceModelMetadata::JSON_BASELINE_TAG{"baseline"};
const std::string CInferenceModelMetadata::JSON_FEATURE_IMPORTANCE_BASELINE_TAG{"feature_importance_baseline"};
const std::string CInferenceModelMetadata::JSON_CLASS_NAME_TAG{"class_name"};
const std::string CInferenceModelMetadata::JSON_CLASSES_TAG{"classes"};
const std::string CInferenceModelMetadata::JSON_FEATURE_NAME_TAG{"feature_name"};
Expand Down
Loading

0 comments on commit 3fc14ae

Please sign in to comment.