Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Review] Random Forest & Decision Tree Regression + major updates to Classification #635

Merged
merged 58 commits into from
Jul 1, 2019
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
e333239
Added rfRegressor class to random forest.
myrtw May 7, 2019
af36570
Added base dt class and DecisionTreeRegressor
myrtw May 7, 2019
545582d
More class updates.
myrtw May 8, 2019
8c5dd25
TreeNode updates & more dt class changes.
myrtw May 8, 2019
b9c5c6f
added regression kernels, modified naming convention to metric quest…
vishalmehta1991 May 8, 2019
04da29a
More decision tree changes
myrtw May 8, 2019
14fb0e0
added kernels for mean squared error
vishalmehta1991 May 9, 2019
feaac6d
added all regression code / kernels, now compiles, next step is testing
vishalmehta1991 May 9, 2019
09eb479
Code in flux. Regression related changes.
myrtw May 9, 2019
1b2b8d3
fixed right mse, it needs to be computed in kernel
vishalmehta1991 May 10, 2019
a6457fc
Added support for MSE or MAE split criterion.
myrtw May 13, 2019
d8176b7
Fixed split_criterion config in rf_test.
myrtw May 13, 2019
229bd03
relocating functors, adding inline to device functors, adding entropy…
vishalmehta1991 May 13, 2019
f2a8336
Removed useless mem alloc+added tmp testing script
myrtw May 13, 2019
76d2e7e
added iota and permute on GPU using thrust and ml-prims
vishalmehta1991 May 14, 2019
9a5f32c
Preprocess quantiles in batches.
myrtw May 14, 2019
b6ec0fb
merged new cuml dir structure
vishalmehta1991 May 15, 2019
b472bbd
Swapped cudaMemcpy w/ updateDevice/Host/Async.
myrtw May 15, 2019
d6027a9
Made rowids and colids unsigned int.
myrtw May 16, 2019
29ac9ee
now using minmax primitive with column sampler
vishalmehta1991 May 16, 2019
8c5ed32
deleted col_minmax kernel now using ml-prims
vishalmehta1991 May 16, 2019
280bce6
adding missing stream in cub
vishalmehta1991 May 17, 2019
8648363
Reordered call to find_best_fruit_all function.
myrtw May 21, 2019
d48e9c2
Fixed nbins bug for GLOBAL_QUANTILE.
myrtw May 27, 2019
21d30e9
Changelog update.
myrtw May 27, 2019
6d374b3
removing unused function parameters and some comments
vishalmehta1991 May 28, 2019
507601d
adding label sampler in tree build
vishalmehta1991 May 29, 2019
76a5d3b
name change for gini/mse metric files
vishalmehta1991 May 29, 2019
c5abd26
Renamed dt class; sorted rf rowIDs; del MemGetInfo
myrtw May 30, 2019
35f1b59
Moved plant, grow_tree methods to DecisionTreeBase class.
myrtw May 30, 2019
f8cb923
added depth zero helper function
vishalmehta1991 Jun 5, 2019
8417d37
Copied RF predictions on device. Added metrics too.
myrtw Jun 6, 2019
895ccf6
moving allocations outside the loop
vishalmehta1991 Jun 6, 2019
5250be8
Added helper for tree fit.
myrtw Jun 6, 2019
e67b1f1
Merge remote-tracking branch 'origin/branch-0.8' into fea-ext-randomf…
vishalmehta1991 Jun 7, 2019
1dd63f4
Made RF's data input for predictions a GPU ptr.
myrtw Jun 7, 2019
babb624
Added unit-tests for Accuracy score.
myrtw Jun 12, 2019
5f4a5fd
added support for when large number of features histograms do not fin…
vishalmehta1991 Jun 13, 2019
ae744ad
fixed missing plus sign
vishalmehta1991 Jun 13, 2019
fae84fe
Added unit-tests for regression metrics
myrtw Jun 13, 2019
dd7fa2d
Merge branch 'fea-ext-randomforest_regression' of github.com:vishalme…
vishalmehta1991 Jun 13, 2019
6f604cc
blocks max limit for classifier
vishalmehta1991 Jun 13, 2019
18ef5b2
Added support for wider datasets for minmax prim.
myrtw Jun 14, 2019
4ab84a2
loop around regressor kernels for large number of features
vishalmehta1991 Jun 14, 2019
83d4dfb
Merge branch 'fea-ext-randomforest_regression' of github.com:vishalme…
vishalmehta1991 Jun 14, 2019
f79f8fa
Minor kernel fix + helper function.
myrtw Jun 14, 2019
e677e55
Merge branch 'branch-0.8' of github.com:rapidsai/cuml into fea-ext-ra…
myrtw Jun 18, 2019
8d326b3
Python related updates to rf/dt, randomforest.pyx
myrtw Jun 18, 2019
b588047
WIP python wrapper changes.
myrtw Jun 18, 2019
1181890
Fixed Python wrapper
myrtw Jun 19, 2019
5907085
Merge branch 'branch-0.8' of github.com:rapidsai/cuml into fea-ext-ra…
myrtw Jun 19, 2019
51017ac
adding host quatile data structure and removing host mem copies for n…
vishalmehta1991 Jun 24, 2019
10473d3
quantile fix; once per RF; quantile per tree flag; default false; tem…
vishalmehta1991 Jun 26, 2019
95eb5f1
critical fix: entropy functor cannot have log(0)
vishalmehta1991 Jun 26, 2019
9ad7e0f
Merge branch 'branch-0.9' of github.com:rapidsai/cuml into fea-ext-ra…
myrtw Jun 27, 2019
db787ef
Python style fix.
myrtw Jun 28, 2019
ffb80e2
using proper split criterion in rf test, fixing max value range for e…
vishalmehta1991 Jul 1, 2019
c6c5a5e
changing ChangeLog entry to branch-0.9
vishalmehta1991 Jul 1, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## New Features
- PR #515: Added Random Projection feature
- PR #504: Contingency matrix ml-prim
- PR #635: Random Forest & Decision Tree Regression (Single-GPU)

