Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Vector Machine #912

Merged
merged 103 commits into from
Sep 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
ec5def4
Add LRU Cache prim
tfeher Jun 11, 2019
66d5d32
Remove unused printouts, edit doc and formatting
tfeher Jun 12, 2019
4d4aa09
Update changelog
tfeher Jul 16, 2019
7092cfb
Add general Gram matrices prim
tfeher Jul 21, 2019
85e6c49
Fix const attribute for input vectors
tfeher Jul 12, 2019
aaeced4
Add missing sync and remove printouts
tfeher Jul 21, 2019
c378a33
Merge branch 'fea-ext-cache' into fea-ext-svm-rebase
tfeher Jul 31, 2019
525aeb9
Merge branch 'fea-ext-gram' into fea-ext-svm-rebase
tfeher Jul 31, 2019
d377eed
Mark const methods as such
tfeher Jul 30, 2019
210f4fc
Add Support Vector Machine classifier
tfeher Jul 30, 2019
9fa3902
Add stateless C++ layer below the stateful Sklearn like C++ class
tfeher Jul 31, 2019
4f13e63
Update Changelog
tfeher Jul 31, 2019
106186f
Fix namespace and style issues for the python wrapper. Add missing __…
tfeher Jul 31, 2019
17ff98b
Documentation fixes, remove unused printouts, use asynchronous memset
tfeher Aug 1, 2019
c194aeb
Add one more @brief tag
tfeher Aug 1, 2019
69c859e
Add example usage to the doc
tfeher Aug 1, 2019
12c06f8
Rename variables and edit doc to improve readability
tfeher Aug 1, 2019
c21d9e7
Use input_to_dev_array
tfeher Aug 1, 2019
a4d6d63
Add verbose flag
tfeher Aug 2, 2019
1f91df7
Use unnamed namespace instead of inline keyword to avoid multiple def…
tfeher Aug 2, 2019
58c9347
Rename variables to increase readability of tests
tfeher Aug 2, 2019
815a324
Add device_arra_from_ptr helper function
tfeher Aug 4, 2019
9f87471
Add attributes that expose the parameters of the SVM
tfeher Aug 4, 2019
470d320
Make flake happy
tfeher Aug 4, 2019
640b196
Fix const attribute for ws_util args
tfeher Aug 4, 2019
be8dade
Free buffers before training.
tfeher Aug 4, 2019
35ba93b
Fix coef_ attribute for linear kernel
tfeher Aug 4, 2019
b0f5477
Handle degenerate kernels (with negative eta), edit doc
tfeher Aug 5, 2019
e57f47f
Edit doc
tfeher Aug 5, 2019
f8e2019
Fix CalcB for cases when all SVs are bound.
tfeher Aug 5, 2019
f012aa1
Merge branch 'branch-0.9' into fea-ext-cache
tfeher Aug 5, 2019
e62bddb
Handle different x and y data types in python wrapper
tfeher Aug 7, 2019
54ce017
Improve python wrapper.
tfeher Aug 7, 2019
f67eaba
Set default kernel to RBF
tfeher Aug 7, 2019
338a7be
Add pytest for SVM
tfeher Aug 7, 2019
a81fff3
Add verbose option
tfeher Aug 7, 2019
507f6ee
Merge branch 'branch-0.9' into fea-ext-cache
tfeher Aug 8, 2019
c7ae688
edit docstring
tfeher Aug 8, 2019
5183635
Use allocator->deallocate in SVC::free_buffers
tfeher Aug 8, 2019
9c05bac
remove using namespace
tfeher Aug 8, 2019
13a637b
Correct doc
tfeher Aug 8, 2019
fb2e0a5
Use SelectByAlpha for dual coefs
tfeher Aug 9, 2019
1d64dbe
Remove using namespace from results.h
tfeher Aug 9, 2019
68ed902
Remove using namespace from headers, edit docstring and comments.
tfeher Aug 9, 2019
3bd097f
Merge branch 'fea-ext-cache' into fea-ext-svm
tfeher Aug 9, 2019
c01fd54
Update to modified cache API
tfeher Aug 9, 2019
97c9d0c
Fix doc, remove using namespace from header
tfeher Aug 11, 2019
fb29ee6
Introduce typed tests for workingset and value parametrized set for k…
tfeher Aug 11, 2019
f55488e
Typed KernelCacheTest
tfeher Aug 11, 2019
7d49abf
Fix docstring
tfeher Aug 12, 2019
a23ff3a
Forgot to odd pytest. Work in progress.
tfeher Aug 12, 2019
0f394a3
Move range into ml-prims, use ASSERT in SelectReduce
tfeher Aug 12, 2019
e9e6afa
Correct doc
tfeher Aug 12, 2019
39f218e
Update docstring for svc
tfeher Aug 12, 2019
48030ef
Add linalg/init.h
tfeher Aug 12, 2019
7f135b9
Make flake happy
tfeher Aug 13, 2019
c0917a5
Fix end_bit for sorting
tfeher Aug 13, 2019
a64e60d
Fix python formatting
tfeher Aug 16, 2019
22c1a46
Merge branch 'branch-0.10' into fea-ext-svm
tfeher Aug 22, 2019
466ce67
Make KernelParams POD
tfeher Aug 22, 2019
159d20b
Rename files for SVC
tfeher Aug 22, 2019
d81a869
Add stateless cpp interface. The original templated implementation is…
tfeher Aug 22, 2019
471a4e4
Keep templates in the C++ API.
tfeher Aug 23, 2019
e835b9d
Group SVM parameters into structures.
tfeher Aug 26, 2019
3c320af
Adapt SVM python wrapper to use svmParameter/svmModel structs.
tfeher Aug 26, 2019
743b772
Merge branch 'branch-0.10' into fea-ext-svm
tfeher Aug 26, 2019
ab94240
Fix get/setstate for pickle
tfeher Aug 26, 2019
bc74a53
Restore test/CMakelists.txt after accidental commit
tfeher Aug 26, 2019
ec242c0
Add option numba DeviceNDArray for to_nparray
tfeher Aug 26, 2019
ac4feb2
Test pickling for SVM
tfeher Aug 26, 2019
2fa6dc6
Test different floating point and array formats for SVM
tfeher Aug 26, 2019
24b0ad5
Add C wrappers for SVM
tfeher Aug 27, 2019
ccbdb75
Fix type conversion for 1D numba arrays and cudf Series
tfeher Aug 28, 2019
b001b24
Fix test for input arrays
tfeher Aug 28, 2019
9fcbbe9
Fix testing whether model is fitted
tfeher Aug 28, 2019
2ad519d
Test with different datasets, array shapes, and fit/predict many times
tfeher Aug 28, 2019
6f29aea
Remove whitespace
tfeher Aug 28, 2019
548c528
Refer to issue in docsring
tfeher Aug 29, 2019
d91c929
Test with blobs, every test runs both in single and double precision.
tfeher Aug 29, 2019
83829c0
Add docstring to svc.hpp
tfeher Aug 30, 2019
0cdd285
Merge branch 'branch-0.10' into fea-ext-svm
tfeher Sep 12, 2019
6131356
Fix array size during workingset init
tfeher Sep 9, 2019
08bb112
Fix destructor for SVC class
tfeher Sep 12, 2019
d454e42
Fix memory leak test
tfeher Sep 12, 2019
21b5789
Enable more blobs tests, and add memory leak gtest
tfeher Sep 12, 2019
1acbefc
Fix blobs cluster centers, adjust expected results.
tfeher Sep 18, 2019
4a0534b
Fix SVC blobs test: array order was incorrect.
tfeher Sep 18, 2019
5d5b7b7
Fix SVC tests: increase tolerance for n_sv, increase memory consumpti…
tfeher Sep 18, 2019
24e2139
Increase N_SV tolerance
tfeher Sep 18, 2019
9fc815e
Use cosine similarity when testing separating hyperplane for linear SVM
tfeher Sep 24, 2019
f6e1377
Move GramMatrix files into the matrix folder/namespace.
tfeher Sep 24, 2019
dfe1fe4
Replace SVM ws_util's map_to_sorted with thrust iterator magic
tfeher Sep 24, 2019
7552f58
Replace get_rows in KernelCache with Matrix::copyRows
tfeher Sep 25, 2019
a08fa20
Remove unused header
tfeher Sep 25, 2019
28d17c1
Update changelog
tfeher Sep 25, 2019
54cb04a
Adjust python wrapper after moving GramMatrix into Matrix
tfeher Sep 25, 2019
3d7d9d1
Merge branch 'branch-0.10' into fea-ext-svm
tfeher Sep 25, 2019
de476fa
Implement classificatio accuracy dependent tolerance for SVM tests
tfeher Sep 25, 2019
b6e5616
Adjust blob test parameters case-by-case
tfeher Sep 25, 2019
76336cd
Adjusted SVM stress test parameters
tfeher Sep 25, 2019
6eba3f5
Merge branch 'branch-0.10' into fea-ext-svm
tfeher Sep 26, 2019
af6924f
Increase memory footprint of memory leak detection
tfeher Sep 26, 2019
97d2c00
Remove early return from convert_dtype. This fixes test_input_utils t…
tfeher Sep 26, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
- PR #1113: prims: new batched make-symmetric-matrix primitive
- PR #1112: prims: new batched-gemv primitive
- PR #855: Added benchmark tools
- PR #892: General Gram matrices prim
- PR #912: Support Vector Machine

