-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CINN / PIR] Cinn trivalop fuse (#62088)
* implement FuseFilteredStmtPatterns * update * split trivial op into a single file. * fix compiler complaints * rename StmtIter to StmtPtr * declare group_pattern.InferShardableAxes * refine signature of group_pattern.InferShardableAxes * move group_pattern.InferShardableAxes to group_pattern_util.InferShardableAxes * implement group_pattern_util.InferShardableAxes * add group_pattern_util.InferShardableAxesFromSink * ReversedInferShardableAxes support sinks * update op lower * support multiple sinks in group_pattern_util.InferShardableAxes * update * fix link error * update * remove FusionOp to OpList * update * update * update * update * declare group_pattern_util.h * fix compiler complains * declare group_pattern_util.ClusteringHelper * refine signature of group_pattern_util.ClusterIntoGroupPatternsFromOpList * update op lowr * add todo * minor refine by group_pattern_util.OpSet * update * update * update (#57) * update * update * Cinn trivalop fuse (#58) * fix * refactor StmtFusionHelper by OpTopo * Complete: CreateReduceExpr function. * update * recursive done. * update * Cinn trivalop fuse (#59) * clean all the TODO. * update * fix cluster * remove unused OpTopo.downstream_disconnected_ops * Cinn trivalop fuse (#60) * fix compile rror * update * Cinn trivalop fuse (#61) * add R + T skeleon * add search utils. * update * Cinn trivalop fuse (#62) * push * update * fix * fix transformer * fix * Implement iterator vars fetching in ReduceOp * small fix * add GetOuterIterVars API * fix * fix compile complain * modify GetOutputIters of TrivialOp * remove dumplicate code in visit * implement ClusterIntoGroupPatternsFromOpList * Fix most error in trivial_op.cc. * CreateReduceExpr is OK! * fix * add CheckIterEq * implement group_pattern_util.ClusteringEngine and groupp_pattern_util.ClusteringPolicy * SinkTrivialTransform OK! * update * fix init_tensor name problem. * update * fix compiler complains * refactor ShardableAxesSignature by group_pattern.SoleOutputShardableAxes * split trivial_op.cc * update * implement group_pattern_util.MakeShardableAxesSignature4ReduceOp * update * implement group_pattern_util.MakeEmptyShardableAxesSignature * add helper class group_pattern_util.ShardableAxesProvider * implement group_pattern_util.MakeShardableAxesSignature4BroadcastOp * update * update * fix softmax error.! * fix * update * merge * fix * Implement new OpMergeWithOp and add a relevant flag * update * update * fix reduce_load error. add splitReduceTransform * fix conflict * update * update * update * disable horizontal fusion * fix * Add some VLOG * Fix group cluster bug (#71) * fix * fix dyshape * fix * init split cluster files * update * update * update * spliting * update * spliting * spliting * pattern utils * update * update * clean cmake * update * update * update * fix clustering_engine * fix fusion_helper * update * fix * update * update * update * update * fix * fix some erros * update * update * fix split with num problem * update * fix * fix static issues * fix * init split cluster files (#72) * update * update * update * update * update * update * update * update * update * split shardable axes provider (#73) * update * update * fix broadcast (#75) * update * update * fix * fix code format * fix code format * remove unittest * update * update (#77) * update * update * update --------- Co-authored-by: tc20042008 <[email protected]> Co-authored-by: feifei-111 <[email protected]> Co-authored-by: jiahy0825 <[email protected]> Co-authored-by: zhangbaizhou <[email protected]> Co-authored-by: Baizhou Zhang <[email protected]>
- Loading branch information
1 parent
c3f5747
commit fec0b3d
Showing
46 changed files
with
3,198 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <list> | ||
#include <variant> | ||
#include <vector> | ||
|
||
namespace cinn::api { | ||
|
||
template <typename T> | ||
struct ErrorPattern {}; | ||
|
||
// ElementWise/Broadcast/Injective Ops without reduction ancestors. | ||
template <typename T> | ||
struct InjectiveSourcePattern {}; | ||
|
||
// Reduce op | ||
template <typename T> | ||
struct SingleReductionOpPattern {}; | ||
|
||
// ElementWise/Broadcast ops which have shardable dimentions and reduction | ||
// ancestors. | ||
template <typename T> | ||
struct PartialShardablePattern {}; | ||
|
||
// Reduce base pattern | ||
template <typename T> | ||
struct ReductionPattern { | ||
using Nothing = std::monostate; | ||
std::variant<Nothing, InjectiveSourcePattern<T>, PartialShardablePattern<T>> | ||
input; | ||
SingleReductionOpPattern<T> reduce_op_pattern; | ||
|
||
bool HasFusedInput() const { | ||
return !std::holds_alternative<Nothing>(this->input); | ||
} | ||
}; | ||
|
||
// Stmt := IS | R | PS | ||
// ops in StmtPattern will be lowered into a inlined cuda code. | ||
template <typename T> | ||
using StmtPattern = std::variant<InjectiveSourcePattern<T>, | ||
ReductionPattern<T>, | ||
PartialShardablePattern<T>>; | ||
|
||
// Stmts := [Stmt] | ||
template <typename T> | ||
using StmtPatternVec = std::vector<StmtPattern<T>>; | ||
// fuse rules: | ||
// 1. IS * IS -> IS | ||
// 2. PS * PS -> PS | ||
// 3. IS * PS -> PS | ||
// 4. IS * R -> R | ||
// 5. PS * R -> R | ||
// lifting rules: | ||
// 1. R -> Stmts | ||
// 2. PS -> Stmts | ||
// 3. Stmts * Stmts -> Stmts | ||
// OpTopoPattern := Error | Stmts | ||
|
||
template <typename T> | ||
using OpTopoPattern = std::variant<ErrorPattern<T>, StmtPatternVec<T>>; | ||
|
||
} // namespace cinn::api |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
gather_srcs(group_cluster_src SRCS common_utils.cc pattern_node.cc | ||
pattern_graph.cc) | ||
|
||
add_subdirectory(cluster_policy) | ||
|
||
cc_library(group_cluster SRCS ${group_cluster_src}) |
3 changes: 3 additions & 0 deletions
3
paddle/cinn/frontend/group_cluster/cluster_policy/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
gather_srcs(group_cluster_src SRCS general_topo_policy.cc policy_manager.cc) | ||
|
||
add_subdirectory(shardable_axes_policy) |
25 changes: 25 additions & 0 deletions
25
paddle/cinn/frontend/group_cluster/cluster_policy/general_topo_policy.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/cinn/frontend/group_cluster/cluster_policy/general_topo_policy.h" | ||
|
||
namespace cinn::frontend::group_cluster::policy { | ||
|
||
bool GeneralTopoPolicy::CanFuse(const PatternNodePtr upstream, | ||
const PatternNodePtr downstream) { | ||
// TODO(wuzhanfei) topo policy (if lead to loop) | ||
return false; | ||
} | ||
|
||
} // namespace cinn::frontend::group_cluster::policy |
25 changes: 25 additions & 0 deletions
25
paddle/cinn/frontend/group_cluster/cluster_policy/general_topo_policy.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
#include "paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h" | ||
|
||
namespace cinn::frontend::group_cluster::policy { | ||
|
||
class GeneralTopoPolicy final : virtual public Policy { | ||
public: | ||
bool CanFuse(const PatternNodePtr upstream, const PatternNodePtr downstream); | ||
}; | ||
|
||
} // namespace cinn::frontend::group_cluster::policy |
28 changes: 28 additions & 0 deletions
28
paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h" | ||
#include "paddle/common/enforce.h" | ||
|
||
namespace cinn::frontend::group_cluster::policy { | ||
|
||
bool PolicyManager::CanFuse(const PatternNodePtr upstream, | ||
const PatternNodePtr downstream) { | ||
for (const auto& policy : policies_) { | ||
if (!policy->CanFuse(upstream, downstream)) return false; | ||
} | ||
return true; | ||
} | ||
|
||
} // namespace cinn::frontend::group_cluster::policy |
39 changes: 39 additions & 0 deletions
39
paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/cinn/frontend/group_cluster/pattern_node.h" | ||
|
||
namespace cinn::frontend::group_cluster::policy { | ||
|
||
class Policy { | ||
public: | ||
virtual bool CanFuse(const PatternNodePtr upstream, | ||
const PatternNodePtr downstream) = 0; | ||
}; | ||
|
||
using PolicyPtr = std::shared_ptr<Policy>; | ||
|
||
class PolicyManager { | ||
public: | ||
explicit PolicyManager(const std::vector<PolicyPtr>& policies) | ||
: policies_(policies) {} | ||
bool CanFuse(const PatternNodePtr upstream, const PatternNodePtr downstream); | ||
|
||
private: | ||
std::vector<PolicyPtr> policies_; | ||
}; | ||
|
||
} // namespace cinn::frontend::group_cluster::policy |
2 changes: 2 additions & 0 deletions
2
paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
gather_srcs(group_cluster_src SRCS shardable_axes_base.cc | ||
shardable_axes_policy.cc) |
Oops, something went wrong.