## Improvements

Expand Down
5 changes: 5 additions & 0 deletions cpp/src/decisiontree/algo_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,9 @@ namespace ML {
enum SPLIT_ALGO {
HIST, GLOBAL_QUANTILE, SPLIT_ALGO_END,
};

enum CRITERION {
GINI, ENTROPY, MSE, MAE, CRITERION_END,
};

};
509 changes: 355 additions & 154 deletions cpp/src/decisiontree/decisiontree.cu

Large diffs are not rendered by default.

146 changes: 95 additions & 51 deletions cpp/src/decisiontree/decisiontree.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

#pragma once
#include "algo_helper.h"
#include "kernels/gini_def.h"
#include "memory.cuh"
#include "kernels/metric_def.h"
#include <common/Timer.h>
#include <vector>
#include <algorithm>
Expand All @@ -27,22 +26,25 @@
#include <common/cumlHandle.hpp>

namespace ML {

bool is_dev_ptr(const void *p);

namespace DecisionTree {

template<class T>
struct Question {
int column;
T value;
void update(const GiniQuestion<T> & ques);
void update(const MetricQuestion<T> & ques);
};

template<class T>
template<class T, class L>
struct TreeNode {
TreeNode *left = nullptr;
TreeNode *right = nullptr;
int class_predict;
L prediction;
Question<T> question;
T gini_val;
T split_metric_val;

void print(std::ostream& os) const;
};
Expand Down Expand Up @@ -81,75 +83,104 @@ struct DecisionTreeParams {
*/
int min_rows_per_node = 2;
/**
* Wheather to bootstarp columns with or without replacement
* Whether to bootstrap columns with or without replacement.
*/
bool bootstrap_features = false;

/**
* Node split criterion. GINI and Entropy for classification, MSE or MAE for regression.
*/
CRITERION split_criterion = CRITERION_END;

DecisionTreeParams();
DecisionTreeParams(int cfg_max_depth, int cfg_max_leaves, float cfg_max_features, int cfg_n_bins, int cfg_split_aglo, int cfg_min_rows_per_node, bool cfg_bootstrap_features);
DecisionTreeParams(int cfg_max_depth, int cfg_max_leaves, float cfg_max_features, int cfg_n_bins, int cfg_split_aglo, int cfg_min_rows_per_node, bool cfg_bootstrap_features, CRITERION cfg_split_criterion);
void validity_check() const;
void print() const;
};

template<class T>
class DecisionTreeClassifier {
template<class T, class L>
class DecisionTreeBase {
protected:
int split_algo;
TreeNode<T, L> *root = nullptr;
int nbins;
DataInfo dinfo;
int treedepth;
int depth_counter = 0;
int maxleaves;
int leaf_counter = 0;
std::vector<std::shared_ptr<TemporaryMemory<T, L>>> tempmem;
size_t total_temp_mem;
const int MAXSTREAMS = 1;
size_t max_shared_mem;
size_t shmem_used = 0;
int n_unique_labels = -1; // number of unique labels in dataset
double construct_time;
int min_rows_per_node;
bool bootstrap_features;
CRITERION split_criterion;
std::vector<unsigned int> feature_selector;

void print_node(const std::string& prefix, const TreeNode<T, L>* const node, bool isLeft) const;
void split_branch(T *data, MetricQuestion<T> & ques, const int n_sampled_rows, int& nrowsleft, int& nrowsright, unsigned int* rowids);

void plant(const cumlHandle_impl& handle, T *data, const int ncols, const int nrows, L *labels, unsigned int *rowids, const int n_sampled_rows, int unique_labels,
int maxdepth = -1, int max_leaf_nodes = -1, const float colper = 1.0, int n_bins = 8, int split_algo_flag = SPLIT_ALGO::HIST, int cfg_min_rows_per_node=2,
bool cfg_bootstrap_features=false, CRITERION cfg_split_criterion=CRITERION::CRITERION_END);
void init_depth_zero(const L* labels, std::vector<unsigned int>& colselector, const unsigned int* rowids, const int n_sampled_rows, const std::shared_ptr<TemporaryMemory<T,L>> tempmem);
TreeNode<T, L> * grow_tree(T *data, const float colper, L *labels, int depth, unsigned int* rowids, const int n_sampled_rows, MetricInfo<T> prev_split_info);
virtual void find_best_fruit_all(T *data, L *labels, const float colper, MetricQuestion<T> & ques, float& gain, unsigned int* rowids,
const int n_sampled_rows, MetricInfo<T> split_info[3], int depth) = 0;
void base_fit(const ML::cumlHandle& handle, T *data, const int ncols, const int nrows, L *labels, unsigned int *rowids,
const int n_sampled_rows, int unique_labels, DecisionTreeParams & tree_params, bool is_classifier);

public:
// Printing utility for high level tree info.
void print_tree_summary() const;

// Printing utility for debug and looking at nodes and leaves.
void print() const;

// Predict labels for n_rows rows, with n_cols features each, for a given tree. rows in row-major format.
void predict(const ML::cumlHandle& handle, const T * rows, const int n_rows, const int n_cols, L * predictions, bool verbose=false) const;
void predict_all(const T * rows, const int n_rows, const int n_cols, L * preds, bool verbose=false) const;
L predict_one(const T * row, const TreeNode<T, L> * const node, bool verbose=false) const;

}; // End DecisionTreeBase Class

private:
int split_algo;
TreeNode<T> *root = nullptr;
int nbins;
DataInfo dinfo;
int treedepth;
int depth_counter = 0;
int maxleaves;
int leaf_counter = 0;
std::vector<std::shared_ptr<TemporaryMemory<T>>> tempmem;
size_t total_temp_mem;
const int MAXSTREAMS = 1;
size_t max_shared_mem;
size_t shmem_used = 0;
int n_unique_labels = -1; // number of unique labels in dataset
double construct_time;
int min_rows_per_node;
bool bootstrap_features;
std::vector<int> feature_selector;

template<class T>
class DecisionTreeClassifier : public DecisionTreeBase<T, int> {
public:
// Expects column major T dataset, integer labels
// data, labels are both device ptr.
// Assumption: labels are all mapped to contiguous numbers starting from 0 during preprocessing. Needed for gini hist impl.
void fit(const ML::cumlHandle& handle, T *data, const int ncols, const int nrows, int *labels, unsigned int *rowids,
const int n_sampled_rows, const int unique_labels, DecisionTreeParams tree_params);

/* Predict labels for n_rows rows, with n_cols features each, for a given tree. rows in row-major format. */
void predict(const ML::cumlHandle& handle, const T * rows, const int n_rows, const int n_cols, int* predictions, bool verbose=false) const;

// Printing utility for high level tree info.
void print_tree_summary() const;

// Printing utility for debug and looking at nodes and leaves.
void print() const;

private:
// Same as above fit, but planting is better for a tree then fitting.
void plant(const cumlHandle_impl& handle, T *data, const int ncols, const int nrows, int *labels, unsigned int *rowids, const int n_sampled_rows, int unique_labels,
int maxdepth = -1, int max_leaf_nodes = -1, const float colper = 1.0, int n_bins = 8, int split_algo_flag = SPLIT_ALGO::HIST, int cfg_min_rows_per_node=2, bool cfg_bootstrap_features=false);
/* depth is used to distinguish between root and other tree nodes for computations */
void find_best_fruit_all(T *data, int *labels, const float colper, MetricQuestion<T> & ques, float& gain, unsigned int* rowids,
const int n_sampled_rows, MetricInfo<T> split_info[3], int depth);
}; // End DecisionTreeClassifier Class

TreeNode<T> * grow_tree(T *data, const float colper, int *labels, int depth, unsigned int* rowids, const int n_sampled_rows, GiniInfo prev_split_info);
template<class T>
class DecisionTreeRegressor : public DecisionTreeBase<T, T> {
public:
void fit(const ML::cumlHandle& handle, T *data, const int ncols, const int nrows, T *labels, unsigned int *rowids,
const int n_sampled_rows, DecisionTreeParams tree_params);

private:
/* depth is used to distinguish between root and other tree nodes for computations */
void find_best_fruit_all(T *data, int *labels, const float colper, GiniQuestion<T> & ques, float& gain, unsigned int* rowids,
const int n_sampled_rows, GiniInfo split_info[3], int depth);
void split_branch(T *data, GiniQuestion<T> & ques, const int n_sampled_rows, int& nrowsleft, int& nrowsright, unsigned int* rowids);
void classify_all(const T * rows, const int n_rows, const int n_cols, int* preds, bool verbose=false) const;
int classify(const T * row, const TreeNode<T> * const node, bool verbose=false) const;
void print_node(const std::string& prefix, const TreeNode<T>* const node, bool isLeft) const;
}; // End DecisionTree Class
void find_best_fruit_all(T *data, T *labels, const float colper, MetricQuestion<T> & ques, float& gain, unsigned int* rowids,
const int n_sampled_rows, MetricInfo<T> split_info[3], int depth);
}; // End DecisionTreeRegressor Class

} //End namespace DecisionTree


// Stateless API functions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one very interesting issue that just arose while @Salonijain27 was writing cython classes for Random Forest: this are not stateless APIs! This functions seem to just be an alternative way of calling the DecisionTreeClassifier (or equivalent RF class) Fit, but the object is still stateful! A stateless API means that the state is kept by the client (python) and the C++ code does not keep the state, in this case the C++ object still keeps the state in spite of there being a fit function outside of it.

This not only complicates the cython work (not a deal breaker and it can be done to work) but also potentially complicates the pickling (model saving) functionality. For all the algorithms that are stateless (most of cuML) we automatically get that functionality since the python objects are the state of the algorithm. Wanted to start a quick conversation regarding this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Of course feel free to correct any of this if the analysis/conclusions above are wrong!)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. that's interesting. IIUC, UMAP/kNN both expose such "stateful" classes, no? How are they being handled in cython world?

