Skip to content

Commit

Permalink
double to float with GradientPair usage
Browse files Browse the repository at this point in the history
  • Loading branch information
SHVETS, KIRILL authored and ShvetsKS committed May 1, 2020
1 parent 8de7f19 commit 19120cf
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 30 deletions.
19 changes: 19 additions & 0 deletions include/xgboost/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,25 @@ class GradientPairInternal {
public:
using ValueT = T;

inline void Add(const GradientPairInternal& b) {
grad_ += b.grad_;
hess_ += b.hess_;
}

inline void Add(const ValueT& grad, const ValueT& hess) {
grad_ += grad;
hess_ += hess;
}

inline void SetSubstract(const GradientPairInternal& a, const GradientPairInternal& b) {
grad_ = a.grad_ - b.grad_;
hess_ = a.hess_ - b.hess_;
}

inline static void Reduce(GradientPairInternal& a, const GradientPairInternal& b) { // NOLINT(*)
a.Add(b);
}

XGBOOST_DEVICE GradientPairInternal() : grad_(0), hess_(0) {}

XGBOOST_DEVICE GradientPairInternal(T grad, T hess) {
Expand Down
14 changes: 7 additions & 7 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -832,17 +832,17 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
*/
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) {
#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
std::fill(hist.begin() + begin, hist.begin() + end, tree::GradStats());
std::fill(hist.begin() + begin, hist.begin() + end, GradientPair());
#else // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
memset(hist.data() + begin, '\0', (end-begin)*sizeof(tree::GradStats));
memset(hist.data() + begin, '\0', (end-begin)*sizeof(GradientPair));
#endif // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
}

/*!
* \brief Increment hist as dst += add in range [begin, end)
*/
void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
using FPType = GradientPair::ValueT;
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
const FPType* padd = reinterpret_cast<const FPType*>(add.data());

Expand All @@ -855,7 +855,7 @@ void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end) {
* \brief Copy hist from src to dst in range [begin, end)
*/
void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
using FPType = GradientPair::ValueT;
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
const FPType* psrc = reinterpret_cast<const FPType*>(src.data());

Expand All @@ -869,7 +869,7 @@ void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end) {
*/
void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2,
size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
using FPType = GradientPair::ValueT;
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
const FPType* psrc1 = reinterpret_cast<const FPType*>(src1.data());
const FPType* psrc2 = reinterpret_cast<const FPType*>(src2.data());
Expand Down Expand Up @@ -1027,7 +1027,7 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
const GHistIndexMatrix& gmat,
GHistRow hist,
bool isDense) {
using FPType = decltype(tree::GradStats::sum_grad);
using FPType = GradientPair::ValueT;
const size_t nrows = row_indices.Size();
const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows);

Expand Down Expand Up @@ -1058,7 +1058,7 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
#if defined(_OPENMP)
const auto nthread = static_cast<bst_omp_uint>(this->nthread_); // NOLINT
#endif // defined(_OPENMP)
tree::GradStats* p_hist = hist.data();
GradientPair* p_hist = hist.data();

#pragma omp parallel for num_threads(nthread) schedule(guided)
for (bst_omp_uint bid = 0; bid < nblock; ++bid) {
Expand Down
14 changes: 4 additions & 10 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,13 +401,7 @@ class GHistIndexBlockMatrix {
std::vector<Block> blocks_;
};

/*!
* \brief histogram of gradient statistics for a single node.
* Consists of multiple GradStats, each entry showing total gradient statistics
* for that particular bin
* Uses global bin id so as to represent all features simultaneously
*/
using GHistRow = Span<tree::GradStats>;
using GHistRow = Span<GradientPair>;

/*!
* \brief fill a histogram by zeros
Expand Down Expand Up @@ -439,8 +433,8 @@ class HistCollection {
GHistRow operator[](bst_uint nid) const {
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
CHECK_NE(row_ptr_[nid], kMax);
tree::GradStats* ptr =
const_cast<tree::GradStats*>(dmlc::BeginPtr(data_) + row_ptr_[nid]);
GradientPair* ptr =
const_cast<GradientPair*>(dmlc::BeginPtr(data_) + row_ptr_[nid]);
return {ptr, nbins_};
}

Expand Down Expand Up @@ -483,7 +477,7 @@ class HistCollection {
/*! \brief amount of active nodes in hist collection */
uint32_t n_nodes_added_ = 0;

std::vector<tree::GradStats> data_;
std::vector<GradientPair> data_;

/*! \brief row_ptr_[nid] locates bin for histogram of node nid */
std::vector<size_t> row_ptr_;
Expand Down
19 changes: 12 additions & 7 deletions src/tree/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,14 +332,15 @@ XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad)

