Skip to content

Commit

Permalink
Merge pull request #9774 from reyoung/feature/simplify_data_structures
Browse files Browse the repository at this point in the history
Simplify DataStructure in SSAGraph
  • Loading branch information
reyoung authored Apr 10, 2018
2 parents 0f38bb4 + 17bfe3f commit 161344b
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 25 deletions.
21 changes: 13 additions & 8 deletions paddle/fluid/framework/details/multi_devices_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto graph = new SSAGraph();
SSAGraph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast;
result.vars_.resize(places_.size());

// We cannot invoke resize. It is a bug of GCC 4.8
result.vars_ = std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());

bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
Expand Down Expand Up @@ -147,15 +151,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (vars.empty()) { // This device has no data. continue.
continue;
}
auto *prev_grad = &vars[vars.size() - 1];
op_handle->AddInput(prev_grad);
auto &prev_grad = vars[vars.size() - 1];
op_handle->AddInput(prev_grad.get());

auto &var = vars[vars.size()];
var.place_ = p;
var.name_ = og;
var.version_ = vars.size() - 1;
vars.emplace_back(new VarHandle);
auto &var = vars.back();
var->place_ = p;
var->name_ = og;
var->version_ = vars.size() - 1;

op_handle->AddOutput(&var);
op_handle->AddOutput(var.get());
}
#else
PADDLE_ENFORCE("Not implemented");
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/framework/details/ssa_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include <map>
#include <string>
#include <vector>

#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"

Expand All @@ -24,7 +26,9 @@ namespace framework {
namespace details {

struct SSAGraph {
std::vector<std::unordered_map<std::string, std::map<int, VarHandle>>> vars_;
std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
vars_;
// aux variables to represent dependency. Useful to resolve data hazard.
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
std::vector<std::unique_ptr<OpHandleBase>> ops_;
Expand Down
30 changes: 16 additions & 14 deletions paddle/fluid/framework/details/ssa_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = it_new->second.generated_op_;
auto &read_ops = it_old->second.pending_ops_;
auto *write_op = (*it_new)->generated_op_;
auto &read_ops = (*it_old)->pending_ops_;

for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
Expand All @@ -54,14 +54,15 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
var_holder.emplace_back(new VarHandle);
auto &init_var = var_holder[0];
init_var.place_ = place;
init_var.name_ = each_var_name;
init_var.generated_op_ = nullptr;
init_var.version_ = 0;
var = &init_var;
init_var->place_ = place;
init_var->name_ = each_var_name;
init_var->generated_op_ = nullptr;
init_var->version_ = 0;
var = init_var.get();
} else {
var = &var_holder.rbegin()->second;
var = var_holder.rbegin()->get();
}
return var;
}
Expand All @@ -72,19 +73,20 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name];
size_t version = vars.size();
auto &var = vars[version];
var.version_ = version;
var.name_ = each_var_name;
var.place_ = place;
op_handle->AddOutput(&var);
vars.emplace_back(new VarHandle());
auto &var = vars.back();
var->version_ = version;
var->name_ = each_var_name;
var->place_ = place;
op_handle->AddOutput(var.get());
}

template <typename Callback>
void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(pair2.second);
callback(*pair2);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->vars_) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(version_pair.second);
InsertPendingVar(*version_pair);
}
}
}
Expand All @@ -95,7 +95,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->vars_) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
}
}
}
Expand Down

0 comments on commit 161344b

Please sign in to comment.