diff --git a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml index 6c9f976bf941b..a8eac75248186 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml +++ b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml @@ -117,7 +117,7 @@ interfaces : paddle::dialect::InferSymbolicShapeInterface - op : uniform_random - args : (int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0) + args : (int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0, Place place={}) output : Tensor(out) infer_meta : func : CreateVecShapeInferMeta diff --git a/paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.cc index 34dc952e1f71c..f9be06cc701c0 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.cc @@ -58,7 +58,6 @@ class AddAccuracyCheckPattern builder.set_insertion_point(fusion_op); const auto& InsertAccuaryCheckOp = [&](::pir::Operation* op) -> void { - rewriter.SetInsertionPointAfter(fusion_op); for (size_t i = 0; i < op->num_operands(); ++i) { rewriter.Build( fusion_op.result(i), @@ -67,6 +66,7 @@ class AddAccuracyCheckPattern i); } }; + const auto& ConvertCinnOpToPdOp = [&](::pir::Operation* op) -> void { rewriter.SetInsertionPointAfter(fusion_op); for (size_t i = 0; i < op->num_operands(); ++i) { @@ -86,6 +86,7 @@ class AddAccuracyCheckPattern } auto new_op = op->Clone(ir_mapping, clone_options); rewriter.Insert(new_op); + rewriter.SetInsertionPointAfter(new_op); }; for (auto& op : op_list) { @@ -103,7 +104,7 @@ class AddAccuracyCheckPattern class AccuarcyCheckPass : public pir::Pass { public: - AccuarcyCheckPass() : pir::Pass("accuracy_check_pass", /*opt_level=*/4) {} + AccuarcyCheckPass() : pir::Pass("accuracy_check_pass", /*opt_level=*/3) {} bool Initialize(pir::IrContext* context) override { pir::RewritePatternSet ps(context); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc index df4e11668a5c4..501f34411a3d1 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/phi/common/place.h" #include "paddle/pir/include/core/builtin_dialect.h" #include "paddle/pir/include/core/ir_mapping.h" namespace cinn::dialect::details { @@ -37,8 +38,8 @@ pir::Attribute ArrayAttributeToIntArrayAttribute( data.push_back(attr.dyn_cast<::pir::Int64Attribute>().data()); } } - pir::Attribute attr_data = paddle::dialect::IntArrayAttribute::get( - pir::IrContext::Instance(), phi::IntArray(data)); + ::pir::Attribute attr_data = paddle::dialect::IntArrayAttribute::get( + ::pir::IrContext::Instance(), phi::IntArray(data)); return attr_data; } @@ -49,7 +50,7 @@ const auto& handler_reduce_sum_op = VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; auto attrs = op->attributes(); - pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute( + ::pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute( attrs.at("dim").dyn_cast<::pir::ArrayAttribute>()); attrs.insert({"axis", attr_axis}); attrs.insert({"dtype", attrs["dtype"]}); @@ -74,7 +75,7 @@ const auto& handler_reduce_max_op = // TODO(chenxi67): 1. CINN op Dialect Normalization;2.AST Op compute // Normalization - pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute( + ::pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute( attrs.at("dim").dyn_cast<::pir::ArrayAttribute>()); attrs.insert({"axis", attr_axis}); attrs.insert({"keepdim", attrs["keep_dim"]}); @@ -96,7 +97,7 @@ const auto& handler_reduce_min_op = VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; auto attrs = op->attributes(); - pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute( + ::pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute( attrs.at("dim").dyn_cast<::pir::ArrayAttribute>()); attrs.insert({"axis", attr_axis}); attrs.insert({"keepdim", attrs["keep_dim"]}); @@ -118,7 +119,7 @@ const auto& handler_reduce_prod_op = VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; auto attrs = op->attributes(); - pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute( + ::pir::Attribute attr_axis = ArrayAttributeToIntArrayAttribute( attrs.at("dim").dyn_cast<::pir::ArrayAttribute>()); attrs.insert({"dims", attr_axis}); attrs.erase("dim"); @@ -136,9 +137,9 @@ ::pir::Operation* ConvertSliceOp(::pir::Operation* op, ::pir::Builder& builder) { // NOLINT VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; auto attrs = op->attributes(); - pir::Attribute starts = ArrayAttributeToIntArrayAttribute( + ::pir::Attribute starts = ArrayAttributeToIntArrayAttribute( attrs.at("starts").dyn_cast<::pir::ArrayAttribute>()); - pir::Attribute ends = ArrayAttributeToIntArrayAttribute( + ::pir::Attribute ends = ArrayAttributeToIntArrayAttribute( attrs.at("ends").dyn_cast<::pir::ArrayAttribute>()); attrs["starts"] = starts; attrs["ends"] = ends; @@ -155,8 +156,7 @@ ::pir::Operation* ConvertReshapeOp(::pir::Operation* op, ::pir::Builder& builder) { // NOLINT VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; auto attrs = op->attributes(); - attrs.at("shape").dyn_cast<::pir::ArrayAttribute>(); - pir::Attribute shape = ArrayAttributeToIntArrayAttribute( + ::pir::Attribute shape = ArrayAttributeToIntArrayAttribute( attrs.at("shape").dyn_cast<::pir::ArrayAttribute>()); attrs["shape"] = shape; auto pd_op = builder.Build( @@ -172,9 +172,6 @@ ::pir::Operation* ConvertConcatOp(::pir::Operation* op, ::pir::Builder& builder) { // NOLINT VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; auto attrs = op->attributes(); - for (auto item : attrs) { - VLOG(0) << item.first; - } std::vector vec_inputs; for (uint32_t i = 0; i < op->num_operands(); ++i) { vec_inputs.push_back(ir_mapping.Lookup(op->operand_source(i))); @@ -190,6 +187,144 @@ ::pir::Operation* ConvertConcatOp(::pir::Operation* op, return pd_op; } +::pir::Operation* ConvertScaleOp(::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) { // NOLINT + VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; + auto attrs = op->attributes(); + + float scale = attrs.at("scale").dyn_cast<::pir::FloatAttribute>().data(); + float bias = attrs.at("bias").dyn_cast().data(); + bool bias_after_scale = + attrs.at("bias_after_scale").dyn_cast().data(); + auto pd_op = builder.Build( + ir_mapping.Lookup(op->operand_source(0)), scale, bias, bias_after_scale); + for (uint32_t i = 0; i < op->num_results(); ++i) { + ir_mapping.Add(op->result(i), pd_op->result(i)); + } + return pd_op; +} + +::pir::Operation* ConvertFlipOp(::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) { // NOLINT + VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; + auto attrs = op->attributes(); + auto pd_op = builder.Build( + ir_mapping.Lookup(op->operand_source(0)), attrs); + for (uint32_t i = 0; i < op->num_results(); ++i) { + ir_mapping.Add(op->result(i), pd_op->result(i)); + } + return pd_op; +} + +::pir::Operation* ConvertPool2dOp(::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) { // NOLINT + VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; + auto attrs = op->attributes(); + ::pir::Attribute kernel_size = ArrayAttributeToIntArrayAttribute( + attrs.at("kernel_size").dyn_cast<::pir::ArrayAttribute>()); + attrs["kernel_size"] = kernel_size; + attrs["strides"] = attrs.at("stride_size"); + attrs["paddings"] = attrs.at("padding_size"); + attrs.erase("stride_size"); + attrs.erase("padding_size"); + auto pd_op = builder.Build( + ir_mapping.Lookup(op->operand_source(0)), attrs); + for (uint32_t i = 0; i < op->num_results(); ++i) { + ir_mapping.Add(op->result(i), pd_op->result(i)); + } + return pd_op; +} + +::pir::Operation* ConvertIscloseOp(::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) { // NOLINT + VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; + auto attrs = op->attributes(); + double rtol = attrs.at("atol").dyn_cast().data(); + double atol = attrs.at("atol").dyn_cast().data(); + bool equal_nan = attrs.at("equal_nan").dyn_cast().data(); + auto pd_op = builder.Build( + ir_mapping.Lookup(op->operand_source(0)), + ir_mapping.Lookup(op->operand_source(1)), + rtol, + atol, + equal_nan); + for (uint32_t i = 0; i < op->num_results(); ++i) { + ir_mapping.Add(op->result(i), pd_op->result(i)); + } + return pd_op; +} + +::pir::Operation* ConvertExpandOp(::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) { // NOLINT + VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; + auto attrs = op->attributes(); + + std::vector shape_; + auto attr_shape = attrs.at("out_shape").dyn_cast<::pir::ArrayAttribute>(); + for (size_t i = 0; i < attr_shape.size(); ++i) { + shape_.push_back(attr_shape.at(i).dyn_cast<::pir::Int64Attribute>().data()); + } + + paddle::dialect::FullIntArrayOp full_shape_op = + builder.Build( + shape_, phi::DataType::INT64, phi::CPUPlace()); + ::pir::Value out_shape = full_shape_op->result(0); + auto pd_op = builder.Build( + ir_mapping.Lookup(op->operand_source(0)), out_shape); + for (uint32_t i = 0; i < op->num_results(); ++i) { + ir_mapping.Add(op->result(i), pd_op->result(i)); + } + return pd_op; +} + +::pir::Operation* ConvertUniformOp(::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) { // NOLINT + VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; + auto attrs = op->attributes(); + std::vector shape; + auto attr_shape = attrs.at("out_shape").dyn_cast<::pir::ArrayAttribute>(); + for (size_t i = 0; i < attr_shape.size(); ++i) { + shape.push_back(attr_shape.at(i).dyn_cast<::pir::Int64Attribute>().data()); + } + ::phi::DataType dtype = + attrs.at("dtype").dyn_cast().data(); + + float min = attrs.at("min").dyn_cast().data(); + float max = attrs.at("max").dyn_cast().data(); + float seed = attrs.at("diag_num").dyn_cast().data(); + ::phi::Place place = + attrs.at("place").dyn_cast().data(); + + auto pd_op = builder.Build( + shape, dtype, min, max, seed, place); + for (uint32_t i = 0; i < op->num_results(); ++i) { + ir_mapping.Add(op->result(i), pd_op->result(i)); + } + return pd_op; +} + +::pir::Operation* ConvertGatherOp(::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) { // NOLINT + VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; + auto attrs = op->attributes(); + int axis = attrs.at("axis").dyn_cast().data(); + auto pd_op = builder.Build( + ir_mapping.Lookup(op->operand_source(0)), + ir_mapping.Lookup(op->operand_source(1)), + axis); + for (uint32_t i = 0; i < op->num_results(); ++i) { + ir_mapping.Add(op->result(i), pd_op->result(i)); + } + return pd_op; +} + bool CanApplyOn(::pir::Operation* op) { return op->dialect()->name() == "cinn_op"; } @@ -262,3 +397,35 @@ REGISTER_TRANSFORM_RULES(reshape_op, REGISTER_TRANSFORM_RULES(concat_op, cinn::dialect::ConcatOp::name(), cinn::dialect::details::ConvertConcatOp); + +REGISTER_TRANSFORM_RULES(scale_op, + cinn::dialect::ScaleOp::name(), + cinn::dialect::details::ConvertScaleOp); + +REGISTER_TRANSFORM_RULES( + flip_op, + cinn::dialect::ReverseOp::name(), // cinn::dialect::ReverseOp <-> + // paddle::dialect::FlipOp + cinn::dialect::details::ConvertFlipOp); + +REGISTER_TRANSFORM_RULES(pool2d_op, + cinn::dialect::Pool2dOp::name(), + cinn::dialect::details::ConvertPool2dOp); + +REGISTER_TRANSFORM_RULES(isclose_op, + cinn::dialect::IscloseOp::name(), + cinn::dialect::details::ConvertIscloseOp); + +REGISTER_TRANSFORM_RULES( + expand_op, + cinn::dialect::BroadcastOp::name(), + cinn::dialect::details::ConvertExpandOp); // cinn::dialect::BroadcastOp <-> + // paddle::dialect::ExpandOp + +REGISTER_TRANSFORM_RULES(uniform_op, + cinn::dialect::UniformRandomOp::name(), + cinn::dialect::details::ConvertUniformOp); + +REGISTER_TRANSFORM_RULES(gather_op, + cinn::dialect::GatherOp::name(), + cinn::dialect::details::ConvertGatherOp); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 5f1f734daa9a8..648b3af363241 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -794,7 +794,8 @@ class UniformOpPattern : public paddle::drr::DrrPatternBase { {"dtype", pattern.Attr("uniform_dtype")}, {"diag_num", pattern.Attr("seed")}, {"diag_step", pattern.Attr("seed")}, - {"diag_val", pattern.Attr("min_value")}}); + {"diag_val", pattern.Attr("min_value")}, + {"place", pattern.Attr("uniform_place")}}); res.Tensor("ret") = cinn_uniform(); } }; diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 678d1274cb467..2829d6e4dc624 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -372,7 +372,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ 'pir::ArrayAttribute', 'const std::vector&', ], - 'Place': ['paddle::dialect::PlaceAttribute', 'const Place&'], + 'Place': ['paddle::dialect::PlaceAttribute', 'const phi::Place&'], 'DataLayout': [ 'paddle::dialect::DataLayoutAttribute', 'DataLayout', diff --git a/paddle/fluid/pir/dialect/op_generator/python_c_gen.py b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py index be28dba5ca15c..a1ab47d51eaf7 100644 --- a/paddle/fluid/pir/dialect/op_generator/python_c_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py @@ -203,6 +203,7 @@ "std::vector": "CastPyArg2ScalarArray", "paddle::experimental::IntArray": "CastPyArg2IntArray", "paddle::Place": "CastPyArg2Place", + "phi::Place": "CastPyArg2Place", "Place": "CastPyArg2Place", "phi::DataType": "CastPyArg2DataTypeDirectly", }