Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN]delete redundant SetShapeOrDataForValue #64470

Merged
merged 4 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions paddle/cinn/common/broadcast_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,18 @@ using Pattern2Placement = std::unordered_map<symbol::DimExpr, symbol::DimExpr>;
Pattern2Placement ConstructCstrLhsEqRhsReplacement(
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition) {
auto [lhs, rhs] = *broadcastable_condition;
if (rhs.isa<std::string>()) return Pattern2Placement{{rhs, lhs}};
if (lhs.isa<std::string>()) return Pattern2Placement{{lhs, rhs}};
if (rhs.isa<std::string>()) {
if (GetSymbolSet(lhs).count(ToString(rhs)) != 0) {
return Pattern2Placement{{lhs, rhs}};
}
return Pattern2Placement{{rhs, lhs}};
}
if (lhs.isa<std::string>()) {
if (GetSymbolSet(rhs).count(ToString(lhs)) != 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样对于 S0*S1*S2 <----> S0*S2 的情况是不是处理不了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我们只用保证uniform的符号不被替换就可以了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更新替换机制

return Pattern2Placement{{rhs, lhs}};
}
return Pattern2Placement{{lhs, rhs}};
}
return Pattern2Placement{{lhs, rhs}};
}

Expand Down
18 changes: 18 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,24 @@ void FusionOp::Print(pir::IrPrinter& printer) {
os << printer.indentation() << "}";
}

bool FusionOp::InferSymbolicShape(
::pir::InferSymbolicShapeContext* infer_context) {
::pir::InferSymExprForBlock(*block(), infer_context);

for (uint32_t rst_idx = 0; rst_idx < num_results(); rst_idx++) {
auto inner_yield_value = block()->back().operand_source(rst_idx);
const auto& shape =
infer_context->GetShapeOrDataForValue(inner_yield_value);
infer_context->SetShapeOrDataForValue(result(rst_idx), shape);
}

if (VLOG_IS_ON(4)) {
::std::cerr << ">>>>>>>>>>>>>>>>>>>> cinn_op.fusion(op_id: op_"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不像是error信息,对调试如果没有作用的话就去掉吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

<< block()->back().id() << ") END." << ::std::endl;
}
return true;
}

void YieldStoreOp::Build(pir::Builder& builder,
pir::OperationArgument& argument,
pir::Value x,
Expand Down
5 changes: 4 additions & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ class IR_API GroupOp

// FusionOp represents a subgraphs that can be fused to one kernel.
// Every GroupOp can be lowered to at least one FusionOp
class IR_API FusionOp : public pir::Op<FusionOp> {
class IR_API FusionOp
: public pir::Op<FusionOp, paddle::dialect::InferSymbolicShapeInterface> {
public:
using Op::Op;
static const char *name() { return "cinn_op.fusion"; }
Expand All @@ -81,6 +82,8 @@ class IR_API FusionOp : public pir::Op<FusionOp> {

std::vector<pir::Operation *> GetOperators() const;

bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);

void VerifySig();
void Print(pir::IrPrinter &printer); // NOLINT
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,13 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
.dyn_cast<paddle::dialect::PlaceAttribute>()
.data());
op->operand(0).set_source(new_full->result(0));
shape_analysis.SetShapeOrDataForValue(
new_full.result(0), symbol::TensorShapeOrDataDimExprs(out_dim));
} else {
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
op->operand_source(0),
cinn::hlir::framework::pir::GetBroadcastAxis(x_dims, output_shape),
output_shape);

op->operand(0).set_source(new_transpose_op->result(0));
shape_analysis.SetShapeOrDataForValue(
new_transpose_op.result(0),
symbol::TensorShapeOrDataDimExprs(out_dim));
}
}

Expand All @@ -160,18 +155,13 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
.data());

op->operand(1).set_source(new_full->result(0));
shape_analysis.SetShapeOrDataForValue(
new_full.result(0), symbol::TensorShapeOrDataDimExprs(out_dim));
} else {
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
op->operand_source(1),
cinn::hlir::framework::pir::GetBroadcastAxis(y_dims, output_shape),
output_shape);

op->operand(1).set_source(new_transpose_op->result(0));
shape_analysis.SetShapeOrDataForValue(
new_transpose_op.result(0),
symbol::TensorShapeOrDataDimExprs(out_dim));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ class AddYieldStoreInFusionOpPattern
op->operand_source(i), op->operand_source(i).type());
auto orignal_base = op->operand_source(i);
op->operand(i).set_source(store_op.result(0));

shape_analysis.SetShapeOrDataForValue(
store_op.result(0),
shape_analysis.GetShapeOrDataForValue(orignal_base));
}

