From 0922d8dc69ecdd581345900b5753c674453d6ed4 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Thu, 11 Apr 2024 09:56:49 +0000 Subject: [PATCH 01/19] [CINN] Support horizontal fusion --- paddle/cinn/api/op_topo_pattern.h | 2 +- .../cluster_policy/relative_judge_policy.h | 2 +- .../frontend/group_cluster/pattern_graph.cc | 4 +- test/ir/pir/cinn/test_horizontal_fusion.py | 63 +++++++++++++++++++ 4 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 test/ir/pir/cinn/test_horizontal_fusion.py diff --git a/paddle/cinn/api/op_topo_pattern.h b/paddle/cinn/api/op_topo_pattern.h index 34f17fbfde9e0..96e9484087110 100644 --- a/paddle/cinn/api/op_topo_pattern.h +++ b/paddle/cinn/api/op_topo_pattern.h @@ -31,7 +31,7 @@ struct InjectiveSourcePattern {}; template struct SingleReductionOpPattern {}; -// ElementWise/Broadcast ops which have shardable dimentions and reduction +// ElementWise/Broadcast ops which have shardable dimensions and reduction // ancestors. template struct PartialShardablePattern {}; diff --git a/paddle/cinn/frontend/group_cluster/cluster_policy/relative_judge_policy.h b/paddle/cinn/frontend/group_cluster/cluster_policy/relative_judge_policy.h index e98b68dc893af..9352a4efe1261 100644 --- a/paddle/cinn/frontend/group_cluster/cluster_policy/relative_judge_policy.h +++ b/paddle/cinn/frontend/group_cluster/cluster_policy/relative_judge_policy.h @@ -155,7 +155,7 @@ static ValueDimRelation CreateOpRelativenessForReduce(pir::Operation* op) { int out_idx = 0; bool keep_dim = GetReduceOpKeepDims(op); for (int i = 0; i < input_rank; i++) { - if (std::find(reduce_axis_idx.begin(), reduce_axis_idx.end(), i) != + if (std::find(reduce_axis_idx.begin(), reduce_axis_idx.end(), i) == reduce_axis_idx.end()) { res[ValueDim(op->operand_source(0), i)] [ValueDim(op->result(0), out_idx)] = true; diff --git a/paddle/cinn/frontend/group_cluster/pattern_graph.cc b/paddle/cinn/frontend/group_cluster/pattern_graph.cc index bbd49d1b17503..465613db8e792 100644 --- a/paddle/cinn/frontend/group_cluster/pattern_graph.cc +++ b/paddle/cinn/frontend/group_cluster/pattern_graph.cc @@ -91,7 +91,9 @@ void PatternGraph::ReduceLiftReduceTree() { void PatternGraph::HorizontalFusion() { GraphTransformer, + Or, + StmtPatternGraphMatcher>, + StmtPatternGraphMatcher>, LiftToHorizontalFusionPatternOperation>(this); GraphTransformer Date: Fri, 12 Apr 2024 02:32:02 +0000 Subject: [PATCH 02/19] Change data type --- test/ir/pir/cinn/test_horizontal_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ir/pir/cinn/test_horizontal_fusion.py b/test/ir/pir/cinn/test_horizontal_fusion.py index 83556ec647bf8..b9bcf8b7010a1 100644 --- a/test/ir/pir/cinn/test_horizontal_fusion.py +++ b/test/ir/pir/cinn/test_horizontal_fusion.py @@ -35,7 +35,7 @@ def setUp(self): self.prepare_data() def prepare_data(self): - self.x = paddle.randn([256, 128], dtype="bfloat16") + self.x = paddle.randn([256, 128], dtype="float32") self.x.stop_gradient = True def check_jit_kernel_info(self, static_fn): From a3bb64b335fa1f4dbd190432d48bf160b9f2c0f5 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Mon, 22 Apr 2024 08:56:39 +0000 Subject: [PATCH 03/19] Support horizontal fusion --- paddle/cinn/operator_fusion/backend/pattern.h | 17 ---- .../cinn/operator_fusion/frontend/pattern.h | 17 ---- paddle/cinn/operator_fusion/pattern.h | 21 ++++- paddle/cinn/operator_fusion/pattern_graph.cc | 13 ++- paddle/cinn/operator_fusion/pattern_graph.h | 93 +++++++++++++++---- 5 files changed, 103 insertions(+), 58 deletions(-) diff --git a/paddle/cinn/operator_fusion/backend/pattern.h b/paddle/cinn/operator_fusion/backend/pattern.h index 7006fb02b829e..4df4e9419cc62 100644 --- a/paddle/cinn/operator_fusion/backend/pattern.h +++ b/paddle/cinn/operator_fusion/backend/pattern.h @@ -68,21 +68,4 @@ struct UnsupportPattern { static std::string name() { return "Unsupport"; } }; -template <> -struct HorizontalFusionPattern { - explicit HorizontalFusionPattern( - const std::vector>& patterns) - : patterns_(patterns) {} - std::vector> patterns_; - std::vector ops() const { - std::vector result; - for (const auto& pattern : patterns_) { - auto ops = GetOpsInPattern(pattern); - ExtendVector(&result, ops); - } - return result; - } - static std::string name() { return "HorizontalFusionPattern"; } -}; - } // namespace cinn::fusion diff --git a/paddle/cinn/operator_fusion/frontend/pattern.h b/paddle/cinn/operator_fusion/frontend/pattern.h index e267483e56586..5c2641a808ce1 100644 --- a/paddle/cinn/operator_fusion/frontend/pattern.h +++ b/paddle/cinn/operator_fusion/frontend/pattern.h @@ -72,21 +72,4 @@ struct UnsupportPattern { static std::string name() { return "Unsupport"; } }; -template <> -struct HorizontalFusionPattern { - explicit HorizontalFusionPattern( - const std::vector>& patterns) - : patterns_(patterns) {} - std::vector> patterns_; - std::vector ops() const { - std::vector result; - for (const auto& pattern : patterns_) { - auto ops = GetOpsInPattern(pattern); - ExtendVector(&result, ops); - } - return result; - } - static std::string name() { return "HorizontalFusionPattern"; } -}; - } // namespace cinn::fusion diff --git a/paddle/cinn/operator_fusion/pattern.h b/paddle/cinn/operator_fusion/pattern.h index 908b4a4348bfc..73558dd1535ad 100644 --- a/paddle/cinn/operator_fusion/pattern.h +++ b/paddle/cinn/operator_fusion/pattern.h @@ -79,9 +79,26 @@ struct ReduceTreePlusTrivialPattern { }; template -class UnsupportPattern {}; +struct StmtPattern; + +template +struct UnsupportPattern {}; + template -class HorizontalFusionPattern {}; +struct HorizontalFusionPattern { + explicit HorizontalFusionPattern(const std::vector>& patterns) + : patterns_(patterns) {} + std::vector> patterns_; + std::vector ops() const { + std::vector result; + for (const auto& pattern : patterns_) { + auto ops = GetOpsInPattern(pattern); + ExtendVector(&result, ops); + } + return result; + } + static std::string name() { return "HorizontalFusionPattern"; } +}; template using StmtPatternBase = std::variant, diff --git a/paddle/cinn/operator_fusion/pattern_graph.cc b/paddle/cinn/operator_fusion/pattern_graph.cc index a8ab68cf809b3..4c3a745cbb95e 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.cc +++ b/paddle/cinn/operator_fusion/pattern_graph.cc @@ -99,14 +99,17 @@ void PatternGraph::ReduceLiftReduceTree() { template void PatternGraph::HorizontalFusion() { - GraphTransformer>, - LiftToHorizontalFusionPatternOperation>(this); + GraphTransformer< + NodePattern, + T, + Or>, + StmtPatternGraphMatcher>>, + StmtPatternGraphMatcher>>, + LiftToHorizontalFusionPatternOperation>(this); GraphTransformer, HorizontalFusionOperation>(this); } diff --git a/paddle/cinn/operator_fusion/pattern_graph.h b/paddle/cinn/operator_fusion/pattern_graph.h index e6ba134262349..4103e6b1ce491 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.h +++ b/paddle/cinn/operator_fusion/pattern_graph.h @@ -322,29 +322,88 @@ struct CanFuseReduceTreeAndTrivialMatcher { } }; +template struct HorizontalFusionConstrain { - template bool operator()(const PatternGraph& graph, - const PatternNodePtr& first, - const PatternNodePtr& second) { - if (!StmtPatternGraphMatcher>()(graph, first)) { + const PatternNodePtr& lhs, + const PatternNodePtr& rhs) { + if (!StmtPatternGraphMatcher>()(graph, lhs)) { return false; } - if (!StmtPatternGraphMatcher>()(graph, second)) { + if (!StmtPatternGraphMatcher>()(graph, rhs)) { return false; } - const auto& first_dim = first->sink_op() - ->result(0) - .type() - .template dyn_cast() - .dims(); - const auto& second_dim = second->sink_op() - ->result(0) - .type() - .template dyn_cast() - .dims(); - return graph.topo_manager().CanFuse(first, second) && - first_dim == second_dim; + const auto& lhs_pattern = + lhs->stmt_pattern().template Get(); + const auto& rhs_pattern = + rhs->stmt_pattern().template Get(); + + return graph.topo_manager().CanFuse(lhs, rhs) && + PatternDimMatch(lhs_pattern.patterns_.back(), + rhs_pattern.patterns_.back()); + } + + const pir::DenseTensorType::Dim& GetInputDim(pir::Operation* op) { + return op->operand_source(0) + .type() + .template dyn_cast() + .dims(); + } + + const pir::DenseTensorType::Dim& GetOutputDim(pir::Operation* op) { + return op->result(0) + .type() + .template dyn_cast() + .dims(); + } + + template