/*! \brief core statistics used for tree construction */
struct XGBOOST_ALIGNAS(16) GradStats {
using GradType = double;
/*! \brief sum gradient statistics */
double sum_grad { 0 };
GradType sum_grad { 0 };
/*! \brief sum hessian statistics */
double sum_hess { 0 };
GradType sum_hess { 0 };

public:
XGBOOST_DEVICE double GetGrad() const { return sum_grad; }
XGBOOST_DEVICE double GetHess() const { return sum_hess; }
XGBOOST_DEVICE GradType GetGrad() const { return sum_grad; }
XGBOOST_DEVICE GradType GetHess() const { return sum_hess; }

friend std::ostream& operator<<(std::ostream& os, GradStats s) {
os << s.GetGrad() << "/" << s.GetHess();
Expand All @@ -354,7 +355,7 @@ struct XGBOOST_ALIGNAS(16) GradStats {
template <typename GpairT>
XGBOOST_DEVICE explicit GradStats(const GpairT &sum)
: sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {}
explicit GradStats(const double grad, const double hess)
explicit GradStats(const GradType grad, const GradType hess)
: sum_grad(grad), sum_hess(hess) {}
/*!
* \brief accumulate statistics
Expand All @@ -379,7 +380,7 @@ struct XGBOOST_ALIGNAS(16) GradStats {
/*! \return whether the statistics is not used yet */
inline bool Empty() const { return sum_hess == 0.0; }
/*! \brief add statistics to the data */
inline void Add(double grad, double hess) {
inline void Add(GradType grad, GradType hess) {
sum_grad += grad;
sum_hess += hess;
}
Expand Down Expand Up @@ -425,7 +426,11 @@ struct SplitEntryContainer {
* \param split_index the feature index where the split is on
*/
bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
if (this->SplitIndex() <= split_index) {
if (std::isinf(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf,
// for example when lambda = 0 & min_child_weight = 0
// skip value in this case
return false;
} else if (this->SplitIndex() <= split_index) {
return new_loss_chg > this->loss_chg;
} else {
return !(this->loss_chg > new_loss_chg);
Expand Down
9 changes: 9 additions & 0 deletions src/tree/split_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ class ElasticNet final : public SplitEvaluator {
return 0.0;
}
}
inline float ThresholdL1(float g) const {
if (g > params_->reg_alpha) {
return g - params_->reg_alpha;
} else if (g < -params_->reg_alpha) {
return g + params_->reg_alpha;
} else {
return 0.0;
}
}
};

XGBOOST_REGISTER_SPLIT_EVALUATOR(ElasticNet, "elastic_net")
Expand Down
10 changes: 6 additions & 4 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1077,23 +1077,25 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
{
auto& stats = snode_[nid].stats;
GHistRow hist = hist_[nid];
GradientPair grad_stat;
if (tree[nid].IsRoot()) {
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
const std::vector<uint32_t>& row_ptr = gmat.cut.Ptrs();
const uint32_t ibegin = row_ptr[fid_least_bins_];
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
auto begin = hist.data();
for (uint32_t i = ibegin; i < iend; ++i) {
const GradStats et = begin[i];
stats.Add(et.sum_grad, et.sum_hess);
const GradientPair et = begin[i];
grad_stat.Add(et.GetGrad(), et.GetHess());
}
} else {
const RowSetCollection::Elem e = row_set_collection_[nid];
for (const size_t* it = e.begin; it < e.end; ++it) {
stats.Add(gpair[*it]);
grad_stat.Add(gpair[*it]);
}
}
histred_.Allreduce(&snode_[nid].stats, 1);
histred_.Allreduce(&grad_stat, 1);
snode_[nid].stats = tree::GradStats(grad_stat.GetGrad(), grad_stat.GetHess());
} else {
int parent_id = tree[nid].Parent();
if (tree[nid].IsLeftChild()) {
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_quantile_hist.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ class QuantileHistMaker: public TreeUpdater {

common::Monitor builder_monitor_;
common::ParallelGHistBuilder hist_buffer_;
rabit::Reducer<GradStats, GradStats::Reduce> histred_;
rabit::Reducer<GradientPair, GradientPair::Reduce> histred_;
};

std::unique_ptr<Builder> builder_;
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def test_kwargs_grid_search():
from sklearn import datasets

params = {'tree_method': 'hist'}
clf = xgb.XGBClassifier(n_estimators=1, learning_rate=1.0, **params)
clf = xgb.XGBClassifier(n_estimators=5, learning_rate=1.0, **params)
assert clf.get_params()['tree_method'] == 'hist'
# 'max_leaves' is not a default argument of XGBClassifier
# Check we can still do grid search over this parameter
Expand Down

0 comments on commit 19120cf

Please sign in to comment.