return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ ::pir::GroupOpsVec CloneOps(
auto& alignment_schedule_info = node.alignment_schedule_info;
for (auto op : group_ops) {
auto new_op = op->Clone(*ir_mapping, clone_options);
// TODO(Hongqing-work): delete this after fix bug of
// cinn_dynamic_reshape_op_pass
auto& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());

Expand Down Expand Up @@ -350,6 +352,8 @@ class CinnGroupClusterPattern
auto new_group_op = ReplaceWithGroupOp(
&rewriter, uniq_ops, node, output_values, &ir_mapping);

// TODO(Hongqing-work): delete this after fix bug of
// cinn_dynamic_reshape_op_pass
auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(
group_op->GetParentProgram());
// update ir mapping
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,6 @@ bool ReplaceShapeOpsToGenerateShape(
GetOutOfRewrittenGenerateShapeOp(
shape_operand, rewriter, ShapeOrDataDimExprs4Value);
if (!opt_generated_shape.has_value()) return false;
shape_analysis->SetShapeOrDataForValue(
opt_generated_shape.value(), ShapeOrDataDimExprs4Value(shape_operand));
rewriter->ReplaceAllUsesWith(shape_operand, opt_generated_shape.value());
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ class GroupOpPattern : public pir::OpRewritePattern<cinn::dialect::GroupOp> {
<< "fusion_yield_values already has key!";
const auto& shape_expr =
shape_analysis.GetShapeOrDataForValue(vec_outs[i]);
// TODO(Hongqing-work): delete this after fix bug of
// cinn_dynamic_reshape_op_pass
shape_analysis.SetShapeOrDataForValue(fusion_op.result(i), shape_expr);
auto find_it = output_value2id.find(vec_outs[i]);
if (find_it != output_value2id.end()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,6 @@ std::optional<pir::Value> InsertGenerateShapeOpToRunFirst(
return std::nullopt;
}

void CloneDimExprInfo(pir::Value from,
pir::Value to,
const ShapeOrDataDimExprsAccessor& ctx) {
ctx.SetShapeOrDataDimExprs(to, ctx.GetShapeOrDataDimExprs(from));
}

void ReplaceAllUses(pir::Value from, pir::Value to) {
from.ReplaceAllUsesWith(to);
}
Expand All @@ -119,7 +113,6 @@ bool RewriteOneGenerateShapeOpToRunFirst(
std::optional<pir::Value> new_shape = InsertGenerateShapeOpToRunFirst(
&builder, block_args, op.out(), dim_exprs_accessor);
if (!new_shape.has_value()) continue;
CloneDimExprInfo(op.out(), new_shape.value(), dim_exprs_accessor);
ReplaceAllUses(op.out(), new_shape.value());
EraseGenerateShapeOp(op_iter, block);
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ namespace cinn::dialect {
struct ShapeOrDataDimExprsAccessor {
std::function<const symbol::ShapeOrDataDimExprs&(pir::Value)>
GetShapeOrDataDimExprs;
std::function<void(pir::Value, const symbol::ShapeOrDataDimExprs&)>
SetShapeOrDataDimExprs;
};

// Returns true if at least one GenerateShapeOp rewrote.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,7 @@ class GroupOpGenerateShapeOpsPattern
.GetShapeOrDataDimExprs =
[&](pir::Value value) -> const symbol::ShapeOrDataDimExprs& {
return shape_analysis.GetShapeOrDataForValue(value);
},
.SetShapeOrDataDimExprs =
[&](pir::Value value,
const symbol::ShapeOrDataDimExprs& dim_exprs) {
shape_analysis.SetShapeOrDataForValue(value, dim_exprs);
}};
}};
return MoveGenerateShapeOpsToPrologue(
ctx, group_op.block(), dim_exprs_accessor);
}
Expand All @@ -82,12 +77,7 @@ class MoveGenerateShapeOpsToProloguePass : public pir::Pass {
.GetShapeOrDataDimExprs =
[&](pir::Value value) -> const symbol::ShapeOrDataDimExprs& {
return shape_analysis.GetShapeOrDataForValue(value);
},
.SetShapeOrDataDimExprs =
[&](pir::Value value,
const symbol::ShapeOrDataDimExprs& dim_exprs) {
shape_analysis.SetShapeOrDataForValue(value, dim_exprs);
}};
}};
MoveGenerateShapeOpsToPrologue(ctx, group_op.block(), dim_exprs_accessor);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ class FusionOpPattern : public pir::OpRewritePattern<cinn::dialect::FusionOp> {
for (size_t i = 0; i < fusion_op.num_results(); ++i) {
rewriter.ReplaceAllUsesWith(fusion_op.result(i),
paddle_op.value()->result(i));
shape_analysis.SetShapeOrDataForValue(
paddle_op.value()->result(i),
shape_analysis.GetShapeOrDataForValue(fusion_op.result(i)));
}

rewriter.EraseOp(fusion_op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,11 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
pir::Value broadcasted_x =
rewriter->Build<paddle::dialect::ExpandOp>(x, output_dim_tensor).out();
op->operand(0).set_source(broadcasted_x);
shape_analysis.SetShapeOrDataForValue(broadcasted_x, out_shape);
}
if (y_shape.shape() != out_shape.shape()) {
pir::Value broadcasted_y =
rewriter->Build<paddle::dialect::ExpandOp>(y, output_dim_tensor).out();
op->operand(1).set_source(broadcasted_y);
shape_analysis.SetShapeOrDataForValue(broadcasted_y, out_shape);
}
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,13 @@ class DynamicExpandOpPattern
op->operand_source(0), broadcast_axes, out_shape);
}();

