Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xpu]add sine_pos fuse pass and sine_pos xpu kernel #60025

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ set(XPU_XBLAS_LIB_NAME "libxpu_blas.so")
set(XPU_XFA_LIB_NAME "libxpu_flash_attention.so")

if(NOT DEFINED XPU_BASE_DATE)
set(XPU_BASE_DATE "20231203")
set(XPU_BASE_DATE "20231218")
endif()
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "20231229")
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ if(WITH_XPU)
${XPU_PASS_DEPS})
pass_library(elementwise_mul_add_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(sine_pos_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
endif()

cc_library(
Expand Down
286 changes: 286 additions & 0 deletions paddle/fluid/framework/ir/xpu/sine_pos_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <map>
#include <string>

#include "glog/logging.h"

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/quantize_helper.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace phi {
class DenseTensor;
} // namespace phi

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
/*
fuse block in vis model to sine_pos_xpu op
------------------------------------------------------
sub block:
x y
\ /
\ /
\ /
mul
/ \
/ \
/ \
slice slice
| |
| |
sin cos
\ /
\ /
\ /
stack
|
|
flatten
|
out
------------------------------------------------------
After the pass is applied:
x y
\ /
\ /
\ /
sine_pos_xpu
|
|
out
*/

struct SinePosXPUPattern : public PatternBase {
SinePosXPUPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(ew_mul);
PATTERN_DECL_NODE(slice1);
PATTERN_DECL_NODE(slice2);
PATTERN_DECL_NODE(sin);
PATTERN_DECL_NODE(cos);
PATTERN_DECL_NODE(stack);
PATTERN_DECL_NODE(flatten);
// declare variable node's name
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(y);
PATTERN_DECL_NODE(ew_mul_out);
PATTERN_DECL_NODE(slice1_out);
PATTERN_DECL_NODE(slice2_out);
PATTERN_DECL_NODE(sin_out);
PATTERN_DECL_NODE(cos_out);
PATTERN_DECL_NODE(stack_out);
PATTERN_DECL_NODE(flatten_out);
};

SinePosXPUPattern::SinePosXPUPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto x = pattern->NewNode(x_repr())
->assert_is_op_input("elementwise_mul", "X")
->assert_more([&](Node* node) {
auto x_shape = node->Var()->GetShape();
size_t x_rank = x_shape.size();
return x_rank == 3 && x_shape.back() == 1;
});
auto y = pattern->NewNode(y_repr())
->assert_is_op_input("elementwise_mul", "Y")
->assert_more([&](Node* node) {
auto x_shape = node->Var()->GetShape();
size_t x_rank = x_shape.size();
return x_rank == 1 && x_shape[0] % 2 == 0;
});
auto* ew_mul = pattern->NewNode(ew_mul_repr())
->assert_is_op("elementwise_mul")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<int>("axis") == -1;
});
auto* ew_mul_out = pattern->NewNode(ew_mul_out_repr())
->assert_is_op_output("elementwise_mul", "Out")
->assert_is_op_input("strided_slice", "Input");
ew_mul->LinksFrom({x, y}).LinksTo({ew_mul_out});
auto* slice1 =
pattern->NewNode(slice1_repr())
->assert_is_op("strided_slice")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<std::vector<int>>("axes") ==
std::vector<int>{2} &&
op_desc->GetAttrIfExists<std::vector<int>>("starts") ==
std::vector<int>{0} &&
op_desc->GetAttrIfExists<std::vector<int>>("strides") ==
std::vector<int>{2};
});
auto* slice1_out = pattern->NewNode(slice1_out_repr())
->assert_is_op_output("strided_slice", "Out")
->assert_is_op_input("sin", "X");
slice1->LinksFrom({ew_mul_out}).LinksTo({slice1_out});
auto* sin = pattern->NewNode(sin_repr())->assert_is_op("sin");
auto* sin_out = pattern->NewNode(sin_out_repr())
->assert_is_op_output("sin", "Out")
->assert_is_op_nth_input("stack", "X", 0);
sin->LinksFrom({slice1_out}).LinksTo({sin_out});
auto* slice2 =
pattern->NewNode(slice2_repr())
->assert_is_op("strided_slice")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<std::vector<int>>("axes") ==
std::vector<int>{2} &&
op_desc->GetAttrIfExists<std::vector<int>>("starts") ==
std::vector<int>{1} &&
op_desc->GetAttrIfExists<std::vector<int>>("strides") ==
std::vector<int>{2};
});
auto* slice2_out = pattern->NewNode(slice2_out_repr())
->assert_is_op_output("strided_slice", "Out")
->assert_is_op_input("cos", "X");
slice2->LinksFrom({ew_mul_out}).LinksTo({slice2_out});
auto* cos = pattern->NewNode(cos_repr())->assert_is_op("cos");
auto* cos_out = pattern->NewNode(cos_out_repr())
->assert_is_op_output("cos", "Out")
->assert_is_op_nth_input("stack", "X", 1);
cos->LinksFrom({slice2_out}).LinksTo({cos_out});
auto* stack = pattern->NewNode(stack_repr())
->assert_is_op("stack")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<int>("axis") == 3;
});
auto* stack_out = pattern->NewNode(stack_out_repr())
->assert_is_op_output("stack", "Y")
->assert_is_op_input("flatten_contiguous_range", "X");
stack->LinksFrom({sin_out, cos_out}).LinksTo({stack_out});

