diff --git a/include/api/CDataFrameAnalysisConfigReader.h b/include/api/CDataFrameAnalysisConfigReader.h index f4a086033a..974f74a708 100644 --- a/include/api/CDataFrameAnalysisConfigReader.h +++ b/include/api/CDataFrameAnalysisConfigReader.h @@ -84,6 +84,8 @@ class API_EXPORT CDataFrameAnalysisConfigReader { bool fallback(bool value) const; //! Get an unsigned integer parameter. std::size_t fallback(std::size_t value) const; + //! Get a signed integer parameter. + std::ptrdiff_t fallback(std::ptrdiff_t value) const; //! Get a floating point parameter. double fallback(double value) const; //! Get a string parameter. diff --git a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h index 2ee662eca5..fefeb84b82 100644 --- a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h +++ b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h @@ -86,7 +86,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final core::CRapidJsonConcurrentLineWriter& writer) const; private: - std::size_t m_NumTopClasses; + std::ptrdiff_t m_NumTopClasses; EPredictionFieldType m_PredictionFieldType; mutable CInferenceModelMetadata m_InferenceModelMetadata; }; diff --git a/lib/api/CDataFrameAnalysisConfigReader.cc b/lib/api/CDataFrameAnalysisConfigReader.cc index 20f91b8122..f98960d2b0 100644 --- a/lib/api/CDataFrameAnalysisConfigReader.cc +++ b/lib/api/CDataFrameAnalysisConfigReader.cc @@ -112,6 +112,17 @@ std::size_t CDataFrameAnalysisConfigReader::CParameter::fallback(std::size_t val return m_Value->GetUint64(); } +std::ptrdiff_t CDataFrameAnalysisConfigReader::CParameter::fallback(std::ptrdiff_t value) const { + if (m_Value == nullptr) { + return value; + } + if (m_Value->IsInt64() == false) { + this->handleFatal(); + return value; + } + return m_Value->GetInt64(); +} + double CDataFrameAnalysisConfigReader::CParameter::fallback(double value) const { if (m_Value == nullptr) { return value; diff --git a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc index 78730cb95c..ce61c06abc 100644 --- a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc @@ -81,7 +81,7 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier : CDataFrameTrainBoostedTreeRunner{ spec, parameters, loss(parameters[NUM_CLASSES].as())} { - m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0}); + m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::ptrdiff_t{0}); m_PredictionFieldType = parameters[PREDICTION_FIELD_TYPE].fallback(E_PredictionFieldTypeString); this->boostedTreeFactory().classAssignmentObjective( @@ -138,14 +138,18 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( writer.Key(IS_TRAINING_FIELD_NAME); writer.Bool(maths::CDataFrameUtils::isMissing(actualClassId) == false); - if (m_NumTopClasses > 0) { + if (m_NumTopClasses != 0) { TSizeVec classIds(scores.size()); std::iota(classIds.begin(), classIds.end(), 0); std::sort(classIds.begin(), classIds.end(), [&scores](std::size_t lhs, std::size_t rhs) { return scores[lhs] > scores[rhs]; }); - classIds.resize(std::min(classIds.size(), m_NumTopClasses)); + // -1 is a special value meaning "output all the classes" + classIds.resize(m_NumTopClasses == -1 + ? classIds.size() + : std::min(classIds.size(), + static_cast(m_NumTopClasses))); writer.Key(TOP_CLASSES_FIELD_NAME); writer.StartArray(); for (std::size_t i : classIds) {