Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Mar 16, 2020
1 parent caa0132 commit 2bec61a
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 88 deletions.
13 changes: 6 additions & 7 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,13 +1021,12 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
print(type(data))
raise ValueError('Only cupy/cudf currently supported for DeviceDMatrix')


super().__init__(data,label=label,weight=weight, base_margin=base_margin,
missing=missing,
silent=silent,
feature_names=feature_names,
feature_types=feature_types,
nthread=nthread)
super().__init__(data, label=label, weight=weight, base_margin=base_margin,
missing=missing,
silent=silent,
feature_names=feature_names,
feature_types=feature_types,
nthread=nthread)

def _init_from_array_interface(self, data, missing, nthread):
"""Initialize DMatrix from cupy ndarray."""
Expand Down
6 changes: 3 additions & 3 deletions src/common/compressed_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ static const int kPadding = 4; // Assign padding so we can read slightly off

// The number of bits required to represent a given unsigned range
inline XGBOOST_DEVICE size_t SymbolBits(size_t num_symbols) {
auto bits = std::ceil(log2(double(num_symbols)));
auto bits = std::ceil(log2(static_cast<double>(num_symbols)));
return std::max(static_cast<size_t>(bits), size_t(1));
}
} // namespace detail
Expand All @@ -53,8 +53,8 @@ class CompressedBufferWriter {
size_t symbol_bits_;

public:
XGBOOST_DEVICE explicit CompressedBufferWriter(size_t num_symbols):symbol_bits_(detail::SymbolBits(num_symbols)) {
}
XGBOOST_DEVICE explicit CompressedBufferWriter(size_t num_symbols)
: symbol_bits_(detail::SymbolBits(num_symbols)) {}

/**
* \fn static size_t CompressedBufferWriter::CalculateBufferSize(int
Expand Down
136 changes: 67 additions & 69 deletions src/data/device_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,19 @@
* \brief Device-memory version of DMatrix.
*/

#include <thrust/execution_policy.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <xgboost/base.h>
#include <xgboost/data.h>

#include <memory>
#include <thrust/execution_policy.h>

#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <utility>
#include "../common/hist_util.h"
#include "../common/math.h"
#include "adapter.h"
#include "device_dmatrix.h"
#include "device_adapter.cuh"
#include "ellpack_page.cuh"
#include "../common/hist_util.h"
#include "../common/math.h"
#include "device_dmatrix.h"

namespace xgboost {
namespace data {
Expand All @@ -37,7 +36,7 @@ struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
// Returns maximum row length
template <typename AdapterBatchT>
size_t GetRowCounts(const AdapterBatchT& batch, common::Span<size_t> offset,
int device_idx, float missing) {
int device_idx, float missing) {
IsValidFunctor is_valid(missing);
// Count elements per row
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
Expand All @@ -51,23 +50,23 @@ size_t GetRowCounts(const AdapterBatchT& batch, common::Span<size_t> offset,
dh::XGBCachingDeviceAllocator<char> alloc;
size_t row_stride = thrust::reduce(
thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data()) + offset.size(),
size_t(0),
thrust::maximum<size_t>());
thrust::device_pointer_cast(offset.data()) + offset.size(), size_t(0),
thrust::maximum<size_t>());
return row_stride;
}

template <typename AdapterBatchT>
struct WriteCompressedEllpackFunctor
{
template <typename AdapterBatchT>
struct WriteCompressedEllpackFunctor {
WriteCompressedEllpackFunctor(common::CompressedByteT* buffer,
const common::CompressedBufferWriter& writer, const AdapterBatchT& batch,const EllpackDeviceAccessor& accessor,const IsValidFunctor&is_valid)
:
d_buffer(buffer),
writer(writer),
batch(batch),accessor(accessor),is_valid(is_valid)
{
}
const common::CompressedBufferWriter& writer,
const AdapterBatchT& batch,
EllpackDeviceAccessor accessor,
const IsValidFunctor& is_valid)
: d_buffer(buffer),
writer(writer),
batch(batch),
accessor(std::move(accessor)),
is_valid(is_valid) {}

common::CompressedByteT* d_buffer;
common::CompressedBufferWriter writer;
Expand All @@ -76,55 +75,57 @@ struct WriteCompressedEllpackFunctor
IsValidFunctor is_valid;

using Tuple = thrust::tuple<size_t, size_t, size_t>;
__device__ size_t operator()(Tuple out)
{
__device__ size_t operator()(Tuple out) {
auto e = batch.GetElement(out.get<2>());
if (is_valid(e)) {
// -1 because the scan is inclusive
size_t output_position = accessor.row_stride * e.row_idx + out.get<1>() - 1;
size_t output_position =
accessor.row_stride * e.row_idx + out.get<1>() - 1;
auto bin_idx = accessor.SearchBin(e.value, e.column_idx);
writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position);
}
return 0;

}
};

// Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data
template <typename AdapterBatchT>
void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl*dst,
int device_idx, float missing,common::Span<size_t> row_counts) {
void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst,
int device_idx, float missing,
common::Span<size_t> row_counts) {
// Some witchcraft happens here
// The goal is to copy valid elements out of the input to an ellpack matrix with a given row stride, using no extra working memory
// Standard stream compaction needs to be modified to do this, so we manually define a segmented stream compaction via operators on an inclusive scan. The output of this inclusive scan is fed to a custom function which works out the correct output position
// The goal is to copy valid elements out of the input to an ellpack matrix
// with a given row stride, using no extra working memory Standard stream
// compaction needs to be modified to do this, so we manually define a
// segmented stream compaction via operators on an inclusive scan. The output
// of this inclusive scan is fed to a custom function which works out the
// correct output position
auto counting = thrust::make_counting_iterator(0llu);
IsValidFunctor is_valid(missing);
auto key_iter = dh::MakeTransformIterator<size_t >(counting,[=]__device__ (size_t idx)
{
return batch.GetElement(idx).row_idx;
});
auto value_iter = dh::MakeTransformIterator<size_t>(
auto key_iter = dh::MakeTransformIterator<size_t>(
counting,
[=]__device__ (size_t idx) -> size_t
{
return is_valid(batch.GetElement(idx));
});
[=] __device__(size_t idx) { return batch.GetElement(idx).row_idx; });
auto value_iter = dh::MakeTransformIterator<size_t>(
counting, [=] __device__(size_t idx) -> size_t {
return is_valid(batch.GetElement(idx));
});

auto key_value_index_iter = thrust::make_zip_iterator(thrust::make_tuple(key_iter, value_iter, counting));
auto key_value_index_iter = thrust::make_zip_iterator(
thrust::make_tuple(key_iter, value_iter, counting));

// Tuple[0] = The row index of the input, used as a key to define segments
// Tuple[1] = Scanned flags of valid elements for each row
// Tuple[2] = The index in the input data
using Tuple = thrust::tuple<size_t , size_t , size_t >;
using Tuple = thrust::tuple<size_t, size_t, size_t>;

auto device_accessor = dst->GetDeviceAccessor(device_idx);
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();

// We redirect the scan output into this functor to do the actual writing
WriteCompressedEllpackFunctor<AdapterBatchT> functor(d_compressed_buffer, writer,
batch, device_accessor, is_valid);
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
d_compressed_buffer, writer, batch, device_accessor, is_valid);
thrust::discard_iterator<size_t> discard;
thrust::transform_output_iterator<
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
Expand Down Expand Up @@ -153,8 +154,8 @@ void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch,
dh::LaunchN(adapter->DeviceIdx(), batch.Size(), [=] __device__(size_t idx) {
const auto& e = batch.GetElement(idx);
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&d_column_sizes[e.column_idx]),
static_cast<unsigned long long>(1)); // NOLINT
&d_column_sizes[e.column_idx]),
static_cast<unsigned long long>(1)); // NOLINT
});

