Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gpugraph #46

Merged
merged 3 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions paddle/fluid/distributed/ps/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1763,8 +1763,8 @@ int GraphTable::parse_feature(int idx, const std::string& feat_str,
return -1;
}

std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int slice_num) {
std::vector<std::vector<uint64_t>> res(slice_num);
int GraphTable::get_all_id(int type_id, int slice_num, std::vector<std::vector<uint64_t>> *output) {
output->resize(slice_num);
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
std::vector<std::future<std::vector<uint64_t>>> tasks;
for (int idx = 0; idx < search_shards.size(); idx++) {
Expand All @@ -1781,14 +1781,14 @@ std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int slice
for (size_t i = 0; i < tasks.size(); i++) {
auto ids = tasks[i].get();
for (auto &id : ids) {
res[(uint64_t)(id) % slice_num].push_back(id);
(*output)[(uint64_t)(id) % slice_num].push_back(id);
}
}
return res;
return 0;
}

std::vector<std::vector<uint64_t>> GraphTable::get_all_neighbor_id(int type_id, int slice_num) {
std::vector<std::vector<uint64_t>> res(slice_num);
int GraphTable::get_all_neighbor_id(int type_id, int slice_num, std::vector<std::vector<uint64_t>> *output) {
output->resize(slice_num);
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
std::vector<std::future<std::vector<uint64_t>>> tasks;
for (int idx = 0; idx < search_shards.size(); idx++) {
Expand All @@ -1805,15 +1805,15 @@ std::vector<std::vector<uint64_t>> GraphTable::get_all_neighbor_id(int type_id,
for (size_t i = 0; i < tasks.size(); i++) {
auto ids = tasks[i].get();
for (auto &id : ids) {
res[(uint64_t)(id) % slice_num].push_back(id);
(*output)[(uint64_t)(id) % slice_num].push_back(id);
}
}
return res;
return 0;
}

std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int idx,
int slice_num) {
std::vector<std::vector<uint64_t>> res(slice_num);
int GraphTable::get_all_id(int type_id, int idx,
int slice_num, std::vector<std::vector<uint64_t>> *output) {
output->resize(slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<std::vector<uint64_t>>> tasks;
VLOG(0) << "begin task, task_pool_size_[" << task_pool_size_ << "]";
Expand All @@ -1829,14 +1829,14 @@ std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int idx,
VLOG(0) << "end task, task_pool_size_[" << task_pool_size_ << "]";
for (size_t i = 0; i < tasks.size(); i++) {
auto ids = tasks[i].get();
for (auto &id : ids) res[id % slice_num].push_back(id);
for (auto &id : ids) (*output)[id % slice_num].push_back(id);
}
return res;
return 0;
}

std::vector<std::vector<uint64_t>> GraphTable::get_all_neighbor_id(int type_id, int idx,
int slice_num) {
std::vector<std::vector<uint64_t>> res(slice_num);
int GraphTable::get_all_neighbor_id(int type_id, int idx,
int slice_num, std::vector<std::vector<uint64_t>> *output) {
output->resize(slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<std::vector<uint64_t>>> tasks;
VLOG(0) << "begin task, task_pool_size_[" << task_pool_size_ << "]";
Expand All @@ -1852,9 +1852,9 @@ std::vector<std::vector<uint64_t>> GraphTable::get_all_neighbor_id(int type_id,
VLOG(0) << "end task, task_pool_size_[" << task_pool_size_ << "]";
for (size_t i = 0; i < tasks.size(); i++) {
auto ids = tasks[i].get();
for (auto &id : ids) res[id % slice_num].push_back(id);
for (auto &id : ids) (*output)[id % slice_num].push_back(id);
}
return res;
return 0;
}

int GraphTable::get_all_feature_ids(int type_id, int idx, int slice_num,
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/distributed/ps/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -503,12 +503,12 @@ class GraphTable : public Table {
int32_t load_edges(const std::string &path, bool reverse,
const std::string &edge_type);

std::vector<std::vector<uint64_t>> get_all_id(int type, int slice_num);
std::vector<std::vector<uint64_t>> get_all_neighbor_id(int type, int slice_num);
std::vector<std::vector<uint64_t>> get_all_id(int type, int idx,
int slice_num);
std::vector<std::vector<uint64_t>> get_all_neighbor_id(int type_id, int idx,
int slice_num);
int get_all_id(int type, int slice_num, std::vector<std::vector<uint64_t>> *output);
int get_all_neighbor_id(int type, int slice_num, std::vector<std::vector<uint64_t>> *output);
int get_all_id(int type, int idx,
int slice_num, std::vector<std::vector<uint64_t>> *output);
int get_all_neighbor_id(int type_id, int id,
int slice_num, std::vector<std::vector<uint64_t>> *output);
int get_all_feature_ids(int type, int idx,
int slice_num, std::vector<std::vector<uint64_t>>* output);
int32_t load_nodes(const std::string &path, std::string node_type = std::string());
Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ int CtrDymfAccessor::Initialize() {
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim());
common_feature_value.optimizer_name = name;

common_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim();
common_feature_value.embedx_dim = _config.embedx_dim();
Expand Down Expand Up @@ -182,7 +183,8 @@ int32_t CtrDymfAccessor::Create(float** values, size_t num) {
value[common_feature_value.SlotIndex()] = -1;
value[common_feature_value.MfDimIndex()] = -1;
_embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(),
value + common_feature_value.EmbedG2SumIndex());
value + common_feature_value.EmbedG2SumIndex(),
false); // adam embed init not zero, adagrad embed init zero
_embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
value + common_feature_value.EmbedxG2SumIndex(),
false);
Expand Down Expand Up @@ -288,18 +290,17 @@ std::string CtrDymfAccessor::ParseToString(const float* v, int param) {
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4];
// << v[5] << " " << v[6];
for (int i = common_feature_value.EmbedG2SumIndex();
i < common_feature_value.SlotIndex(); i++) {
i < common_feature_value.EmbedxG2SumIndex(); i++) {
os << " " << v[i];
}
os << " " << common_feature_value.Slot(const_cast<float*>(v)) << " "
<< common_feature_value.MfDim(const_cast<float*>(v));
auto show = common_feature_value.Show(const_cast<float*>(v));
auto click = common_feature_value.Click(const_cast<float*>(v));
auto score = ShowClickScore(show, click);
auto mf_dim = int(common_feature_value.MfDim(const_cast<float*>(v)));
if (score >= _config.embedx_threshold() &&
param > common_feature_value.EmbedxG2SumIndex()) {
for (auto i = common_feature_value.EmbedxG2SumIndex();
i < common_feature_value.Dim(); ++i) {
i < common_feature_value.Dim(mf_dim); ++i) {
os << " " << v[i];
}
}
Expand Down
16 changes: 16 additions & 0 deletions paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,21 @@ class CtrDymfAccessor : public ValueAccessor {
int EmbedxG2SumIndex() { return MfDimIndex() + 1; }
int EmbedxWIndex() { return EmbedxG2SumIndex() + embedx_sgd_dim; }

// 根据mf_dim计算的总长度
int Dim(int& mf_dim) {
int tmp_embedx_sgd_dim = 1;
if (optimizer_name == "SparseAdamSGDRule") {//adam
tmp_embedx_sgd_dim = mf_dim * 2 + 2;
} else if (optimizer_name == "SparseSharedAdamSGDRule") { //shared_adam
tmp_embedx_sgd_dim = 4;
}
return 7 + embed_sgd_dim + tmp_embedx_sgd_dim + mf_dim;
}

// 根据mf_dim计算的总byte数
int Size(int& mf_dim) { return (Dim(mf_dim)) * sizeof(float); }


float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; }
float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; }
float& Show(float* val) { return val[ShowIndex()]; }
Expand All @@ -71,6 +86,7 @@ class CtrDymfAccessor : public ValueAccessor {
int embed_sgd_dim;
int embedx_dim;
int embedx_sgd_dim;
std::string optimizer_name;
};

struct CtrDymfPushValue {
Expand Down
101 changes: 56 additions & 45 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ int GraphDataGenerator::FillInsBuf() {
delete[] h_ins_buf;

if (!FLAGS_enable_opt_get_features && slot_num_ > 0) {
uint64_t *feature_buf = reinterpret_cast<uint64_t *>(d_feature_buf_->ptr());
uint64_t *feature_buf =
reinterpret_cast<uint64_t *>(d_feature_buf_->ptr());
uint64_t h_feature_buf[(batch_size_ * 2 * 2) * slot_num_];
cudaMemcpy(h_feature_buf, feature_buf,
(batch_size_ * 2 * 2) * slot_num_ * sizeof(uint64_t),
Expand All @@ -383,65 +384,67 @@ int GraphDataGenerator::FillInsBuf() {
}

int GraphDataGenerator::GenerateBatch() {
int total_instance = 0;
platform::CUDADeviceGuard guard(gpuid_);
int res = 0;
if (!gpu_graph_training_) {
while (cursor_ < h_device_keys_.size()) {
size_t device_key_size = h_device_keys_[cursor_]->size();
if (infer_node_type_start_[cursor_] >= device_key_size) {
cursor_++;
continue;
}
int total_instance =
total_instance =
(infer_node_type_start_[cursor_] + batch_size_ <= device_key_size)
? batch_size_
: device_key_size - infer_node_type_start_[cursor_];
uint64_t *d_type_keys =
reinterpret_cast<uint64_t *>(d_device_keys_[cursor_]->ptr());
d_type_keys += infer_node_type_start_[cursor_];
infer_node_type_start_[cursor_] += total_instance;
VLOG(1) << "in graph_data generator:batch_size = " << batch_size_
<< " instance = " << total_instance;
total_instance *= 2;
id_tensor_ptr_ = feed_vec_[0]->mutable_data<int64_t>({total_instance, 1},
this->place_);
show_tensor_ptr_ =
feed_vec_[1]->mutable_data<int64_t>({total_instance}, this->place_);
clk_tensor_ptr_ =
feed_vec_[2]->mutable_data<int64_t>({total_instance}, this->place_);
/*
cudaMemcpyAsync(id_tensor_ptr_, d_type_keys, sizeof(int64_t) * total_instance,
cudaMemcpyDeviceToDevice, stream_);
*/
CopyDuplicateKeys<<<GET_BLOCKS(total_instance / 2), CUDA_NUM_THREADS, 0,
stream_>>>(id_tensor_ptr_, d_type_keys,
total_instance / 2);
GraphFillCVMKernel<<<GET_BLOCKS(total_instance), CUDA_NUM_THREADS, 0,
stream_>>>(show_tensor_ptr_, total_instance);
GraphFillCVMKernel<<<GET_BLOCKS(total_instance), CUDA_NUM_THREADS, 0,
stream_>>>(clk_tensor_ptr_, total_instance);
return total_instance / 2;
break;
}
return 0;
}
platform::CUDADeviceGuard guard(gpuid_);
int res = 0;
while (ins_buf_pair_len_ < batch_size_) {
res = FillInsBuf();
if (res == -1) {
if (ins_buf_pair_len_ == 0) {
return 0;
} else {
break;
if (total_instance == 0) {
return 0;
}
} else {
while (ins_buf_pair_len_ < batch_size_) {
res = FillInsBuf();
if (res == -1) {
if (ins_buf_pair_len_ == 0) {
return 0;
} else {
break;
}
}
}
}
int total_instance =
ins_buf_pair_len_ < batch_size_ ? ins_buf_pair_len_ : batch_size_;
total_instance =
ins_buf_pair_len_ < batch_size_ ? ins_buf_pair_len_ : batch_size_;

total_instance *= 2;
id_tensor_ptr_ =
feed_vec_[0]->mutable_data<int64_t>({total_instance, 1}, this->place_);
show_tensor_ptr_ =
feed_vec_[1]->mutable_data<int64_t>({total_instance}, this->place_);
clk_tensor_ptr_ =
feed_vec_[2]->mutable_data<int64_t>({total_instance}, this->place_);
total_instance *= 2;
id_tensor_ptr_ =
feed_vec_[0]->mutable_data<int64_t>({total_instance, 1}, this->place_);
show_tensor_ptr_ =
feed_vec_[1]->mutable_data<int64_t>({total_instance}, this->place_);
clk_tensor_ptr_ =
feed_vec_[2]->mutable_data<int64_t>({total_instance}, this->place_);
}

int64_t *slot_tensor_ptr_[slot_num_];
int64_t *slot_lod_tensor_ptr_[slot_num_];
Expand All @@ -452,7 +455,7 @@ int GraphDataGenerator::GenerateBatch() {
slot_lod_tensor_ptr_[i] = feed_vec_[3 + 2 * i + 1]->mutable_data<int64_t>(
{total_instance + 1}, this->place_);
}
if (FLAGS_enable_opt_get_features) {
if (FLAGS_enable_opt_get_features || !gpu_graph_training_) {
cudaMemcpyAsync(d_slot_tensor_ptr_->ptr(), slot_tensor_ptr_,
sizeof(uint64_t *) * slot_num_, cudaMemcpyHostToDevice,
stream_);
Expand All @@ -462,22 +465,31 @@ int GraphDataGenerator::GenerateBatch() {
}
}

VLOG(2) << "total_instance: " << total_instance
<< ", ins_buf_pair_len = " << ins_buf_pair_len_;
uint64_t *ins_buf = reinterpret_cast<uint64_t *>(d_ins_buf_->ptr());
uint64_t *ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance;
cudaMemcpyAsync(id_tensor_ptr_, ins_cursor, sizeof(uint64_t) * total_instance,
cudaMemcpyDeviceToDevice, stream_);
uint64_t *ins_cursor, *ins_buf;
if (gpu_graph_training_) {
VLOG(2) << "total_instance: " << total_instance
<< ", ins_buf_pair_len = " << ins_buf_pair_len_;
// uint64_t *ins_buf = reinterpret_cast<uint64_t *>(d_ins_buf_->ptr());
// uint64_t *ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance;
ins_buf = reinterpret_cast<uint64_t *>(d_ins_buf_->ptr());
ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance;
cudaMemcpyAsync(id_tensor_ptr_, ins_cursor,
sizeof(uint64_t) * total_instance, cudaMemcpyDeviceToDevice,
stream_);

GraphFillCVMKernel<<<GET_BLOCKS(total_instance), CUDA_NUM_THREADS, 0,
stream_>>>(show_tensor_ptr_, total_instance);
GraphFillCVMKernel<<<GET_BLOCKS(total_instance), CUDA_NUM_THREADS, 0,
stream_>>>(clk_tensor_ptr_, total_instance);
GraphFillCVMKernel<<<GET_BLOCKS(total_instance), CUDA_NUM_THREADS, 0,
stream_>>>(show_tensor_ptr_, total_instance);
GraphFillCVMKernel<<<GET_BLOCKS(total_instance), CUDA_NUM_THREADS, 0,
stream_>>>(clk_tensor_ptr_, total_instance);
} else {
ins_cursor = (uint64_t *)id_tensor_ptr_;
}

if (slot_num_ > 0) {
uint64_t *feature_buf = reinterpret_cast<uint64_t *>(d_feature_buf_->ptr());
if (FLAGS_enable_opt_get_features) {
if (FLAGS_enable_opt_get_features || !gpu_graph_training_) {
FillFeatureBuf(ins_cursor, feature_buf, total_instance);
// FillFeatureBuf(id_tensor_ptr_, feature_buf, total_instance);
if (debug_mode_) {
uint64_t h_walk[total_instance];
cudaMemcpy(h_walk, ins_cursor, total_instance * sizeof(uint64_t),
Expand Down Expand Up @@ -538,10 +550,9 @@ int GraphDataGenerator::GenerateBatch() {
}
}

ins_buf_pair_len_ -= total_instance / 2;

cudaStreamSynchronize(stream_);

if (!gpu_graph_training_) return total_instance / 2;
ins_buf_pair_len_ -= total_instance / 2;
if (debug_mode_) {
uint64_t h_slot_tensor[slot_num_][total_instance];
uint64_t h_slot_lod_tensor[slot_num_][total_instance + 1];
Expand Down Expand Up @@ -704,8 +715,8 @@ int GraphDataGenerator::FillFeatureBuf(

auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
int ret = gpu_graph_ptr->get_feature_of_nodes(
gpuid_, (uint64_t *)d_walk->ptr(), (uint64_t *)d_feature->ptr(), buf_size_,
slot_num_);
gpuid_, (uint64_t *)d_walk->ptr(), (uint64_t *)d_feature->ptr(),
buf_size_, slot_num_);
return ret;
}

Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,8 @@ void DatasetImpl<T>::LoadIntoMemory() {
int cnt = 0;
for (auto& iter : node_to_id) {
int node_idx = iter.second;
auto gpu_graph_device_keys =
gpu_graph_ptr->get_all_id(1, node_idx, thread_num_);
std::vector<std::vector<uint64_t>> gpu_graph_device_keys;
gpu_graph_ptr->get_all_id(1, node_idx, thread_num_, &gpu_graph_device_keys);
auto& type_total_key = graph_all_type_total_keys_[cnt];
type_total_key.resize(thread_num_);
for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) {
Expand Down Expand Up @@ -500,8 +500,8 @@ void DatasetImpl<T>::LoadIntoMemory() {
// FIX: trick for iterate edge table
for (auto& iter : edge_to_id) {
int edge_idx = iter.second;
auto gpu_graph_device_keys =
gpu_graph_ptr->get_all_id(0, edge_idx, thread_num_);
std::vector<std::vector<uint64_t>> gpu_graph_device_keys;
gpu_graph_ptr->get_all_id(0, edge_idx, thread_num_, &gpu_graph_device_keys);
for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) {
VLOG(1) << "edge type: " << edge_idx << ", gpu_graph_device_keys[" << i
<< "] = " << gpu_graph_device_keys[i].size();
Expand All @@ -510,7 +510,8 @@ void DatasetImpl<T>::LoadIntoMemory() {
}
}
if (FLAGS_graph_get_neighbor_id) {
auto gpu_graph_neighbor_keys = gpu_graph_ptr->get_all_neighbor_id(0, edge_idx, thread_num_);
std::vector<std::vector<uint64_t>> gpu_graph_neighbor_keys;
gpu_graph_ptr->get_all_neighbor_id(0, edge_idx, thread_num_, &gpu_graph_neighbor_keys);
for (size_t i = 0; i < gpu_graph_neighbor_keys.size(); i++) {
for (size_t k = 0; k < gpu_graph_neighbor_keys[i].size(); k++) {
gpu_graph_total_keys_.push_back(gpu_graph_neighbor_keys[i][k]);
Expand Down
Loading