Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device dmatrix #5420

Merged
merged 18 commits into from
Mar 28, 2020
7 changes: 4 additions & 3 deletions python-package/xgboost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import sys
import warnings

from .core import DMatrix, Booster
from .core import DMatrix, DeviceQuantileDMatrix, Booster
from .training import train, cv
from . import rabit # noqa
from . import rabit # noqa
from . import tracker # noqa
from .tracker import RabitTracker # noqa
from . import dask

try:
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
from .sklearn import XGBRFClassifier, XGBRFRegressor
Expand All @@ -31,7 +32,7 @@
with open(VERSION_FILE) as f:
__version__ = f.read().strip()

__all__ = ['DMatrix', 'Booster',
__all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster',
'train', 'cv',
'RabitTracker',
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
Expand Down
79 changes: 72 additions & 7 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,18 @@ def _maybe_pandas_data(data, feature_names, feature_types,
return data, feature_names, feature_types


def _cudf_array_interfaces(df):
'''Extract CuDF __cuda_array_interface__'''
interfaces = []
for col in df:
interface = df[col].__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interfaces.append(interface)
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
return interfaces_str


def _maybe_cudf_dataframe(data, feature_names, feature_types):
"""Extract internal data from cudf.DataFrame for DMatrix data."""
if not (CUDF_INSTALLED and isinstance(data,
Expand Down Expand Up @@ -592,16 +604,10 @@ def _init_from_dt(self, data, nthread):

def _init_from_array_interface_columns(self, df, missing, nthread):
"""Initialize DMatrix from columnar memory format."""
interfaces = []
for col in df:
interface = df[col].__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interfaces.append(interface)
interfaces_str = _cudf_array_interfaces(df)
handle = ctypes.c_void_p()
missing = missing if missing is not None else np.nan
nthread = nthread if nthread is not None else 1
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
_check_call(
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
Expand Down Expand Up @@ -1001,6 +1007,65 @@ def feature_types(self, feature_types):
self._feature_types = feature_types


class DeviceQuantileDMatrix(DMatrix):
"""Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do not
use this for test/validation tasks as some information may be lost in quantisation. This
DMatrix is primarily designed to save memory in training and avoids intermediate steps,
directly creating a compressed representation for training without allocating additional
memory. Implementation does not currently consider weights in quantisation process(unlike
DMatrix).

You can construct DeviceDMatrix from cupy/cudf
"""

def __init__(self, data, label=None, weight=None, base_margin=None,
missing=None,
silent=False,
feature_names=None,
feature_types=None,
nthread=None, max_bin=256):
self.max_bin = max_bin
if not (hasattr(data, "__cuda_array_interface__") or (
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))):
raise ValueError('Only cupy/cudf currently supported for DeviceDMatrix')

super().__init__(data, label=label, weight=weight, base_margin=base_margin,
missing=missing,
silent=silent,
feature_names=feature_names,
feature_types=feature_types,
nthread=nthread)

def _init_from_array_interface_columns(self, df, missing, nthread):
"""Initialize DMatrix from columnar memory format."""
interfaces_str = _cudf_array_interfaces(df)
handle = ctypes.c_void_p()
missing = missing if missing is not None else np.nan
nthread = nthread if nthread is not None else 1
_check_call(
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
ctypes.c_float(missing), ctypes.c_int(nthread),
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
self.handle = handle

def _init_from_array_interface(self, data, missing, nthread):
"""Initialize DMatrix from cupy ndarray."""
interface = data.__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')

handle = ctypes.c_void_p()
missing = missing if missing is not None else np.nan
nthread = nthread if nthread is not None else 1
_check_call(
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterface(
interface_str,
ctypes.c_float(missing), ctypes.c_int(nthread),
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
self.handle = handle

class Booster(object):
# pylint: disable=too-many-public-methods
"""A Booster of XGBoost.
Expand Down
23 changes: 23 additions & 0 deletions src/c_api/c_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "xgboost/learner.h"
#include "c_api_error.h"
#include "../data/device_adapter.cuh"
#include "../data/device_dmatrix.h"

using namespace xgboost; // NOLINT

Expand All @@ -29,3 +30,25 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
API_END();
}

XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
bst_float missing, int nthread, int max_bin,
DMatrixHandle* out) {
API_BEGIN();
std::string json_str{c_json_strs};
data::CudfAdapter adapter(json_str);
*out =
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
API_END();
}

XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterface(char const* c_json_strs,
bst_float missing, int nthread, int max_bin,
DMatrixHandle* out) {
API_BEGIN();
std::string json_str{c_json_strs};
data::CupyAdapter adapter(json_str);
*out =
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
API_END();
}
22 changes: 8 additions & 14 deletions src/common/compressed_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ static const int kPadding = 4; // Assign padding so we can read slightly off
// the beginning of the array

// The number of bits required to represent a given unsigned range
static size_t SymbolBits(size_t num_symbols) {
auto bits = std::ceil(std::log2(num_symbols));
inline XGBOOST_DEVICE size_t SymbolBits(size_t num_symbols) {
auto bits = std::ceil(log2(static_cast<double>(num_symbols)));
return std::max(static_cast<size_t>(bits), size_t(1));
}
} // namespace detail
Expand All @@ -50,14 +50,11 @@ static size_t SymbolBits(size_t num_symbols) {
*/

class CompressedBufferWriter {
private:
size_t symbol_bits_;
size_t offset_;

public:
explicit CompressedBufferWriter(size_t num_symbols) : offset_(0) {
symbol_bits_ = detail::SymbolBits(num_symbols);
}
XGBOOST_DEVICE explicit CompressedBufferWriter(size_t num_symbols)
: symbol_bits_(detail::SymbolBits(num_symbols)) {}

/**
* \fn static size_t CompressedBufferWriter::CalculateBufferSize(int
Expand Down Expand Up @@ -159,18 +156,15 @@ class CompressedBufferWriter {
}
};

template <typename T>

/**
* \class CompressedIterator
*
* \brief Read symbols from a bit compressed memory buffer. Usable on device and
* host.
* \brief Read symbols from a bit compressed memory buffer. Usable on device and host.
*
* \author Rory
* \date 7/9/2017
*
* \tparam T Generic type parameter.
*/

template <typename T>
class CompressedIterator {
public:
// Type definitions for thrust
Expand Down
8 changes: 8 additions & 0 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1540,4 +1540,12 @@ DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
}


// Thrust version of this function causes error on Windows
template <typename ReturnT, typename IterT, typename FuncT>
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
IterT iter, FuncT func) {
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
}

} // namespace dh
31 changes: 3 additions & 28 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -338,31 +338,6 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
return cuts;
}

struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
explicit IsValidFunctor(float missing) : missing(missing) {}

float missing;
__device__ bool operator()(const data::COOTuple& e) const {
if (common::CheckNAN(e.value) || e.value == missing) {
return false;
}
return true;
}
__device__ bool operator()(const Entry& e) const {
if (common::CheckNAN(e.fvalue) || e.fvalue == missing) {
return false;
}
return true;
}
};

// Thrust version of this function causes error on Windows
template <typename ReturnT, typename IterT, typename FuncT>
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
IterT iter, FuncT func) {
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
}

template <typename AdapterT>
void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
SketchContainer* sketch_container, int num_cuts) {
Expand All @@ -372,10 +347,10 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
auto &batch = adapter->Value();
// Enforce single batch
CHECK(!adapter->Next());
auto batch_iter = MakeTransformIterator<data::COOTuple>(
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
auto entry_iter = MakeTransformIterator<Entry>(
auto entry_iter = dh::MakeTransformIterator<Entry>(
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
return Entry(batch.GetElement(idx).column_idx,
batch.GetElement(idx).value);
Expand All @@ -385,7 +360,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
0);

auto d_column_sizes_scan = column_sizes_scan.data().get();
IsValidFunctor is_valid(missing);
data::IsValidFunctor is_valid(missing);
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
auto e = batch_iter[begin + idx];
if (is_valid(e)) {
Expand Down
6 changes: 3 additions & 3 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ class HistogramCuts {
auto end = cut_ptrs_.ConstHostVector().at(column_id + 1);
const auto &values = cut_values_.ConstHostVector();
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
if (it == values.cend()) {
it = values.cend() - 1;
}
BinIdx idx = it - values.cbegin();
if (idx == end) {
idx -= 1;
}
return idx;
}

Expand Down
2 changes: 1 addition & 1 deletion src/data/data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "array_interface.h"
#include "../common/device_helpers.cuh"
#include "device_adapter.cuh"
#include "simple_dmatrix.h"
#include "device_dmatrix.h"

namespace xgboost {

Expand Down
19 changes: 19 additions & 0 deletions src/data/device_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,31 @@
#include <memory>
#include <string>
#include "../common/device_helpers.cuh"
#include "../common/math.h"
#include "adapter.h"
#include "array_interface.h"

namespace xgboost {
namespace data {

struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
explicit IsValidFunctor(float missing) : missing(missing) {}

float missing;
__device__ bool operator()(const data::COOTuple& e) const {
if (common::CheckNAN(e.value) || e.value == missing) {
return false;
}
return true;
}
__device__ bool operator()(const Entry& e) const {
if (common::CheckNAN(e.fvalue) || e.fvalue == missing) {
return false;
}
return true;
}
};

class CudfAdapterBatch : public detail::NoMetaInfo {
public:
CudfAdapterBatch() = default;
Expand Down
Loading