thrust::host_vector<size_t> host_column_sizes = column_sizes;
Expand All @@ -173,59 +174,57 @@ void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch,
size_t end = begin + size;
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
auto writer_non_const =
writer; // For some reason this variable gets captured as const
writer; // For some reason this variable gets captured as const
const auto& e = batch.GetElement(idx + begin);
if (!is_valid(e)) return;
size_t output_position = e.row_idx * row_stride + d_temp_row_ptr[e.row_idx];
size_t output_position =
e.row_idx * row_stride + d_temp_row_ptr[e.row_idx];
auto bin_idx = device_accessor.SearchBin(e.value, e.column_idx);
writer_non_const.AtomicWriteSymbol(d_compressed_buffer, bin_idx, output_position);
writer_non_const.AtomicWriteSymbol(d_compressed_buffer, bin_idx,
output_position);
d_temp_row_ptr[e.row_idx] += 1;
});

begin = end;
}
}

void WriteNullValues(EllpackPageImpl*dst,
int device_idx, common::Span<size_t> row_counts)
{
// Write the null values
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
common::Span<size_t> row_counts) {
// Write the null values
auto device_accessor = dst->GetDeviceAccessor(device_idx);
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
auto row_stride = dst->row_stride;
dh::LaunchN(device_idx, row_stride * dst->n_rows, [=] __device__(size_t idx)
{
dh::LaunchN(device_idx, row_stride * dst->n_rows, [=] __device__(size_t idx) {
auto writer_non_const =
writer; // For some reason this variable gets captured as const
writer; // For some reason this variable gets captured as const
size_t row_idx = idx / row_stride;
size_t row_offset = idx % row_stride;
if (row_offset >= row_counts[row_idx])
{
if (row_offset >= row_counts[row_idx]) {
writer_non_const.AtomicWriteSymbol(d_compressed_buffer,
device_accessor.NullValue(), idx);
}
});

}
}
// Does not currently support metainfo as no on-device data source contains this
// Current implementation assumes a single batch. More batches can
// be supported in future. Does not currently support inferring row/column size
template <typename AdapterT>
template <typename AdapterT>
DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread) {
common::HistogramCuts cuts = common::AdapterDeviceSketch(adapter, 256, missing);
auto & batch = adapter->Value();
common::HistogramCuts cuts =
common::AdapterDeviceSketch(adapter, 256, missing);
auto& batch = adapter->Value();
// Work out how many valid entries we have in each row
dh::caching_device_vector<size_t> row_counts(adapter->NumRows() + 1,
0);
common::Span<size_t > row_counts_span( row_counts.data().get(),row_counts.size() );
dh::caching_device_vector<size_t> row_counts(adapter->NumRows() + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(),
row_counts.size());
size_t row_stride =
GetRowCounts(batch, row_counts_span, adapter->DeviceIdx(), missing);

dh::XGBCachingDeviceAllocator<char> alloc;
info.num_nonzero_ = thrust::reduce(thrust::cuda::par(alloc),
row_counts.begin(),
row_counts.end());
row_counts.begin(), row_counts.end());
info.num_col_ = adapter->NumColumns();
info.num_row_ = adapter->NumRows();
ellpack_page_.reset(new EllpackPage());
Expand All @@ -239,8 +238,7 @@ DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread) {
CopyDataColumnMajor(adapter, batch, ellpack_page_->Impl(), missing);
}

