Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DimExpr] Remove adt DimExpr #60901

Merged
merged 7 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions paddle/cinn/adt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ if(NOT CINN_ONLY)
equation_value.cc
generate_map_expr.cc
get_sub_reshape_dim_ranges.cc
graph_symbolic_dim_infer_ctx.cc
igroup.cc
index_expr_infer_context.cc
kgroup.cc
Expand All @@ -25,7 +24,6 @@ if(NOT CINN_ONLY)
schedule_dim.cc
schedule_mesh.cc
dim_expr.cc
dim_expr_simplifier.cc
simplify_value.cc
write_broadcast_disabled_bidirection_equation_generator.cc)

Expand All @@ -43,13 +41,4 @@ if(NOT CINN_ONLY)
glog
absl)

cinn_cc_test(
dim_expr_test
SRCS
dim_expr_test.cc
DEPS
gtest
glog
cinncore)

endif()
5 changes: 3 additions & 2 deletions paddle/cinn/adt/adapter_dynamic_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "glog/logging.h"

#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/dim_expr.h"
#include "paddle/cinn/adt/symbolic_dim.h"
#include "paddle/cinn/hlir/framework/pir/group.h"

Expand All @@ -34,8 +35,8 @@ struct DynamicTensor final {
.size();
}

const std::vector<std::optional<DimExpr>>& GetShape() const {
return group->graph_symbolic_dim_infer_ctx()->GetTensorDimExprs(node_data);
const std::vector<DimExpr>& GetShape() const {
return group->GetShapeOrDataExprs(node_data).shape();
}
};

Expand Down
132 changes: 1 addition & 131 deletions paddle/cinn/adt/dim_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,138 +18,8 @@

namespace cinn::adt {

namespace {

template <typename T0, typename T1>
bool DimExprEqualImpl(const T0&, const T1&) {
LOG(FATAL) << "Dead code";
}

bool DimExprEqualImpl(std::int64_t lhs, std::int64_t rhs) { return lhs == rhs; }

bool DimExprEqualImpl(const SymbolicDim& lhs, const SymbolicDim& rhs) {
return lhs == rhs;
}

bool DimExprEqualImpl(const Negative<DimExpr>& lhs,
const Negative<DimExpr>& rhs) {
const auto& [lhs_arg0] = lhs.tuple();
const auto& [rhs_arg0] = rhs.tuple();
return lhs_arg0 == rhs_arg0;
}

bool DimExprEqualImpl(const Reciprocal<DimExpr>& lhs,
const Reciprocal<DimExpr>& rhs) {
const auto& [lhs_arg0] = lhs.tuple();
const auto& [rhs_arg0] = rhs.tuple();
return lhs_arg0 == rhs_arg0;
}

bool DimExprEqualImpl(const Sum<DimExpr>& lhs, const Sum<DimExpr>& rhs) {
const auto& [lhs_operands] = lhs;
const auto& [rhs_operands] = rhs;
return lhs_operands == rhs_operands;
}

bool DimExprEqualImpl(const Product<DimExpr>& lhs,
const Product<DimExpr>& rhs) {
const auto& [lhs_operands] = lhs;
const auto& [rhs_operands] = rhs;
return lhs_operands == rhs_operands;
}

bool DimExprEqualImpl(const BroadcastedDim<DimExpr>& lhs,
const BroadcastedDim<DimExpr>& rhs) {
const auto& [lhs_operands] = lhs;
const auto& [rhs_operands] = rhs;
return lhs_operands == rhs_operands;
}

} // namespace

bool operator==(const DimExpr& lhs, const DimExpr& rhs) {
return std::visit(
[](const auto& lhs, const auto& rhs) {
if (std::is_same_v<std::decay_t<decltype(lhs)>,
std::decay_t<decltype(rhs)>>) {
return DimExprEqualImpl(lhs, rhs);
} else {
return false;
}
},
lhs.variant(),
rhs.variant());
}

namespace {

std::size_t GetHashValueImpl(std::int64_t expr) { return expr; }

std::size_t GetHashValueImpl(const SymbolicDim& expr) {
return expr.value().unique_id();
}

std::size_t GetHashValueImpl(const Negative<DimExpr>& expr) {
const auto& [item] = expr.tuple();
return -GetHashValue(item);
}

std::size_t GetHashValueImpl(const Reciprocal<DimExpr>& expr) {
const auto& [item] = expr.tuple();
return -GetHashValue(item);
}

std::size_t GetHashValueImpl(const List<DimExpr>& exprs) {
std::size_t ret = 0;
for (const auto& expr : *exprs) {
ret = hash_combine(ret, GetHashValue(expr));
}
}

std::size_t GetHashValueImpl(const Sum<DimExpr>& expr) {
const auto& [operands] = expr;
return GetHashValueImpl(operands);
}

std::size_t GetHashValueImpl(const Product<DimExpr>& expr) {
const auto& [operands] = expr;
return GetHashValueImpl(operands);
}

std::size_t GetHashValueImpl(const BroadcastedDim<DimExpr>& expr) {
const auto& [operands] = expr;
return GetHashValueImpl(operands);
}

} // namespace

std::size_t GetHashValue(const DimExpr& expr) {
return std::visit([&](const auto& impl) { return GetHashValueImpl(impl); },
expr.variant());
}

DimExpr operator+(const DimExpr& lhs, const DimExpr& rhs) {
return Sum<DimExpr>{List<DimExpr>{lhs, rhs}};
}

DimExpr operator-(const DimExpr& lhs, const DimExpr& rhs) {
return Sum<DimExpr>{List<DimExpr>{lhs, Negative<DimExpr>{rhs}}};
}

DimExpr operator*(const DimExpr& lhs, const DimExpr& rhs) {
return Product<DimExpr>{List<DimExpr>{lhs, rhs}};
}

DimExpr operator/(const DimExpr& lhs, const DimExpr& rhs) {
return Product<DimExpr>{List<DimExpr>{lhs, Reciprocal<DimExpr>{rhs}}};
}

DimExpr MakeBroadcastedDim(const DimExpr& lhs, const DimExpr& rhs) {
return BroadcastedDim<DimExpr>{List<DimExpr>{lhs, rhs}};
return ::symbol::Broadcast<DimExpr>{::symbol::List<DimExpr>{lhs, rhs}};
}

std::ostream& operator<<(std::ostream& stream, const DimExpr& expr) {
stream << ToTxtString(expr);
return stream;
}
} // namespace cinn::adt
39 changes: 2 additions & 37 deletions paddle/cinn/adt/dim_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,47 +19,12 @@
#include "paddle/cinn/adt/arithmetic.h"
#include "paddle/cinn/adt/logical.h"
#include "paddle/cinn/adt/symbolic_dim.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"