@cjnolet ^^^

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes they do, and the reason UMAP and KNN are the 2 classes that cannot be pickled currently (this was raised only a few weeks ago when some user ran into that issue). There is even an open issue #415 to create a stateless version of knn

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this is not necessarily something that needs to change per se, though it could make things much easier. We can look into using the __getstate__ and __setstate__ python functions to implement the pickling. Will be looking into this today and the first days of next week. But even then the question of whether we want a uniformity in stateless(or fulness) of the C++ API still stands

Copy link
Member

@cjnolet cjnolet May 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed in the initial PR for the RF and DT algorithms that the stateless functions are being called by functions that maintain their own state and not the other way around. Unfortunately, as we were ramping up to release, I did not have the time to bring it up before it was merged. UMAP was implemented with a stateful convenience API built on top of a set of stateless functions and converting the cython to use those stateless functions will be trivial.

KNN was originally forced to be stateful because the FAISS Indexing API was designed in such a way that wrapping it with cython was non-trivial and the state was better off hidden from the cython layer altogether. Recent updates to their API have eliminated this problem and kNN will soon be reduced to a simple stateless kneighbors function on the C++ side. I'm also considering moving it to the prims.

I do think we should continue to provide stateful convenience classes in the C++ layer. I would recommend, as a solution here, to expose a set of flat stateless functions, have the caller maintain the state, and also expose a stateful convenience API that is based on those stateless functions, maintaining the state for the users of the C++ API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We plan to implement a newer version of flat API for 0.9 release. The current rfRegressor API is inline to branch-0.8 rfClassifier

