diff --git a/lib/maths/CBoostedTreeFactory.cc b/lib/maths/CBoostedTreeFactory.cc index 8acaac45ce..61048b20b8 100644 --- a/lib/maths/CBoostedTreeFactory.cc +++ b/lib/maths/CBoostedTreeFactory.cc @@ -181,42 +181,70 @@ void CBoostedTreeFactory::initializeHyperparameterOptimisation() const { // less than p_1, this translates to using log parameter values. CBayesianOptimisation::TDoubleDoublePrVec boundingBox; - if (m_TreeImpl->m_DownsampleFactorOverride == boost::none) { - boundingBox.emplace_back( - m_LogDownsampleFactorSearchInterval(MIN_REGULARIZER_INDEX), - m_LogDownsampleFactorSearchInterval(MAX_REGULARIZER_INDEX)); - } - if (m_TreeImpl->m_RegularizationOverride.depthPenaltyMultiplier() == boost::none) { - boundingBox.emplace_back( - m_LogDepthPenaltyMultiplierSearchInterval(MIN_REGULARIZER_INDEX), - m_LogDepthPenaltyMultiplierSearchInterval(MAX_REGULARIZER_INDEX)); - } - if (m_TreeImpl->m_RegularizationOverride.leafWeightPenaltyMultiplier() == boost::none) { - boundingBox.emplace_back( - m_LogLeafWeightPenaltyMultiplierSearchInterval(MIN_REGULARIZER_INDEX), - m_LogLeafWeightPenaltyMultiplierSearchInterval(MAX_REGULARIZER_INDEX)); - } - if (m_TreeImpl->m_RegularizationOverride.treeSizePenaltyMultiplier() == boost::none) { - boundingBox.emplace_back( - m_LogTreeSizePenaltyMultiplierSearchInterval(MIN_REGULARIZER_INDEX), - m_LogTreeSizePenaltyMultiplierSearchInterval(MAX_REGULARIZER_INDEX)); - } - if (m_TreeImpl->m_RegularizationOverride.softTreeDepthLimit() == boost::none) { - boundingBox.emplace_back(m_SoftDepthLimitSearchInterval(MIN_REGULARIZER_INDEX), - m_SoftDepthLimitSearchInterval(MAX_REGULARIZER_INDEX)); - } - if (m_TreeImpl->m_RegularizationOverride.softTreeDepthTolerance() == boost::none) { - boundingBox.emplace_back(MIN_SOFT_DEPTH_LIMIT_TOLERANCE, MAX_SOFT_DEPTH_LIMIT_TOLERANCE); - } - if (m_TreeImpl->m_EtaOverride == boost::none) { - double rate{m_TreeImpl->m_EtaGrowthRatePerTree - 1.0}; - boundingBox.emplace_back(m_LogEtaSearchInterval(MIN_REGULARIZER_INDEX), - m_LogEtaSearchInterval(MAX_REGULARIZER_INDEX)); - boundingBox.emplace_back(1.0 + MIN_ETA_GROWTH_RATE_SCALE * rate, - 1.0 + MAX_ETA_GROWTH_RATE_SCALE * rate); - } - if (m_TreeImpl->m_FeatureBagFractionOverride == boost::none) { - boundingBox.emplace_back(MIN_FEATURE_BAG_FRACTION, MAX_FEATURE_BAG_FRACTION); + for (int i = 0; i < static_cast(NUMBER_HYPERPARAMETERS); ++i) { + switch (i) { + case E_DownsampleFactor: + if (m_TreeImpl->m_DownsampleFactorOverride == boost::none) { + boundingBox.emplace_back( + m_LogDownsampleFactorSearchInterval(MIN_REGULARIZER_INDEX), + m_LogDownsampleFactorSearchInterval(MAX_REGULARIZER_INDEX)); + } + break; + case E_Alpha: + if (m_TreeImpl->m_RegularizationOverride.depthPenaltyMultiplier() == boost::none) { + boundingBox.emplace_back( + m_LogDepthPenaltyMultiplierSearchInterval(MIN_REGULARIZER_INDEX), + m_LogDepthPenaltyMultiplierSearchInterval(MAX_REGULARIZER_INDEX)); + } + break; + case E_Lambda: + if (m_TreeImpl->m_RegularizationOverride.leafWeightPenaltyMultiplier() == + boost::none) { + boundingBox.emplace_back( + m_LogLeafWeightPenaltyMultiplierSearchInterval(MIN_REGULARIZER_INDEX), + m_LogLeafWeightPenaltyMultiplierSearchInterval(MAX_REGULARIZER_INDEX)); + } + break; + case E_Gamma: + if (m_TreeImpl->m_RegularizationOverride.treeSizePenaltyMultiplier() == + boost::none) { + boundingBox.emplace_back( + m_LogTreeSizePenaltyMultiplierSearchInterval(MIN_REGULARIZER_INDEX), + m_LogTreeSizePenaltyMultiplierSearchInterval(MAX_REGULARIZER_INDEX)); + } + break; + case E_SoftTreeDepthLimit: + if (m_TreeImpl->m_RegularizationOverride.softTreeDepthLimit() == boost::none) { + boundingBox.emplace_back( + m_SoftDepthLimitSearchInterval(MIN_REGULARIZER_INDEX), + m_SoftDepthLimitSearchInterval(MAX_REGULARIZER_INDEX)); + } + break; + case E_SoftTreeDepthTolerance: + if (m_TreeImpl->m_RegularizationOverride.softTreeDepthTolerance() == boost::none) { + boundingBox.emplace_back(MIN_SOFT_DEPTH_LIMIT_TOLERANCE, + MAX_SOFT_DEPTH_LIMIT_TOLERANCE); + } + break; + case E_Eta: + if (m_TreeImpl->m_EtaOverride == boost::none) { + boundingBox.emplace_back(m_LogEtaSearchInterval(MIN_REGULARIZER_INDEX), + m_LogEtaSearchInterval(MAX_REGULARIZER_INDEX)); + } + break; + case E_EtaGrowthRatePerTree: + if (m_TreeImpl->m_EtaOverride == boost::none) { + double rate{m_TreeImpl->m_EtaGrowthRatePerTree - 1.0}; + boundingBox.emplace_back(1.0 + MIN_ETA_GROWTH_RATE_SCALE * rate, + 1.0 + MAX_ETA_GROWTH_RATE_SCALE * rate); + } + break; + case E_FeatureBagFraction: + if (m_TreeImpl->m_FeatureBagFractionOverride == boost::none) { + boundingBox.emplace_back(MIN_FEATURE_BAG_FRACTION, MAX_FEATURE_BAG_FRACTION); + } + break; + } } LOG_TRACE(<< "hyperparameter search bounding box = " << core::CContainerPrinter::print(boundingBox)); diff --git a/lib/maths/CBoostedTreeImpl.cc b/lib/maths/CBoostedTreeImpl.cc index a396dbdd36..c08f1a94eb 100644 --- a/lib/maths/CBoostedTreeImpl.cc +++ b/lib/maths/CBoostedTreeImpl.cc @@ -1229,40 +1229,50 @@ bool CBoostedTreeImpl::selectNextHyperparameters(const TMeanVarAccumulator& loss // of each of the base learners so we scale the other regularisation terms // and the weight shrinkage to compensate. double scale{1.0}; - if (minBoundary.size() > 0) { - scale = std::min(scale, 2.0 * m_DownsampleFactor / - (CTools::stableExp(minBoundary(0)) + - CTools::stableExp(maxBoundary(0)))); + if (m_DownsampleFactorOverride == boost::none) { + auto i = std::distance(m_TunableHyperparameters.begin(), + std::find(m_TunableHyperparameters.begin(), + m_TunableHyperparameters.end(), E_DownsampleFactor)); + if (static_cast(i) < m_TunableHyperparameters.size()) { + scale = std::min(1.0, 2.0 * m_DownsampleFactor / + (CTools::stableExp(minBoundary(i)) + + CTools::stableExp(maxBoundary(i)))); + } } // Read parameters for last round. - int i{0}; - if (m_DownsampleFactorOverride == boost::none) { - parameters(i++) = CTools::stableLog(m_DownsampleFactor); - } - if (m_RegularizationOverride.depthPenaltyMultiplier() == boost::none) { - parameters(i++) = CTools::stableLog(m_Regularization.depthPenaltyMultiplier()); - } - if (m_RegularizationOverride.leafWeightPenaltyMultiplier() == boost::none) { - parameters(i++) = - CTools::stableLog(m_Regularization.leafWeightPenaltyMultiplier() / scale); - } - if (m_RegularizationOverride.treeSizePenaltyMultiplier() == boost::none) { - parameters(i++) = - CTools::stableLog(m_Regularization.treeSizePenaltyMultiplier() / scale); - } - if (m_RegularizationOverride.softTreeDepthLimit() == boost::none) { - parameters(i++) = m_Regularization.softTreeDepthLimit(); - } - if (m_RegularizationOverride.softTreeDepthTolerance() == boost::none) { - parameters(i++) = m_Regularization.softTreeDepthTolerance(); - } - if (m_EtaOverride == boost::none) { - parameters(i++) = CTools::stableLog(m_Eta) / scale; - parameters(i++) = m_EtaGrowthRatePerTree; - } - if (m_FeatureBagFractionOverride == boost::none) { - parameters(i++) = m_FeatureBagFraction; + for (std::size_t i = 0; i < m_TunableHyperparameters.size(); ++i) { + switch (m_TunableHyperparameters[i]) { + case E_Alpha: + parameters(i) = CTools::stableLog(m_Regularization.depthPenaltyMultiplier()); + break; + case E_DownsampleFactor: + parameters(i) = CTools::stableLog(m_DownsampleFactor); + break; + case E_Eta: + parameters(i) = CTools::stableLog(m_Eta) / scale; + break; + case E_EtaGrowthRatePerTree: + parameters(i) = m_EtaGrowthRatePerTree; + break; + case E_FeatureBagFraction: + parameters(i) = m_FeatureBagFraction; + break; + case E_Gamma: + parameters(i) = CTools::stableLog( + m_Regularization.treeSizePenaltyMultiplier() / scale); + break; + case E_Lambda: + parameters(i) = CTools::stableLog( + m_Regularization.leafWeightPenaltyMultiplier() / scale); + break; + case E_SoftTreeDepthLimit: + parameters(i) = m_Regularization.softTreeDepthLimit(); + break; + case E_SoftTreeDepthTolerance: + parameters(i) = m_Regularization.softTreeDepthTolerance(); + break; + } } double meanLoss{CBasicStatistics::mean(lossMoments)}; @@ -1291,41 +1301,43 @@ bool CBoostedTreeImpl::selectNextHyperparameters(const TMeanVarAccumulator& loss std::find(m_TunableHyperparameters.begin(), m_TunableHyperparameters.end(), E_DownsampleFactor)); if (static_cast(i) < m_TunableHyperparameters.size()) { - scale = std::min(1.0, 2.0 * parameters(i) / - (CTools::stableExp(minBoundary(0)) + - CTools::stableExp(maxBoundary(0)))); + scale = std::min(1.0, 2.0 * CTools::stableExp(parameters(i)) / + (CTools::stableExp(minBoundary(i)) + + CTools::stableExp(maxBoundary(i)))); } } - i = 0; - if (m_DownsampleFactorOverride == boost::none) { - m_DownsampleFactor = CTools::stableExp(parameters(i++)); - scale = std::min(1.0, 2.0 * m_DownsampleFactor / - (CTools::stableExp(minBoundary(0)) + - CTools::stableExp(maxBoundary(0)))); - } - if (m_RegularizationOverride.depthPenaltyMultiplier() == boost::none) { - m_Regularization.depthPenaltyMultiplier(CTools::stableExp(parameters(i++))); - } - if (m_RegularizationOverride.leafWeightPenaltyMultiplier() == boost::none) { - m_Regularization.leafWeightPenaltyMultiplier( - scale * CTools::stableExp(parameters(i++))); - } - if (m_RegularizationOverride.treeSizePenaltyMultiplier() == boost::none) { - m_Regularization.treeSizePenaltyMultiplier( - scale * CTools::stableExp(parameters(i++))); - } - if (m_RegularizationOverride.softTreeDepthLimit() == boost::none) { - m_Regularization.softTreeDepthLimit(parameters(i++)); - } - if (m_RegularizationOverride.softTreeDepthTolerance() == boost::none) { - m_Regularization.softTreeDepthTolerance(parameters(i++)); - } - if (m_EtaOverride == boost::none) { - m_Eta = CTools::stableExp(scale * parameters(i++)); - m_EtaGrowthRatePerTree = parameters(i++); - } - if (m_FeatureBagFractionOverride == boost::none) { - m_FeatureBagFraction = parameters(i++); + for (std::size_t i = 0; i < m_TunableHyperparameters.size(); ++i) { + switch (m_TunableHyperparameters[i]) { + case E_Alpha: + m_Regularization.depthPenaltyMultiplier(CTools::stableExp(parameters(i))); + break; + case E_DownsampleFactor: + m_DownsampleFactor = CTools::stableExp(parameters(i)); + break; + case E_Eta: + m_Eta = CTools::stableExp(scale * parameters(i)); + break; + case E_EtaGrowthRatePerTree: + m_EtaGrowthRatePerTree = parameters(i); + break; + case E_FeatureBagFraction: + m_FeatureBagFraction = parameters(i); + break; + case E_Gamma: + m_Regularization.treeSizePenaltyMultiplier( + scale * CTools::stableExp(parameters(i))); + break; + case E_Lambda: + m_Regularization.leafWeightPenaltyMultiplier( + scale * CTools::stableExp(parameters(i))); + break; + case E_SoftTreeDepthLimit: + m_Regularization.softTreeDepthLimit(parameters(i)); + break; + case E_SoftTreeDepthTolerance: + m_Regularization.softTreeDepthTolerance(parameters(i)); + break; + } } return true;