Skip to content

Commit

Permalink
Merge pull request rapidsai#3892 from ajschmidt8/branch-21.08-merge-2…
Browse files Browse the repository at this point in the history
…1.06

Fix merge conflicts [skip ci]
  • Loading branch information
ajschmidt8 authored May 24, 2021
2 parents c4f19c9 + 1568bba commit 3fe9cbb
Show file tree
Hide file tree
Showing 29 changed files with 420 additions and 813 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

Please see https://github.com/rapidsai/cuml/releases/tag/v21.08.0a for the latest changes to this development branch.

# cuML 0.20.0 (Date TBD)
# cuML 21.06.00 (Date TBD)

Please see https://github.com/rapidsai/cuml/releases/tag/v0.20.0a for the latest changes to this development branch.
Please see https://github.com/rapidsai/cuml/releases/tag/v21.06.00a for the latest changes to this development branch.

# cuML 0.19.0 (21 Apr 2021)

Expand Down
11 changes: 7 additions & 4 deletions cpp/include/cuml/manifold/umapparams.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -150,9 +150,12 @@ class UMAPParams {

uint64_t random_state = 0;

bool multicore_implem = true;

int optim_batch_size = 0;
/**
* Whether should we use deterministic algorithm. This should be set to true if
random_state is provided, otherwise it's false. When it's true, cuml will have
higher memory usage but produce stable numeric output.
*/
bool deterministic = true;

Internals::GraphBasedDimRedCallback* callback = nullptr;
};
Expand Down
27 changes: 5 additions & 22 deletions cpp/src/decisiontree/batched-levelalgo/builder_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ struct Builder {
int* hist;
/** sum of predictions (regression only) */
DataT* pred;
/** MAE computation (regression only) */
DataT* pred2;
/** parent MAE computation (regression only) */
DataT* pred2P;
/** node count tracker for averaging (regression only) */
IdxT* pred_count;
/** threadblock arrival count */
Expand Down Expand Up @@ -221,9 +217,6 @@ struct Builder {
// x2 for left and right children
d_wsize +=
calculateAlignedBytes(2 * nPredCounts * sizeof(DataT)); // pred
d_wsize +=
calculateAlignedBytes(2 * nPredCounts * sizeof(DataT)); // pred2
d_wsize += calculateAlignedBytes(nPredCounts * sizeof(DataT)); // pred2P
d_wsize +=
calculateAlignedBytes(nPredCounts * sizeof(IdxT)); // pred_count
}
Expand Down Expand Up @@ -263,10 +256,6 @@ struct Builder {
} else {
pred = reinterpret_cast<DataT*>(d_wspace);
d_wspace += calculateAlignedBytes(2 * nPredCounts * sizeof(DataT));
pred2 = reinterpret_cast<DataT*>(d_wspace);
d_wspace += calculateAlignedBytes(2 * nPredCounts * sizeof(DataT));
pred2P = reinterpret_cast<DataT*>(d_wspace);
d_wspace += calculateAlignedBytes(nPredCounts * sizeof(DataT));
pred_count = reinterpret_cast<IdxT*>(d_wspace);
d_wspace += calculateAlignedBytes(nPredCounts * sizeof(IdxT));
}
Expand Down Expand Up @@ -536,9 +525,6 @@ struct RegTraits {
nbins * sizeof(int) + // pdf_scount
nbins * sizeof(int) + // cdf_scount
nbins * sizeof(DataT) + // sbins
2 * nbins * sizeof(DataT) + // spred2
nbins * sizeof(DataT) + // spred2P
nbins * sizeof(DataT) + // spredP
sizeof(int); // sDone
// Room for alignment (see alignPointer in computeSplitRegressionKernel)
smemSize1 += 6 * sizeof(DataT) + 3 * sizeof(int);
Expand All @@ -556,21 +542,18 @@ struct RegTraits {

CUDA_CHECK(
cudaMemsetAsync(b.pred, 0, sizeof(DataT) * b.nPredCounts * 2, s));
CUDA_CHECK(
cudaMemsetAsync(b.pred2, 0, sizeof(DataT) * b.nPredCounts * 2, s));
CUDA_CHECK(cudaMemsetAsync(b.pred2P, 0, sizeof(DataT) * b.nPredCounts, s));
CUDA_CHECK(
cudaMemsetAsync(b.pred_count, 0, sizeof(IdxT) * b.nPredCounts, s));

ML::PUSH_RANGE(
"computeSplitRegressionKernel @builder_base.cuh [batched-levelalgo]");
computeSplitRegressionKernel<DataT, DataT, IdxT, TPB_DEFAULT>
<<<grid, TPB_DEFAULT, smemSize, s>>>(
b.pred, b.pred2, b.pred2P, b.pred_count, b.params.n_bins,
b.params.max_depth, b.params.min_samples_split,
b.params.min_samples_leaf, b.params.min_impurity_decrease,
b.params.max_leaves, b.input, b.curr_nodes, col, b.done_count, b.mutex,
b.n_leaves, b.splits, b.block_sync, splitType, b.treeid, b.seed);
b.pred, b.pred_count, b.params.n_bins, b.params.max_depth,
b.params.min_samples_split, b.params.min_samples_leaf,
b.params.min_impurity_decrease, b.params.max_leaves, b.input,
b.curr_nodes, col, b.done_count, b.mutex, b.n_leaves, b.splits,
b.block_sync, splitType, b.treeid, b.seed);
ML::POP_RANGE(); //computeSplitRegressionKernel
ML::POP_RANGE(); //Builder::computeSplit
}
Expand Down
103 changes: 7 additions & 96 deletions cpp/src/decisiontree/batched-levelalgo/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,8 @@ __global__ void computeSplitClassificationKernel(

template <typename DataT, typename LabelT, typename IdxT, int TPB>
__global__ void computeSplitRegressionKernel(
DataT* pred, DataT* pred2, DataT* pred2P, IdxT* count, IdxT nbins,
IdxT max_depth, IdxT min_samples_split, IdxT min_samples_leaf,
DataT min_impurity_decrease, IdxT max_leaves,
DataT* pred, IdxT* count, IdxT nbins, IdxT max_depth, IdxT min_samples_split,
IdxT min_samples_leaf, DataT min_impurity_decrease, IdxT max_leaves,
Input<DataT, LabelT, IdxT> input, const Node<DataT, LabelT, IdxT>* nodes,
IdxT colStart, int* done_count, int* mutex, const IdxT* n_leaves,
volatile Split<DataT, IdxT>* splits, void* workspace, CRITERION splitType,
Expand All @@ -531,9 +530,8 @@ __global__ void computeSplitRegressionKernel(

// variables
auto end = range_start + range_len;
auto len = nbins * 2;
auto pdf_spred_len = 1 + nbins;
auto cdf_spred_len = 2 * nbins;
auto cdf_spred_len = nbins;
IdxT stride = blockDim.x * gridDim.x;
IdxT tid = threadIdx.x + blockIdx.x * blockDim.x;
IdxT col;
Expand All @@ -544,10 +542,7 @@ __global__ void computeSplitRegressionKernel(
auto* pdf_scount = alignPointer<int>(cdf_spred + cdf_spred_len);
auto* cdf_scount = alignPointer<int>(pdf_scount + nbins);
auto* sbins = alignPointer<DataT>(cdf_scount + nbins);
auto* spred2 = alignPointer<DataT>(sbins + nbins);
auto* spred2P = alignPointer<DataT>(spred2 + len);
auto* spredP = alignPointer<DataT>(spred2P + nbins);
auto* sDone = alignPointer<int>(spredP + nbins);
auto* sDone = alignPointer<int>(sbins + nbins);

// select random feature to split-check
// (if feature-sampling is true)
Expand Down Expand Up @@ -603,22 +598,13 @@ __global__ void computeSplitRegressionKernel(
__threadfence(); // for commit guarantee
__syncthreads();

// Wait until all blockIdx.x's are done
MLCommon::GridSync gs(workspace, MLCommon::SyncType::ACROSS_X, false);
gs.sync();

// transfer from global to smem
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
pdf_scount[i] = count[gcOffset + i];
spred2P[i] = DataT(0.0);
}
for (IdxT i = threadIdx.x; i < pdf_spred_len; i += blockDim.x) {
pdf_spred[i] = pred[gOffset + i];
}
// memset spred2
for (IdxT i = threadIdx.x; i < len; i += blockDim.x) {
spred2[i] = DataT(0.0);
}
__syncthreads();

/** pdf to cdf conversion **/
Expand All @@ -627,77 +613,10 @@ __global__ void computeSplitRegressionKernel(
// cdf of samples lesser-than-equal to threshold
DataT total_sum = pdf_to_cdf<DataT, IdxT, TPB>(pdf_spred, cdf_spred, nbins);

// cdf of samples greater than threshold
// calculated by subtracting lesser-than-equals from total_sum
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
*(cdf_spred + nbins + i) = total_sum - *(cdf_spred + i);
}

/** get cdf of scount from pdf_scount **/
pdf_to_cdf<int, IdxT, TPB>(pdf_scount, cdf_scount, nbins);
__syncthreads();

// calcualting prediction average-sums
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
spredP[i] = cdf_spred[i] + cdf_spred[i + nbins];
}
__syncthreads();

// now, compute the mean value to be used for metric update
auto invlen = DataT(1.0) / range_len;
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
auto cnt_l = DataT(cdf_scount[i]);
auto cnt_r = DataT(range_len - cdf_scount[i]);
cdf_spred[i] /= cnt_l;
cdf_spred[i + nbins] /= cnt_r;
spredP[i] *= invlen;
}
__syncthreads();

/* Make a second pass over the data to compute gain */

// 2nd pass over data to compute partial metric across blockIdx.x's
if (splitType == CRITERION::MAE) {
for (auto i = range_start + tid; i < end; i += stride) {
auto row = input.rowids[i];
auto d = input.data[row + coloffset];
auto label = input.labels[row];
for (IdxT b = 0; b < nbins; ++b) {
auto isRight = d > sbins[b]; // no divergence
auto offset = isRight * nbins + b;
auto diff = label - (isRight ? cdf_spred[nbins + b] : cdf_spred[b]);
atomicAdd(spred2 + offset, raft::myAbs(diff));
atomicAdd(spred2P + b, raft::myAbs(label - spredP[b]));
}
}
} else {
for (auto i = range_start + tid; i < end; i += stride) {
auto row = input.rowids[i];
auto d = input.data[row + coloffset];
auto label = input.labels[row];
for (IdxT b = 0; b < nbins; ++b) {
auto isRight = d > sbins[b]; // no divergence
auto offset = isRight * nbins + b;
auto diff = label - (isRight ? cdf_spred[nbins + b] : cdf_spred[b]);
auto diff2 = label - spredP[b];
atomicAdd(spred2 + offset, (diff * diff));
atomicAdd(spred2P + b, (diff2 * diff2));
}
}
}
__syncthreads();

// update the corresponding global location for pred2P
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
atomicAdd(pred2P + gcOffset + i, spred2P[i]);
}

// changing gOffset for pred2 from that of pred
gOffset = ((nid * gridDim.y) + blockIdx.y) * len;
// update the corresponding global location for pred2
for (IdxT i = threadIdx.x; i < len; i += blockDim.x) {
atomicAdd(pred2 + gOffset + i, spred2[i]);
}
__threadfence(); // for commit guarantee
__syncthreads();

Expand All @@ -716,19 +635,11 @@ __global__ void computeSplitRegressionKernel(
Split<DataT, IdxT> sp;
sp.init();

// store global pred2 and pred2P into shared memory of last x-dim block
for (IdxT i = threadIdx.x; i < len; i += blockDim.x) {
spred2[i] = pred2[gOffset + i];
}
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
spred2P[i] = pred2P[gcOffset + i];
}
__syncthreads();

// calculate the best candidate bins (one for each block-thread) in current
// feature and corresponding regression-metric gain for splitting
regressionMetricGain(spred2, spred2P, cdf_scount, sbins, sp, col, range_len,
nbins, min_samples_leaf, min_impurity_decrease);
regressionMetricGain(cdf_spred, cdf_scount, total_sum, sbins, sp, col,
range_len, nbins, min_samples_leaf,
min_impurity_decrease);
__syncthreads();

// calculate best bins among candidate bins per feature using warp reduce
Expand Down
41 changes: 10 additions & 31 deletions cpp/src/decisiontree/batched-levelalgo/metrics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -184,47 +184,26 @@ DI void entropyGain(int* shist, DataT* sbins, Split<DataT, IdxT>& sp, IdxT col,
}
}

/**
* @brief Compute gain based on MSE or MAE
*
* @param[in] spred left/right child sum of abs diff of
* prediction for all bins [dim = 2 x bins]
* @param[in] spredP parent's sum of abs diff of prediction
* for all bins [dim = 2 x bins]
* @param[in] scount left child count for all bins
* [len = nbins]
* @param[in] sbins quantiles for the current column
* [len = nbins]
* @param[inout] sp will contain the per-thread best split
* so far
* @param[in] col current column
* @param[in] len total number of samples for current node
* to be split
* @param[in] nbins number of bins
* @param[in] min_samples_leaf minimum number of samples per each leaf.
* Any splits that lead to a leaf node with
* samples fewer than min_samples_leaf will
* be ignored.
* @param[in] min_impurity_decrease minimum improvement in MSE metric. Any
* splits that do not improve (decrease)
* the MSE metric at least by this amount
* will be ignored.
*/
template <typename DataT, typename IdxT>
DI void regressionMetricGain(DataT* spred, DataT* spredP, IdxT* scount,
DataT* sbins, Split<DataT, IdxT>& sp, IdxT col,
IdxT len, IdxT nbins, IdxT min_samples_leaf,
DI void regressionMetricGain(DataT* slabel_cdf, IdxT* scount_cdf,
DataT label_sum, DataT* sbins,
Split<DataT, IdxT>& sp, IdxT col, IdxT len,
IdxT nbins, IdxT min_samples_leaf,
DataT min_impurity_decrease) {
auto invlen = DataT(1.0) / len;
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
auto nLeft = scount[i];
auto nLeft = scount_cdf[i];
auto nRight = len - nLeft;
DataT gain;
// if there aren't enough samples in this split, don't bother!
if (nLeft < min_samples_leaf || nRight < min_samples_leaf) {
gain = -NumericLimits<DataT>::kMax;
} else {
gain = spredP[i] - spred[i] - spred[i + nbins];
DataT parent_obj = -label_sum * label_sum / len;
DataT left_obj = -(slabel_cdf[i] * slabel_cdf[i]) / nLeft;
DataT right_label_sum = slabel_cdf[i] - label_sum;
DataT right_obj = -(right_label_sum * right_label_sum) / nRight;
gain = parent_obj - (left_obj + right_obj);
gain *= invlen;
}
// if the gain is not "enough", don't bother!
Expand Down
1 change: 1 addition & 0 deletions cpp/src/decisiontree/decisiontree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ void validity_check(const DecisionTreeParams params) {
"max_features value %f outside permitted (0, 1] range",
params.max_features);
ASSERT((params.n_bins > 0), "Invalid n_bins %d", params.n_bins);
ASSERT((params.split_criterion != 3), "MAE not supported.");
ASSERT((params.split_algo >= 0) &&
(params.split_algo < SPLIT_ALGO::SPLIT_ALGO_END),
"split_algo value %d outside permitted [0, %d) range",
Expand Down
15 changes: 15 additions & 0 deletions cpp/src/hierarchy/pw_dist_graph.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/distance/distance.cuh>

#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>

#include <raft/linalg/distance_type.h>
#include <raft/mr/device/buffer.hpp>
Expand Down Expand Up @@ -71,6 +72,7 @@ void pairwise_distances(const raft::handle_t &handle, const value_t *X,
value_idx *indptr, value_idx *indices, value_t *data) {
auto d_alloc = handle.get_device_allocator();
auto stream = handle.get_stream();
auto exec_policy = rmm::exec_policy(stream);

value_idx nnz = m * m;

Expand All @@ -90,6 +92,19 @@ void pairwise_distances(const raft::handle_t &handle, const value_t *X,
// usage to hand it a sparse array here.
raft::distance::pairwise_distance<value_t, value_idx>(
X, X, data, m, m, n, workspace, metric, stream);

// self-loops get max distance
auto transform_in = thrust::make_zip_iterator(
thrust::make_tuple(thrust::make_counting_iterator(0), data));

thrust::transform(
exec_policy, transform_in, transform_in + nnz, data,
[=] __device__(const thrust::tuple<value_idx, value_t> &tup) {
value_idx idx = thrust::get<0>(tup);
bool self_loop = idx % m == idx / m;
return (self_loop * std::numeric_limits<value_t>::max()) +
(!self_loop * thrust::get<1>(tup));
});
}

/**
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/umap/runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ void _transform(const raft::handle_t &handle, const umap_inputs &inputs,

SimplSetEmbedImpl::optimize_layout<TPB_X, value_t>(
transformed, inputs.n, embedding, embedding_n, comp_coo.rows(),
comp_coo.cols(), comp_coo.nnz, epochs_per_sample.data(), inputs.n,
comp_coo.cols(), comp_coo.nnz, epochs_per_sample.data(),
params->repulsion_strength, params, n_epochs, d_alloc, stream);
ML::POP_RANGE();

Expand Down
Loading

0 comments on commit 3fe9cbb

Please sign in to comment.