Skip to content

Commit

Permalink
[CINN] Add more rules to cinn_to_pd_op pass (PaddlePaddle#64354)
Browse files Browse the repository at this point in the history
* add more rules to cinn_to_pd_op pass

* update
  • Loading branch information
chen2016013 authored and co63oc committed May 19, 2024
1 parent d101508 commit 7b05425
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 18 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : uniform_random
args : (int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0)
args : (int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0, Place place={})
output : Tensor(out)
infer_meta :
func : CreateVecShapeInferMeta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class AddAccuracyCheckPattern
builder.set_insertion_point(fusion_op);

const auto& InsertAccuaryCheckOp = [&](::pir::Operation* op) -> void {
rewriter.SetInsertionPointAfter(fusion_op);
for (size_t i = 0; i < op->num_operands(); ++i) {
rewriter.Build<paddle::dialect::AccuracyCheckOp>(
fusion_op.result(i),
Expand All @@ -67,6 +66,7 @@ class AddAccuracyCheckPattern
i);
}
};

const auto& ConvertCinnOpToPdOp = [&](::pir::Operation* op) -> void {
rewriter.SetInsertionPointAfter(fusion_op);
for (size_t i = 0; i < op->num_operands(); ++i) {
Expand All @@ -86,6 +86,7 @@ class AddAccuracyCheckPattern
}
auto new_op = op->Clone(ir_mapping, clone_options);
rewriter.Insert(new_op);
rewriter.SetInsertionPointAfter(new_op);
};

for (auto& op : op_list) {
Expand All @@ -103,7 +104,7 @@ class AddAccuracyCheckPattern

class AccuarcyCheckPass : public pir::Pass {
public:
AccuarcyCheckPass() : pir::Pass("accuracy_check_pass", /*opt_level=*/4) {}
AccuarcyCheckPass() : pir::Pass("accuracy_check_pass", /*opt_level=*/3) {}

bool Initialize(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
Expand Down
193 changes: 180 additions & 13 deletions paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/phi/common/place.h"
#include "paddle/pir/include/core/builtin_dialect.h"
#include "paddle/pir/include/core/ir_mapping.h"
namespace cinn::dialect::details {
Expand All @@ -37,8 +38,8 @@ pir::Attribute ArrayAttributeToIntArrayAttribute(
data.push_back(attr.dyn_cast<::pir::Int64Attribute>().data());
}
}
pir::Attribute attr_data = paddle::dialect::IntArrayAttribute::get(
pir::IrContext::Instance(), phi::IntArray(data));
::pir::Attribute attr_data = paddle::dialect::IntArrayAttribute::get(
::pir::IrContext::Instance(), phi::IntArray(data));
return attr_data;
}

Expand All @@ -49,7 +50,7 @@ const auto& handler_reduce_sum_op =
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();

pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
::pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
attrs.insert({"axis", attr_axis});
attrs.insert({"dtype", attrs["dtype"]});
Expand All @@ -74,7 +75,7 @@ const auto& handler_reduce_max_op =

// TODO(chenxi67): 1. CINN op Dialect Normalization;2.AST Op compute
// Normalization
pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
::pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
attrs.insert({"axis", attr_axis});
attrs.insert({"keepdim", attrs["keep_dim"]});
Expand All @@ -96,7 +97,7 @@ const auto& handler_reduce_min_op =
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();

pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
::pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
attrs.insert({"axis", attr_axis});
attrs.insert({"keepdim", attrs["keep_dim"]});
Expand All @@ -118,7 +119,7 @@ const auto& handler_reduce_prod_op =
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();

pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
::pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute(
attrs.at("dim").dyn_cast<::pir::ArrayAttribute>());
attrs.insert({"dims", attr_axis});
attrs.erase("dim");
Expand All @@ -136,9 +137,9 @@ ::pir::Operation* ConvertSliceOp(::pir::Operation* op,
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();
pir::Attribute starts = ArrayAttributeToIntArrayAttribute(
::pir::Attribute starts = ArrayAttributeToIntArrayAttribute(
attrs.at("starts").dyn_cast<::pir::ArrayAttribute>());
pir::Attribute ends = ArrayAttributeToIntArrayAttribute(
::pir::Attribute ends = ArrayAttributeToIntArrayAttribute(
attrs.at("ends").dyn_cast<::pir::ArrayAttribute>());
attrs["starts"] = starts;
attrs["ends"] = ends;
Expand All @@ -155,8 +156,7 @@ ::pir::Operation* ConvertReshapeOp(::pir::Operation* op,
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();
attrs.at("shape").dyn_cast<::pir::ArrayAttribute>();
pir::Attribute shape = ArrayAttributeToIntArrayAttribute(
::pir::Attribute shape = ArrayAttributeToIntArrayAttribute(
attrs.at("shape").dyn_cast<::pir::ArrayAttribute>());
attrs["shape"] = shape;
auto pd_op = builder.Build<paddle::dialect::ReshapeOp>(
Expand All @@ -172,9 +172,6 @@ ::pir::Operation* ConvertConcatOp(::pir::Operation* op,
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();
for (auto item : attrs) {
VLOG(0) << item.first;
}
std::vector<pir::Value> vec_inputs;
for (uint32_t i = 0; i < op->num_operands(); ++i) {
vec_inputs.push_back(ir_mapping.Lookup(op->operand_source(i)));
Expand All @@ -190,6 +187,144 @@ ::pir::Operation* ConvertConcatOp(::pir::Operation* op,
return pd_op;
}

::pir::Operation* ConvertScaleOp(::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();

float scale = attrs.at("scale").dyn_cast<::pir::FloatAttribute>().data();
float bias = attrs.at("bias").dyn_cast<pir::FloatAttribute>().data();
bool bias_after_scale =
attrs.at("bias_after_scale").dyn_cast<pir::BoolAttribute>().data();
auto pd_op = builder.Build<paddle::dialect::ScaleOp>(
ir_mapping.Lookup(op->operand_source(0)), scale, bias, bias_after_scale);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
}

::pir::Operation* ConvertFlipOp(::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();
auto pd_op = builder.Build<paddle::dialect::FlipOp>(
ir_mapping.Lookup(op->operand_source(0)), attrs);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
}

::pir::Operation* ConvertPool2dOp(::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();
::pir::Attribute kernel_size = ArrayAttributeToIntArrayAttribute(
attrs.at("kernel_size").dyn_cast<::pir::ArrayAttribute>());
attrs["kernel_size"] = kernel_size;
attrs["strides"] = attrs.at("stride_size");
attrs["paddings"] = attrs.at("padding_size");
attrs.erase("stride_size");
attrs.erase("padding_size");
auto pd_op = builder.Build<paddle::dialect::Pool2dOp>(
ir_mapping.Lookup(op->operand_source(0)), attrs);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
}

::pir::Operation* ConvertIscloseOp(::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();
double rtol = attrs.at("atol").dyn_cast<pir::FloatAttribute>().data();
double atol = attrs.at("atol").dyn_cast<pir::FloatAttribute>().data();
bool equal_nan = attrs.at("equal_nan").dyn_cast<pir::BoolAttribute>().data();
auto pd_op = builder.Build<paddle::dialect::IscloseOp>(
ir_mapping.Lookup(op->operand_source(0)),
ir_mapping.Lookup(op->operand_source(1)),
rtol,
atol,
equal_nan);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
}

::pir::Operation* ConvertExpandOp(::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();

std::vector<int64_t> shape_;
auto attr_shape = attrs.at("out_shape").dyn_cast<::pir::ArrayAttribute>();
for (size_t i = 0; i < attr_shape.size(); ++i) {
shape_.push_back(attr_shape.at(i).dyn_cast<::pir::Int64Attribute>().data());
}

paddle::dialect::FullIntArrayOp full_shape_op =
builder.Build<paddle::dialect::FullIntArrayOp>(
shape_, phi::DataType::INT64, phi::CPUPlace());
::pir::Value out_shape = full_shape_op->result(0);
auto pd_op = builder.Build<paddle::dialect::ExpandOp>(
ir_mapping.Lookup(op->operand_source(0)), out_shape);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
}

::pir::Operation* ConvertUniformOp(::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();
std::vector<int64_t> shape;
auto attr_shape = attrs.at("out_shape").dyn_cast<::pir::ArrayAttribute>();
for (size_t i = 0; i < attr_shape.size(); ++i) {
shape.push_back(attr_shape.at(i).dyn_cast<::pir::Int64Attribute>().data());
}
::phi::DataType dtype =
attrs.at("dtype").dyn_cast<paddle::dialect::DataTypeAttribute>().data();

float min = attrs.at("min").dyn_cast<pir::FloatAttribute>().data();
float max = attrs.at("max").dyn_cast<pir::FloatAttribute>().data();
float seed = attrs.at("diag_num").dyn_cast<pir::FloatAttribute>().data();
::phi::Place place =
attrs.at("place").dyn_cast<paddle::dialect::PlaceAttribute>().data();

auto pd_op = builder.Build<paddle::dialect::UniformOp>(
shape, dtype, min, max, seed, place);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
}

::pir::Operation* ConvertGatherOp(::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto attrs = op->attributes();
int axis = attrs.at("axis").dyn_cast<pir::Int32Attribute>().data();
auto pd_op = builder.Build<paddle::dialect::GatherOp>(
ir_mapping.Lookup(op->operand_source(0)),
ir_mapping.Lookup(op->operand_source(1)),
axis);
for (uint32_t i = 0; i < op->num_results(); ++i) {
ir_mapping.Add(op->result(i), pd_op->result(i));
}
return pd_op;
}

bool CanApplyOn(::pir::Operation* op) {
return op->dialect()->name() == "cinn_op";
}
Expand Down Expand Up @@ -262,3 +397,35 @@ REGISTER_TRANSFORM_RULES(reshape_op,
REGISTER_TRANSFORM_RULES(concat_op,
cinn::dialect::ConcatOp::name(),
cinn::dialect::details::ConvertConcatOp);

REGISTER_TRANSFORM_RULES(scale_op,
cinn::dialect::ScaleOp::name(),
cinn::dialect::details::ConvertScaleOp);

REGISTER_TRANSFORM_RULES(
flip_op,
cinn::dialect::ReverseOp::name(), // cinn::dialect::ReverseOp <->
// paddle::dialect::FlipOp
cinn::dialect::details::ConvertFlipOp);

REGISTER_TRANSFORM_RULES(pool2d_op,
cinn::dialect::Pool2dOp::name(),
cinn::dialect::details::ConvertPool2dOp);

REGISTER_TRANSFORM_RULES(isclose_op,
cinn::dialect::IscloseOp::name(),
cinn::dialect::details::ConvertIscloseOp);

REGISTER_TRANSFORM_RULES(
expand_op,
cinn::dialect::BroadcastOp::name(),
cinn::dialect::details::ConvertExpandOp); // cinn::dialect::BroadcastOp <->
// paddle::dialect::ExpandOp

REGISTER_TRANSFORM_RULES(uniform_op,
cinn::dialect::UniformRandomOp::name(),
cinn::dialect::details::ConvertUniformOp);

REGISTER_TRANSFORM_RULES(gather_op,
cinn::dialect::GatherOp::name(),
cinn::dialect::details::ConvertGatherOp);
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,8 @@ class UniformOpPattern : public paddle::drr::DrrPatternBase {
{"dtype", pattern.Attr("uniform_dtype")},
{"diag_num", pattern.Attr("seed")},
{"diag_step", pattern.Attr("seed")},
{"diag_val", pattern.Attr("min_value")}});
{"diag_val", pattern.Attr("min_value")},
{"place", pattern.Attr("uniform_place")}});
res.Tensor("ret") = cinn_uniform();
}
};
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
'pir::ArrayAttribute<pir::StrAttribute>',
'const std::vector<std::string>&',
],
'Place': ['paddle::dialect::PlaceAttribute', 'const Place&'],
'Place': ['paddle::dialect::PlaceAttribute', 'const phi::Place&'],
'DataLayout': [
'paddle::dialect::DataLayoutAttribute',
'DataLayout',
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/python_c_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
"std::vector<phi::Scalar>": "CastPyArg2ScalarArray",
"paddle::experimental::IntArray": "CastPyArg2IntArray",
"paddle::Place": "CastPyArg2Place",
"phi::Place": "CastPyArg2Place",
"Place": "CastPyArg2Place",
"phi::DataType": "CastPyArg2DataTypeDirectly",
}
Expand Down

0 comments on commit 7b05425

Please sign in to comment.