Skip to content

Commit

Permalink
pull push dedup, use pull feature value clip size (PaddlePaddle#62)
Browse files Browse the repository at this point in the history
pull push dedup, clip pull feature value
  • Loading branch information
qingshui authored Jul 12, 2022
1 parent 84d4b98 commit 15aeb84
Show file tree
Hide file tree
Showing 14 changed files with 1,749 additions and 859 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/ps/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_node_file(
auto node = feature_shards[idx][index]->add_feature_node(id, false);
if (node != NULL) {
node->set_feature_size(feat_name[idx].size());
for (int i = 1; i < n; ++i) {
for (int i = 1; i < num; ++i) {
auto &v = vals[i];
parse_feature(idx, v.ptr, v.len, node);
}
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/feature_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,24 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor {
int optimizer_type_;
};

struct CommonPullValue {
/*
float show;
float click;
float embed_w;
float mf_size
std::vector<float> embedx_w;
*/
__host__ __device__ int ShowIndex() { return 0; }
__host__ __device__ int ClickIndex() { return 1; }
__host__ __device__ int EmbedWIndex() { return 2; }
__host__ __device__ int MfSizeIndex() { return 3; } // actual mf size (ex. 0)
__host__ __device__ int EmbedxWIndex() { return 4; }
__host__ __device__ int Size(const int mf_dim) {
return (4 + mf_dim) * sizeof(float);
}
};

struct CommonPushValue {
/*
float slot;
Expand Down Expand Up @@ -251,6 +269,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor {
public:
CommonFeatureValue common_feature_value;
CommonPushValue common_push_value;
CommonPullValue common_pull_value;
};


Expand Down
41 changes: 14 additions & 27 deletions paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,37 +90,24 @@ __global__ void dy_mf_search_kernel(Table* table,
// return;
if (i < len) {
auto it = table->find(keys[i]);

if (it != table->end()) {
uint64_t offset = i * pull_feature_value_size;
float* cur = (float*)(vals + offset);
float* input = it->second;
int mf_dim = int(input[feature_value_accessor.common_feature_value.MfDimIndex()]);

*(reinterpret_cast<uint64_t*>(cur + feature_value_accessor.common_feature_value.CpuPtrIndex())) =
*(reinterpret_cast<uint64_t*>(input + feature_value_accessor.common_feature_value.CpuPtrIndex()));
cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] =
input[feature_value_accessor.common_feature_value.DeltaScoreIndex()];
cur[feature_value_accessor.common_feature_value.ShowIndex()] =
input[feature_value_accessor.common_feature_value.ShowIndex()];
cur[feature_value_accessor.common_feature_value.ClickIndex()] =
input[feature_value_accessor.common_feature_value.ClickIndex()];
cur[feature_value_accessor.common_feature_value.EmbedWIndex()] =
input[feature_value_accessor.common_feature_value.EmbedWIndex()];
for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedDim(); x++) {
cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + x] =
input[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + x];
}
cur[feature_value_accessor.common_feature_value.SlotIndex()] =
input[feature_value_accessor.common_feature_value.SlotIndex()];
cur[feature_value_accessor.common_feature_value.MfDimIndex()] =
input[feature_value_accessor.common_feature_value.MfDimIndex()];
cur[feature_value_accessor.common_feature_value.MfSizeIndex()] =
input[feature_value_accessor.common_feature_value.MfSizeIndex()];

for (int x = feature_value_accessor.common_feature_value.EmbedxG2SumIndex();
x < int(feature_value_accessor.common_feature_value.Size(mf_dim) / sizeof(float)); x++){
cur[x] = input[x];

cur[feature_value_accessor.common_pull_value.ShowIndex()] =
input[feature_value_accessor.common_feature_value.ShowIndex()];
cur[feature_value_accessor.common_pull_value.ClickIndex()] =
input[feature_value_accessor.common_feature_value.ClickIndex()];
cur[feature_value_accessor.common_pull_value.EmbedWIndex()] =
input[feature_value_accessor.common_feature_value.EmbedWIndex()];
int embedx_dim = int(input[feature_value_accessor.common_feature_value.MfSizeIndex()]);
cur[feature_value_accessor.common_pull_value.MfSizeIndex()] = embedx_dim;

int embedx_off = feature_value_accessor.common_pull_value.EmbedxWIndex();
int value_off = feature_value_accessor.common_feature_value.EmbedxWOffsetIndex(input);
for (int i = 0; i < embedx_dim; ++i) {
cur[embedx_off + i] = input[value_off + i];
}
}
}
Expand Down
38 changes: 28 additions & 10 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <thread>
#include <vector>

#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
#if defined(PADDLE_WITH_CUDA)
Expand All @@ -25,6 +26,7 @@ limitations under the License. */
#include "thrust/pair.h"
#elif defined(PADDLE_WITH_XPU_KP)
#include <xpu/runtime.h>

#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif

Expand All @@ -47,7 +49,7 @@ template <typename KeyType, typename ValType, typename GradType>
class HeterComm {
public:
HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource);
HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource,
HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource,
CommonFeatureValueAccessor& accessor);
virtual ~HeterComm();
HeterComm(const HeterComm&) = delete;
Expand All @@ -61,18 +63,19 @@ class HeterComm {
uint32_t* d_restore_idx,
size_t & uniq_len);
void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
int& uniq_len); // NOLINT
int& uniq_len); // NOLINT
void dynamic_merge_grad(int gpu_num, KeyType* d_keys, float* d_grads,
size_t len, int& uniq_len, size_t& segment_len, bool enable_segment_merge_grad);
size_t len, int& uniq_len, size_t& segment_len,
bool enable_segment_merge_grad);
void segment_merge_grad(int gpu_num, KeyType* d_keys, float* d_grads,
const uint32_t* d_index, size_t len,
const uint32_t* d_fea_num_info,
size_t uniq_len, size_t& segment_len);
const uint32_t* d_index, size_t len,
const uint32_t* d_fea_num_info, size_t uniq_len,
size_t& segment_len);
void pull_sparse(int num, KeyType* d_keys, float* d_vals, size_t len);
void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len,
size_t chunk_size, int stream_num, int offset = -1);
size_t chunk_size, int stream_num, int offset = -1);
void build_ps(int num, KeyType* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size, int stream_num);
size_t feature_value_size, size_t chunk_size, int stream_num);
void dump();
void show_one_table(int gpu_num);
void show_table_collisions();
Expand Down Expand Up @@ -124,7 +127,7 @@ class HeterComm {
}

