diff --git a/include/LightGBM/metric.h b/include/LightGBM/metric.h index 61d9fc99ea80..9d505d2768d1 100644 --- a/include/LightGBM/metric.h +++ b/include/LightGBM/metric.h @@ -103,6 +103,14 @@ class DCGCalculator { static double CalMaxDCGAtK(data_size_t k, const label_t* label, data_size_t num_data); + + /*! + * \brief Check the metadata for NDCG and lambdarank + * \param metadata Metadata + * \param num_queries Number of queries + */ + static void CheckMetadata(const Metadata& metadata, data_size_t num_queries); + /*! * \brief Check the label range for NDCG and lambdarank * \param label Pointer of label diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp index 63a1690906a2..49fc834b87df 100644 --- a/src/io/metadata.cpp +++ b/src/io/metadata.cpp @@ -277,6 +277,10 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector 0) { + Log::Debug("Number of queries in %s: %i. Average number of rows per query: %f.", + data_filename_.c_str(), static_cast(num_queries_), static_cast(num_data_) / num_queries_); + } } void Metadata::SetInitScore(const double* init_score, data_size_t len) { diff --git a/src/metric/dcg_calculator.cpp b/src/metric/dcg_calculator.cpp index 58843d89f9e1..1d648bfafd40 100644 --- a/src/metric/dcg_calculator.cpp +++ b/src/metric/dcg_calculator.cpp @@ -152,6 +152,19 @@ void DCGCalculator::CalDCG(const std::vector& ks, const label_t* la } } +void DCGCalculator::CheckMetadata(const Metadata& metadata, data_size_t num_queries) { + const data_size_t* query_boundaries = metadata.query_boundaries(); + if (num_queries > 0 && query_boundaries != nullptr) { + for (data_size_t i = 0; i < num_queries; i++) { + data_size_t num_rows = query_boundaries[i + 1] - query_boundaries[i]; + if (num_rows > kMaxPosition) { + Log::Fatal("Number of rows %i exceeds upper limit of %i for a query", static_cast(num_rows), static_cast(kMaxPosition)); + } + } + } +} + + void DCGCalculator::CheckLabel(const label_t* label, data_size_t num_data) { for (data_size_t i = 0; i < num_data; ++i) { label_t delta = std::fabs(label[i] - static_cast(label[i])); diff --git a/src/metric/rank_metric.hpp b/src/metric/rank_metric.hpp index 3b3afb547eb9..58804f415278 100644 --- a/src/metric/rank_metric.hpp +++ b/src/metric/rank_metric.hpp @@ -37,13 +37,14 @@ class NDCGMetric:public Metric { num_data_ = num_data; // get label label_ = metadata.label(); + num_queries_ = metadata.num_queries(); + DCGCalculator::CheckMetadata(metadata, num_queries_); DCGCalculator::CheckLabel(label_, num_data_); // get query boundaries query_boundaries_ = metadata.query_boundaries(); if (query_boundaries_ == nullptr) { Log::Fatal("The NDCG metric requires query information"); } - num_queries_ = metadata.num_queries(); // get query weights query_weights_ = metadata.query_weights(); if (query_weights_ == nullptr) { diff --git a/src/objective/rank_objective.hpp b/src/objective/rank_objective.hpp index a720a69a3148..9bd7b7d99cf6 100644 --- a/src/objective/rank_objective.hpp +++ b/src/objective/rank_objective.hpp @@ -120,6 +120,7 @@ class LambdarankNDCG : public RankingObjective { void Init(const Metadata& metadata, data_size_t num_data) override { RankingObjective::Init(metadata, num_data); + DCGCalculator::CheckMetadata(metadata, num_queries_); DCGCalculator::CheckLabel(label_, num_data_); inverse_max_dcgs_.resize(num_queries_); #pragma omp parallel for schedule(static)