From e8ecafb8dc628f45b75b4c2844a236d27e0a6d98 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 16 Jun 2020 20:00:24 +0800 Subject: [PATCH] Accept string for ArrayInterface constructor. --- src/data/array_interface.h | 41 ++++++++++++++++----- src/data/data.cu | 2 +- tests/cpp/data/test_array_interface.cc | 51 ++++++++++++++++++++++++++ tests/cpp/helpers.cc | 8 +++- 4 files changed, 90 insertions(+), 12 deletions(-) create mode 100644 tests/cpp/data/test_array_interface.cc diff --git a/src/data/array_interface.h b/src/data/array_interface.h index 53152b63d9d9..c4db6c5dc66c 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -1,7 +1,7 @@ /*! * Copyright 2019 by Contributors * \file array_interface.h - * \brief Basic structure holding a reference to arrow columnar data format. + * \brief View of __array_interface__ */ #ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_ #define XGBOOST_DATA_ARRAY_INTERFACE_H_ @@ -11,6 +11,7 @@ #include #include +#include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/json.h" #include "xgboost/logging.h" @@ -113,6 +114,7 @@ class ArrayInterfaceHandler { get( obj.at("data")) .at(0)))); + CHECK(p_data); return p_data; } @@ -186,7 +188,7 @@ class ArrayInterfaceHandler { return 0; } - static std::pair ExtractShape( + static std::pair ExtractShape( std::map const& column) { auto j_shape = get(column.at("shape")); auto typestr = get(column.at("typestr")); @@ -201,12 +203,12 @@ class ArrayInterfaceHandler { } if (j_shape.size() == 1) { - return {static_cast(get(j_shape.at(0))), 1}; + return {static_cast(get(j_shape.at(0))), 1}; } else { CHECK_EQ(j_shape.size(), 2) << "Only 1D or 2-D arrays currently supported."; - return {static_cast(get(j_shape.at(0))), - static_cast(get(j_shape.at(1)))}; + return {static_cast(get(j_shape.at(0))), + static_cast(get(j_shape.at(1)))}; } } template @@ -219,7 +221,6 @@ class ArrayInterfaceHandler { CHECK_EQ(typestr.at(2), static_cast(sizeof(T) + 48)) << "Input data type and typestr mismatch. typestr: " << typestr; - auto shape = ExtractShape(column); T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData(column); @@ -231,8 +232,8 @@ class ArrayInterfaceHandler { class ArrayInterface { public: ArrayInterface() = default; - explicit ArrayInterface(std::map const &column, - bool allow_mask = true) { + void Initialize(std::map const &column, + bool allow_mask = true) { ArrayInterfaceHandler::Validate(column); data = ArrayInterfaceHandler::GetPtrFromArrayData(column); CHECK(data) << "Column is null"; @@ -263,6 +264,25 @@ class ArrayInterface { this->CheckType(); } + explicit ArrayInterface(std::string const& str, bool allow_mask = true) { + auto jinterface = Json::Load({str.c_str(), str.size()}); + if (IsA(jinterface)) { + this->Initialize(get(jinterface), allow_mask); + return; + } + if (IsA(jinterface)) { + CHECK_EQ(get(jinterface).size(), 1) + << "Column: " << ArrayInterfaceErrors::Dimension(1); + this->Initialize(get(get(jinterface)[0]), allow_mask); + return; + } + } + + explicit ArrayInterface(std::map const &column, + bool allow_mask = true) { + this->Initialize(column, allow_mask); + } + void CheckType() const { if (type[1] == 'f' && type[2] == '4') { return; @@ -291,6 +311,7 @@ class ArrayInterface { } XGBOOST_DEVICE float GetElement(size_t idx) const { + SPAN_CHECK(idx < num_cols * num_rows); if (type[1] == 'f' && type[2] == '4') { return reinterpret_cast(data)[idx]; } else if (type[1] == 'f' && type[2] == '8') { @@ -318,8 +339,8 @@ class ArrayInterface { } RBitField8 valid; - int32_t num_rows; - int32_t num_cols; + bst_row_t num_rows; + bst_feature_t num_cols; void* data; char type[3]; }; diff --git a/src/data/data.cu b/src/data/data.cu index 526f9a67366e..fb57f4751545 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -63,7 +63,7 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { auto const& j_arr = get(j_interface); CHECK_EQ(j_arr.size(), 1) << "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1); - ArrayInterface array_interface(get(j_arr[0])); + ArrayInterface array_interface(interface_str); std::string key{c_key}; CHECK(!array_interface.valid.Data()) << "Meta info " << key << " should be dense, found validity mask"; diff --git a/tests/cpp/data/test_array_interface.cc b/tests/cpp/data/test_array_interface.cc new file mode 100644 index 000000000000..5fe93ffa493b --- /dev/null +++ b/tests/cpp/data/test_array_interface.cc @@ -0,0 +1,51 @@ +/*! + * Copyright 2020 by XGBoost Contributors + */ +#include +#include +#include "../helpers.h" +#include "../../../src/data/array_interface.h" + +namespace xgboost { +TEST(ArrayInterface, Initialize) { + size_t constexpr kRows = 10, kCols = 10; + HostDeviceVector storage; + auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); + auto arr_interface = ArrayInterface(array); + ASSERT_EQ(arr_interface.num_rows, kRows); + ASSERT_EQ(arr_interface.num_cols, kCols); + ASSERT_EQ(arr_interface.data, storage.ConstHostPointer()); +} + +TEST(ArrayInterface, Error) { + constexpr size_t kRows = 16, kCols = 10; + Json column { Object() }; + std::vector j_shape {Json(Integer(static_cast(kRows)))}; + column["shape"] = Array(j_shape); + std::vector j_data { + Json(Integer(reinterpret_cast(nullptr))), + Json(Boolean(false))}; + + auto const& column_obj = get(column); + // missing version + EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj), dmlc::Error); + column["version"] = Integer(static_cast(1)); + // missing data + EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj), dmlc::Error); + column["data"] = j_data; + // missing typestr + EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj), dmlc::Error); + column["typestr"] = String("(column_obj), dmlc::Error); + + HostDeviceVector storage; + auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); + j_data = { + Json(Integer(reinterpret_cast(storage.ConstHostPointer()))), + Json(Boolean(false))}; + column["data"] = j_data; + EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj)); +} + +} // namespace xgboost diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 893891d13b9f..2274e57e7307 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -182,7 +182,13 @@ Json RandomDataGenerator::ArrayInterfaceImpl(HostDeviceVector *storage, this->GenerateDense(storage); Json array_interface {Object()}; array_interface["data"] = std::vector(2); - array_interface["data"][0] = Integer(reinterpret_cast(storage->DevicePointer())); + if (storage->DeviceCanRead()) { + array_interface["data"][0] = + Integer(reinterpret_cast(storage->ConstDevicePointer())); + } else { + array_interface["data"][0] = + Integer(reinterpret_cast(storage->ConstHostPointer())); + } array_interface["data"][1] = Boolean(false); array_interface["shape"] = std::vector(2);