Skip to content

Commit

Permalink
optimized BuildHist function
Browse files Browse the repository at this point in the history
  • Loading branch information
SmirnovEgorRu committed Dec 31, 2019
1 parent 298ebe6 commit fc565ac
Show file tree
Hide file tree
Showing 8 changed files with 674 additions and 179 deletions.
84 changes: 18 additions & 66 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,93 +659,45 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
}
}

/*!
* \brief fill a histogram by zeroes
*/
void InitilizeHistByZeroes(GHistRow hist) {
memset(hist.data(), '\0', hist.size()*sizeof(tree::GradStats));
}

void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat,
GHistRow hist) {
const size_t nthread = static_cast<size_t>(this->nthread_);
data_.resize(nbins_ * nthread_);

const size_t* rid = row_indices.begin;
const size_t nrows = row_indices.Size();
const uint32_t* index = gmat.index.data();
const size_t* row_ptr = gmat.row_ptr.data();
const float* pgh = reinterpret_cast<const float*>(gpair.data());

double* hist_data = reinterpret_cast<double*>(hist.data());
double* data = reinterpret_cast<double*>(data_.data());

const size_t block_size = 512;
size_t n_blocks = nrows/block_size;
n_blocks += !!(nrows - n_blocks*block_size);

const size_t nthread_to_process = std::min(nthread, n_blocks);
memset(thread_init_.data(), '\0', nthread_to_process*sizeof(size_t));

const size_t cache_line_size = 64;
const size_t prefetch_offset = 10;
size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;

#pragma omp parallel for num_threads(nthread_to_process) schedule(guided)
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
dmlc::omp_uint tid = omp_get_thread_num();
double* data_local_hist = ((nthread_to_process == 1) ? hist_data :
reinterpret_cast<double*>(data_.data() + tid * nbins_));

if (!thread_init_[tid]) {
memset(data_local_hist, '\0', 2*nbins_*sizeof(double));
thread_init_[tid] = true;
}

const size_t istart = iblock*block_size;
const size_t iend = (((iblock+1)*block_size > nrows) ? nrows : istart + block_size);
for (size_t i = istart; i < iend; ++i) {
const size_t icol_start = row_ptr[rid[i]];
const size_t icol_end = row_ptr[rid[i]+1];

if (i < nrows - no_prefetch_size) {
PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]);
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
}

for (size_t j = icol_start; j < icol_end; ++j) {
const uint32_t idx_bin = 2*index[j];
const size_t idx_gh = 2*rid[i];
for (size_t i = 0; i < nrows; ++i) {
const size_t icol_start = row_ptr[rid[i]];
const size_t icol_end = row_ptr[rid[i]+1];

data_local_hist[idx_bin] += pgh[idx_gh];
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
}
if (i < nrows - no_prefetch_size) {
PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]);
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
}
}

if (nthread_to_process > 1) {
const size_t size = (2*nbins_);
const size_t block_size = 1024;
size_t n_blocks = size/block_size;
n_blocks += !!(size - n_blocks*block_size);
for (size_t j = icol_start; j < icol_end; ++j) {
const uint32_t idx_bin = 2*index[j];
const size_t idx_gh = 2*rid[i];

size_t n_worked_bins = 0;
for (size_t i = 0; i < nthread_to_process; ++i) {
if (thread_init_[i]) {
thread_init_[n_worked_bins++] = i;
}
}

#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided)
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
const size_t istart = iblock * block_size;
const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size);

const size_t bin = 2 * thread_init_[0] * nbins_;
memcpy(hist_data + istart, (data + bin + istart), sizeof(double) * (iend - istart));

for (size_t i_bin_part = 1; i_bin_part < n_worked_bins; ++i_bin_part) {
const size_t bin = 2 * thread_init_[i_bin_part] * nbins_;
for (size_t i = istart; i < iend; i++) {
hist_data[i] += data[bin + i];
}
}
hist_data[idx_bin] += pgh[idx_gh];
hist_data[idx_bin+1] += pgh[idx_gh+1];
}
}
}
Expand Down
102 changes: 96 additions & 6 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class DenseCuts : public CutsBuilder {

// FIXME(trivialfis): Merge this into generic cut builder.
/*! \brief Builds the cut matrix on the GPU.
*
*
* \return The row stride across the entire dataset.
*/
size_t DeviceSketch(int device,
Expand Down Expand Up @@ -347,6 +347,11 @@ class GHistIndexBlockMatrix {
*/
using GHistRow = Span<tree::GradStats>;

/*!
* \brief fill a histogram by zeroes
*/
void InitilizeHistByZeroes(GHistRow hist);

/*!
* \brief histogram of gradient statistics for multiple nodes
*/
Expand All @@ -369,9 +374,12 @@ class HistCollection {

// initialize histogram collection
void Init(uint32_t nbins) {
nbins_ = nbins;
if (nbins_ != nbins) {
nbins_ = nbins;
// quite expensive operation, so let's do this only once
data_.clear();
}
row_ptr_.clear();
data_.clear();
}

// create an empty histogram for i-th node
Expand All @@ -382,20 +390,102 @@ class HistCollection {
}
CHECK_EQ(row_ptr_[nid], kMax);

row_ptr_[nid] = data_.size();
data_.resize(data_.size() + nbins_);
row_ptr_[nid] = nbins_ * nid;

if (data_.size() <= nbins_ * (nid + 1)) {
data_.resize(nbins_ * (nid + 1));
}
}

private:
/*! \brief number of all bins over all features */
uint32_t nbins_;
uint32_t nbins_ = 0;

std::vector<tree::GradStats> data_;

/*! \brief row_ptr_[nid] locates bin for historgram of node nid */
std::vector<size_t> row_ptr_;
};

/*!
* \brief Stores temporary histograms to compute them in parallel
* Supports processing multiple tree-nodes for nested parallelism
* Able to reduce histograms across threads in efficient way
*/
class HistBuffer {
public:
void Init(size_t nbins) {
if (nbins != nbins_) {
hist_.Init(nbins);
max_size_ = 0;
nbins_ = nbins;
}
}

// Add new elements if needed, mark all hists as unused
void Reset(size_t nthreads, size_t nodes) {
const size_t new_size = nthreads * nodes;
for (size_t i = max_size_; i < new_size; ++i) {
hist_.AddHistRow(i);
}
max_size_ = std::max(max_size_, new_size);

nodes_ = nodes;
nthreads_ = nthreads;

hist_was_used_.resize(nthreads * nodes);
std::fill(hist_was_used_.begin(), hist_was_used_.end(), false);
}

// Get specified hist, initilize hist by zeroes if it wasn't used before
GHistRow GetInitializedHist(size_t tid, size_t nid) {
CHECK_LT(nid, nodes_);
CHECK_LT(tid, nthreads_);
GHistRow hist = hist_[tid * nodes_ + nid];

if (!hist_was_used_[tid * nodes_ + nid]) {
InitilizeHistByZeroes(hist);
hist_was_used_[tid * nodes_ + nid] = true;
}

return hist;
}

// Reduce following bins (begin, end] for nid-node in dst across threads
void ReduceHist(GHistRow dst, size_t nid, size_t begin, size_t end) {
CHECK_GT(end, begin);
CHECK_LT(nid, nodes_);

for (size_t tid = 0; tid < nthreads_; ++tid) {
if (hist_was_used_[tid * nodes_ + nid]) {
GHistRow src = hist_[tid * nodes_ + nid];

for (size_t i = begin; i < end; ++i) {
dst[i].Add(src[i]);
}
}
}
}

protected:
/*! \brief number of bins in each histogram */
size_t nbins_ = 0;
/*! \brief number of threads for parallel computation */
size_t nthreads_ = 0;
/*! \brief number of nodes which will be processed in parallel */
size_t nodes_ = 0;
/*! \brief Real size of hist_ */
size_t max_size_ = 0;
/*! \brief number of nodes which will be processed in parallel */
HistCollection hist_;
/*!
* \brief Marks which hists were used, it means that they should be merged.
* Contains only {true or false} values
* but 'int' is used instead of 'bool', because std::vector<bool> isn't thread safe
*/
std::vector<int> hist_was_used_;
};

/*!
* \brief builder for histograms of gradient statistics
*/
Expand Down
123 changes: 123 additions & 0 deletions src/common/threading_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*!
* Copyright 2015-2019 by Contributors
* \file common.h
* \brief Threading utilities
*/
#ifndef XGBOOST_COMMON_THREADING_UTILS_H_
#define XGBOOST_COMMON_THREADING_UTILS_H_

#include <vector>
#include <algorithm>

namespace xgboost {
namespace common {

// Represent simple range of indexes [begin, end)
// Inspired by tbb::blocked_range
class Range1d {
public:
Range1d(size_t begin, size_t end): begin_(begin), end_(end) {
CHECK_LT(begin, end);
}

size_t begin() {
return begin_;
}

size_t end() {
return end_;
}

private:
size_t begin_;
size_t end_;
};


// Split 2d space to balanced blocks
// Implementation of the class is inspired by tbb::blocked_range2d
// However, TBB provides only (n x m) 2d range (matrix) separated by blocks. Example:
// [ 1,2,3 ]
// [ 4,5,6 ]
// [ 7,8,9 ]
// But the class is able to work with different sizes in each 'row'. Example:
// [ 1,2 ]
// [ 3,4,5,6 ]
// [ 7,8,9]
// If grain_size is 2: It produces following blocks:
// [1,2], [3,4], [5,6], [7,8], [9]
// The class helps to process data in several tree nodes (non-balanced usually) in parallel
// Using nested parallelism (by nodes and by data in each node)
// it helps to improve CPU resources utilization
class BlockedSpace2d {
public:
// Example of space:
// [ 1,2 ]
// [ 3,4,5,6 ]
// [ 7,8,9]
// BlockedSpace2d will create following blocks (tasks) if grain_size=2:
// 1-block: first_dimension = 0, range of indexes in a 'row' = [0,2) (includes [1,2] values)
// 2-block: first_dimension = 1, range of indexes in a 'row' = [0,2) (includes [3,4] values)
// 3-block: first_dimension = 1, range of indexes in a 'row' = [2,4) (includes [5,6] values)
// 4-block: first_dimension = 2, range of indexes in a 'row' = [0,2) (includes [7,8] values)
// 5-block: first_dimension = 2, range of indexes in a 'row' = [2,3) (includes [9] values)
// Arguments:
// dim1 - size of the first dimension in the space
// getter_size_dim2 - functor to get the second dimensions for each 'row' by row-index
// grain_size - max size of produced blocks
template<typename Func>
BlockedSpace2d(size_t dim1, Func getter_size_dim2, size_t grain_size) {
for (size_t i = 0; i < dim1; ++i) {
const size_t size = getter_size_dim2(i);
const size_t n_blocks = size/grain_size + !!(size % grain_size);
for (size_t iblock = 0; iblock < n_blocks; ++iblock) {
const size_t begin = iblock * grain_size;
const size_t end = std::min(begin + grain_size, size);
AddBlock(i, begin, end);
}
}
}

// Amount of blocks(tasks) in a space
size_t Size() const {
return ranges_.size();
}

// get index of the first dimension of i-th block(task)
size_t GetFirstDimension(size_t i) const {
CHECK_LT(i, first_dimension_.size());
return first_dimension_[i];
}

// get a range of indexes for the second dimension of i-th block(task)
Range1d GetRange(size_t i) const {
CHECK_LT(i, ranges_.size());
return ranges_[i];
}

private:
void AddBlock(size_t first_dimension, size_t begin, size_t end) {
first_dimension_.push_back(first_dimension);
ranges_.emplace_back(begin, end);
}

std::vector<Range1d> ranges_;
std::vector<size_t> first_dimension_;
};


// Wrapper to implement nested parallelism with simple omp parallel for
template<typename Func>
void ParallelFor2d(const BlockedSpace2d& space, Func func) {
const int num_blocks_in_space = static_cast<int>(space.Size());

#pragma omp parallel for
for (auto i = 0; i < num_blocks_in_space; i++) {
func(space.GetFirstDimension(i), space.GetRange(i));
}
}

} // namespace common
} // namespace xgboost

#endif // XGBOOST_COMMON_THREADING_UTILS_H_
Loading

0 comments on commit fc565ac

Please sign in to comment.