Skip to content

Commit

Permalink
[ML] Allow unbounded num_top_classes in classification analysis (#1526)
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek authored Oct 12, 2020
1 parent 499338d commit 49dd77c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
2 changes: 2 additions & 0 deletions include/api/CDataFrameAnalysisConfigReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class API_EXPORT CDataFrameAnalysisConfigReader {
bool fallback(bool value) const;
//! Get an unsigned integer parameter.
std::size_t fallback(std::size_t value) const;
//! Get a signed integer parameter.
std::ptrdiff_t fallback(std::ptrdiff_t value) const;
//! Get a floating point parameter.
double fallback(double value) const;
//! Get a string parameter.
Expand Down
2 changes: 1 addition & 1 deletion include/api/CDataFrameTrainBoostedTreeClassifierRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
core::CRapidJsonConcurrentLineWriter& writer) const;

private:
std::size_t m_NumTopClasses;
std::ptrdiff_t m_NumTopClasses;
EPredictionFieldType m_PredictionFieldType;
mutable CInferenceModelMetadata m_InferenceModelMetadata;
};
Expand Down
11 changes: 11 additions & 0 deletions lib/api/CDataFrameAnalysisConfigReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ std::size_t CDataFrameAnalysisConfigReader::CParameter::fallback(std::size_t val
return m_Value->GetUint64();
}

std::ptrdiff_t CDataFrameAnalysisConfigReader::CParameter::fallback(std::ptrdiff_t value) const {
if (m_Value == nullptr) {
return value;
}
if (m_Value->IsInt64() == false) {
this->handleFatal();
return value;
}
return m_Value->GetInt64();
}

double CDataFrameAnalysisConfigReader::CParameter::fallback(double value) const {
if (m_Value == nullptr) {
return value;
Expand Down
10 changes: 7 additions & 3 deletions lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier
: CDataFrameTrainBoostedTreeRunner{
spec, parameters, loss(parameters[NUM_CLASSES].as<std::size_t>())} {

m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0});
m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::ptrdiff_t{0});
m_PredictionFieldType =
parameters[PREDICTION_FIELD_TYPE].fallback(E_PredictionFieldTypeString);
this->boostedTreeFactory().classAssignmentObjective(
Expand Down Expand Up @@ -138,14 +138,18 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
writer.Key(IS_TRAINING_FIELD_NAME);
writer.Bool(maths::CDataFrameUtils::isMissing(actualClassId) == false);

if (m_NumTopClasses > 0) {
if (m_NumTopClasses != 0) {
TSizeVec classIds(scores.size());
std::iota(classIds.begin(), classIds.end(), 0);
std::sort(classIds.begin(), classIds.end(),
[&scores](std::size_t lhs, std::size_t rhs) {
return scores[lhs] > scores[rhs];
});
classIds.resize(std::min(classIds.size(), m_NumTopClasses));
// -1 is a special value meaning "output all the classes"
classIds.resize(m_NumTopClasses == -1
? classIds.size()
: std::min(classIds.size(),
static_cast<std::size_t>(m_NumTopClasses)));
writer.Key(TOP_CLASSES_FIELD_NAME);
writer.StartArray();
for (std::size_t i : classIds) {
Expand Down

0 comments on commit 49dd77c

Please sign in to comment.