## Improvements
- PR #961: High Peformance RF; HIST algo
Expand Down
5 changes: 4 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ if(BUILD_CUML_CPP_LIBRARY)
src/random_projection/rproj.cu
src/solver/solver.cu
src/spectral/spectral.cu
src/svm/svc.cu
src/svm/ws_util.cu
src/tsne/tsne.cu
src/tsvd/tsvd.cu
src/umap/umap.cu
Expand Down Expand Up @@ -375,7 +377,8 @@ if(BUILD_CUML_C_LIBRARY)
src/common/cuml_api.cpp
src/dbscan/dbscan_api.cpp
src/glm/glm_api.cpp
src/holtwinters/holtwinters_api.cpp)
src/holtwinters/holtwinters_api.cpp
src/svm/svm_api.cpp)
target_link_libraries(${CUML_C_TARGET} ${CUML_CPP_TARGET})
endif(BUILD_CUML_C_LIBRARY)

Expand Down
152 changes: 152 additions & 0 deletions cpp/src/svm/kernelcache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cuda_utils.h>
#include <linalg/gemm.h>
#include "cache/cache.h"
#include "common/cumlHandle.hpp"
#include "common/host_buffer.hpp"
#include "matrix/grammatrix.h"
#include "matrix/matrix.h"
#include "ml_utils.h"