auto* flatten =
pattern->NewNode(flatten_repr())
->assert_is_op("flatten_contiguous_range")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<int>("start_axis") == 2 &&
op_desc->GetAttrIfExists<int>("stop_axis") == 3;
});
auto* flatten_out =
pattern->NewNode(flatten_out_repr())
->assert_is_op_output("flatten_contiguous_range", "Out")
->AsOutput();
flatten->LinksFrom({stack_out}).LinksTo({flatten_out});
}

} // namespace patterns

class SinePosFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

private:
const std::string name_scope_{"sine_pos_fuse_pass"};
};

void SinePosFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);

GraphPatternDetector gpd;
patterns::SinePosXPUPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle SinePosFusePass fuse";
/* declare operator node's name */
// declare operator node's name
GET_IR_NODE(ew_mul);
GET_IR_NODE(slice1);
GET_IR_NODE(slice2);
GET_IR_NODE(sin);
GET_IR_NODE(cos);
GET_IR_NODE(stack);
GET_IR_NODE(flatten);
// declare variable node's name
GET_IR_NODE(x);
GET_IR_NODE(y);
GET_IR_NODE(ew_mul_out);
GET_IR_NODE(slice1_out);
GET_IR_NODE(slice2_out);
GET_IR_NODE(sin_out);
GET_IR_NODE(cos_out);
GET_IR_NODE(stack_out);
GET_IR_NODE(flatten_out);
auto* block = flatten->Op()->Block();
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
// Generate sine_pos_xpu fused op
framework::OpDesc fused_op_desc(block);
fused_op_desc.SetType("sine_pos_xpu");
// set attrs for fused op
fused_op_desc.SetInput("x", {x->Name()});
fused_op_desc.SetInput("y", {y->Name()});

fused_op_desc.SetOutput("out", {flatten_out->Name()});
// relink fused op
auto* fused_op = graph->CreateOpNode(&fused_op_desc);
IR_NODE_LINK_TO(x, fused_op);
IR_NODE_LINK_TO(y, fused_op);
IR_NODE_LINK_TO(fused_op, flatten_out);
// delete useless node
std::unordered_set<const Node*> delete_nodes = {ew_mul,
ew_mul_out,
slice1,
slice1_out,
slice2,
slice2_out,
sin,
sin_out,
cos,
cos_out,
stack,
stack_out,
flatten};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};

gpd(graph, handler);

AddStatis(found_subgraph_count);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(sine_pos_fuse_pass, paddle::framework::ir::SinePosFusePass);

REGISTER_PASS_CAPABILITY(sine_pos_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"sin_pos_xpu", 0));
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"yolo_box_xpu_fuse_pass",
"fast_where_xpu_fuse_pass",
"elementwise_mul_add_fuse_pass",
"sine_pos_fuse_pass",
// "auto_mixed_precision_pass",
"cast_mixed_precision_op_fuse_pass",
"xpu_quantize_op_pass",
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,15 @@
func : self_dp_attention
data_type : x

- op : sine_pos_xpu
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : SinePosXPUInferMeta
kernel :
func : sine_pos_xpu
data_type : x

- op : skip_layernorm
args : (Tensor x, Tensor y, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis)
output : Tensor(out)
Expand Down
7 changes: 5 additions & 2 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"pool3d",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"pow", XPUKernelSet({phi::DataType::FLOAT32})},
{"pow", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"pow_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"pow2_decay_with_linear_warmup", XPUKernelSet({phi::DataType::FLOAT32})},
{"prior_box", XPUKernelSet({phi::DataType::FLOAT32})},
Expand All @@ -707,7 +707,8 @@ XPUOpMap& get_kl2_ops() {
{"reduce_max",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
phi::DataType::INT64,
phi::DataType::FLOAT16})},
{"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean",
XPUKernelSet({phi::DataType::FLOAT32,
Expand Down Expand Up @@ -1171,6 +1172,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
phi::DataType::INT32})},
{"sine_pos_xpu",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
};

return s_xpu2_kernels;
Expand Down
31 changes: 31 additions & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3687,4 +3687,35 @@ void QKVAttentionXPUInferMeta(const MetaTensor& q,
qkv_max->set_dtype(out_dtype);
qkv_max->set_layout(q.layout());
}
void SinePosXPUInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out) {
auto x_dims = x.dims();
auto x_dims_size = x_dims.size();
PADDLE_ENFORCE_EQ(
x_dims_size,
3,
phi::errors::InvalidArgument(
"x_dims_size should be 3, but received x_dims_size is %d",
x_dims_size));
PADDLE_ENFORCE_EQ(x_dims[x_dims_size - 1],
1,
phi::errors::InvalidArgument(
"x last dim size should be 1, but received is %d",
x_dims[x_dims_size - 1]));
auto y_dims = y.dims();
auto y_dims_size = y_dims.size();
PADDLE_ENFORCE_EQ(
y_dims_size,
1,
phi::errors::InvalidArgument(
"x_dims_size should be 3, but received x_dims_size is %d",
y_dims_size));

phi::DDim out_dim = phi::make_ddim({x_dims[0], x_dims[1], y_dims[0]});

out->set_dims(out_dim);
out->set_dtype(x.dtype());
}

} // namespace phi
Loading