Skip to content

Commit

Permalink
[XPU] delete op device (#51029)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang authored Mar 1, 2023
1 parent af149c0 commit c930994
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 8 deletions.
16 changes: 12 additions & 4 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference)
pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(trt_delete_weight_dequant_linear_op_pass inference)
pass_library(delete_op_device_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference)
Expand Down Expand Up @@ -221,13 +222,16 @@ if(WITH_XPU)
SRCS xpu/pass_utils.cc
DEPS pass)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu)
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_slice_fuse_pass inference DIR xpu)
pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu)
pass_library(link_xpu_op_max_pass inference DIR xpu)
pass_library(multi_encoder_xpu_slice_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
endif()

cc_library(
Expand Down Expand Up @@ -372,6 +376,10 @@ cc_test(
test_generate_pass_cc
SRCS generate_pass_tester.cc
DEPS generate_pass pass_desc_proto)
cc_test(
test_delete_op_device_pass
SRCS delete_op_device_pass_test.cc
DEPS delete_op_device_pass)
cc_test(
test_delete_dropout_pass_cc
SRCS delete_dropout_op_pass_test.cc
Expand Down
57 changes: 57 additions & 0 deletions paddle/fluid/framework/ir/delete_op_device_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>
#include "paddle/fluid/framework/ir/pass.h"

namespace phi {
class DenseTensor;
} // namespace phi

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {

// "op_device" attr is only used in model training. "op_device" attr will change
// place of op kernel, so we use "delete_op_device_pass" to remove it.
class DeleteOpDevicePass : public Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};

void DeleteOpDevicePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
int found_subgraph_count = 0;
for (auto* node : graph->Nodes()) {
if (!node->IsOp() || !node->Op()->HasAttr("op_device")) continue;
node->Op()->RemoveAttr("op_device");
found_subgraph_count++;
}
if (found_subgraph_count > 0) {
LOG(INFO) << "--- detected " << found_subgraph_count << " subgraphs";
}
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(delete_op_device_pass, paddle::framework::ir::DeleteOpDevicePass);
52 changes: 52 additions & 0 deletions paddle/fluid/framework/ir/delete_op_device_pass_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>

#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
namespace framework {
namespace ir {

TEST(delete_op_device_pass, relu) {
ProgramDesc program;
auto* x_var = program.MutableBlock(0)->Var("relu_x");
auto* out_var = program.MutableBlock(0)->Var("relu_out");
OpDesc* relu_op = program.MutableBlock(0)->AppendOp();
relu_op->SetType("relu");
relu_op->SetInput("X", {x_var->Name()});
relu_op->SetOutput("Out", {out_var->Name()});
relu_op->SetAttr("op_device", std::string{"gpu:0"});

std::unique_ptr<Graph> graph(new Graph(program));
auto pass = PassRegistry::Instance().Get("delete_op_device_pass");
graph.reset(pass->Apply(graph.release()));
for (auto* node : graph->Nodes()) {
if (!node->IsOp()) continue;
if (node->Op()->Type() == "relu") {
PADDLE_ENFORCE(!node->Op()->HasAttr("op_device"),
platform::errors::InvalidArgument(
"Run delete_op_device_pass failed. Relu op still has "
"'op_device' attr."));
}
}
}

} // namespace ir
} // namespace framework
} // namespace paddle

USE_PASS(delete_op_device_pass);
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass",
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass",
};
"delete_op_device_pass"};

Graph *Pass::Apply(Graph *graph) const {
VLOG(10) << "start to apply pass " << Type() << " to graph";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"multi_encoder_xpu_slice_fuse_pass",
"fc_xpu_fuse_pass",
"link_xpu_op_max_pass",
"delete_op_device_pass",
});
use_xpu_ = true;
}
Expand Down

0 comments on commit c930994

Please sign in to comment.