namespace cinn::adt {

template <typename T>
struct BroadcastedDim final {
List<T> operands;

const BroadcastedDim& tuple() const { return *this; }
};

// DimExpr = std::int64_t
// | SymbolicDim
// | Negative DimExpr
// | Reciprocal DimExpr
// | Sum DimExpr
// | Product DimExpr
// | BroadcastedDim DimExpr
DEFINE_ADT_UNION(DimExpr,
std::int64_t,
SymbolicDim,
Negative<DimExpr>,
Reciprocal<DimExpr>,
Sum<DimExpr>,
Product<DimExpr>,
BroadcastedDim<DimExpr>);

DimExpr operator+(const DimExpr& lhs, const DimExpr& rhs);
DimExpr operator-(const DimExpr& lhs, const DimExpr& rhs);
DimExpr operator*(const DimExpr& lhs, const DimExpr& rhs);
DimExpr operator/(const DimExpr& lhs, const DimExpr& rhs);
using DimExpr = ::symbol::DimExpr;

DimExpr MakeBroadcastedDim(const DimExpr& lhs, const DimExpr& rhs);

bool operator==(const DimExpr& lhs, const DimExpr& rhs);

inline bool operator!=(const DimExpr& lhs, const DimExpr& rhs) {
return !(lhs == rhs);
}

std::size_t GetHashValue(const DimExpr& expr);

std::ostream& operator<<(std::ostream&, const DimExpr& expr);

} // namespace cinn::adt
20 changes: 10 additions & 10 deletions paddle/cinn/adt/dim_expr_match_trait.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,23 +62,23 @@ struct MatchTrait<DimExpr, SymbolicDim> final {
};

template <typename T0>
struct MatchTrait<DimExpr, Negative<T0>> final
: public UnaryDimExprMatchTrait<Negative, T0> {};
struct MatchTrait<DimExpr, ::symbol::Negative<T0>> final
: public UnaryDimExprMatchTrait<::symbol::Negative, T0> {};

template <typename T0>
struct MatchTrait<DimExpr, Reciprocal<T0>> final
: public UnaryDimExprMatchTrait<Reciprocal, T0> {};
struct MatchTrait<DimExpr, ::symbol::Reciprocal<T0>> final
: public UnaryDimExprMatchTrait<::symbol::Reciprocal, T0> {};

template <typename T0>
struct MatchTrait<DimExpr, Sum<T0>> final
: public ListDimExprMatchTrait<Sum, T0> {};
struct MatchTrait<DimExpr, ::symbol::Add<T0>> final
: public ListDimExprMatchTrait<::symbol::Add, T0> {};

template <typename T0>
struct MatchTrait<DimExpr, Product<T0>> final
: public ListDimExprMatchTrait<Product, T0> {};
struct MatchTrait<DimExpr, ::symbol::Mul<T0>> final
: public ListDimExprMatchTrait<::symbol::Mul, T0> {};

template <typename T0>
struct MatchTrait<DimExpr, BroadcastedDim<T0>> final
: public ListDimExprMatchTrait<BroadcastedDim, T0> {};
struct MatchTrait<DimExpr, ::symbol::Broadcast<T0>> final
: public ListDimExprMatchTrait<::symbol::Broadcast, T0> {};

} // namespace cinn::adt
Loading