Skip to content

Commit

Permalink
fix isolated var fetch bug, test=release/2.0 (#24086)
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy authored Apr 23, 2020
1 parent 314ea80 commit 9eef667
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 0 deletions.
26 changes: 26 additions & 0 deletions paddle/fluid/framework/details/multi_devices_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,32 @@ std::vector<std::unique_ptr<ir::Graph>> TrySeparateToMultipleSingleDeviceGraphs(
g->Set(kGraphDepVars, new GraphDepVars());
}

std::vector<VarHandle *> isolated_var_handles;
for (auto *node : graph->Nodes()) {
if (!node->IsWrappedBy<VarHandleBase>()) {
continue;
}

auto &var_handle_base = node->Wrapper<VarHandleBase>();
auto *var_handle = dynamic_cast<VarHandle *>(&var_handle_base);
if (var_handle && var_handle->PendingOps().empty() &&
var_handle->GeneratedOp() == nullptr) {
isolated_var_handles.emplace_back(var_handle);
}
}

for (auto *var_handle : isolated_var_handles) {
auto dev_idx = var_handle->scope_idx();
auto &src_vars = graph->Get<GraphVars>(kGraphVars)[dev_idx];
auto *dst_graph = graphs[dev_idx].get();
auto &dst_vars = dst_graph->Get<GraphVars>(kGraphVars)[0];
VLOG(10) << "Move isolated var " << var_handle->Name() << " at device "
<< dev_idx;
dst_graph->AddNode(graph->RemoveNode(var_handle->Node()).release());
dst_vars[var_handle->Name()].emplace_back(var_handle);
src_vars.erase(var_handle->Name());
}

for (auto &pair : op_to_dev_idx) {
auto *op = pair.first;
auto dev_idx = pair.second;
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/framework/ir/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
all_vars.emplace(var->Name(), var);
}

auto not_visited_vars = all_vars;

for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = CreateOpNode(op);
// For input args, reuse the same var name if it was created before.
// Otherwise, create a new one.
for (auto &each_var_name : op->InputArgumentNames()) {
not_visited_vars.erase(each_var_name);
ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name).back();
Expand All @@ -68,6 +71,7 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
// For output args, always create a new var.
std::unordered_set<std::string> out_arg_set;
for (auto &each_var_name : op->OutputArgumentNames()) {
not_visited_vars.erase(each_var_name);
if (each_var_name != kEmptyVarName) {
PADDLE_ENFORCE_EQ(out_arg_set.count(each_var_name), 0,
platform::errors::InvalidArgument(
Expand All @@ -91,6 +95,16 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
var->inputs.push_back(node);
}
}

for (auto &pair : not_visited_vars) {
const auto &var_name = pair.first;
auto *var_desc = pair.second;
if (var_name != kEmptyVarName) {
VLOG(10) << "Create isolated var node " << var_name;
var_nodes[var_name].push_back(CreateVarNode(var_desc));
}
}

Set<const std::vector<OpDesc *>>(
details::kStaleProgramOpDescs,
new std::vector<OpDesc *>(program.Block(0).AllOps()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,15 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph;

std::vector<ir::Node *> isolated_vars;

for (auto &node : nodes) {
if (node->IsVar() && node->Var()) {
all_vars_.emplace(node->Name(), node->Var());

if (node->inputs.empty() && node->outputs.empty()) {
isolated_vars.emplace_back(node.get());
}
}
}

Expand All @@ -185,6 +191,10 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
result.Set(details::kGraphDepVars, new details::GraphDepVars);
result.Set(kGraphOps, new GraphOps);

for (auto *var_node : isolated_vars) {
CreateIsolatedVarNode(&result, var_node);
}

bool is_forwarding = true;

for (ir::Node *node : sorted_ops) {
Expand Down Expand Up @@ -582,6 +592,15 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
return all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS;
}

void MultiDevSSAGraphBuilderBase::CreateIsolatedVarNode(
ir::Graph *graph, ir::Node *var_node) const {
for (size_t i = 0; i < places_.size(); ++i) {
VLOG(10) << "Create isolated var node " << var_node->Name() << " at device "
<< i;
CreateOrGetLatestVarHandle(graph, var_node, places_[i], i);
}
}

void AllReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, const std::string &p_name,
const std::string &g_name) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
size_t device_id) const;

void CreateIsolatedVarNode(ir::Graph *result, ir::Node *var_node) const;

#if defined(PADDLE_WITH_NCCL)
mutable platform::NCCLContextMap *nccl_ctxs_{nullptr};
mutable platform::NCCLCommunicator *multi_nccl_ctxs_{nullptr};
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass
test_optimizer_in_control_flow test_dataloader_keep_order
test_dataloader_unkeep_order
test_parallel_executor_fetch_isolated_var
test_parallel_executor_inference_feed_partial_data
test_parallel_ssa_graph_inference_feed_partial_data
test_fetch_unmerged
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
import six
import paddle.fluid as fluid


def enable_parallel_ssa_executor(enabled=True):
if fluid.is_compiled_with_cuda():
fluid.core.globals()['FLAGS_enable_parallel_graph'] = enabled


class TestParallelExecutorFetchIsolatedVarBase(unittest.TestCase):
def build_network(self, is_training):
x = fluid.data(name='x', shape=[-1, 10], dtype='float32')
y = fluid.data(name='y', shape=[-1, 10], dtype='float32')
fc = fluid.layers.fc(x, size=30)
loss = fluid.layers.reduce_mean(fc)
if is_training:
adam = fluid.optimizer.Adam(learning_rate=1e-3)
adam.minimize(loss)

return loss, y

def exec_strategy(self, use_experimental_executor):
strategy = fluid.ExecutionStrategy()
strategy.use_experimental_executor = use_experimental_executor
return strategy

def places(self, use_gpu, dev_cnt):
if use_gpu:
return fluid.cuda_places(list(range(dev_cnt)))
else:
return fluid.cpu_places(dev_cnt)

def test_main(self):
for use_gpu in [False, True]:
for dev_cnt in [1, 2]:
for is_training in [False, True]:
for use_experimental_executor in [False, True]:
for use_parallel_ssa_executor in [False, True]:
func = lambda: self.run_impl(use_gpu, dev_cnt, is_training, use_experimental_executor, use_parallel_ssa_executor)
self.run_func_with_guard(func)

def run_impl(self, use_gpu, dev_cnt, is_training, use_experimental_executor,
use_parallel_ssa_executor):
enable_parallel_ssa_executor(use_parallel_ssa_executor)

if fluid.is_compiled_with_cuda():
if fluid.core.globals()[
'FLAGS_enable_parallel_graph'] and not use_gpu:
return
else:
if use_gpu:
return

loss, isolated_var = self.build_network(is_training)
loss_name = loss.name if is_training else None

places = self.places(use_gpu, dev_cnt)
exe = fluid.Executor(places[0])

exe.run(fluid.default_startup_program())

prog = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel(
loss_name=loss_name,
exec_strategy=self.exec_strategy(use_experimental_executor),
places=places)

BATCH_SIZE = 8 * dev_cnt
for _ in six.moves.range(10):
x_np = np.random.random(size=[BATCH_SIZE, 10]).astype('float32')
y_np = np.random.random(size=[BATCH_SIZE, 10]).astype('float32')

_, y_np_fetch = exe.run(prog,
feed={'x': x_np,
'y': y_np},
fetch_list=[loss, isolated_var])

self.assertTrue(np.array_equal(y_np, y_np_fetch))

enable_parallel_ssa_executor(False)

def run_func_with_guard(self, func):
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.unique_name.guard():
with fluid.scope_guard(fluid.Scope()):
func()


if __name__ == '__main__':
unittest.main()

0 comments on commit 9eef667

Please sign in to comment.