diff --git a/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc index 4520513588267..a26017443fc99 100644 --- a/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc +++ b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc @@ -24,47 +24,33 @@ void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; } std::vector RandomSampler::sample_k(int k, const std::shared_ptr rng) { int n = edges->size(); - if (k > n) { + if (k >= n) { k = n; - } - std::vector sample_result; - for(int i = 0;i < k;i ++ ) { + std::vector sample_result; + for (int i = 0; i < k; i++) { sample_result.push_back(i); + } + return sample_result; } - if (k == n) { - return sample_result; - } - - std::uniform_int_distribution distrib(0, n - 1); + std::vector sample_result; std::unordered_map replace_map; + while (k--) { + std::uniform_int_distribution distrib(0, n - 1); + int rand_int = distrib(*rng); + auto iter = replace_map.find(rand_int); + if (iter == replace_map.end()) { + sample_result.push_back(rand_int); + } else { + sample_result.push_back(iter->second); + } - for(int i = 0; i < k; i ++) { - int j = distrib(*rng); - if (j >= i) { - // buff_nid[offset + i] = nid[j] if m.find(j) == m.end() else nid[m[j]] - auto iter_j = replace_map.find(j); - if(iter_j == replace_map.end()) { - sample_result[i] = j; - } else { - sample_result[i] = iter_j -> second; - } - // m[j] = i if m.find(i) == m.end() else m[i] - auto iter_i = replace_map.find(i); - if(iter_i == replace_map.end()) { - replace_map[j] = i; - } else { - replace_map[j] = (iter_i -> second); - } + iter = replace_map.find(n - 1); + if (iter == replace_map.end()) { + replace_map[rand_int] = n - 1; } else { - sample_result[i] = sample_result[j]; - // buff_nid[offset + j] = nid[i] if m.find(i) == m.end() else nid[m[i]] - auto iter_i = replace_map.find(i); - if(iter_i == replace_map.end()) { - sample_result[j] = i; - } else { - sample_result[j] = (iter_i -> second); - } + replace_map[rand_int] = iter->second; } + --n; } return sample_result; }