Skip to content

Commit

Permalink
add IpuInplacePass (PaddlePaddle#112)
Browse files Browse the repository at this point in the history
* code format

* add IpuInplacePass
  • Loading branch information
gglin001 authored Aug 31, 2021
1 parent 7cb619a commit 8effd41
Show file tree
Hide file tree
Showing 19 changed files with 227 additions and 95 deletions.
4 changes: 3 additions & 1 deletion paddle/fluid/framework/ipu/ipu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ void IpuBackend::Prepare() {

std::vector<int64_t> IpuBackend::GetTensorShape(const std::string& var_name) {
auto oshape = compiler_->GetTensorShape(var_name);
oshape.insert(oshape.begin(), ipu_strategy_->batches_per_step);
if (ipu_strategy_->batches_per_step != 1) {
oshape.insert(oshape.begin(), ipu_strategy_->batches_per_step);
}
return oshape;
}

Expand Down
14 changes: 7 additions & 7 deletions paddle/fluid/framework/ipu/ipu_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ipu/ipu_strategy.h"
#include "paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.h"

#include <glog/logging.h>
#include "paddle/fluid/framework/ir/graph_printer.h"

namespace paddle {
namespace framework {
namespace ipu {}
}
}
namespace ipu {

//

} // namespace ipu
} // namespace framework
} // namespace paddle
15 changes: 0 additions & 15 deletions paddle/fluid/framework/ipu/ipu_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,10 @@ limitations under the License. */

#pragma once

#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>

#include <popart/sessionoptions.hpp>

#include "boost/optional.hpp"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {

namespace ipu {

using VirtualGraphMode = popart::VirtualGraphMode;
Expand All @@ -44,6 +30,5 @@ struct IpuStrategy {
};

} // namespace ipu

} // namespace framework
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ if(WITH_IPU)
pass_library(inference_extract_pass base DIR ipu DEPS ipu_pass_base)
pass_library(popart_canonicalization_pass base DIR ipu DEPS ipu_pass_base)
target_link_libraries(popart_canonicalization_pass -Wl,--whole-archive popart_canonicalization_utils -Wl,--no-whole-archive)
pass_library(ipu_inplace_pass base DIR ipu DEPS ipu_pass_base)
endif()

cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector )
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"

Expand All @@ -26,4 +28,4 @@ class ForwardGraphExtractPass : public IPUPassBase {

} // namespace ir
} // namespace framework
} // namespace paddle
} // namespace paddle
10 changes: 2 additions & 8 deletions paddle/fluid/framework/ir/ipu/inference_extract_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,9 @@

#include "paddle/fluid/framework/ir/ipu/inference_extract_pass.h"

#include <string>

#include "paddle/fluid/framework/ipu/ipu_backend.h"
#include "paddle/fluid/framework/ipu/ipu_strategy.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/enforce.h"

// debug
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
Expand Down Expand Up @@ -73,4 +67,4 @@ void InferenceExtractPass::ApplyImpl(ir::Graph* graph) const {
} // namespace paddle

REGISTER_PASS(inference_extract_pass,
paddle::framework::ir::InferenceExtractPass);
paddle::framework::ir::InferenceExtractPass);
7 changes: 3 additions & 4 deletions paddle/fluid/framework/ir/ipu/inference_extract_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/graph.h"
#pragma once

#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"

namespace paddle {
namespace framework {
Expand All @@ -23,9 +23,8 @@ namespace ir {
class InferenceExtractPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

};

} // namespace ir
} // namespace framework
} // namespace paddle
} // namespace paddle
20 changes: 0 additions & 20 deletions paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,7 @@

#include "paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.h"

#include <algorithm>
#include <array>
#include <fstream>
#include <iosfwd>
#include <memory>
#include <ostream>
#include <string>
#include <unordered_map>
#include <unordered_set>

#include "paddle/fluid/framework/ipu/ipu_backend.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"

