Skip to content

Commit

Permalink
more stable sampling K from N
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Oct 9, 2017
1 parent 532ae9d commit 7999741
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions include/LightGBM/utils/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <random>
#include <vector>
#include <set>

namespace LightGBM {

Expand Down Expand Up @@ -65,29 +66,29 @@ class Random {
inline std::vector<int> Sample(int N, int K) {
std::vector<int> ret;
ret.reserve(K);
if (K > N || K < 0) {
if (K > N || K <= 0) {
return ret;
} else if (K == N) {
for (int i = 0; i < N; ++i) {
ret.push_back(i);
}
} else if (K > N / 2) {
} else if (K > 1 && K > (N / std::log2(K))) {
for (int i = 0; i < N; ++i) {
double prob = (K - ret.size()) / static_cast<double>(N - i);
if (NextFloat() < prob) {
ret.push_back(i);
}
}
} else {
int min_step = 1;
int avg_step = N / K;
int max_step = 2 * avg_step - min_step;
int start = -1;
for (int i = 0; i < K; ++i) {
int step = NextShort(min_step, max_step + 1);
start += step;
if (start >= N) { break; }
ret.push_back(start);
std::set<int> sample_set;
while (sample_set.size() < K) {
int next = RandInt32() % N;
if (sample_set.count(next) == 0) {
sample_set.insert(next);
}
}
for (auto iter = sample_set.begin(); iter != sample_set.end(); ++iter) {
ret.push_back(*iter);
}
}
return ret;
Expand Down

0 comments on commit 7999741

Please sign in to comment.