Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#68 from tc20042008/xk-cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
refactor ShardableAxesSignature by group_pattern.SoleOutputShardableAxes
  • Loading branch information
tc20042008 authored Mar 14, 2024
2 parents 5180b55 + cda4b1b commit 35506a8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 28 deletions.
6 changes: 5 additions & 1 deletion paddle/cinn/frontend/group_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,12 @@ struct ShardableAxesUtil {
}
};

struct SoleOutputShardableAxes {
ShardableAxes shardable_axes;
};

struct ShardableAxesSignature {
ShardableAxes output_shardable_axes;
SoleOutputShardableAxes sole_output_sa;
std::unordered_map<OpAndOperandIndex, ShardableAxes> input_shardable_axes;
};

Expand Down
64 changes: 37 additions & 27 deletions paddle/cinn/frontend/group_pattern_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ struct OpTopo {

};

int GetOutputShardableAxesResultIdx(const pir::Operation* op) {
return 0;
}

OpPatternKind GetOpPatternKind(const ::pir::Operation* node) {
return hlir::framework::pir::CompatibleInfo::OpKind(*node);
}
Expand Down Expand Up @@ -215,11 +219,11 @@ ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(
rank = GetRank(op->operand_source(i));
}
}
CHECK_EQ(op->num_results(), 1);
const int result_idx = GetOutputShardableAxesResultIdx(op);
if (rank.has_value()) {
CHECK_EQ(rank.value(), GetRank(op->result(0)));
CHECK_EQ(rank.value(), GetRank(op->result(result_idx)));
} else {
rank = GetRank(op->result(0));
rank = GetRank(op->result(result_idx));
}
CHECK(rank.has_value());
return rank.value();
Expand All @@ -231,7 +235,9 @@ ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(
input_shardable_axes[OpAndOperandIndex{op, i}] = output_shardable_axes;
}
return ShardableAxesSignature{
.output_shardable_axes = output_shardable_axes,
.sole_output_sa = SoleOutputShardableAxes{
.shardable_axes=output_shardable_axes,
},
.input_shardable_axes = input_shardable_axes,
};
}
Expand Down Expand Up @@ -276,9 +282,11 @@ std::unordered_map<pir::Value, ShardableAxes> ReversedInferShardableAxes(
};
reversed_walker(sinks.begin(), sinks.end(), [&](const auto* op) {
auto shardable_axes_sig = MakeShardableAxesSignature4Op(op);
const auto& sole_output_sa = shardable_axes_sig.sole_output_sa;
const int result_idx = GetOutputShardableAxesResultIdx(op);
const auto& old2new = ShardableAxesUtil::GetOldName2NewName(
shardable_axes_sig.output_shardable_axes,
value2shardable_axes.at(op->result(0)));
sole_output_sa.shardable_axes,
value2shardable_axes.at(op->result(result_idx)));
for (auto& pair : shardable_axes_sig.input_shardable_axes) {
const auto& [my_op, input_idx] = pair.first;
CHECK_EQ(my_op, op);
Expand All @@ -296,8 +304,8 @@ std::unordered_map<pir::Value, ShardableAxes> ReversedInferShardableAxes(
const pir::Operation* sink,
const ShardableAxes& init_sa) {
using OpAndInitValue = std::pair<pir::Value, ShardableAxes>;
CHECK_EQ(sink->num_results(), 1);
std::array<OpAndInitValue, 1> sinks{OpAndInitValue{sink->result(0), init_sa}};
const int result_idx = GetOutputShardableAxesResultIdx(sink);
std::array<OpAndInitValue, 1> sinks{OpAndInitValue{sink->result(result_idx), init_sa}};
return ReversedInferShardableAxes(reversed_walker, sinks.begin(), sinks.end());
}

Expand Down Expand Up @@ -364,7 +372,7 @@ GetAxisName2BoundAxisName(
if (ops->count(input_op) == 0) return std::nullopt;
const auto& iter = op2shardable_axes_signature.find(input_op);
if (iter == op2shardable_axes_signature.end()) return std::nullopt;
const auto& output_sa = iter->second.output_shardable_axes;
const auto& output_sa = iter->second.sole_output_sa.shardable_axes;
return &output_sa;
};
std::map<std::string, std::vector<std::string>> axis_name2bound_axis_name;
Expand Down Expand Up @@ -432,9 +440,11 @@ GetSinkAndInitShardableAxes(
for (const auto* sink : sinks) {
const auto& sig_iter = op2shardable_axes_signature.find(sink);
CHECK(sig_iter != op2shardable_axes_signature.end());
const auto& output_shardable_axes = sig_iter->second.output_shardable_axes;
CHECK_EQ(sink->num_results(), 1);
sink2sa[sink->result(0)] = ConvertByBoundAxisName(output_shardable_axes);
const auto& sole_output_sa = sig_iter->second.sole_output_sa;
const auto& output_shardable_axes = sole_output_sa.shardable_axes;
const int result_idx = GetOutputShardableAxesResultIdx(sink);
sink2sa[sink->result(result_idx)] =
ConvertByBoundAxisName(output_shardable_axes);
}
return sink2sa;
}
Expand Down Expand Up @@ -486,16 +496,17 @@ std::unordered_map<pir::Value, ShardableAxes> InferShardableAxesFromSink(
const OpTopo& op_topo) {
auto reversed_walker = GetOpsReversedTopoWalker(op_topo);
CHECK_GT(op_topo.ops->count(sink), 0);
size_t rank = GetRank(sink->result(0));
const int result_idx = GetOutputShardableAxesResultIdx(sink);
size_t rank = GetRank(sink->result(result_idx));
const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank);
return ReversedInferShardableAxes(reversed_walker, sink, init_sa);
}


pir::Value GetStmtBigestShapeValueImpl(const IS& injective_source) {
const auto* sink_op = injective_source.sole_sink;
CHECK_EQ(sink_op->num_results(), 1);
return sink_op->result(0);
const int result_idx = GetOutputShardableAxesResultIdx(sink_op);
return sink_op->result(result_idx);
}

pir::Value GetStmtBigestShapeValueImpl(const R& reduce_pattern) {
Expand All @@ -506,8 +517,8 @@ pir::Value GetStmtBigestShapeValueImpl(const R& reduce_pattern) {

pir::Value GetStmtBigestShapeValueImpl(const PS& partial_shardable) {
const auto* sink_op = partial_shardable.sole_sink;
CHECK_EQ(sink_op->num_results(), 1);
return sink_op->result(0);
const int result_idx = GetOutputShardableAxesResultIdx(sink_op);
return sink_op->result(result_idx);
}

pir::Value GetStmtBigestShapeValue(const StmtPattern& stmt) {
Expand Down Expand Up @@ -549,7 +560,6 @@ void SortStmtPtrs(
std::sort(stmt_ptrs->begin(), stmt_ptrs->end(), Cmp);
}


class StmtFusionHelper {
public:
StmtFusionHelper(
Expand Down Expand Up @@ -952,8 +962,10 @@ class StmtFusionHelper {
}();
const auto& shardable_axes_sig = [&] {
ShardableAxesSignature signature;
signature.output_shardable_axes =
value2shardable_axes.at(sink->result(0));
int result_idx = GetOutputShardableAxesResultIdx(sink);
signature.sole_output_sa = SoleOutputShardableAxes{
.shardable_axes=value2shardable_axes.at(sink->result(result_idx)),
};
for (const auto& pair : input_op_operands) {
const auto& [op, idx] = pair;
pir::Value input = op->operand_source(idx);
Expand Down Expand Up @@ -1604,8 +1616,7 @@ class LoopAlignableClusteringPolicy final : public ClusteringPolicy {
const R& src,
const PS& dst) {
const auto* sink_op = src.reduce_op_pattern.reduce_op;
CHECK_EQ(sink_op->num_results(), 1);
pir::Value value = sink_op->result(0);
pir::Value value = sink_op->result(GetOutputShardableAxesResultIdx(sink_op));
const auto& shardable_axes = ShardableAxes4Value(value);
CHECK(shardable_axes.has_value());
return IsStmtSinkOpOutputFullyShardableImpl(src, *shardable_axes.value());
Expand All @@ -1617,8 +1628,7 @@ class LoopAlignableClusteringPolicy final : public ClusteringPolicy {
const R& dst) {
const auto GetSoleOutputValue = [&](const R& reduce_pattern) {
const auto* sink_op = src.reduce_op_pattern.reduce_op;
CHECK_EQ(sink_op->num_results(), 1);
pir::Value value = sink_op->result(0);
pir::Value value = sink_op->result(GetOutputShardableAxesResultIdx(sink_op));
return value;
};
const auto GetShardableAxes = [&](const R& reduce_pattern) {
Expand Down Expand Up @@ -1668,8 +1678,7 @@ class LoopAlignableClusteringPolicy final : public ClusteringPolicy {
const ShardableAxes4ValueT& ShardableAxes4Value,
const StmtPattern& stmt) {
const auto* sink_op = GetStmtSoleSinkOp(stmt);
CHECK_EQ(sink_op->num_results(), 1);
pir::Value value = sink_op->result(0);
pir::Value value = sink_op->result(GetOutputShardableAxesResultIdx(sink_op));
const auto& shardable_axes = ShardableAxes4Value(value);
CHECK(shardable_axes.has_value());
return IsStmtSinkOpOutputFullyShardable(stmt, *shardable_axes.value());
Expand Down Expand Up @@ -1752,7 +1761,8 @@ class LoopAlignableClusteringPolicy final : public ClusteringPolicy {
}
return true;
} else {
return GetRank(reduce_op->result(0)) == shardable_axes.size();
const int result_idx = GetOutputShardableAxesResultIdx(reduce_op);
return GetRank(reduce_op->result(result_idx)) == shardable_axes.size();
}
}

Expand Down

0 comments on commit 35506a8

Please sign in to comment.