// ----------------------------- Classification ----------------------------------- //

void fit(const ML::cumlHandle& handle, DecisionTree::DecisionTreeClassifier<float> * dt_classifier, float *data, const int ncols, const int nrows, int *labels,
unsigned int *rowids, const int n_sampled_rows, int unique_labels, DecisionTree::DecisionTreeParams tree_params);

Expand All @@ -161,4 +192,17 @@ void predict(const ML::cumlHandle& handle, const DecisionTree::DecisionTreeClass
void predict(const ML::cumlHandle& handle, const DecisionTree::DecisionTreeClassifier<double> * dt_classifier, const double * rows,
const int n_rows, const int n_cols, int* predictions, bool verbose=false);

// ----------------------------- Regression ----------------------------------- //

void fit(const ML::cumlHandle& handle, DecisionTree::DecisionTreeRegressor<float> * dt_regressor, float *data, const int ncols, const int nrows, float *labels,
unsigned int *rowids, const int n_sampled_rows, DecisionTree::DecisionTreeParams tree_params);

void fit(const ML::cumlHandle& handle, DecisionTree::DecisionTreeRegressor<double> * dt_regressor, double *data, const int ncols, const int nrows, double *labels,
unsigned int *rowids, const int n_sampled_rows, DecisionTree::DecisionTreeParams tree_params);

void predict(const ML::cumlHandle& handle, const DecisionTree::DecisionTreeRegressor<float> * dt_regressor, const float * rows,
const int n_rows, const int n_cols, float * predictions, bool verbose=false);
void predict(const ML::cumlHandle& handle, const DecisionTree::DecisionTreeRegressor<double> * dt_regressor, const double * rows,
const int n_rows, const int n_cols, double * predictions, bool verbose=false);

} //End namespace ML
38 changes: 38 additions & 0 deletions cpp/src/decisiontree/kernels/batch_cal.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
/* Return max. possible number of columns that can be processed within avail_shared_memory.
Expects that requested_shared_memory is a multiple of ncols. */
int get_batch_cols_cnt(const size_t avail_shared_memory, const size_t requested_shared_memory, const int ncols) {
int ncols_in_batch = ncols;
int ncols_factor = requested_shared_memory / ncols;
if (requested_shared_memory > avail_shared_memory) {
ncols_in_batch = avail_shared_memory / ncols_factor; // floor div.
}
return ncols_in_batch;
}


