Skip to content

Commit

Permalink
[DimExpr] DimExpr support hash (PaddlePaddle#60471)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent f7fb89f commit fd3e7a8
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 6 deletions.
52 changes: 52 additions & 0 deletions paddle/pir/dialect/shape/utils/dim_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/pir/dialect/shape/utils/dim_expr.h"
#include "paddle/pir/core/utils.h"

namespace symbol {

Expand Down Expand Up @@ -184,4 +185,55 @@ std::ostream& operator<<(std::ostream& stream, const DimExpr& dim_expr) {
return stream;
}

namespace {

std::size_t GetHashValueImpl(const std::int64_t& dim_expr) { return dim_expr; }

std::size_t GetHashValueImpl(const std::string& dim_expr) {
return std::hash<std::string>()(dim_expr);
}

std::size_t GetHashValueImpl(const Negative<DimExpr>& dim_expr) {
return -GetHashValue(dim_expr->data);
}

std::size_t GetHashValueImpl(const Reciprocal<DimExpr>& dim_expr) {
return pir::hash_combine(1, -GetHashValue(dim_expr->data));
}

std::size_t GetHashValueImpl(const List<DimExpr>& exprs) {
std::size_t ret = 0;
for (const auto& expr : *exprs) {
ret = pir::hash_combine(ret, GetHashValue(expr));
}
return ret;
}

std::size_t GetHashValueImpl(const Add<DimExpr>& dim_expr) {
return pir::hash_combine(1, GetHashValueImpl(dim_expr.operands));
}

std::size_t GetHashValueImpl(const Mul<DimExpr>& dim_expr) {
return pir::hash_combine(2, GetHashValueImpl(dim_expr.operands));
}

std::size_t GetHashValueImpl(const Max<DimExpr>& dim_expr) {
return pir::hash_combine(3, GetHashValueImpl(dim_expr.operands));
}

std::size_t GetHashValueImpl(const Min<DimExpr>& dim_expr) {
return pir::hash_combine(4, GetHashValueImpl(dim_expr.operands));
}

std::size_t GetHashValueImpl(const Broadcast<DimExpr>& dim_expr) {
return pir::hash_combine(5, GetHashValueImpl(dim_expr.operands));
}

} // namespace

std::size_t GetHashValue(const DimExpr& dim_expr) {
return std::visit([](const auto& impl) { return GetHashValueImpl(impl); },
dim_expr.variant());
}

} // namespace symbol
13 changes: 13 additions & 0 deletions paddle/pir/dialect/shape/utils/dim_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,17 @@ IR_API std::string ToString(const DimExpr& dim_expr);

IR_API std::ostream& operator<<(std::ostream&, const DimExpr& dim_expr);

IR_API std::size_t GetHashValue(const DimExpr& dim_expr);

} // namespace symbol

namespace std {

template <>
struct hash<symbol::DimExpr> {
std::size_t operator()(const symbol::DimExpr& dim_expr) const {
return symbol::GetHashValue(dim_expr);
}
};

} // namespace std
34 changes: 28 additions & 6 deletions test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
namespace symbol::test {

// Construct DimExpr by overloaded operator(+, - , *, /)
TEST(DimExpr, dim_expr_naive) {
TEST(DimExpr, DimExprNaive) {
DimExpr sym0 = DimExpr("S0");
DimExpr sym1 = DimExpr("S1");
DimExpr constant1 = DimExpr(1);
DimExpr output = (sym0 + sym1) * constant1;
}

// Construct DimExpr by DimExprBuilder
TEST(DimExpr, dim_expr_builder) {
TEST(DimExpr, DimExprBuilder) {
DimExprBuilder builder{nullptr};
DimExpr sym0 = DimExpr("S0");
DimExpr sym1 = DimExpr("S1");
Expand All @@ -40,7 +40,7 @@ TEST(DimExpr, dim_expr_builder) {
}

// Add constraints by DimExprBuilder
TEST(DimExpr, constraint) {
TEST(DimExpr, Constraint) {
std::vector<DimExprConstraint> constraints{};
DimExprBuilder builder(&constraints);
DimExpr sym0 = DimExpr("S0");
Expand All @@ -55,7 +55,7 @@ TEST(DimExpr, constraint) {
extend_x = x.shape
out = pd.reshape(y, extend_x)
*/
TEST(DimExpr, data_shape_expr) {
TEST(DimExpr, DataShapeExpr) {
// Show ideal ShapeOrDataDimExprs of each pir::Value
std::vector<DimExpr> x_shapes{DimExpr("S0"), DimExpr(2)};
std::vector<DimExpr> y_shapes{DimExpr(1), DimExpr("S1"), DimExpr(2)};
Expand All @@ -80,7 +80,7 @@ TEST(Simplify, NumberArithmetic) {
ASSERT_EQ((mul_div.Get<std::int64_t>()), 1);
}

TEST(DimExpr, equal) {
TEST(DimExpr, Equal) {
DimExprBuilder builder{nullptr};
DimExpr sym0 = DimExpr("S0");
DimExpr sym1 = DimExpr("S1");
Expand Down Expand Up @@ -111,7 +111,7 @@ TEST(DimExpr, equal) {
builder.Broadcast(DimExpr("S0"), constant1));
}

TEST(DimExpr, print) {
TEST(DimExpr, Print) {
DimExprBuilder builder{nullptr};
DimExpr sym0 = DimExpr("S0");
DimExpr sym1 = DimExpr("S1");
Expand All @@ -124,4 +124,26 @@ TEST(DimExpr, print) {
ASSERT_EQ((ToString(builder.Broadcast(sym0, sym1))), "Broadcast(S0, S1)");
}

TEST(DimExpr, Hash) {
DimExprBuilder builder{nullptr};
DimExpr sym0 = DimExpr("S0");
DimExpr sym1 = DimExpr("S1");
ASSERT_EQ((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(sym0 + sym1)));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(sym1 + sym0)));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(sym0 - sym1)));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(sym0 * sym1)));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(sym0 / sym1)));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(builder.Max(sym0, sym1))));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(builder.Min(sym0, sym1))));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(builder.Broadcast(sym0, sym1))));
}

} // namespace symbol::test

0 comments on commit fd3e7a8

Please sign in to comment.