namespace ML {
namespace SVM {

/**
* @brief Buffer to store a kernel tile
*
* We calculate the kernel matrix for the vectors in the working set.
* For every vector x_i in the working set, we always calculate a full row of the
* kernel matrix K(x_j, x_i), j=1..n_rows.
*
* A kernel tile stores all the kernel rows for the working set, i.e. K(x_j, x_i)
* for all i in the working set, and j in 1..n_rows.
*
* The kernel values can be cached to avoid repeated calculation of the kernel
* function.
*/
template <typename math_t>
class KernelCache {
private:
const math_t *x; //!< pointer to the training vectors

MLCommon::device_buffer<math_t>
x_ws; //!< feature vectors in the current working set
MLCommon::device_buffer<int>
ws_cache_idx; //!< cache position of a workspace vectors
MLCommon::device_buffer<math_t> tile; //!< Kernel matrix tile

int n_rows; //!< number of rows in x
int n_cols; //!< number of columns in x
int n_ws; //!< number of elements in the working set

cublasHandle_t cublas_handle;

MLCommon::Matrix::GramMatrixBase<math_t> *kernel;

const cumlHandle_impl handle;

const int TPB = 256; //!< threads per block for kernels launched

MLCommon::Cache::Cache<math_t> cache;

cudaStream_t stream;

public:
/**
* Construct an object to manage kernel cache
*
* @param handle reference to cumlHandle implementation
* @param x device array of training vectors in column major format,
* size [n_rows x n_cols]
* @param n_rows number of training vectors
* @param n_cols number of features
* @param n_ws size of working set
* @param kernel pointer to kernel (default linear)
* @param cache_size (default 200 MiB)
*/
KernelCache(const cumlHandle_impl &handle, const math_t *x, int n_rows,
int n_cols, int n_ws,
MLCommon::Matrix::GramMatrixBase<math_t> *kernel,
float cache_size = 200)
: cache(handle.getDeviceAllocator(), handle.getStream(), n_rows,
cache_size),
kernel(kernel),
x(x),
n_rows(n_rows),
n_cols(n_cols),
n_ws(n_ws),
cublas_handle(handle.getCublasHandle()),
x_ws(handle.getDeviceAllocator(), handle.getStream(), n_ws * n_cols),
tile(handle.getDeviceAllocator(), handle.getStream(), n_ws * n_rows),
ws_cache_idx(handle.getDeviceAllocator(), handle.getStream(), n_ws) {
ASSERT(kernel != nullptr, "Kernel pointer required for KernelCache!");

stream = handle.getStream();
}

~KernelCache(){};

/**
* @brief Get all the kernel matrix rows for the working set.
* @param ws_idx indices of the working set
* @return pointer to the kernel tile [ n_rows x n_ws] K_j,i = K(x_j, x_q)
* where j=1..n_rows and q = ws_idx[i], j is the contiguous dimension
*/
math_t *GetTile(int *ws_idx) {
if (cache.GetSize() > 0) {
int n_cached;
cache.GetCacheIdxPartitioned(ws_idx, n_ws, ws_cache_idx.data(), &n_cached,
stream);
// collect allready cached values
cache.GetVecs(ws_cache_idx.data(), n_cached, tile.data(), stream);

int non_cached = n_ws - n_cached;
if (non_cached > 0) {
int *ws_idx_new = ws_idx + n_cached;
// AssignCacheIdx can permute ws_idx_new, therefore it has to come
// before calcKernel. Could come on separate stream to do collectrows
// while AssignCacheIdx runs
cache.AssignCacheIdx(ws_idx_new, non_cached,
ws_cache_idx.data() + n_cached,
stream); // cache stream

// collect training vectors for kernel elements that needs to be calculated
MLCommon::Matrix::copyRows(x, n_rows, n_cols, x_ws.data(), ws_idx_new,
non_cached, stream, false);
math_t *tile_new = tile.data() + n_cached * n_rows;
(*kernel)(x, n_rows, n_cols, x_ws.data(), non_cached, tile_new, stream);
// We need AssignCacheIdx to be finished before calling StoreCols
cache.StoreVecs(tile_new, n_rows, non_cached,
ws_cache_idx.data() + n_cached, stream);
}
} else {
if (n_ws > 0) {
// collect all the feature vectors in the working set
MLCommon::Matrix::copyRows(x, n_rows, n_cols, x_ws.data(), ws_idx, n_ws,
stream, false);
(*kernel)(x, n_rows, n_cols, x_ws.data(), n_ws, tile.data(), stream);
}
}
return tile.data();
}
};

}; // end namespace SVM
}; // end namespace ML
Loading