Skip to content

Commit

Permalink
Fix SVM model parameter handling in case n_support=0 (#4097)
Browse files Browse the repository at this point in the history
Fixes #4033

This PR fixes SVM model parameter handling in case the fitted model has no support vectors, only bias.
C++ side changes:
- The bias calculation is updated to calculate the bias as the average function value in this case.
- The prediction function is modified to avoid kernel function calculation in this case.
- Added an SVR unit test to check model fitting and prediction.

Python side changes:
- It was incorrectly assumed that n_support==0 means the model is not fitted correctly, this is removed.
- Model attributes (`dual_coef_`, `support_`, `support_vectors_`) are defined as empty arrays in this case.
- `coef_` attribute is an array of zeros if there are no support vectors.
- Unit test added to check training prediction and model attributes.

Authors:
  - Tamas Bela Feher (https://github.com/tfeher)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #4097
  • Loading branch information
tfeher authored Jul 26, 2021
1 parent 40af8af commit cb32219
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 44 deletions.
18 changes: 13 additions & 5 deletions cpp/src/svm/results.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,17 @@ class Results {
{
CombineCoefs(alpha, val_tmp.data());
GetDualCoefs(val_tmp.data(), dual_coefs, n_support);
*b = CalcB(alpha, f, *n_support);
if (*n_support > 0) {
*idx = GetSupportVectorIndices(val_tmp.data(), *n_support);
*x_support = CollectSupportVectors(*idx, *n_support);
*b = CalcB(alpha, f);
// Make sure that all pending GPU calculations finished before we return
CUDA_CHECK(cudaStreamSynchronize(stream));
} else {
*dual_coefs = nullptr;
*idx = nullptr;
*x_support = nullptr;
}
// Make sure that all pending GPU calculations finished before we return
CUDA_CHECK(cudaStreamSynchronize(stream));
}

/**
Expand Down Expand Up @@ -192,6 +192,7 @@ class Results {
*n_support = SelectByCoef(val_tmp, n_rows, val_tmp, select_op, val_selected.data());
*dual_coefs = (math_t*)allocator->allocate(*n_support * sizeof(math_t), stream);
raft::copy(*dual_coefs, val_selected.data(), *n_support, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
}

/**
Expand All @@ -218,15 +219,22 @@ class Results {
* @param [in] f optimality indicator vector, size [n_rows]
* @return the value of b
*/
math_t CalcB(const math_t* alpha, const math_t* f)
math_t CalcB(const math_t* alpha, const math_t* f, int n_support)
{
if (n_support == 0) {
math_t f_sum;
cub::DeviceReduce::Sum(
cub_storage.data(), cub_bytes, f, d_val_reduced.data(), n_train, stream);
raft::update_host(&f_sum, d_val_reduced.data(), 1, stream);
return -f_sum / n_train;
}
// We know that for an unbound support vector i, the decision function
// (before taking the sign) has value F(x_i) = y_i, where
// F(x_i) = \sum_j y_j \alpha_j K(x_j, x_i) + b, and j runs through all
// support vectors. The constant b can be expressed from these formulas.
// Note that F and f denote different quantities. The lower case f is the
// optimality indicator vector defined as
// f_i = y_i - \sum_j y_j \alpha_j K(x_j, x_i).
// f_i = - y_i + \sum_j y_j \alpha_j K(x_j, x_i).
// For unbound support vectors f_i = -b.

// Select f for unbound support vectors (0 < alpha < C)
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/svm/smosolver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class SmoSolver {
raft::linalg::unaryOp(
f, yr, n_rows, [epsilon] __device__(math_t y) { return epsilon - y; }, stream);

// f_i = epsilon - y_i, for i \in [n_rows..2*n_rows-1]
// f_i = -epsilon - y_i, for i \in [n_rows..2*n_rows-1]
raft::linalg::unaryOp(
f + n_rows, yr, n_rows, [epsilon] __device__(math_t y) { return -epsilon - y; }, stream);
}
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/svm/svc_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ void svcPredict(const raft::handle_t& handle,
MLCommon::device_buffer<math_t> K(
handle_impl.get_device_allocator(), stream, n_batch * model.n_support);
MLCommon::device_buffer<math_t> y(handle_impl.get_device_allocator(), stream, n_rows);
if (model.n_support == 0) {
CUDA_CHECK(cudaMemsetAsync(y.data(), 0, n_rows * sizeof(math_t), stream));
}
MLCommon::device_buffer<math_t> x_rbf(handle_impl.get_device_allocator(), stream);
MLCommon::device_buffer<int> idx(handle_impl.get_device_allocator(), stream);

Expand All @@ -137,7 +140,7 @@ void svcPredict(const raft::handle_t& handle,
// We process the input data batchwise:
// - calculate the kernel values K[x_batch, x_support]
// - calculate y(x_batch) = K[x_batch, x_support] * dual_coeffs
for (int i = 0; i < n_rows; i += n_batch) {
for (int i = 0; i < n_rows && model.n_support > 0; i += n_batch) {
if (i + n_batch >= n_rows) { n_batch = n_rows - i; }
math_t* x_ptr = nullptr;
int ld1 = 0;
Expand Down
11 changes: 10 additions & 1 deletion cpp/test/sg/svc_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,16 @@ class SvrTest : public ::testing::Test {
{1, 1, 1, 10, 2, 10, 1} // sample weights
},
smoOutput2<math_t>{
6, {}, -15.5, {3.9}, {1.0, 2.0, 3.0, 4.0, 6.0, 7.0}, {0, 1, 2, 3, 5, 6}, {}}}};
6, {}, -15.5, {3.9}, {1.0, 2.0, 3.0, 4.0, 6.0, 7.0}, {0, 1, 2, 3, 5, 6}, {}}},
{SvrInput<math_t>{
svmParameter{1, 0, 100, 10, 1e-6, CUML_LEVEL_INFO, 0.1, EPSILON_SVR},
KernelParams{LINEAR, 3, 1, 0},
7, // n_rows
1, // n_cols
{1, 2, 3, 4, 5, 6, 7}, // x
{2, 2, 2, 2, 2, 2, 2} // y
},
smoOutput2<math_t>{0, {}, 2, {}, {}, {}, {}}}};
for (auto d : data) {
auto p = d.first;
auto exp = d.second;
Expand Down
90 changes: 54 additions & 36 deletions python/cuml/svm/svm_base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ class SVMBase(Base,
return self.gamma

def _calc_coef(self):
if (self.n_support_ == 0):
return cupy.zeros((1, self.n_cols), dtype=self.dtype)
with using_output_type("cupy"):
return cupy.dot(self.dual_coef_, self.support_vectors_)

Expand Down Expand Up @@ -429,29 +431,28 @@ class SVMBase(Base,

if self.dtype == np.float32:
model_f = <svmModel[float]*><uintptr_t> self._model
if model_f.n_support == 0:
self._fit_status_ = 1 # incorrect fit
return
self._intercept_ = CumlArray.full(1, model_f.b, np.float32)
self.n_support_ = model_f.n_support

self.dual_coef_ = CumlArray(
data=<uintptr_t>model_f.dual_coefs,
shape=(1, self.n_support_),
dtype=self.dtype,
order='F')
if model_f.n_support > 0:
self.dual_coef_ = CumlArray(
data=<uintptr_t>model_f.dual_coefs,
shape=(1, self.n_support_),
dtype=self.dtype,
order='F')

self.support_ = CumlArray(
data=<uintptr_t>model_f.support_idx,
shape=(self.n_support_,),
dtype=np.int32,
order='F')
self.support_ = CumlArray(
data=<uintptr_t>model_f.support_idx,
shape=(self.n_support_,),
dtype=np.int32,
order='F')

self.support_vectors_ = CumlArray(
data=<uintptr_t>model_f.x_support,
shape=(self.n_support_, self.n_cols),
dtype=self.dtype,
order='F')

self.support_vectors_ = CumlArray(
data=<uintptr_t>model_f.x_support,
shape=(self.n_support_, self.n_cols),
dtype=self.dtype,
order='F')
self.n_classes_ = model_f.n_classes
if self.n_classes_ > 0:
self._unique_labels_ = CumlArray(
Expand All @@ -463,29 +464,28 @@ class SVMBase(Base,
self._unique_labels_ = None
else:
model_d = <svmModel[double]*><uintptr_t> self._model
if model_d.n_support == 0:
self._fit_status_ = 1 # incorrect fit
return
self._intercept_ = CumlArray.full(1, model_d.b, np.float64)
self.n_support_ = model_d.n_support

self.dual_coef_ = CumlArray(
data=<uintptr_t>model_d.dual_coefs,
shape=(1, self.n_support_),
dtype=self.dtype,
order='F')
if model_d.n_support > 0:
self.dual_coef_ = CumlArray(
data=<uintptr_t>model_d.dual_coefs,
shape=(1, self.n_support_),
dtype=self.dtype,
order='F')

self.support_ = CumlArray(
data=<uintptr_t>model_d.support_idx,
shape=(self.n_support_,),
dtype=np.int32,
order='F')
self.support_ = CumlArray(
data=<uintptr_t>model_d.support_idx,
shape=(self.n_support_,),
dtype=np.int32,
order='F')

self.support_vectors_ = CumlArray(
data=<uintptr_t>model_d.x_support,
shape=(self.n_support_, self.n_cols),
dtype=self.dtype,
order='F')

self.support_vectors_ = CumlArray(
data=<uintptr_t>model_d.x_support,
shape=(self.n_support_, self.n_cols),
dtype=self.dtype,
order='F')
self.n_classes_ = model_d.n_classes
if self.n_classes_ > 0:
self._unique_labels_ = CumlArray(
Expand All @@ -496,6 +496,24 @@ class SVMBase(Base,
else:
self._unique_labels_ = None

if self.n_support_ == 0:
self.dual_coef_ = CumlArray.empty(
shape=(1, 0),
dtype=self.dtype,
order='F')

self.support_ = CumlArray.empty(
shape=(0,),
dtype=np.int32,
order='F')

# Setting all dims to zero due to issue
# https://github.com/rapidsai/cuml/issues/4095
self.support_vectors_ = CumlArray.empty(
shape=(0, 0),
dtype=self.dtype,
order='F')

def predict(self, X, predict_class, convert_dtype=True) -> CumlArray:
"""
Predicts the y for X, where y is either the decision function value
Expand Down
22 changes: 22 additions & 0 deletions python/cuml/test/test_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#

import pytest
import cupy as cp
import numpy as np
from numba import cuda

Expand Down Expand Up @@ -681,3 +682,24 @@ def test_svm_predict_convert_dtype(train_dtype, test_dtype, classifier):
clf = cu_svm.SVR()
clf.fit(X_train, y_train)
clf.predict(X_test.astype(test_dtype))


def test_svm_no_support_vectors():
n_rows = 10
n_cols = 3
X = cp.random.uniform(size=(n_rows, n_cols), dtype=cp.float64)
y = cp.ones((n_rows, 1))
model = cuml.svm.SVR(kernel="linear", C=10)
model.fit(X, y)
pred = model.predict(X)

assert array_equal(pred, y, 0)

assert model.n_support_ == 0
assert abs(model.intercept_ - 1) <= 1e-6
assert array_equal(model.coef_, cp.zeros((1, n_cols)))
assert model.dual_coef_.shape == (1, 0)
assert model.support_.shape == (0,)
assert model.support_vectors_.shape[0] == 0
# Check disabled due to https://github.com/rapidsai/cuml/issues/4095
# assert model.support_vectors_.shape[1] == n_cols

0 comments on commit cb32219

Please sign in to comment.