void set_accessor(CommonFeatureValueAccessor& accessor) {
feature_value_accessor_ = accessor;
feature_value_accessor_ = accessor;
}
#endif

Expand All @@ -137,6 +140,19 @@ class HeterComm {
int get_transfer_devid(int send_id) { return (send_id + 4) % 8; }

void end_pass();
#if defined(PADDLE_WITH_CUDA)
// dedup
int dedup_keys_and_fillidx(const int gpu_id,
const int total_fea_num,
const KeyType* d_keys, // input
KeyType* d_merged_keys, // output
KeyType* d_sorted_keys,
uint32_t* d_restore_idx,
uint32_t* d_sorted_idx,
uint32_t* d_offset,
uint32_t* d_merged_cnts,
bool filter_zero);
#endif

struct Node {
ppStream in_stream;
Expand Down Expand Up @@ -243,7 +259,9 @@ class HeterComm {
ValType* src_val);
void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right,
char* src_val, size_t val_size);

protected:
void pull_merge_sparse(int num, KeyType* d_keys, float* d_vals, size_t len);
void pull_normal_sparse(int num, KeyType* d_keys, float* d_vals, size_t len);

protected:
using Table = HashTable<KeyType, ValType>;
Expand Down
Loading

0 comments on commit 15aeb84

Please sign in to comment.