From 7ab8d80234e1418131378718b45898251ac5404f Mon Sep 17 00:00:00 2001 From: miaoli06 <106585574+miaoli06@users.noreply.github.com> Date: Mon, 20 Feb 2023 21:18:21 +0800 Subject: [PATCH] fix infer batch num in multi node (#215) --- paddle/fluid/framework/data_feed.h | 21 ++++++++++++++++- paddle/fluid/framework/data_set.cc | 37 ++++++++++++++++++++++++++++++ paddle/fluid/framework/data_set.h | 1 + 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 9ff0a172b1977..94420c022926b 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -938,6 +938,14 @@ class GraphDataGenerator { uint64_t CopyUniqueNodes(); int GetPathNum() { return total_row_; } void ResetPathNum() { total_row_ = 0; } + int GetGraphBatchsize() {return batch_size_;}; + void SetNewBatchsize(int batch_num) { + if (!gpu_graph_training_ && !sage_mode_) { + batch_size_ = (total_row_ + batch_num - 1) / batch_num; + } else { + return; + } + } void ResetEpochFinish() { epoch_finish_ = false; } void ClearSampleState(); void DumpWalkPath(std::string dump_path, size_t dump_rate); @@ -1166,7 +1174,18 @@ class DataFeed { virtual const std::vector& GetInsContentVec() const { return ins_content_vec_; } - virtual int GetCurBatchSize() { return batch_size_; } + virtual int GetCurBatchSize() { +#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) + return gpu_graph_data_generator_.GetGraphBatchsize(); +#else + return batch_size_; +#endif + } + virtual void SetNewBatchsize(int batch_num) { +#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) + gpu_graph_data_generator_.SetNewBatchsize(batch_num); +#endif + } virtual int GetGraphPathNum() { #if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) return gpu_graph_data_generator_.GetPathNum(); diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 0d4cd507360bc..49fed23a642a5 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -1931,8 +1931,45 @@ void SlotRecordDataset::PrepareTrain() { return; } +void SlotRecordDataset::DynamicAdjustBatchNum() { + VLOG(3) << "dynamic adjust batch num of graph in multi node"; +#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) + if (gpu_graph_mode_) { + int thread_max_batch_num = 0; + for (size_t i = 0; i < readers_.size(); i++) { + int batch_size = readers_[i]->GetCurBatchSize(); + int64_t ins_num = readers_[i]->GetGraphPathNum(); + int batch_num = (ins_num + batch_size - 1) / batch_size; + if (batch_num > thread_max_batch_num) { + thread_max_batch_num = batch_num; + } + VLOG(3) << "ins num:" << ins_num << ", batch size:" + << batch_size << ", batch_num:" << thread_max_batch_num; + } +#ifdef PADDLE_WITH_GLOO + auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance(); + if (gloo_wrapper->Size() > 1) { + if (!gloo_wrapper->IsInitialized()) { + VLOG(0) << "GLOO is not inited"; + gloo_wrapper->Init(); + } + std::vector thread_batch_num_vec(1, thread_max_batch_num); + auto thread_max_batch_num_vec = + gloo_wrapper->AllReduce(thread_batch_num_vec, "max"); + thread_max_batch_num = thread_max_batch_num_vec[0]; + VLOG(3) << "thread max batch num:" << thread_max_batch_num; + for (size_t i = 0; i < readers_.size(); i++) { + readers_[i]->SetNewBatchsize(thread_max_batch_num); + } + } +#endif + } +#endif +} + void SlotRecordDataset::DynamicAdjustReadersNum(int thread_num) { if (thread_num_ == thread_num) { + DynamicAdjustBatchNum(); VLOG(3) << "DatasetImpl::DynamicAdjustReadersNum thread_num_=" << thread_num_ << ", thread_num_=thread_num, no need to adjust"; return; diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index c112e0c02fe97..35f08002c1152 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -410,6 +410,7 @@ class SlotRecordDataset : public DatasetImpl { bool discard_remaining_ins); virtual void PrepareTrain(); virtual void DynamicAdjustReadersNum(int thread_num); + void DynamicAdjustBatchNum(); protected: bool enable_heterps_ = true;