Skip to content

Commit

Permalink
Improvements in feature sampling (#4278)
Browse files Browse the repository at this point in the history
With this PR, the feature sampling overhead is greatly reduced, especially for wide (thousands of features) datasets. The PR requires some structural changes in RAFT therefore is marked as WIP.

Authors:
  - Vinay Deshpande (https://github.com/vinaydes)
  - Ray Douglass (https://github.com/raydouglass)
  - Andy Adinets (https://github.com/canonizer)
  - Jordan Jacobelli (https://github.com/Ethyling)
  - Jiwei Liu (https://github.com/daxiongshu)
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Christopher Akiki (https://github.com/cakiki)
  - Venkat (https://github.com/venkywonka)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #4278
  • Loading branch information
vinaydes authored Aug 3, 2022
1 parent 350d709 commit 3b3b891
Show file tree
Hide file tree
Showing 5 changed files with 316 additions and 60 deletions.
86 changes: 83 additions & 3 deletions cpp/src/decisiontree/batched-levelalgo/builder.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ struct Builder {
int n_blks_for_cols = 10;
/** Memory alignment value */
const size_t align_value = 512;
IdxT* colids;
/** rmm device workspace buffer */
rmm::device_uvector<char> d_buff;
/** pinned host buffer to store the trained nodes */
Expand Down Expand Up @@ -281,6 +282,7 @@ struct Builder {
d_wsize += calculateAlignedBytes(sizeof(NodeWorkItem) * max_batch); // d_work_Items
d_wsize += // workload_info
calculateAlignedBytes(sizeof(WorkloadInfo<IdxT>) * max_blocks_dimx);
d_wsize += calculateAlignedBytes(sizeof(IdxT) * max_batch * dataset.n_sampled_cols); // colids

// all nodes in the tree
h_wsize += // h_workload_info
Expand Down Expand Up @@ -320,6 +322,8 @@ struct Builder {
d_wspace += calculateAlignedBytes(sizeof(NodeWorkItem) * max_batch);
workload_info = reinterpret_cast<WorkloadInfo<IdxT>*>(d_wspace);
d_wspace += calculateAlignedBytes(sizeof(WorkloadInfo<IdxT>) * max_blocks_dimx);
colids = reinterpret_cast<IdxT*>(d_wspace);
d_wspace += calculateAlignedBytes(sizeof(IdxT) * max_batch * dataset.n_sampled_cols);

RAFT_CUDA_TRY(
cudaMemsetAsync(done_count, 0, sizeof(int) * max_batch * n_col_blks, builder_stream));
Expand Down Expand Up @@ -378,7 +382,7 @@ struct Builder {

auto doSplit(const std::vector<NodeWorkItem>& work_items)
{
raft::common::nvtx::range fun_scope("Builder::doSplit @bulder_base.cuh [batched-levelalgo]");
raft::common::nvtx::range fun_scope("Builder::doSplit @builder.cuh [batched-levelalgo]");
// start fresh on the number of *new* nodes created in this batch
RAFT_CUDA_TRY(cudaMemsetAsync(n_nodes, 0, sizeof(IdxT), builder_stream));
initSplit<DataT, IdxT, TPB_DEFAULT>(splits, work_items.size(), builder_stream);
Expand All @@ -388,11 +392,86 @@ struct Builder {

auto [n_blocks_dimx, n_large_nodes] = this->updateWorkloadInfo(work_items);

// do feature-sampling
if (dataset.n_sampled_cols != dataset.N) {
raft::common::nvtx::range fun_scope("feature-sampling");
constexpr int block_threads = 128;
constexpr int max_samples_per_thread = 72; // register spillage if more than this limit
// decide if the problem size is suitable for the excess-sampling strategy.
//
// our required shared memory is a function of number of samples we'll need to sample (in
// parallel, with replacement) in excess to get 'k' uniques out of 'n' features. estimated
// static shared memory required by cub's block-wide collectives:
// max_samples_per_thread * block_threads * sizeof(IdxT)
//
// The maximum items to sample ( the constant `max_samples_per_thread` to be set at
// compile-time) is calibrated so that:
// 1. There is no register spills and accesses to global memory
// 2. The required static shared memory (ie, `max_samples_per_thread * block_threads *
// sizeof(IdxT)` does not exceed 46KB.
//
// number of samples we'll need to sample (in parallel, with replacement), to expect 'k'
// unique samples from 'n' is given by the following equation: log(1 - k/n)/log(1 - 1/n) ref:
// https://stats.stackexchange.com/questions/296005/the-expected-number-of-unique-elements-drawn-with-replacement
IdxT n_parallel_samples =
std::ceil(raft::myLog(1 - double(dataset.n_sampled_cols) / double(dataset.N)) /
(raft::myLog(1 - 1.f / double(dataset.N))));
// maximum sampling work possible by all threads in a block :
// `max_samples_per_thread * block_thread`
// dynamically calculated sampling work to be done per block:
// `n_parallel_samples`
// former must be greater or equal to than latter for excess-sampling-based strategy
if (max_samples_per_thread * block_threads >= n_parallel_samples) {
raft::common::nvtx::range fun_scope("excess-sampling-based approach");
dim3 grid;
grid.x = work_items.size();
grid.y = 1;
grid.z = 1;

if (n_parallel_samples <= block_threads)
// each thread randomly samples only 1 sample
excess_sample_with_replacement_kernel<IdxT, 1, block_threads>
<<<grid, block_threads, 0, builder_stream>>>(colids,
d_work_items,
work_items.size(),
treeid,
seed,
dataset.N,
dataset.n_sampled_cols,
n_parallel_samples);
else
// each thread does more work and samples `max_samples_per_thread` samples
excess_sample_with_replacement_kernel<IdxT, max_samples_per_thread, block_threads>
<<<grid, block_threads, 0, builder_stream>>>(colids,
d_work_items,
work_items.size(),
treeid,
seed,
dataset.N,
dataset.n_sampled_cols,
n_parallel_samples);
raft::common::nvtx::pop_range();
} else {
raft::common::nvtx::range fun_scope("reservoir-sampling-based approach");
// using algo-L (reservoir sampling) strategy to sample 'dataset.n_sampled_cols' unique
// features from 'dataset.N' total features
dim3 grid;
grid.x = (work_items.size() + 127) / 128;
grid.y = 1;
grid.z = 1;
algo_L_sample_kernel<<<grid, block_threads, 0, builder_stream>>>(
colids, d_work_items, work_items.size(), treeid, seed, dataset.N, dataset.n_sampled_cols);
raft::common::nvtx::pop_range();
}
RAFT_CUDA_TRY(cudaPeekAtLastError());
raft::common::nvtx::pop_range();
}

// iterate through a batch of columns (to reduce the memory pressure) and
// compute the best split at the end
for (IdxT c = 0; c < dataset.n_sampled_cols; c += n_blks_for_cols) {
computeSplit(c, n_blocks_dimx, n_large_nodes);
RAFT_CUDA_TRY(cudaGetLastError());
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

// create child nodes (or make the current ones leaf)
Expand All @@ -407,7 +486,7 @@ struct Builder {
dataset,
d_work_items,
splits);
RAFT_CUDA_TRY(cudaGetLastError());
RAFT_CUDA_TRY(cudaPeekAtLastError());
raft::common::nvtx::pop_range();
raft::update_host(h_splits, splits, work_items.size(), builder_stream);
handle.sync_stream(builder_stream);
Expand Down Expand Up @@ -462,6 +541,7 @@ struct Builder {
quantiles,
d_work_items,
col,
colids,
done_count,
mutex,
splits,
Expand Down
228 changes: 228 additions & 0 deletions cpp/src/decisiontree/batched-levelalgo/kernels/builder_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
#include "../objectives.cuh"
#include "../quantiles.h"

#include <raft/random/rng.hpp>

#include <cub/cub.cuh>

namespace ML {
namespace DT {

Expand Down Expand Up @@ -60,6 +64,13 @@ HDI bool SplitNotValid(const SplitT& split,
(IdxT(num_rows) - split.nLeft) < min_samples_leaf;
}

/* Returns 'dataset' rounded up to a correctly-aligned pointer of type OutT* */
template <typename OutT, typename InT>
DI OutT* alignPointer(InT dataset)
{
return reinterpret_cast<OutT*>(raft::alignTo(reinterpret_cast<size_t>(dataset), sizeof(OutT)));
}

template <typename DataT, typename LabelT, typename IdxT, int TPB>
__global__ void nodeSplitKernel(const IdxT max_depth,
const IdxT min_samples_leaf,
Expand Down Expand Up @@ -111,6 +122,222 @@ HDI IdxT lower_bound(DataT* array, IdxT len, DataT element)
return start;
}

template <typename IdxT>
struct CustomDifference {
__device__ IdxT operator()(const IdxT& lhs, const IdxT& rhs)
{
if (lhs == rhs)
return 0;
else
return 1;
}
};

/**
* @brief Generates 'k' unique samples of features from 'n' feature sample-space.
* Does this for each work-item (node), feeding a unique seed for each (treeid, nodeid
* (=blockIdx.x), threadIdx.x). Method used is a random, parallel, sampling with replacement of
* excess of 'k' samples (hence the name) and then eliminating the dupicates by ordering them. The
* excess number of samples (=`n_parallel_samples`) is calculated such that after ordering there is
* atleast 'k' uniques.
*/
template <typename IdxT, int MAX_SAMPLES_PER_THREAD, int BLOCK_THREADS = 128>
__global__ void excess_sample_with_replacement_kernel(
IdxT* colids,
const NodeWorkItem* work_items,
size_t work_items_size,
IdxT treeid,
uint64_t seed,
size_t n /* total cols to sample from*/,
size_t k /* number of unique cols to sample */,
int n_parallel_samples /* number of cols to sample with replacement */)
{
if (blockIdx.x >= work_items_size) return;

const uint32_t nodeid = work_items[blockIdx.x].idx;

uint64_t subsequence(fnv1a32_basis);
subsequence = fnv1a32(subsequence, uint32_t(threadIdx.x));
subsequence = fnv1a32(subsequence, uint32_t(treeid));
subsequence = fnv1a32(subsequence, uint32_t(nodeid));

raft::random::PCGenerator gen(seed, subsequence, uint64_t(0));
raft::random::UniformIntDistParams<IdxT, uint64_t> uniform_int_dist_params;

uniform_int_dist_params.start = 0;
uniform_int_dist_params.end = n;
uniform_int_dist_params.diff =
uint64_t(uniform_int_dist_params.end - uniform_int_dist_params.start);

IdxT n_uniques = 0;
IdxT items[MAX_SAMPLES_PER_THREAD];
IdxT col_indices[MAX_SAMPLES_PER_THREAD];
IdxT mask[MAX_SAMPLES_PER_THREAD];
// populate this
for (int i = 0; i < MAX_SAMPLES_PER_THREAD; ++i)
mask[i] = 0;

do {
// blocked arrangement
for (int cta_sample_idx = MAX_SAMPLES_PER_THREAD * threadIdx.x, thread_local_sample_idx = 0;
thread_local_sample_idx < MAX_SAMPLES_PER_THREAD;
++cta_sample_idx, ++thread_local_sample_idx) {
// mask of the previous iteration, if exists, is re-used here
// so previously generated unique random numbers are used.
// newly generated random numbers may or may not duplicate the previously generated ones
// but this ensures some forward progress in order to generate atleast 'k' unique random
// samples.
if (mask[thread_local_sample_idx] == 0 and cta_sample_idx < n_parallel_samples)
raft::random::custom_next(
gen, &items[thread_local_sample_idx], uniform_int_dist_params, IdxT(0), IdxT(0));
else if (mask[thread_local_sample_idx] ==
0) // indices that exceed `n_parallel_samples` will not generate
items[thread_local_sample_idx] = n - 1;
else
continue; // this case is for samples whose mask == 1 (saving previous iteraion's random
// number generated)
}

// Specialize BlockRadixSort type for our thread block
typedef cub::BlockRadixSort<IdxT, BLOCK_THREADS, MAX_SAMPLES_PER_THREAD> BlockRadixSortT;
// BlockAdjacentDifference
typedef cub::BlockAdjacentDifference<IdxT, BLOCK_THREADS> BlockAdjacentDifferenceT;
// BlockScan
typedef cub::BlockScan<IdxT, BLOCK_THREADS> BlockScanT;

// Shared memory
__shared__ union TempStorage {
typename BlockRadixSortT::TempStorage sort;
typename BlockAdjacentDifferenceT::TempStorage diff;
typename BlockScanT::TempStorage scan;
} temp_storage;

// collectively sort items
BlockRadixSortT(temp_storage.sort).Sort(items);

__syncthreads();

// compute the mask
// compute the adjacent differences according to the functor
// TODO: Replace deprecated 'FlagHeads' with 'SubtractLeft' when it is available
BlockAdjacentDifferenceT(temp_storage.diff)
.FlagHeads(mask, items, mask, CustomDifference<IdxT>());

__syncthreads();

// do a scan on the mask to get the indices for gathering
BlockScanT(temp_storage.scan).ExclusiveSum(mask, col_indices, n_uniques);

__syncthreads();

} while (n_uniques < k);

// write the items[] of only the ones with mask[]=1 to col[offset + col_idx[]]
IdxT col_offset = k * blockIdx.x;
for (int i = 0; i < MAX_SAMPLES_PER_THREAD; ++i) {
if (mask[i] and col_indices[i] < k) { colids[col_offset + col_indices[i]] = items[i]; }
}
}

// algo L of the reservoir sampling algorithm
/**
* @brief For each work item select 'k' features without replacement from 'n' features using algo-L.
* On exit each row of the colids array will contain k random integers from the [0..n-1] range.
*
* Each thread works on single row. The parameters work_items_size, treeid and seed are
* used to initialize a unique random seed for each work item.
*
* @param colids the generated random indices, size [work_items_size, k] row major layout
* @param work_items
* @param treeid
* @param seed
* @param n total cos to sample from
* @param k number of cols to sample
* algorithm of reservoir sampling. wiki :
* https://en.wikipedia.org/wiki/Reservoir_sampling#An_optimal_algorithm
*/
template <typename IdxT>
__global__ void algo_L_sample_kernel(int* colids,
const NodeWorkItem* work_items,
size_t work_items_size,
IdxT treeid,
uint64_t seed,
size_t n /* total cols to sample from*/,
size_t k /* cols to sample */)
{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= work_items_size) return;
const uint32_t nodeid = work_items[tid].idx;
uint64_t subsequence = (uint64_t(treeid) << 32) | uint64_t(nodeid);
raft::random::PCGenerator gen(seed, subsequence, uint64_t(0));
raft::random::UniformIntDistParams<IdxT, uint64_t> uniform_int_dist_params;
uniform_int_dist_params.start = 0;
uniform_int_dist_params.end = k;
uniform_int_dist_params.diff =
uint64_t(uniform_int_dist_params.end - uniform_int_dist_params.start);
float fp_uniform_val;
IdxT int_uniform_val;
// fp_uniform_val will have a random value between 0 and 1
gen.next(fp_uniform_val);
double W = raft::myExp(raft::myLog(fp_uniform_val) / k);

size_t col(0);
// initially fill the reservoir array in increasing order of cols till k
while (1) {
colids[tid * k + col] = col;
if (col == k - 1)
break;
else
++col;
}
// randomly sample from a geometric distribution
while (col < n) {
// fp_uniform_val will have a random value between 0 and 1
gen.next(fp_uniform_val);
col += static_cast<int>(raft::myLog(fp_uniform_val) / raft::myLog(1 - W)) + 1;
if (col < n) {
// int_uniform_val will now have a random value between 0...k
raft::random::custom_next(gen, &int_uniform_val, uniform_int_dist_params, IdxT(0), IdxT(0));
colids[tid * k + int_uniform_val] = col; // the bad memory coalescing here is hidden
// fp_uniform_val will have a random value between 0 and 1
gen.next(fp_uniform_val);
W *= raft::myExp(raft::myLog(fp_uniform_val) / k);
}
}
}

template <typename IdxT>
__global__ void adaptive_sample_kernel(int* colids,
const NodeWorkItem* work_items,
size_t work_items_size,
IdxT treeid,
uint64_t seed,
int N,
int M)
{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= work_items_size) return;
const uint32_t nodeid = work_items[tid].idx;

uint64_t subsequence = (uint64_t(treeid) << 32) | uint64_t(nodeid);
raft::random::PCGenerator gen(seed, subsequence, uint64_t(0));

int selected_count = 0;
for (int i = 0; i < N; i++) {
uint32_t toss = 0;
gen.next(toss);
uint64_t lhs = uint64_t(M - selected_count);
lhs <<= 32;
uint64_t rhs = uint64_t(toss) * (N - i);
if (lhs > rhs) {
colids[tid * M + selected_count] = i;
selected_count++;
if (selected_count == M) break;
}
}
}

template <typename DataT,
typename LabelT,
typename IdxT,
Expand All @@ -126,6 +353,7 @@ __global__ void computeSplitKernel(BinT* histograms,
const Quantiles<DataT, IdxT> quantiles,
const NodeWorkItem* work_items,
IdxT colStart,
const IdxT* colids,
int* done_count,
int* mutex,
volatile Split<DataT, IdxT>* splits,
Expand Down
Loading

0 comments on commit 3b3b891

Please sign in to comment.