Skip to content

Commit

Permalink
[CINN/Fusion] horizontal support dynamic shape and enhance fusion abi…
Browse files Browse the repository at this point in the history
…lity (PaddlePaddle#63913)

* [CINN] Support horizontal fusion

* Change data type

* Support horizontal fusion

* Fix compile error

* add topo sort in backend fusion

* horizontal support dynamic shape and enhance fusion ability

* fix

* xx

* fix some bugs

* fix

* xxxx

* fix

* horizontal operator fusion enhance

* fix

* fix

* fix

* fix

* fix by code review

* fix

---------

Co-authored-by: jiahongyu <[email protected]>
  • Loading branch information
2 people authored and co63oc committed May 18, 2024
1 parent 806936f commit d3b786b
Show file tree
Hide file tree
Showing 24 changed files with 923 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ void ApplyDivideGroupOpToFusionOpPass(
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
if (FLAGS_group_schedule_tiling_first) {
pass_manager->AddPass(cinn::dialect::ir::CreateCinnGroupClusterPass());
pass_manager->AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass());
// pass_manager->AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass());
} else {
pass_manager->AddPass(
cinn::dialect::ir::CreateDivideGroupOpToFusionOpPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ std::vector<GroupClusterNode> GroupSplit(cinn::dialect::GroupOp group_op) {
std::function<cinn::fusion::FrontendContent(pir::Operation*)> func =
[](pir::Operation* op) { return cinn::fusion::FrontendContent(op); };
const auto& contents = cinn::fusion::MapVector(group_op.GetOperators(), func);
auto cluster_result = cinn::fusion::ClusterOps(contents);
auto cluster_result = cinn::fusion::ClusterOps(contents, {});
std::vector<std::vector<pir::Operation*>> result;
std::transform(
cluster_result.begin(),
Expand Down Expand Up @@ -390,6 +390,9 @@ class CinnGroupClusterPass : public pir::PatternRewritePass {
}

bool CanApplyOn(pir::Operation* op) const override {
if (op->isa<cinn::dialect::FusionOp>()) {
return false;
}
return op->num_regions() > 0;
}
};
Expand Down
5 changes: 3 additions & 2 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(

// =========== OpFusion ============

func_bodies = OperationFusion(ops, func_bodies);
// VLOG(4) << "Bucket Lower output values is : " << group->output_values();
func_bodies = OperationFusion(ops, func_bodies, group->output_values());
const auto& fusion_group_info = GetFusionGroupInfo(func_bodies);

// =========== CodeGen And Optimizer ================
Expand Down Expand Up @@ -728,7 +729,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
group->mut_output_names().clear();

// collect all output tensor.
for (auto op_result : group->GetGroupOutputValues()) {
for (auto op_result : group->output_values()) {
if (tensor_map.count(op_result) == 0) {
continue;
}
Expand Down
67 changes: 48 additions & 19 deletions paddle/cinn/hlir/framework/pir/trivial_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,21 @@ std::vector<ir::Var> GetOutputIters(const FusibleOp& op) {
return AppendBound(std::visit(Visitor(), op), _GetRootExpr(op));
}

std::vector<ir::Var> GetAllIterVars(const ir::Expr& expr) {
ir::Expr compute_schedule_block_realize =
(ExprSetFinderUtils::ChildScheduleBlockRealizes *
ExprSetFinderUtils::ScheduleBlockRealizeIsNotInit)
.GetSingle(expr);

const std::vector<Expr>& all_iter_expr =
compute_schedule_block_realize.As<ir::ScheduleBlockRealize>()
->iter_values;
return ComposeUtils::ExprVec2VarVec(all_iter_expr);
}

std::vector<ir::Var> GetReduceIters(const ReduceOp& op) {
auto GetUnorderedAllIterVars = [](const ReduceOp& op) {
ir::Expr compute_schedule_block_realize =
(ExprSetFinderUtils::ChildScheduleBlockRealizes *
ExprSetFinderUtils::ScheduleBlockRealizeIsNotInit)
.GetSingle(_GetRootExpr(op));

const std::vector<Expr>& all_iter_expr =
compute_schedule_block_realize.As<ir::ScheduleBlockRealize>()
->iter_values;
return ComposeUtils::ExprVec2VarVec(all_iter_expr);
return GetAllIterVars(_GetRootExpr(op));
};

// Iter Vars not appearing in outer_iter_vars are pushed into
Expand Down Expand Up @@ -560,16 +564,39 @@ std::pair<TrivialOp, ReduceOp> SplitReduceOp(const ReduceOp& reduce_op) {
return std::make_pair(result_trivial, result_reduce);
}

std::vector<ir::Var> GetAllForIters(const ir::Expr& expr) {
using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils::
ChildFors;
using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils::
ChildScheduleBlockRealizes;
using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils::
FindFather;
using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils::
IsFor;
using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils::
ScheduleBlockRealizeIsNotInit;
const auto& all_father_fors =
(ChildScheduleBlockRealizes * ScheduleBlockRealizeIsNotInit *
FindFather(expr) * IsFor)(expr);
std::vector<ir::Var> vars;
for (const auto& for_expr : all_father_fors) {
vars.push_back(for_expr.As<ir::For>()->loop_var);
}
VLOG(4) << "GetAllForIters : " << expr
<< "\n var is : " << utils::Join(vars, ",");
return vars;
}

} // namespace trivial_fusion_detail

std::vector<ir::Expr> OperationFusion(
const std::vector<::pir::Operation*>& original_ops,
const std::vector<ir::Expr>& op_compute_bodies) {
PADDLE_ENFORCE_EQ(FLAGS_group_schedule_tiling_first,
true,
::common::errors::PreconditionNotMet(
"TrivialFusion must be used with tiling first, set "
"FLAGS_group_schedule_tiling_first=1"));
const std::vector<ir::Expr>& op_compute_bodies,
const std::vector<::pir::Value>& outputs) {
PADDLE_ENFORCE(FLAGS_group_schedule_tiling_first,
::common::errors::PreconditionNotMet(
"TrivialFusion must be used with tiling first, set "
"FLAGS_group_schedule_tiling_first=1"));
const auto& ops = trivial_fusion_detail::FilterVector(
original_ops, [](const ::pir::Operation* op) {
if (op->name() == "cinn_op.generate_shape") {
Expand All @@ -581,10 +608,9 @@ std::vector<ir::Expr> OperationFusion(
std::vector<cinn::fusion::BackendContent> contents;
for (int i = 0; i < ops.size(); i++) {
contents.emplace_back(ops[i], op_compute_bodies[i]);
// contents.emplace_back(ops[i]);
}
const auto& fusion_nodes =
cinn::fusion::ClusterOps<cinn::fusion::BackendStage>(contents);
cinn::fusion::ClusterOps<cinn::fusion::BackendStage>(contents, outputs);

PADDLE_ENFORCE_EQ(fusion_nodes.size(),
1,
Expand All @@ -601,6 +627,8 @@ std::vector<ir::Expr> OperationFusion(

FusionGroupInfo GetFusionGroupInfo(
const std::vector<ir::Expr>& op_compute_bodies) {
using trivial_fusion_detail::AppendBound;
using trivial_fusion_detail::GetAllForIters;
using trivial_fusion_detail::ReduceOp;
using trivial_fusion_detail::ComposeUtils::ConcatVector;
using trivial_fusion_detail::ExprSetFinderUtils::ChildScheduleBlockRealizes;
Expand All @@ -618,7 +646,7 @@ FusionGroupInfo GetFusionGroupInfo(
ReduceOp op = ReduceOp(body);
if (group_info.reduce_var_name.empty()) {
std::vector<ir::Var> all_iters =
ConcatVector(GetOutputIters(op), GetReduceIters(op));
AppendBound(GetAllForIters(body), body);
std::transform(all_iters.begin(),
all_iters.end(),
std::back_inserter(group_info.loop_ranges),
Expand All @@ -631,7 +659,8 @@ FusionGroupInfo GetFusionGroupInfo(
return (int64_t)-1;
}
});
std::vector<ir::Var> reduce_iters = GetReduceIters(op);
std::vector<ir::Var> reduce_iters = fusion::FilterVector(
all_iters, [](const ir::Var& var) { return var->is_reduce_axis; });
for (int64_t i = all_iters.size() - reduce_iters.size();
i < all_iters.size();
i++) {
Expand Down
7 changes: 6 additions & 1 deletion paddle/cinn/hlir/framework/pir/trivial_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ FusibleOp SinkTrivialLoopAlign(TrivialOp trivial_op,
ReduceOp reduce_op,
std::vector<size_t> fake_reduce_iter_idx);

std::vector<ir::Var> GetAllIterVars(const ir::Expr& expr);

std::vector<ir::Var> GetAllForIters(const ir::Expr& expr);

} // namespace trivial_fusion_detail

struct FusionGroupInfo {
Expand All @@ -178,7 +182,8 @@ FusionGroupInfo GetFusionGroupInfo(

std::vector<ir::Expr> OperationFusion(
const std::vector<::pir::Operation*>& ops,
const std::vector<ir::Expr>& op_compute_bodies);
const std::vector<ir::Expr>& op_compute_bodies,
const std::vector<::pir::Value>& outputs);

} // namespace pir
} // namespace framework
Expand Down
77 changes: 71 additions & 6 deletions paddle/cinn/hlir/framework/pir/trivial_op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ std::vector<ir::Expr> VarVec2ExprVec(const std::vector<ir::Var>& in) {
std::vector<ir::Expr> GetEachTensorLoadExpr(const ir::Expr& body,
const ir::Tensor& tensor) {
VLOG(4) << "GetEachTensorLoadExpr: " << tensor;
std::set<Expr> load_exprs = cinn::ir::ir_utils::CollectIRNodesWithoutTensor(
std::vector<Expr> load_exprs = cinn::ir::ir_utils::CollectIRNodesInOrder(
body, [&tensor](const Expr* expr) {
return expr->As<ir::Load>() && expr->As<ir::Load>()->is_addr_tensor() &&
expr->As<ir::Load>()->tensor.as_tensor_ref()->name ==
Expand All @@ -65,7 +65,7 @@ std::vector<ir::Expr> GetEachTensorLoadExpr(const ir::Expr& body,
for (auto& t : load_exprs) {
VLOG(4) << "GetEachTensorLoadExpr Found: " << t << " " << t.ptr();
}
return std::vector(load_exprs.begin(), load_exprs.end());
return load_exprs;
}

MappingTargetExprToDestExprMutator::MappingTargetExprToDestExprMutator(
Expand All @@ -83,6 +83,16 @@ void MappingTargetExprToDestExprMutator::Visit(const ir::Load* load, Expr* op) {
IRMutator::Visit(load, op);
}
}

void MappingTargetExprToDestExprMutator::Visit(const ir::For* for_node,
Expr* op) {
if (for_node == source_.ptr()) {
*op = dest_;
} else {
IRMutator::Visit(for_node, op);
}
}

void MappingTargetExprToDestExprMutator::Visit(const ir::Store* store,
Expr* op) {
if (store == source_.ptr()) {
Expand All @@ -91,6 +101,7 @@ void MappingTargetExprToDestExprMutator::Visit(const ir::Store* store,
IRMutator::Visit(store, op);
}
}

void MappingTargetExprToDestExprMutator::Visit(const ir::Reduce* reduce,
Expr* op) {
if (reduce == source_.ptr()) {
Expand Down Expand Up @@ -196,7 +207,7 @@ ExprSetFinder ExprSetFinder::operator*(ExprSetFinder x) const {
std::vector<ir::Expr> res;
for (const auto& r : rs) {
const auto& x_res = x.f_(r);
res.insert(res.begin(), x_res.begin(), x_res.end());
res.insert(res.end(), x_res.begin(), x_res.end());
}
return res;
};
Expand Down Expand Up @@ -246,6 +257,15 @@ ExprSetFinder ScheduleBlockRealizeNotRoot = FilterMaker(
},
"ScheduleBlockRealizeNotRoot");

ExprSetFinder ScheduleBlockRealizeIsRoot = FilterMaker(
[](const ir::Expr& e) -> bool {
return (e.As<ir::ScheduleBlockRealize>() &&
e.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name.find("root") != std::string::npos);
},
"ScheduleBlockRealizeIsRoot");

ExprSetFinder ScheduleBlockRealizeIsNotInit = FilterMaker(
[](const ir::Expr& e) -> bool {
return (e.As<ir::ScheduleBlockRealize>() &&
Expand Down Expand Up @@ -277,6 +297,12 @@ ExprSetFinder ChildScheduleBlockRealizes =
"ChildScheduleBlockRealizes") *
ScheduleBlockRealizeNotRoot;

ExprSetFinder ChildRootScheduleBlockRealizes =
Collector(
[](const ir::Expr* e) { return e->As<ir::ScheduleBlockRealize>(); },
"ChildScheduleBlockRealizes") *
ScheduleBlockRealizeIsRoot;

ExprSetFinder IsForIterVar(const ir::Var& var) {
return FilterMaker(
[var = var](const ir::Expr& e) -> bool {
Expand Down Expand Up @@ -304,7 +330,7 @@ ExprSetFinder ChildTensorLoads = Collector(

ExprSetFinder ChildTensorStores = Collector(
[](const ir::Expr* e) {
return e->As<ir::Load>() && e->As<ir::Store>()->is_addr_tensor();
return e->As<ir::Store>() && e->As<ir::Store>()->is_addr_tensor();
},
"ChildTensorStores");

Expand All @@ -324,8 +350,10 @@ ExprSetFinder FindFather(const ir::Expr& root) {
const auto& f = [&](const auto& child) -> ExprSet {
ExprSetFinder find_child =
Collector([child](const ir::Expr* e) { return *e == child; });
const auto& father_collector = Collector(
[&](const ir::Expr* current) { return !find_child(*current).empty(); });
const auto& father_collector = Collector([&](const ir::Expr* current) {
auto res = (*current != child) && !find_child(*current).empty();
return res;
});
return father_collector(root);
};
return ExprSetFinder(f, "FindFather");
Expand Down Expand Up @@ -373,6 +401,35 @@ ExprTransformer WrapForsTransformer(const std::vector<ir::Var>& vs) {
return ExprTransformer(f);
}

ExprTransformer UnsqueezeForTransformer(
const ExprSetFinderUtils::ExprSetFinder& followed_finder,
const ir::Var& to_append_var) {
const auto& suqueeze_for_func = [&](const ir::Expr& e) -> ir::Expr {
auto copied_e = ir::ir_utils::IRCopy(e);
ir::Expr followed_expr = followed_finder.GetSingle(copied_e);
// (ExprSetFinderUtils::ChildFors *
// ExprSetFinderUtils::IsForIterVar(following_for_iter_var)).GetSingle(copied_e);
VLOG(6) << "UnsqueezeForTransformer: for insert after " << followed_expr;
if (followed_expr.As<ir::For>()) {
followed_expr.As<ir::For>()->body = ir::Block::Make({WrapForTransformer(
to_append_var)(followed_expr.As<ir::For>()->body)});
} else if (followed_expr.As<ir::ScheduleBlockRealize>()) {
const auto& schedule_block = followed_expr.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>();
schedule_block->body =
WrapForTransformer(to_append_var)(schedule_block->body);
} else {
PADDLE_THROW(
"UnsqueezeForTransformer: only support insert after a (For / "
"ScheduleBlockRealizer): %s",
followed_expr);
}
VLOG(6) << "UnsqueezeForTransformer: After changed: " << copied_e;
return copied_e;
};
return ExprTransformer(suqueeze_for_func);
}

ExprTransformer ChangeTensorLoadTransformer(const ir::Tensor& tensor,
const ir::Expr& dst_load) {
const auto& f = [&](const ir::Expr& e) -> ir::Expr {
Expand Down Expand Up @@ -420,6 +477,14 @@ ExprTransformer ChangeVarTransformer(const std::vector<ir::Var>& target_vars,
return ExprTransformer(f);
}

ExprTransformer ReplaceVarTransformer(const std::vector<ir::Var>& target_vars,
const std::vector<ir::Expr>& dest_expr) {
const auto& f = [=](const ir::Expr& e) -> ir::Expr {
return ComposeUtils::CopyedReplaceExpr(e, target_vars, dest_expr);
};
return ExprTransformer(f);
}

bool IsReduceBool(const ir::Expr& lhs, const ir::Expr& rhs) {
return lhs.type().is_bool() || rhs.type().is_bool();
}
Expand Down
13 changes: 12 additions & 1 deletion paddle/cinn/hlir/framework/pir/trivial_op_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ struct MappingTargetExprToDestExprMutator : public ir::IRMutator<> {
void Visit(const ir::Load* load, Expr* op) override;
void Visit(const ir::Store* store, Expr* op) override;
void Visit(const ir::Reduce* reduce, Expr* op) override;
void Visit(const ir::For* for_node, Expr* op) override;

private:
ir::Expr source_;
Expand Down Expand Up @@ -132,7 +133,7 @@ template <typename Teller>
ExprSetFinder Collector(Teller t, std::string name = "") {
return ExprSetFinder(
[=](const ir::Expr& x) -> ExprSet {
const auto& rs = cinn::ir::ir_utils::CollectIRNodesWithoutTensor(x, t);
const auto& rs = cinn::ir::ir_utils::CollectIRNodesInOrder(x, t);
return std::vector(rs.begin(), rs.end());
},
name);
Expand Down Expand Up @@ -170,6 +171,8 @@ extern ExprSetFinder ChildScheduleBlocks;

extern ExprSetFinder ChildScheduleBlockRealizes;

extern ExprSetFinder ChildRootScheduleBlockRealizes;

extern ExprSetFinder For2Min;

extern ExprSetFinder For2Max;
Expand Down Expand Up @@ -230,6 +233,14 @@ std::vector<ir::Var> CreateInnerBlockVars(
ExprTransformer ChangeVarTransformer(const std::vector<ir::Var>& target_vars,
const std::vector<ir::Var>& dest_vars);

ExprTransformer ReplaceVarTransformer(const std::vector<ir::Var>& target_vars,
const std::vector<ir::Expr>& dest_exprs);

// insert after followed_finder. only support For and ScheduleBlockRealizer
ExprTransformer UnsqueezeForTransformer(
const ExprSetFinderUtils::ExprSetFinder& followed_finder,
const ir::Var& to_append_var);

ExprTransformer SubstitudeByScheduleBlockRealize(const ir::Expr& realize);

ExprTransformer WrapScheduleRealizer(const std::vector<ir::Var>& block_vars,
Expand Down
Loading

0 comments on commit d3b786b

Please sign in to comment.