From cf2c4ecfeb92173b707cfaa9b23020b412d1ad84 Mon Sep 17 00:00:00 2001 From: JZZ-NOTE Date: Mon, 6 Dec 2021 14:14:06 +0000 Subject: [PATCH 1/2] multithread_memory_optimize --- .../analysis/passes/memory_optimize_pass.cc | 32 +++++++++++-------- .../analysis/passes/memory_optimize_pass.h | 12 +++---- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc index 2202b94bee727..db0fc126d8017 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -52,11 +52,11 @@ typedef struct { // The traversal order also affect the lifecycles, so different sort_kind is // used. void MemoryOptimizePass::CollectLifeCycle( - std::unordered_map* lifecycles, - int sort_kind) const { - max_lifecycle_ = 0; + Graph* graph, std::unordered_map* lifecycles, + int sort_kind, int max_lifecycle) const { + max_lifecycle = 0; for (auto* op_node : framework::ir::TopologyVarientSort( - *graph_, static_cast(sort_kind))) { + *graph, static_cast(sort_kind))) { if (!op_node->IsOp()) continue; auto reads = op_node->inputs; auto writes = op_node->outputs; @@ -77,20 +77,20 @@ void MemoryOptimizePass::CollectLifeCycle( if (node->Var()->Persistable()) continue; std::string var = node->Name(); if (!lifecycles->count(var)) { - (*lifecycles)[var] = std::make_pair(max_lifecycle_, max_lifecycle_); + (*lifecycles)[var] = std::make_pair(max_lifecycle, max_lifecycle); } else { (*lifecycles)[var].second = - std::max(max_lifecycle_, lifecycles->at(var).second); // max() + std::max(max_lifecycle, lifecycles->at(var).second); // max() } } } - ++max_lifecycle_; + ++max_lifecycle; } } void MemoryOptimizePass::CollectVarMemorySize( - space_table_t* space_table) const { + Graph* graph, space_table_t* space_table) const { const int fake_batch_size = 1; auto valid_var = [&](framework::ir::Node* node) -> bool { @@ -130,7 +130,7 @@ void MemoryOptimizePass::CollectVarMemorySize( // although it's not always the case. so black list is the best compromise // between performance and underlying principle. std::unordered_set black_list; - for (auto* node : graph_->Nodes()) { + for (auto* node : graph->Nodes()) { if (node->IsVar() && node->Var()->GetType() == framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) { @@ -141,7 +141,7 @@ void MemoryOptimizePass::CollectVarMemorySize( } // Collect tensors from graph. - for (auto* node : graph_->Nodes()) { + for (auto* node : graph->Nodes()) { if (node->IsVar() && node->Var()->GetType() == framework::proto::VarType::Type::VarType_Type_LOD_TENSOR && @@ -304,7 +304,11 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { // 3. Perform reuse plan: Replace all var's name in the model according to the // mapping table. if (!argument->enable_memory_optim()) return; - graph_ = argument->main_graph_ptr(); + // Because of pass is a singleton, graph and max_lifecycle can not be member + // variables,otherwise,errors will be caused under multithreading + // conditions. + auto graph = argument->main_graph_ptr(); + int max_lifecycle = -1; int sort_kind = 0; std::unordered_map lifecycles; @@ -312,10 +316,10 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { std::unordered_map node2cluster; std::unordered_map cluster_size; - CollectLifeCycle(&lifecycles, sort_kind); - CollectVarMemorySize(&space_table); + CollectLifeCycle(graph, &lifecycles, sort_kind, max_lifecycle); + CollectVarMemorySize(graph, &space_table); MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size); - UpdateOpDescsByReuse(graph_, node2cluster, sort_kind); + UpdateOpDescsByReuse(graph, node2cluster, sort_kind); return; } diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h index 6d20aee295b7c..cdeef8d7364fd 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h @@ -57,17 +57,15 @@ class MemoryOptimizePass : public AnalysisPass { private: void CollectLifeCycle( - std::unordered_map *lifecycles, - int sort_kind) const; + framework::ir::Graph *graph, + std::unordered_map *lifecycles, int sort_kind, + int max_lifecycle) const; - void CollectVarMemorySize(space_table_t *space_table) const; + void CollectVarMemorySize(framework::ir::Graph *graph, + space_table_t *space_table) const; public: std::string repr() const override; - - private: - mutable framework::ir::Graph *graph_{nullptr}; - mutable int max_lifecycle_{-1}; }; } // namespace analysis From bf33242dc098adc33dabca14844fed95ca97876d Mon Sep 17 00:00:00 2001 From: JZZ-NOTE Date: Tue, 7 Dec 2021 03:07:50 +0000 Subject: [PATCH 2/2] multithread_memory_optimize --- .../inference/analysis/passes/memory_optimize_pass.cc | 9 ++++----- .../inference/analysis/passes/memory_optimize_pass.h | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc index db0fc126d8017..3fa417c2ea631 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -53,8 +53,8 @@ typedef struct { // used. void MemoryOptimizePass::CollectLifeCycle( Graph* graph, std::unordered_map* lifecycles, - int sort_kind, int max_lifecycle) const { - max_lifecycle = 0; + int sort_kind) const { + int max_lifecycle = 0; for (auto* op_node : framework::ir::TopologyVarientSort( *graph, static_cast(sort_kind))) { if (!op_node->IsOp()) continue; @@ -304,11 +304,10 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { // 3. Perform reuse plan: Replace all var's name in the model according to the // mapping table. if (!argument->enable_memory_optim()) return; - // Because of pass is a singleton, graph and max_lifecycle can not be member + // Because of pass is a singleton, graph can not be member // variables,otherwise,errors will be caused under multithreading // conditions. auto graph = argument->main_graph_ptr(); - int max_lifecycle = -1; int sort_kind = 0; std::unordered_map lifecycles; @@ -316,7 +315,7 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { std::unordered_map node2cluster; std::unordered_map cluster_size; - CollectLifeCycle(graph, &lifecycles, sort_kind, max_lifecycle); + CollectLifeCycle(graph, &lifecycles, sort_kind); CollectVarMemorySize(graph, &space_table); MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size); UpdateOpDescsByReuse(graph, node2cluster, sort_kind); diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h index cdeef8d7364fd..57052243d2f18 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h @@ -58,8 +58,8 @@ class MemoryOptimizePass : public AnalysisPass { private: void CollectLifeCycle( framework::ir::Graph *graph, - std::unordered_map *lifecycles, int sort_kind, - int max_lifecycle) const; + std::unordered_map *lifecycles, + int sort_kind) const; void CollectVarMemorySize(framework::ir::Graph *graph, space_table_t *space_table) const;