WriteNullValues(ellpack_page_->Impl(), adapter->DeviceIdx(),
row_counts_span);
WriteNullValues(ellpack_page_->Impl(), adapter->DeviceIdx(), row_counts_span);

// Synchronise worker columns
rabit::Allreduce<rabit::op::Max>(&info.num_col_, 1);
Expand Down
16 changes: 7 additions & 9 deletions src/data/device_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
#include <memory>

#include "adapter.h"
#include "simple_dmatrix.h"
#include "simple_batch_iterator.h"
#include "simple_dmatrix.h"

namespace xgboost {
namespace data {
Expand All @@ -22,7 +22,7 @@ class DeviceDMatrix : public DMatrix {
public:
template <typename AdapterT>
explicit DeviceDMatrix(AdapterT* adapter, float missing, int nthread);

MetaInfo& Info() override { return info; }

const MetaInfo& Info() const override { return info; }
Expand All @@ -31,19 +31,17 @@ class DeviceDMatrix : public DMatrix {

bool EllpackExists() const override { return true; }
bool SparsePageExists() const override { return false; }

private:
BatchSet<SparsePage> GetRowBatches() override
{
BatchSet<SparsePage> GetRowBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
}
BatchSet<CSCPage> GetColumnBatches()override
{
BatchSet<CSCPage> GetColumnBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<CSCPage>(BatchIterator<CSCPage>(nullptr));
}
BatchSet<SortedCSCPage> GetSortedColumnBatches()override
{
BatchSet<SortedCSCPage> GetSortedColumnBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
}
Expand All @@ -59,4 +57,4 @@ class DeviceDMatrix : public DMatrix {
};
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_DEVICE_DMATRIX_H_
#endif // XGBOOST_DATA_DEVICE_DMATRIX_H_
1 change: 1 addition & 0 deletions tests/python-gpu/test_gpu_updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_gpu_hist(self):
cpu_results = run_suite(param, select_datasets=datasets)
assert_gpu_results(cpu_results, gpu_results)

@pytest.mark.skipif(**tm.no_cupy())
def test_gpu_hist_device_dmatrix(self):
# Cannot vary max_bin yet
device_dmatrix_test_param = parameter_combinations({
Expand Down

0 comments on commit 2bec61a

Please sign in to comment.