Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Add op_callstack to Pir #62139

Merged
merged 17 commits into from
Mar 6, 2024
12 changes: 10 additions & 2 deletions paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,17 @@ static utils::Attribute ConvertArrayAttribute(
CASE_ATTRIBUTE(float, FloatAttribute)
} else if (attr_vec[0].isa<::pir::DoubleAttribute>()) {
CASE_ATTRIBUTE(double, DoubleAttribute)
} else if (attr_vec[0].isa<::pir::StrAttribute>()) {
std::vector<std::string> dst_attr;
for (auto element : attr_vec) {
dst_attr.push_back(
element.dyn_cast<::pir::StrAttribute>().AsString());
}
dst_attr = dst_attr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

啊,这里忘记删掉了,如果没有其他问题的话下个 PR 再删吧

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增的 op_callstack 是 Array<Str> 所以这里加了相关处理

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

} else {
LOG(FATAL) << "only support bool/int32/int64/float/double attribute in "
"ArrayAttribute";
LOG(FATAL)
<< "only support bool/int32/int64/float/double/string attribute in "
"ArrayAttribute";
}
}
} else {
Expand Down
56 changes: 54 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/python_c_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h"
#include "paddle/fluid/pybind/op_callstack_utils.h"
#include "paddle/fluid/framework/op_proto_maker.h"


{body}
Expand All @@ -71,8 +74,24 @@
{attrs}

// Call ir static api
pir::InsertionPoint before_insertion_point =
paddle::dialect::ApiBuilder::Instance().GetCurrentInsertionPoint();
auto before_insertion_iterator = before_insertion_point.second;
pir::Attribute callstack_info_attr = get_op_callstack_info();
before_insertion_iterator--;
auto static_api_out = paddle::dialect::{api_name}({args});

before_insertion_iterator++;
pir::InsertionPoint after_insertion_point =
paddle::dialect::ApiBuilder::Instance().GetCurrentInsertionPoint();
PADDLE_ENFORCE_EQ(before_insertion_point.first,after_insertion_point.first);
auto after_insertion_iterator = after_insertion_point.second;
for (auto block_iterator = before_insertion_iterator;
block_iterator != after_insertion_iterator;
block_iterator++) {{
block_iterator->set_attribute(paddle::framework::OpProtoAndCheckerMaker::
OpCreationCallstackAttrName(),
callstack_info_attr);
}}
return ToPyObject(static_api_out);
}} catch (...) {{
ThrowExceptionToPython(std::current_exception());
Expand All @@ -94,8 +113,24 @@
{attrs}

// Call ir static api
pir::InsertionPoint before_insertion_point =
paddle::dialect::ApiBuilder::Instance().GetCurrentInsertionPoint();
auto before_insertion_iterator = before_insertion_point.second;
pir::Attribute callstack_info_attr = get_op_callstack_info();
before_insertion_iterator--;
paddle::dialect::{api_name}({args});

before_insertion_iterator++;
pir::InsertionPoint after_insertion_point =
paddle::dialect::ApiBuilder::Instance().GetCurrentInsertionPoint();
PADDLE_ENFORCE_EQ(before_insertion_point.first,after_insertion_point.first);
auto after_insertion_iterator = after_insertion_point.second;
for (auto block_iterator = before_insertion_iterator;
block_iterator != after_insertion_iterator;
block_iterator++) {{
block_iterator->set_attribute(paddle::framework::OpProtoAndCheckerMaker::
OpCreationCallstackAttrName(),
callstack_info_attr);
}}
return nullptr;
}} catch (...) {{
ThrowExceptionToPython(std::current_exception());
Expand Down Expand Up @@ -129,7 +164,24 @@
{cast_attrs}

// Call ir static api
pir::InsertionPoint before_insertion_point =
paddle::dialect::ApiBuilder::Instance().GetCurrentInsertionPoint();
auto before_insertion_iterator = before_insertion_point.second;
pir::Attribute callstack_info_attr = get_op_callstack_info();
before_insertion_iterator--;
auto static_api_out = paddle::dialect::{api_name}({args_with_mutable_attrs});
before_insertion_iterator++;
pir::InsertionPoint after_insertion_point =
paddle::dialect::ApiBuilder::Instance().GetCurrentInsertionPoint();
PADDLE_ENFORCE_EQ(before_insertion_point.first,after_insertion_point.first);
auto after_insertion_iterator = after_insertion_point.second;
for (auto block_iterator = before_insertion_iterator;
block_iterator != after_insertion_iterator;
block_iterator++) {{
block_iterator->set_attribute(paddle::framework::OpProtoAndCheckerMaker::
OpCreationCallstackAttrName(),
callstack_info_attr);
}}
return ToPyObject(static_api_out);


Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/core/builtin_op.h"
#include "paddle/pir/include/core/parameter.h"

namespace paddle {
namespace dialect {

Expand Down Expand Up @@ -241,6 +242,7 @@ std::tuple<pir::Value, pir::Value> fused_gemm_epilogue(pir::Value x,
{"trans_x", pir::BoolAttribute::get(ctx, trans_x)},
{"trans_y", pir::BoolAttribute::get(ctx, trans_y)},
{"activation", pir::StrAttribute::get(ctx, activation)}};

auto fused_gemm_epilogue_op =
ApiBuilder::Instance()
.GetBuilder()
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ set(PYBIND_SRCS
auto_parallel_py.cc
eval_frame_tools.cc
cpython_internals.c
eval_frame.c)
eval_frame.c
op_callstack_utils.cc)

if(NOT WITH_SHARED_IR)
# Note: We want to compile pir source into paddle.so directly, because
Expand Down
Loading