Skip to content

Commit

Permalink
refactor LGBM_DatasetGetFeatureNames (#3022)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS authored Jun 11, 2020
1 parent b3a84df commit f30e0bb
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 12 deletions.
13 changes: 10 additions & 3 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,23 @@ LGBM_SE LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
R_API_BEGIN();
int len = 0;
CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
const size_t reserved_string_size = 256;
std::vector<std::vector<char>> names(len);
std::vector<char*> ptr_names(len);
for (int i = 0; i < len; ++i) {
names[i].resize(256);
names[i].resize(reserved_string_size);
ptr_names[i] = names[i].data();
}
int out_len;
CHECK_CALL(LGBM_DatasetGetFeatureNames(R_GET_PTR(handle),
ptr_names.data(), &out_len));
size_t required_string_size;
CHECK_CALL(
LGBM_DatasetGetFeatureNames(
R_GET_PTR(handle),
len, &out_len,
reserved_string_size, &required_string_size,
ptr_names.data()));
CHECK_EQ(len, out_len);
CHECK_GE(reserved_string_size, required_string_size);
auto merge_str = Join<char*>(ptr_names, "\t");
EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1);
R_API_END();
Expand Down
14 changes: 11 additions & 3 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,21 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(DatasetHandle handle,
/*!
* \brief Get feature names of dataset.
* \param handle Handle of dataset
* \param[out] feature_names Feature names, should pre-allocate memory
* \param len Number of ``char*`` pointers stored at ``out_strs``.
* If smaller than the max size, only this many strings are copied
* \param[out] num_feature_names Number of feature names
* \param buffer_len Size of pre-allocated strings.
* Content is copied up to ``buffer_len - 1`` and null-terminated
* \param[out] out_buffer_len String sizes required to do the full string copies
* \param[out] feature_names Feature names, should pre-allocate memory
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(DatasetHandle handle,
char** feature_names,
int* num_feature_names);
const int len,
int* num_feature_names,
const size_t buffer_len,
size_t* out_buffer_len,
char** feature_names);

/*!
* \brief Free space for dataset.
Expand Down
32 changes: 32 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,38 @@ def set_group(self, group):
self.set_field('group', group)
return self

def get_feature_name(self):
"""Get the names of columns (features) in the Dataset.
Returns
-------
feature_names : list
The names of columns (features) in the Dataset.
"""
if self.handle is None:
raise LightGBMError("Cannot get feature_name before construct dataset")
num_feature = self.num_feature()
tmp_out_len = ctypes.c_int(0)
reserved_string_buffer_size = 255
required_string_buffer_size = ctypes.c_size_t(0)
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for i in range_(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_DatasetGetFeatureNames(
self.handle,
num_feature,
ctypes.byref(tmp_out_len),
reserved_string_buffer_size,
ctypes.byref(required_string_buffer_size),
ptr_string_buffers))
if num_feature != tmp_out_len.value:
raise ValueError("Length of feature names doesn't equal with num_feature")
if reserved_string_buffer_size < required_string_buffer_size.value:
raise BufferError(
"Allocated feature name buffer size ({}) was inferior to the needed size ({})."
.format(reserved_string_buffer_size, required_string_buffer_size.value)
)
return [string_buffers[i].value.decode('utf-8') for i in range_(num_feature)]

def get_label(self):
"""Get the label of the Dataset.
Expand Down
18 changes: 13 additions & 5 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1110,15 +1110,23 @@ int LGBM_DatasetSetFeatureNames(
}

int LGBM_DatasetGetFeatureNames(
DatasetHandle handle,
char** feature_names,
int* num_feature_names) {
API_BEGIN();
DatasetHandle handle,
const int len,
int* num_feature_names,
const size_t buffer_len,
size_t* out_buffer_len,
char** feature_names) {
API_BEGIN();
*out_buffer_len = 0;
auto dataset = reinterpret_cast<Dataset*>(handle);
auto inside_feature_name = dataset->feature_names();
*num_feature_names = static_cast<int>(inside_feature_name.size());
for (int i = 0; i < *num_feature_names; ++i) {
std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
if (i < len) {
std::memcpy(feature_names[i], inside_feature_name[i].c_str(), std::min(inside_feature_name[i].size() + 1, buffer_len));
feature_names[i][buffer_len - 1] = '\0';
}
*out_buffer_len = std::max(inside_feature_name[i].size() + 1, *out_buffer_len);
}
API_END();
}
Expand Down
7 changes: 6 additions & 1 deletion tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,20 @@ def check_asserts(data):
self.assertTrue(np.all(np.isclose([data.label[0], data.weight[0], data.init_score[0]],
data.label[0])))
self.assertAlmostEqual(data.label[1], data.weight[1])
self.assertListEqual(data.feature_name, data.get_feature_name())

X, y = load_breast_cancer(True)
sequence = np.ones(y.shape[0])
sequence[0] = np.nan
sequence[1] = np.inf
lgb_data = lgb.Dataset(X, sequence, weight=sequence, init_score=sequence).construct()
feature_names = ['f{0}'.format(i) for i in range(X.shape[1])]
lgb_data = lgb.Dataset(X, sequence,
weight=sequence, init_score=sequence,
feature_name=feature_names).construct()
check_asserts(lgb_data)
lgb_data = lgb.Dataset(X, y).construct()
lgb_data.set_label(sequence)
lgb_data.set_weight(sequence)
lgb_data.set_init_score(sequence)
lgb_data.set_feature_name(feature_names)
check_asserts(lgb_data)

0 comments on commit f30e0bb

Please sign in to comment.