Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix Bug] Fix Bugs of Two Pass #60626

Merged
merged 5 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 17 additions & 44 deletions paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,51 +314,14 @@ class SubstituteDimExprHelper final {
DimExpr4SymbolNameT DimExpr4SymbolName_;
};

std::optional<DimExpr> SubstituteDimExpr(
DimExpr SubstituteDimExpr(
const DimExpr& dim_expr,
const std::function<std::optional<DimExpr>(const std::string& symbol_name)>&
DimExpr4SymbolName) {
return SubstituteDimExprHelper(DimExpr4SymbolName).Substitute(dim_expr);
}

std::function<std::optional<DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const std::vector<std::tuple<std::string /*symbol_name*/,
int /*in_tensor_idx*/,
int /*in_tensor_dim_idx*/>>& symbol_bindings,
const std::function<std::optional<DimExpr>(
int in_tensor_idx, int in_tensor_dim_idx)>& DimExpr4InputDim) {
std::unordered_map<std::string, std::vector<std::pair<int, int>>>
symbol_name2in_tensor_dim_pos;
for (const auto& tuple : symbol_bindings) {
const auto& [symbol_name, in_tensor_idx, in_tensor_dim_idx] = tuple;
symbol_name2in_tensor_dim_pos[symbol_name].emplace_back(
std::pair{in_tensor_idx, in_tensor_dim_idx});
}
return [map = std::move(symbol_name2in_tensor_dim_pos), DimExpr4InputDim](
const std::string& symbol_name) -> std::optional<DimExpr> {
const auto& iter = map.find(symbol_name);
if (iter == map.end()) {
return std::nullopt;
}
const auto& positions = iter->second;
std::optional<DimExpr> ret = std::nullopt;
for (const auto& [in_tensor_idx, in_tensor_dim_idx] : positions) {
const auto& current = DimExpr4InputDim(in_tensor_idx, in_tensor_dim_idx);
if (!current.has_value()) {
return std::nullopt;
}
if (ret.has_value()) {
// Same names, same DimExprs.
if (ret.value() != current.value()) {
return std::nullopt;
}
} else {
ret = current;
}
}
return ret;
};
const auto& opt_substituted =
SubstituteDimExprHelper(DimExpr4SymbolName).Substitute(dim_expr);
if (opt_substituted.has_value()) return opt_substituted.value();
return dim_expr;
}

namespace {
Expand Down Expand Up @@ -387,6 +350,12 @@ std::optional<DimExpr> GetDimExprBySymbolBindingImpl(
return shape_or_data_dim_expr.shape().at(dim_idx);
}

std::string GetSymbolNameBySymbolBinding(
const GenerateShapeOp::SymbolBinding& symbol_binding) {
return std::visit([](const auto& impl) { return impl.symbol_name; },
symbol_binding);
}

} // namespace

