Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DimExpr] DimExpr support hash #60471

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions paddle/pir/dialect/shape/utils/dim_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/pir/dialect/shape/utils/dim_expr.h"
#include "paddle/pir/core/utils.h"

namespace symbol {

Expand Down Expand Up @@ -184,4 +185,55 @@ std::ostream& operator<<(std::ostream& stream, const DimExpr& dim_expr) {
return stream;
}

namespace {

std::size_t GetHashValueImpl(const std::int64_t& dim_expr) { return dim_expr; }

std::size_t GetHashValueImpl(const std::string& dim_expr) {
return std::hash<std::string>()(dim_expr);
}

std::size_t GetHashValueImpl(const Negative<DimExpr>& dim_expr) {
return -GetHashValue(dim_expr->data);
}

std::size_t GetHashValueImpl(const Reciprocal<DimExpr>& dim_expr) {
return pir::hash_combine(1, -GetHashValue(dim_expr->data));
}

std::size_t GetHashValueImpl(const List<DimExpr>& exprs) {
std::size_t ret = 0;
for (const auto& expr : *exprs) {
ret = pir::hash_combine(ret, GetHashValue(expr));
}
return ret;
}

std::size_t GetHashValueImpl(const Add<DimExpr>& dim_expr) {
return pir::hash_combine(1, GetHashValueImpl(dim_expr.operands));
}

std::size_t GetHashValueImpl(const Mul<DimExpr>& dim_expr) {
return pir::hash_combine(2, GetHashValueImpl(dim_expr.operands));
}

std::size_t GetHashValueImpl(const Max<DimExpr>& dim_expr) {
return pir::hash_combine(3, GetHashValueImpl(dim_expr.operands));
}

std::size_t GetHashValueImpl(const Min<DimExpr>& dim_expr) {
return pir::hash_combine(4, GetHashValueImpl(dim_expr.operands));
}

std::size_t GetHashValueImpl(const Broadcast<DimExpr>& dim_expr) {
return pir::hash_combine(5, GetHashValueImpl(dim_expr.operands));
}

} // namespace

std::size_t GetHashValue(const DimExpr& dim_expr) {
return std::visit([](const auto& impl) { return GetHashValueImpl(impl); },
dim_expr.variant());
}

} // namespace symbol
13 changes: 13 additions & 0 deletions paddle/pir/dialect/shape/utils/dim_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,17 @@ IR_API std::string ToString(const DimExpr& dim_expr);

IR_API std::ostream& operator<<(std::ostream&, const DimExpr& dim_expr);

IR_API std::size_t GetHashValue(const DimExpr& dim_expr);

} // namespace symbol

namespace std {

template <>
struct hash<symbol::DimExpr> {
std::size_t operator()(const symbol::DimExpr& dim_expr) const {
return symbol::GetHashValue(dim_expr);
}
};

} // namespace std
34 changes: 28 additions & 6 deletions test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
namespace symbol::test {

// Construct DimExpr by overloaded operator(+, - , *, /)
TEST(DimExpr, dim_expr_naive) {
TEST(DimExpr, DimExprNaive) {
DimExpr sym0 = DimExpr("S0");
DimExpr sym1 = DimExpr("S1");
DimExpr constant1 = DimExpr(1);
DimExpr output = (sym0 + sym1) * constant1;
}

// Construct DimExpr by DimExprBuilder
TEST(DimExpr, dim_expr_builder) {
TEST(DimExpr, DimExprBuilder) {
DimExprBuilder builder{nullptr};
DimExpr sym0 = DimExpr("S0");
DimExpr sym1 = DimExpr("S1");
Expand All @@ -40,7 +40,7 @@ TEST(DimExpr, dim_expr_builder) {
}

// Add constraints by DimExprBuilder
TEST(DimExpr, constraint) {
TEST(DimExpr, Constraint) {
std::vector<DimExprConstraint> constraints{};
DimExprBuilder builder(&constraints);
DimExpr sym0 = DimExpr("S0");
Expand All @@ -55,7 +55,7 @@ TEST(DimExpr, constraint) {
extend_x = x.shape
out = pd.reshape(y, extend_x)
*/
TEST(DimExpr, data_shape_expr) {
TEST(DimExpr, DataShapeExpr) {
// Show ideal ShapeOrDataDimExprs of each pir::Value
std::vector<DimExpr> x_shapes{DimExpr("S0"), DimExpr(2)};
std::vector<DimExpr> y_shapes{DimExpr(1), DimExpr("S1"), DimExpr(2)};
Expand All @@ -80,7 +80,7 @@ TEST(Simplify, NumberArithmetic) {
ASSERT_EQ((mul_div.Get<std::int64_t>()), 1);
}

TEST(DimExpr, equal) {
TEST(DimExpr, Equal) {
DimExprBuilder builder{nullptr};
DimExpr sym0 = DimExpr("S0");
DimExpr sym1 = DimExpr("S1");
Expand Down Expand Up @@ -111,7 +111,7 @@ TEST(DimExpr, equal) {
builder.Broadcast(DimExpr("S0"), constant1));
}

TEST(DimExpr, print) {
TEST(DimExpr, Print) {
DimExprBuilder builder{nullptr};
DimExpr sym0 = DimExpr("S0");
DimExpr sym1 = DimExpr("S1");
Expand All @@ -124,4 +124,26 @@ TEST(DimExpr, print) {
ASSERT_EQ((ToString(builder.Broadcast(sym0, sym1))), "Broadcast(S0, S1)");
}

TEST(DimExpr, Hash) {
DimExprBuilder builder{nullptr};
DimExpr sym0 = DimExpr("S0");
DimExpr sym1 = DimExpr("S1");
ASSERT_EQ((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(sym0 + sym1)));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(sym1 + sym0)));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(sym0 - sym1)));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(sym0 * sym1)));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(sym0 / sym1)));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(builder.Max(sym0, sym1))));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(builder.Min(sym0, sym1))));
ASSERT_NE((std::hash<DimExpr>()(sym0 + sym1)),
(std::hash<DimExpr>()(builder.Broadcast(sym0, sym1))));
}

} // namespace symbol::test