// debug
#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
Expand Down Expand Up @@ -67,5 +49,3 @@ REGISTER_PASS(ipu_graph_builder_pass,
paddle::framework::ir::IpuGraphBuilderPass)
.RequirePassAttr("feed_list")
.RequirePassAttr("fetch_list");

USE_PASS(graph_viz_pass);
5 changes: 3 additions & 2 deletions paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"

namespace paddle {
namespace framework {
namespace ir {


class IpuGraphBuilderPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};

} // namespace ir
} // namespace framework
} // namespace paddle
} // namespace paddle
82 changes: 82 additions & 0 deletions paddle/fluid/framework/ir/ipu/ipu_inplace_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/ipu/ipu_inplace_pass.h"

#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
namespace framework {
namespace ir {

Node *GetInputVarNode(const std::string &var_name, const ir::Node *node) {
PADDLE_ENFORCE_EQ(node->IsOp(), true,
platform::errors::InvalidArgument("node is not Op"));
for (auto *n : node->inputs) {
if (n->Name() == var_name) {
return n;
}
}
return nullptr;
}

void IpuInplacePass::ApplyImpl(ir::Graph *graph) const {
// use this pass after forward_graph_extract_pass
VLOG(10) << "enter IpuInplacePass::ApplyImpl";
VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph);

for (auto *node : graph->Nodes()) {
if (!node->IsOp()) {
continue;
}

RenameInplaceVar(node);
}

VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);
VLOG(10) << "leave IpuInplacePass::ApplyImpl";
}

void IpuInplacePass::RenameInplaceVar(ir::Node *node) const {
// rename input_var, only support one input_var rename
auto *op = node->Op();
for (auto name : op->Input("__inputs__")) {
for (auto name_out : op->Output("__outputs__")) {
if (name == name_out) {
auto new_name = name + "_1";
VLOG(10) << "replace op node: " << node->Name()
<< " input var: " << name << " to " << new_name;
auto var = GetInputVarNode(name, node);
if (var) {
var->RenameVar(new_name);
for (auto *op_in : var->inputs) {
op_in->Op()->RenameOutput(name, new_name);
}
for (auto *op_out : var->outputs) {
op_out->Op()->RenameInput(name, new_name);
}
return;
}
}
}
}
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(ipu_inplace_pass, paddle::framework::ir::IpuInplacePass);
33 changes: 33 additions & 0 deletions paddle/fluid/framework/ir/ipu/ipu_inplace_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"

namespace paddle {
namespace framework {
namespace ir {

class IpuInplacePass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

private:
void RenameInplaceVar(ir::Node* node) const;
};

} // namespace ir
} // namespace framework
} // namespace paddle
8 changes: 0 additions & 8 deletions paddle/fluid/framework/ir/ipu/ipu_pass_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,10 @@

#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {

class Graph;

void IPUPassBase::Init(const std::string& repr, Graph* graph) const {
repr_ = repr;
graph_ = graph;
Expand Down
11 changes: 0 additions & 11 deletions paddle/fluid/framework/ir/ipu/ipu_pass_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,14 @@

#pragma once

#include <string>

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {

class Graph;
class Node;

class IPUPassBase : public Pass {
public:
void Init(const std::string& repr, Graph* graph) const;
Expand Down
12 changes: 1 addition & 11 deletions paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,6 @@

#include "paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h"

#include <algorithm>
#include <array>
#include <fstream>
#include <iosfwd>
#include <memory>
#include <ostream>
#include <string>
#include <unordered_map>
#include <unordered_set>

#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"

Expand Down Expand Up @@ -124,4 +114,4 @@ REGISTER_PASS(ipu_runtime_replacer_pass,
.RequirePassAttr("feed_list")
.RequirePassAttr("fetch_list");

USE_PASS(graph_viz_pass);
// USE_PASS(graph_viz_pass);
4 changes: 3 additions & 1 deletion paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"

Expand All @@ -26,4 +28,4 @@ class IpuRuntimeReplacerPass : public IPUPassBase {

} // namespace ir
} // namespace framework
} // namespace paddle
} // namespace paddle
Loading

0 comments on commit 8effd41

Please sign in to comment.