diff --git a/CHANGELOG.md b/CHANGELOG.md index 562db4aafb..3cbd400320 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## New Features - PR #766: Expose score method based on inertia for KMeans +- PR #635: Random Forest & Decision Tree Regression (Single-GPU) ## Improvements diff --git a/cpp/src/decisiontree/algo_helper.h b/cpp/src/decisiontree/algo_helper.h index b6121772d8..9f805eb0ff 100644 --- a/cpp/src/decisiontree/algo_helper.h +++ b/cpp/src/decisiontree/algo_helper.h @@ -23,4 +23,13 @@ enum SPLIT_ALGO { GLOBAL_QUANTILE, SPLIT_ALGO_END, }; + +enum CRITERION { + GINI, + ENTROPY, + MSE, + MAE, + CRITERION_END, }; + +}; // namespace ML diff --git a/cpp/src/decisiontree/decisiontree.cu b/cpp/src/decisiontree/decisiontree.cu index 162d625ba0..f6857c6e2d 100644 --- a/cpp/src/decisiontree/decisiontree.cu +++ b/cpp/src/decisiontree/decisiontree.cu @@ -16,47 +16,50 @@ #include #include "decisiontree.h" -#include "kernels/gini.cuh" -#include "kernels/split_labels.cuh" #include "kernels/col_condenser.cuh" -#include "kernels/evaluate.cuh" +#include "kernels/evaluate_classifier.cuh" +#include "kernels/evaluate_regressor.cuh" +#include "kernels/metric.cuh" #include "kernels/quantile.cuh" -#include "algo_helper.h" -#include "kernels/gini_def.h" -#include -#include "memory.h" -#include "memory.cu" -#include -#include -#include -#include -#include -#include +#include "kernels/split_labels.cuh" +#include "memory.cuh" namespace ML { -namespace DecisionTree { -template -void Question::update(const GiniQuestion & ques) { - column = ques.original_column; - value = ques.value; +bool is_dev_ptr(const void *p) { + cudaPointerAttributes pointer_attr; + cudaError_t err = cudaPointerGetAttributes(&pointer_attr, p); + if (err == cudaSuccess) { + return pointer_attr.devicePointer; + } else { + err = cudaGetLastError(); + return false; + } } -template -void TreeNode::print(std::ostream& os) const { +namespace DecisionTree { - if (left == nullptr && right == nullptr) - os << "(leaf, " << class_predict << ", " << gini_val << ")" ; - else - os << "(" << question.column << ", " << question.value << ", " << gini_val << ")" ; +template +void Question::update(const MetricQuestion &ques) { + column = ques.original_column; + value = ques.value; +} - return; +template +void TreeNode::print(std::ostream &os) const { + if (left == nullptr && right == nullptr) { + os << "(leaf, " << prediction << ", " << split_metric_val << ")"; + } else { + os << "(" << question.column << ", " << question.value << ", " + << split_metric_val << ")"; + } + return; } -template -std::ostream& operator<<(std::ostream& os, const TreeNode * const node) { - node->print(os); - return os; +template +std::ostream &operator<<(std::ostream &os, const TreeNode *const node) { + node->print(os); + return os; } /** @@ -67,295 +70,520 @@ DecisionTreeParams::DecisionTreeParams() {} /** * @brief Decision tree hyper-parameter object constructor to set all DecisionTreeParams members. */ -DecisionTreeParams::DecisionTreeParams(int cfg_max_depth, int cfg_max_leaves, float cfg_max_features, int cfg_n_bins, int cfg_split_algo, - int cfg_min_rows_per_node, bool cfg_bootstrap_features):max_depth(cfg_max_depth), max_leaves(cfg_max_leaves), - max_features(cfg_max_features), n_bins(cfg_n_bins), split_algo(cfg_split_algo), - min_rows_per_node(cfg_min_rows_per_node), bootstrap_features(cfg_bootstrap_features) {} +DecisionTreeParams::DecisionTreeParams( + int cfg_max_depth, int cfg_max_leaves, float cfg_max_features, int cfg_n_bins, + int cfg_split_algo, int cfg_min_rows_per_node, bool cfg_bootstrap_features, + CRITERION cfg_split_criterion, bool cfg_quantile_per_tree) + : max_depth(cfg_max_depth), + max_leaves(cfg_max_leaves), + max_features(cfg_max_features), + n_bins(cfg_n_bins), + split_algo(cfg_split_algo), + min_rows_per_node(cfg_min_rows_per_node), + bootstrap_features(cfg_bootstrap_features), + split_criterion(cfg_split_criterion), + quantile_per_tree(cfg_quantile_per_tree) {} /** * @brief Check validity of all decision tree hyper-parameters. */ void DecisionTreeParams::validity_check() const { - ASSERT((max_depth == -1) || (max_depth > 0), "Invalid max depth %d", max_depth); - ASSERT((max_leaves == -1) || (max_leaves > 0), "Invalid max leaves %d", max_leaves); - ASSERT((max_features > 0) && (max_features <= 1.0), "max_features value %f outside permitted (0, 1] range", max_features); - ASSERT((n_bins > 0), "Invalid n_bins %d", n_bins); - ASSERT((split_algo >= 0) && (split_algo < SPLIT_ALGO::SPLIT_ALGO_END), "split_algo value %d outside permitted [0, %d) range", - split_algo, SPLIT_ALGO::SPLIT_ALGO_END); - ASSERT((min_rows_per_node > 0), "Invalid min # rows per node %d", min_rows_per_node); + ASSERT((max_depth == -1) || (max_depth > 0), "Invalid max depth %d", + max_depth); + ASSERT((max_leaves == -1) || (max_leaves > 0), "Invalid max leaves %d", + max_leaves); + ASSERT((max_features > 0) && (max_features <= 1.0), + "max_features value %f outside permitted (0, 1] range", max_features); + ASSERT((n_bins > 0), "Invalid n_bins %d", n_bins); + ASSERT((split_algo >= 0) && (split_algo < SPLIT_ALGO::SPLIT_ALGO_END), + "split_algo value %d outside permitted [0, %d) range", split_algo, + SPLIT_ALGO::SPLIT_ALGO_END); + ASSERT((min_rows_per_node >= 2), + "Invalid min # rows per node value %d. Should be >= 2.", + min_rows_per_node); } /** * @brief Print all decision tree hyper-parameters. */ void DecisionTreeParams::print() const { - std::cout << "max_depth: " << max_depth << std::endl; - std::cout << "max_leaves: " << max_leaves << std::endl; - std::cout << "max_features: " << max_features << std::endl; - std::cout << "n_bins: " << n_bins << std::endl; - std::cout << "split_algo: " << split_algo << std::endl; - std::cout << "min_rows_per_node: " << min_rows_per_node << std::endl; -} - -/** - * @brief Build (i.e., fit, train) Decision Tree classifier for input data. - * @tparam T: data type for input data (float or double). - * @param[in] handle: cumlHandle - * @param[in] data: train data (nrows samples, ncols features) in column major format, excluding labels. Device pointer. - * @param[in] ncols: number of features (i.e., columns) excluding target feature. - * @param[in] nrows: number of training data samples of the whole unsampled dataset. - * @param[in] labels: 1D array of target features (int only). One label per training sample. Device pointer. - Assumption: labels need to be preprocessed to map to ascending numbers from 0; - needed for current gini impl. in decision tree. - * @param[in,out] rowids: array of n_sampled_rows integers in [0, nrows) range. Device pointer. - The same array is then rearranged when splits are made, allowing us to construct trees without rearranging the actual dataset. - * @param[in] n_sampled_rows: number of training samples, after sampling. If using decision tree directly over the whole dataset: n_sampled_rows = nrows - * @param[in] n_unique_labels: #unique label values. Number of categories of classification. - * @param[in] tree_params: Decision Tree training hyper parameter struct. - */ -template -void DecisionTreeClassifier::fit(const ML::cumlHandle& handle, T *data, const int ncols, const int nrows, int *labels, - unsigned int *rowids, const int n_sampled_rows, int unique_labels, DecisionTreeParams tree_params) { - tree_params.validity_check(); - if (tree_params.n_bins > n_sampled_rows) { - std::cout << "Warning! Calling with number of bins > number of rows! "; - std::cout << "Resetting n_bins to " << n_sampled_rows << "." << std::endl; - tree_params.n_bins = n_sampled_rows; - } - return plant(handle.getImpl(), data, ncols, nrows, labels, rowids, n_sampled_rows, unique_labels, tree_params.max_depth, - tree_params.max_leaves, tree_params.max_features, tree_params.n_bins, tree_params.split_algo, tree_params.min_rows_per_node, tree_params.bootstrap_features); -} - -/** - * @brief Predict target feature for input data; n-ary classification for single feature supported. Inference of trees is CPU only for now. - * @tparam T: data type for input data (float or double). - * @param[in] handle: cumlHandle (currently unused; API placeholder) - * @param[in] rows: test data (n_rows samples, n_cols features) in row major format. CPU pointer. - * @param[in] n_rows: number of data samples. - * @param[in] n_cols: number of features (excluding target feature). - * @param[in,out] predictions: n_rows predicted labels. CPU pointer, user allocated. - * @param[in] verbose: flag for debugging purposes. - */ -template -void DecisionTreeClassifier::predict(const ML::cumlHandle& handle, const T * rows, const int n_rows, const int n_cols, int* predictions, bool verbose) const { - ASSERT(root, "Cannot predict w/ empty tree!"); - ASSERT((n_rows > 0), "Invalid n_rows %d", n_rows); - ASSERT((n_cols > 0), "Invalid n_cols %d", n_cols); - classify_all(rows, n_rows, n_cols, predictions, verbose); + std::cout << "max_depth: " << max_depth << std::endl; + std::cout << "max_leaves: " << max_leaves << std::endl; + std::cout << "max_features: " << max_features << std::endl; + std::cout << "n_bins: " << n_bins << std::endl; + std::cout << "split_algo: " << split_algo << std::endl; + std::cout << "min_rows_per_node: " << min_rows_per_node << std::endl; + std::cout << "split_criterion: " << split_criterion << std::endl; } /** * @brief Print high-level tree information. * @tparam T: data type for input data (float or double). + * @tparam L: data type for labels (int type for classification, T type for regression). */ -template -void DecisionTreeClassifier::print_tree_summary() const { - std::cout << " Decision Tree depth --> " << depth_counter << " and n_leaves --> " << leaf_counter << std::endl; - std::cout << " Total temporary memory usage--> "<< ((double)total_temp_mem / (1024*1024)) << " MB" << std::endl; - std::cout << " Tree growing time --> " << construct_time << " seconds" << std::endl; - std::cout << " Shared memory used --> " << shmem_used << " bytes " << std::endl; +template +void DecisionTreeBase::print_tree_summary() const { + std::cout << " Decision Tree depth --> " << depth_counter + << " and n_leaves --> " << leaf_counter << std::endl; + std::cout << " Total temporary memory usage--> " + << ((double)total_temp_mem / (1024 * 1024)) << " MB" << std::endl; + std::cout << " Tree growing time --> " << construct_time << " seconds" + << std::endl; + std::cout << " Shared memory used --> " << shmem_used << " bytes " + << std::endl; } /** * @brief Print detailed tree information. * @tparam T: data type for input data (float or double). + * @tparam L: data type for labels (int type for classification, T type for regression). */ -template -void DecisionTreeClassifier::print() const { - print_tree_summary(); - print_node("", root, false); +template +void DecisionTreeBase::print() const { + print_tree_summary(); + print_node("", root, false); } -template -void DecisionTreeClassifier::plant(const cumlHandle_impl& handle, T *data, const int ncols, const int nrows, int *labels, unsigned int *rowids, const int n_sampled_rows, - int unique_labels, int maxdepth, int max_leaf_nodes, const float colper, int n_bins, int split_algo_flag, int cfg_min_rows_per_node, bool cfg_bootstrap_features) { - - split_algo = split_algo_flag; - dinfo.NLocalrows = nrows; - dinfo.NGlobalrows = nrows; - dinfo.Ncols = ncols; - nbins = n_bins; - treedepth = maxdepth; - maxleaves = max_leaf_nodes; - tempmem.resize(MAXSTREAMS); - n_unique_labels = unique_labels; - min_rows_per_node = cfg_min_rows_per_node; - bootstrap_features = cfg_bootstrap_features; - - //Bootstrap features - feature_selector.resize(dinfo.Ncols); - if (bootstrap_features) { - srand( n_bins ); - for(int i=0; i < dinfo.Ncols; i++) { - feature_selector.push_back( rand() % dinfo.Ncols ); - } - } else { - std::iota(feature_selector.begin(), feature_selector.end(), 0); - } - - std::random_shuffle(feature_selector.begin(),feature_selector.end()); - feature_selector.resize((int) (colper * dinfo.Ncols)); - - cudaDeviceProp prop; - CUDA_CHECK(cudaGetDeviceProperties(&prop, 0)); - max_shared_mem = prop.sharedMemPerBlock; - - if (split_algo == SPLIT_ALGO::HIST) { - shmem_used += 2 * sizeof(T) * ncols; - shmem_used += nbins * n_unique_labels * sizeof(int) * ncols; - } else { - shmem_used += nbins * n_unique_labels * sizeof(int) * ncols; - } - ASSERT(shmem_used <= max_shared_mem, "Shared memory per block limit %zd , requested %zd \n", max_shared_mem, shmem_used); - - for (int i = 0; i < MAXSTREAMS; i++) { - tempmem[i] = std::make_shared>(handle, n_sampled_rows, ncols, MAXSTREAMS, unique_labels, n_bins, split_algo); - if (split_algo == SPLIT_ALGO::GLOBAL_QUANTILE) { - preprocess_quantile(data, rowids, n_sampled_rows, ncols, dinfo.NLocalrows, n_bins, tempmem[i]); - } - } - total_temp_mem = tempmem[0]->totalmem; - total_temp_mem *= MAXSTREAMS; - GiniInfo split_info; - MLCommon::TimerCPU timer; - root = grow_tree(data, colper, labels, 0, rowids, n_sampled_rows, split_info); - construct_time = timer.getElapsedSeconds(); - - for (int i = 0; i < MAXSTREAMS; i++) { - tempmem[i].reset(); - } - - return; +template +void DecisionTreeBase::print_node(const std::string &prefix, + const TreeNode *const node, + bool isLeft) const { + if (node != nullptr) { + std::cout << prefix; + + std::cout << (isLeft ? "├" : "└"); + + // print the value of the node + std::cout << node << std::endl; + + // enter the next tree level - left and right branch + print_node(prefix + (isLeft ? "│ " : " "), node->left, true); + print_node(prefix + (isLeft ? "│ " : " "), node->right, false); + } } -template -TreeNode* DecisionTreeClassifier::grow_tree(T *data, const float colper, int *labels, int depth, unsigned int* rowids, - const int n_sampled_rows, GiniInfo prev_split_info) { - - TreeNode *node = new TreeNode(); - GiniQuestion ques; - Question node_ques; - float gain = 0.0; - GiniInfo split_info[3]; // basis, left, right. Populate this - split_info[0] = prev_split_info; - - bool condition = ((depth != 0) && (prev_split_info.best_gini == 0.0f)); // This node is a leaf, no need to search for best split - condition = condition || (n_sampled_rows < min_rows_per_node); // Do not split a node with less than min_rows_per_node samples - - if (!condition) { - find_best_fruit_all(data, labels, colper, ques, gain, rowids, n_sampled_rows, &split_info[0], depth); //ques and gain are output here - condition = condition || (gain == 0.0f); - } - - if (treedepth != -1) - condition = (condition || (depth == treedepth)); - - if (maxleaves != -1) - condition = (condition || (leaf_counter >= maxleaves)); // FIXME not fully respecting maxleaves, but >= constraints it more than == - - if (condition) { - node->class_predict = get_class_hist(split_info[0].hist); - node->gini_val = split_info[0].best_gini; - - leaf_counter++; - if (depth > depth_counter) - depth_counter = depth; - } else { - int nrowsleft, nrowsright; - split_branch(data, ques, n_sampled_rows, nrowsleft, nrowsright, rowids); // populates ques.value - node_ques.update(ques); - node->question = node_ques; - node->left = grow_tree(data, colper, labels, depth+1, &rowids[0], nrowsleft, split_info[1]); - node->right = grow_tree(data, colper, labels, depth+1, &rowids[nrowsleft], nrowsright, split_info[2]); - node->gini_val = split_info[0].best_gini; - } - return node; +template +void DecisionTreeBase::split_branch(T *data, MetricQuestion &ques, + const int n_sampled_rows, + int &nrowsleft, int &nrowsright, + unsigned int *rowids) { + T *temp_data = tempmem[0]->temp_data->data(); + T *sampledcolumn = &temp_data[n_sampled_rows * ques.bootstrapped_column]; + make_split(sampledcolumn, ques, n_sampled_rows, nrowsleft, nrowsright, rowids, + split_algo, tempmem[0]); } +template +void DecisionTreeBase::plant( + const cumlHandle_impl &handle, T *data, const int ncols, const int nrows, + L *labels, unsigned int *rowids, const int n_sampled_rows, int unique_labels, + int maxdepth, int max_leaf_nodes, const float colper, int n_bins, + int split_algo_flag, int cfg_min_rows_per_node, bool cfg_bootstrap_features, + CRITERION cfg_split_criterion, bool quantile_per_tree, + std::shared_ptr> in_tempmem) { + split_algo = split_algo_flag; + dinfo.NLocalrows = nrows; + dinfo.NGlobalrows = nrows; + dinfo.Ncols = ncols; + nbins = n_bins; + treedepth = maxdepth; + maxleaves = max_leaf_nodes; + tempmem.resize(MAXSTREAMS); + n_unique_labels = unique_labels; + min_rows_per_node = cfg_min_rows_per_node; + bootstrap_features = cfg_bootstrap_features; + split_criterion = cfg_split_criterion; + + //Bootstrap features + feature_selector.resize(dinfo.Ncols); + if (bootstrap_features) { + srand(n_bins); + for (int i = 0; i < dinfo.Ncols; i++) { + feature_selector.push_back(rand() % dinfo.Ncols); + } + } else { + std::iota(feature_selector.begin(), feature_selector.end(), 0); + } + + std::random_shuffle(feature_selector.begin(), feature_selector.end()); + feature_selector.resize((int)(colper * dinfo.Ncols)); + + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, handle.getDevice())); + max_shared_mem = prop.sharedMemPerBlock; + + if (split_algo == SPLIT_ALGO::HIST) { + shmem_used += 2 * sizeof(T); + } + if (typeid(L) == typeid(int)) { // Classification + shmem_used += nbins * n_unique_labels * sizeof(int); + } else { // Regression + shmem_used += nbins * sizeof(T) * 3; + shmem_used += nbins * sizeof(int); + } + ASSERT(shmem_used <= max_shared_mem, + "Shared memory per block limit %zd , requested %zd \n", max_shared_mem, + shmem_used); + + for (int i = 0; i < MAXSTREAMS; i++) { + if (in_tempmem != nullptr) { + tempmem[i] = in_tempmem; + } else { + tempmem[i] = std::make_shared>( + handle, n_sampled_rows, ncols, MAXSTREAMS, unique_labels, n_bins, + split_algo); + quantile_per_tree = true; + } + if (split_algo == SPLIT_ALGO::GLOBAL_QUANTILE && + quantile_per_tree == true) { + preprocess_quantile(data, rowids, n_sampled_rows, ncols, dinfo.NLocalrows, + n_bins, tempmem[i]); + } + } + total_temp_mem = tempmem[0]->totalmem; + total_temp_mem *= MAXSTREAMS; + MetricInfo split_info; + MLCommon::TimerCPU timer; + root = grow_tree(data, colper, labels, 0, rowids, n_sampled_rows, split_info); + construct_time = timer.getElapsedSeconds(); + if (in_tempmem == nullptr) { + for (int i = 0; i < MAXSTREAMS; i++) { + tempmem[i].reset(); + } + } +} -template -void DecisionTreeClassifier::find_best_fruit_all(T *data, int *labels, const float colper, GiniQuestion & ques, float& gain, - unsigned int* rowids, const int n_sampled_rows, GiniInfo split_info[3], int depth) { - std::vector& colselector = feature_selector; - - // Optimize ginibefore; no need to compute except for root. - if (depth == 0) { - CUDA_CHECK(cudaHostRegister(colselector.data(), sizeof(int) * colselector.size(), cudaHostRegisterDefault)); - // Copy sampled column IDs to device memory - CUDA_CHECK(cudaMemcpyAsync(tempmem[0]->d_colids->data(), colselector.data(), sizeof(int) * colselector.size(), cudaMemcpyHostToDevice, tempmem[0]->stream)); - CUDA_CHECK(cudaStreamSynchronize(tempmem[0]->stream)); - - int *labelptr = tempmem[0]->sampledlabels->data(); - get_sampled_labels(labels, labelptr, rowids, n_sampled_rows, tempmem[0]->stream); - gini(labelptr, n_sampled_rows, tempmem[0], split_info[0], n_unique_labels); - //Unregister - CUDA_CHECK(cudaHostUnregister(colselector.data())); - } - - int current_nbins = (n_sampled_rows < nbins) ? n_sampled_rows : nbins; - best_split_all_cols(data, rowids, labels, current_nbins, n_sampled_rows, n_unique_labels, dinfo.NLocalrows, colselector, - tempmem[0], &split_info[0], ques, gain, split_algo); +template +TreeNode *DecisionTreeBase::grow_tree( + T *data, const float colper, L *labels, int depth, unsigned int *rowids, + const int n_sampled_rows, MetricInfo prev_split_info) { + TreeNode *node = new TreeNode(); + MetricQuestion ques; + Question node_ques; + float gain = 0.0; + MetricInfo split_info[3]; // basis, left, right. Populate this + split_info[0] = prev_split_info; + + bool condition = + ((depth != 0) && + (prev_split_info.best_metric == + 0.0f)); // This node is a leaf, no need to search for best split + condition = + condition || + (n_sampled_rows < + min_rows_per_node); // Do not split a node with less than min_rows_per_node samples + + if (treedepth != -1) { + condition = (condition || (depth == treedepth)); + } + + if (maxleaves != -1) { + condition = + (condition || + (leaf_counter >= + maxleaves)); // FIXME not fully respecting maxleaves, but >= constraints it more than == + } + + if (!condition) { + find_best_fruit_all(data, labels, colper, ques, gain, rowids, + n_sampled_rows, &split_info[0], + depth); //ques and gain are output here + condition = condition || (gain == 0.0f); + } + + if (condition) { + if (typeid(L) == typeid(int)) { // classification + node->prediction = get_class_hist(split_info[0].hist); + } else { // regression (typeid(L) == typeid(T)) + node->prediction = split_info[0].predict; + } + node->split_metric_val = split_info[0].best_metric; + + leaf_counter++; + if (depth > depth_counter) { + depth_counter = depth; + } + } else { + int nrowsleft, nrowsright; + split_branch(data, ques, n_sampled_rows, nrowsleft, nrowsright, + rowids); // populates ques.value + node_ques.update(ques); + node->question = node_ques; + node->left = grow_tree(data, colper, labels, depth + 1, &rowids[0], + nrowsleft, split_info[1]); + node->right = grow_tree(data, colper, labels, depth + 1, &rowids[nrowsleft], + nrowsright, split_info[2]); + node->split_metric_val = split_info[0].best_metric; + } + return node; } -template -void DecisionTreeClassifier::split_branch(T *data, GiniQuestion & ques, const int n_sampled_rows, int& nrowsleft, - int& nrowsright, unsigned int* rowids) { +template +void DecisionTreeBase::init_depth_zero( + const L *labels, std::vector &colselector, + const unsigned int *rowids, const int n_sampled_rows, + const std::shared_ptr> tempmem) { + CUDA_CHECK(cudaHostRegister(colselector.data(), + sizeof(unsigned int) * colselector.size(), + cudaHostRegisterDefault)); + // Copy sampled column IDs to device memory + MLCommon::updateDevice(tempmem->d_colids->data(), colselector.data(), + colselector.size(), tempmem->stream); + CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); + + L *labelptr = tempmem->sampledlabels->data(); + get_sampled_labels(labels, labelptr, rowids, n_sampled_rows, + tempmem->stream); + + //Unregister + CUDA_CHECK(cudaHostUnregister(colselector.data())); +} - T *temp_data = tempmem[0]->temp_data->data(); - T *sampledcolumn = &temp_data[n_sampled_rows * ques.bootstrapped_column]; - make_split(sampledcolumn, ques, n_sampled_rows, nrowsleft, nrowsright, rowids, split_algo, tempmem[0]); +/** + * @brief Predict target feature for input data; n-ary classification or regression for single feature supported. Inference of trees is CPU only for now. + * @tparam T: data type for input data (float or double). + * @tparam L: data type for labels (int type for classification, T type for regression). + * @param[in] handle: cumlHandle (currently unused; API placeholder) + * @param[in] rows: test data (n_rows samples, n_cols features) in row major format. Current impl. expects a CPU pointer. TODO future API change. + * @param[in] n_rows: number of data samples. + * @param[in] n_cols: number of features (excluding target feature). + * @param[in,out] predictions: n_rows predicted labels. Current impl. expects a CPU pointer, user allocated. TODO future API change. + * @param[in] verbose: flag for debugging purposes. + */ +template +void DecisionTreeBase::predict(const ML::cumlHandle &handle, + const T *rows, const int n_rows, + const int n_cols, L *predictions, + bool verbose) const { + ASSERT(!is_dev_ptr(rows) && !is_dev_ptr(predictions), + "DT Error: Current impl. expects both input and predictions to be CPU " + "pointers.\n"); + + ASSERT(root, "Cannot predict w/ empty tree!"); + ASSERT((n_rows > 0), "Invalid n_rows %d", n_rows); + ASSERT((n_cols > 0), "Invalid n_cols %d", n_cols); + + predict_all(rows, n_rows, n_cols, predictions, verbose); } -template -void DecisionTreeClassifier::classify_all(const T * rows, const int n_rows, const int n_cols, int* preds, bool verbose) const { - for (int row_id = 0; row_id < n_rows; row_id++) { - preds[row_id] = classify(&rows[row_id * n_cols], root, verbose); - } - return; +template +void DecisionTreeBase::predict_all(const T *rows, const int n_rows, + const int n_cols, L *preds, + bool verbose) const { + for (int row_id = 0; row_id < n_rows; row_id++) { + preds[row_id] = predict_one(&rows[row_id * n_cols], root, verbose); + } } -template -int DecisionTreeClassifier::classify(const T * row, const TreeNode* const node, bool verbose) const { - - Question q = node->question; - if (node->left && (row[q.column] <= q.value)) { - if (verbose) - std::cout << "Classifying Left @ node w/ column " << q.column << " and value " << q.value << std::endl; - return classify(row, node->left, verbose); - } else if (node->right && (row[q.column] > q.value)) { - if (verbose) - std::cout << "Classifying Right @ node w/ column " << q.column << " and value " << q.value << std::endl; - return classify(row, node->right, verbose); - } else { - if (verbose) - std::cout << "Leaf node. Predicting " << node->class_predict << std::endl; - return node->class_predict; - } +template +L DecisionTreeBase::predict_one(const T *row, + const TreeNode *const node, + bool verbose) const { + Question q = node->question; + if (node->left && (row[q.column] <= q.value)) { + if (verbose) { + std::cout << "Classifying Left @ node w/ column " << q.column + << " and value " << q.value << std::endl; + } + return predict_one(row, node->left, verbose); + } else if (node->right && (row[q.column] > q.value)) { + if (verbose) { + std::cout << "Classifying Right @ node w/ column " << q.column + << " and value " << q.value << std::endl; + } + return predict_one(row, node->right, verbose); + } else { + if (verbose) { + std::cout << "Leaf node. Predicting " << node->prediction << std::endl; + } + return node->prediction; + } } -template -void DecisionTreeClassifier::print_node(const std::string& prefix, const TreeNode* const node, bool isLeft) const { +template +void DecisionTreeBase::base_fit( + const ML::cumlHandle &handle, T *data, const int ncols, const int nrows, + L *labels, unsigned int *rowids, const int n_sampled_rows, int unique_labels, + DecisionTreeParams &tree_params, bool is_classifier, + std::shared_ptr> in_tempmem) { + const char *CRITERION_NAME[] = {"GINI", "ENTROPY", "MSE", "MAE", "END"}; + CRITERION default_criterion = + (is_classifier) ? CRITERION::GINI : CRITERION::MSE; + CRITERION last_criterion = + (is_classifier) ? CRITERION::ENTROPY : CRITERION::MAE; + + tree_params.validity_check(); + if (tree_params.n_bins > n_sampled_rows) { + std::cout << "Warning! Calling with number of bins > number of rows! "; + std::cout << "Resetting n_bins to " << n_sampled_rows << "." << std::endl; + tree_params.n_bins = n_sampled_rows; + } + + if ( + tree_params.split_criterion == + CRITERION:: + CRITERION_END) { // Set default to GINI (classification) or MSE (regression) + tree_params.split_criterion = default_criterion; + } + ASSERT((tree_params.split_criterion >= default_criterion) && + (tree_params.split_criterion <= last_criterion), + "Unsupported criterion %s\n", + CRITERION_NAME[tree_params.split_criterion]); + + plant(handle.getImpl(), data, ncols, nrows, labels, rowids, n_sampled_rows, + unique_labels, tree_params.max_depth, tree_params.max_leaves, + tree_params.max_features, tree_params.n_bins, tree_params.split_algo, + tree_params.min_rows_per_node, tree_params.bootstrap_features, + tree_params.split_criterion, tree_params.quantile_per_tree, in_tempmem); +} - if (node != nullptr) { - std::cout << prefix; +/** + * @brief Build (i.e., fit, train) Decision Tree classifier for input data. + * @tparam T: data type for input data (float or double). + * @param[in] handle: cumlHandle + * @param[in] data: train data (nrows samples, ncols features) in column major format, excluding labels. Device pointer. + * @param[in] ncols: number of features (i.e., columns) excluding target feature. + * @param[in] nrows: number of training data samples of the whole unsampled dataset. + * @param[in] labels: 1D array of target features (int only). One label per training sample. Device pointer. + Assumption: labels need to be preprocessed to map to ascending numbers from 0; + needed for current gini impl. in decision tree. + * @param[in,out] rowids: array of n_sampled_rows integers in [0, nrows) range. Device pointer. + The same array is then rearranged when splits are made, allowing us to construct trees without rearranging the actual dataset. + * @param[in] n_sampled_rows: number of training samples, after sampling. If using decision tree directly over the whole dataset: n_sampled_rows = nrows + * @param[in] n_unique_labels: #unique label values. Number of categories of classification. + * @param[in] tree_params: Decision Tree training hyper parameter struct. + */ +template +void DecisionTreeClassifier::fit( + const ML::cumlHandle &handle, T *data, const int ncols, const int nrows, + int *labels, unsigned int *rowids, const int n_sampled_rows, + int unique_labels, DecisionTreeParams tree_params, + std::shared_ptr> in_tempmem) { + this->base_fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, + unique_labels, tree_params, true, in_tempmem); +} - std::cout << (isLeft ? "├" : "└" ); +template +void DecisionTreeClassifier::find_best_fruit_all( + T *data, int *labels, const float colper, MetricQuestion &ques, + float &gain, unsigned int *rowids, const int n_sampled_rows, + MetricInfo split_info[3], int depth) { + std::vector &colselector = this->feature_selector; + + // Optimize ginibefore; no need to compute except for root. + if (depth == 0) { + this->init_depth_zero(labels, colselector, rowids, n_sampled_rows, + this->tempmem[0]); + int *labelptr = this->tempmem[0]->sampledlabels->data(); + if (this->split_criterion == CRITERION::GINI) { + gini(labelptr, n_sampled_rows, this->tempmem[0], + split_info[0], this->n_unique_labels); + } else { + gini(labelptr, n_sampled_rows, this->tempmem[0], + split_info[0], this->n_unique_labels); + } + } + + // Do not update bin count for the GLOBAL_QUANTILE split algorithm, as all potential split points were precomputed. + int current_nbins = ((this->split_algo != SPLIT_ALGO::GLOBAL_QUANTILE) && + (n_sampled_rows < this->nbins)) + ? n_sampled_rows + : this->nbins; + + if (this->split_criterion == CRITERION::GINI) { + best_split_all_cols_classifier( + data, rowids, labels, current_nbins, n_sampled_rows, + this->n_unique_labels, this->dinfo.NLocalrows, colselector, + this->tempmem[0], &split_info[0], ques, gain, this->split_algo, + this->max_shared_mem); + } else { + best_split_all_cols_classifier( + data, rowids, labels, current_nbins, n_sampled_rows, + this->n_unique_labels, this->dinfo.NLocalrows, colselector, + this->tempmem[0], &split_info[0], ques, gain, this->split_algo, + this->max_shared_mem); + } +} - // print the value of the node - std::cout << node << std::endl; +/** + * @brief Build (i.e., fit, train) Decision Tree regressor for input data. + * @tparam T: data type for input data (float or double). + * @param[in] handle: cumlHandle + * @param[in] data: train data (nrows samples, ncols features) in column major format, excluding labels. Device pointer. + * @param[in] ncols: number of features (i.e., columns) excluding target feature. + * @param[in] nrows: number of training data samples of the whole unsampled dataset. + * @param[in] labels: 1D array of target features (float or double). One label per training sample. Device pointer. + * @param[in,out] rowids: array of n_sampled_rows integers in [0, nrows) range. Device pointer. + The same array is then rearranged when splits are made, allowing us to construct trees without rearranging the actual dataset. + * @param[in] n_sampled_rows: number of training samples, after sampling. If using decision tree directly over the whole dataset: n_sampled_rows = nrows + * @param[in] tree_params: Decision Tree training hyper parameter struct. + */ +template +void DecisionTreeRegressor::fit( + const ML::cumlHandle &handle, T *data, const int ncols, const int nrows, + T *labels, unsigned int *rowids, const int n_sampled_rows, + DecisionTreeParams tree_params, + std::shared_ptr> in_tempmem) { + this->base_fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, 1, + tree_params, false, in_tempmem); +} - // enter the next tree level - left and right branch - print_node( prefix + (isLeft ? "│ " : " "), node->left, true); - print_node( prefix + (isLeft ? "│ " : " "), node->right, false); - } +template +void DecisionTreeRegressor::find_best_fruit_all( + T *data, T *labels, const float colper, MetricQuestion &ques, float &gain, + unsigned int *rowids, const int n_sampled_rows, MetricInfo split_info[3], + int depth) { + std::vector &colselector = this->feature_selector; + + if (depth == 0) { + this->init_depth_zero(labels, colselector, rowids, n_sampled_rows, + this->tempmem[0]); + T *labelptr = this->tempmem[0]->sampledlabels->data(); + if (this->split_criterion == CRITERION::MSE) { + mse(labelptr, n_sampled_rows, this->tempmem[0], + split_info[0]); + } else { + mse(labelptr, n_sampled_rows, this->tempmem[0], + split_info[0]); + } + } + + // Do not update bin count for the GLOBAL_QUANTILE split algorithm, as all potential split points were precomputed. + int current_nbins = ((this->split_algo != SPLIT_ALGO::GLOBAL_QUANTILE) && + (n_sampled_rows < this->nbins)) + ? n_sampled_rows + : this->nbins; + + if (this->split_criterion == CRITERION::MSE) { + best_split_all_cols_regressor( + data, rowids, labels, current_nbins, n_sampled_rows, + this->dinfo.NLocalrows, colselector, this->tempmem[0], split_info, ques, + gain, this->split_algo, this->max_shared_mem); + } else { + best_split_all_cols_regressor( + data, rowids, labels, current_nbins, n_sampled_rows, + this->dinfo.NLocalrows, colselector, this->tempmem[0], split_info, ques, + gain, this->split_algo, this->max_shared_mem); + } } //Class specializations +template class DecisionTreeBase; +template class DecisionTreeBase; +template class DecisionTreeBase; +template class DecisionTreeBase; + template class DecisionTreeClassifier; template class DecisionTreeClassifier; -} //End namespace DecisionTree +template class DecisionTreeRegressor; +template class DecisionTreeRegressor; + +} //End namespace DecisionTree +// Stateless API functions + +// ----------------------------- Classification ----------------------------------- // /** * @brief Build (i.e., fit, train) Decision Tree classifier for input data. @@ -372,9 +600,13 @@ template class DecisionTreeClassifier; * @param[in] n_unique_labels: #unique label values. Number of categories of classification. * @param[in] tree_params: Decision Tree training hyper parameter struct */ -void fit(const ML::cumlHandle& handle, DecisionTree::DecisionTreeClassifier * dt_classifier, float *data, const int ncols, const int nrows, int *labels, - unsigned int *rowids, const int n_sampled_rows, int unique_labels, DecisionTree::DecisionTreeParams tree_params) { - dt_classifier->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, unique_labels, tree_params); +void fit(const ML::cumlHandle &handle, + DecisionTree::DecisionTreeClassifier *dt_classifier, + float *data, const int ncols, const int nrows, int *labels, + unsigned int *rowids, const int n_sampled_rows, int unique_labels, + DecisionTree::DecisionTreeParams tree_params) { + dt_classifier->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, + unique_labels, tree_params); } /** @@ -392,38 +624,131 @@ void fit(const ML::cumlHandle& handle, DecisionTree::DecisionTreeClassifier * dt_classifier, double *data, const int ncols, const int nrows, int *labels, - unsigned int *rowids, const int n_sampled_rows, int unique_labels, DecisionTree::DecisionTreeParams tree_params) { - dt_classifier->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, unique_labels, tree_params); + */ +void fit(const ML::cumlHandle &handle, + DecisionTree::DecisionTreeClassifier *dt_classifier, + double *data, const int ncols, const int nrows, int *labels, + unsigned int *rowids, const int n_sampled_rows, int unique_labels, + DecisionTree::DecisionTreeParams tree_params) { + dt_classifier->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, + unique_labels, tree_params); } /** * @brief Predict target feature for input data; n-ary classification for single feature supported. Inference of trees is CPU only for now. * @param[in] handle: cumlHandle (currently unused; API placeholder) - * @param[in] dt_classifier: Pointer to decision tree object, which holds the trained tree. + * @param[in] dt_classifier: Pointer to decision tree object, which holds the trained tree. * @param[in] rows: test data type float (n_rows samples, n_cols features) in row major format. CPU pointer. * @param[in] n_rows: number of data samples. * @param[in] n_cols: number of features (excluding target feature). * @param[in,out] predictions: n_rows predicted labels. CPU pointer, user allocated. * @param[in] verbose: flag for debugging purposes. */ -void predict(const ML::cumlHandle& handle, const DecisionTree::DecisionTreeClassifier * dt_classifier, const float * rows, const int n_rows, const int n_cols, int* predictions, bool verbose) { - return dt_classifier->predict(handle, rows, n_rows, n_cols, predictions, verbose); +void predict(const ML::cumlHandle &handle, + const DecisionTree::DecisionTreeClassifier *dt_classifier, + const float *rows, const int n_rows, const int n_cols, + int *predictions, bool verbose) { + return dt_classifier->predict(handle, rows, n_rows, n_cols, predictions, + verbose); } - + /** * @brief Predict target feature for input data; n-ary classification for single feature supported. Inference of trees is CPU only for now. * @param[in] handle: cumlHandle (currently unused; API placeholder) - * @param[in] dt_classifier: Pointer to decision tree object, which holds the trained tree. + * @param[in] dt_classifier: Pointer to decision tree object, which holds the trained tree. + * @param[in] rows: test data type double (n_rows samples, n_cols features) in row major format. CPU pointer. + * @param[in] n_rows: number of data samples. + * @param[in] n_cols: number of features (excluding target feature). + * @param[in,out] predictions: n_rows predicted labels. CPU pointer, user allocated. + * @param[in] verbose: flag for debugging purposes. + */ +void predict(const ML::cumlHandle &handle, + const DecisionTree::DecisionTreeClassifier *dt_classifier, + const double *rows, const int n_rows, const int n_cols, + int *predictions, bool verbose) { + return dt_classifier->predict(handle, rows, n_rows, n_cols, predictions, + verbose); +} + +// ----------------------------- Regression ----------------------------------- // + +/** + * @brief Build (i.e., fit, train) Decision Tree regressor for input data. + * @param[in] handle: cumlHandle + * @param[in,out] dt_regressor: Pointer to Decision Tree Regressor object. The object holds the trained tree. + * @param[in] data: train data in float (nrows samples, ncols features) in column major format, excluding labels. Device pointer. + * @param[in] ncols: number of features (i.e., columns) excluding target feature. + * @param[in] nrows: number of training data samples of the whole unsampled dataset. + * @param[in] labels: 1D array of target features (type float). One label per training sample. Device pointer. + * @param[in,out] rowids: This array consists of integers from (0 - n_sampled_rows), the same array is then rearranged when splits are made. This allows, us to contruct trees without rearranging the actual dataset. Device pointer. + * @param[in] n_sampled_rows: number of training samples, after sampling. If using decsion tree directly over the whole dataset (n_sampled_rows = nrows) + * @param[in] tree_params: Decision Tree training hyper parameter struct + */ +void fit(const ML::cumlHandle &handle, + DecisionTree::DecisionTreeRegressor *dt_regressor, float *data, + const int ncols, const int nrows, float *labels, unsigned int *rowids, + const int n_sampled_rows, + DecisionTree::DecisionTreeParams tree_params) { + dt_regressor->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, + tree_params); +} + +/** + * @brief Build (i.e., fit, train) Decision Tree regressor for input data. + * @param[in] handle: cumlHandle + * @param[in,out] dt_regressor: Pointer to Decision Tree Regressor object. The object holds the trained tree. + * @param[in] data: train data in double (nrows samples, ncols features) in column major format, excluding labels. Device pointer. + * @param[in] ncols: number of features (i.e., columns) excluding target feature. + * @param[in] nrows: number of training data samples of the whole unsampled dataset. + * @param[in] labels: 1D array of target features (type float). One label per training sample. Device pointer. + * @param[in,out] rowids: array of n_sampled_rows integers in [0, nrows) range. Device pointer. + The same array is then rearranged when splits are made, allowing us to construct trees without rearranging the actual dataset. + * @param[in] n_sampled_rows: number of training samples, after sampling. If using decsion tree directly over the whole dataset (n_sampled_rows = nrows) + * @param[in] tree_params: Decision Tree training hyper parameter struct. + */ +void fit(const ML::cumlHandle &handle, + DecisionTree::DecisionTreeRegressor *dt_regressor, + double *data, const int ncols, const int nrows, double *labels, + unsigned int *rowids, const int n_sampled_rows, + DecisionTree::DecisionTreeParams tree_params) { + dt_regressor->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, + tree_params); +} + +/** + * @brief Predict target feature for input data; regression for single feature supported. Inference of trees is CPU only for now. + * @param[in] handle: cumlHandle (currently unused; API placeholder) + * @param[in] dt_regressor: Pointer to decision tree object, which holds the trained tree. + * @param[in] rows: test data type float (n_rows samples, n_cols features) in row major format. CPU pointer. + * @param[in] n_rows: number of data samples. + * @param[in] n_cols: number of features (excluding target feature). + * @param[in,out] predictions: n_rows predicted labels. CPU pointer, user allocated. + * @param[in] verbose: flag for debugging purposes. + */ +void predict(const ML::cumlHandle &handle, + const DecisionTree::DecisionTreeRegressor *dt_regressor, + const float *rows, const int n_rows, const int n_cols, + float *predictions, bool verbose) { + return dt_regressor->predict(handle, rows, n_rows, n_cols, predictions, + verbose); +} + +/** + * @brief Predict target feature for input data; regression for single feature supported. Inference of trees is CPU only for now. + * @param[in] handle: cumlHandle (currently unused; API placeholder) + * @param[in] dt_regressor: Pointer to decision tree object, which holds the trained tree. * @param[in] rows: test data type double (n_rows samples, n_cols features) in row major format. CPU pointer. * @param[in] n_rows: number of data samples. * @param[in] n_cols: number of features (excluding target feature). * @param[in,out] predictions: n_rows predicted labels. CPU pointer, user allocated. * @param[in] verbose: flag for debugging purposes. */ -void predict(const ML::cumlHandle& handle, const DecisionTree::DecisionTreeClassifier * dt_classifier, const double * rows, const int n_rows, const int n_cols, int* predictions, bool verbose) { - return dt_classifier->predict(handle, rows, n_rows, n_cols, predictions, verbose); +void predict(const ML::cumlHandle &handle, + const DecisionTree::DecisionTreeRegressor *dt_regressor, + const double *rows, const int n_rows, const int n_cols, + double *predictions, bool verbose) { + return dt_regressor->predict(handle, rows, n_rows, n_cols, predictions, + verbose); } -} //End namespace ML +} //End namespace ML diff --git a/cpp/src/decisiontree/decisiontree.h b/cpp/src/decisiontree/decisiontree.h index 498a19a342..7bbf0e0bc4 100644 --- a/cpp/src/decisiontree/decisiontree.h +++ b/cpp/src/decisiontree/decisiontree.h @@ -15,153 +15,250 @@ */ #pragma once -#include "algo_helper.h" +#include +#include +#include #include -#include "kernels/gini_def.h" -#include "memory.h" +#include +#include +#include +#include "algo_helper.h" +#include "kernels/metric_def.h" namespace ML { + +bool is_dev_ptr(const void *p); + namespace DecisionTree { -template +template struct Question { - int column; - T value; - void update(const GiniQuestion & ques); + int column; + T value; + void update(const MetricQuestion &ques); }; -template +template struct TreeNode { - TreeNode *left = nullptr; - TreeNode *right = nullptr; - int class_predict; - Question question; - T gini_val; + TreeNode *left = nullptr; + TreeNode *right = nullptr; + L prediction; + Question question; + T split_metric_val; - void print(std::ostream& os) const; + void print(std::ostream &os) const; }; struct DataInfo { - unsigned int NLocalrows; - unsigned int NGlobalrows; - unsigned int Ncols; + unsigned int NLocalrows; + unsigned int NGlobalrows; + unsigned int Ncols; }; - struct DecisionTreeParams { - /** - * Maximum tree depth. Unlimited (e.g., until leaves are pure), if -1. - */ - int max_depth = -1; - /** - * Maximum leaf nodes per tree. Soft constraint. Unlimited, if -1. - */ - int max_leaves = -1; - /** - * Ratio of number of features (columns) to consider per node split. - TODO SKL's default is sqrt(n_cols) - */ - float max_features = 1.0; - /** - * Number of bins used by the split algorithm. - */ - int n_bins = 8; - /** - * The split algorithm: HIST or GLOBAL_QUANTILE. - */ - int split_algo = SPLIT_ALGO::HIST; - /** - * The minimum number of samples (rows) needed to split a node. - */ - int min_rows_per_node = 2; - /** - * Wheather to bootstarp columns with or without replacement - */ - bool bootstrap_features = false; - - DecisionTreeParams(); - DecisionTreeParams(int cfg_max_depth, int cfg_max_leaves, float cfg_max_features, int cfg_n_bins, int cfg_split_aglo, int cfg_min_rows_per_node, bool cfg_bootstrap_features); - void validity_check() const; - void print() const; + /** + * Maximum tree depth. Unlimited (e.g., until leaves are pure), if -1. + */ + int max_depth = -1; + /** + * Maximum leaf nodes per tree. Soft constraint. Unlimited, if -1. + */ + int max_leaves = -1; + /** + * Ratio of number of features (columns) to consider per node split. + * TODO SKL's default is sqrt(n_cols) + */ + float max_features = 1.0; + /** + * Number of bins used by the split algorithm. + */ + int n_bins = 8; + /** + * The split algorithm: HIST or GLOBAL_QUANTILE. + */ + int split_algo = SPLIT_ALGO::HIST; + /** + * The minimum number of samples (rows) needed to split a node. + */ + int min_rows_per_node = 2; + /** + * Whether to bootstrap columns with or without replacement. + */ + bool bootstrap_features = false; + /** + * Whether a quantile needs to be computed for individual trees in RF. + * Default: compute quantiles once per RF. Only affects GLOBAL_QUANTILE split_algo. + **/ + bool quantile_per_tree = false; + /** + * Node split criterion. GINI and Entropy for classification, MSE or MAE for regression. + */ + CRITERION split_criterion = CRITERION_END; + + DecisionTreeParams(); + DecisionTreeParams(int cfg_max_depth, int cfg_max_leaves, + float cfg_max_features, int cfg_n_bins, int cfg_split_aglo, + int cfg_min_rows_per_node, bool cfg_bootstrap_features, + CRITERION cfg_split_criterion, bool cfg_quantile_per_tree); + void validity_check() const; + void print() const; }; -template -class DecisionTreeClassifier { - -private: - int split_algo; - TreeNode *root = nullptr; - int nbins; - DataInfo dinfo; - int treedepth; - int depth_counter = 0; - int maxleaves; - int leaf_counter = 0; - std::vector>> tempmem; - size_t total_temp_mem; - const int MAXSTREAMS = 1; - size_t max_shared_mem; - size_t shmem_used = 0; - int n_unique_labels = -1; // number of unique labels in dataset - double construct_time; - int min_rows_per_node; - bool bootstrap_features; - std::vector feature_selector; - -public: - // Expects column major T dataset, integer labels - // data, labels are both device ptr. - // Assumption: labels are all mapped to contiguous numbers starting from 0 during preprocessing. Needed for gini hist impl. - void fit(const ML::cumlHandle& handle, T *data, const int ncols, const int nrows, int *labels, unsigned int *rowids, - const int n_sampled_rows, const int unique_labels, DecisionTreeParams tree_params); - - /* Predict labels for n_rows rows, with n_cols features each, for a given tree. rows in row-major format. */ - void predict(const ML::cumlHandle& handle, const T * rows, const int n_rows, const int n_cols, int* predictions, bool verbose=false) const; - - // Printing utility for high level tree info. - void print_tree_summary() const; - - // Printing utility for debug and looking at nodes and leaves. - void print() const; - -private: - // Same as above fit, but planting is better for a tree then fitting. - void plant(const cumlHandle_impl& handle, T *data, const int ncols, const int nrows, int *labels, unsigned int *rowids, const int n_sampled_rows, int unique_labels, - int maxdepth = -1, int max_leaf_nodes = -1, const float colper = 1.0, int n_bins = 8, int split_algo_flag = SPLIT_ALGO::HIST, int cfg_min_rows_per_node=2, bool cfg_bootstrap_features=false); - - TreeNode * grow_tree(T *data, const float colper, int *labels, int depth, unsigned int* rowids, const int n_sampled_rows, GiniInfo prev_split_info); - - /* depth is used to distinguish between root and other tree nodes for computations */ - void find_best_fruit_all(T *data, int *labels, const float colper, GiniQuestion & ques, float& gain, unsigned int* rowids, - const int n_sampled_rows, GiniInfo split_info[3], int depth); - void split_branch(T *data, GiniQuestion & ques, const int n_sampled_rows, int& nrowsleft, int& nrowsright, unsigned int* rowids); - void classify_all(const T * rows, const int n_rows, const int n_cols, int* preds, bool verbose=false) const; - int classify(const T * row, const TreeNode * const node, bool verbose=false) const; - void print_node(const std::string& prefix, const TreeNode* const node, bool isLeft) const; -}; // End DecisionTree Class - -} //End namespace DecisionTree +template +class DecisionTreeBase { + protected: + int split_algo; + TreeNode *root = nullptr; + int nbins; + DataInfo dinfo; + int treedepth; + int depth_counter = 0; + int maxleaves; + int leaf_counter = 0; + std::vector>> tempmem; + size_t total_temp_mem; + const int MAXSTREAMS = 1; + size_t max_shared_mem; + size_t shmem_used = 0; + int n_unique_labels = -1; // number of unique labels in dataset + double construct_time; + int min_rows_per_node; + bool bootstrap_features; + CRITERION split_criterion; + std::vector feature_selector; + + void print_node(const std::string &prefix, const TreeNode *const node, + bool isLeft) const; + void split_branch(T *data, MetricQuestion &ques, const int n_sampled_rows, + int &nrowsleft, int &nrowsright, unsigned int *rowids); + + void plant(const cumlHandle_impl &handle, T *data, const int ncols, + const int nrows, L *labels, unsigned int *rowids, + const int n_sampled_rows, int unique_labels, int maxdepth = -1, + int max_leaf_nodes = -1, const float colper = 1.0, int n_bins = 8, + int split_algo_flag = SPLIT_ALGO::HIST, + int cfg_min_rows_per_node = 2, bool cfg_bootstrap_features = false, + CRITERION cfg_split_criterion = CRITERION::CRITERION_END, + bool cfg_quantile_per_tree = false, + std::shared_ptr> in_tempmem = nullptr); + void init_depth_zero(const L *labels, std::vector &colselector, + const unsigned int *rowids, const int n_sampled_rows, + const std::shared_ptr> tempmem); + TreeNode *grow_tree(T *data, const float colper, L *labels, int depth, + unsigned int *rowids, const int n_sampled_rows, + MetricInfo prev_split_info); + virtual void find_best_fruit_all(T *data, L *labels, const float colper, + MetricQuestion &ques, float &gain, + unsigned int *rowids, + const int n_sampled_rows, + MetricInfo split_info[3], int depth) = 0; + void base_fit(const ML::cumlHandle &handle, T *data, const int ncols, + const int nrows, L *labels, unsigned int *rowids, + const int n_sampled_rows, int unique_labels, + DecisionTreeParams &tree_params, bool is_classifier, + std::shared_ptr> in_tempmem); + + public: + // Printing utility for high level tree info. + void print_tree_summary() const; + + // Printing utility for debug and looking at nodes and leaves. + void print() const; + + // Predict labels for n_rows rows, with n_cols features each, for a given tree. rows in row-major format. + void predict(const ML::cumlHandle &handle, const T *rows, const int n_rows, + const int n_cols, L *predictions, bool verbose = false) const; + void predict_all(const T *rows, const int n_rows, const int n_cols, L *preds, + bool verbose = false) const; + L predict_one(const T *row, const TreeNode *const node, + bool verbose = false) const; + +}; // End DecisionTreeBase Class + +template +class DecisionTreeClassifier : public DecisionTreeBase { + public: + // Expects column major T dataset, integer labels + // data, labels are both device ptr. + // Assumption: labels are all mapped to contiguous numbers starting from 0 during preprocessing. Needed for gini hist impl. + void fit(const ML::cumlHandle &handle, T *data, const int ncols, + const int nrows, int *labels, unsigned int *rowids, + const int n_sampled_rows, const int unique_labels, + DecisionTreeParams tree_params, + std::shared_ptr> in_tempmem = nullptr); + + private: + /* depth is used to distinguish between root and other tree nodes for computations */ + void find_best_fruit_all(T *data, int *labels, const float colper, + MetricQuestion &ques, float &gain, + unsigned int *rowids, const int n_sampled_rows, + MetricInfo split_info[3], int depth); +}; // End DecisionTreeClassifier Class + +template +class DecisionTreeRegressor : public DecisionTreeBase { + public: + void fit(const ML::cumlHandle &handle, T *data, const int ncols, + const int nrows, T *labels, unsigned int *rowids, + const int n_sampled_rows, DecisionTreeParams tree_params, + std::shared_ptr> in_tempmem = nullptr); + + private: + /* depth is used to distinguish between root and other tree nodes for computations */ + void find_best_fruit_all(T *data, T *labels, const float colper, + MetricQuestion &ques, float &gain, + unsigned int *rowids, const int n_sampled_rows, + MetricInfo split_info[3], int depth); +}; // End DecisionTreeRegressor Class + +} //End namespace DecisionTree // Stateless API functions -void fit(const ML::cumlHandle& handle, - DecisionTree::DecisionTreeClassifier* dt_classifier, - float* data, const int ncols, const int nrows, int* labels, - unsigned int* rowids, const int n_sampled_rows, int unique_labels, + +// ----------------------------- Classification ----------------------------------- // + +void fit(const ML::cumlHandle &handle, + DecisionTree::DecisionTreeClassifier *dt_classifier, + float *data, const int ncols, const int nrows, int *labels, + unsigned int *rowids, const int n_sampled_rows, int unique_labels, + DecisionTree::DecisionTreeParams tree_params); + +void fit(const ML::cumlHandle &handle, + DecisionTree::DecisionTreeClassifier *dt_classifier, + double *data, const int ncols, const int nrows, int *labels, + unsigned int *rowids, const int n_sampled_rows, int unique_labels, + DecisionTree::DecisionTreeParams tree_params); + +void predict(const ML::cumlHandle &handle, + const DecisionTree::DecisionTreeClassifier *dt_classifier, + const float *rows, const int n_rows, const int n_cols, + int *predictions, bool verbose = false); +void predict(const ML::cumlHandle &handle, + const DecisionTree::DecisionTreeClassifier *dt_classifier, + const double *rows, const int n_rows, const int n_cols, + int *predictions, bool verbose = false); + +// ----------------------------- Regression ----------------------------------- // + +void fit(const ML::cumlHandle &handle, + DecisionTree::DecisionTreeRegressor *dt_regressor, float *data, + const int ncols, const int nrows, float *labels, unsigned int *rowids, + const int n_sampled_rows, DecisionTree::DecisionTreeParams tree_params); -void fit(const ML::cumlHandle& handle, - DecisionTree::DecisionTreeClassifier* dt_classifier, - double* data, const int ncols, const int nrows, int* labels, - unsigned int* rowids, const int n_sampled_rows, int unique_labels, +void fit(const ML::cumlHandle &handle, + DecisionTree::DecisionTreeRegressor *dt_regressor, + double *data, const int ncols, const int nrows, double *labels, + unsigned int *rowids, const int n_sampled_rows, DecisionTree::DecisionTreeParams tree_params); -void predict(const ML::cumlHandle& handle, - const DecisionTree::DecisionTreeClassifier* dt_classifier, - const float* rows, const int n_rows, const int n_cols, - int* predictions, bool verbose = false); -void predict(const ML::cumlHandle& handle, - const DecisionTree::DecisionTreeClassifier* dt_classifier, - const double* rows, const int n_rows, const int n_cols, - int* predictions, bool verbose = false); +void predict(const ML::cumlHandle &handle, + const DecisionTree::DecisionTreeRegressor *dt_regressor, + const float *rows, const int n_rows, const int n_cols, + float *predictions, bool verbose = false); +void predict(const ML::cumlHandle &handle, + const DecisionTree::DecisionTreeRegressor *dt_regressor, + const double *rows, const int n_rows, const int n_cols, + double *predictions, bool verbose = false); } //End namespace ML diff --git a/cpp/src/decisiontree/kernels/batch_cal.cuh b/cpp/src/decisiontree/kernels/batch_cal.cuh new file mode 100644 index 0000000000..10681d1033 --- /dev/null +++ b/cpp/src/decisiontree/kernels/batch_cal.cuh @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +/* Return max. possible number of columns that can be processed within avail_shared_memory. + Expects that requested_shared_memory is a multiple of ncols. */ +int get_batch_cols_cnt(const size_t avail_shared_memory, + const size_t requested_shared_memory, const int ncols) { + int ncols_in_batch = ncols; + int ncols_factor = requested_shared_memory / ncols; + if (requested_shared_memory > avail_shared_memory) { + ncols_in_batch = avail_shared_memory / ncols_factor; // floor div. + } + return ncols_in_batch; +} + +/* Update batch_ncols (max. possible number of columns that can be processed within avail_shared_memory), + blocks (for next kernel launch), and shmemsize (requested shared memory for next kernel launch). + Precondition: requested_shared_memory is a multiple of ncols. */ +void update_kernel_config(const size_t avail_shared_memory, + const size_t requested_shared_memory, const int ncols, + const int nrows, const int threads, int& batch_ncols, + int& blocks, size_t& shmemsize) { + batch_ncols = + get_batch_cols_cnt(avail_shared_memory, requested_shared_memory, ncols); + shmemsize = + (requested_shared_memory / ncols) * + batch_ncols; // requested_shared_memory is a multiple of ncols for all kernels + blocks = min(MLCommon::ceildiv(batch_ncols * nrows, threads), 65536); +} diff --git a/cpp/src/decisiontree/kernels/col_condenser.cuh b/cpp/src/decisiontree/kernels/col_condenser.cuh index 158fce94fd..891d28617f 100644 --- a/cpp/src/decisiontree/kernels/col_condenser.cuh +++ b/cpp/src/decisiontree/kernels/col_condenser.cuh @@ -29,10 +29,12 @@ __global__ void get_sampled_column_kernel( return; } -void get_sampled_labels(const int* labels, int* outlabels, unsigned int* rowids, - const int n_sampled_rows, const cudaStream_t stream) { +template +void get_sampled_labels(const T* labels, T* outlabels, + const unsigned int* rowids, const int n_sampled_rows, + const cudaStream_t stream) { int threads = 128; - get_sampled_column_kernel + get_sampled_column_kernel <<>>( labels, outlabels, rowids, n_sampled_rows); CUDA_CHECK(cudaGetLastError()); @@ -42,7 +44,7 @@ void get_sampled_labels(const int* labels, int* outlabels, unsigned int* rowids, template __global__ void allcolsampler_kernel(const T* __restrict__ data, const unsigned int* __restrict__ rowids, - const int* __restrict__ colids, + const unsigned int* __restrict__ colids, const int nrows, const int ncols, const int rowoffset, T* sampledcols) { int tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -50,58 +52,19 @@ __global__ void allcolsampler_kernel(const T* __restrict__ data, for (unsigned int i = tid; i < nrows * ncols; i += blockDim.x * gridDim.x) { int newcolid = (int)(i / nrows); int myrowstart; - if (colids != nullptr) + if (colids != nullptr) { myrowstart = colids[newcolid] * rowoffset; - else + } else { myrowstart = newcolid * rowoffset; + } - int index = rowids[i % nrows] + myrowstart; + int index; + if (rowids != nullptr) { + index = rowids[i % nrows] + myrowstart; + } else { + index = i % nrows + myrowstart; + } sampledcols[i] = data[index]; } return; } - -template -__global__ void allcolsampler_minmax_kernel( - const T* __restrict__ data, const unsigned int* __restrict__ rowids, - const int* __restrict__ colids, const int nrows, const int ncols, - const int rowoffset, T* globalmin, T* globalmax, T* sampledcols, - T init_min_val) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - extern __shared__ char shmem[]; - T* minshared = (T*)shmem; - T* maxshared = (T*)(shmem + sizeof(T) * ncols); - - for (int i = threadIdx.x; i < ncols; i += blockDim.x) { - minshared[i] = init_min_val; - maxshared[i] = -init_min_val; - } - - // Initialize min max in global memory - if (tid < ncols) { - globalmin[tid] = init_min_val; - globalmax[tid] = -init_min_val; - } - - __syncthreads(); - - for (unsigned int i = tid; i < nrows * ncols; i += blockDim.x * gridDim.x) { - int newcolid = (int)(i / nrows); - int myrowstart = colids[newcolid] * rowoffset; - int index = rowids[i % nrows] + myrowstart; - T coldata = data[index]; - - MLCommon::myAtomicMin(&minshared[newcolid], coldata); - MLCommon::myAtomicMax(&maxshared[newcolid], coldata); - sampledcols[i] = coldata; - } - - __syncthreads(); - - for (int j = threadIdx.x; j < ncols; j += blockDim.x) { - MLCommon::myAtomicMin(&globalmin[j], minshared[j]); - MLCommon::myAtomicMax(&globalmax[j], maxshared[j]); - } - - return; -} diff --git a/cpp/src/decisiontree/kernels/evaluate.cuh b/cpp/src/decisiontree/kernels/evaluate.cuh deleted file mode 100644 index 575066e47a..0000000000 --- a/cpp/src/decisiontree/kernels/evaluate.cuh +++ /dev/null @@ -1,285 +0,0 @@ -/* - * Copyright (c) 2019, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include -#include "../algo_helper.h" -#include "../memory.h" -#include "col_condenser.cuh" -#include "gini.cuh" - -/* - The output of the function is a histogram array, of size ncols * nbins * n_unique_lables - column order is as per colids (bootstrapped random cols) for each col there are nbins histograms - */ -template -__global__ void all_cols_histograms_kernel( - const T* __restrict__ data, const int* __restrict__ labels, - const unsigned int* __restrict__ rowids, const int* __restrict__ colids, - const int nbins, const int nrows, const int ncols, const int rowoffset, - const int n_unique_labels, const T* __restrict__ globalminmax, int* histout) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - extern __shared__ char shmem[]; - T* minmaxshared = (T*)shmem; - int* shmemhist = (int*)(shmem + 2 * ncols * sizeof(T)); - - for (int i = threadIdx.x; i < 2 * ncols; i += blockDim.x) { - minmaxshared[i] = globalminmax[i]; - } - - for (int i = threadIdx.x; i < n_unique_labels * nbins * ncols; - i += blockDim.x) { - shmemhist[i] = 0; - } - - __syncthreads(); - - for (unsigned int i = tid; i < nrows * ncols; i += blockDim.x * gridDim.x) { - int mycolid = (int)(i / nrows); - int coloffset = mycolid * n_unique_labels * nbins; - - // nbins is # batched bins. Use (batched bins + 1) for delta computation. - T delta = (minmaxshared[mycolid + ncols] - minmaxshared[mycolid]) / (nbins); - T base_quesval = minmaxshared[mycolid] + delta; - - T localdata = data[i]; - int label = labels[rowids[i % nrows]]; - for (int j = 0; j < nbins; j++) { - T quesval = base_quesval + j * delta; - - if (localdata <= quesval) { - atomicAdd(&shmemhist[label + n_unique_labels * j + coloffset], 1); - } - } - } - - __syncthreads(); - - for (int i = threadIdx.x; i < ncols * n_unique_labels * nbins; - i += blockDim.x) { - atomicAdd(&histout[i], shmemhist[i]); - } -} - -template -__global__ void all_cols_histograms_global_quantile_kernel( - const T* __restrict__ data, const int* __restrict__ labels, - const unsigned int* __restrict__ rowids, const int* __restrict__ colids, - const int nbins, const int nrows, const int ncols, const int rowoffset, - const int n_unique_labels, int* histout, const T* __restrict__ quantile) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - extern __shared__ char shmem[]; - int* shmemhist = (int*)(shmem); - - for (int i = threadIdx.x; i < n_unique_labels * nbins * ncols; - i += blockDim.x) { - shmemhist[i] = 0; - } - - __syncthreads(); - - for (unsigned int i = tid; i < nrows * ncols; i += blockDim.x * gridDim.x) { - int mycolid = (int)(i / nrows); - int coloffset = mycolid * n_unique_labels * nbins; - - // nbins is # batched bins. - T localdata = data[i]; - int label = labels[rowids[i % nrows]]; - for (int j = 0; j < nbins; j++) { - int quantile_index = colids[mycolid] * nbins + j; - T quesval = quantile[quantile_index]; - if (localdata <= quesval) { - atomicAdd(&shmemhist[label + n_unique_labels * j + coloffset], 1); - } - } - } - - __syncthreads(); - - for (int i = threadIdx.x; i < ncols * n_unique_labels * nbins; - i += blockDim.x) { - atomicAdd(&histout[i], shmemhist[i]); - } -} - -template -void find_best_split(const std::shared_ptr> tempmem, - const int nbins, const int n_unique_labels, - const std::vector& col_selector, - GiniInfo split_info[3], const int nrows, - GiniQuestion& ques, float& gain, const int split_algo) { - gain = 0.0f; - int best_col_id = -1; - int best_bin_id = -1; - - int n_cols = col_selector.size(); - for (int col_id = 0; col_id < n_cols; col_id++) { - int col_hist_base_index = col_id * nbins * n_unique_labels; - // tempmem->h_histout holds n_cols histograms of nbins of n_unique_labels each. - for (int i = 0; i < nbins; i++) { - // if tmp_lnrows or tmp_rnrows is 0, the corresponding gini will be 1 but that doesn't - // matter as it won't count in the info_gain computation. - float tmp_gini_left = 1.0f; - float tmp_gini_right = 1.0f; - int tmp_lnrows = 0; - - //separate loop for now to avoid overflow. - for (int j = 0; j < n_unique_labels; j++) { - int hist_index = i * n_unique_labels + j; - tmp_lnrows += - tempmem->h_histout->data()[col_hist_base_index + hist_index]; - } - int tmp_rnrows = nrows - tmp_lnrows; - - if (tmp_lnrows == 0 || tmp_rnrows == 0) continue; - - // Compute gini right and gini left value for each bin. - for (int j = 0; j < n_unique_labels; j++) { - int hist_index = i * n_unique_labels + j; - - float prob_left = - (float)(tempmem->h_histout - ->data()[col_hist_base_index + hist_index]) / - tmp_lnrows; - tmp_gini_left -= prob_left * prob_left; - - float prob_right = - (float)(split_info[0].hist[j] - - tempmem->h_histout - ->data()[col_hist_base_index + hist_index]) / - tmp_rnrows; - tmp_gini_right -= prob_right * prob_right; - } - - ASSERT((tmp_gini_left >= 0.0f) && (tmp_gini_left <= 1.0f), - "gini left value %f not in [0.0, 1.0]", tmp_gini_left); - ASSERT((tmp_gini_right >= 0.0f) && (tmp_gini_right <= 1.0f), - "gini right value %f not in [0.0, 1.0]", tmp_gini_right); - - float impurity = (tmp_lnrows * 1.0f / nrows) * tmp_gini_left + - (tmp_rnrows * 1.0f / nrows) * tmp_gini_right; - float info_gain = split_info[0].best_gini - impurity; - - // Compute best information col_gain so far - if (info_gain > gain) { - gain = info_gain; - best_bin_id = i; - best_col_id = col_id; - split_info[1].best_gini = tmp_gini_left; - split_info[2].best_gini = tmp_gini_right; - } - } - } - - if (best_col_id == -1 || best_bin_id == -1) return; - - split_info[1].hist.resize(n_unique_labels); - split_info[2].hist.resize(n_unique_labels); - for (int j = 0; j < n_unique_labels; j++) { - split_info[1].hist[j] = - tempmem->h_histout->data()[best_col_id * n_unique_labels * nbins + - best_bin_id * n_unique_labels + j]; - split_info[2].hist[j] = split_info[0].hist[j] - split_info[1].hist[j]; - } - - if (split_algo == ML::SPLIT_ALGO::HIST) { - ques.set_question_fields( - best_col_id, col_selector[best_col_id], best_bin_id, nbins, n_cols, - std::numeric_limits::max(), -std::numeric_limits::max(), (T)0); - } else if (split_algo == ML::SPLIT_ALGO::GLOBAL_QUANTILE) { - T ques_val; - T* d_quantile = tempmem->d_quantile->data(); - int q_index = col_selector[best_col_id] * nbins + best_bin_id; - CUDA_CHECK(cudaMemcpyAsync(&ques_val, &d_quantile[q_index], sizeof(T), - cudaMemcpyDeviceToHost, tempmem->stream)); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - ques.set_question_fields( - best_col_id, col_selector[best_col_id], best_bin_id, nbins, n_cols, - std::numeric_limits::max(), -std::numeric_limits::max(), ques_val); - } - return; -} - -template -void best_split_all_cols(const T* data, const unsigned int* rowids, - const int* labels, const int nbins, const int nrows, - const int n_unique_labels, const int rowoffset, - const std::vector& colselector, - const std::shared_ptr> tempmem, - GiniInfo split_info[3], GiniQuestion& ques, - float& gain, const int split_algo) { - int* d_colids = tempmem->d_colids->data(); - T* d_globalminmax = tempmem->d_globalminmax->data(); - int* d_histout = tempmem->d_histout->data(); - int* h_histout = tempmem->h_histout->data(); - - int ncols = colselector.size(); - int col_minmax_bytes = sizeof(T) * 2 * ncols; - int n_hist_bytes = n_unique_labels * nbins * sizeof(int) * ncols; - - CUDA_CHECK( - cudaMemsetAsync((void*)d_histout, 0, n_hist_bytes, tempmem->stream)); - - int threads = 512; - int blocks = MLCommon::ceildiv(nrows * ncols, threads); - if (blocks > 65536) blocks = 65536; - - /* Kernel allcolsampler_*_kernel: - - populates tempmem->tempdata with the sampled column data, - - and computes min max histograms in tempmem->d_globalminmax *if minmax in name - across all columns. - */ - size_t shmemsize = col_minmax_bytes; - if (split_algo == ML::SPLIT_ALGO::HIST) { // Histograms (min, max) - allcolsampler_minmax_kernel<<stream>>>( - data, rowids, d_colids, nrows, ncols, rowoffset, &d_globalminmax[0], - &d_globalminmax[colselector.size()], tempmem->temp_data->data(), - std::numeric_limits::max()); - } else if (split_algo == - ML::SPLIT_ALGO:: - GLOBAL_QUANTILE) { // Global quantiles; just col condenser - allcolsampler_kernel<<stream>>>( - data, rowids, d_colids, nrows, ncols, rowoffset, - tempmem->temp_data->data()); - } - CUDA_CHECK(cudaGetLastError()); - - shmemsize = n_hist_bytes; - - if (split_algo == ML::SPLIT_ALGO::HIST) { - shmemsize += col_minmax_bytes; - all_cols_histograms_kernel<<stream>>>( - tempmem->temp_data->data(), labels, rowids, d_colids, nbins, nrows, ncols, - rowoffset, n_unique_labels, d_globalminmax, d_histout); - } else if (split_algo == ML::SPLIT_ALGO::GLOBAL_QUANTILE) { - all_cols_histograms_global_quantile_kernel<<stream>>>( - tempmem->temp_data->data(), labels, rowids, d_colids, nbins, nrows, ncols, - rowoffset, n_unique_labels, d_histout, tempmem->d_quantile->data()); - } - CUDA_CHECK(cudaGetLastError()); - - CUDA_CHECK(cudaMemcpyAsync(h_histout, d_histout, n_hist_bytes, - cudaMemcpyDeviceToHost, tempmem->stream)); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - - find_best_split(tempmem, nbins, n_unique_labels, colselector, &split_info[0], - nrows, ques, gain, split_algo); - return; -} diff --git a/cpp/src/decisiontree/kernels/evaluate_classifier.cuh b/cpp/src/decisiontree/kernels/evaluate_classifier.cuh new file mode 100644 index 0000000000..5d63c094e2 --- /dev/null +++ b/cpp/src/decisiontree/kernels/evaluate_classifier.cuh @@ -0,0 +1,314 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include "../algo_helper.h" +#include "../memory.h" +#include "batch_cal.cuh" +#include "col_condenser.cuh" +#include "metric.cuh" +#include "stats/minmax.h" + +/* + The output of the function is a histogram array, of size ncols * nbins * n_unique_labels + column order is as per colids (bootstrapped random cols) for each col there are nbins histograms + */ +template +__global__ void all_cols_histograms_kernel_class( + const T* __restrict__ data, const int* __restrict__ labels, const int nbins, + const int nrows, const int ncols, const int batch_ncols, + const int n_unique_labels, const T* __restrict__ globalminmax, int* histout) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + extern __shared__ char shmem[]; + + int colstep = (int)(ncols / batch_ncols); + if ((ncols % batch_ncols) != 0) colstep++; + + int batchsz = batch_ncols; + for (int k = 0; k < colstep; k++) { + if (k == (colstep - 1) && ((ncols % batch_ncols) != 0)) { + batchsz = ncols % batch_ncols; + } + + T* minmaxshared = (T*)shmem; + int* shmemhist = (int*)(shmem + 2 * batchsz * sizeof(T)); + + for (int i = threadIdx.x; i < 2 * batchsz; i += blockDim.x) { + (i < batchsz) ? (minmaxshared[i] = globalminmax[k * batch_ncols + i]) + : (minmaxshared[i] = + globalminmax[k * batch_ncols + (i - batchsz) + ncols]); + } + + for (int i = threadIdx.x; i < n_unique_labels * nbins * batchsz; + i += blockDim.x) { + shmemhist[i] = 0; + } + + __syncthreads(); + + for (unsigned int i = tid; i < nrows * batchsz; + i += blockDim.x * gridDim.x) { + int mycolid = (int)(i / nrows); + int coloffset = mycolid * n_unique_labels * nbins; + + T delta = + (minmaxshared[mycolid + batchsz] - minmaxshared[mycolid]) / (nbins); + T base_quesval = minmaxshared[mycolid] + delta; + + T localdata = data[i + k * batch_ncols * nrows]; + int label = labels[i % nrows]; + for (int j = 0; j < nbins; j++) { + T quesval = base_quesval + j * delta; + + if (localdata <= quesval) { + atomicAdd(&shmemhist[label + n_unique_labels * j + coloffset], 1); + } + } + } + + __syncthreads(); + for (int i = threadIdx.x; i < batchsz * n_unique_labels * nbins; + i += blockDim.x) { + atomicAdd(&histout[k * batch_ncols * n_unique_labels * nbins + i], + shmemhist[i]); + } + + __syncthreads(); + } +} + +template +__global__ void all_cols_histograms_global_quantile_kernel_class( + const T* __restrict__ data, const int* __restrict__ labels, + const unsigned int* __restrict__ colids, const int nbins, const int nrows, + const int ncols, const int batch_ncols, const int n_unique_labels, + int* histout, const T* __restrict__ quantile) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + extern __shared__ char shmem[]; + int* shmemhist = (int*)(shmem); + + int colstep = (int)(ncols / batch_ncols); + if ((ncols % batch_ncols) != 0) colstep++; + + int batchsz = batch_ncols; + for (int k = 0; k < colstep; k++) { + if (k == (colstep - 1) && ((ncols % batch_ncols) != 0)) { + batchsz = ncols % batch_ncols; + } + + for (int i = threadIdx.x; i < n_unique_labels * nbins * batchsz; + i += blockDim.x) { + shmemhist[i] = 0; + } + + __syncthreads(); + + for (unsigned int i = tid; i < nrows * batchsz; + i += blockDim.x * gridDim.x) { + int mycolid = (int)(i / nrows); + int coloffset = mycolid * n_unique_labels * nbins; + + T localdata = data[k * batch_ncols * nrows + i]; + int label = labels[i % nrows]; + for (int j = 0; j < nbins; j++) { + int quantile_index = colids[k * batch_ncols + mycolid] * nbins + j; + T quesval = quantile[quantile_index]; + if (localdata <= quesval) { + atomicAdd(&shmemhist[label + n_unique_labels * j + coloffset], 1); + } + } + } + + __syncthreads(); + + for (int i = threadIdx.x; i < batchsz * n_unique_labels * nbins; + i += blockDim.x) { + atomicAdd(&histout[k * batch_ncols * n_unique_labels * nbins + i], + shmemhist[i]); + } + + __syncthreads(); + } +} + +template +void find_best_split_classifier( + const std::shared_ptr> tempmem, const int nbins, + const int n_unique_labels, const std::vector& col_selector, + MetricInfo split_info[3], const int nrows, MetricQuestion& ques, + float& gain, const int split_algo) { + gain = 0.0f; + int best_col_id = -1; + int best_bin_id = -1; + + int n_cols = col_selector.size(); + for (int col_id = 0; col_id < n_cols; col_id++) { + int col_hist_base_index = col_id * nbins * n_unique_labels; + // tempmem->h_histout holds n_cols histograms of nbins of n_unique_labels each. + for (int i = 0; i < nbins; i++) { + // if tmp_lnrows or tmp_rnrows is 0, the corresponding gini will be 1 but that doesn't + // matter as it won't count in the info_gain computation. + int tmp_lnrows = 0; + + //separate loop for now to avoid overflow. + for (int j = 0; j < n_unique_labels; j++) { + int hist_index = i * n_unique_labels + j; + tmp_lnrows += + tempmem->h_histout->data()[col_hist_base_index + hist_index]; + } + int tmp_rnrows = nrows - tmp_lnrows; + + if (tmp_lnrows == 0 || tmp_rnrows == 0) continue; + + std::vector tmp_histleft(n_unique_labels); + std::vector tmp_histright(n_unique_labels); + + // Compute gini right and gini left value for each bin. + for (int j = 0; j < n_unique_labels; j++) { + int hist_index = i * n_unique_labels + j; + tmp_histleft[j] = + tempmem->h_histout->data()[col_hist_base_index + hist_index]; + tmp_histright[j] = split_info[0].hist[j] - tmp_histleft[j]; + } + + float tmp_gini_left = F::exec(tmp_histleft, tmp_lnrows); + float tmp_gini_right = F::exec(tmp_histright, tmp_rnrows); + + float max_value = F::max_val(n_unique_labels); + ASSERT((tmp_gini_left >= 0.0f) && (tmp_gini_left <= max_value), + "gini left value %f not in [0.0, 1.0]", tmp_gini_left); + ASSERT((tmp_gini_right >= 0.0f) && (tmp_gini_right <= max_value), + "gini right value %f not in [0.0, 1.0]", tmp_gini_right); + + float impurity = (tmp_lnrows * 1.0f / nrows) * tmp_gini_left + + (tmp_rnrows * 1.0f / nrows) * tmp_gini_right; + float info_gain = split_info[0].best_metric - impurity; + + // Compute best information col_gain so far + if (info_gain > gain) { + gain = info_gain; + best_bin_id = i; + best_col_id = col_id; + split_info[1].best_metric = tmp_gini_left; + split_info[2].best_metric = tmp_gini_right; + } + } + } + + if (best_col_id == -1 || best_bin_id == -1) return; + + split_info[1].hist.resize(n_unique_labels); + split_info[2].hist.resize(n_unique_labels); + for (int j = 0; j < n_unique_labels; j++) { + split_info[1].hist[j] = + tempmem->h_histout->data()[best_col_id * n_unique_labels * nbins + + best_bin_id * n_unique_labels + j]; + split_info[2].hist[j] = split_info[0].hist[j] - split_info[1].hist[j]; + } + + if (split_algo == ML::SPLIT_ALGO::HIST) { + ques.set_question_fields( + best_col_id, col_selector[best_col_id], best_bin_id, nbins, n_cols, + std::numeric_limits::max(), -std::numeric_limits::max(), (T)0); + } else if (split_algo == ML::SPLIT_ALGO::GLOBAL_QUANTILE) { + T ques_val; + int q_index = col_selector[best_col_id] * nbins + best_bin_id; + ques_val = tempmem->h_quantile->data()[q_index]; + ques.set_question_fields( + best_col_id, col_selector[best_col_id], best_bin_id, nbins, n_cols, + std::numeric_limits::max(), -std::numeric_limits::max(), ques_val); + } + return; +} + +template +void best_split_all_cols_classifier( + const T* data, const unsigned int* rowids, const L* labels, const int nbins, + const int nrows, const int n_unique_labels, const int rowoffset, + const std::vector& colselector, + const std::shared_ptr> tempmem, + MetricInfo split_info[3], MetricQuestion& ques, float& gain, + const int split_algo, const size_t max_shared_mem) { + unsigned int* d_colids = tempmem->d_colids->data(); + T* d_globalminmax = tempmem->d_globalminmax->data(); + int* d_histout = tempmem->d_histout->data(); + int* h_histout = tempmem->h_histout->data(); + + int ncols = colselector.size(); + int col_minmax_bytes = sizeof(T) * 2 * ncols; + int n_hist_elements = n_unique_labels * nbins * ncols; + int n_hist_bytes = n_hist_elements * sizeof(int); + + CUDA_CHECK( + cudaMemsetAsync((void*)d_histout, 0, n_hist_bytes, tempmem->stream)); + + const int threads = 512; + int blocks = min(MLCommon::ceildiv(nrows * ncols, threads), 65536); + + /* Kernel allcolsampler_*_kernel: + - populates tempmem->tempdata with the sampled column data, + - and computes min max histograms in tempmem->d_globalminmax *if minmax in name + across all columns. + */ + size_t shmemsize = col_minmax_bytes; + if (split_algo == ML::SPLIT_ALGO::HIST) { // Histograms (min, max) + MLCommon::Stats::minmax( + data, rowids, d_colids, nrows, ncols, rowoffset, &d_globalminmax[0], + &d_globalminmax[colselector.size()], tempmem->temp_data->data(), + tempmem->stream); + } else if (split_algo == + ML::SPLIT_ALGO:: + GLOBAL_QUANTILE) { // Global quantiles; just col condenser + allcolsampler_kernel<<stream>>>( + data, rowids, d_colids, nrows, ncols, rowoffset, + tempmem->temp_data->data()); + } + CUDA_CHECK(cudaGetLastError()); + L* labelptr = tempmem->sampledlabels->data(); + get_sampled_labels(labels, labelptr, rowids, nrows, tempmem->stream); + + int batch_ncols; + size_t shmem_needed = n_hist_bytes; + if (split_algo == ML::SPLIT_ALGO::HIST) { + shmem_needed += col_minmax_bytes; + } + update_kernel_config(max_shared_mem, shmem_needed, ncols, nrows, threads, + batch_ncols, blocks, shmemsize); + + if (split_algo == ML::SPLIT_ALGO::HIST) { + all_cols_histograms_kernel_class<<stream>>>( + tempmem->temp_data->data(), labelptr, nbins, nrows, ncols, batch_ncols, + n_unique_labels, d_globalminmax, d_histout); + } else if (split_algo == ML::SPLIT_ALGO::GLOBAL_QUANTILE) { + all_cols_histograms_global_quantile_kernel_class<<< + blocks, threads, shmemsize, tempmem->stream>>>( + tempmem->temp_data->data(), labelptr, d_colids, nbins, nrows, ncols, + batch_ncols, n_unique_labels, d_histout, tempmem->d_quantile->data()); + } + CUDA_CHECK(cudaGetLastError()); + + MLCommon::updateHost(h_histout, d_histout, n_hist_elements, tempmem->stream); + CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); + + find_best_split_classifier(tempmem, nbins, n_unique_labels, + colselector, &split_info[0], nrows, ques, + gain, split_algo); + return; +} diff --git a/cpp/src/decisiontree/kernels/evaluate_regressor.cuh b/cpp/src/decisiontree/kernels/evaluate_regressor.cuh new file mode 100644 index 0000000000..a651f256af --- /dev/null +++ b/cpp/src/decisiontree/kernels/evaluate_regressor.cuh @@ -0,0 +1,473 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include "../algo_helper.h" +#include "../memory.h" +#include "batch_cal.cuh" +#include "col_condenser.cuh" +#include "metric.cuh" +#include "stats/minmax.h" + +template +__global__ void compute_mse_minmax_kernel_reg( + const T* __restrict__ data, const T* __restrict__ labels, const int nbins, + const int nrows, const int ncols, const int batch_ncols, + const T* __restrict__ globalminmax, T* mseout, const T* __restrict__ predout, + const int* __restrict__ countout, const T pred_parent) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + extern __shared__ char shmem[]; + + int colstep = (int)(ncols / batch_ncols); + if ((ncols % batch_ncols) != 0) colstep++; + + int batchsz = batch_ncols; + for (int k = 0; k < colstep; k++) { + if (k == (colstep - 1) && ((ncols % batch_ncols) != 0)) { + batchsz = ncols % batch_ncols; + } + + T* minmaxshared = (T*)shmem; + T* shmem_pred = (T*)(shmem + 2 * batchsz * sizeof(T)); + T* shmem_mse = + (T*)(shmem + 2 * batchsz * sizeof(T) + nbins * batchsz * sizeof(T)); + int* shmem_count = + (int*)(shmem + 2 * batchsz * sizeof(T) + 3 * nbins * batchsz * sizeof(T)); + + for (int i = threadIdx.x; i < 2 * batchsz; i += blockDim.x) { + (i < batchsz) ? (minmaxshared[i] = globalminmax[k * batch_ncols + i]) + : (minmaxshared[i] = + globalminmax[k * batch_ncols + (i - batchsz) + ncols]); + } + + for (int i = threadIdx.x; i < nbins * batchsz; i += blockDim.x) { + shmem_count[i] = countout[i + k * nbins * batch_ncols]; + shmem_pred[i] = predout[i + k * nbins * batch_ncols]; + shmem_mse[i] = 0.0; + shmem_mse[i + batchsz * nbins] = 0.0; + } + + __syncthreads(); + + for (unsigned int i = tid; i < nrows * batchsz; + i += blockDim.x * gridDim.x) { + int mycolid = (int)(i / nrows); + int coloffset = mycolid * nbins; + + T delta = + (minmaxshared[mycolid + batchsz] - minmaxshared[mycolid]) / (nbins); + T base_quesval = minmaxshared[mycolid] + delta; + + T localdata = data[i + k * batch_ncols * nrows]; + T label = labels[i % nrows]; + for (int j = 0; j < nbins; j++) { + T quesval = base_quesval + j * delta; + + if (localdata <= quesval) { + T temp = shmem_pred[coloffset + j] / shmem_count[coloffset + j]; + temp = label - temp; + atomicAdd(&shmem_mse[j + coloffset], F::exec(temp)); + } else { + T temp = (pred_parent * nrows - shmem_pred[coloffset + j]) / + (nrows - shmem_count[coloffset + j]); + temp = label - temp; + atomicAdd(&shmem_mse[j + coloffset + batchsz * nbins], F::exec(temp)); + } + } + } + + __syncthreads(); + + for (int i = threadIdx.x; i < batchsz * nbins; i += blockDim.x) { + atomicAdd(&mseout[i + k * batch_ncols * nbins], shmem_mse[i]); + atomicAdd(&mseout[i + k * batch_ncols * nbins + ncols * nbins], + shmem_mse[i + batchsz * nbins]); + } + + __syncthreads(); + } +} + +/* + The output of the function is a histogram array, of size ncols * nbins * n_unique_lables + column order is as per colids (bootstrapped random cols) for each col there are nbins histograms + */ +template +__global__ void all_cols_histograms_minmax_kernel_reg( + const T* __restrict__ data, const T* __restrict__ labels, const int nbins, + const int nrows, const int ncols, const int batch_ncols, + const T* __restrict__ globalminmax, T* predout, int* countout) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + extern __shared__ char shmem[]; + + int colstep = (int)(ncols / batch_ncols); + if ((ncols % batch_ncols) != 0) colstep++; + + int batchsz = batch_ncols; + for (int k = 0; k < colstep; k++) { + if (k == (colstep - 1) && ((ncols % batch_ncols) != 0)) { + batchsz = ncols % batch_ncols; + } + + T* minmaxshared = (T*)shmem; + T* shmem_pred = (T*)(shmem + 2 * batchsz * sizeof(T)); + int* shmem_count = + (int*)(shmem + 2 * batchsz * sizeof(T) + nbins * batchsz * sizeof(T)); + + for (int i = threadIdx.x; i < 2 * batchsz; i += blockDim.x) { + (i < batchsz) ? (minmaxshared[i] = globalminmax[k * batch_ncols + i]) + : (minmaxshared[i] = + globalminmax[k * batch_ncols + (i - batchsz) + ncols]); + } + + for (int i = threadIdx.x; i < nbins * batchsz; i += blockDim.x) { + shmem_pred[i] = 0; + shmem_count[i] = 0; + } + + __syncthreads(); + + for (unsigned int i = tid; i < nrows * batchsz; + i += blockDim.x * gridDim.x) { + int mycolid = (int)(i / nrows); + int coloffset = mycolid * nbins; + + T delta = + (minmaxshared[mycolid + batchsz] - minmaxshared[mycolid]) / (nbins); + T base_quesval = minmaxshared[mycolid] + delta; + + T localdata = data[i + k * batch_ncols * nrows]; + T label = labels[i % nrows]; + for (int j = 0; j < nbins; j++) { + T quesval = base_quesval + j * delta; + + if (localdata <= quesval) { + atomicAdd(&shmem_count[j + coloffset], 1); + atomicAdd(&shmem_pred[j + coloffset], label); + } + } + } + + __syncthreads(); + + for (int i = threadIdx.x; i < batchsz * nbins; i += blockDim.x) { + atomicAdd(&predout[i + k * batch_ncols * nbins], shmem_pred[i]); + atomicAdd(&countout[i + k * batch_ncols * nbins], shmem_count[i]); + } + + __syncthreads(); + } +} + +template +__global__ void all_cols_histograms_global_quantile_kernel_reg( + const T* __restrict__ data, const T* __restrict__ labels, + const unsigned int* __restrict__ colids, const int nbins, const int nrows, + const int ncols, const int batch_ncols, T* predout, int* countout, + const T* __restrict__ quantile) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + extern __shared__ char shmem[]; + + int colstep = (int)(ncols / batch_ncols); + if ((ncols % batch_ncols) != 0) colstep++; + + int batchsz = batch_ncols; + for (int k = 0; k < colstep; k++) { + if (k == (colstep - 1) && ((ncols % batch_ncols) != 0)) { + batchsz = ncols % batch_ncols; + } + + T* shmem_pred = (T*)(shmem); + int* shmem_count = (int*)(shmem + nbins * batchsz * sizeof(T)); + + for (int i = threadIdx.x; i < nbins * batchsz; i += blockDim.x) { + shmem_pred[i] = 0; + shmem_count[i] = 0; + } + + __syncthreads(); + + for (unsigned int i = tid; i < nrows * batchsz; + i += blockDim.x * gridDim.x) { + int mycolid = (int)(i / nrows); + int coloffset = mycolid * nbins; + + T localdata = data[i + k * batch_ncols * nrows]; + T label = labels[i % nrows]; + for (int j = 0; j < nbins; j++) { + int quantile_index = colids[mycolid + k * batch_ncols] * nbins + j; + T quesval = quantile[quantile_index]; + if (localdata <= quesval) { + atomicAdd(&shmem_count[j + coloffset], 1); + atomicAdd(&shmem_pred[j + coloffset], label); + } + } + } + + __syncthreads(); + + for (int i = threadIdx.x; i < batchsz * nbins; i += blockDim.x) { + atomicAdd(&predout[i + k * batch_ncols * nbins], shmem_pred[i]); + atomicAdd(&countout[i + k * batch_ncols * nbins], shmem_count[i]); + } + __syncthreads(); + } +} + +template +__global__ void compute_mse_global_quantile_kernel_reg( + const T* __restrict__ data, const T* __restrict__ labels, + const unsigned int* __restrict__ colids, const int nbins, const int nrows, + const int ncols, const int batch_ncols, T* mseout, + const T* __restrict__ predout, const int* __restrict__ countout, + const T* __restrict__ quantile, const T pred_parent) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + extern __shared__ char shmem[]; + + int colstep = (int)(ncols / batch_ncols); + if ((ncols % batch_ncols) != 0) colstep++; + + int batchsz = batch_ncols; + for (int k = 0; k < colstep; k++) { + if (k == (colstep - 1) && ((ncols % batch_ncols) != 0)) { + batchsz = ncols % batch_ncols; + } + + T* shmem_pred = (T*)(shmem); + T* shmem_mse = (T*)(shmem + nbins * batchsz * sizeof(T)); + int* shmem_count = (int*)(shmem + 3 * nbins * batchsz * sizeof(T)); + + for (int i = threadIdx.x; i < nbins * batchsz; i += blockDim.x) { + shmem_count[i] = countout[i + k * nbins * batch_ncols]; + shmem_pred[i] = predout[i + k * nbins * batch_ncols]; + shmem_mse[i] = 0.0; + shmem_mse[i + batchsz * nbins] = 0.0; + } + + __syncthreads(); + + for (unsigned int i = tid; i < nrows * batchsz; + i += blockDim.x * gridDim.x) { + int mycolid = (int)(i / nrows); + int coloffset = mycolid * nbins; + + T localdata = data[i + k * batch_ncols * nrows]; + T label = labels[i % nrows]; + for (int j = 0; j < nbins; j++) { + int quantile_index = colids[mycolid + k * batch_ncols] * nbins + j; + T quesval = quantile[quantile_index]; + + if (localdata <= quesval) { + T temp = shmem_pred[coloffset + j] / shmem_count[coloffset + j]; + temp = label - temp; + atomicAdd(&shmem_mse[j + coloffset], F::exec(temp)); + } else { + T temp = (pred_parent * nrows - shmem_pred[coloffset + j]) / + (nrows - shmem_count[coloffset + j]); + temp = label - temp; + atomicAdd(&shmem_mse[j + coloffset + batchsz * nbins], F::exec(temp)); + } + } + } + + __syncthreads(); + + for (int i = threadIdx.x; i < batchsz * nbins; i += blockDim.x) { + atomicAdd(&mseout[i + k * batch_ncols * nbins], shmem_mse[i]); + atomicAdd(&mseout[i + k * batch_ncols * nbins + ncols * nbins], + shmem_mse[i + batchsz * nbins]); + } + __syncthreads(); + } +} + +template +void find_best_split_regressor( + const std::shared_ptr> tempmem, const int nbins, + const std::vector& col_selector, MetricInfo split_info[3], + const int nrows, MetricQuestion& ques, float& gain, const int split_algo) { + gain = 0.0f; + int best_col_id = -1; + int best_bin_id = -1; + + int n_cols = col_selector.size(); + for (int col_id = 0; col_id < n_cols; col_id++) { + int col_count_base_index = col_id * nbins; + // tempmem->h_histout holds n_cols histograms of nbins of n_unique_labels each. + for (int i = 0; i < nbins; i++) { + int tmp_lnrows = tempmem->h_histout->data()[col_count_base_index + i]; + int tmp_rnrows = nrows - tmp_lnrows; + + if (tmp_lnrows == 0 || tmp_rnrows == 0) continue; + + float tmp_pred_left = + tempmem->h_predout->data()[col_count_base_index + i]; + float tmp_pred_right = (nrows * split_info[0].predict) - tmp_pred_left; + tmp_pred_left /= tmp_lnrows; + tmp_pred_right /= tmp_rnrows; + + // Compute MSE right and MSE left value for each bin. + float tmp_mse_left = tempmem->h_mseout->data()[col_count_base_index + i]; + float tmp_mse_right = + tempmem->h_mseout->data()[col_count_base_index + i + n_cols * nbins]; + tmp_mse_left /= tmp_lnrows; + tmp_mse_right /= tmp_rnrows; + + float impurity = (tmp_lnrows * 1.0f / nrows) * tmp_mse_left + + (tmp_rnrows * 1.0f / nrows) * tmp_mse_right; + float info_gain = split_info[0].best_metric - impurity; + + // Compute best information col_gain so far + if (info_gain > gain) { + gain = info_gain; + best_bin_id = i; + best_col_id = col_id; + split_info[1].best_metric = tmp_mse_left; + split_info[2].best_metric = tmp_mse_right; + split_info[1].predict = tmp_pred_left; + split_info[2].predict = tmp_pred_right; + } + } + } + + if (best_col_id == -1 || best_bin_id == -1) return; + + if (split_algo == ML::SPLIT_ALGO::HIST) { + ques.set_question_fields( + best_col_id, col_selector[best_col_id], best_bin_id, nbins, n_cols, + std::numeric_limits::max(), -std::numeric_limits::max(), (T)0); + } else if (split_algo == ML::SPLIT_ALGO::GLOBAL_QUANTILE) { + T ques_val; + int q_index = col_selector[best_col_id] * nbins + best_bin_id; + ques_val = tempmem->h_quantile->data()[q_index]; + ques.set_question_fields( + best_col_id, col_selector[best_col_id], best_bin_id, nbins, n_cols, + std::numeric_limits::max(), -std::numeric_limits::max(), ques_val); + } + return; +} + +template +void best_split_all_cols_regressor( + const T* data, const unsigned int* rowids, const T* labels, const int nbins, + const int nrows, const int rowoffset, + const std::vector& colselector, + const std::shared_ptr> tempmem, + MetricInfo split_info[3], MetricQuestion& ques, float& gain, + const int split_algo, const size_t max_shared_mem) { + unsigned int* d_colids = tempmem->d_colids->data(); + T* d_globalminmax = tempmem->d_globalminmax->data(); + int* d_histout = tempmem->d_histout->data(); + int* h_histout = tempmem->h_histout->data(); + T* d_mseout = tempmem->d_mseout->data(); + T* h_mseout = tempmem->h_mseout->data(); + T* d_predout = tempmem->d_predout->data(); + T* h_predout = tempmem->h_predout->data(); + + int ncols = colselector.size(); + int col_minmax_bytes = sizeof(T) * 2 * ncols; + int n_pred_bytes = nbins * sizeof(T) * ncols; + int n_count_bytes = nbins * ncols * sizeof(int); + int n_mse_bytes = 2 * nbins * sizeof(T) * ncols; + + CUDA_CHECK(cudaMemsetAsync((void*)d_mseout, 0, n_mse_bytes, tempmem->stream)); + CUDA_CHECK( + cudaMemsetAsync((void*)d_predout, 0, n_pred_bytes, tempmem->stream)); + CUDA_CHECK( + cudaMemsetAsync((void*)d_histout, 0, n_count_bytes, tempmem->stream)); + + const int threads = 512; + int blocks = MLCommon::ceildiv(nrows * ncols, threads); + if (blocks > 65536) blocks = 65536; + + /* Kernel allcolsampler_*_kernel: + - populates tempmem->tempdata with the sampled column data, + - and computes min max histograms in tempmem->d_globalminmax *if minmax in name + across all columns. + */ + size_t shmemsize = col_minmax_bytes; + if (split_algo == ML::SPLIT_ALGO::HIST) { // Histograms (min, max) + MLCommon::Stats::minmax( + data, rowids, d_colids, nrows, ncols, rowoffset, &d_globalminmax[0], + &d_globalminmax[colselector.size()], tempmem->temp_data->data(), + tempmem->stream); + } else if (split_algo == + ML::SPLIT_ALGO:: + GLOBAL_QUANTILE) { // Global quantiles; just col condenser + allcolsampler_kernel<<stream>>>( + data, rowids, d_colids, nrows, ncols, rowoffset, + tempmem->temp_data->data()); + } + CUDA_CHECK(cudaGetLastError()); + + int batch_ncols; + size_t shmem_needed; + + T* labelptr = tempmem->sampledlabels->data(); + get_sampled_labels(labels, labelptr, rowids, nrows, tempmem->stream); + + if (split_algo == ML::SPLIT_ALGO::HIST) { + shmem_needed = n_pred_bytes + n_count_bytes + col_minmax_bytes; + update_kernel_config(max_shared_mem, shmem_needed, ncols, nrows, threads, + batch_ncols, blocks, shmemsize); + all_cols_histograms_minmax_kernel_reg<<stream>>>( + tempmem->temp_data->data(), labelptr, nbins, nrows, ncols, batch_ncols, + d_globalminmax, d_predout, d_histout); + + shmem_needed += n_mse_bytes; + update_kernel_config(max_shared_mem, shmem_needed, ncols, nrows, threads, + batch_ncols, blocks, shmemsize); + compute_mse_minmax_kernel_reg + <<stream>>>( + tempmem->temp_data->data(), labelptr, nbins, nrows, ncols, batch_ncols, + d_globalminmax, d_mseout, d_predout, d_histout, split_info[0].predict); + } else if (split_algo == ML::SPLIT_ALGO::GLOBAL_QUANTILE) { + shmem_needed = n_pred_bytes + n_count_bytes; + update_kernel_config(max_shared_mem, shmem_needed, ncols, nrows, threads, + batch_ncols, blocks, shmemsize); + all_cols_histograms_global_quantile_kernel_reg<<stream>>>( + tempmem->temp_data->data(), labelptr, d_colids, nbins, nrows, ncols, + batch_ncols, d_predout, d_histout, tempmem->d_quantile->data()); + + shmem_needed += n_mse_bytes; + update_kernel_config(max_shared_mem, shmem_needed, ncols, nrows, threads, + batch_ncols, blocks, shmemsize); + + compute_mse_global_quantile_kernel_reg + <<stream>>>( + tempmem->temp_data->data(), labelptr, d_colids, nbins, nrows, ncols, + batch_ncols, d_mseout, d_predout, d_histout, + tempmem->d_quantile->data(), split_info[0].predict); + } + CUDA_CHECK(cudaGetLastError()); + + MLCommon::updateHost(h_mseout, d_mseout, n_mse_bytes / sizeof(T), + tempmem->stream); + MLCommon::updateHost(h_histout, d_histout, n_count_bytes / sizeof(int), + tempmem->stream); + MLCommon::updateHost(h_predout, d_predout, n_pred_bytes / sizeof(T), + tempmem->stream); + CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); + + find_best_split_regressor(tempmem, nbins, colselector, &split_info[0], nrows, + ques, gain, split_algo); + return; +} diff --git a/cpp/src/decisiontree/kernels/gini.cuh b/cpp/src/decisiontree/kernels/gini.cuh deleted file mode 100644 index 77d5e2dcfa..0000000000 --- a/cpp/src/decisiontree/kernels/gini.cuh +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) 2019, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include "../memory.h" -#include "cub/cub.cuh" -#include "cuda_utils.h" -#include "gini_def.h" - -template -void GiniQuestion::set_question_fields(int cfg_bootcolumn, int cfg_column, - int cfg_batch_id, int cfg_nbins, - int cfg_ncols, T cfg_min, T cfg_max, - T cfg_value) { - bootstrapped_column = cfg_bootcolumn; - original_column = cfg_column; - batch_id = cfg_batch_id; - min = cfg_min; - max = cfg_max; - nbins = cfg_nbins; - ncols = cfg_ncols; - value = cfg_value; // Will be updated in make_split -} - -__global__ void gini_kernel(const int *__restrict__ labels, const int nrows, - const int nmax, int *histout) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - extern __shared__ unsigned int shmemhist[]; - if (threadIdx.x < nmax) shmemhist[threadIdx.x] = 0; - - __syncthreads(); - - if (tid < nrows) { - int label = labels[tid]; - atomicAdd(&shmemhist[label], 1); - } - - __syncthreads(); - - if (threadIdx.x < nmax) - atomicAdd(&histout[threadIdx.x], shmemhist[threadIdx.x]); - - return; -} - -template -void gini(int *labels_in, const int nrows, - const std::shared_ptr> tempmem, - GiniInfo &split_info, int &unique_labels) { - int *dhist = tempmem->d_hist->data(); - int *hhist = tempmem->h_hist->data(); - float gval = 1.0; - - CUDA_CHECK( - cudaMemsetAsync(dhist, 0, sizeof(int) * unique_labels, tempmem->stream)); - gini_kernel<<stream>>>(labels_in, nrows, unique_labels, dhist); - CUDA_CHECK(cudaGetLastError()); - CUDA_CHECK(cudaMemcpyAsync(hhist, dhist, sizeof(int) * unique_labels, - cudaMemcpyDeviceToHost, tempmem->stream)); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - - split_info.hist.resize(unique_labels, 0); - for (int i = 0; i < unique_labels; i++) { - split_info.hist[i] = hhist[i]; //update_gini_hist - float prob = ((float)hhist[i]) / nrows; - gval -= prob * prob; - } - - split_info.best_gini = gval; //Update gini val - - return; -} diff --git a/cpp/src/decisiontree/kernels/metric.cuh b/cpp/src/decisiontree/kernels/metric.cuh new file mode 100644 index 0000000000..33c6266141 --- /dev/null +++ b/cpp/src/decisiontree/kernels/metric.cuh @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "cuda_utils.h" +#include "metric_def.h" + +template +void MetricQuestion::set_question_fields(int cfg_bootcolumn, int cfg_column, + int cfg_batch_id, int cfg_nbins, + int cfg_ncols, T cfg_min, T cfg_max, + T cfg_value) { + bootstrapped_column = cfg_bootcolumn; + original_column = cfg_column; + batch_id = cfg_batch_id; + min = cfg_min; + max = cfg_max; + nbins = cfg_nbins; + ncols = cfg_ncols; + value = cfg_value; // Will be updated in make_split +} + +template +__device__ __forceinline__ T SquareFunctor::exec(T x) { + return MLCommon::myPow(x, (T)2); +} + +template +__device__ __forceinline__ T AbsFunctor::exec(T x) { + return MLCommon::myAbs(x); +} + +float GiniFunctor::max_val(int nclass) { return 1.0; } + +float EntropyFunctor::max_val(int nclass) { + float prob = 1.0 / nclass; + return (-1.0 * nclass * prob * logf(prob)); +} +float GiniFunctor::exec(std::vector &hist, int nrows) { + float gval = 1.0; + for (int i = 0; i < hist.size(); i++) { + float prob = ((float)hist[i]) / nrows; + gval -= prob * prob; + } + return gval; +} + +float EntropyFunctor::exec(std::vector &hist, int nrows) { + float eval = 0.0; + for (int i = 0; i < hist.size(); i++) { + if (hist[i] != 0) { + float prob = ((float)hist[i]) / nrows; + eval += prob * logf(prob); + } + } + return (-1 * eval); +} + +__global__ void gini_kernel(const int *__restrict__ labels, const int nrows, + const int nmax, int *histout) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + extern __shared__ unsigned int shmemhist[]; + if (threadIdx.x < nmax) shmemhist[threadIdx.x] = 0; + + __syncthreads(); + + if (tid < nrows) { + int label = labels[tid]; + atomicAdd(&shmemhist[label], 1); + } + + __syncthreads(); + + if (threadIdx.x < nmax) + atomicAdd(&histout[threadIdx.x], shmemhist[threadIdx.x]); + + return; +} + +template +__global__ void pred_kernel(const T *__restrict__ labels, const int nrows, + T *predout) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + __shared__ T shmempred; + + if (threadIdx.x == 0) shmempred = 0; + + __syncthreads(); + + if (tid < nrows) { + T label = labels[tid]; + atomicAdd(&shmempred, label); + } + + __syncthreads(); + + if (threadIdx.x == 0) { + atomicAdd(predout, shmempred); + } + + return; +} + +template +__global__ void mse_kernel(const T *__restrict__ labels, const int nrows, + const T *predout, T *mseout) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + __shared__ T shmemmse; + + if (threadIdx.x == 0) { + shmemmse = 0; + } + + __syncthreads(); + + if (tid < nrows) { + T label = labels[tid] - (predout[0] / nrows); + atomicAdd(&shmemmse, F::exec(label)); + } + + __syncthreads(); + + if (threadIdx.x == 0) { + atomicAdd(mseout, shmemmse); + } + + return; +} + +template +void gini(int *labels_in, const int nrows, + const std::shared_ptr> tempmem, + MetricInfo &split_info, const int unique_labels) { + int *dhist = tempmem->d_hist->data(); + int *hhist = tempmem->h_hist->data(); + + CUDA_CHECK( + cudaMemsetAsync(dhist, 0, sizeof(int) * unique_labels, tempmem->stream)); + gini_kernel<<stream>>>(labels_in, nrows, unique_labels, dhist); + CUDA_CHECK(cudaGetLastError()); + MLCommon::updateHost(hhist, dhist, unique_labels, tempmem->stream); + CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); + + split_info.hist.resize(unique_labels, 0); + for (int i = 0; i < unique_labels; i++) { + split_info.hist[i] = hhist[i]; + } + + split_info.best_metric = F::exec(split_info.hist, nrows); + + return; +} + +template +void mse(T *labels_in, const int nrows, + const std::shared_ptr> tempmem, + MetricInfo &split_info) { + T *dpred = tempmem->d_predout->data(); + T *dmse = tempmem->d_mseout->data(); + T *hmse = tempmem->h_mseout->data(); + T *hpred = tempmem->h_predout->data(); + + CUDA_CHECK(cudaMemsetAsync(dpred, 0, sizeof(T), tempmem->stream)); + CUDA_CHECK(cudaMemsetAsync(dmse, 0, sizeof(T), tempmem->stream)); + + pred_kernel<<stream>>>( + labels_in, nrows, dpred); + CUDA_CHECK(cudaGetLastError()); + mse_kernel<<stream>>>( + labels_in, nrows, dpred, dmse); + CUDA_CHECK(cudaGetLastError()); + + MLCommon::updateHost(hmse, dmse, 1, tempmem->stream); + MLCommon::updateHost(hpred, dpred, 1, tempmem->stream); + CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); + + split_info.best_metric = + (float)hmse[0] / (float)nrows; //Update split metric value + split_info.predict = hpred[0] / (T)nrows; + return; +} diff --git a/cpp/src/decisiontree/kernels/gini_def.h b/cpp/src/decisiontree/kernels/metric_def.h similarity index 71% rename from cpp/src/decisiontree/kernels/gini_def.h rename to cpp/src/decisiontree/kernels/metric_def.h index ca89ae8f8a..68a8d379b7 100644 --- a/cpp/src/decisiontree/kernels/gini_def.h +++ b/cpp/src/decisiontree/kernels/metric_def.h @@ -15,11 +15,13 @@ */ #pragma once +#include #include #include +#include "../memory.h" template -struct GiniQuestion { +struct MetricQuestion { int bootstrapped_column; int original_column; T value; @@ -45,8 +47,30 @@ struct GiniQuestion { T cfg_value); }; -struct GiniInfo { - float best_gini = -1.0f; +template +struct MetricInfo { + float best_metric = -1.0f; + T predict = 0; std::vector - hist; //Element hist[i] stores # labels with label i for a given node. + hist; //Element hist[i] stores # labels with label i for a given node. for classification +}; + +struct SquareFunctor { + template + static __device__ __forceinline__ T exec(T x); +}; + +struct AbsFunctor { + template + static __device__ __forceinline__ T exec(T x); +}; + +struct GiniFunctor { + static float exec(std::vector& hist, int nrows); + static float max_val(int nclass); +}; + +struct EntropyFunctor { + static float exec(std::vector& hist, int nrows); + static float max_val(int nclass); }; diff --git a/cpp/src/decisiontree/kernels/quantile.cuh b/cpp/src/decisiontree/kernels/quantile.cuh index e2b0b255ae..649090bceb 100644 --- a/cpp/src/decisiontree/kernels/quantile.cuh +++ b/cpp/src/decisiontree/kernels/quantile.cuh @@ -17,6 +17,7 @@ #pragma once #include "col_condenser.cuh" #include "cub/cub.cuh" +#include "quantile.h" __global__ void set_sorting_offset(const int nrows, const int ncols, int *offsets) { @@ -39,41 +40,58 @@ __global__ void get_all_quantiles(const T *__restrict__ data, T *quantile, return; } -template +template void preprocess_quantile(const T *data, const unsigned int *rowids, const int n_sampled_rows, const int ncols, const int rowoffset, const int nbins, - std::shared_ptr> tempmem) { + std::shared_ptr> tempmem) { + /* + // Dynamically determine batch_cols (number of columns processed per loop iteration) from the available device memory. + size_t free_mem, total_mem; + CUDA_CHECK(cudaMemGetInfo(&free_mem, &total_mem)); + int max_ncols = free_mem / (2 * n_sampled_rows * sizeof(T)); + int batch_cols = (max_ncols > ncols) ? ncols : max_ncols; + ASSERT(max_ncols != 0, "Cannot preprocess quantiles due to insufficient device memory."); + */ + int batch_cols = + 1; // Processing one column at a time, for now, until an appropriate getMemInfo function is provided for the deviceAllocator interface. + int threads = 128; - int num_items = - n_sampled_rows * - ncols; // number of items to sort across all segments (i.e., cols) - int num_segments = ncols; MLCommon::device_buffer *d_offsets; MLCommon::device_buffer *d_keys_out; T *d_keys_in = tempmem->temp_data->data(); - int *colids = nullptr; + unsigned int *colids = nullptr; d_offsets = new MLCommon::device_buffer( - tempmem->ml_handle.getDeviceAllocator(), tempmem->stream, num_segments + 1); - d_keys_out = new MLCommon::device_buffer( - tempmem->ml_handle.getDeviceAllocator(), tempmem->stream, num_items); + tempmem->ml_handle.getDeviceAllocator(), tempmem->stream, batch_cols + 1); int blocks = MLCommon::ceildiv(ncols * n_sampled_rows, threads); allcolsampler_kernel<<stream>>>( - data, rowids, colids, n_sampled_rows, ncols, rowoffset, d_keys_in); + data, rowids, colids, n_sampled_rows, ncols, rowoffset, + d_keys_in); // d_keys_in already allocated for all ncols CUDA_CHECK(cudaGetLastError()); - blocks = MLCommon::ceildiv(ncols + 1, threads); + + blocks = MLCommon::ceildiv(batch_cols + 1, threads); set_sorting_offset<<stream>>>( - n_sampled_rows, ncols, d_offsets->data()); + n_sampled_rows, batch_cols, d_offsets->data()); CUDA_CHECK(cudaGetLastError()); // Determine temporary device storage requirements MLCommon::device_buffer *d_temp_storage = nullptr; size_t temp_storage_bytes = 0; + + int batch_cnt = + MLCommon::ceildiv(ncols, batch_cols); // number of loop iterations + int last_batch_size = + ncols - batch_cols * (batch_cnt - 1); // number of columns in last batch + int batch_items = + n_sampled_rows * batch_cols; // used to determine d_temp_storage size + + d_keys_out = new MLCommon::device_buffer( + tempmem->ml_handle.getDeviceAllocator(), tempmem->stream, batch_items); CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortKeys( d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out->data(), - num_items, num_segments, d_offsets->data(), d_offsets->data() + 1, 0, + batch_items, batch_cols, d_offsets->data(), d_offsets->data() + 1, 0, 8 * sizeof(T), tempmem->stream)); // Allocate temporary storage @@ -81,19 +99,32 @@ void preprocess_quantile(const T *data, const unsigned int *rowids, new MLCommon::device_buffer(tempmem->ml_handle.getDeviceAllocator(), tempmem->stream, temp_storage_bytes); - // Run sorting operation - CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortKeys( - (void *)d_temp_storage->data(), temp_storage_bytes, d_keys_in, - d_keys_out->data(), num_items, num_segments, d_offsets->data(), - d_offsets->data() + 1, 0, 8 * sizeof(T), tempmem->stream)); - - blocks = MLCommon::ceildiv(ncols * nbins, threads); - get_all_quantiles<<stream>>>( - d_keys_out->data(), tempmem->d_quantile->data(), n_sampled_rows, ncols, - nbins); - CUDA_CHECK(cudaGetLastError()); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); + // Compute quantiles for cur_batch_cols columns per loop iteration. + for (int batch = 0; batch < batch_cnt; batch++) { + int cur_batch_cols = (batch == batch_cnt - 1) + ? last_batch_size + : batch_cols; // properly handle the last batch + + int batch_offset = batch * n_sampled_rows * batch_cols; + int quantile_offset = batch * nbins * batch_cols; + // Run sorting operation + CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortKeys( + (void *)d_temp_storage->data(), temp_storage_bytes, + &d_keys_in[batch_offset], d_keys_out->data(), n_sampled_rows * batch_cols, + cur_batch_cols, d_offsets->data(), d_offsets->data() + 1, 0, + 8 * sizeof(T), tempmem->stream)); + + blocks = MLCommon::ceildiv(cur_batch_cols * nbins, threads); + get_all_quantiles<<stream>>>( + d_keys_out->data(), &tempmem->d_quantile->data()[quantile_offset], + n_sampled_rows, cur_batch_cols, nbins); + + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); + } + MLCommon::updateHost(tempmem->h_quantile->data(), tempmem->d_quantile->data(), + nbins * ncols, tempmem->stream); d_keys_out->release(tempmem->stream); d_offsets->release(tempmem->stream); d_temp_storage->release(tempmem->stream); diff --git a/cpp/src/decisiontree/kernels/quantile.h b/cpp/src/decisiontree/kernels/quantile.h new file mode 100644 index 0000000000..61cac17b13 --- /dev/null +++ b/cpp/src/decisiontree/kernels/quantile.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "../memory.h" +template +void preprocess_quantile(const T *data, const unsigned int *rowids, + const int n_sampled_rows, const int ncols, + const int rowoffset, const int nbins, + std::shared_ptr> tempmem); diff --git a/cpp/src/decisiontree/kernels/split_labels.cuh b/cpp/src/decisiontree/kernels/split_labels.cuh index fae33c5ab5..702e871dee 100644 --- a/cpp/src/decisiontree/kernels/split_labels.cuh +++ b/cpp/src/decisiontree/kernels/split_labels.cuh @@ -19,7 +19,7 @@ #include #include "../algo_helper.h" #include "cub/cub.cuh" -#include "gini.cuh" +#include "metric.cuh" template __global__ void flag_kernel(T* column, char* leftflag, char* rightflag, @@ -79,12 +79,12 @@ int get_class_hist(std::vector& node_hist) { return classval; } -template -void make_split(T* column, GiniQuestion& ques, const int nrows, +template +void make_split(T* column, MetricQuestion& ques, const int nrows, int& nrowsleft, int& nrowsright, unsigned int* rowids, int split_algo, - const std::shared_ptr> tempmem) { - int* temprowids = tempmem->temprowids->data(); + const std::shared_ptr> tempmem) { + unsigned int* temprowids = tempmem->temprowids->data(); char* d_flags_left = tempmem->d_flags_left->data(); char* d_flags_right = tempmem->d_flags_right->data(); T* question_value = tempmem->question_value->data(); @@ -109,26 +109,23 @@ void make_split(T* column, GiniQuestion& ques, const int nrows, cub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, rowids, d_flags_left, temprowids, d_num_selected_out, - nrows); - CUDA_CHECK(cudaMemcpyAsync(&nrowsleftright[0], d_num_selected_out, - sizeof(int), cudaMemcpyDeviceToHost, - tempmem->stream)); + nrows, tempmem->stream); + MLCommon::updateHost(&nrowsleftright[0], d_num_selected_out, 1, + tempmem->stream); CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); nrowsleft = nrowsleftright[0]; cub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, rowids, d_flags_right, &temprowids[nrowsleft], - d_num_selected_out, nrows); - CUDA_CHECK(cudaMemcpyAsync(&nrowsleftright[1], d_num_selected_out, - sizeof(int), cudaMemcpyDeviceToHost, - tempmem->stream)); - CUDA_CHECK(cudaMemcpyAsync(rowids, temprowids, nrows * sizeof(int), - cudaMemcpyDeviceToDevice, tempmem->stream)); + d_num_selected_out, nrows, tempmem->stream); + MLCommon::updateHost(&nrowsleftright[1], d_num_selected_out, 1, + tempmem->stream); + MLCommon::copyAsync(rowids, temprowids, nrows, tempmem->stream); // Copy GPU-computed question value to tree node. - if (split_algo == ML::SPLIT_ALGO::HIST) - CUDA_CHECK(cudaMemcpyAsync(&(ques.value), question_value, sizeof(T), - cudaMemcpyDeviceToHost, tempmem->stream)); + if (split_algo == ML::SPLIT_ALGO::HIST) { + MLCommon::updateHost(&(ques.value), question_value, 1, tempmem->stream); + } CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); nrowsright = nrowsleftright[1]; diff --git a/cpp/src/decisiontree/memory.cu b/cpp/src/decisiontree/memory.cu deleted file mode 100644 index d4e0dc85ec..0000000000 --- a/cpp/src/decisiontree/memory.cu +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Copyright (c) 2019, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include "cub/cub.cuh" -#include -#include "common/cumlHandle.hpp" -#include -#include -#include "memory.h" -#include - -template -TemporaryMemory::TemporaryMemory(const ML::cumlHandle_impl& handle, int N, int Ncols, int maxstr, int n_unique, int n_bins, const int split_algo):ml_handle(handle) - { - - //Assign Stream from cumlHandle - stream = ml_handle.getStream(); - - int n_hist_elements = n_unique * n_bins; - - h_hist = new MLCommon::host_buffer(handle.getHostAllocator(), stream, n_hist_elements); - d_hist = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, n_hist_elements); - nrowsleftright = new MLCommon::host_buffer(handle.getHostAllocator(), stream, 2); - int extra_elements = Ncols; - int quantile_elements = (split_algo == ML::SPLIT_ALGO::GLOBAL_QUANTILE) ? extra_elements : 1; - - temp_data = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, N * extra_elements); - totalmem += n_hist_elements * sizeof(int) + N * extra_elements * sizeof(T); - - if (split_algo == ML::SPLIT_ALGO::GLOBAL_QUANTILE) { - d_quantile = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, n_bins * quantile_elements); - d_temp_sampledcolumn = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, N * extra_elements); - totalmem += (n_bins + N) * extra_elements * sizeof(T); - } - - sampledlabels = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, N); - totalmem += N*sizeof(int); - - //Allocate Temporary for split functions - d_num_selected_out = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, 1); - d_flags_left = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, N); - d_flags_right = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, N); - temprowids = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, N); - question_value = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, 1); - - cub::DeviceSelect::Flagged(d_split_temp_storage, split_temp_storage_bytes, temprowids->data(), d_flags_left->data(), temprowids->data(), d_num_selected_out->data(), N); - d_split_temp_storage = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, split_temp_storage_bytes); - - totalmem += split_temp_storage_bytes + (N + 1)*sizeof(int) + 2*N*sizeof(char) + sizeof(T); - - h_histout = new MLCommon::host_buffer(handle.getHostAllocator(), stream, n_hist_elements * Ncols); - - d_globalminmax = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, Ncols * 2); - d_histout = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, n_hist_elements * Ncols); - d_colids = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, Ncols); - totalmem += (n_hist_elements * sizeof(int) + sizeof(int) + 2*sizeof(T))* Ncols; - - } - -template -void TemporaryMemory::print_info() - { - std::cout <<" Inside the print_info function \n" << std::flush; - std::cout << " Total temporary memory usage--> "<< ((double)totalmem/ (1024*1024)) << " MB" << std::endl; - return; - } - -template -TemporaryMemory::~TemporaryMemory() - { - - h_hist->release(stream); - d_hist->release(stream); - nrowsleftright->release(stream); - temp_data->release(stream); - - delete h_hist; - delete d_hist; - delete temp_data; - - if (d_quantile != nullptr) { - d_quantile->release(stream); - delete d_quantile; - } - if (d_temp_sampledcolumn != nullptr) { - d_temp_sampledcolumn->release(stream); - delete d_temp_sampledcolumn; - } - - sampledlabels->release(stream); - d_split_temp_storage->release(stream); - d_num_selected_out->release(stream); - d_flags_left->release(stream); - d_flags_right->release(stream); - temprowids->release(stream); - question_value->release(stream); - h_histout->release(stream); - - delete sampledlabels; - delete d_split_temp_storage; - delete d_num_selected_out; - delete d_flags_left; - delete d_flags_right; - delete temprowids; - delete question_value; - delete h_histout; - - d_globalminmax->release(stream); - d_histout->release(stream); - d_colids->release(stream); - - delete d_globalminmax; - delete d_histout; - delete d_colids; - - } - diff --git a/cpp/src/decisiontree/memory.cuh b/cpp/src/decisiontree/memory.cuh new file mode 100644 index 0000000000..f6de336080 --- /dev/null +++ b/cpp/src/decisiontree/memory.cuh @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include "cub/cub.cuh" +#include "memory.h" + +template +TemporaryMemory::TemporaryMemory(const ML::cumlHandle_impl& handle, int N, + int Ncols, int maxstr, int n_unique, + int n_bins, const int split_algo) + : ml_handle(handle) { + //Assign Stream from cumlHandle + stream = ml_handle.getStream(); + + int n_hist_elements = n_unique * n_bins; + + h_hist = new MLCommon::host_buffer(handle.getHostAllocator(), stream, + n_hist_elements); + d_hist = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, + n_hist_elements); + nrowsleftright = + new MLCommon::host_buffer(handle.getHostAllocator(), stream, 2); + + int extra_elements = Ncols; + int quantile_elements = + (split_algo == ML::SPLIT_ALGO::GLOBAL_QUANTILE) ? extra_elements : 1; + + temp_data = new MLCommon::device_buffer(handle.getDeviceAllocator(), + stream, N * Ncols); + totalmem += n_hist_elements * sizeof(int) + N * extra_elements * sizeof(T); + + if (split_algo == ML::SPLIT_ALGO::GLOBAL_QUANTILE) { + h_quantile = new MLCommon::host_buffer(handle.getHostAllocator(), stream, + n_bins * quantile_elements); + d_quantile = new MLCommon::device_buffer( + handle.getDeviceAllocator(), stream, n_bins * quantile_elements); + totalmem += n_bins * extra_elements * sizeof(T); + } + + sampledlabels = + new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, N); + totalmem += N * sizeof(L); + + //Allocate Temporary for split functions + d_num_selected_out = + new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, 1); + d_flags_left = + new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, N); + d_flags_right = + new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, N); + temprowids = new MLCommon::device_buffer( + handle.getDeviceAllocator(), stream, N); + question_value = + new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, 1); + + cub::DeviceSelect::Flagged(d_split_temp_storage, split_temp_storage_bytes, + temprowids->data(), d_flags_left->data(), + temprowids->data(), d_num_selected_out->data(), N, + stream); + d_split_temp_storage = new MLCommon::device_buffer( + handle.getDeviceAllocator(), stream, split_temp_storage_bytes); + + totalmem += split_temp_storage_bytes + (N + 1) * sizeof(int) + + 2 * N * sizeof(char) + sizeof(T); + + h_histout = new MLCommon::host_buffer(handle.getHostAllocator(), stream, + n_hist_elements * Ncols); + int mse_elements = Ncols * n_bins; + h_mseout = new MLCommon::host_buffer(handle.getHostAllocator(), stream, + 2 * mse_elements); + h_predout = new MLCommon::host_buffer(handle.getHostAllocator(), stream, + mse_elements); + + d_globalminmax = new MLCommon::device_buffer(handle.getDeviceAllocator(), + stream, Ncols * 2); + d_histout = new MLCommon::device_buffer(handle.getDeviceAllocator(), + stream, n_hist_elements * Ncols); + d_mseout = new MLCommon::device_buffer(handle.getDeviceAllocator(), stream, + 2 * mse_elements); + d_predout = new MLCommon::device_buffer(handle.getDeviceAllocator(), + stream, mse_elements); + + d_colids = new MLCommon::device_buffer( + handle.getDeviceAllocator(), stream, Ncols); + // memory of d_histout + d_colids + d_globalminmax + (d_mseout + d_predout) + totalmem += (n_hist_elements * sizeof(int) + sizeof(unsigned int) + + 2 * sizeof(T) + 3 * n_bins * sizeof(T)) * + Ncols; + + //this->print_info(); +} + +template +void TemporaryMemory::print_info() { + std::cout << " Total temporary memory usage--> " + << ((double)totalmem / (1024 * 1024)) << " MB" << std::endl; +} + +template +TemporaryMemory::~TemporaryMemory() { + h_hist->release(stream); + d_hist->release(stream); + nrowsleftright->release(stream); + temp_data->release(stream); + + delete h_hist; + delete d_hist; + delete temp_data; + + if (d_quantile != nullptr) { + d_quantile->release(stream); + h_quantile->release(stream); + delete h_quantile; + delete d_quantile; + } + + sampledlabels->release(stream); + d_split_temp_storage->release(stream); + d_num_selected_out->release(stream); + d_flags_left->release(stream); + d_flags_right->release(stream); + temprowids->release(stream); + question_value->release(stream); + h_histout->release(stream); + h_mseout->release(stream); + h_predout->release(stream); + + delete sampledlabels; + delete d_split_temp_storage; + delete d_num_selected_out; + delete d_flags_left; + delete d_flags_right; + delete temprowids; + delete question_value; + delete h_histout; + delete h_mseout; + delete h_predout; + + d_globalminmax->release(stream); + d_histout->release(stream); + d_mseout->release(stream); + d_predout->release(stream); + d_colids->release(stream); + + delete d_globalminmax; + delete d_histout; + delete d_mseout; + delete d_predout; + delete d_colids; +} diff --git a/cpp/src/decisiontree/memory.h b/cpp/src/decisiontree/memory.h index 2324f91e89..803b4bf44b 100644 --- a/cpp/src/decisiontree/memory.h +++ b/cpp/src/decisiontree/memory.h @@ -14,55 +14,54 @@ * limitations under the License. */ - #pragma once #include -#include "common/cumlHandle.hpp" #include #include +#include "common/cumlHandle.hpp" +template +struct TemporaryMemory { + // Labels after boostrapping + MLCommon::device_buffer *sampledlabels; -template -struct TemporaryMemory -{ - // Labels after boostrapping - MLCommon::device_buffer *sampledlabels; - - // Used for gini histograms (root tree node) - MLCommon::device_buffer *d_hist; - MLCommon::host_buffer *h_hist; + // Used for gini histograms (root tree node) + MLCommon::device_buffer *d_hist; + MLCommon::host_buffer *h_hist; - //Host/Device histograms and device minmaxs - MLCommon::device_buffer *d_globalminmax; - MLCommon::device_buffer *d_histout, *d_colids; - MLCommon::host_buffer *h_histout; + //Host/Device histograms and device minmaxs + MLCommon::device_buffer *d_globalminmax; + MLCommon::device_buffer *d_histout; + MLCommon::device_buffer *d_colids; + MLCommon::host_buffer *h_histout; + MLCommon::device_buffer *d_mseout, *d_predout; + MLCommon::host_buffer *h_mseout, *h_predout; - //Below pointers are shared for split functions - MLCommon::device_buffer *d_flags_left, *d_flags_right; - MLCommon::host_buffer *nrowsleftright; - MLCommon::device_buffer *d_split_temp_storage = nullptr; - size_t split_temp_storage_bytes = 0; + //Below pointers are shared for split functions + MLCommon::device_buffer *d_flags_left, *d_flags_right; + MLCommon::host_buffer *nrowsleftright; + MLCommon::device_buffer *d_split_temp_storage = nullptr; + size_t split_temp_storage_bytes = 0; - MLCommon::device_buffer *d_num_selected_out, *temprowids; - MLCommon::device_buffer *question_value, *temp_data; + MLCommon::device_buffer *d_num_selected_out; + MLCommon::device_buffer *temprowids; + MLCommon::device_buffer *question_value, *temp_data; - //Total temp mem - size_t totalmem = 0; + //Total temp mem + size_t totalmem = 0; - //CUDA stream - cudaStream_t stream; + //CUDA stream + cudaStream_t stream; - //For quantiles - MLCommon::device_buffer *d_quantile = nullptr; - MLCommon::device_buffer *d_temp_sampledcolumn = nullptr; - - const ML::cumlHandle_impl& ml_handle; - - TemporaryMemory(const ML::cumlHandle_impl& handle, int N, int Ncols, int maxstr, int n_unique, int n_bins, const int split_algo); - - void print_info(); - ~TemporaryMemory(); + //For quantiles + MLCommon::device_buffer *d_quantile = nullptr; + MLCommon::host_buffer *h_quantile = nullptr; -}; + const ML::cumlHandle_impl &ml_handle; + TemporaryMemory(const ML::cumlHandle_impl &handle, int N, int Ncols, + int maxstr, int n_unique, int n_bins, const int split_algo); + void print_info(); + ~TemporaryMemory(); +}; diff --git a/cpp/src/knn/knn.hpp b/cpp/src/knn/knn.hpp index bd29806d4a..024ebe598a 100644 --- a/cpp/src/knn/knn.hpp +++ b/cpp/src/knn/knn.hpp @@ -18,20 +18,19 @@ #include "common/cumlHandle.hpp" -#include -#include #include +#include +#include -#include #include #include +#include #include namespace ML { - - /** +/** * @brief Flat C++ API function to perform a brute force knn on * a series of input arrays and combine the results into a single * output array for indexes and distances. @@ -47,13 +46,11 @@ namespace ML { * @param res_D the resulting distance array of size n * k * @param k the number of nearest neighbors to return */ - void brute_force_knn( - cumlHandle &handle, - float **input, int*sizes, int n_params, int D, - float *search_items, int n, - long *res_I, float *res_D, int k); +void brute_force_knn(cumlHandle &handle, float **input, int *sizes, + int n_params, int D, float *search_items, int n, + long *res_I, float *res_D, int k); - /** +/** * @brief A flat C++ API function that chunks a host array up into * some number of different devices * @@ -65,36 +62,32 @@ namespace ML { * @param sizes output array sizes * @param n_chunks number of chunks to spread across device arrays */ - void chunk_host_array( - cumlHandle &handle, - const float *ptr, int n, int D, - int* devices, float **output, int *sizes, int n_chunks); - - class kNN { +void chunk_host_array(cumlHandle &handle, const float *ptr, int n, int D, + int *devices, float **output, int *sizes, int n_chunks); - float **ptrs; - int *sizes; +class kNN { + float **ptrs; + int *sizes; - int total_n; - int indices; - int D; - bool verbose; - bool owner; + int total_n; + int indices; + int D; + bool verbose; + bool owner; - cumlHandle *handle; + cumlHandle *handle; - public: - /** + public: + /** * Build a kNN object for training and querying a k-nearest neighbors model. * @param D number of features in each vector */ - kNN(const cumlHandle &handle, int D, bool verbose = false); - ~kNN(); - - void reset(); + kNN(const cumlHandle &handle, int D, bool verbose = false); + ~kNN(); + void reset(); - /** + /** * Search the kNN for the k-nearest neighbors of a set of query vectors * @param search_items set of vectors to query for neighbors * @param n number of items in search_items @@ -102,18 +95,18 @@ namespace ML { * @param res_D pointer to device memory for returning k nearest distances * @param k number of neighbors to query */ - void search(float *search_items, int search_items_size, - long *res_I, float *res_D, int k); + void search(float *search_items, int search_items_size, long *res_I, + float *res_D, int k); - /** + /** * Fit a kNN model by creating separate indices for multiple given * instances of kNNParams. * @param input an array of pointers to data on (possibly different) devices * @param N number of items in input array. */ - void fit(float **input, int *sizes, int N); + void fit(float **input, int *sizes, int N); - /** + /** * Chunk a host array up into one or many GPUs (determined by the provided * list of gpu ids) and fit a knn model. * @@ -123,7 +116,6 @@ namespace ML { * @param n_chunks number of elements in gpus * @param out host pointer to copy output */ - void fit_from_host(float *ptr, int n, int* devices, int n_chunks); - }; + void fit_from_host(float *ptr, int n, int *devices, int n_chunks); }; - +}; // namespace ML diff --git a/cpp/src/knn/knn_api.h b/cpp/src/knn/knn_api.h index d01cf5e990..4404ad7c4a 100644 --- a/cpp/src/knn/knn_api.h +++ b/cpp/src/knn/knn_api.h @@ -22,7 +22,7 @@ extern "C" { #endif - /** +/** * @brief Flat C API function to perform a brute force knn on * a series of input arrays and combine the results into a single * output array for indexes and distances. @@ -38,12 +38,9 @@ extern "C" { * @param res_D the resulting distance array of size n * k * @param k the number of nearest neighbors to return */ -cumlError_t knn_search( - const cumlHandle_t handle, - float **input, int *size, int n_params, int D, - const float *search_items, int n, - long *res_I, float *res_D, int k -); +cumlError_t knn_search(const cumlHandle_t handle, float **input, int *size, + int n_params, int D, const float *search_items, int n, + long *res_I, float *res_D, int k); /** * @brief A flat C++ API function that chunks a host array up into @@ -57,11 +54,9 @@ cumlError_t knn_search( * @param sizes output array sizes * @param n_chunks number of chunks to spread across device arrays */ -cumlError_t chunk_host_array( - const cumlHandle_t handle, - const float *ptr, int n, int D, - int* devices, float **output, int *sizes, int n_chunks, -); +cumlError_t chunk_host_array(const cumlHandle_t handle, const float *ptr, int n, + int D, int *devices, float **output, int *sizes, + int n_chunks, ); #ifdef __cplusplus } diff --git a/cpp/src/metrics/trustworthiness.h b/cpp/src/metrics/trustworthiness.h index 2f8c06cb03..9c5fb82ed4 100644 --- a/cpp/src/metrics/trustworthiness.h +++ b/cpp/src/metrics/trustworthiness.h @@ -22,8 +22,8 @@ namespace ML { namespace Metrics { - template - double trustworthiness_score(const cumlHandle& h, math_t* X, - math_t* X_embedded, int n, int m, int d, int n_neighbors); -} +template +double trustworthiness_score(const cumlHandle& h, math_t* X, math_t* X_embedded, + int n, int m, int d, int n_neighbors); } +} // namespace ML diff --git a/cpp/src/ml_cuda_utils.h b/cpp/src/ml_cuda_utils.h index 0e8aa889f4..7892bde5f9 100644 --- a/cpp/src/ml_cuda_utils.h +++ b/cpp/src/ml_cuda_utils.h @@ -21,8 +21,6 @@ namespace ML { - - int get_device(const void *ptr) { cudaPointerAttributes att; cudaPointerGetAttributes(&att, ptr); @@ -30,20 +28,20 @@ int get_device(const void *ptr) { } cudaMemoryType memory_type(const void *p) { - cudaPointerAttributes att; - cudaError_t err = cudaPointerGetAttributes(&att, p); - ASSERT(err == cudaSuccess || - err == cudaErrorInvalidValue, "%s", cudaGetErrorString(err)); - - if (err == cudaErrorInvalidValue) { - // Make sure the current thread error status has been reset - err = cudaGetLastError(); - ASSERT(err == cudaErrorInvalidValue, "%s", cudaGetErrorString(err)); - } - #if CUDA_VERSION >= 10000 - return att.type; - #else - return att.memoryType; - #endif + cudaPointerAttributes att; + cudaError_t err = cudaPointerGetAttributes(&att, p); + ASSERT(err == cudaSuccess || err == cudaErrorInvalidValue, "%s", + cudaGetErrorString(err)); + + if (err == cudaErrorInvalidValue) { + // Make sure the current thread error status has been reset + err = cudaGetLastError(); + ASSERT(err == cudaErrorInvalidValue, "%s", cudaGetErrorString(err)); } +#if CUDA_VERSION >= 10000 + return att.type; +#else + return att.memoryType; +#endif } +} // namespace ML diff --git a/cpp/src/randomforest/randomforest.cu b/cpp/src/randomforest/randomforest.cu index a7f4a29e0a..595369ba8f 100644 --- a/cpp/src/randomforest/randomforest.cu +++ b/cpp/src/randomforest/randomforest.cu @@ -14,16 +14,12 @@ * limitations under the License. */ -#include -#include -#include -#include -#include -#include -#include -#include +#include "../decisiontree/kernels/quantile.h" +#include "../decisiontree/memory.h" +#include "random/permute.h" #include "random/rng.h" #include "randomforest.h" +#include "score/scores.h" namespace ML { @@ -31,12 +27,35 @@ namespace ML { * @brief Construct RF_metrics. * @param[in] cfg_accuracy: accuracy. */ -RF_metrics::RF_metrics(float cfg_accuracy) : accuracy(cfg_accuracy){}; +RF_metrics::RF_metrics(float cfg_accuracy) + : rf_type(RF_type::CLASSIFICATION), accuracy(cfg_accuracy){}; /** - * @brief Print accuracy metric. + * @brief Construct RF_metrics. + * @param[in] cfg_mean_abs_error: mean absolute error. + * @param[in] cfg_mean_squared_error: mean squared error. + * @param[in] cfg_median_abs_error: median absolute error. */ -void RF_metrics::print() { std::cout << "Accuracy: " << accuracy << std::endl; } +RF_metrics::RF_metrics(double cfg_mean_abs_error, double cfg_mean_squared_error, + double cfg_median_abs_error) + : rf_type(RF_type::REGRESSION), + mean_abs_error(cfg_mean_abs_error), + mean_squared_error(cfg_mean_squared_error), + median_abs_error(cfg_median_abs_error){}; + +/** + * @brief Print either accuracy metric for classification, or mean absolute error, mean squared error, + and median absolute error metrics for regression. + */ +void RF_metrics::print() { + if (rf_type == RF_type::CLASSIFICATION) { + std::cout << "Accuracy: " << accuracy << std::endl; + } else if (rf_type == RF_type::REGRESSION) { + std::cout << "Mean Absolute Error: " << mean_abs_error << std::endl; + std::cout << "Mean Squared Error: " << mean_squared_error << std::endl; + std::cout << "Median Absolute Error: " << median_abs_error << std::endl; + } +} /** * @brief Update labels so they are unique from 0 to n_unique_labels values. @@ -92,9 +111,10 @@ void postprocess_labels(int n_rows, std::vector& labels, } /** - * @brief Random forest default constructor. + * @brief Random forest hyper-parameter object default constructor (1 tree). */ RF_params::RF_params() : n_trees(1) {} + /** * @brief Random forest hyper-parameter object constructor to set n_trees member. */ @@ -150,39 +170,34 @@ void RF_params::print() const { /** * @brief Construct rf (random forest) object. * @tparam T: data type for input data (float or double). + * @tparam L: data type for labels (int type for classification, T type for regression). * @param[in] cfg_rf_params: Random forest hyper-parameter struct. * @param[in] cfg_rf_type: Random forest type. Only CLASSIFICATION is currently supported. */ -template -rf::rf(RF_params cfg_rf_params, int cfg_rf_type) - : rf_params(cfg_rf_params), rf_type(cfg_rf_type), trees(nullptr) { +template +rf::rf(RF_params cfg_rf_params, int cfg_rf_type) + : rf_params(cfg_rf_params), rf_type(cfg_rf_type) { rf_params.validity_check(); } -/** - * @brief Destructor for random forest object. - * @tparam T: data type for input data (float or double). - */ -template -rf::~rf() { - delete[] trees; -} - /** * @brief Return number of trees in the forest. * @tparam T: data type for input data (float or double). + * @tparam L: data type for labels (int type for classification, T type for regression). */ -template -int rf::get_ntrees() { +template +int rf::get_ntrees() { return rf_params.n_trees; } /** * @brief Print summary for all trees in the random forest. * @tparam T: data type for input data (float or double). + * @tparam L: data type for labels (int type for classification, T type for regression). */ -template -void rf::print_rf_summary() { +template +void rf::print_rf_summary() { + const DecisionTree::DecisionTreeBase* trees = get_trees_ptr(); if (!trees) { std::cout << "Empty forest" << std::endl; } else { @@ -200,9 +215,11 @@ void rf::print_rf_summary() { /** * @brief Print detailed view of all trees in the random forest. * @tparam T: data type for input data (float or double). + * @tparam L: data type for labels (int type for classification, T type for regression). */ -template -void rf::print_rf_detailed() { +template +void rf::print_rf_detailed() { + const DecisionTree::DecisionTreeBase* trees = get_trees_ptr(); if (!trees) { std::cout << "Empty forest" << std::endl; } else { @@ -217,6 +234,88 @@ void rf::print_rf_detailed() { } } +/** + * @brief Sample row IDs for tree fitting and bootstrap if requested. + * @tparam T: data type for input data (float or double). + * @tparam L: data type for labels (int type for classification, T type for regression). + * @param[in] handle: cumlHandle + * @param[in] tree_id: unique tree ID + * @param[in] n_rows: total number of data samples. + * @param[in] n_sampled_rows: number of rows used for training + * @param[in, out] selected_rows: already allocated array w/ row IDs + * @param[in, out] sorted_selected_rows: already allocated array. Will contain sorted row IDs. + * @param[in, out] rows_temp_storage: temp. storage used for sorting (previously allocated). + * @param[in] temp_storage_bytes: size in bytes of rows_temp_storage. + */ +template +void rf::prepare_fit_per_tree(const ML::cumlHandle_impl& handle, + int tree_id, int n_rows, int n_sampled_rows, + unsigned int* selected_rows, + unsigned int* sorted_selected_rows, + char* rows_temp_storage, + size_t temp_storage_bytes) { + cudaStream_t stream = handle.getStream(); + + if (rf_params.bootstrap) { + MLCommon::Random::Rng r( + tree_id * + 1000); // Ensure the seed for each tree is different and meaningful. + r.uniformInt(selected_rows, n_sampled_rows, (unsigned int)0, + (unsigned int)n_rows, stream); + //thrust::sequence(thrust::cuda::par.on(stream), sorted_selected_rows, + // sorted_selected_rows + n_sampled_rows); + + CUDA_CHECK(cub::DeviceRadixSort::SortKeys( + (void*)rows_temp_storage, temp_storage_bytes, selected_rows, + sorted_selected_rows, n_sampled_rows, 0, 8 * sizeof(unsigned int), + stream)); + } else { // Sampling w/o replacement + MLCommon::device_buffer* inkeys = + new MLCommon::device_buffer(handle.getDeviceAllocator(), + stream, n_rows); + MLCommon::device_buffer* outkeys = + new MLCommon::device_buffer(handle.getDeviceAllocator(), + stream, n_rows); + thrust::sequence(thrust::cuda::par.on(stream), inkeys->data(), + inkeys->data() + n_rows); + int* perms = nullptr; + MLCommon::Random::permute(perms, outkeys->data(), inkeys->data(), 1, n_rows, + false, stream); + // outkeys has more rows than selected_rows; doing the shuffling before the resize to differentiate the per-tree rows sample. + CUDA_CHECK(cub::DeviceRadixSort::SortKeys( + (void*)rows_temp_storage, temp_storage_bytes, outkeys->data(), + sorted_selected_rows, n_sampled_rows, 0, 8 * sizeof(unsigned int), + stream)); + inkeys->release(stream); + outkeys->release(stream); + delete inkeys; + delete outkeys; + } +} + +template +void rf::error_checking(const T* input, L* predictions, int n_rows, + int n_cols, bool predict) const { + if (predict) { + ASSERT(get_trees_ptr(), "Cannot predict! No trees in the forest."); + ASSERT(predictions != nullptr, + "Error! User has not allocated memory for predictions."); + } else { + ASSERT(!get_trees_ptr(), "Cannot fit an existing forest."); + } + ASSERT((n_rows > 0), "Invalid n_rows %d", n_rows); + ASSERT((n_cols > 0), "Invalid n_cols %d", n_cols); + + bool input_is_dev_ptr = is_dev_ptr(input); + bool preds_is_dev_ptr = is_dev_ptr(predictions); + + if (!input_is_dev_ptr || (input_is_dev_ptr != preds_is_dev_ptr)) { + ASSERT(false, + "RF Error: Expected both input and labels/predictions to be GPU " + "pointers"); + } +} + /** * @brief Construct rfClassifier object. * @tparam T: data type for input data (float or double). @@ -224,7 +323,26 @@ void rf::print_rf_detailed() { */ template rfClassifier::rfClassifier(RF_params cfg_rf_params) - : rf::rf(cfg_rf_params, RF_type::CLASSIFICATION){}; + : rf::rf(cfg_rf_params, RF_type::CLASSIFICATION){}; + +/** + * @brief Destructor for random forest classifier object. + * @tparam T: data type for input data (float or double). + */ +template +rfClassifier::~rfClassifier() { + delete[] trees; +} + +/** + * @brief Return a const pointer to decision trees. + * @tparam T: data type for input data (float or double). + */ +template +const DecisionTree::DecisionTreeClassifier* rfClassifier::get_trees_ptr() + const { + return trees; +} /** * @brief Build (i.e., fit, train) random forest classifier for input data. @@ -241,71 +359,92 @@ rfClassifier::rfClassifier(RF_params cfg_rf_params) template void rfClassifier::fit(const cumlHandle& user_handle, T* input, int n_rows, int n_cols, int* labels, int n_unique_labels) { - ASSERT(!this->trees, "Cannot fit an existing forest."); - ASSERT((n_rows > 0), "Invalid n_rows %d", n_rows); - ASSERT((n_cols > 0), "Invalid n_cols %d", n_cols); + this->error_checking(input, labels, n_rows, n_cols, false); - rfClassifier::trees = - new DecisionTree::DecisionTreeClassifier[this->rf_params.n_trees]; + trees = new DecisionTree::DecisionTreeClassifier[this->rf_params.n_trees]; int n_sampled_rows = this->rf_params.rows_sample * n_rows; + const cumlHandle_impl& handle = user_handle.getImpl(); cudaStream_t stream = user_handle.getStream(); + + // Select n_sampled_rows (with replacement) numbers from [0, n_rows) per tree. + // selected_rows: randomly generated IDs for bootstrapped samples (w/ replacement); a device ptr. + MLCommon::device_buffer selected_rows( + handle.getDeviceAllocator(), stream, n_sampled_rows); + MLCommon::device_buffer sorted_selected_rows( + handle.getDeviceAllocator(), stream, n_sampled_rows); + + // Will sort selected_rows (row IDs), prior to fit, to improve access patterns + MLCommon::device_buffer* rows_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + CUDA_CHECK(cub::DeviceRadixSort::SortKeys( + rows_temp_storage, temp_storage_bytes, selected_rows.data(), + sorted_selected_rows.data(), n_sampled_rows, 0, 8 * sizeof(unsigned int), + stream)); + // Allocate temporary storage + rows_temp_storage = new MLCommon::device_buffer( + handle.getDeviceAllocator(), stream, temp_storage_bytes); + std::shared_ptr> tempmem = + std::make_shared>( + user_handle.getImpl(), n_sampled_rows, n_cols, 1, n_unique_labels, + this->rf_params.tree_params.n_bins, + this->rf_params.tree_params.split_algo); + if ((this->rf_params.tree_params.split_algo == SPLIT_ALGO::GLOBAL_QUANTILE) && + !(this->rf_params.tree_params.quantile_per_tree)) { + preprocess_quantile(input, nullptr, n_sampled_rows, n_cols, n_rows, + this->rf_params.tree_params.n_bins, tempmem); + } for (int i = 0; i < this->rf_params.n_trees; i++) { - // Select n_sampled_rows (with replacement) numbers from [0, n_rows) per tree. - // selected_rows: randomly generated IDs for bootstrapped samples (w/ replacement); a device ptr. - MLCommon::device_buffer selected_rows( - handle.getDeviceAllocator(), stream, n_sampled_rows); - - if (this->rf_params.bootstrap) { - MLCommon::Random::Rng r( - i * - 1000); // Ensure the seed for each tree is different and meaningful. - r.uniformInt(selected_rows.data(), n_sampled_rows, (unsigned int)0, - (unsigned int)n_rows, stream); - } else { - std::vector h_selected_rows(n_rows); - std::iota(h_selected_rows.begin(), h_selected_rows.end(), 0); - std::random_shuffle(h_selected_rows.begin(), h_selected_rows.end()); - h_selected_rows.resize(n_sampled_rows); - MLCommon::updateDevice(selected_rows.data(), h_selected_rows.data(), - n_sampled_rows, stream); - } + this->prepare_fit_per_tree(handle, i, n_rows, n_sampled_rows, + selected_rows.data(), + sorted_selected_rows.data(), + rows_temp_storage->data(), temp_storage_bytes); /* Build individual tree in the forest. - input is a pointer to orig data that have n_cols features and n_rows rows. - n_sampled_rows: # rows sampled for tree's bootstrap sample. - - selected_rows: points to a list of row #s (w/ n_sampled_rows elements) used to build the bootstrapped sample. - Expectation: Each tree node will contain (a) # n_sampled_rows and (b) a pointer to a list of row numbers w.r.t original data. + - sorted_selected_rows: points to a list of row #s (w/ n_sampled_rows elements) used to build the bootstrapped sample. + Expectation: Each tree node will contain (a) # n_sampled_rows and (b) a pointer to a list of row numbers w.r.t original data. */ - this->trees[i].fit(user_handle, input, n_cols, n_rows, labels, - selected_rows.data(), n_sampled_rows, n_unique_labels, - this->rf_params.tree_params); - //Cleanup - selected_rows.release(stream); + trees[i].fit(user_handle, input, n_cols, n_rows, labels, + sorted_selected_rows.data(), n_sampled_rows, n_unique_labels, + this->rf_params.tree_params, tempmem); } + + //Cleanup + rows_temp_storage->release(stream); + selected_rows.release(stream); + sorted_selected_rows.release(stream); + tempmem.reset(); + delete rows_temp_storage; } /** * @brief Predict target feature for input data; n-ary classification for single feature supported. * @tparam T: data type for input data (float or double). - * @param[in] user_handle: cumlHandle (currently unused; API placeholder) - * @param[in] input: test data (n_rows samples, n_cols features) in row major format. CPU pointer. + * @param[in] user_handle: cumlHandle. + * @param[in] input: test data (n_rows samples, n_cols features) in row major format. GPU pointer. * @param[in] n_rows: number of data samples. * @param[in] n_cols: number of features (excluding target feature). - * @param[in, out] predictions: n_rows predicted labels. CPU pointer, user allocated. + * @param[in, out] predictions: n_rows predicted labels. GPU pointer, user allocated. * @param[in] verbose: flag for debugging purposes. */ template void rfClassifier::predict(const cumlHandle& user_handle, const T* input, int n_rows, int n_cols, int* predictions, bool verbose) const { - ASSERT(this->trees, "Cannot predict! No trees in the forest."); - ASSERT((n_rows > 0), "Invalid n_rows %d", n_rows); - ASSERT((n_cols > 0), "Invalid n_cols %d", n_cols); - ASSERT(predictions != nullptr, - "Error! User has not allocated memory for predictions."); + this->error_checking(input, predictions, n_rows, n_cols, true); + + std::vector h_predictions(n_rows); + const cumlHandle_impl& handle = user_handle.getImpl(); + cudaStream_t stream = user_handle.getStream(); + + std::vector h_input(n_rows * n_cols); + MLCommon::updateHost(h_input.data(), input, n_rows * n_cols, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + int row_size = n_cols; for (int row_id = 0; row_id < n_rows; row_id++) { @@ -313,7 +452,7 @@ void rfClassifier::predict(const cumlHandle& user_handle, const T* input, std::cout << "\n\n"; std::cout << "Predict for sample: "; for (int i = 0; i < n_cols; i++) - std::cout << input[row_id * row_size + i] << ", "; + std::cout << h_input[row_id * row_size + i] << ", "; std::cout << std::endl; } @@ -326,11 +465,11 @@ void rfClassifier::predict(const cumlHandle& user_handle, const T* input, //Return prediction for one sample. if (verbose) { std::cout << "Printing tree " << i << std::endl; - //this->trees[i].print(); + trees[i].print(); } int prediction; - this->trees[i].predict(user_handle, &input[row_id * row_size], 1, n_cols, - &prediction, verbose); + trees[i].predict(user_handle, &h_input[row_id * row_size], 1, n_cols, + &prediction, verbose); ret = prediction_to_cnt.insert(std::pair(prediction, 1)); if (!(ret.second)) { ret.first->second += 1; @@ -341,19 +480,22 @@ void rfClassifier::predict(const cumlHandle& user_handle, const T* input, } } - predictions[row_id] = majority_prediction; + h_predictions[row_id] = majority_prediction; } + + MLCommon::updateDevice(predictions, h_predictions.data(), n_rows, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); } /** * @brief Predict target feature for input data and validate against ref_labels. * @tparam T: data type for input data (float or double). - * @param[in] user_handle: cumlHandle (currently unused; API placeholder) - * @param[in] input: test data (n_rows samples, n_cols features) in row major format. CPU pointer. - * @param[in] ref_labels: label values for cross validation (n_rows elements); CPU pointer. + * @param[in] user_handle: cumlHandle. + * @param[in] input: test data (n_rows samples, n_cols features) in row major format. GPU pointer. + * @param[in] ref_labels: label values for cross validation (n_rows elements); GPU pointer. * @param[in] n_rows: number of data samples. * @param[in] n_cols: number of features (excluding target feature). - * @param[in, out] predictions: n_rows predicted labels. CPU pointer, user allocated. + * @param[in, out] predictions: n_rows predicted labels. GPU pointer, user allocated. * @param[in] verbose: flag for debugging purposes. */ template @@ -362,12 +504,10 @@ RF_metrics rfClassifier::score(const cumlHandle& user_handle, const T* input, int* predictions, bool verbose) const { predict(user_handle, input, n_rows, n_cols, predictions, verbose); - unsigned long long correctly_predicted = 0ULL; - for (int i = 0; i < n_rows; i++) { - correctly_predicted += (predictions[i] == ref_labels[i]); - } - - float accuracy = correctly_predicted * 1.0f / n_rows; + cudaStream_t stream = user_handle.getImpl().getStream(); + auto d_alloc = user_handle.getDeviceAllocator(); + float accuracy = MLCommon::Score::accuracy_score(predictions, ref_labels, + n_rows, d_alloc, stream); RF_metrics stats(accuracy); if (verbose) stats.print(); @@ -376,14 +516,209 @@ RF_metrics rfClassifier::score(const cumlHandle& user_handle, const T* input, return stats; } -template class rf; -template class rf; +/** + * @brief Construct rfRegressor object. + * @tparam T: data type for input data (float or double). + * @param[in] cfg_rf_params: Random forest hyper-parameter struct. + */ +template +rfRegressor::rfRegressor(RF_params cfg_rf_params) + : rf::rf(cfg_rf_params, RF_type::REGRESSION){}; + +/** + * @brief Destructor for random forest regressor object. + * @tparam T: data type for input data (float or double). + */ +template +rfRegressor::~rfRegressor() { + delete[] trees; +} + +/** + * @brief Return a const pointer to decision trees. + * @tparam T: data type for input data (float or double). + */ +template +const DecisionTree::DecisionTreeRegressor* rfRegressor::get_trees_ptr() + const { + return trees; +} + +/** + * @brief Build (i.e., fit, train) random forest regressor for input data. + * @tparam T: data type for input data (float or double). + * @param[in] user_handle: cumlHandle + * @param[in] input: train data (n_rows samples, n_cols features) in column major format, excluding labels. Device pointer. + * @param[in] n_rows: number of training data samples. + * @param[in] n_cols: number of features (i.e., columns) excluding target feature. + * @param[in] labels: 1D array of target features (float or double), with one label per training sample. Device pointer. + */ +template +void rfRegressor::fit(const cumlHandle& user_handle, T* input, int n_rows, + int n_cols, T* labels) { + this->error_checking(input, labels, n_rows, n_cols, false); + + trees = new DecisionTree::DecisionTreeRegressor[this->rf_params.n_trees]; + + int n_sampled_rows = this->rf_params.rows_sample * n_rows; + + const cumlHandle_impl& handle = user_handle.getImpl(); + cudaStream_t stream = user_handle.getStream(); + + // Select n_sampled_rows (with replacement) numbers from [0, n_rows) per tree. + // selected_rows: randomly generated IDs for bootstrapped samples (w/ replacement); a device ptr. + MLCommon::device_buffer selected_rows( + handle.getDeviceAllocator(), stream, n_sampled_rows); + MLCommon::device_buffer sorted_selected_rows( + handle.getDeviceAllocator(), stream, n_sampled_rows); + + // Will sort selected_rows (row IDs), prior to fit, to improve access patterns + MLCommon::device_buffer* rows_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + CUDA_CHECK(cub::DeviceRadixSort::SortKeys( + rows_temp_storage, temp_storage_bytes, selected_rows.data(), + sorted_selected_rows.data(), n_sampled_rows, 0, 8 * sizeof(unsigned int), + stream)); + // Allocate temporary storage + rows_temp_storage = new MLCommon::device_buffer( + handle.getDeviceAllocator(), stream, temp_storage_bytes); + std::shared_ptr> tempmem = + std::make_shared>( + user_handle.getImpl(), n_sampled_rows, n_cols, 1, 1, + this->rf_params.tree_params.n_bins, + this->rf_params.tree_params.split_algo); + + if ((this->rf_params.tree_params.split_algo == SPLIT_ALGO::GLOBAL_QUANTILE) && + !(this->rf_params.tree_params.quantile_per_tree)) { + preprocess_quantile(input, nullptr, n_sampled_rows, n_cols, n_rows, + this->rf_params.tree_params.n_bins, tempmem); + } + for (int i = 0; i < this->rf_params.n_trees; i++) { + this->prepare_fit_per_tree(handle, i, n_rows, n_sampled_rows, + selected_rows.data(), + sorted_selected_rows.data(), + rows_temp_storage->data(), temp_storage_bytes); + + /* Build individual tree in the forest. + - input is a pointer to orig data that have n_cols features and n_rows rows. + - n_sampled_rows: # rows sampled for tree's bootstrap sample. + - sorted_selected_rows: points to a list of row #s (w/ n_sampled_rows elements) used to build the bootstrapped sample. + Expectation: Each tree node will contain (a) # n_sampled_rows and (b) a pointer to a list of row numbers w.r.t original data. + */ + + trees[i].fit(user_handle, input, n_cols, n_rows, labels, + sorted_selected_rows.data(), n_sampled_rows, + this->rf_params.tree_params, tempmem); + } + //Cleanup + rows_temp_storage->release(stream); + selected_rows.release(stream); + sorted_selected_rows.release(stream); + tempmem.reset(); + delete rows_temp_storage; +} + +/** + * @brief Predict target feature for input data; regression for single feature supported. + * @tparam T: data type for input data (float or double). + * @param[in] user_handle: cumlHandle. + * @param[in] input: test data (n_rows samples, n_cols features) in row major format. GPU pointer. + * @param[in] n_rows: number of data samples. + * @param[in] n_cols: number of features (excluding target feature). + * @param[in, out] predictions: n_rows predicted labels. GPU pointer, user allocated. + * @param[in] verbose: flag for debugging purposes. + */ +template +void rfRegressor::predict(const cumlHandle& user_handle, const T* input, + int n_rows, int n_cols, T* predictions, + bool verbose) const { + this->error_checking(input, predictions, n_rows, n_cols, true); + + std::vector h_predictions(n_rows); + const cumlHandle_impl& handle = user_handle.getImpl(); + cudaStream_t stream = user_handle.getStream(); + + std::vector h_input(n_rows * n_cols); + MLCommon::updateHost(h_input.data(), input, n_rows * n_cols, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + int row_size = n_cols; + + for (int row_id = 0; row_id < n_rows; row_id++) { + if (verbose) { + std::cout << "\n\n"; + std::cout << "Predict for sample: "; + for (int i = 0; i < n_cols; i++) + std::cout << h_input[row_id * row_size + i] << ", "; + std::cout << std::endl; + } + + T sum_predictions = 0; + + for (int i = 0; i < this->rf_params.n_trees; i++) { + //Return prediction for one sample. + if (verbose) { + std::cout << "Printing tree " << i << std::endl; + trees[i].print(); + } + T prediction; + trees[i].predict(user_handle, &h_input[row_id * row_size], 1, n_cols, + &prediction, verbose); + sum_predictions += prediction; + } + // Random forest's prediction is the arithmetic mean of all its decision tree predictions. + h_predictions[row_id] = sum_predictions / this->rf_params.n_trees; + } + + MLCommon::updateDevice(predictions, h_predictions.data(), n_rows, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +/** + * @brief Predict target feature for input data and validate against ref_labels. + * @tparam T: data type for input data (float or double). + * @param[in] user_handle: cumlHandle. + * @param[in] input: test data (n_rows samples, n_cols features) in row major format. GPU pointer. + * @param[in] ref_labels: label values for cross validation (n_rows elements); GPU pointer. + * @param[in] n_rows: number of data samples. + * @param[in] n_cols: number of features (excluding target feature). + * @param[in, out] predictions: n_rows predicted labels. GPU pointer, user allocated. + * @param[in] verbose: flag for debugging purposes. + */ +template +RF_metrics rfRegressor::score(const cumlHandle& user_handle, const T* input, + const T* ref_labels, int n_rows, int n_cols, + T* predictions, bool verbose) const { + predict(user_handle, input, n_rows, n_cols, predictions, verbose); + + cudaStream_t stream = user_handle.getImpl().getStream(); + auto d_alloc = user_handle.getDeviceAllocator(); + + double mean_abs_error, mean_squared_error, median_abs_error; + MLCommon::Score::regression_metrics(predictions, ref_labels, n_rows, d_alloc, + stream, mean_abs_error, + mean_squared_error, median_abs_error); + RF_metrics stats(mean_abs_error, mean_squared_error, median_abs_error); + if (verbose) stats.print(); + + return stats; +} + +template class rf; +template class rf; +template class rf; +template class rf; template class rfClassifier; template class rfClassifier; +template class rfRegressor; +template class rfRegressor; + // Stateless API functions: fit, predict and score +// ----------------------------- Classification ----------------------------------- // + /** * @brief Build (i.e., fit, train) random forest classifier for input data of type float. * @param[in] user_handle: cumlHandle @@ -424,12 +759,12 @@ void fit(const cumlHandle& user_handle, rfClassifier* rf_classifier, /** * @brief Predict target feature for input data of type float; n-ary classification for single feature supported. - * @param[in] user_handle: cumlHandle (currently unused; API placeholder) + * @param[in] user_handle: cumlHandle. * @param[in] rf_classifier: pointer to the rfClassifier object. The user should have previously called fit to build the random forest. - * @param[in] input: test data (n_rows samples, n_cols features) in row major format. CPU pointer. + * @param[in] input: test data (n_rows samples, n_cols features) in row major format. GPU pointer. * @param[in] n_rows: number of data samples. * @param[in] n_cols: number of features (excluding target feature). - * @param[in, out] predictions: n_rows predicted labels. CPU pointer, user allocated. + * @param[in, out] predictions: n_rows predicted labels. GPU pointer, user allocated. * @param[in] verbose: flag for debugging purposes. */ void predict(const cumlHandle& user_handle, @@ -441,12 +776,12 @@ void predict(const cumlHandle& user_handle, /** * @brief Predict target feature for input data of type double; n-ary classification for single feature supported. - * @param[in] user_handle: cumlHandle (currently unused; API placeholder) + * @param[in] user_handle: cumlHandle. * @param[in] rf_classifier: pointer to the rfClassifier object. The user should have previously called fit to build the random forest. - * @param[in] input: test data (n_rows samples, n_cols features) in row major format. CPU pointer. + * @param[in] input: test data (n_rows samples, n_cols features) in row major format. GPU pointer. * @param[in] n_rows: number of data samples. * @param[in] n_cols: number of features (excluding target feature). - * @param[in, out] predictions: n_rows predicted labels. CPU pointer, user allocated. + * @param[in, out] predictions: n_rows predicted labels. GPU pointer, user allocated. * @param[in] verbose: flag for debugging purposes. */ void predict(const cumlHandle& user_handle, @@ -458,13 +793,13 @@ void predict(const cumlHandle& user_handle, /** * @brief Predict target feature for input data of type float and validate against ref_labels. - * @param[in] user_handle: cumlHandle (currently unused; API placeholder) + * @param[in] user_handle: cumlHandle. * @param[in] rf_classifier: pointer to the rfClassifier object. The user should have previously called fit to build the random forest. - * @param[in] input: test data (n_rows samples, n_cols features) in row major format. CPU pointer. - * @param[in] ref_labels: label values for cross validation (n_rows elements); CPU pointer. + * @param[in] input: test data (n_rows samples, n_cols features) in row major format. GPU pointer. + * @param[in] ref_labels: label values for cross validation (n_rows elements); GPU pointer. * @param[in] n_rows: number of data samples. * @param[in] n_cols: number of features (excluding target feature). - * @param[in, out] predictions: n_rows predicted labels. CPU pointer, user allocated. + * @param[in, out] predictions: n_rows predicted labels. GPU pointer, user allocated. * @param[in] verbose: flag for debugging purposes. */ RF_metrics score(const cumlHandle& user_handle, @@ -477,13 +812,13 @@ RF_metrics score(const cumlHandle& user_handle, /** * @brief Predict target feature for input data of type double and validate against ref_labels. - * @param[in] user_handle: cumlHandle (currently unused; API placeholder) + * @param[in] user_handle: cumlHandle. * @param[in] rf_classifier: pointer to the rfClassifier object. The user should have previously called fit to build the random forest. - * @param[in] input: test data (n_rows samples, n_cols features) in row major format. CPU pointer. - * @param[in] ref_labels: label values for cross validation (n_rows elements); CPU pointer. + * @param[in] input: test data (n_rows samples, n_cols features) in row major format. GPU pointer. + * @param[in] ref_labels: label values for cross validation (n_rows elements); GPU pointer. * @param[in] n_rows: number of data samples. * @param[in] n_cols: number of features (excluding target feature). - * @param[in, out] predictions: n_rows predicted labels. CPU pointer, user allocated. + * @param[in, out] predictions: n_rows predicted labels. GPU pointer, user allocated. * @param[in] verbose: flag for debugging purposes. */ RF_metrics score(const cumlHandle& user_handle, @@ -497,13 +832,117 @@ RF_metrics score(const cumlHandle& user_handle, RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features, int n_bins, int split_algo, int min_rows_per_node, bool bootstrap_features, bool bootstrap, int n_trees, - float rows_sample) { + float rows_sample, CRITERION split_criterion, + bool quantile_per_tree) { DecisionTree::DecisionTreeParams tree_params( max_depth, max_leaves, max_features, n_bins, split_algo, min_rows_per_node, - bootstrap_features); + bootstrap_features, split_criterion, quantile_per_tree); RF_params rf_params(bootstrap, bootstrap_features, n_trees, rows_sample, tree_params); return rf_params; } + +// ----------------------------- Regression ----------------------------------- // + +/** + * @brief Build (i.e., fit, train) random forest regressor for input data of type float. + * @param[in] user_handle: cumlHandle + * @param[in,out] rf_regreesor: pointer to the rfRegressor object, previously constructed by the user. + * @param[in] input: train data (n_rows samples, n_cols features) in column major format, excluding labels. Device pointer. + * @param[in] n_rows: number of training data samples. + * @param[in] n_cols: number of features (i.e., columns) excluding target feature. + * @param[in] labels: 1D array of target features (float), with one label per training sample. Device pointer. + */ +void fit(const cumlHandle& user_handle, rfRegressor* rf_regressor, + float* input, int n_rows, int n_cols, float* labels) { + rf_regressor->fit(user_handle, input, n_rows, n_cols, labels); +} + +/** + * @brief Build (i.e., fit, train) random forest regressor for input data of type double. + * @param[in] user_handle: cumlHandle + * @param[in,out] rf_regressor: pointer to the rfRegressor object, previously constructed by the user. + * @param[in] input: train data (n_rows samples, n_cols features) in column major format, excluding labels. Device pointer. + * @param[in] n_rows: number of training data samples. + * @param[in] n_cols: number of features (i.e., columns) excluding target feature. + * @param[in] labels: 1D array of target features (double), with one label per training sample. Device pointer. + */ +void fit(const cumlHandle& user_handle, rfRegressor* rf_regressor, + double* input, int n_rows, int n_cols, double* labels) { + rf_regressor->fit(user_handle, input, n_rows, n_cols, labels); +} + +/** + * @brief Predict target feature for input data of type float; regression for single feature supported. + * @param[in] user_handle: cumlHandle. + * @param[in] rf_regressor: pointer to the rfRegressor object. The user should have previously called fit to build the random forest. + * @param[in] input: test data (n_rows samples, n_cols features) in row major format. GPU pointer. + * @param[in] n_rows: number of data samples. + * @param[in] n_cols: number of features (excluding target feature). + * @param[in, out] predictions: n_rows predicted labels. GPU pointer, user allocated. + * @param[in] verbose: flag for debugging purposes. + */ +void predict(const cumlHandle& user_handle, + const rfRegressor* rf_regressor, const float* input, + int n_rows, int n_cols, float* predictions, bool verbose) { + rf_regressor->predict(user_handle, input, n_rows, n_cols, predictions, + verbose); +} + +/** + * @brief Predict target feature for input data of type double; regression for single feature supported. + * @param[in] user_handle: cumlHandle. + * @param[in] rf_regressor: pointer to the rfRegressor object. The user should have previously called fit to build the random forest. + * @param[in] input: test data (n_rows samples, n_cols features) in row major format. GPU pointer. + * @param[in] n_rows: number of data samples. + * @param[in] n_cols: number of features (excluding target feature). + * @param[in, out] predictions: n_rows predicted labels. GPU pointer, user allocated. + * @param[in] verbose: flag for debugging purposes. + */ +void predict(const cumlHandle& user_handle, + const rfRegressor* rf_regressor, const double* input, + int n_rows, int n_cols, double* predictions, bool verbose) { + rf_regressor->predict(user_handle, input, n_rows, n_cols, predictions, + verbose); +} + +/** + * @brief Predict target feature for input data of type float and validate against ref_labels. + * @param[in] user_handle: cumlHandle. + * @param[in] rf_regressor: pointer to the rfRegressor object. The user should have previously called fit to build the random forest. + * @param[in] input: test data (n_rows samples, n_cols features) in row major format. GPU pointer. + * @param[in] ref_labels: label values for cross validation (n_rows elements); GPU pointer. + * @param[in] n_rows: number of data samples. + * @param[in] n_cols: number of features (excluding target feature). + * @param[in, out] predictions: n_rows predicted labels. GPU pointer, user allocated. + * @param[in] verbose: flag for debugging purposes. + */ +RF_metrics score(const cumlHandle& user_handle, + const rfRegressor* rf_regressor, const float* input, + const float* ref_labels, int n_rows, int n_cols, + float* predictions, bool verbose) { + return rf_regressor->score(user_handle, input, ref_labels, n_rows, n_cols, + predictions, verbose); +} + +/** + * @brief Predict target feature for input data of type double and validate against ref_labels. + * @param[in] user_handle: cumlHandle. + * @param[in] rf_regressor: pointer to the rfRegressor object. The user should have previously called fit to build the random forest. + * @param[in] input: test data (n_rows samples, n_cols features) in row major format. GPU pointer. + * @param[in] ref_labels: label values for cross validation (n_rows elements); GPU pointer. + * @param[in] n_rows: number of data samples. + * @param[in] n_cols: number of features (excluding target feature). + * @param[in, out] predictions: n_rows predicted labels. GPU pointer, user allocated. + * @param[in] verbose: flag for debugging purposes. + */ +RF_metrics score(const cumlHandle& user_handle, + const rfRegressor* rf_regressor, const double* input, + const double* ref_labels, int n_rows, int n_cols, + double* predictions, bool verbose) { + return rf_regressor->score(user_handle, input, ref_labels, n_rows, n_cols, + predictions, verbose); +} + }; // namespace ML // end namespace ML diff --git a/cpp/src/randomforest/randomforest.h b/cpp/src/randomforest/randomforest.h index d450a4e1fc..27a19c5bae 100644 --- a/cpp/src/randomforest/randomforest.h +++ b/cpp/src/randomforest/randomforest.h @@ -20,40 +20,51 @@ namespace ML { +enum RF_type { + CLASSIFICATION, + REGRESSION, +}; + struct RF_metrics { - float accuracy; + RF_type rf_type; + + // Classification metrics + float accuracy = -1.0f; + + // Regression metrics + double mean_abs_error = -1.0; + double mean_squared_error = -1.0; + double median_abs_error = -1.0; RF_metrics(float cfg_accuracy); + RF_metrics(double cfg_mean_abs_error, double cfg_mean_squared_error, + double cfg_median_abs_error); void print(); }; -enum RF_type { - CLASSIFICATION, - REGRESSION, -}; - struct RF_params { /** - * Control bootstrapping. If set, each tree in the forest is built on a bootstrapped sample with replacement. - * If false, sampling without replacement is done. - */ + * Control bootstrapping. If set, each tree in the forest is built on a bootstrapped sample with replacement. + * If false, sampling without replacement is done. + */ bool bootstrap = true; /** - * Control bootstrapping for features. If features are drawn with or without replacement - */ + * Control bootstrapping for features. If features are drawn with or without replacement + */ bool bootstrap_features = false; /** - * Number of decision trees in the random forest. - */ + * Number of decision trees in the random forest. + */ int n_trees; /** - * Ratio of dataset rows used while fitting each tree. - */ + * Ratio of dataset rows used while fitting each tree. + */ float rows_sample = 1.0f; /** - * Decision tree traingin hyper parameter struct. - */ + * Decision tree training hyper parameter struct. + */ DecisionTree::DecisionTreeParams tree_params; + RF_params(); RF_params(int cfg_n_trees); RF_params(bool cfg_bootstrap, bool cfg_bootstrap_features, int cfg_n_trees, @@ -65,16 +76,34 @@ struct RF_params { void print() const; }; -template +/* Update labels so they are unique from 0 to n_unique_vals. + Create an old_label to new_label map per random forest. +*/ +void preprocess_labels(int n_rows, std::vector& labels, + std::map& labels_map, bool verbose = false); + +/* Revert preprocessing effect, if needed. */ +void postprocess_labels(int n_rows, std::vector& labels, + std::map& labels_map, bool verbose = false); + +template class rf { protected: RF_params rf_params; int rf_type; - DecisionTree::DecisionTreeClassifier* trees; + virtual const DecisionTree::DecisionTreeBase* get_trees_ptr() const = 0; + virtual ~rf() = default; + void prepare_fit_per_tree(const ML::cumlHandle_impl& handle, int tree_id, + int n_rows, int n_sampled_rows, + unsigned int* selected_rows, + unsigned int* sorted_selected_rows, + char* rows_temp_storage, size_t temp_storage_bytes); + + void error_checking(const T* input, L* predictions, int n_rows, int n_cols, + bool is_predict) const; public: rf(RF_params cfg_rf_params, int cfg_rf_type = RF_type::CLASSIFICATION); - ~rf(); int get_ntrees(); void print_rf_summary(); @@ -82,9 +111,14 @@ class rf { }; template -class rfClassifier : public rf { +class rfClassifier : public rf { + private: + DecisionTree::DecisionTreeClassifier* trees = nullptr; + const DecisionTree::DecisionTreeClassifier* get_trees_ptr() const; + public: rfClassifier(RF_params cfg_rf_params); + ~rfClassifier(); void fit(const cumlHandle& user_handle, T* input, int n_rows, int n_cols, int* labels, int n_unique_labels); @@ -95,17 +129,29 @@ class rfClassifier : public rf { int* predictions, bool verbose = false) const; }; -/* Update labels so they are unique from 0 to n_unique_vals. - Create an old_label to new_label map per random forest. -*/ -void preprocess_labels(int n_rows, std::vector& labels, - std::map& labels_map, bool verbose = false); +template +class rfRegressor : public rf { + private: + DecisionTree::DecisionTreeRegressor* trees = nullptr; + const DecisionTree::DecisionTreeRegressor* get_trees_ptr() const; -/* Revert preprocessing effect, if needed. */ -void postprocess_labels(int n_rows, std::vector& labels, - std::map& labels_map, bool verbose = false); + public: + rfRegressor(RF_params cfg_rf_params); + ~rfRegressor(); + + void fit(const cumlHandle& user_handle, T* input, int n_rows, int n_cols, + T* labels); + void predict(const cumlHandle& user_handle, const T* input, int n_rows, + int n_cols, T* predictions, bool verbose = false) const; + RF_metrics score(const cumlHandle& user_handle, const T* input, + const T* ref_labels, int n_rows, int n_cols, + T* predictions, bool verbose = false) const; +}; // Stateless API functions: fit, predict and score. + +// ----------------------------- Classification ----------------------------------- // + void fit(const cumlHandle& user_handle, rfClassifier* rf_classifier, float* input, int n_rows, int n_cols, int* labels, int n_unique_labels); @@ -132,5 +178,31 @@ RF_metrics score(const cumlHandle& user_handle, RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features, int n_bins, int split_algo, int min_rows_per_node, bool bootstrap_features, bool bootstrap, int n_trees, - float rows_sample); + float rows_sample, CRITERION split_criterion, + bool quantile_per_tree); + +// ----------------------------- Regression ----------------------------------- // + +void fit(const cumlHandle& user_handle, rfRegressor* rf_regressor, + float* input, int n_rows, int n_cols, float* labels); +void fit(const cumlHandle& user_handle, rfRegressor* rf_regressor, + double* input, int n_rows, int n_cols, double* labels); + +void predict(const cumlHandle& user_handle, + const rfRegressor* rf_regressor, const float* input, + int n_rows, int n_cols, float* predictions, bool verbose = false); +void predict(const cumlHandle& user_handle, + const rfRegressor* rf_regressor, const double* input, + int n_rows, int n_cols, double* predictions, bool verbose = false); + +RF_metrics score(const cumlHandle& user_handle, + const rfRegressor* rf_regressor, + const float* input, const float* ref_labels, + int n_rows, int n_cols, float* predictions, + bool verbose = false); +RF_metrics score(const cumlHandle& user_handle, + const rfRegressor* rf_regressor, + const double* input, const double* ref_labels, + int n_rows, int n_cols, double* predictions, + bool verbose = false); }; // namespace ML diff --git a/cpp/src/spectral/spectral.h b/cpp/src/spectral/spectral.h index f3207b6d36..3de9a06940 100644 --- a/cpp/src/spectral/spectral.h +++ b/cpp/src/spectral/spectral.h @@ -174,15 +174,15 @@ void fit_clusters(const cumlHandle &handle, T *X, int m, int n, int n_neighbors, MLCommon::allocate(knn_indices, m * n_neighbors); MLCommon::allocate(knn_dists, m * n_neighbors); - float **ptrs = new float*[1]; + float **ptrs = new float *[1]; int *sizes = new int[1]; ptrs[0] = X; sizes[0] = m; knn.fit(ptrs, sizes, 1); knn.search(X, m, knn_indices, knn_dists, n_neighbors); - fit_clusters(handle, knn_indices, knn_dists, m, n_neighbors, - n_clusters, eigen_tol, out); + fit_clusters(handle, knn_indices, knn_dists, m, n_neighbors, n_clusters, + eigen_tol, out); CUDA_CHECK(cudaFree(knn_indices)); CUDA_CHECK(cudaFree(knn_dists)); @@ -328,7 +328,7 @@ void fit_embedding(const cumlHandle &handle, T *X, int m, int n, MLCommon::allocate(knn_indices, m * n_neighbors); MLCommon::allocate(knn_dists, m * n_neighbors); - float **ptrs = new float*[1]; + float **ptrs = new float *[1]; int *sizes = new int[1]; ptrs[0] = X; sizes[0] = m; @@ -339,11 +339,11 @@ void fit_embedding(const cumlHandle &handle, T *X, int m, int n, fit_embedding(handle, knn_indices, knn_dists, m, n_neighbors, n_components, out); - CUDA_CHECK(cudaFree(knn_indices)); - CUDA_CHECK(cudaFree(knn_dists)); + CUDA_CHECK(cudaFree(knn_indices)); + CUDA_CHECK(cudaFree(knn_dists)); - delete ptrs; - delete sizes; + delete ptrs; + delete sizes; } } // namespace Spectral } // namespace ML diff --git a/cpp/src_prims/cuda_utils.h b/cpp/src_prims/cuda_utils.h index cca6f648db..0a7a308bad 100644 --- a/cpp/src_prims/cuda_utils.h +++ b/cpp/src_prims/cuda_utils.h @@ -106,17 +106,16 @@ std::string arr2Str(const T *arr, int size, std::string name, return ss.str(); } -template +template void ASSERT_DEVICE_MEM(T *ptr, std::string name) { - cudaPointerAttributes s_att; - cudaError_t s_err = cudaPointerGetAttributes(&s_att, ptr); + cudaPointerAttributes s_att; + cudaError_t s_err = cudaPointerGetAttributes(&s_att, ptr); - if(s_err != 0 || s_att.device == -1) - std::cout << "Invalid device pointer encountered in " << name << - ". device=" << s_att.device << ", err=" << s_err << std::endl; + if (s_err != 0 || s_att.device == -1) + std::cout << "Invalid device pointer encountered in " << name + << ". device=" << s_att.device << ", err=" << s_err << std::endl; }; - /** number of threads per warp */ static const int WarpSize = 32; diff --git a/cpp/src_prims/label/classlabels.h b/cpp/src_prims/label/classlabels.h index f7b4e40be6..bf020bde09 100644 --- a/cpp/src_prims/label/classlabels.h +++ b/cpp/src_prims/label/classlabels.h @@ -109,23 +109,23 @@ void getOvrLabels(math_t *y, int n, math_t *y_unique, int n_classes, // TODO: add one-versus-one selection: select two classes, relabel them to // +/-1, return array with the new class labels and corresponding indices. - template - __global__ void map_label_kernel(Type *map_ids, size_t N_labels, Type *in, - Type *out, size_t N, Lambda filter_op) { - int tid = threadIdx.x + blockIdx.x * TPB_X; - if (tid < N) { - if (!filter_op(in[tid])) { - for (size_t i = 0; i < N_labels; i++) { - if (in[tid] == map_ids[i]) { - out[tid] = i + 1; - break; - } +template +__global__ void map_label_kernel(Type *map_ids, size_t N_labels, Type *in, + Type *out, size_t N, Lambda filter_op) { + int tid = threadIdx.x + blockIdx.x * TPB_X; + if (tid < N) { + if (!filter_op(in[tid])) { + for (size_t i = 0; i < N_labels; i++) { + if (in[tid] == map_ids[i]) { + out[tid] = i + 1; + break; } } } } +} - /** +/** * Maps an input array containing a series of numbers into a new array * where numbers have been mapped to a monotonically increasing set * of labels. This can be useful in machine learning algorithms, for instance, @@ -142,27 +142,27 @@ void getOvrLabels(math_t *y, int n, math_t *y_unique, int n_classes, * @param filter_op an optional function for specifying which values * should have monotonically increasing labels applied to them. */ - template - void make_monotonic(Type *out, Type *in, size_t N, cudaStream_t stream, - Lambda filter_op) { - static const size_t TPB_X = 256; +template +void make_monotonic(Type *out, Type *in, size_t N, cudaStream_t stream, + Lambda filter_op) { + static const size_t TPB_X = 256; - dim3 blocks(ceildiv(N, TPB_X)); - dim3 threads(TPB_X); + dim3 blocks(ceildiv(N, TPB_X)); + dim3 threads(TPB_X); - std::shared_ptr allocator(new defaultDeviceAllocator); + std::shared_ptr allocator(new defaultDeviceAllocator); - Type *map_ids; - int num_clusters; - getUniqueLabels(in, N, &map_ids, &num_clusters, stream, allocator); + Type *map_ids; + int num_clusters; + getUniqueLabels(in, N, &map_ids, &num_clusters, stream, allocator); - map_label_kernel<<>>( - map_ids, num_clusters, in, out, N, filter_op); + map_label_kernel<<>>( + map_ids, num_clusters, in, out, N, filter_op); - allocator->deallocate(map_ids, num_clusters * sizeof(Type), stream); - } + allocator->deallocate(map_ids, num_clusters * sizeof(Type), stream); +} - /** +/** * Maps an input array containing a series of numbers into a new array * where numbers have been mapped to a monotonically increasing set * of labels. This can be useful in machine learning algorithms, for instance, @@ -177,10 +177,10 @@ void getOvrLabels(math_t *y, int n, math_t *y_unique, int n_classes, * @param N number of elements in the input array * @param stream cuda stream to use */ - template - void make_monotonic(Type *out, Type *in, size_t N, cudaStream_t stream) { - make_monotonic(out, in, N, stream, - [] __device__(Type val) { return false; }); - } +template +void make_monotonic(Type *out, Type *in, size_t N, cudaStream_t stream) { + make_monotonic(out, in, N, stream, + [] __device__(Type val) { return false; }); +} }; // namespace Label }; // end namespace MLCommon diff --git a/cpp/src_prims/score/scores.h b/cpp/src_prims/score/scores.h index 01a798b6ad..564d67c974 100644 --- a/cpp/src_prims/score/scores.h +++ b/cpp/src_prims/score/scores.h @@ -25,9 +25,9 @@ #include "common/cuml_allocator.hpp" -#include "selection/knn.h" -#include "distance/distance.h" #include +#include "distance/distance.h" +#include "selection/knn.h" #include #include @@ -35,181 +35,165 @@ #define MAX_BATCH_SIZE 512 #define N_THREADS 512 - namespace MLCommon { - namespace Score { - - /** - * @brief Compute a the rank of trustworthiness score - * @input param ind_X: indexes given by pairwise distance and sorting - * @input param ind_X_embedded: indexes given by KNN - * @input param n: Number of samples - * @input param n_neighbors: Number of neighbors considered by trustworthiness score - * @input param work: Batch to consider (to do it at once use n * n_neighbors) - * @output param rank: Resulting rank - */ - template - __global__ void compute_rank(math_t *ind_X, knn_index_t *ind_X_embedded, - int n, int n_neighbors, int work, double * rank) - { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= work) - return; - - int n_idx = i / n_neighbors; - int nn_idx = (i % n_neighbors) + 1; - - knn_index_t idx = ind_X_embedded[n_idx * (n_neighbors+1) + nn_idx]; - math_t* sample_i = &ind_X[n_idx * n]; - for (int r = 1; r < n; r++) - { - if (sample_i[r] == idx) - { - int tmp = r - n_neighbors; - if (tmp > 0) - atomicAdd(rank, tmp); - break; - } - } - } - - - /** - * @brief Compute a kNN and returns the indexes of the nearest neighbors - * @param input Input matrix holding the dataset - * @param n Number of samples - * @param d Number of features - * @param d_alloc the device allocator to use for temp device memory - * @param stream cuda stream to use - * @return Matrix holding the indexes of the nearest neighbors - */ - template - long* get_knn_indexes(math_t* input, int n, - int d, int n_neighbors, - std::shared_ptr d_alloc, - cudaStream_t stream) - { - long* d_pred_I = (long*)d_alloc->allocate(n * n_neighbors * sizeof(long), stream); - math_t* d_pred_D = (math_t*)d_alloc->allocate(n * n_neighbors * sizeof(math_t), stream); - - float **ptrs = new float*[1]; - ptrs[0] = input; - - int *sizes = new int[1]; - sizes[0] = n; - - MLCommon::Selection::brute_force_knn(ptrs, sizes, 1, d, - input, n, d_pred_I, d_pred_D, n_neighbors, stream); - - d_alloc->deallocate(d_pred_D, n * n_neighbors * sizeof(math_t), stream); - return d_pred_I; - } - - /** - * @brief Compute the trustworthiness score - * @tparam distance_type: Distance type to consider - * @param X: Data in original dimension - * @param X_embedde: Data in target dimension (embedding) - * @param n: Number of samples - * @param m: Number of features in high/original dimension - * @param d: Number of features in low/embedded dimension - * @param n_neighbors Number of neighbors considered by trustworthiness score - * @param d_alloc device allocator to use for temp device memory - * @param stream the cuda stream to use - * @return Trustworthiness score - */ - template - double trustworthiness_score(math_t* X, - math_t* X_embedded, int n, int m, int d, - int n_neighbors, - std::shared_ptr d_alloc, - cudaStream_t stream) - { - const int TMP_SIZE = MAX_BATCH_SIZE * n; - - size_t workspaceSize = 0; // EucUnexpandedL2Sqrt does not require workspace (may need change for other distances) - typedef cutlass::Shape<8, 128, 128> OutputTile_t; - bool bAllocWorkspace = false; - - math_t* d_pdist_tmp = (math_t*)d_alloc->allocate(TMP_SIZE * sizeof(math_t), stream); - int* d_ind_X_tmp = (int*)d_alloc->allocate(TMP_SIZE * sizeof(int), stream); - - long* ind_X_embedded = get_knn_indexes( - X_embedded, - n, d, n_neighbors + 1, - d_alloc, stream); - - double t_tmp = 0.0; - double t = 0.0; - double* d_t = (double*)d_alloc->allocate(sizeof(double), stream); - - int toDo = n; - while (toDo > 0) - { - int batchSize = min(toDo, MAX_BATCH_SIZE); - // Takes at most MAX_BATCH_SIZE vectors at a time - - MLCommon::Distance::distance - (&X[(n - toDo) * m], X, - d_pdist_tmp, - batchSize, n, m, - (void*)nullptr, workspaceSize, - stream - ); - CUDA_CHECK(cudaPeekAtLastError()); - - MLCommon::Selection::sortColumnsPerRow(d_pdist_tmp, d_ind_X_tmp, - batchSize, n, - bAllocWorkspace, NULL, workspaceSize, - stream); - CUDA_CHECK(cudaPeekAtLastError()); - - t_tmp = 0.0; - updateDevice(d_t, &t_tmp, 1, stream); - - int work = batchSize * n_neighbors; - int n_blocks = work / N_THREADS + 1; - compute_rank<<>>(d_ind_X_tmp, - &ind_X_embedded[(n - toDo) * (n_neighbors+1)], - n, - n_neighbors, - batchSize * n_neighbors, - d_t); - CUDA_CHECK(cudaPeekAtLastError()); - - updateHost(&t_tmp, d_t, 1, stream); - t += t_tmp; - - toDo -= batchSize; - } - - t = 1.0 - ((2.0 / ((n * n_neighbors) * ((2.0 * n) - (3.0 * n_neighbors) - 1.0))) * t); - - d_alloc->deallocate(ind_X_embedded, n * (n_neighbors + 1) * sizeof(long), stream); - d_alloc->deallocate(d_pdist_tmp, TMP_SIZE * sizeof(math_t), stream); - d_alloc->deallocate(d_ind_X_tmp, TMP_SIZE * sizeof(int), stream); - d_alloc->deallocate(d_t, sizeof(double), stream); - - return t; - } +namespace Score { + +/** + * @brief Compute a the rank of trustworthiness score + * @input param ind_X: indexes given by pairwise distance and sorting + * @input param ind_X_embedded: indexes given by KNN + * @input param n: Number of samples + * @input param n_neighbors: Number of neighbors considered by trustworthiness score + * @input param work: Batch to consider (to do it at once use n * n_neighbors) + * @output param rank: Resulting rank + */ +template +__global__ void compute_rank(math_t *ind_X, knn_index_t *ind_X_embedded, int n, + int n_neighbors, int work, double *rank) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= work) return; + + int n_idx = i / n_neighbors; + int nn_idx = (i % n_neighbors) + 1; + + knn_index_t idx = ind_X_embedded[n_idx * (n_neighbors + 1) + nn_idx]; + math_t *sample_i = &ind_X[n_idx * n]; + for (int r = 1; r < n; r++) { + if (sample_i[r] == idx) { + int tmp = r - n_neighbors; + if (tmp > 0) atomicAdd(rank, tmp); + break; + } + } +} +/** + * @brief Compute a kNN and returns the indexes of the nearest neighbors + * @param input Input matrix holding the dataset + * @param n Number of samples + * @param d Number of features + * @param d_alloc the device allocator to use for temp device memory + * @param stream cuda stream to use + * @return Matrix holding the indexes of the nearest neighbors + */ +template +long *get_knn_indexes(math_t *input, int n, int d, int n_neighbors, + std::shared_ptr d_alloc, + cudaStream_t stream) { + long *d_pred_I = + (long *)d_alloc->allocate(n * n_neighbors * sizeof(long), stream); + math_t *d_pred_D = + (math_t *)d_alloc->allocate(n * n_neighbors * sizeof(math_t), stream); + + float **ptrs = new float *[1]; + ptrs[0] = input; + + int *sizes = new int[1]; + sizes[0] = n; + + MLCommon::Selection::brute_force_knn(ptrs, sizes, 1, d, input, n, d_pred_I, + d_pred_D, n_neighbors, stream); + + d_alloc->deallocate(d_pred_D, n * n_neighbors * sizeof(math_t), stream); + return d_pred_I; +} +/** + * @brief Compute the trustworthiness score + * @tparam distance_type: Distance type to consider + * @param X: Data in original dimension + * @param X_embedde: Data in target dimension (embedding) + * @param n: Number of samples + * @param m: Number of features in high/original dimension + * @param d: Number of features in low/embedded dimension + * @param n_neighbors Number of neighbors considered by trustworthiness score + * @param d_alloc device allocator to use for temp device memory + * @param stream the cuda stream to use + * @return Trustworthiness score + */ +template +double trustworthiness_score(math_t *X, math_t *X_embedded, int n, int m, int d, + int n_neighbors, + std::shared_ptr d_alloc, + cudaStream_t stream) { + const int TMP_SIZE = MAX_BATCH_SIZE * n; + + size_t workspaceSize = + 0; // EucUnexpandedL2Sqrt does not require workspace (may need change for other distances) + typedef cutlass::Shape<8, 128, 128> OutputTile_t; + bool bAllocWorkspace = false; + + math_t *d_pdist_tmp = + (math_t *)d_alloc->allocate(TMP_SIZE * sizeof(math_t), stream); + int *d_ind_X_tmp = (int *)d_alloc->allocate(TMP_SIZE * sizeof(int), stream); + + long *ind_X_embedded = + get_knn_indexes(X_embedded, n, d, n_neighbors + 1, d_alloc, stream); + + double t_tmp = 0.0; + double t = 0.0; + double *d_t = (double *)d_alloc->allocate(sizeof(double), stream); + + int toDo = n; + while (toDo > 0) { + int batchSize = min(toDo, MAX_BATCH_SIZE); + // Takes at most MAX_BATCH_SIZE vectors at a time + + MLCommon::Distance::distance( + &X[(n - toDo) * m], X, d_pdist_tmp, batchSize, n, m, (void *)nullptr, + workspaceSize, stream); + CUDA_CHECK(cudaPeekAtLastError()); + + MLCommon::Selection::sortColumnsPerRow(d_pdist_tmp, d_ind_X_tmp, batchSize, + n, bAllocWorkspace, NULL, + workspaceSize, stream); + CUDA_CHECK(cudaPeekAtLastError()); + + t_tmp = 0.0; + updateDevice(d_t, &t_tmp, 1, stream); + + int work = batchSize * n_neighbors; + int n_blocks = work / N_THREADS + 1; + compute_rank<<>>( + d_ind_X_tmp, &ind_X_embedded[(n - toDo) * (n_neighbors + 1)], n, + n_neighbors, batchSize * n_neighbors, d_t); + CUDA_CHECK(cudaPeekAtLastError()); + + updateHost(&t_tmp, d_t, 1, stream); + t += t_tmp; + + toDo -= batchSize; + } + + t = + 1.0 - + ((2.0 / ((n * n_neighbors) * ((2.0 * n) - (3.0 * n_neighbors) - 1.0))) * t); + + d_alloc->deallocate(ind_X_embedded, n * (n_neighbors + 1) * sizeof(long), + stream); + d_alloc->deallocate(d_pdist_tmp, TMP_SIZE * sizeof(math_t), stream); + d_alloc->deallocate(d_ind_X_tmp, TMP_SIZE * sizeof(int), stream); + d_alloc->deallocate(d_t, sizeof(double), stream); + + return t; +} /** - * Calculates the "Coefficient of Determination" (R-Squared) score - * normalizing the sum of squared errors by the total sum of squares. - * - * This score indicates the proportionate amount of variation in an - * expected response variable is explained by the independent variables - * in a linear regression model. The larger the R-squared value, the - * more variability is explained by the linear regression model. - * - * @param y: Array of ground-truth response variables - * @param y_hat: Array of predicted response variables - * @param n: Number of elements in y and y_hat - * @return: The R-squared value. - */ -template + * Calculates the "Coefficient of Determination" (R-Squared) score + * normalizing the sum of squared errors by the total sum of squares. + * + * This score indicates the proportionate amount of variation in an + * expected response variable is explained by the independent variables + * in a linear regression model. The larger the R-squared value, the + * more variability is explained by the linear regression model. + * + * @param y: Array of ground-truth response variables + * @param y_hat: Array of predicted response variables + * @param n: Number of elements in y and y_hat + * @return: The R-squared value. + */ +template math_t r2_score(math_t *y, math_t *y_hat, int n, cudaStream_t stream) { math_t *y_bar; MLCommon::allocate(y_bar, 1); @@ -244,5 +228,128 @@ math_t r2_score(math_t *y, math_t *y_hat, int n, cudaStream_t stream) { return 1.0 - sse / ssto; } + +/** + * @brief Compute accuracy of predictions. Useful for classification. + * @tparam math_t: data type for predictions (e.g., int for classification) + * @param[in] predictions: array of predictions (GPU pointer). + * @param[in] ref_predictions: array of reference (ground-truth) predictions (GPU pointer). + * @param[in] n: number of elements in each of predictions, ref_predictions. + * @param[in] d_alloc: device allocator. + * @param[in] stream: cuda stream. + * @return: Accuracy score in [0, 1]; higher is better. + */ +template +float accuracy_score(const math_t *predictions, const math_t *ref_predictions, + int n, std::shared_ptr d_alloc, + cudaStream_t stream) { + unsigned long long correctly_predicted = 0ULL; + math_t *diffs_array = (math_t *)d_alloc->allocate(n * sizeof(math_t), stream); + + //TODO could write a kernel instead + MLCommon::LinAlg::eltwiseSub(diffs_array, predictions, ref_predictions, n, + stream); + CUDA_CHECK(cudaGetLastError()); + correctly_predicted = thrust::count(thrust::cuda::par.on(stream), diffs_array, + diffs_array + n, 0); + d_alloc->deallocate(diffs_array, n * sizeof(math_t), stream); + + float accuracy = correctly_predicted * 1.0f / n; + return accuracy; +} + +template +__global__ void reg_metrics_kernel(const T *predictions, + const T *ref_predictions, int n, + double *abs_diffs, double *tmp_sums) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + __shared__ double shmem[2]; // {abs_difference_sum, squared difference sum} + + for (int i = threadIdx.x; i < 2; i += blockDim.x) { + shmem[i] = 0; + } + __syncthreads(); + + for (int i = tid; i < n; i += blockDim.x * gridDim.x) { + double diff = predictions[i] - ref_predictions[i]; + double abs_diff = abs(diff); + atomicAdd(&shmem[0], abs_diff); + atomicAdd(&shmem[1], diff * diff); + + // update absolute difference in global memory for subsequent abs. median computation + abs_diffs[i] = abs_diff; + } + __syncthreads(); + + // Update tmp_sum w/ total abs_difference_sum and squared difference sum. + for (int i = threadIdx.x; i < 2; i += blockDim.x) { + atomicAdd(&tmp_sums[i], shmem[i]); + } +} + +/** + * @brief Compute regression metrics mean absolute error, mean squared error, median absolute error + * @tparam T: data type for predictions (e.g., float or double for regression). + * @param[in] predictions: array of predictions (GPU pointer). + * @param[in] ref_predictions: array of reference (ground-truth) predictions (GPU pointer). + * @param[in] n: number of elements in each of predictions, ref_predictions. Should be > 0. + * @param[in] d_alloc: device allocator. + * @param[in] stream: cuda stream. + * @param[out] mean_abs_error: Mean Absolute Error. Sum over n of (|predictions[i] - ref_predictions[i]|) / n. + * @param[out] mean_squared_error: Mean Squared Error. Sum over n of ((predictions[i] - ref_predictions[i])^2) / n. + * @param[out] median_abs_error: Median Absolute Error. Median of |predictions[i] - ref_predictions[i]| for i in [0, n). + */ +template +void regression_metrics(const T *predictions, const T *ref_predictions, int n, + std::shared_ptr d_alloc, + cudaStream_t stream, double &mean_abs_error, + double &mean_squared_error, double &median_abs_error) { + std::vector mean_errors(2); + std::vector h_sorted_abs_diffs(n); + int thread_cnt = 256; + int block_cnt = ceildiv(n, thread_cnt); + + int array_size = n * sizeof(double); + double *abs_diffs_array = (double *)d_alloc->allocate(array_size, stream); + double *sorted_abs_diffs = (double *)d_alloc->allocate(array_size, stream); + double *tmp_sums = (double *)d_alloc->allocate(2 * sizeof(double), stream); + CUDA_CHECK(cudaMemsetAsync(tmp_sums, 0, 2 * sizeof(double), stream)); + + reg_metrics_kernel<<>>( + predictions, ref_predictions, n, abs_diffs_array, tmp_sums); + CUDA_CHECK(cudaGetLastError()); + MLCommon::updateHost(&mean_errors[0], tmp_sums, 2, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + mean_abs_error = mean_errors[0] / n; + mean_squared_error = mean_errors[1] / n; + + // Compute median error. Sort diffs_array and pick median value + char *temp_storage = nullptr; + size_t temp_storage_bytes; + CUDA_CHECK(cub::DeviceRadixSort::SortKeys( + (void *)temp_storage, temp_storage_bytes, abs_diffs_array, sorted_abs_diffs, + n, 0, 8 * sizeof(double), stream)); + temp_storage = (char *)d_alloc->allocate(temp_storage_bytes, stream); + CUDA_CHECK(cub::DeviceRadixSort::SortKeys( + (void *)temp_storage, temp_storage_bytes, abs_diffs_array, sorted_abs_diffs, + n, 0, 8 * sizeof(double), stream)); + + MLCommon::updateHost(h_sorted_abs_diffs.data(), sorted_abs_diffs, n, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + int middle = n / 2; + if (n % 2 == 1) { + median_abs_error = h_sorted_abs_diffs[middle]; + } else { + median_abs_error = + (h_sorted_abs_diffs[middle] + h_sorted_abs_diffs[middle - 1]) / 2; + } + + d_alloc->deallocate(abs_diffs_array, array_size, stream); + d_alloc->deallocate(sorted_abs_diffs, array_size, stream); + d_alloc->deallocate(temp_storage, temp_storage_bytes, stream); + d_alloc->deallocate(tmp_sums, 2 * sizeof(double), stream); +} } // namespace Score } // namespace MLCommon diff --git a/cpp/src_prims/stats/minmax.h b/cpp/src_prims/stats/minmax.h index 9257873d07..bece2d7f55 100644 --- a/cpp/src_prims/stats/minmax.h +++ b/cpp/src_prims/stats/minmax.h @@ -87,43 +87,62 @@ __global__ void minmaxInitKernel(int ncols, T* globalmin, T* globalmax, } template -__global__ void minmaxKernel(const T* data, const int* rowids, - const int* colids, int nrows, int ncols, +__global__ void minmaxKernel(const T* data, const unsigned int* rowids, + const unsigned int* colids, int nrows, int ncols, int row_stride, T* g_min, T* g_max, T* sampledcols, - T init_min_val) { + T init_min_val, int batch_ncols, int num_batches) { int tid = threadIdx.x + blockIdx.x * blockDim.x; extern __shared__ char shmem[]; T* s_min = (T*)shmem; - T* s_max = (T*)(shmem + sizeof(T) * ncols); - for (int i = threadIdx.x; i < ncols; i += blockDim.x) { - *(E*)&s_min[i] = encode(init_min_val); - *(E*)&s_max[i] = encode(-init_min_val); + T* s_max = (T*)(shmem + sizeof(T) * batch_ncols); + + int last_batch_ncols = ncols % batch_ncols; + if (last_batch_ncols == 0) { + last_batch_ncols = batch_ncols; } - __syncthreads(); - for (int i = tid; i < nrows * ncols; i += blockDim.x * gridDim.x) { - int col = i / nrows; - int row = i % nrows; - if (colids != nullptr) { - col = colids[col]; + int orig_batch_ncols = batch_ncols; + + for (int batch_id = 0; batch_id < num_batches; batch_id++) { + if (batch_id == num_batches - 1) { + batch_ncols = last_batch_ncols; } - if (rowids != nullptr) { - row = rowids[row]; + + for (int i = threadIdx.x; i < batch_ncols; i += blockDim.x) { + *(E*)&s_min[i] = encode(init_min_val); + *(E*)&s_max[i] = encode(-init_min_val); } - int index = row + col * row_stride; - T coldata = data[index]; - if (!isnan(coldata)) { - atomicMinBits(&s_min[col], coldata); - atomicMaxBits(&s_max[col], coldata); + __syncthreads(); + + for (int i = tid; i < nrows * batch_ncols; i += blockDim.x * gridDim.x) { + int col = (batch_id * orig_batch_ncols) + (i / nrows); + int row = i % nrows; + if (colids != nullptr) { + col = colids[col]; + } + if (rowids != nullptr) { + row = rowids[row]; + } + int index = row + col * row_stride; + T coldata = data[index]; + if (!isnan(coldata)) { + //Min max values are saved in shared memory and global memory as per the shuffled colids. + atomicMinBits(&s_min[(int)(i / nrows)], coldata); + atomicMaxBits(&s_max[(int)(i / nrows)], coldata); + } + if (sampledcols != nullptr) { + sampledcols[batch_id * orig_batch_ncols + i] = coldata; + } } - if (sampledcols != nullptr) { - sampledcols[i] = coldata; + __syncthreads(); + + // finally, perform global mem atomics + for (int j = threadIdx.x; j < batch_ncols; j += blockDim.x) { + atomicMinBits(&g_min[batch_id * orig_batch_ncols + j], + decode(*(E*)&s_min[j])); + atomicMaxBits(&g_max[batch_id * orig_batch_ncols + j], + decode(*(E*)&s_max[j])); } - } - __syncthreads(); - // finally, perform global mem atomics - for (int j = threadIdx.x; j < ncols; j += blockDim.x) { - atomicMinBits(&g_min[j], decode(*(E*)&s_min[j])); - atomicMaxBits(&g_max[j], decode(*(E*)&s_max[j])); + __syncthreads(); } } @@ -146,7 +165,7 @@ __global__ void minmaxKernel(const T* data, const int* rowids, * @param globalmin final col-wise global minimum (size = ncols) * @param globalmax final col-wise global maximum (size = ncols) * @param sampledcols output sampled data. Pass nullptr if you don't need this - * @param init_val initial minimum value to be + * @param init_val initial minimum value to be * @param stream: cuda stream * @note This method makes the following assumptions: * 1. input and output matrices are assumed to be col-major @@ -154,9 +173,9 @@ __global__ void minmaxKernel(const T* data, const int* rowids, * in shared memory */ template -void minmax(const T* data, const int* rowids, const int* colids, int nrows, - int ncols, int row_stride, T* globalmin, T* globalmax, - T* sampledcols, cudaStream_t stream) { +void minmax(const T* data, const unsigned int* rowids, + const unsigned int* colids, int nrows, int ncols, int row_stride, + T* globalmin, T* globalmax, T* sampledcols, cudaStream_t stream) { using E = typename encode_traits::E; int nblks = ceildiv(ncols, TPB); T init_val = std::numeric_limits::max(); @@ -164,11 +183,24 @@ void minmax(const T* data, const int* rowids, const int* colids, int nrows, <<>>(ncols, globalmin, globalmax, init_val); CUDA_CHECK(cudaPeekAtLastError()); nblks = ceildiv(nrows * ncols, TPB); - nblks = max(nblks, 65536); + nblks = min(nblks, 65536); size_t smemSize = sizeof(T) * 2 * ncols; + + // Get available shared memory size. + cudaDeviceProp prop; + int dev_ID = 0; + CUDA_CHECK(cudaGetDevice(&dev_ID)); + CUDA_CHECK(cudaGetDeviceProperties(&prop, dev_ID)); + size_t max_shared_mem = prop.sharedMemPerBlock; + + // Compute the batch_ncols, in [1, ncols] range, that meet the available shared memory constraints. + int batch_ncols = min(ncols, (int)(max_shared_mem / (sizeof(T) * 2))); + int num_batches = ceildiv(ncols, batch_ncols); + smemSize = sizeof(T) * 2 * batch_ncols; + minmaxKernel<<>>( data, rowids, colids, nrows, ncols, row_stride, globalmin, globalmax, - sampledcols, init_val); + sampledcols, init_val, batch_ncols, num_batches); CUDA_CHECK(cudaPeekAtLastError()); decodeKernel<<>>(globalmin, globalmax, ncols); CUDA_CHECK(cudaPeekAtLastError()); diff --git a/cpp/test/prims/cuda_utils.cu b/cpp/test/prims/cuda_utils.cu index 43fbba3e57..a84bddb07a 100644 --- a/cpp/test/prims/cuda_utils.cu +++ b/cpp/test/prims/cuda_utils.cu @@ -19,8 +19,6 @@ namespace MLCommon { - - TEST(Utils, Assert) { ASSERT_NO_THROW(ASSERT(1 == 1, "Should not assert!")); ASSERT_THROW(ASSERT(1 != 1, "Should assert!"), Exception); diff --git a/cpp/test/prims/make_blobs.cu b/cpp/test/prims/make_blobs.cu index f25ce58d4f..4fcd367dfd 100644 --- a/cpp/test/prims/make_blobs.cu +++ b/cpp/test/prims/make_blobs.cu @@ -49,8 +49,7 @@ struct MakeBlobsInputs { }; template -class MakeBlobsTest - : public ::testing::TestWithParam> { +class MakeBlobsTest : public ::testing::TestWithParam> { protected: void SetUp() override { // Tests are configured with their expected test-values sigma. For example, @@ -95,7 +94,7 @@ class MakeBlobsTest protected: MakeBlobsInputs params; - int *labels; + int* labels; T *data, *stats, *mu_vec; T h_stats[2]; // mean, var std::shared_ptr allocator; diff --git a/cpp/test/prims/minmax.cu b/cpp/test/prims/minmax.cu index 9666581e0e..6d311da51a 100644 --- a/cpp/test/prims/minmax.cu +++ b/cpp/test/prims/minmax.cu @@ -134,7 +134,8 @@ const std::vector> inputsf = { {0.00001f, 4096, 512, 1234ULL}, {0.00001f, 4096, 1024, 1234ULL}, {0.00001f, 8192, 32, 1234ULL}, {0.00001f, 8192, 64, 1234ULL}, {0.00001f, 8192, 128, 1234ULL}, {0.00001f, 8192, 256, 1234ULL}, - {0.00001f, 8192, 512, 1234ULL}, {0.00001f, 8192, 1024, 1234ULL}}; + {0.00001f, 8192, 512, 1234ULL}, {0.00001f, 8192, 1024, 1234ULL}, + {0.00001f, 1024, 8192, 1234ULL}}; const std::vector> inputsd = { {0.0000001, 1024, 32, 1234ULL}, {0.0000001, 1024, 64, 1234ULL}, @@ -145,7 +146,8 @@ const std::vector> inputsd = { {0.0000001, 4096, 512, 1234ULL}, {0.0000001, 4096, 1024, 1234ULL}, {0.0000001, 8192, 32, 1234ULL}, {0.0000001, 8192, 64, 1234ULL}, {0.0000001, 8192, 128, 1234ULL}, {0.0000001, 8192, 256, 1234ULL}, - {0.0000001, 8192, 512, 1234ULL}, {0.0000001, 8192, 1024, 1234ULL}}; + {0.0000001, 8192, 512, 1234ULL}, {0.0000001, 8192, 1024, 1234ULL}, + {0.0000001, 1024, 8192, 1234ULL}}; typedef MinMaxTest MinMaxTestF; TEST_P(MinMaxTestF, Result) { diff --git a/cpp/test/prims/score.cu b/cpp/test/prims/score.cu index ca9b582dfe..867e5827d2 100644 --- a/cpp/test/prims/score.cu +++ b/cpp/test/prims/score.cu @@ -16,6 +16,7 @@ #include #include +#include #include "random/rng.h" #include "score/scores.h" #include "test_utils.h" @@ -74,5 +75,429 @@ TEST(ScoreTestLowScore, Result) { CUDA_CHECK(cudaStreamDestroy(stream)); } -} // namespace Score -} // namespace MLCommon +// Tests for accuracy_score + +struct AccuracyInputs { + /** + * Number of predictions. + */ + int n; + /** + * Number of predictions w/ different values than their corresponding element in reference predictions. + * Valid range [0, n]. changed_n in [0, n] will yield accuracy of (n - changed_n) / n. + */ + int changed_n; + /** + * Seed for randomly generated predictions. + */ + unsigned long long int seed; +}; + +std::ostream &operator<<(::std::ostream &os, const AccuracyInputs &acc_inputs) { + os << "AccuracyInputs are {" << acc_inputs.n << ", " << acc_inputs.changed_n + << ", " << acc_inputs.seed << "}" << std::endl; + return os; +} + +template +__global__ void change_vals(T *predictions, T *ref_predictions, + const int changed_n) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < changed_n) { + predictions[tid] = + ref_predictions[tid] + 1; // change first changed_n predictions + } +} + +template +class AccuracyTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + params = ::testing::TestWithParam::GetParam(); + ASSERT((params.changed_n <= params.n) && (params.changed_n >= 0), + "Invalid params."); + + Random::Rng r(params.seed); + CUDA_CHECK(cudaStreamCreate(&stream)); + std::shared_ptr d_allocator(new defaultDeviceAllocator); + + allocate(predictions, params.n); + allocate(ref_predictions, params.n); + r.normal(ref_predictions, params.n, (T)0.0, (T)1.0, stream); + copyAsync(predictions, ref_predictions, params.n, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + //Modify params.changed_n unique predictions to a different value. New value is irrelevant. + if (params.changed_n > 0) { + int threads = 64; + int blocks = ceildiv(params.changed_n, threads); + //@todo Could also generate params.changed_n unique random positions in [0, n) range, instead of changing the first ones. + change_vals<<>>( + predictions, ref_predictions, params.changed_n); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaStreamSynchronize(stream)); + } + + computed_accuracy = MLCommon::Score::accuracy_score( + predictions, ref_predictions, params.n, d_allocator, stream); + ref_accuracy = (params.n - params.changed_n) * 1.0f / params.n; + //std::cout << "computed_accuracy is " << computed_accuracy << " ref_accuracy is " << ref_accuracy << std::endl; + } + + void TearDown() override { + CUDA_CHECK(cudaFree(predictions)); + CUDA_CHECK(cudaFree(ref_predictions)); + CUDA_CHECK(cudaStreamDestroy(stream)); + computed_accuracy = -1.0f; + ref_accuracy = -1.0f; + } + + AccuracyInputs params; + T *predictions, *ref_predictions; + float computed_accuracy, ref_accuracy; + cudaStream_t stream; +}; + +const std::vector inputs = { + {1, 1, 1234ULL}, // single element, wrong prediction + {1, 0, 1234ULL}, // single element, perfect prediction + {2, 1, 1234ULL}, // multiple elements, 0.5 accuracy + {1000, 0, 1234ULL}, // multiple elements, perfect predictions + {1000, 1000, 1234ULL}, // multiple elements, no correct predictions + {1000, 80, 1234ULL}, // multiple elements, prediction mix + {1000, 45, 1234ULL} // multiple elements, prediction mix +}; + +typedef AccuracyTest AccuracyTestF; +TEST_P(AccuracyTestF, Result) { + ASSERT_TRUE(computed_accuracy == ref_accuracy); +} + +typedef AccuracyTest AccuracyTestD; +TEST_P(AccuracyTestD, Result) { + ASSERT_TRUE(computed_accuracy == ref_accuracy); +} + +INSTANTIATE_TEST_CASE_P(AccuracyTests, AccuracyTestF, + ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AccuracyTests, AccuracyTestD, + ::testing::ValuesIn(inputs)); + +// Tests for regression_metrics + +template +struct RegressionInputs { + T tolerance; + int n; // number of predictions + bool + hardcoded_preds; // (hardcoded_preds) ? use predictions, ref_predictions : use randomly generated arrays. + std::vector predictions; + std::vector ref_predictions; + T predictions_range + [2]; // predictions in predictions_range if not hardcoded_preds + T ref_predictions_range + [2]; // predictions in ref_predictions_range if not hardcoded_preds + unsigned long long int seed; +}; + +template +std::ostream &operator<<(std::ostream &os, + const RegressionInputs ®_inputs) { + os << "RegressionInputs are {" << reg_inputs.tolerance << ", " << reg_inputs.n + << ", " << reg_inputs.hardcoded_preds << ", "; + if (reg_inputs.hardcoded_preds) { + os << "{"; + for (int i = 0; i < reg_inputs.n; i++) + os << reg_inputs.predictions[i] << ", "; + os << "}, {"; + for (int i = 0; i < reg_inputs.n; i++) + os << reg_inputs.ref_predictions[i] << ", "; + os << "}"; + os << "{" << reg_inputs.predictions_range[0] << ", " + << reg_inputs.predictions_range[1] << "}, "; + os << "{" << reg_inputs.ref_predictions_range[0] << ", " + << reg_inputs.ref_predictions_range[1] << "}"; + } else { + os << "{}, {}, {}, {}"; + } + os << ", " << reg_inputs.seed; + return os; +} + +template +void host_regression_computations(std::vector &predictions, + std::vector &ref_predictions, const int n, + std::vector ®ression_metrics) { + double abs_difference_sum = 0; + double mse_sum = 0; + std::vector abs_diffs(n); + + for (int i = 0; i < n; i++) { + double abs_diff = abs(predictions[i] - ref_predictions[i]); + abs_difference_sum += abs_diff; + mse_sum += pow(predictions[i] - ref_predictions[i], 2); + abs_diffs[i] = abs_diff; + } + + regression_metrics[0] = abs_difference_sum / n; + regression_metrics[1] = mse_sum / n; + + std::sort(abs_diffs.begin(), abs_diffs.end()); + int middle = n / 2; + if (n % 2 == 1) { + regression_metrics[2] = abs_diffs[middle]; + } else { + regression_metrics[2] = (abs_diffs[middle] + abs_diffs[middle - 1]) / 2; + } +} + +template +class RegressionMetricsTest + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + params = ::testing::TestWithParam>::GetParam(); + computed_regression_metrics.assign(3, -1.0); + ref_regression_metrics.assign(3, -1.0); + + CUDA_CHECK(cudaStreamCreate(&stream)); + std::shared_ptr d_allocator(new defaultDeviceAllocator); + + allocate(d_predictions, params.n); + allocate(d_ref_predictions, params.n); + + if (params.hardcoded_preds) { + updateDevice(d_predictions, params.predictions.data(), params.n, stream); + updateDevice(d_ref_predictions, params.ref_predictions.data(), params.n, + stream); + } else { + params.predictions.resize(params.n); + params.ref_predictions.resize(params.n); + Random::Rng r(params.seed); + // randomly generate arrays + r.uniform(d_predictions, params.n, params.predictions_range[0], + params.predictions_range[1], stream); + r.uniform(d_ref_predictions, params.n, params.ref_predictions_range[0], + params.ref_predictions_range[1], stream); + // copy to host to compute reference regression metrics + updateHost(params.predictions.data(), d_predictions, params.n, stream); + updateHost(params.ref_predictions.data(), d_ref_predictions, params.n, + stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + } + + MLCommon::Score::regression_metrics( + d_predictions, d_ref_predictions, params.n, d_allocator, stream, + computed_regression_metrics[0], computed_regression_metrics[1], + computed_regression_metrics[2]); + + host_regression_computations(params.predictions, params.ref_predictions, + params.n, ref_regression_metrics); + CUDA_CHECK(cudaStreamSynchronize(stream)); + } + + void TearDown() override { + CUDA_CHECK(cudaStreamDestroy(stream)); + CUDA_CHECK(cudaFree(d_predictions)); + CUDA_CHECK(cudaFree(d_ref_predictions)); + } + + RegressionInputs params; + T *d_predictions, *d_ref_predictions; + std::vector computed_regression_metrics; + std::vector ref_regression_metrics; + cudaStream_t stream; +}; + +const std::vector> regression_inputs_float = { + {0.00001f, 1, true, {10.2f}, {20.2f}, {}, {}, 1234ULL}, // single element + {0.00001f, + 2, + true, + {10.2f, 40.2f}, + {20.2f, 80.2f}, + {}, + {}, + 1234ULL}, // two elements, mean same as median + // next three inputs should result in identical regression metrics values + {0.00001f, + 6, + true, + {10.5f, 20.5f, 30.5f, 40.5f, 50.5f, 60.5f}, + {20.5f, 40.5f, 55.5f, 80.5f, 100.5f, 120.5f}, + {}, + {}, + 1234ULL}, // diffs all negative, reverse sorted + {0.00001f, + 6, + true, + {20.5f, 40.5f, 55.5f, 80.5f, 100.5f, 120.5f}, + {10.5f, 20.5f, 30.5f, 40.5f, 50.5f, 60.5f}, + {}, + {}, + 1234ULL}, // diffs all positive, already sorted + {0.00001f, + 6, + true, + {40.5f, 55.5f, 20.5f, 120.5f, 100.5f, 80.5f}, + {20.5f, 30.5f, 10.5f, 60.5f, 50.5f, 40.5f}, + {}, + {}, + 1234ULL}, // mix + {0.00001f, + 6, + true, + {10.5f, 20.5f, 30.5f, 40.5f, 50.5f, 60.5f}, + {10.5f, 20.5f, 30.5f, 40.5f, 50.5f, 60.5f}, + {}, + {}, + 1234ULL}, // identical predictions (0 error) + {0.00001f, + 6, + true, + {10.5f, 20.5f, 30.5f, 40.5f, 50.5f, 60.5f}, + {20.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f}, + {}, + {}, + 1234ULL}, // predictions[i] - ref_predictions[i] const for each i + {0.00001f, + 2048, + false, + {}, + {}, + {-2048.0f, 2048.0f}, + {-2048.0f, 2048.0f}, + 1234ULL}, // random mix, even number of elements + {0.00001f, + 2049, + false, + {}, + {}, + {-2048.0f, 2048.0f}, + {-2048.0f, 2048.0f}, + 1234ULL}, // random mix, odd number of elements + {0.00001f, + 1024, + false, + {}, + {}, + {0.0f, 2048.0f}, + {8192.0f, 16384.0f}, + 1234ULL}, // random mix, diffs are all negative + {0.00001f, + 1024, + false, + {}, + {}, + {8192.0f, 16384.0f}, + {0.0f, 2048.0f}, + 1234ULL} // random mix, diffs are all positive +}; + +const std::vector> regression_inputs_double = { + {0.0000001, 1, true, {10.2}, {20.2}, {}, {}, 1234ULL}, // single element + {0.0000001, + 2, + true, + {10.2, 40.2}, + {20.2, 80.2}, + {}, + {}, + 1234ULL}, // two elements + {0.0000001, + 6, + true, + {10.5, 20.5, 30.5, 40.5, 50.5, 60.5}, + {20.5, 40.5, 55.5, 80.5, 100.5, 120.5}, + {}, + {}, + 1234ULL}, // diffs all negative, reverse sorted + {0.0000001, + 6, + true, + {20.5, 40.5, 55.5, 80.5, 100.5, 120.5}, + {10.5, 20.5, 30.5, 40.5, 50.5, 60.5}, + {}, + {}, + 1234ULL}, // diffs all positive, already sorted + {0.0000001, + 6, + true, + {40.5, 55.5, 20.5, 120.5, 100.5, 80.5}, + {20.5, 30.5, 10.5, 60.5, 50.5, 40.5}, + {}, + {}, + 1234ULL}, // mix + {0.0000001, + 6, + true, + {10.5, 20.5, 30.5, 40.5, 50.5, 60.5}, + {10.5, 20.5, 30.5, 40.5, 50.5, 60.5}, + {}, + {}, + 1234ULL}, // identical predictions (0 error) + {0.0000001, + 6, + true, + {10.5, 20.5, 30.5, 40.5, 50.5, 60.5}, + {20.5, 30.5, 40.5, 50.5, 60.5, 70.5}, + {}, + {}, + 1234ULL}, // predictions[i] - ref_predictions[i] const for each i + {0.0000001, + 2048, + false, + {}, + {}, + {-2048.0, 2048.0}, + {-2048.0, 2048.0}, + 1234ULL}, // random mix, even number of elements + {0.0000001, + 2049, + false, + {}, + {}, + {-2048.0, 2048.0}, + {-2048.0, 2048.0}, + 1234ULL}, // random mix, odd number of elements + {0.0000001, + 1024, + false, + {}, + {}, + {0, 2048}, + {8192.0, 16384.0}, + 1234ULL}, // random mix, diffs are all negative + {0.0000001, + 1024, + false, + {}, + {}, + {8192.0, 16384.0}, + {0.0, 2048}, + 1234ULL} // random mix, diffs are all positive +}; + +typedef RegressionMetricsTest RegressionMetricsTestF; +TEST_P(RegressionMetricsTestF, Result) { + for (int i = 0; i < 3; i++) { + ASSERT_TRUE(match(computed_regression_metrics[i], ref_regression_metrics[i], + CompareApprox(params.tolerance))); + } +} + +typedef RegressionMetricsTest RegressionMetricsTestD; +TEST_P(RegressionMetricsTestD, Result) { + for (int i = 0; i < 3; i++) { + ASSERT_TRUE(match(computed_regression_metrics[i], ref_regression_metrics[i], + CompareApprox(params.tolerance))); + } +} + +INSTANTIATE_TEST_CASE_P(RegressionMetricsTests, RegressionMetricsTestF, + ::testing::ValuesIn(regression_inputs_float)); +INSTANTIATE_TEST_CASE_P(RegressionMetricsTests, RegressionMetricsTestD, + ::testing::ValuesIn(regression_inputs_double)); + +} // end namespace Score +} // end namespace MLCommon diff --git a/cpp/test/prims/test_utils.h b/cpp/test/prims/test_utils.h index f944cd48c3..a9b7b36c46 100644 --- a/cpp/test/prims/test_utils.h +++ b/cpp/test/prims/test_utils.h @@ -168,9 +168,10 @@ ::testing::AssertionResult devArrMatch(T expected, const T *actual, size_t rows, * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE */ template -::testing::AssertionResult devArrMatchHost(const T *expected_h, const T *actual_d, - size_t size, L eq_compare, - cudaStream_t stream = 0) { +::testing::AssertionResult devArrMatchHost(const T *expected_h, + const T *actual_d, size_t size, + L eq_compare, + cudaStream_t stream = 0) { std::shared_ptr act_h(new T[size]); updateHost(act_h.get(), actual_d, size, stream); CUDA_CHECK(cudaStreamSynchronize(stream)); @@ -180,8 +181,8 @@ ::testing::AssertionResult devArrMatchHost(const T *expected_h, const T *actual_ auto exp = expected_h[i]; auto act = act_h.get()[i]; if (!eq_compare(exp, act)) { - ok = false; - fail<<"actual=" << act << " != expected=" << exp << " @" << i <<"; "; + ok = false; + fail << "actual=" << act << " != expected=" << exp << " @" << i << "; "; } } if (!ok) return fail; diff --git a/cpp/test/prims/trustworthiness.cu b/cpp/test/prims/trustworthiness.cu index c70cc3ae02..3dc8948b8d 100644 --- a/cpp/test/prims/trustworthiness.cu +++ b/cpp/test/prims/trustworthiness.cu @@ -14,430 +14,439 @@ * limitations under the License. */ +#include #include +#include +#include #include #include "distance/distance.h" -#include #include "test_utils.h" -#include -#include namespace MLCommon { namespace Score { - class TrustworthinessScoreTest: public ::testing::Test { - protected: - void basicTest() { - std::vector X = { - 5.6142087,8.59787,-4.382763,-3.6452143,-5.8816037, - -0.6330313,4.6920023,-0.79210913,0.6106314,2.1210914, - 5.919943,-8.43784,-6.4819884,0.41001374,-6.1052523, - -4.0825715,-5.314755,-2.834671,5.751696,-6.5012555, - -0.4719201,-7.53353,7.6789393,-1.4959852,-5.5977287, - -9.564147,1.2902534,3.559834,-6.7659483,8.265964, - 4.595404,9.133477,-6.1553917,-6.319754,-2.9039452, - 4.4150834,-3.094395,-4.426273,9.584571,-5.64133, - 6.6209483,7.4044604,3.9620576,5.639907,10.33007, - -0.8792053,5.143776,-7.464049,1.2448754,-5.6300974, - 5.4518576,4.119535,6.749645,7.627064,-7.2298336, - 1.9681473,-6.9083176,6.404673,0.07186685,9.0994835, - 8.51037,-8.986389,0.40534487,2.115397,4.086756, - 1.2284287,-2.6272132,0.06527536,-9.587425,-7.206078, - 7.864875,7.4397306,-6.9233336,-2.6643622,3.3466153, - 7.0408177,-3.6069896,-9.971769,4.4075623,7.9063697, - 2.559074,4.323717,1.6867131,-1.1576937,-9.893141, - -3.251416,-7.4889135,-4.0588717,-2.73338,-7.4852257, - 3.4460473,9.759119,-5.4680476,-4.722435,-8.032619, - -1.4598992,4.227361,3.135568,1.1950601,1.1982028, - 6.998856,-6.131138,-6.6921015,0.5361224,-7.1213965, - -5.6104236,-7.2212887,-2.2710054,8.544764,-6.0254574, - 1.4582269,-5.5587835,8.031556,-0.26328218,-5.2591386, - -9.262641,2.8691363,5.299787,-9.209455,8.523085, - 5.180329,10.655528,-5.7171874,-6.7739563,-3.6306462, - 4.067106,-1.5912259,-3.2345476,8.042973,-3.6364832, - 4.1242137,9.886953,5.4743724,6.3058076,9.369645, - -0.5175337,4.9859877,-7.879498,1.358422,-4.147944, - 3.8984218,5.894656,6.4903927,8.702036,-8.023722, - 2.802145,-7.748032,5.8461113,-0.34215945,11.298865, - 1.4107164,-9.949621,-1.6257563,-10.655836,2.4528909, - 1.1570255,5.170669,2.8398793,7.1838694,9.088459, - 2.631155,3.964414,2.8769252,0.04198391,-0.16993195, - 3.6747139,-2.8377378,6.1782537,10.759618,-4.5642614, - -8.522967,0.8614642,6.623416,-1.029324,5.5488334, - -7.804511,2.128833,7.9042315,7.789576,-2.7944536, - 0.72271067,-10.511495,-0.78634536,-10.661714,2.9376361, - 1.9148129,6.22859,0.26264945,8.028384,6.8743043, - 0.9351067,7.0690722,4.2846055,1.4134506,-0.18144785, - 5.2778087,-1.7140163,9.217541,8.602799,-2.6537218, - -7.8377395,1.1244944,5.4540544,-0.38506773,3.9885726, - -10.76455,1.4440702,9.136163,6.664117,-5.7046547, - 8.038592,-9.229767,-0.2799413,3.6064725,4.187257, - 1.0516582,-2.0707326,-0.7615968,-8.561018,-3.7831352, - 10.300297,5.332594,-6.5880876,-4.2508664,1.7985519, - 5.7226253,-4.1223383,-9.6697855,1.4885283,7.524974, - 1.7206005,4.890457,3.7264557,0.4428284,-9.922455, - -4.250455,-6.4410596,-2.107994,-1.4109765,-6.1325397, - 0.32883006,6.0489736,7.7257385,-8.281174,1.0129383, - -10.792166,8.378851,10.802716,9.848448,-9.188757, - 1.3151443,1.9971865,-2.521849,4.3268294,-7.775683, - -2.2902298,3.0824065,-7.17559,9.6100855,7.3965735, - -10.476525,5.895973,-3.6974669,-7.6688933,1.7354839, - -7.4045196,-1.7992063,-4.0394845,5.2471714,-2.250571, - 2.528036,-8.343515,-2.2374575,-10.019771,0.73371273, - 3.1853926,2.7994921,2.6637669,7.620401,7.515571, - 0.68636256,5.834537,4.650282,-1.0362619,0.4461701, - 3.7870514,-4.1340904,7.202998,9.736904,-3.005512, - -8.920467,1.1228397,6.2598724,1.2812365,4.5442104, - -8.791537,0.92113096,8.464749,8.359035,-4.3923397, - 1.2252625,-10.1986475,-1.4409319,-10.013967,3.9071581, - 1.683064,4.877419,1.6570637,9.559105,7.3546534, - 0.36635467,5.220211,4.6303267,0.6601065,0.16149978, - 3.8818731,-3.4438233,8.42085,8.659159,-3.0935583, - -8.039611,2.3060374,5.134666,1.0458113,6.0190983, - -9.143728,0.99048865,9.210842,6.670241,-5.9614363, - 0.8747396,7.078824,8.067469,-10.314754,0.45977542, - -9.28306,9.1838665,9.318644,7.189082,-11.092555, - 1.0320464,3.882163,0.10953151,7.9029684,-6.9068265, - -1.3526366,5.3996363,-8.430931,11.452577,6.39663, - -11.090514,4.6662245,-3.1268113,-8.357452,2.2276728, - -10.357126,-0.9291848,-3.4193344,3.1289792,-2.5030103, - 6.772719,11.457757,-4.2125936,-6.684548,-4.7611327, - 3.6960156,-2.3030636,-3.0591488,10.452471,-4.1267314, - 5.66614,7.501461,5.072407,6.636537,8.990381, - -0.2559256,4.737867,-6.2149944,2.535682,-5.5484023, - 5.7113924,3.4742818,7.9915137,7.0052586,-7.156467, - 1.4354781,-8.286235,5.7523417,-2.4175215,9.678009, - 0.05066403,-9.645226,-2.2658763,-9.518178,4.493372, - 2.3232365,2.1659086,0.42507997,8.360246,8.23535, - 2.6878164,5.236947,3.4924245,-0.6089895,0.8884741, - 4.359464,-4.6073823,7.83441,8.958755,-3.4690795, - -9.182282,1.2478025,5.6311107,-1.2408862,3.6316886, - -8.684654,2.1078515,7.2813864,7.9265943,-3.6135032, - 0.4571511,8.493568,10.496853,-7.432897,0.8625995, - -9.607528,7.2899456,8.83158,8.908199,-10.300263, - 1.1451302,3.7871468,-0.97040755,5.7664757,-8.9688, - -2.146672,5.9641485,-6.2908535,10.126465,6.1553903, - -12.066902,6.301596,-5.0419583,-8.228695,2.4879954, - -8.918582,-3.7434099,-4.1593685,3.7431836,-1.1704745, - 0.5524103,9.109399,9.571567,-11.209955,1.2462777, - -9.554555,9.091726,11.477966,7.630937,-10.450911, - 1.9205878,5.358983,-0.44546837,6.7611346,-9.74753, - -0.5939732,3.8892255,-6.437991,10.294727,5.6723895, - -10.7883,6.192348,-5.293862,-10.811491,1.0194173, - -7.074576,-3.192368,-2.5231771,4.2791643,-0.53309685, - 0.501366,9.636625,7.710316,-6.4219728,1.0975566, - -8.218886,6.9011984,9.873679,8.903804,-9.316832, - 1.2404599,4.9039655,1.2272617,4.541515,-5.2753224, - -3.2196746,3.1303136,-7.285681,9.041425,5.6417427, - -9.93667,5.7548947,-5.113397,-8.544622,4.182665, - -7.7709813,-3.2810235,-3.312072,3.8900535,-2.0604856, - 6.709082,-8.461194,1.2666026,4.8770437,2.6955879, - 3.0340345,-1.1614609,-3.536341,-7.090382,-5.36146, - 9.072544,6.4554095,-4.4728956,-1.88395,3.1095037, - 8.782348,-3.316743,-8.65248,1.6802986,8.186188, - 2.1783829,4.931278,4.158475,1.4033595,-11.320101, - -3.7084908,-6.740436,-2.5555193,-1.0451177,-6.5569925, - 0.82810307,8.505919,8.332857,-9.488569,-0.21588463, - -8.056692,8.493993,7.6401625,8.812983,-9.377281, - 2.4369764,3.1766508,0.6300803,5.6666765,-7.913654, - -0.42301777,4.506412,-7.8954244,10.904591,5.042256, - -9.626183,8.347351,-3.605006,-7.923387,1.1024277, - -8.705793,-2.5151258,-2.5066147,4.0515003,-2.060757, - 6.2635093,8.286584,-6.0509276,-6.76452,-3.1158175, - 1.6578803,-1.4608748,-1.24211,8.151246,-4.2970877, - 6.093071,7.4911637,4.51018,4.8425875,9.211085, - -2.4386222,4.5830803,-5.6079445,2.3713675,-4.0707507, - 3.1787417,5.462342,6.915912,6.3928423,-7.2970796, - 5.0112796,-9.140893,4.9990606,0.38391754,7.7088532, - 1.9340848,8.18833,8.16617,-9.42086,-0.3388326, - -9.659727,8.243045,8.099073,8.439428,-7.038694, - 2.1077902,3.3866816,-1.9975324,7.4972878,-7.2525196, - -1.553731,4.08758,-6.6922374,9.50525,4.026735, - -9.243538,7.2740564,-3.9319072,-6.3228955,1.6693478, - -7.923119,-3.7423058,-2.2813146,5.3469067,-1.8285407, - 3.3118162,8.826356,-4.4641976,-6.4751124,-9.200089, - -2.519147,4.225298,2.4105988,-0.4344186,0.53441775, - 5.2836394,-8.2816105,-4.996147,-1.6870759,-7.8543897, - -3.9788852,-7.0346904,-3.1289773,7.4567637,-5.6227813, - 1.0709786,-8.866012,8.427324,-1.1755563,-5.789216, - -8.197835,5.3342214,6.0646234,-6.8975716,7.717031, - 3.480355,8.312151,-3.6645212,-3.0976524,-8.090359, - -1.9176173,2.4257212,1.9700835,0.4098958,2.1341088, - 7.652741,-9.9595585,-5.989757,0.10119354,-7.935407, - -5.792786,-5.22783,-4.318978,5.414037,-6.4621663, - 1.670883,-6.9224787,8.696932,-2.0214002,-6.6681314, - -8.326418,4.9049683,5.4442496,-6.403739,7.5822453, - 7.0972915,-9.072851,-0.23897195,1.7662339,5.3096304, - 1.983179,-2.222645,-0.34700772,-9.094717,-6.107907, - 9.525174,8.1550665,-5.6940084,-4.1636486,1.7360662, - 8.528821,-3.7299833,-9.341266,2.608542,9.108706, - 0.7978509,4.2488184,2.454484,0.9446999,-10.106636, - -3.8973773,-6.6566644,-4.5647273,-0.99837756,-6.568582, - 9.324853,-7.9020953,2.0910501,2.2896829,1.6790711, - 1.3159255,-3.5258796,1.8898442,-8.105812,-4.924962, - 8.771129,7.1202874,-5.991957,-3.4106019,2.4450088, - 7.796387,-3.055946,-7.8971434,1.9856719,9.001636, - 1.8511922,3.019749,3.1227696,0.4822102,-10.021213, - -3.530504,-6.225959,-3.0029628,-1.7881511,-7.3879776, - 1.3925704,9.499782,-3.7318087,-3.7074296,-7.7466836, - -1.5284524,4.0535855,3.112011,0.10340207,-0.5429599, - 6.67026,-9.155924,-4.924038,0.64248866,-10.0103655, - -3.2742946,-4.850029,-3.6707063,8.586258,-5.855605, - 4.906918,-6.7813993,7.9938135,-2.5473144,-5.688948, - -7.822478,2.1421318,4.66659,-9.701272,9.549149, - 0.8998125,-8.651497,-0.56899565,-8.639817,2.3088377, - 2.1264515,3.2764478,2.341989,8.594338,8.630639, - 2.8440373,6.2043204,4.433932,0.6320018,-1.8179281, - 5.09452,-1.5741565,8.153934,8.744339,-3.6945698, - -8.883078,1.5329908,5.2745943,0.44716078,4.8809066, - -7.9594903,1.134374,9.233994,6.5528665,-4.520542, - 9.477355,-8.622195,-0.23191702,2.0485356,3.9379985, - 1.5916302,-1.4516805,-0.0843819,-7.8554378,-5.88308, - 7.999766,6.2572145,-5.585321,-4.0097756,0.42382592, - 6.160884,-3.631315,-8.333449,2.770595,7.8495173, - 3.3331623,4.940415,3.6207345,-0.037517,-11.034698, - -3.185103,-6.614664,-3.2177854,-2.0792234,-6.8879867, - 7.821685,-8.455084,1.0784642,4.0033927,2.7343264, - 2.6052725,-4.1224284,-0.89305353,-6.8267674,-4.9715133, - 8.880253,5.6994023,-5.9695024,-4.9181266,1.3017995, - 7.972617,-3.9452884,-10.424556,2.4504194,6.21529, - 0.93840516,4.2070026,6.159839,0.91979957,-8.706724, - -4.317946,-6.6823545,-3.0388,-2.464262,-7.3716645, - 1.3926703,6.544412,-5.6251183,-5.122411,-8.622049, - -2.3905911,3.9138813,1.9779967,-0.05011125,0.13310997, - 7.229751,-9.742043,-8.08724,1.2426697,-7.9230795, - -3.3162494,-7.129571,-3.5488048,7.4701195,-5.2357526, - 0.5917681,-6.272206,6.342328,-2.909731,-4.991607, - -8.845513,3.3228495,7.033246,-7.8180246,8.214469, - 6.3910093,9.185153,-6.20472,-7.713809,-3.8481297, - 3.5579286,0.7078448,-3.2893546,7.384514,-4.448121, - 3.0104196,9.492943,8.024847,4.9114385,9.965594, - -3.014036,5.182494,-5.8806014,2.5312455,-5.9926524, - 4.474469,6.3717875,6.993105,6.493093,-8.935534, - 3.004074,-8.055647,8.315765,-1.3026813,8.250377, - 0.02606229,6.8508425,9.655665,-7.0116496,-0.41060972, - -10.049198,7.897801,6.7791023,8.3362,-9.821014, - 2.491157,3.5160472,-1.6228812,7.398063,-8.769123, - -3.1743705,3.2827861,-6.497855,10.831924,5.2761307, - -9.704417,4.3817043,-3.9841619,-8.111647,1.1883026, - -8.115312,-2.9240117,-5.8879666,4.20928,-0.3587938, - 6.935672,-10.177582,0.48819053,3.1250648,2.9306343, - 3.082544,-3.477687,-1.3768549,-7.4922366,-3.756631, - 10.039836,3.6670392,-5.9761434,-4.4728765,3.244255, - 7.027899,-2.3806512,-10.4100685,1.605716,7.7953773, - 0.5408159,1.7156523,3.824097,-1.0604783,-10.142124, - -5.246805,-6.5283823,-4.579547,-2.42714,-6.709197, - 2.7782338,7.33353,-6.454507,-2.9929368,-7.8362985, - -2.695445,2.4900775,1.6682367,0.4641757,-1.0495365, - 6.9631333,-9.291356,-8.23837,-0.34263706,-8.275113, - -2.8454232,-5.0864096,-2.681942,7.5450225,-6.2517986, - 0.06810654,-6.470652,4.9042645,-1.8369255,-6.6937943, - -7.9625087,2.8510258,6.180508,-8.282598,7.919079, - 1.4897474,6.7217417,-4.2459426,-4.114431,-8.375707, - -2.143264,5.6972933,1.5574739,0.39375135,1.7930849, - 5.1737595,-7.826241,-5.160268,-0.80433255,-7.839536, - -5.2620406,-5.4643164,-3.185536,6.620315,-7.065227, - 1.0524757,-6.125088,5.7126627,-1.6161644,-3.852159, - -9.164279,2.7005782,5.946544,-8.468236,8.2145405, - 1.1035942,6.590157,-4.0461283,-4.8090615,-7.6702685, - -2.1121511,5.1147075,1.6128504,2.0064135,1.0544407, - 6.0038295,-7.8282537,-4.801278,0.32349443,-8.0649805, - -4.372714,-5.61336,-5.21394,8.176595,-5.4753284, - 1.7800134,-8.267283,7.2133374,-0.16594432,-6.317046, - -9.490406,4.1261597,5.473317,-7.7551675,7.007468, - 7.478628,-8.801905,0.10975724,3.5478222,4.797803, - 1.3825226,-3.357369,0.99262005,-6.94877,-5.4781394, - 9.632604,5.7492557,-5.9014316,-3.1632116,2.340859, - 8.708098,-3.1255999,-8.848661,4.5612836,8.455157, - 0.73460823,4.112301,4.392744,-0.30759293,-6.8036823, - -3.0331545,-8.269506,-2.82415,-0.9411246,-5.993506, - 2.1618164,-8.716055,-0.7432543,-10.255819,3.095418, - 2.5131428,4.752442,0.9907621,7.8279433,7.85814, - 0.50430876,5.2840405,4.457291,0.03330028,-0.40692952, - 3.9244103,-2.117118,7.6977615,8.759009,-4.2157164, - -9.136053,3.247858,4.668686,0.76162136,5.3833632, - -9.231471,0.44309422,8.380872,6.7211227,-3.091507, - 2.173508,-9.038242,-1.3666698,-9.819077,0.37825826, - 2.3898845,4.2440815,1.9161536,7.24787,6.9124637, - 1.6238527,5.1140285,3.1935842,1.02845,-1.1273454, - 5.638998,-2.497932,8.342559,8.586319,-2.9069402, - -7.6387944,3.5975037,4.4115705,0.41506064,4.9078383, - -9.68327,1.8159529,9.744613,8.40622,-4.495336, - 9.244892,-8.789869,1.3158468,4.018167,3.3922846, - 2.652022,-2.7495477,0.2528986,-8.268324,-6.004913, - 10.428784,6.6580734,-5.537176,-1.7177434,2.7504628, - 6.7735,-2.4454272,-9.998361,2.9483433,6.8266654, - 2.3787718,4.472637,2.5871701,0.7355365,-7.7027745, - -4.1879907,-7.172832,-4.1843605,-0.03646783,-5.419406, - 6.958486,11.011111,-7.1821184,-7.956423,-3.408451, - 4.6850276,-2.348787,-4.398289,6.9787564,-3.8324208, - 5.967827,8.433518,4.660108,5.5657144,9.964243, - -1.3515275,6.404833,-6.4805903,2.4379845,-6.0816774, - 1.752272,5.3771873,6.9613523,6.9788294,-6.3894596, - 3.7521114,-6.8034263,6.4458385,-0.7233525,10.512529, - 4.362273,9.231461,-6.3382263,-7.659,-3.461823, - 4.71463,0.17817476,-3.685746,7.2962036,-4.6489477, - 5.218017,11.546999,4.7218375,6.8498397,9.281103, - -3.900459,6.844054,-7.0886965,-0.05019227,-8.233724, - 5.5808983,6.374517,8.321048,7.969449,-7.3478637, - 1.4917561,-8.003144,4.780668,-1.1981848,7.753739, - 2.0260844,-8.880096,-3.4258451,-7.141975,1.9637157, - 1.814725,5.311151,1.4831505,7.8483663,7.257948, - 1.395786,6.417756,5.376912,0.59505713,0.00062552, - 3.6634305,-4.159713,7.3571978,10.966816,-2.5419605, - -8.466229,1.904205,5.6338267,-0.52567476,5.59736, - -8.361799,0.5009981,8.460681,7.3891273,-3.5272243, - 5.0552278,9.921456,-7.69693,-7.286378,-1.9198836, - 3.1666567,-2.5832257,-2.2445817,9.888111,-5.076563, - 5.677401,7.497946,5.662994,5.414262,8.566503, - -2.5530663,7.1032815,-6.0612082,1.3419591,-4.9595256, - 4.3377542,4.3790717,6.793512,8.383502,-7.1278043, - 3.3240774,-9.379446,6.838661,-0.81241214,8.694813, - 0.79141915,7.632467,8.575382,-8.533798,0.28954387, - -7.5675836,5.8653326,8.97235,7.1649346,-10.575289, - 0.9359381,5.02381,-0.5609511,5.543464,-7.69131, - -2.1792977,2.4729247,-6.1917787,10.373678,7.6549597, - -8.809486,5.5657206,-3.3169382,-8.042887,2.0874746, - -7.079005,-3.33398,-3.6843317,4.0172358,-2.0754814, - 1.1726758,7.4618697,6.9483604,-8.469206,0.7401797, - -10.318176,8.384557,10.5476265,9.146971,-9.250223, - 0.6290606,4.4941425,-0.7514017,7.2271705,-8.309598, - -1.4761636,4.0140634,-6.021102,9.132852,5.6610966, - -11.249811,8.359293,-1.9445792,-7.7393436,-0.3931331, - -8.824441,-2.5995944,-2.5714035,4.140213,-3.6863053, - 5.517265,9.020411,-4.9286127,-7.871219,-3.7446704, - 2.5179656,-1.4543481,-2.2703636,7.010597,-3.6436229, - 6.753862,7.4129915,7.1406755,5.653706,9.5445175, - 0.15698843,4.761813,-7.698002,1.6870106,-4.5410123, - 4.171763,5.3747005,6.341021,7.456738,-8.231657, - 2.763487,-9.208167,6.676799,-1.1957736,10.062605, - 4.0975976,7.312957,-2.4981596,-2.9658387,-8.150425, - -2.1075552,2.64375,1.6636052,1.1483809,0.09276015, - 5.8556347,-7.8481026,-5.9913163,-0.02840613,-9.937289, - -1.0486673,-5.2340155,-3.83912,7.7165728,-8.409944, - 0.80863273,-6.9119215,7.5712357,0.36031485,-6.056131, - -8.470033,1.8678337,3.0121377,-7.3096333,8.205484, - 5.262654,8.774514,-4.7603083,-7.2096143,-4.437014, - 3.6080024,-1.624254,-4.2787876,8.880863,-4.8984556, - 5.1782074,9.944454,3.911282,3.5396595,8.867042, - -1.2006199,5.393288,-5.6455317,0.7829499,-4.0338907, - 2.479272,6.5080743,8.582535,7.0097537,-6.9823785, - 3.984318,-7.225381,5.3135114,-1.0391048,8.951443, - -0.70119005,-8.510742,-0.42949116,-10.9224825,2.8176029, - 1.6800792,5.778404,1.7269998,7.1975236,7.7258267, - 2.7632928,5.3399253,3.4650044,0.01971426,-1.6468811, - 4.114996,-1.5110453,6.8689218,8.269899,-3.1568048, - -7.0344677,1.2911975,5.950357,0.19028673,4.657226, - -8.199647,2.246055,8.989509,5.3101015,-4.2400866 - }; +class TrustworthinessScoreTest : public ::testing::Test { + protected: + void basicTest() { + std::vector X = { + 5.6142087, 8.59787, -4.382763, -3.6452143, -5.8816037, + -0.6330313, 4.6920023, -0.79210913, 0.6106314, 2.1210914, + 5.919943, -8.43784, -6.4819884, 0.41001374, -6.1052523, + -4.0825715, -5.314755, -2.834671, 5.751696, -6.5012555, + -0.4719201, -7.53353, 7.6789393, -1.4959852, -5.5977287, + -9.564147, 1.2902534, 3.559834, -6.7659483, 8.265964, + 4.595404, 9.133477, -6.1553917, -6.319754, -2.9039452, + 4.4150834, -3.094395, -4.426273, 9.584571, -5.64133, + 6.6209483, 7.4044604, 3.9620576, 5.639907, 10.33007, + -0.8792053, 5.143776, -7.464049, 1.2448754, -5.6300974, + 5.4518576, 4.119535, 6.749645, 7.627064, -7.2298336, + 1.9681473, -6.9083176, 6.404673, 0.07186685, 9.0994835, + 8.51037, -8.986389, 0.40534487, 2.115397, 4.086756, + 1.2284287, -2.6272132, 0.06527536, -9.587425, -7.206078, + 7.864875, 7.4397306, -6.9233336, -2.6643622, 3.3466153, + 7.0408177, -3.6069896, -9.971769, 4.4075623, 7.9063697, + 2.559074, 4.323717, 1.6867131, -1.1576937, -9.893141, + -3.251416, -7.4889135, -4.0588717, -2.73338, -7.4852257, + 3.4460473, 9.759119, -5.4680476, -4.722435, -8.032619, + -1.4598992, 4.227361, 3.135568, 1.1950601, 1.1982028, + 6.998856, -6.131138, -6.6921015, 0.5361224, -7.1213965, + -5.6104236, -7.2212887, -2.2710054, 8.544764, -6.0254574, + 1.4582269, -5.5587835, 8.031556, -0.26328218, -5.2591386, + -9.262641, 2.8691363, 5.299787, -9.209455, 8.523085, + 5.180329, 10.655528, -5.7171874, -6.7739563, -3.6306462, + 4.067106, -1.5912259, -3.2345476, 8.042973, -3.6364832, + 4.1242137, 9.886953, 5.4743724, 6.3058076, 9.369645, + -0.5175337, 4.9859877, -7.879498, 1.358422, -4.147944, + 3.8984218, 5.894656, 6.4903927, 8.702036, -8.023722, + 2.802145, -7.748032, 5.8461113, -0.34215945, 11.298865, + 1.4107164, -9.949621, -1.6257563, -10.655836, 2.4528909, + 1.1570255, 5.170669, 2.8398793, 7.1838694, 9.088459, + 2.631155, 3.964414, 2.8769252, 0.04198391, -0.16993195, + 3.6747139, -2.8377378, 6.1782537, 10.759618, -4.5642614, + -8.522967, 0.8614642, 6.623416, -1.029324, 5.5488334, + -7.804511, 2.128833, 7.9042315, 7.789576, -2.7944536, + 0.72271067, -10.511495, -0.78634536, -10.661714, 2.9376361, + 1.9148129, 6.22859, 0.26264945, 8.028384, 6.8743043, + 0.9351067, 7.0690722, 4.2846055, 1.4134506, -0.18144785, + 5.2778087, -1.7140163, 9.217541, 8.602799, -2.6537218, + -7.8377395, 1.1244944, 5.4540544, -0.38506773, 3.9885726, + -10.76455, 1.4440702, 9.136163, 6.664117, -5.7046547, + 8.038592, -9.229767, -0.2799413, 3.6064725, 4.187257, + 1.0516582, -2.0707326, -0.7615968, -8.561018, -3.7831352, + 10.300297, 5.332594, -6.5880876, -4.2508664, 1.7985519, + 5.7226253, -4.1223383, -9.6697855, 1.4885283, 7.524974, + 1.7206005, 4.890457, 3.7264557, 0.4428284, -9.922455, + -4.250455, -6.4410596, -2.107994, -1.4109765, -6.1325397, + 0.32883006, 6.0489736, 7.7257385, -8.281174, 1.0129383, + -10.792166, 8.378851, 10.802716, 9.848448, -9.188757, + 1.3151443, 1.9971865, -2.521849, 4.3268294, -7.775683, + -2.2902298, 3.0824065, -7.17559, 9.6100855, 7.3965735, + -10.476525, 5.895973, -3.6974669, -7.6688933, 1.7354839, + -7.4045196, -1.7992063, -4.0394845, 5.2471714, -2.250571, + 2.528036, -8.343515, -2.2374575, -10.019771, 0.73371273, + 3.1853926, 2.7994921, 2.6637669, 7.620401, 7.515571, + 0.68636256, 5.834537, 4.650282, -1.0362619, 0.4461701, + 3.7870514, -4.1340904, 7.202998, 9.736904, -3.005512, + -8.920467, 1.1228397, 6.2598724, 1.2812365, 4.5442104, + -8.791537, 0.92113096, 8.464749, 8.359035, -4.3923397, + 1.2252625, -10.1986475, -1.4409319, -10.013967, 3.9071581, + 1.683064, 4.877419, 1.6570637, 9.559105, 7.3546534, + 0.36635467, 5.220211, 4.6303267, 0.6601065, 0.16149978, + 3.8818731, -3.4438233, 8.42085, 8.659159, -3.0935583, + -8.039611, 2.3060374, 5.134666, 1.0458113, 6.0190983, + -9.143728, 0.99048865, 9.210842, 6.670241, -5.9614363, + 0.8747396, 7.078824, 8.067469, -10.314754, 0.45977542, + -9.28306, 9.1838665, 9.318644, 7.189082, -11.092555, + 1.0320464, 3.882163, 0.10953151, 7.9029684, -6.9068265, + -1.3526366, 5.3996363, -8.430931, 11.452577, 6.39663, + -11.090514, 4.6662245, -3.1268113, -8.357452, 2.2276728, + -10.357126, -0.9291848, -3.4193344, 3.1289792, -2.5030103, + 6.772719, 11.457757, -4.2125936, -6.684548, -4.7611327, + 3.6960156, -2.3030636, -3.0591488, 10.452471, -4.1267314, + 5.66614, 7.501461, 5.072407, 6.636537, 8.990381, + -0.2559256, 4.737867, -6.2149944, 2.535682, -5.5484023, + 5.7113924, 3.4742818, 7.9915137, 7.0052586, -7.156467, + 1.4354781, -8.286235, 5.7523417, -2.4175215, 9.678009, + 0.05066403, -9.645226, -2.2658763, -9.518178, 4.493372, + 2.3232365, 2.1659086, 0.42507997, 8.360246, 8.23535, + 2.6878164, 5.236947, 3.4924245, -0.6089895, 0.8884741, + 4.359464, -4.6073823, 7.83441, 8.958755, -3.4690795, + -9.182282, 1.2478025, 5.6311107, -1.2408862, 3.6316886, + -8.684654, 2.1078515, 7.2813864, 7.9265943, -3.6135032, + 0.4571511, 8.493568, 10.496853, -7.432897, 0.8625995, + -9.607528, 7.2899456, 8.83158, 8.908199, -10.300263, + 1.1451302, 3.7871468, -0.97040755, 5.7664757, -8.9688, + -2.146672, 5.9641485, -6.2908535, 10.126465, 6.1553903, + -12.066902, 6.301596, -5.0419583, -8.228695, 2.4879954, + -8.918582, -3.7434099, -4.1593685, 3.7431836, -1.1704745, + 0.5524103, 9.109399, 9.571567, -11.209955, 1.2462777, + -9.554555, 9.091726, 11.477966, 7.630937, -10.450911, + 1.9205878, 5.358983, -0.44546837, 6.7611346, -9.74753, + -0.5939732, 3.8892255, -6.437991, 10.294727, 5.6723895, + -10.7883, 6.192348, -5.293862, -10.811491, 1.0194173, + -7.074576, -3.192368, -2.5231771, 4.2791643, -0.53309685, + 0.501366, 9.636625, 7.710316, -6.4219728, 1.0975566, + -8.218886, 6.9011984, 9.873679, 8.903804, -9.316832, + 1.2404599, 4.9039655, 1.2272617, 4.541515, -5.2753224, + -3.2196746, 3.1303136, -7.285681, 9.041425, 5.6417427, + -9.93667, 5.7548947, -5.113397, -8.544622, 4.182665, + -7.7709813, -3.2810235, -3.312072, 3.8900535, -2.0604856, + 6.709082, -8.461194, 1.2666026, 4.8770437, 2.6955879, + 3.0340345, -1.1614609, -3.536341, -7.090382, -5.36146, + 9.072544, 6.4554095, -4.4728956, -1.88395, 3.1095037, + 8.782348, -3.316743, -8.65248, 1.6802986, 8.186188, + 2.1783829, 4.931278, 4.158475, 1.4033595, -11.320101, + -3.7084908, -6.740436, -2.5555193, -1.0451177, -6.5569925, + 0.82810307, 8.505919, 8.332857, -9.488569, -0.21588463, + -8.056692, 8.493993, 7.6401625, 8.812983, -9.377281, + 2.4369764, 3.1766508, 0.6300803, 5.6666765, -7.913654, + -0.42301777, 4.506412, -7.8954244, 10.904591, 5.042256, + -9.626183, 8.347351, -3.605006, -7.923387, 1.1024277, + -8.705793, -2.5151258, -2.5066147, 4.0515003, -2.060757, + 6.2635093, 8.286584, -6.0509276, -6.76452, -3.1158175, + 1.6578803, -1.4608748, -1.24211, 8.151246, -4.2970877, + 6.093071, 7.4911637, 4.51018, 4.8425875, 9.211085, + -2.4386222, 4.5830803, -5.6079445, 2.3713675, -4.0707507, + 3.1787417, 5.462342, 6.915912, 6.3928423, -7.2970796, + 5.0112796, -9.140893, 4.9990606, 0.38391754, 7.7088532, + 1.9340848, 8.18833, 8.16617, -9.42086, -0.3388326, + -9.659727, 8.243045, 8.099073, 8.439428, -7.038694, + 2.1077902, 3.3866816, -1.9975324, 7.4972878, -7.2525196, + -1.553731, 4.08758, -6.6922374, 9.50525, 4.026735, + -9.243538, 7.2740564, -3.9319072, -6.3228955, 1.6693478, + -7.923119, -3.7423058, -2.2813146, 5.3469067, -1.8285407, + 3.3118162, 8.826356, -4.4641976, -6.4751124, -9.200089, + -2.519147, 4.225298, 2.4105988, -0.4344186, 0.53441775, + 5.2836394, -8.2816105, -4.996147, -1.6870759, -7.8543897, + -3.9788852, -7.0346904, -3.1289773, 7.4567637, -5.6227813, + 1.0709786, -8.866012, 8.427324, -1.1755563, -5.789216, + -8.197835, 5.3342214, 6.0646234, -6.8975716, 7.717031, + 3.480355, 8.312151, -3.6645212, -3.0976524, -8.090359, + -1.9176173, 2.4257212, 1.9700835, 0.4098958, 2.1341088, + 7.652741, -9.9595585, -5.989757, 0.10119354, -7.935407, + -5.792786, -5.22783, -4.318978, 5.414037, -6.4621663, + 1.670883, -6.9224787, 8.696932, -2.0214002, -6.6681314, + -8.326418, 4.9049683, 5.4442496, -6.403739, 7.5822453, + 7.0972915, -9.072851, -0.23897195, 1.7662339, 5.3096304, + 1.983179, -2.222645, -0.34700772, -9.094717, -6.107907, + 9.525174, 8.1550665, -5.6940084, -4.1636486, 1.7360662, + 8.528821, -3.7299833, -9.341266, 2.608542, 9.108706, + 0.7978509, 4.2488184, 2.454484, 0.9446999, -10.106636, + -3.8973773, -6.6566644, -4.5647273, -0.99837756, -6.568582, + 9.324853, -7.9020953, 2.0910501, 2.2896829, 1.6790711, + 1.3159255, -3.5258796, 1.8898442, -8.105812, -4.924962, + 8.771129, 7.1202874, -5.991957, -3.4106019, 2.4450088, + 7.796387, -3.055946, -7.8971434, 1.9856719, 9.001636, + 1.8511922, 3.019749, 3.1227696, 0.4822102, -10.021213, + -3.530504, -6.225959, -3.0029628, -1.7881511, -7.3879776, + 1.3925704, 9.499782, -3.7318087, -3.7074296, -7.7466836, + -1.5284524, 4.0535855, 3.112011, 0.10340207, -0.5429599, + 6.67026, -9.155924, -4.924038, 0.64248866, -10.0103655, + -3.2742946, -4.850029, -3.6707063, 8.586258, -5.855605, + 4.906918, -6.7813993, 7.9938135, -2.5473144, -5.688948, + -7.822478, 2.1421318, 4.66659, -9.701272, 9.549149, + 0.8998125, -8.651497, -0.56899565, -8.639817, 2.3088377, + 2.1264515, 3.2764478, 2.341989, 8.594338, 8.630639, + 2.8440373, 6.2043204, 4.433932, 0.6320018, -1.8179281, + 5.09452, -1.5741565, 8.153934, 8.744339, -3.6945698, + -8.883078, 1.5329908, 5.2745943, 0.44716078, 4.8809066, + -7.9594903, 1.134374, 9.233994, 6.5528665, -4.520542, + 9.477355, -8.622195, -0.23191702, 2.0485356, 3.9379985, + 1.5916302, -1.4516805, -0.0843819, -7.8554378, -5.88308, + 7.999766, 6.2572145, -5.585321, -4.0097756, 0.42382592, + 6.160884, -3.631315, -8.333449, 2.770595, 7.8495173, + 3.3331623, 4.940415, 3.6207345, -0.037517, -11.034698, + -3.185103, -6.614664, -3.2177854, -2.0792234, -6.8879867, + 7.821685, -8.455084, 1.0784642, 4.0033927, 2.7343264, + 2.6052725, -4.1224284, -0.89305353, -6.8267674, -4.9715133, + 8.880253, 5.6994023, -5.9695024, -4.9181266, 1.3017995, + 7.972617, -3.9452884, -10.424556, 2.4504194, 6.21529, + 0.93840516, 4.2070026, 6.159839, 0.91979957, -8.706724, + -4.317946, -6.6823545, -3.0388, -2.464262, -7.3716645, + 1.3926703, 6.544412, -5.6251183, -5.122411, -8.622049, + -2.3905911, 3.9138813, 1.9779967, -0.05011125, 0.13310997, + 7.229751, -9.742043, -8.08724, 1.2426697, -7.9230795, + -3.3162494, -7.129571, -3.5488048, 7.4701195, -5.2357526, + 0.5917681, -6.272206, 6.342328, -2.909731, -4.991607, + -8.845513, 3.3228495, 7.033246, -7.8180246, 8.214469, + 6.3910093, 9.185153, -6.20472, -7.713809, -3.8481297, + 3.5579286, 0.7078448, -3.2893546, 7.384514, -4.448121, + 3.0104196, 9.492943, 8.024847, 4.9114385, 9.965594, + -3.014036, 5.182494, -5.8806014, 2.5312455, -5.9926524, + 4.474469, 6.3717875, 6.993105, 6.493093, -8.935534, + 3.004074, -8.055647, 8.315765, -1.3026813, 8.250377, + 0.02606229, 6.8508425, 9.655665, -7.0116496, -0.41060972, + -10.049198, 7.897801, 6.7791023, 8.3362, -9.821014, + 2.491157, 3.5160472, -1.6228812, 7.398063, -8.769123, + -3.1743705, 3.2827861, -6.497855, 10.831924, 5.2761307, + -9.704417, 4.3817043, -3.9841619, -8.111647, 1.1883026, + -8.115312, -2.9240117, -5.8879666, 4.20928, -0.3587938, + 6.935672, -10.177582, 0.48819053, 3.1250648, 2.9306343, + 3.082544, -3.477687, -1.3768549, -7.4922366, -3.756631, + 10.039836, 3.6670392, -5.9761434, -4.4728765, 3.244255, + 7.027899, -2.3806512, -10.4100685, 1.605716, 7.7953773, + 0.5408159, 1.7156523, 3.824097, -1.0604783, -10.142124, + -5.246805, -6.5283823, -4.579547, -2.42714, -6.709197, + 2.7782338, 7.33353, -6.454507, -2.9929368, -7.8362985, + -2.695445, 2.4900775, 1.6682367, 0.4641757, -1.0495365, + 6.9631333, -9.291356, -8.23837, -0.34263706, -8.275113, + -2.8454232, -5.0864096, -2.681942, 7.5450225, -6.2517986, + 0.06810654, -6.470652, 4.9042645, -1.8369255, -6.6937943, + -7.9625087, 2.8510258, 6.180508, -8.282598, 7.919079, + 1.4897474, 6.7217417, -4.2459426, -4.114431, -8.375707, + -2.143264, 5.6972933, 1.5574739, 0.39375135, 1.7930849, + 5.1737595, -7.826241, -5.160268, -0.80433255, -7.839536, + -5.2620406, -5.4643164, -3.185536, 6.620315, -7.065227, + 1.0524757, -6.125088, 5.7126627, -1.6161644, -3.852159, + -9.164279, 2.7005782, 5.946544, -8.468236, 8.2145405, + 1.1035942, 6.590157, -4.0461283, -4.8090615, -7.6702685, + -2.1121511, 5.1147075, 1.6128504, 2.0064135, 1.0544407, + 6.0038295, -7.8282537, -4.801278, 0.32349443, -8.0649805, + -4.372714, -5.61336, -5.21394, 8.176595, -5.4753284, + 1.7800134, -8.267283, 7.2133374, -0.16594432, -6.317046, + -9.490406, 4.1261597, 5.473317, -7.7551675, 7.007468, + 7.478628, -8.801905, 0.10975724, 3.5478222, 4.797803, + 1.3825226, -3.357369, 0.99262005, -6.94877, -5.4781394, + 9.632604, 5.7492557, -5.9014316, -3.1632116, 2.340859, + 8.708098, -3.1255999, -8.848661, 4.5612836, 8.455157, + 0.73460823, 4.112301, 4.392744, -0.30759293, -6.8036823, + -3.0331545, -8.269506, -2.82415, -0.9411246, -5.993506, + 2.1618164, -8.716055, -0.7432543, -10.255819, 3.095418, + 2.5131428, 4.752442, 0.9907621, 7.8279433, 7.85814, + 0.50430876, 5.2840405, 4.457291, 0.03330028, -0.40692952, + 3.9244103, -2.117118, 7.6977615, 8.759009, -4.2157164, + -9.136053, 3.247858, 4.668686, 0.76162136, 5.3833632, + -9.231471, 0.44309422, 8.380872, 6.7211227, -3.091507, + 2.173508, -9.038242, -1.3666698, -9.819077, 0.37825826, + 2.3898845, 4.2440815, 1.9161536, 7.24787, 6.9124637, + 1.6238527, 5.1140285, 3.1935842, 1.02845, -1.1273454, + 5.638998, -2.497932, 8.342559, 8.586319, -2.9069402, + -7.6387944, 3.5975037, 4.4115705, 0.41506064, 4.9078383, + -9.68327, 1.8159529, 9.744613, 8.40622, -4.495336, + 9.244892, -8.789869, 1.3158468, 4.018167, 3.3922846, + 2.652022, -2.7495477, 0.2528986, -8.268324, -6.004913, + 10.428784, 6.6580734, -5.537176, -1.7177434, 2.7504628, + 6.7735, -2.4454272, -9.998361, 2.9483433, 6.8266654, + 2.3787718, 4.472637, 2.5871701, 0.7355365, -7.7027745, + -4.1879907, -7.172832, -4.1843605, -0.03646783, -5.419406, + 6.958486, 11.011111, -7.1821184, -7.956423, -3.408451, + 4.6850276, -2.348787, -4.398289, 6.9787564, -3.8324208, + 5.967827, 8.433518, 4.660108, 5.5657144, 9.964243, + -1.3515275, 6.404833, -6.4805903, 2.4379845, -6.0816774, + 1.752272, 5.3771873, 6.9613523, 6.9788294, -6.3894596, + 3.7521114, -6.8034263, 6.4458385, -0.7233525, 10.512529, + 4.362273, 9.231461, -6.3382263, -7.659, -3.461823, + 4.71463, 0.17817476, -3.685746, 7.2962036, -4.6489477, + 5.218017, 11.546999, 4.7218375, 6.8498397, 9.281103, + -3.900459, 6.844054, -7.0886965, -0.05019227, -8.233724, + 5.5808983, 6.374517, 8.321048, 7.969449, -7.3478637, + 1.4917561, -8.003144, 4.780668, -1.1981848, 7.753739, + 2.0260844, -8.880096, -3.4258451, -7.141975, 1.9637157, + 1.814725, 5.311151, 1.4831505, 7.8483663, 7.257948, + 1.395786, 6.417756, 5.376912, 0.59505713, 0.00062552, + 3.6634305, -4.159713, 7.3571978, 10.966816, -2.5419605, + -8.466229, 1.904205, 5.6338267, -0.52567476, 5.59736, + -8.361799, 0.5009981, 8.460681, 7.3891273, -3.5272243, + 5.0552278, 9.921456, -7.69693, -7.286378, -1.9198836, + 3.1666567, -2.5832257, -2.2445817, 9.888111, -5.076563, + 5.677401, 7.497946, 5.662994, 5.414262, 8.566503, + -2.5530663, 7.1032815, -6.0612082, 1.3419591, -4.9595256, + 4.3377542, 4.3790717, 6.793512, 8.383502, -7.1278043, + 3.3240774, -9.379446, 6.838661, -0.81241214, 8.694813, + 0.79141915, 7.632467, 8.575382, -8.533798, 0.28954387, + -7.5675836, 5.8653326, 8.97235, 7.1649346, -10.575289, + 0.9359381, 5.02381, -0.5609511, 5.543464, -7.69131, + -2.1792977, 2.4729247, -6.1917787, 10.373678, 7.6549597, + -8.809486, 5.5657206, -3.3169382, -8.042887, 2.0874746, + -7.079005, -3.33398, -3.6843317, 4.0172358, -2.0754814, + 1.1726758, 7.4618697, 6.9483604, -8.469206, 0.7401797, + -10.318176, 8.384557, 10.5476265, 9.146971, -9.250223, + 0.6290606, 4.4941425, -0.7514017, 7.2271705, -8.309598, + -1.4761636, 4.0140634, -6.021102, 9.132852, 5.6610966, + -11.249811, 8.359293, -1.9445792, -7.7393436, -0.3931331, + -8.824441, -2.5995944, -2.5714035, 4.140213, -3.6863053, + 5.517265, 9.020411, -4.9286127, -7.871219, -3.7446704, + 2.5179656, -1.4543481, -2.2703636, 7.010597, -3.6436229, + 6.753862, 7.4129915, 7.1406755, 5.653706, 9.5445175, + 0.15698843, 4.761813, -7.698002, 1.6870106, -4.5410123, + 4.171763, 5.3747005, 6.341021, 7.456738, -8.231657, + 2.763487, -9.208167, 6.676799, -1.1957736, 10.062605, + 4.0975976, 7.312957, -2.4981596, -2.9658387, -8.150425, + -2.1075552, 2.64375, 1.6636052, 1.1483809, 0.09276015, + 5.8556347, -7.8481026, -5.9913163, -0.02840613, -9.937289, + -1.0486673, -5.2340155, -3.83912, 7.7165728, -8.409944, + 0.80863273, -6.9119215, 7.5712357, 0.36031485, -6.056131, + -8.470033, 1.8678337, 3.0121377, -7.3096333, 8.205484, + 5.262654, 8.774514, -4.7603083, -7.2096143, -4.437014, + 3.6080024, -1.624254, -4.2787876, 8.880863, -4.8984556, + 5.1782074, 9.944454, 3.911282, 3.5396595, 8.867042, + -1.2006199, 5.393288, -5.6455317, 0.7829499, -4.0338907, + 2.479272, 6.5080743, 8.582535, 7.0097537, -6.9823785, + 3.984318, -7.225381, 5.3135114, -1.0391048, 8.951443, + -0.70119005, -8.510742, -0.42949116, -10.9224825, 2.8176029, + 1.6800792, 5.778404, 1.7269998, 7.1975236, 7.7258267, + 2.7632928, 5.3399253, 3.4650044, 0.01971426, -1.6468811, + 4.114996, -1.5110453, 6.8689218, 8.269899, -3.1568048, + -7.0344677, 1.2911975, 5.950357, 0.19028673, 4.657226, + -8.199647, 2.246055, 8.989509, 5.3101015, -4.2400866}; - std::vector X_embedded = { - -0.41849962,-0.53906363,0.46958843,-0.35832694,-0.23779503,-0.29751351, - -0.01072748,-0.21353109,-0.54769957,-0.55086273,0.37093949,-0.12714292, - -0.06639574,-0.36098689,-0.13060696,-0.07362658,-1.01205945,-0.39285606, - 0.2864089,-0.32031146,-0.19595343,0.08900568,-0.04813879,-0.06563424, - -0.42655188,-0.69014251,0.51459783,-0.1942696,-0.07767916,-0.6119386, - 0.04813685,-0.22557008,-0.56890118,-0.60293794,0.43429622,-0.09240723, - -0.00624062,-0.25800395,-0.1886092,0.01655941,-0.01961523,-0.14147359, - 0.41414487,-0.8512944,-0.61199242,-0.18586016,0.14024924,-0.41635606, - -0.02890144,0.1065347,0.39700791,-1.14060664,-0.95313865,0.14416681, - 0.17306046,-0.53189689,-0.98987544,-0.67918193,0.41787854,-0.20878236, - -0.06612862,0.03502904,-0.03765266,-0.0980606,-0.00971657,0.29432917, - 0.36575687,-1.1645509,-0.89094597,0.03718805,0.2310573,-0.38345811, - -0.10401925,-0.10653082,0.38469055,-0.88302094,-0.80197543,0.03548668, - 0.02775662,-0.54374295,0.03379983,0.00923623,0.29320273,-1.05263519, - -0.93360096,0.03778313,0.12360487,-0.56437284,0.0644429,0.33432651, - 0.36450726,-1.22978747,-0.83822101,-0.18796451,0.34888434,-0.3801491, - -0.45327303,-0.59747899,0.39697698,-0.15616602,-0.06159166,-0.40301991, - -0.11725303,-0.11913263,-0.12406619,-0.11227967,0.43083835,-0.90535849, - -0.81646025,0.10012121,-0.0141237,-0.63747931,0.04805023,0.34190539, - 0.50725192,-1.17861414,-0.74641538,-0.09333111,0.27992678,-0.56214809, - 0.04970971,0.36249384,0.57705611,-1.16913795,-0.69849908,0.10957897, - 0.27983218,-0.62088525,0.0410459,0.23973398,0.40960434,-1.14183664, - -0.83321381,0.02149482,0.21720445,-0.49869928,-0.95655465,-0.51680422, - 0.45761383,-0.08351214,-0.12151554,0.00819737,-0.20813803,-0.01055793, - 0.25319234,0.36154974,0.1822421,-1.15837133,-0.92209691,-0.0501582, - 0.08535917,-0.54003763,-1.08675635,-1.04009593,0.09408128,0.07009826, - -0.01762833,-0.19180447,-0.18029785,-0.20342001,0.04034991,0.1814747, - 0.36906669,-1.13532007,-0.8852452,0.0782818,0.16825101,-0.50301319, - -0.29128098,-0.65341312,0.51484352,-0.38758236,-0.22531103,-0.55021971, - 0.10804344,-0.3521522,-0.38849035,-0.74110794,0.53761131,-0.25142813, - -0.1118066,-0.47453368,0.06347904,-0.23796193,-1.02682328,-0.47594091, - 0.39515916,-0.2782529,-0.16566519,0.08063579,0.00810116,-0.06213913, - -1.059654,-0.62496334,0.53698546,-0.11806234,0.00356161,0.11513405, - -0.14213292,0.04102662,-0.36622161,-0.73686272,0.48323864,-0.27338892, - -0.14203401,-0.41736352,0.03332564,-0.21907479,-0.06396769,0.01831361, - 0.46263444,-1.01878166,-0.86486858,0.17622118,-0.01249686,-0.74530888, - -0.9354887,-0.5027945,0.38170099,-0.15547098,0.00677824,-0.04677663, - -0.13541745,0.07253501,-0.97933143,-0.58001202,0.48235369,-0.18836913, - -0.02430783,0.07572441,-0.08101331,0.00630076,-0.16881248,-0.67989182, - 0.46083611,-0.43910736,-0.29321918,-0.38735861,0.07669903,-0.29749861, - -0.40047669,-0.56722462,0.33168188,-0.13118173,-0.06672747,-0.56856316, - -0.26269144,-0.14236671,0.10651901,0.4962585,0.38848072,-1.06653547, - -0.64079332,-0.47378591,0.43195483,-0.04856951,-0.9840439,-0.70610428, - 0.34028092,-0.2089237,-0.05382041,0.01625874,-0.02080803,-0.12535211, - -0.04146428,-1.24533033,0.48944879,0.0578458,0.26708388,-0.90321028, - 0.35377088,-0.36791429,-0.35382384,-0.52748734,0.42854419,-0.31744713, - -0.19174226,-0.39073724,-0.03258846,-0.19978228,-0.36185205,-0.57412046, - 0.43681973,-0.25414538,-0.12904905,-0.46334973,-0.03123853,-0.11303604, - -0.87073672,-0.45441297,0.41825858,-0.25303507,-0.21845073,0.10248682, - -0.11045569,-0.10002795,-0.00572806,0.16519061,0.42651513,-1.11417019, - -0.83789682,0.02995787,0.16843079,-0.53874511,0.03056994,0.17877036, - 0.49632853,-1.03276777,-0.74778616,-0.03971953,0.10907949,-0.67385727, - -0.9523471,-0.56550741,0.40409449,-0.2703723,-0.10175014,0.13605487, - -0.06306008,-0.01768126,-0.4749442,-0.56964815,0.39389887,-0.19248079, - -0.04161081,-0.38728487,-0.20341556,-0.12656988,-0.35949609,-0.46137866, - 0.28798422,-0.06603147,-0.04363992,-0.60343552,-0.23565227,-0.10242701, - -0.06792886,0.09689897,0.33259571,-0.98854214,-0.84444433,0.00673901, - 0.13457057,-0.43145794,-0.51500046,-0.50821936,0.38000089,0.0132636, - 0.0580942,-0.40157595,-0.11967677,0.02549113,-0.10350953,0.22918226, - 0.40411913,-1.05619383,-0.71218503,-0.02197581,0.26422262,-0.34765676, - 0.06601537,0.21712676,0.34723559,-1.20982027,-0.95646334,0.00793948, - 0.27620381,-0.43475035,-0.67326003,-0.6137197,0.43724492,-0.17666136, - -0.06591748,-0.18937394,-0.07400128,-0.06881691,-0.5201112,-0.61088628, - 0.4225319,-0.18969463,-0.06921366,-0.33993208,-0.06990873,-0.10288513, - -0.70659858,-0.56003648,0.46628812,-0.16090363,-0.0185108,-0.1431348, - -0.1128775,-0.0078648,-0.02323332,0.04292452,0.39291084,-0.94897962, - -0.63863206,-0.16546988,0.23698957,-0.30633628 - }; + std::vector X_embedded = { + -0.41849962, -0.53906363, 0.46958843, -0.35832694, -0.23779503, + -0.29751351, -0.01072748, -0.21353109, -0.54769957, -0.55086273, + 0.37093949, -0.12714292, -0.06639574, -0.36098689, -0.13060696, + -0.07362658, -1.01205945, -0.39285606, 0.2864089, -0.32031146, + -0.19595343, 0.08900568, -0.04813879, -0.06563424, -0.42655188, + -0.69014251, 0.51459783, -0.1942696, -0.07767916, -0.6119386, + 0.04813685, -0.22557008, -0.56890118, -0.60293794, 0.43429622, + -0.09240723, -0.00624062, -0.25800395, -0.1886092, 0.01655941, + -0.01961523, -0.14147359, 0.41414487, -0.8512944, -0.61199242, + -0.18586016, 0.14024924, -0.41635606, -0.02890144, 0.1065347, + 0.39700791, -1.14060664, -0.95313865, 0.14416681, 0.17306046, + -0.53189689, -0.98987544, -0.67918193, 0.41787854, -0.20878236, + -0.06612862, 0.03502904, -0.03765266, -0.0980606, -0.00971657, + 0.29432917, 0.36575687, -1.1645509, -0.89094597, 0.03718805, + 0.2310573, -0.38345811, -0.10401925, -0.10653082, 0.38469055, + -0.88302094, -0.80197543, 0.03548668, 0.02775662, -0.54374295, + 0.03379983, 0.00923623, 0.29320273, -1.05263519, -0.93360096, + 0.03778313, 0.12360487, -0.56437284, 0.0644429, 0.33432651, + 0.36450726, -1.22978747, -0.83822101, -0.18796451, 0.34888434, + -0.3801491, -0.45327303, -0.59747899, 0.39697698, -0.15616602, + -0.06159166, -0.40301991, -0.11725303, -0.11913263, -0.12406619, + -0.11227967, 0.43083835, -0.90535849, -0.81646025, 0.10012121, + -0.0141237, -0.63747931, 0.04805023, 0.34190539, 0.50725192, + -1.17861414, -0.74641538, -0.09333111, 0.27992678, -0.56214809, + 0.04970971, 0.36249384, 0.57705611, -1.16913795, -0.69849908, + 0.10957897, 0.27983218, -0.62088525, 0.0410459, 0.23973398, + 0.40960434, -1.14183664, -0.83321381, 0.02149482, 0.21720445, + -0.49869928, -0.95655465, -0.51680422, 0.45761383, -0.08351214, + -0.12151554, 0.00819737, -0.20813803, -0.01055793, 0.25319234, + 0.36154974, 0.1822421, -1.15837133, -0.92209691, -0.0501582, + 0.08535917, -0.54003763, -1.08675635, -1.04009593, 0.09408128, + 0.07009826, -0.01762833, -0.19180447, -0.18029785, -0.20342001, + 0.04034991, 0.1814747, 0.36906669, -1.13532007, -0.8852452, + 0.0782818, 0.16825101, -0.50301319, -0.29128098, -0.65341312, + 0.51484352, -0.38758236, -0.22531103, -0.55021971, 0.10804344, + -0.3521522, -0.38849035, -0.74110794, 0.53761131, -0.25142813, + -0.1118066, -0.47453368, 0.06347904, -0.23796193, -1.02682328, + -0.47594091, 0.39515916, -0.2782529, -0.16566519, 0.08063579, + 0.00810116, -0.06213913, -1.059654, -0.62496334, 0.53698546, + -0.11806234, 0.00356161, 0.11513405, -0.14213292, 0.04102662, + -0.36622161, -0.73686272, 0.48323864, -0.27338892, -0.14203401, + -0.41736352, 0.03332564, -0.21907479, -0.06396769, 0.01831361, + 0.46263444, -1.01878166, -0.86486858, 0.17622118, -0.01249686, + -0.74530888, -0.9354887, -0.5027945, 0.38170099, -0.15547098, + 0.00677824, -0.04677663, -0.13541745, 0.07253501, -0.97933143, + -0.58001202, 0.48235369, -0.18836913, -0.02430783, 0.07572441, + -0.08101331, 0.00630076, -0.16881248, -0.67989182, 0.46083611, + -0.43910736, -0.29321918, -0.38735861, 0.07669903, -0.29749861, + -0.40047669, -0.56722462, 0.33168188, -0.13118173, -0.06672747, + -0.56856316, -0.26269144, -0.14236671, 0.10651901, 0.4962585, + 0.38848072, -1.06653547, -0.64079332, -0.47378591, 0.43195483, + -0.04856951, -0.9840439, -0.70610428, 0.34028092, -0.2089237, + -0.05382041, 0.01625874, -0.02080803, -0.12535211, -0.04146428, + -1.24533033, 0.48944879, 0.0578458, 0.26708388, -0.90321028, + 0.35377088, -0.36791429, -0.35382384, -0.52748734, 0.42854419, + -0.31744713, -0.19174226, -0.39073724, -0.03258846, -0.19978228, + -0.36185205, -0.57412046, 0.43681973, -0.25414538, -0.12904905, + -0.46334973, -0.03123853, -0.11303604, -0.87073672, -0.45441297, + 0.41825858, -0.25303507, -0.21845073, 0.10248682, -0.11045569, + -0.10002795, -0.00572806, 0.16519061, 0.42651513, -1.11417019, + -0.83789682, 0.02995787, 0.16843079, -0.53874511, 0.03056994, + 0.17877036, 0.49632853, -1.03276777, -0.74778616, -0.03971953, + 0.10907949, -0.67385727, -0.9523471, -0.56550741, 0.40409449, + -0.2703723, -0.10175014, 0.13605487, -0.06306008, -0.01768126, + -0.4749442, -0.56964815, 0.39389887, -0.19248079, -0.04161081, + -0.38728487, -0.20341556, -0.12656988, -0.35949609, -0.46137866, + 0.28798422, -0.06603147, -0.04363992, -0.60343552, -0.23565227, + -0.10242701, -0.06792886, 0.09689897, 0.33259571, -0.98854214, + -0.84444433, 0.00673901, 0.13457057, -0.43145794, -0.51500046, + -0.50821936, 0.38000089, 0.0132636, 0.0580942, -0.40157595, + -0.11967677, 0.02549113, -0.10350953, 0.22918226, 0.40411913, + -1.05619383, -0.71218503, -0.02197581, 0.26422262, -0.34765676, + 0.06601537, 0.21712676, 0.34723559, -1.20982027, -0.95646334, + 0.00793948, 0.27620381, -0.43475035, -0.67326003, -0.6137197, + 0.43724492, -0.17666136, -0.06591748, -0.18937394, -0.07400128, + -0.06881691, -0.5201112, -0.61088628, 0.4225319, -0.18969463, + -0.06921366, -0.33993208, -0.06990873, -0.10288513, -0.70659858, + -0.56003648, 0.46628812, -0.16090363, -0.0185108, -0.1431348, + -0.1128775, -0.0078648, -0.02323332, 0.04292452, 0.39291084, + -0.94897962, -0.63863206, -0.16546988, 0.23698957, -0.30633628}; - cudaStream_t stream; - cudaStreamCreate(&stream); + cudaStream_t stream; + cudaStreamCreate(&stream); - allocator.reset(new defaultDeviceAllocator); + allocator.reset(new defaultDeviceAllocator); - float* d_X = (float*)allocator->allocate(X.size() * sizeof(float), stream); - float* d_X_embedded = (float*)allocator->allocate(X_embedded.size() * sizeof(float), stream); + float* d_X = (float*)allocator->allocate(X.size() * sizeof(float), stream); + float* d_X_embedded = + (float*)allocator->allocate(X_embedded.size() * sizeof(float), stream); - updateDevice(d_X, X.data(), X.size(), stream); - updateDevice(d_X_embedded, X_embedded.data(), X_embedded.size(), stream); + updateDevice(d_X, X.data(), X.size(), stream); + updateDevice(d_X_embedded, X_embedded.data(), X_embedded.size(), stream); - // euclidean test - score = trustworthiness_score(d_X, d_X_embedded, 50, 30, 8, 5, allocator, stream); + // euclidean test + score = trustworthiness_score( + d_X, d_X_embedded, 50, 30, 8, 5, allocator, stream); - allocator->deallocate(d_X, X.size() * sizeof(float), stream); - allocator->deallocate(d_X_embedded, X_embedded.size() * sizeof(float), stream); + allocator->deallocate(d_X, X.size() * sizeof(float), stream); + allocator->deallocate(d_X_embedded, X_embedded.size() * sizeof(float), + stream); - cudaStreamDestroy(stream); - } - - void SetUp() override { - basicTest(); - } - - void TearDown() override { - } + cudaStreamDestroy(stream); + } - protected: - double score; - std::shared_ptr allocator; - }; + void SetUp() override { basicTest(); } + void TearDown() override {} - typedef TrustworthinessScoreTest TrustworthinessScoreTestF; - TEST_F(TrustworthinessScoreTestF, Result) { - ASSERT_TRUE(0.9374 < score && score < 0.9376); - } -}; + protected: + double score; + std::shared_ptr allocator; }; +typedef TrustworthinessScoreTest TrustworthinessScoreTestF; +TEST_F(TrustworthinessScoreTestF, Result) { + ASSERT_TRUE(0.9374 < score && score < 0.9376); +} +}; // namespace Score +}; // namespace MLCommon diff --git a/cpp/test/sg/rf_test.cu b/cpp/test/sg/rf_test.cu index a582354772..e00bda61de 100644 --- a/cpp/test/sg/rf_test.cu +++ b/cpp/test/sg/rf_test.cu @@ -39,6 +39,7 @@ struct RfInputs { int n_bins; int split_algo; int min_rows_per_node; + CRITERION split_criterion; }; template @@ -47,14 +48,15 @@ template } template -class RfTest : public ::testing::TestWithParam> { +class RfClassifierTest : public ::testing::TestWithParam> { protected: void basicTest() { params = ::testing::TestWithParam>::GetParam(); DecisionTree::DecisionTreeParams tree_params( params.max_depth, params.max_leaves, params.max_features, params.n_bins, - params.split_algo, params.min_rows_per_node, params.bootstrap_features); + params.split_algo, params.min_rows_per_node, params.bootstrap_features, + params.split_criterion, false); RF_params rf_params(params.bootstrap, params.bootstrap_features, params.n_trees, params.rows_sample, tree_params); //rf_params.print(); @@ -66,6 +68,8 @@ class RfTest : public ::testing::TestWithParam> { int data_len = params.n_rows * params.n_cols; allocate(data, data_len); allocate(labels, params.n_rows); + allocate(predicted_labels, params.n_inference_rows); + cudaStream_t stream; CUDA_CHECK(cudaStreamCreate(&stream)); @@ -89,18 +93,21 @@ class RfTest : public ::testing::TestWithParam> { labels_map.size()); CUDA_CHECK(cudaStreamSynchronize(stream)); - CUDA_CHECK(cudaStreamDestroy(stream)); // Inference data: same as train, but row major int inference_data_len = params.n_inference_rows * params.n_cols; inference_data_h = {30.0, 10.0, 1.0, 20.0, 2.0, 10.0, 0.0, 40.0}; inference_data_h.resize(inference_data_len); + allocate(inference_data_d, inference_data_len); + updateDevice(inference_data_d, inference_data_h.data(), data_len, stream); // Predict and compare against known labels - predicted_labels.resize(params.n_inference_rows); - RF_metrics tmp = score(handle, rf_classifier, inference_data_h.data(), - labels_h.data(), params.n_inference_rows, - params.n_cols, predicted_labels.data(), false); + RF_metrics tmp = + score(handle, rf_classifier, inference_data_d, labels, + params.n_inference_rows, params.n_cols, predicted_labels, false); + CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaStreamDestroy(stream)); + accuracy = tmp.accuracy; } @@ -112,16 +119,17 @@ class RfTest : public ::testing::TestWithParam> { inference_data_h.clear(); labels_h.clear(); labels_map.clear(); - predicted_labels.clear(); CUDA_CHECK(cudaFree(labels)); + CUDA_CHECK(cudaFree(predicted_labels)); CUDA_CHECK(cudaFree(data)); + CUDA_CHECK(cudaFree(inference_data_d)); delete rf_classifier; } protected: RfInputs params; - T* data; + T *data, *inference_data_d; int* labels; std::vector inference_data_h; std::vector labels_h; @@ -131,33 +139,153 @@ class RfTest : public ::testing::TestWithParam> { rfClassifier* rf_classifier; float accuracy = -1.0f; // overriden in each test SetUp and TearDown - std::vector predicted_labels; + int* predicted_labels; }; -const std::vector> inputsf2 = { - {4, 2, 1, 1.0f, 1.0f, 4, -1, -1, false, false, 4, SPLIT_ALGO::HIST, - 2}, // single tree forest, bootstrap false, unlimited depth, 4 bins - {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, - 2}, // single tree forest, bootstrap false, depth of 8, 4 bins - {4, 2, 10, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, - 2}, //forest with 10 trees, all trees should produce identical predictions (no bootstrapping or column subsampling) - {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::HIST, - 2}, //forest with 10 trees, with bootstrap and column subsampling enabled, 3 bins - {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::GLOBAL_QUANTILE, - 2} //forest with 10 trees, with bootstrap and column subsampling enabled, 3 bins, different split algorithm +//------------------------------------------------------------------------------------------------------------------------------------- + +template +class RfRegressorTest : public ::testing::TestWithParam> { + protected: + void basicTest() { + params = ::testing::TestWithParam>::GetParam(); + + DecisionTree::DecisionTreeParams tree_params( + params.max_depth, params.max_leaves, params.max_features, params.n_bins, + params.split_algo, params.min_rows_per_node, params.bootstrap_features, + params.split_criterion, false); + RF_params rf_params(params.bootstrap, params.bootstrap_features, + params.n_trees, params.rows_sample, tree_params); + //rf_params.print(); + + //-------------------------------------------------------- + // Random Forest + //-------------------------------------------------------- + + int data_len = params.n_rows * params.n_cols; + allocate(data, data_len); + allocate(labels, params.n_rows); + allocate(predicted_labels, params.n_inference_rows); + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + + // Populate data (assume Col major) + std::vector data_h = {0.0, 0.0, 0.0, 0.0, 10.0, 20.0, 30.0, 40.0}; + data_h.resize(data_len); + updateDevice(data, data_h.data(), data_len, stream); + + // Populate labels + labels_h = {1.0, 2.0, 3.0, 4.0}; + labels_h.resize(params.n_rows); + updateDevice(labels, labels_h.data(), params.n_rows, stream); + + rf_regressor = new typename rfRegressor::rfRegressor(rf_params); + + cumlHandle handle; + handle.setStream(stream); + + fit(handle, rf_regressor, data, params.n_rows, params.n_cols, labels); + + CUDA_CHECK(cudaStreamSynchronize(stream)); + + // Inference data: same as train, but row major + int inference_data_len = params.n_inference_rows * params.n_cols; + inference_data_h = {0.0, 10.0, 0.0, 20.0, 0.0, 30.0, 0.0, 40.0}; + inference_data_h.resize(inference_data_len); + allocate(inference_data_d, inference_data_len); + updateDevice(inference_data_d, inference_data_h.data(), data_len, stream); + + // Predict and compare against known labels + RF_metrics tmp = + score(handle, rf_regressor, inference_data_d, labels, + params.n_inference_rows, params.n_cols, predicted_labels, false); + CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaStreamDestroy(stream)); + + mse = tmp.mean_squared_error; + } + + void SetUp() override { basicTest(); } + + void TearDown() override { + mse = -1.0f; // reset mse + inference_data_h.clear(); + labels_h.clear(); + + CUDA_CHECK(cudaFree(labels)); + CUDA_CHECK(cudaFree(predicted_labels)); + CUDA_CHECK(cudaFree(data)); + CUDA_CHECK(cudaFree(inference_data_d)); + delete rf_regressor; + } + + protected: + RfInputs params; + T *data, *inference_data_d; + T* labels; + std::vector inference_data_h; + std::vector labels_h; + + rfRegressor* rf_regressor; + float mse = -1.0f; // overriden in each test SetUp and TearDown + + T* predicted_labels; }; +//------------------------------------------------------------------------------------------------------------------------------------- -const std::vector> inputsd2 = { // Same as inputsf2 - {4, 2, 1, 1.0f, 1.0f, 4, -1, -1, false, false, 4, SPLIT_ALGO::HIST, 2}, - {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2}, - {4, 2, 10, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2}, - {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::HIST, 2}, +const std::vector> inputsf2_clf = { + {4, 2, 1, 1.0f, 1.0f, 4, -1, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION:: + GINI}, // single tree forest, bootstrap false, unlimited depth, 4 bins + {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::GINI}, // single tree forest, bootstrap false, depth of 8, 4 bins + {4, 2, 10, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION:: + GINI}, //forest with 10 trees, all trees should produce identical predictions (no bootstrapping or column subsampling) + {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::HIST, 2, + CRITERION:: + GINI}, //forest with 10 trees, with bootstrap and column subsampling enabled, 3 bins + {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::GLOBAL_QUANTILE, + 2, + CRITERION:: + CRITERION_END}, //forest with 10 trees, with bootstrap and column subsampling enabled, 3 bins, different split algorithm + {4, 2, 1, 1.0f, 1.0f, 4, -1, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::ENTROPY}, + {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::ENTROPY}, + {4, 2, 10, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::ENTROPY}, + {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::HIST, 2, + CRITERION::ENTROPY}, {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::GLOBAL_QUANTILE, - 2}}; + 2, CRITERION::ENTROPY}}; -typedef RfTest RfTestF; -TEST_P(RfTestF, Fit) { - //rf_classifier->print_rf_detailed(); // Prints all trees in the forest. Leaf nodes use the remapped values from labels_map. +const std::vector> inputsd2_clf = { // Same as inputsf2_clf + {4, 2, 1, 1.0f, 1.0f, 4, -1, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::GINI}, + {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::GINI}, + {4, 2, 10, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::GINI}, + {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::HIST, 2, + CRITERION::GINI}, + {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::GLOBAL_QUANTILE, + 2, CRITERION::CRITERION_END}, + {4, 2, 1, 1.0f, 1.0f, 4, -1, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::ENTROPY}, + {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::ENTROPY}, + {4, 2, 10, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::ENTROPY}, + {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::HIST, 2, + CRITERION::ENTROPY}, + {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::GLOBAL_QUANTILE, + 2, CRITERION::ENTROPY}}; + +typedef RfClassifierTest RfClassifierTestF; +TEST_P(RfClassifierTestF, Fit) { + //rf_classifier + // ->print_rf_detailed(); // Prints all trees in the forest. Leaf nodes use the remapped values from labels_map. if (!params.bootstrap && (params.max_features == 1.0f)) { ASSERT_TRUE(accuracy == 1.0f); } else { @@ -165,8 +293,8 @@ TEST_P(RfTestF, Fit) { } } -typedef RfTest RfTestD; -TEST_P(RfTestD, Fit) { +typedef RfClassifierTest RfClassifierTestD; +TEST_P(RfClassifierTestD, Fit) { if (!params.bootstrap && (params.max_features == 1.0f)) { ASSERT_TRUE(accuracy == 1.0f); } else { @@ -174,8 +302,63 @@ TEST_P(RfTestD, Fit) { } } -INSTANTIATE_TEST_CASE_P(RfTests, RfTestF, ::testing::ValuesIn(inputsf2)); +INSTANTIATE_TEST_CASE_P(RfClassifierTests, RfClassifierTestF, + ::testing::ValuesIn(inputsf2_clf)); + +INSTANTIATE_TEST_CASE_P(RfClassifierTests, RfClassifierTestD, + ::testing::ValuesIn(inputsd2_clf)); + +typedef RfRegressorTest RfRegressorTestF; +TEST_P(RfRegressorTestF, Fit) { + //rf_regressor->print_rf_detailed(); // Prints all trees in the forest. + if (!params.bootstrap && (params.max_features == 1.0f)) { + ASSERT_TRUE(mse == 0.0f); + } else { + ASSERT_TRUE(mse <= 0.2f); + } +} + +typedef RfRegressorTest RfRegressorTestD; +TEST_P(RfRegressorTestD, Fit) { + if (!params.bootstrap && (params.max_features == 1.0f)) { + ASSERT_TRUE(mse == 0.0f); + } else { + ASSERT_TRUE(mse <= 0.2f); + } +} + +const std::vector> inputsf2_reg = { + {4, 2, 1, 1.0f, 1.0f, 4, -1, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::MSE}, + {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::MSE}, + {4, 2, 5, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION:: + CRITERION_END}, // CRITERION_END uses the default criterion (GINI for classification, MSE for regression) + {4, 2, 1, 1.0f, 1.0f, 4, -1, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::MAE}, + {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::GLOBAL_QUANTILE, + 2, CRITERION::MAE}, + {4, 2, 5, 1.0f, 1.0f, 4, 8, -1, true, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::CRITERION_END}}; + +const std::vector> inputsd2_reg = { // Same as inputsf2_reg + {4, 2, 1, 1.0f, 1.0f, 4, -1, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::MSE}, + {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::MSE}, + {4, 2, 5, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::CRITERION_END}, + {4, 2, 1, 1.0f, 1.0f, 4, -1, -1, false, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::MAE}, + {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::GLOBAL_QUANTILE, + 2, CRITERION::MAE}, + {4, 2, 5, 1.0f, 1.0f, 4, 8, -1, true, false, 4, SPLIT_ALGO::HIST, 2, + CRITERION::CRITERION_END}}; -INSTANTIATE_TEST_CASE_P(RfTests, RfTestD, ::testing::ValuesIn(inputsd2)); +INSTANTIATE_TEST_CASE_P(RfRegressorTests, RfRegressorTestF, + ::testing::ValuesIn(inputsf2_reg)); +INSTANTIATE_TEST_CASE_P(RfRegressorTests, RfRegressorTestD, + ::testing::ValuesIn(inputsd2_reg)); } // end namespace ML diff --git a/python/cuml/ensemble/randomforest.pyx b/python/cuml/ensemble/randomforest.pyx index bb1fbf9bd4..f7f71c0328 100644 --- a/python/cuml/ensemble/randomforest.pyx +++ b/python/cuml/ensemble/randomforest.pyx @@ -45,6 +45,13 @@ cdef extern from "randomforest/randomforest.h" namespace "ML": CLASSIFICATION, REGRESSION + cdef enum CRITERION: + GINI, + ENTROPY, + MSE, + MAE, + CRITERION_END + cdef struct RF_params: pass @@ -95,7 +102,8 @@ cdef extern from "randomforest/randomforest.h" namespace "ML": cdef RF_params set_rf_class_obj(int, int, float, int, int, int, - bool, bool, int, float) except + + bool, bool, int, float, CRITERION, + bool) except + cdef class RandomForest_impl(): @@ -109,6 +117,7 @@ cdef class RandomForest_impl(): cdef object max_features cdef object n_bins cdef object split_algo + cdef object split_criterion cdef object min_rows_per_node cdef object bootstrap cdef object bootstrap_features @@ -117,20 +126,22 @@ cdef class RandomForest_impl(): cdef object n_cols cdef object rows_sample cdef object max_leaves + cdef object quantile_per_tree cdef object gdf_datatype cdef object stats cdef object dtype def __cinit__(self, n_estimators=10, max_depth=-1, handle=None, max_features=1.0, n_bins=8, - split_algo=0, min_rows_per_node=2, + split_algo=0, split_criterion=0, min_rows_per_node=2, bootstrap=True, bootstrap_features=False, type_model="classifier", verbose=False, - rows_sample=1.0, max_leaves=-1, + rows_sample=1.0, max_leaves=-1, quantile_per_tree=False, gdf_datatype=None): self.handle = handle self.split_algo = split_algo + self.split_criterion = split_criterion self.min_rows_per_node = min_rows_per_node self.bootstrap_features = bootstrap_features self.rows_sample = rows_sample @@ -139,6 +150,7 @@ cdef class RandomForest_impl(): self.max_depth = max_depth self.max_features = max_features self.type_model = self._get_type(type_model) + self.quantile_per_tree = quantile_per_tree self.bootstrap = bootstrap self.verbose = verbose self.n_bins = n_bins @@ -205,7 +217,9 @@ cdef class RandomForest_impl(): self.bootstrap_features, self.bootstrap, self.n_estimators, - self.rows_sample) + self.rows_sample, + self.split_criterion, + self.quantile_per_tree) self.rf_classifier32 = new \ rfClassifier[float](rf_param) @@ -239,8 +253,9 @@ cdef class RandomForest_impl(): def predict(self, X): cdef uintptr_t X_ptr - X_ptr = X.ctypes.data - n_rows, n_cols = np.shape(X) + # row major format + X_m, X_ptr, n_rows, n_cols, _ = \ + input_to_dev_array(X, order='C') if n_cols != self.n_cols: raise ValueError("The number of columns/features in the training" " and test data should be the same ") @@ -248,10 +263,10 @@ cdef class RandomForest_impl(): raise ValueError("The datatype of the training data is different" " from the datatype of the testing data") - preds = np.zeros(n_rows, - dtype=np.int32) - - cdef uintptr_t preds_ptr = preds.ctypes.data + preds = np.zeros(n_rows, dtype=np.int32) + cdef uintptr_t preds_ptr + preds_m, preds_ptr, _, _, _ = \ + input_to_dev_array(preds) cdef cumlHandle* handle_ =\ self.handle.getHandle() @@ -279,14 +294,18 @@ cdef class RandomForest_impl(): % (str(self.dtype))) self.handle.sync() + # synchronous w/o a stream + preds = preds_m.copy_to_host() + del(X_m) + del(preds_m) return preds def score(self, X, y): cdef uintptr_t X_ptr, y_ptr - X_ptr = X.ctypes.data - y_ptr = y.ctypes.data - n_rows, n_cols = np.shape(X) + X_m, X_ptr, n_rows, n_cols, _ = \ + input_to_dev_array(X, order='C') + y_m, y_ptr, _, _, _ = input_to_dev_array(y) if n_cols != self.n_cols: raise ValueError("The number of columns/features in the training" @@ -300,8 +319,9 @@ cdef class RandomForest_impl(): preds = np.zeros(n_rows, dtype=np.int32) - - cdef uintptr_t preds_ptr = (preds).ctypes.data + cdef uintptr_t preds_ptr + preds_m, preds_ptr, _, _, _ = \ + input_to_dev_array(preds) cdef cumlHandle* handle_ =\ self.handle.getHandle() @@ -327,6 +347,9 @@ cdef class RandomForest_impl(): self.verbose) self.handle.sync() + del(X_m) + del(y_m) + del(preds_m) return self.stats @@ -391,9 +414,17 @@ class RandomForestClassifier(Base): number of trees in the forest. handle : cuml.Handle If it is None, a new one is created just for this class. + split_criterion: The criterion used to split nodes. + 0 for GINI, 1 for ENTROPY, 4 for CRITERION_END. + 2 and 3 not valid for classification + (default = 0) split_algo : 0 for HIST and 1 for GLOBAL_QUANTILE (default = 0) the algorithm to determine how nodes are split in the tree. + split_criterion: The criterion used to split nodes. + 0 for GINI, 1 for ENTROPY, 4 for CRITERION_END. + 2 and 3 not valid for classification + (default = 0) bootstrap : boolean (default = True) Control bootstrapping. If set, each tree in the forest is built @@ -418,22 +449,25 @@ class RandomForestClassifier(Base): min_rows_per_node : int (default = 2) The minimum number of samples (rows) needed to split a node. + quantile_per_tree : boolean (default = False) + Whether quantile is computed for individal trees in RF. + Only relevant for GLOBAL_QUANTILE split_algo. """ variables = ['n_estimators', 'max_depth', 'handle', 'max_features', 'n_bins', - 'split_algo', 'min_rows_per_node', + 'split_algo', 'split_criterion', 'min_rows_per_node', 'bootstrap', 'bootstrap_features', 'verbose', 'rows_sample', - 'max_leaves'] + 'max_leaves', 'quantile_per_tree'] def __init__(self, n_estimators=10, max_depth=-1, handle=None, max_features=1.0, n_bins=8, - split_algo=0, min_rows_per_node=2, + split_algo=0, split_criterion=0, min_rows_per_node=2, bootstrap=True, bootstrap_features=False, type_model="classifier", verbose=False, - rows_sample=1.0, max_leaves=-1, + rows_sample=1.0, max_leaves=-1, quantile_per_tree=False, gdf_datatype=None, criterion=None, min_samples_leaf=None, min_weight_fraction_leaf=None, max_leaf_nodes=None, min_impurity_decrease=None, @@ -461,6 +495,7 @@ class RandomForestClassifier(Base): super(RandomForestClassifier, self).__init__(handle, verbose) self.split_algo = split_algo + self.split_criterion = split_criterion self.min_rows_per_node = min_rows_per_node self.bootstrap_features = bootstrap_features self.rows_sample = rows_sample @@ -472,13 +507,16 @@ class RandomForestClassifier(Base): self.verbose = verbose self.n_bins = n_bins self.n_cols = None + self.quantile_per_tree = quantile_per_tree self._impl = RandomForest_impl(n_estimators, max_depth, self.handle, max_features, n_bins, - split_algo, min_rows_per_node, + split_algo, split_criterion, + min_rows_per_node, bootstrap, bootstrap_features, type_model, verbose, rows_sample, max_leaves, + quantile_per_tree, gdf_datatype) def fit(self, X, y): diff --git a/python/cuml/test/test_random_forest.py b/python/cuml/test/test_random_forest.py index d25f9c3049..53f3005270 100644 --- a/python/cuml/test/test_random_forest.py +++ b/python/cuml/test/test_random_forest.py @@ -60,8 +60,10 @@ def test_rf_predict_numpy(datatype, use_handle, split_algo, # Initialize, fit and predict using cuML's # random forest classification model cuml_model = curfc(max_features=1.0, - n_bins=8, split_algo=split_algo, min_rows_per_node=2, - n_estimators=30, handle=handle, max_leaves=-1) + n_bins=8, split_algo=0, split_criterion=0, + min_rows_per_node=2, + n_estimators=40, handle=handle, max_leaves=-1, + max_depth=-1) cuml_model.fit(X_train, y_train) cu_predict = cuml_model.predict(X_test) cu_acc = accuracy_score(y_test, cu_predict)