std::function<std::optional<DimExpr>(const std::string& symbol_name)>
Expand All @@ -396,6 +365,10 @@ MakeGetterDimExpr4SymbolName(
DimExpr4InputDim) {
std::unordered_map<std::string, std::vector<GenerateShapeOp::SymbolBinding>>
symbol_name2symbol_bindins{};
for (const auto& symbol_binding : symbol_bindings) {
symbol_name2symbol_bindins[GetSymbolNameBySymbolBinding(symbol_binding)]
.emplace_back(symbol_binding);
}
const auto& GetDimExpr =
[&](const GenerateShapeOp::SymbolBinding& symbol_binding) {
return std::visit(
Expand Down Expand Up @@ -596,14 +569,14 @@ void GenerateSymbolBindings(
std::vector<pir::Value> GetMinimalInputs(
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
const std::vector<pir::Value>& input_tensors) {
std::unordered_set<symbol::DimExpr> handdled_dim_exprs;
std::unordered_set<symbol::DimExpr> handled_dim_exprs;
std::unordered_set<pir::Value> first_occurred_input_tensors;
auto TryCollectFirstOcurredInput_tensor =
[&](pir::Value input_tensor,
const std::vector<symbol::DimExpr>& dim_exprs) {
for (const auto& dim_expr : dim_exprs) {
if (dim_expr.isa<int64_t>()) continue;
if (!handdled_dim_exprs.insert(dim_expr).second) {
if (handled_dim_exprs.insert(dim_expr).second) {
first_occurred_input_tensors.insert(input_tensor);
}
}
Expand Down
12 changes: 2 additions & 10 deletions paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,19 @@ ::pir::Attribute ConvertDimExprToAttribute(pir::IrContext* ctx,
std::optional<symbol::DimExpr> ConvertAttributeToDimExpr(
::pir::Attribute attribute);

std::optional<symbol::DimExpr> SubstituteDimExpr(
symbol::DimExpr SubstituteDimExpr(
const symbol::DimExpr& dim_expr,
const std::function<std::optional<symbol::DimExpr>(
const std::string& symbol_name)>& DimExpr4SymbolName);

std::function<std::optional<symbol::DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const std::vector<std::tuple<std::string /*symbol_name*/,
int /*in_tensor_idx*/,
int /*in_tensor_dim_idx*/>>& symbol_bindings,
const std::function<std::optional<symbol::DimExpr>(
int in_tensor_idx, int in_tensor_dim_idx)>& DimExpr4InputDim);

std::function<std::optional<symbol::DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const GenerateShapeOp::SymbolBindings& symbol_bindings,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim);

using ShapeOrDataDimExprs4ValueT =
std::function<const symbol::ShapeOrDataDimExprs&(pir::Value)>;
std::function<symbol::ShapeOrDataDimExprs(pir::Value)>;

// Returns true if success.
bool MakeGenerateShapeOpAttribute(
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void OperatorDialect::initialize() {
RegisterOp<GroupOp>();
RegisterOp<ConcatOp>();
RegisterOp<SplitOp>();
RegisterOp<GenerateShapeOp>();
RegisterAttribute<GroupInfoAttribute>();
RegisterAttribute<CINNKernelInfoAttribute>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ namespace cinn {
namespace dialect {
namespace ir {

namespace {

pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter,
pir::Value x,
pir::Value y) {
Expand All @@ -42,6 +44,10 @@ pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter,
}

bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
if (op->operand_source(0).defining_op()->isa<paddle::dialect::ExpandOp>() &&
op->operand_source(1).defining_op()->isa<paddle::dialect::ExpandOp>()) {
return false;
}
pir::Value x = op->operand_source(0);
pir::Value y = op->operand_source(1);
pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y);
Expand All @@ -58,6 +64,8 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
return true;
}

} // namespace

template <typename OPTYPE>
class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/dialect/shape/utils/shape_utils.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"
#include "paddle/pir/pattern_rewrite/pattern_match.h"
Expand All @@ -38,6 +39,9 @@ namespace ir {

namespace {

using ShapeOrDataDimExprs4ValueT =
std::function<symbol::ShapeOrDataDimExprs(pir::Value)>;

std::vector<pir::Value> FindSourceDenseTensorOfDimTensor(
pir::Value shape,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) {
Expand Down Expand Up @@ -126,9 +130,19 @@ std::optional<pir::Value> GetOutOfRewritedGenerateShapeOp(
.out();
}

bool ProcessOp(paddle::dialect::ExpandOp op,
pir::PatternRewriter* rewriter,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) {
bool ProcessOp(paddle::dialect::ExpandOp op, pir::PatternRewriter* rewriter) {
if (op.shape().defining_op()->isa<cinn::dialect::GenerateShapeOp>()) {
return false;
}
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value =
[&op](pir::Value value) -> symbol::ShapeOrDataDimExprs {
pir::ShapeConstraintIRAnalysis& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(
op.x().defining_op()->GetParentProgram());
CHECK(shape_analysis.value_id_to_shapeordata_.find(GetValueId(&value)) !=
shape_analysis.value_id_to_shapeordata_.end());
return shape_analysis.value_id_to_shapeordata_.at(GetValueId(&value));
Comment on lines +142 to +144
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value_id_to_shapeordata_在develop分支里弃用了,换成调用Get和Set函数吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,准备另提一个 PR 出来

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

};
std::optional<pir::Value> opt_generated_shape =
GetOutOfRewritedGenerateShapeOp(
op.shape(), rewriter, ShapeOrDataDimExprs4Value);
Expand All @@ -143,32 +157,25 @@ template <typename OPTYPE>
class FuseShapeOpsIntoGenerateShapeOpPattern
: public pir::OpRewritePattern<OPTYPE> {
public:
FuseShapeOpsIntoGenerateShapeOpPattern(
pir::IrContext* context,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value)
: pir::OpRewritePattern<OPTYPE>(context),
ShapeOrDataDimExprs4Value_(ShapeOrDataDimExprs4Value) {}
explicit FuseShapeOpsIntoGenerateShapeOpPattern(pir::IrContext* context)
: pir::OpRewritePattern<OPTYPE>(context) {}

bool MatchAndRewrite(OPTYPE op,
pir::PatternRewriter& rewriter) const override {
return ProcessOp(op, &rewriter, ShapeOrDataDimExprs4Value_);
return ProcessOp(op, &rewriter);
}

private:
ShapeOrDataDimExprs4ValueT ShapeOrDataDimExprs4Value_;
};

FuseShapeOpsIntoGenerateShapeOpPass::FuseShapeOpsIntoGenerateShapeOpPass(
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value)
: pir::PatternRewritePass("fuse_shape_ops_into_generate_shape_op_pass", 1),
ShapeOrDataDimExprs4Value_(ShapeOrDataDimExprs4Value) {}
FuseShapeOpsIntoGenerateShapeOpPass::FuseShapeOpsIntoGenerateShapeOpPass()
: pir::PatternRewritePass("fuse_shape_ops_into_generate_shape_op_pass", 1) {
}

pir::RewritePatternSet FuseShapeOpsIntoGenerateShapeOpPass::InitializePatterns(
pir::IrContext* context) {
pir::RewritePatternSet ps(context);
// elementwise ops
ps.Add<FuseShapeOpsIntoGenerateShapeOpPattern<paddle::dialect::ExpandOp>>(
context, ShapeOrDataDimExprs4Value_);
context);

return ps;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,11 @@ namespace ir {

class FuseShapeOpsIntoGenerateShapeOpPass : public pir::PatternRewritePass {
public:
using ShapeOrDataDimExprs4ValueT =
std::function<const symbol::ShapeOrDataDimExprs &(pir::Value)>;
explicit FuseShapeOpsIntoGenerateShapeOpPass(
const ShapeOrDataDimExprs4ValueT &ShapeOrDataDimExprs4Value);
FuseShapeOpsIntoGenerateShapeOpPass();

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override;

bool CanApplyOn(pir::Operation *op) const override;

private:
ShapeOrDataDimExprs4ValueT ShapeOrDataDimExprs4Value_;
};

} // namespace ir
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ if(WITH_CINN)
add_broadcast_to_elementwise_pass
pd_to_cinn_pass
sub_graph_checker
fully_insert_broadcast_pass
fuse_shape_ops_into_generate_shape_op_pass
split_generate_shape_into_shape_ops_pass)
endif()

Expand Down
26 changes: 9 additions & 17 deletions test/cpp/pir/cinn/generate_shape_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ TEST(DimExprUtil, Convert) {

TEST(DimExprUtil, Substitute) {
DimExpr dim_expr = CreateExampleDimExpr();
const auto& opt_expr = SubstituteDimExpr(
const auto& substituted_expr = SubstituteDimExpr(
dim_expr, [](const std::string& str) -> std::optional<DimExpr> {
if (str == "S0") {
return DimExpr("symbol0");
Expand All @@ -58,9 +58,8 @@ TEST(DimExprUtil, Substitute) {
return std::nullopt;
}
});
ASSERT_TRUE(opt_expr.has_value());
const auto& ret_expr = SubstituteDimExpr(
opt_expr.value(), [](const std::string& str) -> std::optional<DimExpr> {
substituted_expr, [](const std::string& str) -> std::optional<DimExpr> {
if (str == "symbol0") {
return DimExpr("S0");
} else if (str == "symbol1") {
Expand All @@ -69,26 +68,19 @@ TEST(DimExprUtil, Substitute) {
return std::nullopt;
}
});
ASSERT_TRUE(ret_expr.has_value());
ASSERT_EQ(ret_expr.value(), dim_expr);
ASSERT_EQ(ret_expr, dim_expr);
}

TEST(DimExprUtil, MakeGetterDimExpr4SymbolName) {
std::vector<std::tuple<std::string /*symbol_name*/,
int /*in_tensor_idx*/,
int /*in_tensor_dim_idx*/>>
symbol_bindings{};
symbol_bindings.push_back(std::make_tuple("Symbol", 0, 0));
cinn::dialect::GenerateShapeOp::SymbolBindings symbol_bindings{};
using ShapeSymbolBinding = cinn::dialect::GenerateShapeOp::ShapeSymbolBinding;
symbol_bindings.emplace_back(ShapeSymbolBinding{"Symbol", 0, 0});
const auto& dim_expr = CreateExampleDimExpr();
const auto& shape_or_data_dim_exprs = symbol::ShapeOrDataDimExprs({dim_expr});
const auto& DimExpr4SymbolName = MakeGetterDimExpr4SymbolName(
symbol_bindings,
[dim_expr](int in_tensor_idx,
int in_tensor_dim_idx) -> std::optional<DimExpr> {
if (in_tensor_idx == 0 && in_tensor_dim_idx == 0) {
return dim_expr;
} else {
return std::nullopt;
}
[&](int in_tensor_idx) -> const symbol::ShapeOrDataDimExprs& {
return shape_or_data_dim_exprs;
});
const auto& opt_dim_expr = DimExpr4SymbolName("Symbol");
ASSERT_TRUE(opt_dim_expr.has_value());
Expand Down