Skip to content

Commit

Permalink
multithread memory optimize error fix (#37894)
Browse files Browse the repository at this point in the history
* multithread_memory_optimize
  • Loading branch information
JZZ-NOTE committed Jan 6, 2022
1 parent 1e8432f commit 4f87ebe
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
29 changes: 16 additions & 13 deletions paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ typedef struct {
// The traversal order also affect the lifecycles, so different sort_kind is
// used.
void MemoryOptimizePass::CollectLifeCycle(
std::unordered_map<std::string, lifecycle_t>* lifecycles,
Graph* graph, std::unordered_map<std::string, lifecycle_t>* lifecycles,
int sort_kind) const {
max_lifecycle_ = 0;
int max_lifecycle = 0;
for (auto* op_node : framework::ir::TopologyVarientSort(
*graph_, static_cast<framework::ir::SortKind>(sort_kind))) {
*graph, static_cast<framework::ir::SortKind>(sort_kind))) {
if (!op_node->IsOp()) continue;
auto reads = op_node->inputs;
auto writes = op_node->outputs;
Expand All @@ -77,20 +77,20 @@ void MemoryOptimizePass::CollectLifeCycle(
if (node->Var()->Persistable()) continue;
std::string var = node->Name();
if (!lifecycles->count(var)) {
(*lifecycles)[var] = std::make_pair(max_lifecycle_, max_lifecycle_);
(*lifecycles)[var] = std::make_pair(max_lifecycle, max_lifecycle);
} else {
(*lifecycles)[var].second =
std::max(max_lifecycle_, lifecycles->at(var).second); // max()
std::max(max_lifecycle, lifecycles->at(var).second); // max()
}
}
}

++max_lifecycle_;
++max_lifecycle;
}
}

void MemoryOptimizePass::CollectVarMemorySize(
space_table_t* space_table) const {
Graph* graph, space_table_t* space_table) const {
const int fake_batch_size = 1;

auto valid_var = [&](framework::ir::Node* node) -> bool {
Expand Down Expand Up @@ -130,7 +130,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
// although it's not always the case. so black list is the best compromise
// between performance and underlying principle.
std::unordered_set<std::string> black_list;
for (auto* node : graph_->Nodes()) {
for (auto* node : graph->Nodes()) {
if (node->IsVar() &&
node->Var()->GetType() ==
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) {
Expand All @@ -141,7 +141,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
}

// Collect tensors from graph.
for (auto* node : graph_->Nodes()) {
for (auto* node : graph->Nodes()) {
if (node->IsVar() &&
node->Var()->GetType() ==
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR &&
Expand Down Expand Up @@ -304,18 +304,21 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
// 3. Perform reuse plan: Replace all var's name in the model according to the
// mapping table.
if (!argument->enable_memory_optim()) return;
graph_ = argument->main_graph_ptr();
// Because of pass is a singleton, graph can not be member
// variables,otherwise,errors will be caused under multithreading
// conditions.
auto graph = argument->main_graph_ptr();

int sort_kind = 0;
std::unordered_map<std::string, lifecycle_t> lifecycles;
space_table_t space_table;
std::unordered_map<std::string, std::string> node2cluster;
std::unordered_map<std::string, int> cluster_size;

CollectLifeCycle(&lifecycles, sort_kind);
CollectVarMemorySize(&space_table);
CollectLifeCycle(graph, &lifecycles, sort_kind);
CollectVarMemorySize(graph, &space_table);
MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size);
UpdateOpDescsByReuse(graph_, node2cluster, sort_kind);
UpdateOpDescsByReuse(graph, node2cluster, sort_kind);
return;
}

Expand Down
8 changes: 3 additions & 5 deletions paddle/fluid/inference/analysis/passes/memory_optimize_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,15 @@ class MemoryOptimizePass : public AnalysisPass {

private:
void CollectLifeCycle(
framework::ir::Graph *graph,
std::unordered_map<std::string, lifecycle_t> *lifecycles,
int sort_kind) const;

void CollectVarMemorySize(space_table_t *space_table) const;
void CollectVarMemorySize(framework::ir::Graph *graph,
space_table_t *space_table) const;

public:
std::string repr() const override;

private:
mutable framework::ir::Graph *graph_{nullptr};
mutable int max_lifecycle_{-1};
};

} // namespace analysis
Expand Down

1 comment on commit 4f87ebe

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.