Skip to content

Commit

Permalink
Use individual streams in DT builder as opposed to handle getstream (#…
Browse files Browse the repository at this point in the history
…4258)

* FIX Use individual streams in DT bulder as opposed to handle getstream

* FIX clang format fixes

* FIX Pass stream from RF to DT

* FIX clang format fixes
  • Loading branch information
dantegd authored Oct 5, 2021
1 parent 7102c02 commit bf6ade1
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 48 deletions.
92 changes: 46 additions & 46 deletions cpp/src/decisiontree/batched-levelalgo/builder.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ struct Builder {
/** default threads per block for most kernels in here */
static constexpr int TPB_DEFAULT = 128;
const raft::handle_t& handle;
cudaStream_t builder_stream;
/** DT params */
DecisionTreeParams params;
/** input dataset */
Expand Down Expand Up @@ -194,6 +195,7 @@ struct Builder {
ML::pinned_host_vector<char> h_buff;

Builder(const raft::handle_t& handle,
cudaStream_t s,
IdxT treeid,
uint64_t seed,
const DecisionTreeParams& p,
Expand All @@ -205,6 +207,7 @@ struct Builder {
IdxT nclasses,
std::shared_ptr<const rmm::device_uvector<DataT>> quantiles)
: handle(handle),
builder_stream(s),
treeid(treeid),
seed(seed),
params(p),
Expand All @@ -218,14 +221,14 @@ struct Builder {
rowids->data(),
nclasses,
quantiles->data()},
d_buff(0, handle.get_stream())
d_buff(0, builder_stream)
{
max_blocks = 1 + params.max_batch_size + input.nSampledRows / TPB_DEFAULT;
ASSERT(quantiles != nullptr, "Currently quantiles need to be computed before this call!");
ASSERT(nclasses >= 1, "nclasses should be at least 1");

auto [device_workspace_size, host_workspace_size] = workspaceSize();
d_buff.resize(device_workspace_size, handle.get_stream());
d_buff.resize(device_workspace_size, builder_stream);
h_buff.resize(host_workspace_size);
assignWorkspace(d_buff.data(), h_buff.data());
}
Expand Down Expand Up @@ -301,8 +304,8 @@ struct Builder {
d_wspace += calculateAlignedBytes(sizeof(WorkloadInfo<IdxT>) * max_blocks);

CUDA_CHECK(
cudaMemsetAsync(done_count, 0, sizeof(int) * max_batch * n_col_blks, handle.get_stream()));
CUDA_CHECK(cudaMemsetAsync(mutex, 0, sizeof(int) * max_batch, handle.get_stream()));
cudaMemsetAsync(done_count, 0, sizeof(int) * max_batch * n_col_blks, builder_stream));
CUDA_CHECK(cudaMemsetAsync(mutex, 0, sizeof(int) * max_batch, builder_stream));

// host
h_workload_info = reinterpret_cast<WorkloadInfo<IdxT>*>(h_wspace);
Expand Down Expand Up @@ -349,19 +352,19 @@ struct Builder {
}
total_num_blocks += num_blocks;
}
raft::update_device(workload_info, h_workload_info, total_num_blocks, handle.get_stream());
raft::update_device(workload_info, h_workload_info, total_num_blocks, builder_stream);
return std::make_pair(total_num_blocks, n_large_nodes_in_curr_batch);
}

auto doSplit(const std::vector<NodeWorkItem>& work_items)
{
ML::PUSH_RANGE("Builder::doSplit @bulder_base.cuh [batched-levelalgo]");
// start fresh on the number of *new* nodes created in this batch
CUDA_CHECK(cudaMemsetAsync(n_nodes, 0, sizeof(IdxT), handle.get_stream()));
initSplit<DataT, IdxT, TPB_DEFAULT>(splits, work_items.size(), handle.get_stream());
CUDA_CHECK(cudaMemsetAsync(n_nodes, 0, sizeof(IdxT), builder_stream));
initSplit<DataT, IdxT, TPB_DEFAULT>(splits, work_items.size(), builder_stream);

// get the current set of nodes to be worked upon
raft::update_device(d_work_items, work_items.data(), work_items.size(), handle.get_stream());
raft::update_device(d_work_items, work_items.data(), work_items.size(), builder_stream);

auto [total_blocks, large_blocks] = this->updateWorkloadInfo(work_items);

Expand All @@ -376,19 +379,18 @@ struct Builder {
auto smemSize = 2 * sizeof(IdxT) * TPB_DEFAULT;
ML::PUSH_RANGE("nodeSplitKernel @builder_base.cuh [batched-levelalgo]");
nodeSplitKernel<DataT, LabelT, IdxT, ObjectiveT, TPB_DEFAULT>
<<<work_items.size(), TPB_DEFAULT, smemSize, handle.get_stream()>>>(
params.max_depth,
params.min_samples_leaf,
params.min_samples_split,
params.max_leaves,
params.min_impurity_decrease,
input,
d_work_items,
splits);
<<<work_items.size(), TPB_DEFAULT, smemSize, builder_stream>>>(params.max_depth,
params.min_samples_leaf,
params.min_samples_split,
params.max_leaves,
params.min_impurity_decrease,
input,
d_work_items,
splits);
CUDA_CHECK(cudaGetLastError());
ML::POP_RANGE();
raft::update_host(h_splits, splits, work_items.size(), handle.get_stream());
CUDA_CHECK(cudaStreamSynchronize(handle.get_stream()));
raft::update_host(h_splits, splits, work_items.size(), builder_stream);
CUDA_CHECK(cudaStreamSynchronize(builder_stream));
ML::POP_RANGE();
return std::make_tuple(h_splits, work_items.size());
}
Expand Down Expand Up @@ -421,25 +423,25 @@ struct Builder {
auto smemSize = computeSplitSmemSize();
dim3 grid(total_blocks, colBlks, 1);
int nHistBins = large_blocks * nbins * colBlks * nclasses;
CUDA_CHECK(cudaMemsetAsync(hist, 0, sizeof(BinT) * nHistBins, handle.get_stream()));
CUDA_CHECK(cudaMemsetAsync(hist, 0, sizeof(BinT) * nHistBins, builder_stream));
ML::PUSH_RANGE("computeSplitClassificationKernel @builder_base.cuh [batched-levelalgo]");
ObjectiveT objective(input.numOutputs, params.min_samples_leaf);
computeSplitKernel<DataT, LabelT, IdxT, TPB_DEFAULT>
<<<grid, TPB_DEFAULT, smemSize, handle.get_stream()>>>(hist,
params.n_bins,
params.max_depth,
params.min_samples_split,
params.max_leaves,
input,
d_work_items,
col,
done_count,
mutex,
splits,
objective,
treeid,
workload_info,
seed);
<<<grid, TPB_DEFAULT, smemSize, builder_stream>>>(hist,
params.n_bins,
params.max_depth,
params.min_samples_split,
params.max_leaves,
input,
d_work_items,
col,
done_count,
mutex,
splits,
objective,
treeid,
workload_info,
seed);
ML::POP_RANGE(); // computeSplitKernel
ML::POP_RANGE(); // Builder::computeSplit
}
Expand All @@ -453,32 +455,30 @@ struct Builder {
"Expected instance range for each node");
// do this in batch to reduce peak memory usage in extreme cases
std::size_t max_batch_size = min(std::size_t(100000), tree->sparsetree.size());
rmm::device_uvector<NodeT> d_tree(max_batch_size, handle.get_stream());
rmm::device_uvector<InstanceRange> d_instance_ranges(max_batch_size, handle.get_stream());
rmm::device_uvector<DataT> d_leaves(max_batch_size * input.numOutputs, handle.get_stream());
rmm::device_uvector<NodeT> d_tree(max_batch_size, builder_stream);
rmm::device_uvector<InstanceRange> d_instance_ranges(max_batch_size, builder_stream);
rmm::device_uvector<DataT> d_leaves(max_batch_size * input.numOutputs, builder_stream);

ObjectiveT objective(input.numOutputs, params.min_samples_leaf);
for (std::size_t batch_begin = 0; batch_begin < tree->sparsetree.size();
batch_begin += max_batch_size) {
std::size_t batch_end = min(batch_begin + max_batch_size, tree->sparsetree.size());
std::size_t batch_size = batch_end - batch_begin;
raft::update_device(
d_tree.data(), tree->sparsetree.data() + batch_begin, batch_size, handle.get_stream());
raft::update_device(d_instance_ranges.data(),
instance_ranges.data() + batch_begin,
batch_size,
handle.get_stream());
d_tree.data(), tree->sparsetree.data() + batch_begin, batch_size, builder_stream);
raft::update_device(
d_instance_ranges.data(), instance_ranges.data() + batch_begin, batch_size, builder_stream);

CUDA_CHECK(
cudaMemsetAsync(d_leaves.data(), 0, sizeof(DataT) * d_leaves.size(), handle.get_stream()));
cudaMemsetAsync(d_leaves.data(), 0, sizeof(DataT) * d_leaves.size(), builder_stream));
size_t smemSize = sizeof(BinT) * input.numOutputs;
int num_blocks = batch_size;
leafKernel<<<num_blocks, TPB_DEFAULT, smemSize, handle.get_stream()>>>(
leafKernel<<<num_blocks, TPB_DEFAULT, smemSize, builder_stream>>>(
objective, input, d_tree.data(), d_instance_ranges.data(), d_leaves.data());
raft::update_host(tree->vector_leaf.data() + batch_begin * input.numOutputs,
d_leaves.data(),
batch_size * input.numOutputs,
handle.get_stream());
builder_stream);
}
}
}; // end Builder
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/decisiontree/decisiontree.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ class DecisionTree {
template <class DataT, class LabelT>
static std::shared_ptr<DT::TreeMetaDataNode<DataT, LabelT>> fit(
const raft::handle_t& handle,
const cudaStream_t s,
const DataT* data,
const int ncols,
const int nrows,
Expand All @@ -253,6 +254,7 @@ class DecisionTree {
// Dispatch objective
if (params.split_criterion == CRITERION::GINI) {
return Builder<GiniObjectiveFunction<DataT, LabelT, IdxT>>(handle,
s,
treeid,
seed,
params,
Expand All @@ -266,6 +268,7 @@ class DecisionTree {
.train();
} else if (params.split_criterion == CRITERION::ENTROPY) {
return Builder<EntropyObjectiveFunction<DataT, LabelT, IdxT>>(handle,
s,
treeid,
seed,
params,
Expand All @@ -279,6 +282,7 @@ class DecisionTree {
.train();
} else if (params.split_criterion == CRITERION::MSE) {
return Builder<MSEObjectiveFunction<DataT, LabelT, IdxT>>(handle,
s,
treeid,
seed,
params,
Expand All @@ -292,6 +296,7 @@ class DecisionTree {
.train();
} else if (params.split_criterion == CRITERION::POISSON) {
return Builder<PoissonObjectiveFunction<DataT, LabelT, IdxT>>(handle,
s,
treeid,
seed,
params,
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/randomforest/randomforest.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ class RandomForest {
#pragma omp parallel for num_threads(n_streams)
for (int i = 0; i < this->rf_params.n_trees; i++) {
int stream_id = omp_get_thread_num();
auto s = handle.get_internal_stream(stream_id);

this->get_row_sample(
i, n_rows, &selected_rows[stream_id], handle.get_internal_stream(stream_id));
this->get_row_sample(i, n_rows, &selected_rows[stream_id], s);

/* Build individual tree in the forest.
- input is a pointer to orig data that have n_cols features and n_rows rows.
Expand All @@ -176,6 +176,7 @@ class RandomForest {
*/

forest->trees[i] = DT::DecisionTree::fit(handle,
s,
input,
n_cols,
n_rows,
Expand Down

0 comments on commit bf6ade1

Please sign in to comment.