auto& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
shape_analysis.SetShapeOrDataForValue(
broadcast->result(0),
shape_analysis.GetShapeOrDataForValue(op.result(0)));

if (auto pre_full = broadcast->operand_source(0)
.defining_op()
->dyn_cast<paddle::dialect::FullOp>()) {
auto input_dim = pre_full.result(0)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims();
if (input_dim.size() == 1 && input_dim[0] == 1) {
shape_analysis.SetShapeOrDataForValue(
pre_full->result(0),
shape_analysis.GetShapeOrDataForValue(op.result(0)));
}
}

rewriter.ReplaceAllUsesWith(op->result(0), broadcast->result(0));
Expand Down
3 changes: 3 additions & 0 deletions paddle/pir/include/dialect/shape/utils/dim_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <optional>
#include <ostream>
#include <string>
#include <unordered_set>
#include <variant>
#include <vector>

Expand Down Expand Up @@ -223,6 +224,8 @@ using DimExprConstraint = std::variant<Equal<DimExpr>, Broadcastable<DimExpr>>;

IR_API std::string ToString(const DimExpr& dim_expr);

IR_API std::unordered_set<std::string> GetSymbolSet(const DimExpr& dim_expr);

IR_API std::ostream& operator<<(std::ostream&, const DimExpr& dim_expr);

IR_API std::ostream& operator<<(std::ostream&,
Expand Down
46 changes: 46 additions & 0 deletions paddle/pir/src/dialect/shape/utils/dim_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,54 @@ std::string ListDimExprToString(const List<DimExpr>& dim_exprs,
}
return ret;
}
std::unordered_set<std::string> ListDimExprGetSymbolSet(
const List<DimExpr>& dim_exprs) {
std::unordered_set<std::string> vars;
for (std::size_t i = 0; i < dim_exprs->size(); ++i) {
const auto& inner_vars = GetSymbolSet(dim_exprs->at(i));
vars.insert(inner_vars.begin(), inner_vars.end());
}
return vars;
}
} // namespace

std::unordered_set<std::string> GetSymbolSet(const DimExpr& dim_expr) {
std::unordered_set<std::string> vars;
auto lambdas = common::Overloaded{
[&](std::int64_t dim_expr) { return; },
[&](const std::string& dim_expr) { vars.insert(dim_expr); },
[&](const Negative<DimExpr>& dim_expr) {
const auto& inner_vars = GetSymbolSet(dim_expr->data);
vars.insert(inner_vars.begin(), inner_vars.end());
},
[&](const Reciprocal<DimExpr>& dim_expr) {
const auto& inner_vars = GetSymbolSet(dim_expr->data);
vars.insert(inner_vars.begin(), inner_vars.end());
},
[&](const Add<DimExpr>& dim_expr) {
const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands);
vars.insert(inner_vars.begin(), inner_vars.end());
},
[&](const Mul<DimExpr>& dim_expr) {
const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands);
vars.insert(inner_vars.begin(), inner_vars.end());
},
[&](const Max<DimExpr>& dim_expr) {
const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands);
vars.insert(inner_vars.begin(), inner_vars.end());
},
[&](const Min<DimExpr>& dim_expr) {
const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands);
vars.insert(inner_vars.begin(), inner_vars.end());
},
[&](const Broadcast<DimExpr>& dim_expr) {
const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands);
vars.insert(inner_vars.begin(), inner_vars.end());
}};
std::visit(lambdas, dim_expr.variant());
return vars;
}

std::string ToString(const DimExpr& dim_expr) {
auto lambdas = common::Overloaded{
[](std::int64_t dim_expr) { return std::to_string(dim_expr); },
Expand Down