From 851d336d3b124f877be4b60502d6e9182c0952d8 Mon Sep 17 00:00:00 2001 From: seemingwang Date: Mon, 27 Jun 2022 22:32:45 +0800 Subject: [PATCH 1/3] fix slot_feature in infer (#47) * fix slot_feature in infer --- paddle/fluid/framework/data_feed.cu | 101 +++++++++++++----------- paddle/fluid/framework/device_worker.cc | 10 ++- 2 files changed, 64 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index 6f9f4464a45de..5647ad107bd53 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -369,7 +369,8 @@ int GraphDataGenerator::FillInsBuf() { delete[] h_ins_buf; if (!FLAGS_enable_opt_get_features && slot_num_ > 0) { - uint64_t *feature_buf = reinterpret_cast(d_feature_buf_->ptr()); + uint64_t *feature_buf = + reinterpret_cast(d_feature_buf_->ptr()); uint64_t h_feature_buf[(batch_size_ * 2 * 2) * slot_num_]; cudaMemcpy(h_feature_buf, feature_buf, (batch_size_ * 2 * 2) * slot_num_ * sizeof(uint64_t), @@ -383,6 +384,9 @@ int GraphDataGenerator::FillInsBuf() { } int GraphDataGenerator::GenerateBatch() { + int total_instance = 0; + platform::CUDADeviceGuard guard(gpuid_); + int res = 0; if (!gpu_graph_training_) { while (cursor_ < h_device_keys_.size()) { size_t device_key_size = h_device_keys_[cursor_]->size(); @@ -390,7 +394,7 @@ int GraphDataGenerator::GenerateBatch() { cursor_++; continue; } - int total_instance = + total_instance = (infer_node_type_start_[cursor_] + batch_size_ <= device_key_size) ? batch_size_ : device_key_size - infer_node_type_start_[cursor_]; @@ -398,6 +402,8 @@ int GraphDataGenerator::GenerateBatch() { reinterpret_cast(d_device_keys_[cursor_]->ptr()); d_type_keys += infer_node_type_start_[cursor_]; infer_node_type_start_[cursor_] += total_instance; + VLOG(1) << "in graph_data generator:batch_size = " << batch_size_ + << " instance = " << total_instance; total_instance *= 2; id_tensor_ptr_ = feed_vec_[0]->mutable_data({total_instance, 1}, this->place_); @@ -405,10 +411,6 @@ int GraphDataGenerator::GenerateBatch() { feed_vec_[1]->mutable_data({total_instance}, this->place_); clk_tensor_ptr_ = feed_vec_[2]->mutable_data({total_instance}, this->place_); - /* - cudaMemcpyAsync(id_tensor_ptr_, d_type_keys, sizeof(int64_t) * total_instance, - cudaMemcpyDeviceToDevice, stream_); - */ CopyDuplicateKeys<<>>(id_tensor_ptr_, d_type_keys, total_instance / 2); @@ -416,32 +418,33 @@ int GraphDataGenerator::GenerateBatch() { stream_>>>(show_tensor_ptr_, total_instance); GraphFillCVMKernel<<>>(clk_tensor_ptr_, total_instance); - return total_instance / 2; + break; } - return 0; - } - platform::CUDADeviceGuard guard(gpuid_); - int res = 0; - while (ins_buf_pair_len_ < batch_size_) { - res = FillInsBuf(); - if (res == -1) { - if (ins_buf_pair_len_ == 0) { - return 0; - } else { - break; + if (total_instance == 0) { + return 0; + } + } else { + while (ins_buf_pair_len_ < batch_size_) { + res = FillInsBuf(); + if (res == -1) { + if (ins_buf_pair_len_ == 0) { + return 0; + } else { + break; + } } } - } - int total_instance = - ins_buf_pair_len_ < batch_size_ ? ins_buf_pair_len_ : batch_size_; + total_instance = + ins_buf_pair_len_ < batch_size_ ? ins_buf_pair_len_ : batch_size_; - total_instance *= 2; - id_tensor_ptr_ = - feed_vec_[0]->mutable_data({total_instance, 1}, this->place_); - show_tensor_ptr_ = - feed_vec_[1]->mutable_data({total_instance}, this->place_); - clk_tensor_ptr_ = - feed_vec_[2]->mutable_data({total_instance}, this->place_); + total_instance *= 2; + id_tensor_ptr_ = + feed_vec_[0]->mutable_data({total_instance, 1}, this->place_); + show_tensor_ptr_ = + feed_vec_[1]->mutable_data({total_instance}, this->place_); + clk_tensor_ptr_ = + feed_vec_[2]->mutable_data({total_instance}, this->place_); + } int64_t *slot_tensor_ptr_[slot_num_]; int64_t *slot_lod_tensor_ptr_[slot_num_]; @@ -452,7 +455,7 @@ int GraphDataGenerator::GenerateBatch() { slot_lod_tensor_ptr_[i] = feed_vec_[3 + 2 * i + 1]->mutable_data( {total_instance + 1}, this->place_); } - if (FLAGS_enable_opt_get_features) { + if (FLAGS_enable_opt_get_features || !gpu_graph_training_) { cudaMemcpyAsync(d_slot_tensor_ptr_->ptr(), slot_tensor_ptr_, sizeof(uint64_t *) * slot_num_, cudaMemcpyHostToDevice, stream_); @@ -462,22 +465,31 @@ int GraphDataGenerator::GenerateBatch() { } } - VLOG(2) << "total_instance: " << total_instance - << ", ins_buf_pair_len = " << ins_buf_pair_len_; - uint64_t *ins_buf = reinterpret_cast(d_ins_buf_->ptr()); - uint64_t *ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance; - cudaMemcpyAsync(id_tensor_ptr_, ins_cursor, sizeof(uint64_t) * total_instance, - cudaMemcpyDeviceToDevice, stream_); + uint64_t *ins_cursor, *ins_buf; + if (gpu_graph_training_) { + VLOG(2) << "total_instance: " << total_instance + << ", ins_buf_pair_len = " << ins_buf_pair_len_; + // uint64_t *ins_buf = reinterpret_cast(d_ins_buf_->ptr()); + // uint64_t *ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance; + ins_buf = reinterpret_cast(d_ins_buf_->ptr()); + ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance; + cudaMemcpyAsync(id_tensor_ptr_, ins_cursor, + sizeof(uint64_t) * total_instance, cudaMemcpyDeviceToDevice, + stream_); - GraphFillCVMKernel<<>>(show_tensor_ptr_, total_instance); - GraphFillCVMKernel<<>>(clk_tensor_ptr_, total_instance); + GraphFillCVMKernel<<>>(show_tensor_ptr_, total_instance); + GraphFillCVMKernel<<>>(clk_tensor_ptr_, total_instance); + } else { + ins_cursor = (uint64_t *)id_tensor_ptr_; + } if (slot_num_ > 0) { uint64_t *feature_buf = reinterpret_cast(d_feature_buf_->ptr()); - if (FLAGS_enable_opt_get_features) { + if (FLAGS_enable_opt_get_features || !gpu_graph_training_) { FillFeatureBuf(ins_cursor, feature_buf, total_instance); + // FillFeatureBuf(id_tensor_ptr_, feature_buf, total_instance); if (debug_mode_) { uint64_t h_walk[total_instance]; cudaMemcpy(h_walk, ins_cursor, total_instance * sizeof(uint64_t), @@ -538,10 +550,9 @@ int GraphDataGenerator::GenerateBatch() { } } - ins_buf_pair_len_ -= total_instance / 2; - cudaStreamSynchronize(stream_); - + if (!gpu_graph_training_) return total_instance / 2; + ins_buf_pair_len_ -= total_instance / 2; if (debug_mode_) { uint64_t h_slot_tensor[slot_num_][total_instance]; uint64_t h_slot_lod_tensor[slot_num_][total_instance + 1]; @@ -704,8 +715,8 @@ int GraphDataGenerator::FillFeatureBuf( auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); int ret = gpu_graph_ptr->get_feature_of_nodes( - gpuid_, (uint64_t *)d_walk->ptr(), (uint64_t *)d_feature->ptr(), buf_size_, - slot_num_); + gpuid_, (uint64_t *)d_walk->ptr(), (uint64_t *)d_feature->ptr(), + buf_size_, slot_num_); return ret; } diff --git a/paddle/fluid/framework/device_worker.cc b/paddle/fluid/framework/device_worker.cc index fecf16b3baabe..e13004eaf1500 100644 --- a/paddle/fluid/framework/device_worker.cc +++ b/paddle/fluid/framework/device_worker.cc @@ -273,8 +273,12 @@ void DeviceWorker::DumpField(const Scope& scope, int dump_mode, } auto set_output_str = [&, this](size_t begin, size_t end, LoDTensor* tensor) { + std::pair bound; + auto& dims = tensor->dims(); for (size_t i = begin; i < end; ++i) { - auto bound = GetTensorBound(tensor, i); + bound = {i * dims[1], (i + 1) * dims[1]}; + // auto bound = GetTensorBound(tensor, i); + if (ars[i].size() > 0) ars[i] += "\t"; // ars[i] += '['; PrintLodTensor(tensor, bound.first, bound.second, ars[i], ' ', false); @@ -303,10 +307,12 @@ void DeviceWorker::DumpField(const Scope& scope, int dump_mode, cpu_tensor.set_lod(tensor->lod()); tensor = &cpu_tensor; } - if (!CheckValidOutput(tensor, batch_size)) { + auto& dims = tensor->dims(); + if (dims.size() != 2 || dims[0] != static_cast(batch_size)) { VLOG(0) << "Note: field[" << field << "] cannot pass check, so it was " "skipped. Maybe the dimension is " "wrong "; + VLOG(0) << dims.size() << " " << dims[0] << " * " << dims[1]; continue; } size_t acutal_thread_num = From 87d49ec5887a2f3ea85b6a3530c5e900fce4c179 Mon Sep 17 00:00:00 2001 From: miaoli06 <106585574+miaoli06@users.noreply.github.com> Date: Tue, 28 Jun 2022 12:09:38 +0800 Subject: [PATCH 2/3] opt cpu to gpu load (#50) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * opt cpu to gpu load:1. remove vector copy;2. parallel * remove useless code --- .../ps/table/common_graph_table.cc | 36 ++++---- .../distributed/ps/table/common_graph_table.h | 12 +-- paddle/fluid/framework/data_set.cc | 11 +-- .../fleet/heter_ps/graph_gpu_wrapper.cu | 84 +++++++++++-------- .../fleet/heter_ps/graph_gpu_wrapper.h | 17 ++-- paddle/fluid/pybind/fleet_py.cc | 13 +-- 6 files changed, 95 insertions(+), 78 deletions(-) diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index 8527c5031e70d..dd26f7ec41d92 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -1763,8 +1763,8 @@ int GraphTable::parse_feature(int idx, const std::string& feat_str, return -1; } -std::vector> GraphTable::get_all_id(int type_id, int slice_num) { - std::vector> res(slice_num); +int GraphTable::get_all_id(int type_id, int slice_num, std::vector> *output) { + output->resize(slice_num); auto &search_shards = type_id == 0 ? edge_shards : feature_shards; std::vector>> tasks; for (int idx = 0; idx < search_shards.size(); idx++) { @@ -1781,14 +1781,14 @@ std::vector> GraphTable::get_all_id(int type_id, int slice for (size_t i = 0; i < tasks.size(); i++) { auto ids = tasks[i].get(); for (auto &id : ids) { - res[(uint64_t)(id) % slice_num].push_back(id); + (*output)[(uint64_t)(id) % slice_num].push_back(id); } } - return res; + return 0; } -std::vector> GraphTable::get_all_neighbor_id(int type_id, int slice_num) { - std::vector> res(slice_num); +int GraphTable::get_all_neighbor_id(int type_id, int slice_num, std::vector> *output) { + output->resize(slice_num); auto &search_shards = type_id == 0 ? edge_shards : feature_shards; std::vector>> tasks; for (int idx = 0; idx < search_shards.size(); idx++) { @@ -1805,15 +1805,15 @@ std::vector> GraphTable::get_all_neighbor_id(int type_id, for (size_t i = 0; i < tasks.size(); i++) { auto ids = tasks[i].get(); for (auto &id : ids) { - res[(uint64_t)(id) % slice_num].push_back(id); + (*output)[(uint64_t)(id) % slice_num].push_back(id); } } - return res; + return 0; } -std::vector> GraphTable::get_all_id(int type_id, int idx, - int slice_num) { - std::vector> res(slice_num); +int GraphTable::get_all_id(int type_id, int idx, + int slice_num, std::vector> *output) { + output->resize(slice_num); auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx]; std::vector>> tasks; VLOG(0) << "begin task, task_pool_size_[" << task_pool_size_ << "]"; @@ -1829,14 +1829,14 @@ std::vector> GraphTable::get_all_id(int type_id, int idx, VLOG(0) << "end task, task_pool_size_[" << task_pool_size_ << "]"; for (size_t i = 0; i < tasks.size(); i++) { auto ids = tasks[i].get(); - for (auto &id : ids) res[id % slice_num].push_back(id); + for (auto &id : ids) (*output)[id % slice_num].push_back(id); } - return res; + return 0; } -std::vector> GraphTable::get_all_neighbor_id(int type_id, int idx, - int slice_num) { - std::vector> res(slice_num); +int GraphTable::get_all_neighbor_id(int type_id, int idx, + int slice_num, std::vector> *output) { + output->resize(slice_num); auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx]; std::vector>> tasks; VLOG(0) << "begin task, task_pool_size_[" << task_pool_size_ << "]"; @@ -1852,9 +1852,9 @@ std::vector> GraphTable::get_all_neighbor_id(int type_id, VLOG(0) << "end task, task_pool_size_[" << task_pool_size_ << "]"; for (size_t i = 0; i < tasks.size(); i++) { auto ids = tasks[i].get(); - for (auto &id : ids) res[id % slice_num].push_back(id); + for (auto &id : ids) (*output)[id % slice_num].push_back(id); } - return res; + return 0; } int GraphTable::get_all_feature_ids(int type_id, int idx, int slice_num, diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.h b/paddle/fluid/distributed/ps/table/common_graph_table.h index a6bfbec34b755..06ea0b4e4b154 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.h +++ b/paddle/fluid/distributed/ps/table/common_graph_table.h @@ -503,12 +503,12 @@ class GraphTable : public Table { int32_t load_edges(const std::string &path, bool reverse, const std::string &edge_type); - std::vector> get_all_id(int type, int slice_num); - std::vector> get_all_neighbor_id(int type, int slice_num); - std::vector> get_all_id(int type, int idx, - int slice_num); - std::vector> get_all_neighbor_id(int type_id, int idx, - int slice_num); + int get_all_id(int type, int slice_num, std::vector> *output); + int get_all_neighbor_id(int type, int slice_num, std::vector> *output); + int get_all_id(int type, int idx, + int slice_num, std::vector> *output); + int get_all_neighbor_id(int type_id, int id, + int slice_num, std::vector> *output); int get_all_feature_ids(int type, int idx, int slice_num, std::vector>* output); int32_t load_nodes(const std::string &path, std::string node_type = std::string()); diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index b9adb4d861b01..335a9c990f023 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -460,8 +460,8 @@ void DatasetImpl::LoadIntoMemory() { int cnt = 0; for (auto& iter : node_to_id) { int node_idx = iter.second; - auto gpu_graph_device_keys = - gpu_graph_ptr->get_all_id(1, node_idx, thread_num_); + std::vector> gpu_graph_device_keys; + gpu_graph_ptr->get_all_id(1, node_idx, thread_num_, &gpu_graph_device_keys); auto& type_total_key = graph_all_type_total_keys_[cnt]; type_total_key.resize(thread_num_); for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) { @@ -500,8 +500,8 @@ void DatasetImpl::LoadIntoMemory() { // FIX: trick for iterate edge table for (auto& iter : edge_to_id) { int edge_idx = iter.second; - auto gpu_graph_device_keys = - gpu_graph_ptr->get_all_id(0, edge_idx, thread_num_); + std::vector> gpu_graph_device_keys; + gpu_graph_ptr->get_all_id(0, edge_idx, thread_num_, &gpu_graph_device_keys); for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) { VLOG(1) << "edge type: " << edge_idx << ", gpu_graph_device_keys[" << i << "] = " << gpu_graph_device_keys[i].size(); @@ -510,7 +510,8 @@ void DatasetImpl::LoadIntoMemory() { } } if (FLAGS_graph_get_neighbor_id) { - auto gpu_graph_neighbor_keys = gpu_graph_ptr->get_all_neighbor_id(0, edge_idx, thread_num_); + std::vector> gpu_graph_neighbor_keys; + gpu_graph_ptr->get_all_neighbor_id(0, edge_idx, thread_num_, &gpu_graph_neighbor_keys); for (size_t i = 0; i < gpu_graph_neighbor_keys.size(); i++) { for (size_t k = 0; k < gpu_graph_neighbor_keys[i].size(); k++) { gpu_graph_total_keys_.push_back(gpu_graph_neighbor_keys[i][k]); diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu index 17b6905b8a5e4..68aae6d3c4a9b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu @@ -28,31 +28,29 @@ void GraphGpuWrapper::set_device(std::vector ids) { } } -std::vector> GraphGpuWrapper::get_all_id(int type, - int slice_num) { +int GraphGpuWrapper::get_all_id(int type, int slice_num, + std::vector>* output) { return ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table_->get_all_id(type, slice_num); + ->cpu_graph_table_->get_all_id(type, slice_num, output); } -std::vector> GraphGpuWrapper::get_all_neighbor_id(int type, - int slice_num) { +int GraphGpuWrapper::get_all_neighbor_id(int type, int slice_num, + std::vector>* output) { return ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table_->get_all_neighbor_id(type, slice_num); + ->cpu_graph_table_->get_all_neighbor_id(type, slice_num, output); } -std::vector> GraphGpuWrapper::get_all_id(int type, - int idx, - int slice_num) { +int GraphGpuWrapper::get_all_id(int type, int idx, + int slice_num, std::vector>* output) { return ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table_->get_all_id(type, idx, slice_num); + ->cpu_graph_table_->get_all_id(type, idx, slice_num, output); } -std::vector> GraphGpuWrapper::get_all_neighbor_id(int type, - int idx, - int slice_num) { +int GraphGpuWrapper::get_all_neighbor_id(int type, int idx, + int slice_num, std::vector>* output) { return ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table_->get_all_neighbor_id(type, idx, slice_num); + ->cpu_graph_table_->get_all_neighbor_id(type, idx, slice_num, output); } int GraphGpuWrapper::get_all_feature_ids(int type, int idx, int slice_num, @@ -203,46 +201,64 @@ void GraphGpuWrapper::init_service() { g->init_cpu_table(table_proto); g->cpu_graph_table_->set_feature_separator(feature_separator_); graph_table = (char *)g; + upload_task_pool.reset(new ::ThreadPool(upload_num)); } void GraphGpuWrapper::finalize() { ((GpuPsGraphTable *)graph_table)->show_table_collisions(); } -void GraphGpuWrapper::upload_batch(int idx, - std::vector> &ids) { +void GraphGpuWrapper::upload_batch(int type, int idx, int slice_num, const std::string &edge_type) { + VLOG(0) << "begin upload edge, type[" << edge_type << "]"; + std::vector> ids; + ((GpuPsGraphTable *)graph_table) + ->cpu_graph_table_->get_all_id(type, idx, slice_num, &ids); debug_gpu_memory_info("upload_batch node start"); GpuPsGraphTable *g = (GpuPsGraphTable *)graph_table; + std::vector> tasks; + for (int i = 0; i < ids.size(); i++) { - GpuPsCommGraph sub_graph = - g->cpu_graph_table_->make_gpu_ps_graph(idx, ids[i]); - g->build_graph_on_single_gpu(sub_graph, i, idx); - sub_graph.release_on_cpu(); - VLOG(0) << "sub graph on gpu " << i << " is built"; + tasks.push_back(upload_task_pool->enqueue( + [&, i, idx, this]() -> int { + VLOG(0) << "begin make_gpu_ps_graph, node_id[" << i << "]_size[" + << ids[i].size() << "]"; + GpuPsCommGraph sub_graph = + g->cpu_graph_table_->make_gpu_ps_graph(idx, ids[i]); + g->build_graph_on_single_gpu(sub_graph, i, idx); + sub_graph.release_on_cpu(); + VLOG(0) << "sub graph on gpu " << i << " is built"; + return 0; + })); } + for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); debug_gpu_memory_info("upload_batch node end"); } // feature table -void GraphGpuWrapper::upload_batch(std::vector> &node_ids, - int slot_num) { +void GraphGpuWrapper::upload_batch(int type, int slice_num, int slot_num) { + std::vector> node_ids; + ((GpuPsGraphTable *)graph_table) + ->cpu_graph_table_->get_all_id(type, slice_num, &node_ids); debug_gpu_memory_info("upload_batch feature start"); GpuPsGraphTable *g = (GpuPsGraphTable *)graph_table; + std::vector> tasks; for (int i = 0; i < node_ids.size(); i++) { - VLOG(0) << "begin make_gpu_ps_graph_fea, node_ids[" << i << "]_size[" + tasks.push_back(upload_task_pool->enqueue( + [&, i, this]() -> int { + VLOG(0) << "begin make_gpu_ps_graph_fea, node_ids[" << i << "]_size[" << node_ids[i].size() << "]"; - GpuPsCommGraphFea sub_graph = g->cpu_graph_table_->make_gpu_ps_graph_fea( - node_ids[i], slot_num); - - // sub_graph.display_on_cpu(); - VLOG(0) << "begin build_graph_fea_on_single_gpu, node_ids[" << i + GpuPsCommGraphFea sub_graph = g->cpu_graph_table_->make_gpu_ps_graph_fea( + node_ids[i], slot_num); + // sub_graph.display_on_cpu(); + VLOG(0) << "begin build_graph_fea_on_single_gpu, node_ids[" << i << "]_size[" << node_ids[i].size() << "]"; - g->build_graph_fea_on_single_gpu(sub_graph, i); - - sub_graph.release_on_cpu(); - - VLOG(0) << "sub graph fea on gpu " << i << " is built"; + g->build_graph_fea_on_single_gpu(sub_graph, i); + sub_graph.release_on_cpu(); + VLOG(0) << "sub graph fea on gpu " << i << " is built"; + return 0; + })); } + for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); // g->build_graph_from_cpu(vec); debug_gpu_memory_info("upload_batch feature end"); } diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h index d5527e235c0e7..0cf6f2649184a 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h @@ -36,9 +36,8 @@ class GraphGpuWrapper { void init_service(); void set_up_types(std::vector& edge_type, std::vector& node_type); - void upload_batch(int etype_id, std::vector>& ids); - void upload_batch(std::vector>& ids, - int slot_num); + void upload_batch(int type, int idx, int slice_num, const std::string &edge_type); + void upload_batch(int type, int slice_num, int slot_num); void add_table_feat_conf(std::string table_name, std::string feat_name, std::string feat_dtype, int feat_shape); void load_edge_file(std::string name, std::string filepath, bool reverse); @@ -54,12 +53,10 @@ class GraphGpuWrapper { void make_complementary_graph(int idx, int64_t byte_size); void set_search_level(int level); void init_search_level(int level); - std::vector> get_all_id(int type, int slice_num); - std::vector> get_all_neighbor_id(int type, int slice_num); - std::vector> get_all_id(int type, int idx, - int slice_num); - std::vector> get_all_neighbor_id(int type, int idx, - int slice_num); + int get_all_id(int type, int slice_num, std::vector>* output); + int get_all_neighbor_id(int type, int slice_num, std::vector>* output); + int get_all_id(int type, int idx, int slice_num, std::vector>* output); + int get_all_neighbor_id(int type, int idx, int slice_num, std::vector>* output); int get_all_feature_ids(int type, int idx, int slice_num, std::vector>* output); NodeQueryResult query_node_list(int gpu_id, int idx, int start, @@ -85,6 +82,8 @@ class GraphGpuWrapper { std::vector device_id_mapping; int search_level = 1; void* graph_table; + int upload_num = 8; + std::shared_ptr<::ThreadPool> upload_task_pool; std::string feature_separator_ = std::string(" "); }; #endif diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 27068183c9c02..7bd4bc98645db 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -351,15 +351,16 @@ void BindGraphGpuWrapper(py::module* m) { .def("load_edge_file", &GraphGpuWrapper::load_edge_file) .def("load_node_and_edge", &GraphGpuWrapper::load_node_and_edge) .def("upload_batch", - py::overload_cast>&>( + py::overload_cast( &GraphGpuWrapper::upload_batch)) .def("upload_batch", - py::overload_cast>&, int>( + py::overload_cast( &GraphGpuWrapper::upload_batch)) - .def("get_all_id", py::overload_cast(&GraphGpuWrapper::get_all_id)) - .def("get_all_id", py::overload_cast(&GraphGpuWrapper::get_all_id)) - .def("get_all_neighbor_id", py::overload_cast(&GraphGpuWrapper::get_all_neighbor_id)) - .def("get_all_neighbor_id", py::overload_cast(&GraphGpuWrapper::get_all_neighbor_id)) + .def("get_all_id", + py::overload_cast>*>( + &GraphGpuWrapper::get_all_id)) + .def("get_all_id", py::overload_cast>*>( + &GraphGpuWrapper::get_all_id)) .def("load_next_partition", &GraphGpuWrapper::load_next_partition) .def("make_partitions", &GraphGpuWrapper::make_partitions) .def("make_complementary_graph", From 726344277a753ef3cf1b394a5fbde5aaaa02c176 Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Tue, 28 Jun 2022 14:28:56 +0800 Subject: [PATCH 3/3] fix adam with multi dim (#39) * fix adam with multi dim; test=develop --- .../distributed/ps/table/ctr_dymf_accessor.cc | 11 ++- .../distributed/ps/table/ctr_dymf_accessor.h | 16 ++++ .../framework/fleet/heter_ps/feature_value.h | 94 +++++++++++-------- .../fleet/heter_ps/hashtable_kernel.cu | 39 ++++---- .../framework/fleet/heter_ps/heter_comm_inl.h | 21 ++--- .../fleet/heter_ps/heter_comm_kernel.cu | 33 +++---- .../framework/fleet/heter_ps/optimizer.cuh.h | 19 ++-- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 60 ++++++------ .../fluid/framework/fleet/ps_gpu_wrapper.cu | 3 +- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 12 +-- .../distributed/passes/ps_trainer_pass.py | 85 ++++++++++++----- python/paddle/distributed/ps/the_one_ps.py | 7 +- 12 files changed, 219 insertions(+), 181 deletions(-) diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc index adeba7285643f..4f901c6f71fb3 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc @@ -29,6 +29,7 @@ int CtrDymfAccessor::Initialize() { _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(), _config.embedx_dim()); + common_feature_value.optimizer_name = name; common_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim(); common_feature_value.embedx_dim = _config.embedx_dim(); @@ -182,7 +183,8 @@ int32_t CtrDymfAccessor::Create(float** values, size_t num) { value[common_feature_value.SlotIndex()] = -1; value[common_feature_value.MfDimIndex()] = -1; _embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(), - value + common_feature_value.EmbedG2SumIndex()); + value + common_feature_value.EmbedG2SumIndex(), + false); // adam embed init not zero, adagrad embed init zero _embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(), value + common_feature_value.EmbedxG2SumIndex(), false); @@ -288,18 +290,17 @@ std::string CtrDymfAccessor::ParseToString(const float* v, int param) { os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4]; // << v[5] << " " << v[6]; for (int i = common_feature_value.EmbedG2SumIndex(); - i < common_feature_value.SlotIndex(); i++) { + i < common_feature_value.EmbedxG2SumIndex(); i++) { os << " " << v[i]; } - os << " " << common_feature_value.Slot(const_cast(v)) << " " - << common_feature_value.MfDim(const_cast(v)); auto show = common_feature_value.Show(const_cast(v)); auto click = common_feature_value.Click(const_cast(v)); auto score = ShowClickScore(show, click); + auto mf_dim = int(common_feature_value.MfDim(const_cast(v))); if (score >= _config.embedx_threshold() && param > common_feature_value.EmbedxG2SumIndex()) { for (auto i = common_feature_value.EmbedxG2SumIndex(); - i < common_feature_value.Dim(); ++i) { + i < common_feature_value.Dim(mf_dim); ++i) { os << " " << v[i]; } } diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h index 209ba76c67cdd..15efdd2b0bc37 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h @@ -57,6 +57,21 @@ class CtrDymfAccessor : public ValueAccessor { int EmbedxG2SumIndex() { return MfDimIndex() + 1; } int EmbedxWIndex() { return EmbedxG2SumIndex() + embedx_sgd_dim; } + // 根据mf_dim计算的总长度 + int Dim(int& mf_dim) { + int tmp_embedx_sgd_dim = 1; + if (optimizer_name == "SparseAdamSGDRule") {//adam + tmp_embedx_sgd_dim = mf_dim * 2 + 2; + } else if (optimizer_name == "SparseSharedAdamSGDRule") { //shared_adam + tmp_embedx_sgd_dim = 4; + } + return 7 + embed_sgd_dim + tmp_embedx_sgd_dim + mf_dim; + } + + // 根据mf_dim计算的总byte数 + int Size(int& mf_dim) { return (Dim(mf_dim)) * sizeof(float); } + + float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; } float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; } float& Show(float* val) { return val[ShowIndex()]; } @@ -71,6 +86,7 @@ class CtrDymfAccessor : public ValueAccessor { int embed_sgd_dim; int embedx_dim; int embedx_sgd_dim; + std::string optimizer_name; }; struct CtrDymfPushValue { diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 22f092952698d..221915fc713a8 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -27,21 +27,6 @@ namespace framework { typedef uint64_t FeatureKey; -struct GpuAccessorInfo { - // value维度 - size_t dim; - // value各个维度的size - size_t size; - // embedx维度 - size_t embedx_dim; - // push value维度 - size_t update_dim; - // push value各个维度的size - size_t update_size; - // value中mf动态长度部分总size大小, sparse下生效 - size_t mf_size; -}; - class FeatureValueAccessor { public: __host__ __device__ FeatureValueAccessor() {} @@ -54,11 +39,8 @@ class FeatureValueAccessor { } __host__ __device__ virtual int Initialize() = 0; - __host__ __device__ virtual GpuAccessorInfo GetAccessorInfo() { return _accessor_info; } - protected: std::unordered_map _config; - GpuAccessorInfo _accessor_info; }; // adagrad: embed_sgd_dim=1, embedx_sgd_dim=1,embedx_dim=n @@ -81,9 +63,9 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { std::vector embedx_w; */ - __host__ __device__ int Dim() { return 8 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; } // has cpu_ptr(1) + __host__ __device__ int Dim() { return 9 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; } // has cpu_ptr(2) __host__ __device__ int DimSize(size_t dim, int embedx_dim) { return sizeof(float); } - __host__ __device__ int Size() { return (Dim()-1) * sizeof(float) + sizeof(uint64_t); } // cpu_ptr:uint64 + __host__ __device__ int Size() { return Dim() * sizeof(float); } // cpu_ptr:uint64=2float __host__ __device__ int EmbedDim() { return embed_sgd_dim;} __host__ __device__ int EmbedXDim() { return embedx_sgd_dim;} __host__ __device__ int EmbedWDim() { return embedx_dim;} @@ -98,6 +80,52 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { __host__ __device__ int MfSizeIndex() { return MfDimIndex() + 1; } // actual mf size (ex. 0) __host__ __device__ int EmbedxG2SumIndex() { return MfSizeIndex() + 1; } __host__ __device__ int EmbedxWIndex() { return EmbedxG2SumIndex() + embedx_sgd_dim; } + + + // 根据mf_dim计算的总长度 + __host__ __device__ int Dim(int& mf_dim) { + int tmp_embedx_sgd_dim = 1; + if (optimizer_type_ == 3) {//adam + tmp_embedx_sgd_dim = mf_dim * 2 + 2; + } else if (optimizer_type_ == 4) { //shared_adam + tmp_embedx_sgd_dim = 4; + } + return 9 + embed_sgd_dim + tmp_embedx_sgd_dim + mf_dim; + } + + // 根据mf_dim 计算的总byte数 + __host__ __device__ int Size(int& mf_dim) { + return Dim(mf_dim) * sizeof(float); // cpu_ptr:2float + } + + // 根据mf_dim 计算的 mf_size byte数 + __host__ __device__ int MFSize(int& mf_dim) { + int tmp_embedx_sgd_dim = 1; + if (optimizer_type_ == 3) { //adam + tmp_embedx_sgd_dim = mf_dim * 2 + 2; + } else if (optimizer_type_ == 4) { //shared_adam + tmp_embedx_sgd_dim = 4; + } + return (tmp_embedx_sgd_dim + mf_dim) * sizeof(float); + } + + __host__ __device__ int EmbedxG2SumOffsetIndex() { return 0; } + __host__ __device__ int EmbedxWOffsetIndex(float* val) { + // has mf + int tmp_embedx_sgd_dim = 1; + if (int(MfSize(val)) > 0) { + if (optimizer_type_ == 3) {//adam + tmp_embedx_sgd_dim = int(MfDim(val)) * 2 + 2; + } else if (optimizer_type_ == 4) { //shared_adam + tmp_embedx_sgd_dim = 4; + } + return EmbedxG2SumIndex() + tmp_embedx_sgd_dim; + } else { + // no mf + return 0; + } + } + __host__ __device__ uint64_t CpuPtr(float* val) {return *(reinterpret_cast(val)); } __host__ __device__ float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; } @@ -114,6 +142,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { int embed_sgd_dim; int embedx_dim; int embedx_sgd_dim; + int optimizer_type_; }; struct CommonPushValue { @@ -177,30 +206,12 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { common_feature_value.embed_sgd_dim = 1; common_feature_value.embedx_sgd_dim = 1; } - + common_feature_value.optimizer_type_ = optimizer_type; common_feature_value.embedx_dim = sparse_embedx_dim; - // VLOG(0) << " INTO FeatureValueAccessor::Initialize()"; - InitAccessorInfo(); return 0; } - // 初始化AccessorInfo - __host__ __device__ virtual void InitAccessorInfo() { - _accessor_info.dim = common_feature_value.Dim(); - _accessor_info.size = common_feature_value.Size(); - - int embedx_dim = (_config.find("embedx_dim") == _config.end()) - ? 8 - : int(_config["embedx_dim"]); - // VLOG(0) << "feature value InitAccessorInfo embedx_dim:" << embedx_dim; - _accessor_info.embedx_dim = embedx_dim; - _accessor_info.update_dim = 5 + embedx_dim; - _accessor_info.update_size = _accessor_info.update_dim * sizeof(float); - _accessor_info.mf_size = - (embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float); - } - __host__ __device__ std::string ParseToString(const float* v, int param_size) { /* uint64_t cpu_ptr; // 2float @@ -223,13 +234,14 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { i < common_feature_value.SlotIndex(); i++) { os << " " << v[i]; } + int mf_dim = int(common_feature_value.MfDim(const_cast(v))); os << " slot: " << common_feature_value.Slot(const_cast(v)) - << " mf_dim: " << common_feature_value.MfDim(const_cast(v)) + << " mf_dim: " << mf_dim << " mf_size: " << common_feature_value.MfSize(const_cast(v)) << " mf: "; if (param_size > common_feature_value.EmbedxG2SumIndex()) { for (auto i = common_feature_value.EmbedxG2SumIndex(); - i < int(common_feature_value.Size() / sizeof(float)); ++i) { + i < common_feature_value.Dim(mf_dim); ++i) { os << " " << v[i]; } } diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index f1b332428b6c6..c430dfa669c45 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -95,37 +95,32 @@ __global__ void dy_mf_search_kernel(Table* table, uint64_t offset = i * pull_feature_value_size; float* cur = (float*)(vals + offset); float* input = it->second; - - cur[feature_value_accessor.common_feature_value.SlotIndex()] = - input[feature_value_accessor.common_feature_value.SlotIndex()]; + int mf_dim = int(input[feature_value_accessor.common_feature_value.MfDimIndex()]); + + *(reinterpret_cast(cur + feature_value_accessor.common_feature_value.CpuPtrIndex())) = + *(reinterpret_cast(input + feature_value_accessor.common_feature_value.CpuPtrIndex())); + cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = + input[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; cur[feature_value_accessor.common_feature_value.ShowIndex()] = input[feature_value_accessor.common_feature_value.ShowIndex()]; cur[feature_value_accessor.common_feature_value.ClickIndex()] = input[feature_value_accessor.common_feature_value.ClickIndex()]; - cur[feature_value_accessor.common_feature_value.MfDimIndex()] = - input[feature_value_accessor.common_feature_value.MfDimIndex()]; cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = input[feature_value_accessor.common_feature_value.EmbedWIndex()]; + for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedDim(); x++) { + cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + x] = + input[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + x]; + } + cur[feature_value_accessor.common_feature_value.SlotIndex()] = + input[feature_value_accessor.common_feature_value.SlotIndex()]; + cur[feature_value_accessor.common_feature_value.MfDimIndex()] = + input[feature_value_accessor.common_feature_value.MfDimIndex()]; cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = input[feature_value_accessor.common_feature_value.MfSizeIndex()]; - cur[feature_value_accessor.common_feature_value.CpuPtrIndex()] = - input[feature_value_accessor.common_feature_value.CpuPtrIndex()]; - cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = - input[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; - cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = - input[feature_value_accessor.common_feature_value.EmbedWIndex()]; - for (int i = 0; i < feature_value_accessor.common_feature_value.EmbedDim(); i++) { - cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i] = - input[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; - } - for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedXDim(); x++) { - cur[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x] = - input[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x]; - } - for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedWDim(); x++) { - cur[feature_value_accessor.common_feature_value.EmbedxWIndex() + x] = - input[feature_value_accessor.common_feature_value.EmbedxWIndex() + x]; + for (int x = feature_value_accessor.common_feature_value.EmbedxG2SumIndex(); + x < int(feature_value_accessor.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ + cur[x] = input[x]; } } } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 45c02029cc3b0..21b85acef9e14 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -77,10 +77,10 @@ HeterComm::HeterComm( } else { max_mf_dim_ = resource_->max_mf_dim(); feature_value_accessor_ = feature_value_accessor; - VLOG(3) << " HeterComm init, feature_value_size:" << feature_value_accessor_.GetAccessorInfo().size - << ", feature_value_push_size:" << feature_value_accessor_.GetAccessorInfo().update_size; - size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); - size_t grad_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); + size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); + size_t grad_type_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + VLOG(0) << " HeterComm init, max feature_value_size:" << val_type_size + << ", feature_value_push_size:" << grad_type_size; auto ptr_table = new PtrTable(capacity / load_factor_); ptr_table->set_accessor(feature_value_accessor_); ptr_table->set_feature_value_size(val_type_size, grad_type_size); @@ -629,8 +629,8 @@ void HeterComm::dynamic_merge_grad( size_t temp_storage_bytes; - size_t grad_dim = feature_value_accessor_.GetAccessorInfo().embedx_dim; - size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); + size_t grad_dim = max_mf_dim_; + size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); @@ -785,8 +785,9 @@ void HeterComm::pull_sparse(int num, auto d_idx = memory::Alloc(place, len * sizeof(int)); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); - size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); - VLOG(5) << "pull_sparse len:" << len << " val_type_size: " << val_type_size; + + size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); + VLOG(3) << "pull_sparse len:" << len << " val_type_size: " << val_type_size; auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); auto d_shard_vals = memory::Alloc(place, len * val_type_size); @@ -855,10 +856,8 @@ void HeterComm::pull_sparse(int num, sync_stream(node.out_stream); } } - heter_comm_kernel_->dy_mf_fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len, val_type_size, stream); - sync_stream(stream); if (!FLAGS_gpugraph_enable_gpu_direct_access) { for (int i = 0; i < total_device; ++i) { @@ -886,7 +885,7 @@ void HeterComm::push_sparse(int dev_num, int dev_id = resource_->dev_id(dev_num); size_t grad_value_size = - TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); + TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); DevPlace place = DevPlace(dev_id); AnyDeviceGuard guard(dev_id); auto stream = resource_->local_stream(dev_num, 0); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 8ba383407fa5c..415865ebba8dd 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -217,36 +217,31 @@ __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, float* d_vals, uint64_t new_offset = uint64_t(idx[i]) * val_size; float* cur = (float*)((char*)d_vals + new_offset); float* shard_val = (float*)((char*)d_shard_vals + uint64_t(i) * val_size); - cur[feature_value_accessor.common_feature_value.SlotIndex()] = - shard_val[feature_value_accessor.common_feature_value.SlotIndex()]; + int mf_dim = int(shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]); + + *(reinterpret_cast(cur + feature_value_accessor.common_feature_value.CpuPtrIndex())) = + *(reinterpret_cast(shard_val + feature_value_accessor.common_feature_value.CpuPtrIndex())); + cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = + shard_val[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; cur[feature_value_accessor.common_feature_value.ShowIndex()] = shard_val[feature_value_accessor.common_feature_value.ShowIndex()]; cur[feature_value_accessor.common_feature_value.ClickIndex()] = shard_val[feature_value_accessor.common_feature_value.ClickIndex()]; - cur[feature_value_accessor.common_feature_value.MfDimIndex()] = - shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]; - cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = - shard_val[feature_value_accessor.common_feature_value.EmbedWIndex()]; - cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = - shard_val[feature_value_accessor.common_feature_value.MfSizeIndex()]; - cur[feature_value_accessor.common_feature_value.CpuPtrIndex()] = - shard_val[feature_value_accessor.common_feature_value.CpuPtrIndex()]; - cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = - shard_val[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = shard_val[feature_value_accessor.common_feature_value.EmbedWIndex()]; for (int i = 0; i < feature_value_accessor.common_feature_value.EmbedDim(); i++) { cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i] = shard_val[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; } + cur[feature_value_accessor.common_feature_value.SlotIndex()] = + shard_val[feature_value_accessor.common_feature_value.SlotIndex()]; + cur[feature_value_accessor.common_feature_value.MfDimIndex()] = mf_dim; + cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = + shard_val[feature_value_accessor.common_feature_value.MfSizeIndex()]; - for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedXDim(); x++) { - cur[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x] = - shard_val[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x]; - } - for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedWDim(); x++) { - cur[feature_value_accessor.common_feature_value.EmbedxWIndex() + x] = - shard_val[feature_value_accessor.common_feature_value.EmbedxWIndex() + x]; + for (int x = feature_value_accessor.common_feature_value.EmbedxG2SumIndex(); + x < int(feature_value_accessor.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ + cur[x] = shard_val[x]; } } } diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index 393a13b93b616..11be566a9a13f 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -116,7 +116,8 @@ class SparseAdagradOptimizer : public Optimizer { (ptr[feature_value_accessor_.common_feature_value.ShowIndex()] - ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { - ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = mf_dim; + ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = + feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; @@ -245,14 +246,14 @@ class SparseAdamOptimizer : public Optimizer { grad + feature_value_accessor_.common_push_value.EmbedGIndex(), g_show); int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); - // printf("mf_dim: %f, lr_gsum: %f, ", mf_dim, ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex()]); if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= optimizer_config.nonclk_coeff * (ptr[feature_value_accessor_.common_feature_value.ShowIndex()] - ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { - ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = mf_dim; + ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = + feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; @@ -261,10 +262,6 @@ class SparseAdamOptimizer : public Optimizer { ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] = (curand_uniform(&state)) * optimizer_config.mf_initial_range; } - ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + Beta1PowIndex()] = - optimizer_config.beta1_decay_rate; - ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + Beta2PowIndex()] = - optimizer_config.beta2_decay_rate; ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta1PowIndex()] = optimizer_config.beta1_decay_rate; ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta2PowIndex()] = @@ -357,7 +354,6 @@ class SparseAdamSharedOptimizer : public Optimizer { float g_show = grad[feature_value_accessor_.common_push_value.ShowIndex()]; float g_click = grad[feature_value_accessor_.common_push_value.ClickIndex()]; - ptr[feature_value_accessor_.common_feature_value.SlotIndex()] = grad[feature_value_accessor_.common_push_value.SlotIndex()]; ptr[feature_value_accessor_.common_feature_value.ShowIndex()] += g_show; @@ -378,7 +374,8 @@ class SparseAdamSharedOptimizer : public Optimizer { (ptr[feature_value_accessor_.common_feature_value.ShowIndex()] - ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { - ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = mf_dim; + ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = + feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; @@ -387,10 +384,6 @@ class SparseAdamSharedOptimizer : public Optimizer { ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] = (curand_uniform(&state)) * optimizer_config.mf_initial_range; } - ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + Beta1PowIndex()] = - optimizer_config.beta1_decay_rate; - ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + Beta2PowIndex()] = - optimizer_config.beta2_decay_rate; ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta1PowIndex()] = optimizer_config.beta1_decay_rate; ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta2PowIndex()] = diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 83a6c48d3bac6..ea97bfe362f0c 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -240,7 +240,7 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { iter != total_data.begin() + end_index; iter++) { uint64_t cur_key = *iter; int shard_id = cur_key % thread_keys_shard_num_; - // int dim_id = slot_index_vec_[slot_idx]; + // TODO: feasign <-> slot <-> multi_dim this->thread_dim_keys_[i][shard_id][0].insert(cur_key); } }; @@ -593,9 +593,10 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { // this->HeterPs_->set_accessor(feature_value_accessor_); int mf_dim = this->index_dim_vec_[j]; VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim - << " feature_value_DIM:" << feature_value_accessor_.GetAccessorInfo().dim; + << " feature_value_dim:" << feature_value_accessor_.common_feature_value.Dim(mf_dim) + << " feature_value_size:" << feature_value_accessor_.common_feature_value.Size(mf_dim); size_t feature_value_size = - TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); + TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(mf_dim)); auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; size_t len = device_dim_keys.size(); @@ -658,30 +659,27 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { } *(reinterpret_cast(val + feature_value_accessor_.common_feature_value.CpuPtrIndex())) = (uint64_t)(device_dim_ptrs[k]); + ptr_val[cpu_table_accessor_->common_feature_value.MfDimIndex()] = float(mf_dim); val[feature_value_accessor_.common_feature_value.MfDimIndex()] = mf_dim; if (dim > cpu_table_accessor_->GetAccessorInfo().dim - cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float)) { val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.GetAccessorInfo().mf_size / sizeof(float); - for (int x = 0; x < feature_value_accessor_.common_feature_value.EmbedXDim(); x++) { - val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x] = + feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float); + + for (int x = 0; x < int(feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float)); + x++) { + val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x] = ptr_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x]; } - for (int x = 0; x < mf_dim; x++) { - val[feature_value_accessor_.common_feature_value.EmbedxWIndex() + x] = - ptr_val[cpu_table_accessor_->common_feature_value.EmbedxWIndex() + x]; - } } else { val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = 0; - for (int x = 0; x < feature_value_accessor_.common_feature_value.EmbedXDim(); x++) { - val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x] = 0; - } - for (int x = 0; x < mf_dim; x++) { - val[feature_value_accessor_.common_feature_value.EmbedxWIndex() + x] = 0; + for (int x = feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(); + x < int(feature_value_accessor_.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ + val[x] = 0; } } - VLOG(5) << "build "<< k << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.GetAccessorInfo().dim); + VLOG(5) << "build "<< k << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.common_feature_value.Dim(mf_dim)); } #endif @@ -829,10 +827,10 @@ void PSGPUWrapper::EndPass() { auto& device_keys = this->current_task_->device_dim_keys_[i][j]; size_t len = device_keys.size(); int mf_dim = this->index_dim_vec_[j]; - VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim - << " key_len :" << len; size_t feature_value_size = - TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); + TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(mf_dim)); + VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim + << " key_len :" << len << " feature_value_size:" << feature_value_size; char* test_build_values = (char*)malloc(feature_value_size * len); cudaMemcpy(test_build_values, hbm_pool->mem(), feature_value_size * len, @@ -880,8 +878,8 @@ void PSGPUWrapper::EndPass() { size_t downpour_value_size = downpour_value->size(); if (gpu_val[feature_value_accessor_.common_feature_value.MfSizeIndex()] > 0 && downpour_value_size == (cpu_table_accessor_->GetAccessorInfo().dim - - cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float))) { // cpu_accessor - downpour_value->resize(cpu_table_accessor_->GetAccessorInfo().dim); + int(cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float)))) { // cpu_accessor + downpour_value->resize(cpu_table_accessor_->common_feature_value.Dim(mf_dim)); } float* cpu_val = downpour_value->data(); @@ -902,16 +900,14 @@ void PSGPUWrapper::EndPass() { } if (gpu_val[feature_value_accessor_.common_feature_value.MfSizeIndex()] > 0) { - for (int x = 0; x < feature_value_accessor_.common_feature_value.EmbedXDim(); x++) { - cpu_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x] = + + for (int x = 0; x < int(feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float)); + x++) { + cpu_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x] = gpu_val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x]; } - for (int x = 0; x < feature_value_accessor_.common_feature_value.EmbedWDim(); x++) { - cpu_val[cpu_table_accessor_->common_feature_value.EmbedxWIndex() + x] = - gpu_val[feature_value_accessor_.common_feature_value.EmbedxWIndex() + x]; - } } - VLOG(5) << "dump to cpu "<< index << " : "<< feature_value_accessor_.ParseToString(gpu_val, feature_value_accessor_.GetAccessorInfo().dim) + VLOG(5) << "dump to cpu "<< index << " : "<< feature_value_accessor_.ParseToString(gpu_val, feature_value_accessor_.common_feature_value.Dim(mf_dim)) << " ===== CPU:" << cpu_table_accessor_->ParseToString(cpu_val, downpour_value->size()); } @@ -969,8 +965,8 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); size_t feature_value_size = 0; - feature_value_size = TYPEALIGN( 8, feature_value_accessor_.GetAccessorInfo().size); - VLOG(5) << "PullSparse feature_value_size:" << feature_value_accessor_.GetAccessorInfo().size; + feature_value_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); + VLOG(3) << "PullSparse max_dim:" << max_mf_dim_ << " feature_value_size:" << feature_value_size; #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begine Gpu Ps PullSparse"; @@ -1103,9 +1099,9 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, // #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begin GPUPS PushSparseGrad"; size_t grad_value_size = - TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); + TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); auto buf = memory::Alloc(place, total_length * grad_value_size); - VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_; + VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_ << "grad_value_size:" << grad_value_size; float* total_grad_values_gpu = reinterpret_cast(buf->ptr()); if (platform::is_cpu_place(place)) { diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu index d04da131e98ef..15d22ab57428d 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -96,7 +96,8 @@ __global__ void PullCopy(float** dest, const float* src, } } else { for (int j = 0; j < mf_dim; j++) { - *(dest[x] + y * (mf_dim + 3) + 3 + j) = feature_value_ptr[feature_value_accessor.common_feature_value.EmbedxWIndex() + j]; + *(dest[x] + y * (mf_dim + 3) + 3 + j) = + feature_value_ptr[feature_value_accessor.common_feature_value.EmbedxWOffsetIndex(feature_value_ptr) + j]; } } } diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index be8546c366994..6da628db72455 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -329,8 +329,6 @@ class PSGPUWrapper { } feature_value_accessor_.Configure(config); - VLOG(0) << "INIT feature_value_accessor_:" << feature_value_accessor_.GetAccessorInfo().dim - << " EMBX:" << feature_value_accessor_.common_feature_value.embedx_sgd_dim; InitializeGPUServer(config); } #endif @@ -509,13 +507,9 @@ class PSGPUWrapper { for (size_t i = 0; i < slot_index_vec_.size(); i++) { slot_index_vec_[i] = dim_index_map[slot_mf_dim_vector_[i]]; } - //TODO(FENGDANLEI): max_mf - VLOG(0) << "InitSlotInfo embed_sgd_dim_:" << embed_sgd_dim_ << " embedx_sgd_dim_:" - << embedx_sgd_dim_ << " embedx_dim_:" << embedx_dim_ - << " optimizer_type_:" << optimizer_type_; - VLOG(0) << "InitSlotInfo:" << feature_value_accessor_.GetAccessorInfo().size; - val_type_size_ =TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); - grad_type_size_ = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); + val_type_size_ = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); + grad_type_size_ = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + VLOG(0) << "InitSlotInfo: val_type_size_" << val_type_size_ << " grad_type_size_:" << grad_type_size_; slot_info_initialized_ = true; } #endif diff --git a/python/paddle/distributed/passes/ps_trainer_pass.py b/python/paddle/distributed/passes/ps_trainer_pass.py index 6112a9a1f45b6..9a4a513b752e3 100755 --- a/python/paddle/distributed/passes/ps_trainer_pass.py +++ b/python/paddle/distributed/passes/ps_trainer_pass.py @@ -311,6 +311,14 @@ def dag_check_up_and_reorder(program, inputs, outputs): for i in range(len(global_block.ops)): assert global_block.desc.op(i) == global_block.ops[i].desc + if attrs['use_ps_gpu']: + gpups_inputs_idxs = list() + gpups_outputs_idxs = list() + gpups_inputs = list() + gpups_outputs = list() + gpups_w_size = list() + gpups_min_distributed_idx = len(_program.global_block().ops) + 1 + for param, ops in pull_sparse_ops.items(): all_ops = _program.global_block().ops op_device = "" @@ -366,38 +374,37 @@ def dag_check_up_and_reorder(program, inputs, outputs): outputs_idxs[out_id] = min(idx, outputs_idxs[out_id]) + if attrs['use_ps_gpu']: + gpups_inputs_idxs.extend(inputs_idxs) + gpups_outputs_idxs.extend(outputs_idxs) + gpups_inputs.extend(inputs) + gpups_outputs.extend(outputs) + gpups_w_size.extend([w.shape[1]] * len(inputs)) + gpups_min_distributed_idx = min(min(op_idxs), + gpups_min_distributed_idx) + continue + if min(outputs_idxs) - max(inputs_idxs) >= 1: if max(inputs_idxs) == -1: distributed_idx = min(op_idxs) else: distributed_idx = max(inputs_idxs) + 1 - if attrs['use_ps_gpu']: - _program.global_block()._insert_op( - index=distributed_idx, - type="pull_gpups_sparse", - inputs={"Ids": inputs, - 'W': w}, - outputs={"Out": outputs}, - attrs={ - "size": [w.shape[1] for i in inputs], - "is_distributed": True, - "is_sparse": True - }) - else: - _program.global_block()._insert_op( - index=distributed_idx, - type="distributed_lookup_table", - inputs={"Ids": inputs, - 'W': w}, - outputs={"Outputs": outputs}, - attrs={ - "is_distributed": is_distributed, - "padding_idx": padding_idx, - "table_id": table_id, - "lookup_table_version": op_type, - "op_device": op_device - }) + _program.global_block()._insert_op( + index=distributed_idx, + type="distributed_lookup_table", + inputs={ + "Ids": inputs, + 'W': w + }, + outputs={"Outputs": outputs}, + attrs={ + "is_distributed": is_distributed, + "padding_idx": padding_idx, + "table_id": table_id, + "lookup_table_version": op_type, + "op_device": op_device + }) else: for i in range(len(inputs_idxs)): distributed_idx = op_idxs[i] @@ -416,6 +423,32 @@ def dag_check_up_and_reorder(program, inputs, outputs): "op_device": op_device }) + if attrs['use_ps_gpu'] and len(gpups_inputs) > 0: + if max(gpups_inputs_idxs) > 0: + raise ValueError("There can't be ops before embedding in gpups") + + _program.global_block()._insert_op(index=gpups_min_distributed_idx, + type="pull_gpups_sparse", + inputs={ + "Ids": gpups_inputs, + }, + outputs={"Out": gpups_outputs}, + attrs={ + "size": gpups_w_size, + "is_distributed": True, + "is_sparse": True + }) + PSGPU = paddle.fluid.core.PSGPU() + try: + gpu_slot = [int(var.name) for var in gpups_inputs] + except (ValueError): + raise ValueError( + "The slot name in gpups Should be able to convert to integer." + ) + PSGPU.set_slot_vector(gpu_slot) + gpu_mf_sizes = [x - 3 for x in gpups_w_size] + PSGPU.set_slot_dim_vector(gpu_mf_sizes) + def _get_pull_sparse_ops(self, _program, attrs): pull_sparse_ops = {} pull_sparse_ids = {} diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index 0c1565aa936b1..ac7dc9878749e 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -587,8 +587,11 @@ def _set(self, table_proto): if proto.table_name == self.common.table_name: usr_table_proto = proto break - table_proto.table_class = 'MemorySparseTable' - warnings.warn("The PS mode must use MemorySparseTable.") + if usr_table_proto.HasField("table_class"): + table_proto.table_class = usr_table_proto.table_class + else: + table_proto.table_class = 'MemorySparseTable' + warnings.warn("The PS mode must use MemorySparseTable.") if usr_table_proto.HasField("shard_num"): table_proto.shard_num = usr_table_proto.shard_num else: