Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed potential overflows in SVM, minor adjustments to nvtx ranges #5504

Merged
merged 1 commit into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/src/svm/kernelcache.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ class KernelCache {
batch_size_base = n_rows;

// enable batching for kernel > 1 GB (default)
if (n_rows * n_ws * sizeof(math_t) > kernel_tile_byte_limit) {
if ((size_t)n_rows * n_ws * sizeof(math_t) > kernel_tile_byte_limit) {
batching_enabled = true;
// only select based on desired big-kernel size
batch_size_base = std::max(1ul, kernel_tile_byte_limit / n_ws / sizeof(math_t));
Expand All @@ -373,7 +373,7 @@ class KernelCache {

// enable sparse row extraction for sparse input where n_ws * n_cols > 1 GB
// Warning: kernel computation will be much slower!
if (is_csr && (n_cols * n_ws * sizeof(math_t) > dense_extract_byte_limit)) {
if (is_csr && ((size_t)n_cols * n_ws * sizeof(math_t) > dense_extract_byte_limit)) {
sparse_extract = true;
}

Expand Down
3 changes: 2 additions & 1 deletion cpp/src/svm/results.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ class Results {
{
SupportStorage<math_t> support_matrix;
// allow ~1GB dense support matrix
if (isDenseType<MatrixViewType>() || (n_support * n_cols * sizeof(math_t) < (1 << 30))) {
if (isDenseType<MatrixViewType>() ||
((size_t)n_support * n_cols * sizeof(math_t) < (1 << 30))) {
support_matrix.data =
(math_t*)rmm_alloc->allocate(n_support * n_cols * sizeof(math_t), stream);
ML::SVM::extractRows<math_t>(matrix, support_matrix.data, idx, n_support, handle);
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/svm/smosolver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class SmoSolver {
raft::update_host(host_return_buff, return_buff.data(), 2, stream);
raft::common::nvtx::pop_range();
raft::common::nvtx::push_range("SmoSolver::UpdateF");
raft::common::nvtx::push_range("SmoSolver::UpdateF::getNnzDaRows");
int nnz_da;
GetNonzeroDeltaAlpha(delta_alpha.data(),
n_ws,
Expand All @@ -232,11 +233,10 @@ class SmoSolver {
RAFT_CUDA_TRY(cudaPeekAtLastError());
// The following should be performed only for elements with nonzero delta_alpha
if (nnz_da > 0) {
raft::common::nvtx::push_range("SmoSolver::UpdateF::getNnzDaRows");
auto batch_descriptor = cache.InitFullTileBatching(nz_da_idx.data(), nnz_da);
raft::common::nvtx::pop_range();

while (cache.getNextBatchKernel(batch_descriptor)) {
raft::common::nvtx::pop_range();
raft::common::nvtx::push_range("SmoSolver::UpdateF::updateBatch");
// do (partial) update
UpdateF(f.data() + batch_descriptor.offset,
Expand All @@ -245,10 +245,10 @@ class SmoSolver {
nnz_da,
batch_descriptor.kernel_data);
RAFT_CUDA_TRY(cudaPeekAtLastError());
raft::common::nvtx::pop_range();
}
}
handle.sync_stream(stream);
raft::common::nvtx::pop_range();
raft::common::nvtx::pop_range(); // ("SmoSolver::UpdateF");

math_t diff = host_return_buff[0];
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/svm/svc_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ void svcFitX(const raft::handle_t& handle,
&(model.b),
param.max_iter);
model.n_cols = n_cols;
handle_impl.sync_stream(stream);
delete kernel;
}

Expand Down Expand Up @@ -150,7 +151,7 @@ void svcPredictX(const raft::handle_t& handle,
int n_batch = n_rows;
// Limit the memory size of the prediction buffer
buffer_size = buffer_size * 1024 * 1024;
if (n_batch * model.n_support * sizeof(math_t) > buffer_size) {
if ((size_t)n_batch * model.n_support * sizeof(math_t) > buffer_size) {
n_batch = buffer_size / (model.n_support * sizeof(math_t));
if (n_batch < 1) n_batch = 1;
}
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/svm/svc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ class SVC(SVMBase,
cdef int n_rows = self.n_rows
cdef int n_cols = self.n_cols

cdef int n_nnz = X_m.nnz if is_sparse else self.n_rows * self.n_cols
cdef int n_nnz = X_m.nnz if is_sparse else -1
cdef uintptr_t X_indptr = X_m.indptr.ptr if is_sparse else X_m.ptr
cdef uintptr_t X_indices = X_m.indices.ptr if is_sparse else X_m.ptr
cdef uintptr_t X_data = X_m.data.ptr if is_sparse else X_m.ptr
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/svm/svr.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class SVR(SVMBase, RegressorMixin):

cdef int n_rows = self.n_rows
cdef int n_cols = self.n_cols
cdef int n_nnz = X_m.nnz if is_sparse else self.n_rows * self.n_cols
cdef int n_nnz = X_m.nnz if is_sparse else -1
cdef uintptr_t X_indptr = X_m.indptr.ptr if is_sparse else X_m.ptr
cdef uintptr_t X_indices = X_m.indices.ptr if is_sparse else X_m.ptr
cdef uintptr_t X_data = X_m.data.ptr if is_sparse else X_m.ptr
Expand Down