/* Update batch_ncols (max. possible number of columns that can be processed within avail_shared_memory),
blocks (for next kernel launch), and shmemsize (requested shared memory for next kernel launch).
Precondition: requested_shared_memory is a multiple of ncols. */
void update_kernel_config(const size_t avail_shared_memory, const size_t requested_shared_memory, const int ncols,
const int nrows, const int threads, int & batch_ncols, int & blocks, size_t & shmemsize) {
batch_ncols = get_batch_cols_cnt(avail_shared_memory, requested_shared_memory, ncols);
shmemsize = (requested_shared_memory / ncols) * batch_ncols; // requested_shared_memory is a multiple of ncols for all kernels
blocks = min(MLCommon::ceildiv(batch_ncols * nrows, threads), 65536);
}
55 changes: 7 additions & 48 deletions cpp/src/decisiontree/kernels/col_condenser.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,71 +28,30 @@ __global__ void get_sampled_column_kernel(const T* __restrict__ column, T *outco
return;
}

void get_sampled_labels(const int *labels, int *outlabels, unsigned int* rowids, const int n_sampled_rows, const cudaStream_t stream) {
template<typename T>
void get_sampled_labels(const T *labels, T *outlabels, const unsigned int* rowids, const int n_sampled_rows, const cudaStream_t stream) {
int threads = 128;
get_sampled_column_kernel<int><<<MLCommon::ceildiv(n_sampled_rows, threads), threads, 0, stream>>>(labels, outlabels, rowids, n_sampled_rows);
get_sampled_column_kernel<T><<<MLCommon::ceildiv(n_sampled_rows, threads), threads, 0, stream>>>(labels, outlabels, rowids, n_sampled_rows);
CUDA_CHECK(cudaGetLastError());
return;
}

template<typename T>
__global__ void allcolsampler_kernel(const T* __restrict__ data, const unsigned int* __restrict__ rowids, const int* __restrict__ colids, const int nrows, const int ncols, const int rowoffset, T* sampledcols)
__global__ void allcolsampler_kernel(const T* __restrict__ data, const unsigned int* __restrict__ rowids, const unsigned int* __restrict__ colids, const int nrows, const int ncols, const int rowoffset, T* sampledcols)
{
int tid = threadIdx.x + blockIdx.x * blockDim.x;

for (unsigned int i = tid; i < nrows*ncols; i += blockDim.x*gridDim.x) {
int newcolid = (int)(i / nrows);
int myrowstart;
if( colids != nullptr)
if (colids != nullptr) {
myrowstart = colids[ newcolid ] * rowoffset;
else
} else {
myrowstart = newcolid * rowoffset;
}

int index = rowids[ i % nrows] + myrowstart;
sampledcols[i] = data[index];
}
return;
}

template<typename T>
__global__ void allcolsampler_minmax_kernel(const T* __restrict__ data, const unsigned int* __restrict__ rowids, const int* __restrict__ colids, const int nrows, const int ncols, const int rowoffset, T* globalmin, T* globalmax, T* sampledcols, T init_min_val)
{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
extern __shared__ char shmem[];
T *minshared = (T*)shmem;
T *maxshared = (T*)(shmem + sizeof(T) * ncols);

for (int i = threadIdx.x; i < ncols; i += blockDim.x) {
minshared[i] = init_min_val;
maxshared[i] = -init_min_val;
}

// Initialize min max in global memory
if (tid < ncols) {
globalmin[tid] = init_min_val;
globalmax[tid] = -init_min_val;
}

__syncthreads();

for (unsigned int i = tid; i < nrows*ncols; i += blockDim.x*gridDim.x) {
int newcolid = (int)(i / nrows);
int myrowstart = colids[ newcolid ] * rowoffset;
int index = rowids[ i % nrows] + myrowstart;
T coldata = data[index];

MLCommon::myAtomicMin(&minshared[newcolid], coldata);
MLCommon::myAtomicMax(&maxshared[newcolid], coldata);
sampledcols[i] = coldata;
}

__syncthreads();

for (int j = threadIdx.x; j < ncols; j+= blockDim.x) {
MLCommon::myAtomicMin(&globalmin[j], minshared[j]);
MLCommon::myAtomicMax(&globalmax[j], maxshared[j]);
}

return;
}

Loading