Skip to content

Commit

Permalink
resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
MayYouBeProsperous committed Jan 2, 2024
2 parents a018a66 + a08580e commit 075ef78
Show file tree
Hide file tree
Showing 575 changed files with 16,837 additions and 6,836 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ paddle/fluid/pir/dialect/operator/ir/pd_api.*
paddle/fluid/pir/dialect/operator/ir/op_decomp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.*
paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.*
paddle/fluid/pir/dialect/operator/ir/pd_op_fused.*
paddle/fluid/pir/dialect/operator/ir/pd_op_fused_bwd.*
Expand Down
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)
option(WITH_XPU_KP "Compile PaddlePaddle with BAIDU XPU compiler " OFF)
option(WITH_XPU_XFT "Compile PaddlePaddle with BAIDU XPU-XFT" OFF)
option(WITH_XPU_PLUGIN "Compile PaddlePaddle with BAIDU XPU plugin" OFF)
option(WITH_XPU_XHPC "Compile PaddlePaddle with BAIDU XPU-HPC library" OFF)
option(WITH_XPU_XHPC "Compile PaddlePaddle with BAIDU XPU-HPC library"
${WITH_XPU})
option(WITH_WIN_DUMP_DBG "Compile with windows core dump debug mode" OFF)
option(WITH_ROCM "Compile PaddlePaddle with ROCM platform" OFF)
option(WITH_IPU "Compile PaddlePaddle with Graphcore IPU" OFF)
Expand Down
4 changes: 2 additions & 2 deletions cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ if(NOT DEFINED XPU_BASE_DATE)
set(XPU_BASE_DATE "20231203")
endif()
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "20231215")
set(XPU_XHPC_BASE_DATE "20231229")
endif()
set(XPU_XCCL_BASE_VERSION "1.1.7.1")
set(XPU_XCCL_BASE_VERSION "1.1.8.1")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
set(XPU_XFT_BASE_VERSION "20230602")
endif()
Expand Down
15 changes: 10 additions & 5 deletions cmake/inference_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,16 @@ copy_part_of_thrid_party(inference_lib_dist ${PADDLE_INFERENCE_INSTALL_DIR})

set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")

if(WIN32)
set(paddle_common_lib ${PADDLE_BINARY_DIR}/paddle/common/common.*)
else()
set(paddle_common_lib ${PADDLE_BINARY_DIR}/paddle/common/libcommon.*)
endif()
copy(
inference_lib_dist
SRCS ${paddle_common_lib}
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)

if(WIN32)
if(WITH_STATIC_LIB)
set(paddle_inference_lib
Expand Down Expand Up @@ -268,11 +278,6 @@ else()
SRCS ${paddle_phi_lib}
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)
endif()
set(paddle_common_lib ${PADDLE_BINARY_DIR}/paddle/common/libcommon.*)
copy(
inference_lib_dist
SRCS ${paddle_common_lib}
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)
endif()

copy(
Expand Down
5 changes: 4 additions & 1 deletion paddle/cinn/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ gather_srcs(
python_interpreter_guard.cc
nvgpu_dev_info.cc
integer_set.cc
dim_expr_simplify.cc)
dim_expr_simplify.cc
dim_expr_converter.cc)

cinn_cc_test(test_equation_graph_topo_walker SRCS
equation_graph_topo_walker_test.cc DEPS gtest glog)
Expand All @@ -49,4 +50,6 @@ endif()
if(NOT CINN_ONLY)
cinn_cc_test(dim_expr_simplify_test SRCS dim_expr_simplify_test.cc DEPS
cinncore)
cinn_cc_test(dim_expr_converter_test SRCS dim_expr_converter_test.cc DEPS
cinncore)
endif()
101 changes: 101 additions & 0 deletions paddle/cinn/common/dim_expr_converter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/common/dim_expr_converter.h"
#include "paddle/cinn/common/ir_util.h"

namespace cinn::common {
using namespace symbol; // NOLINT

namespace {

struct DimExprToIrExprVisitor {
ir::Expr ConvertToIrExpr(const DimExpr& dim_expr) {
return std::visit(*this, dim_expr.variant());
}

ir::Expr operator()(const int64_t& dim) { return ir::Expr(dim); }

ir::Expr operator()(const std::string& dim_expr) {
Var x = ir::_Var_::Make(dim_expr, Int(64));
return x;
}

ir::Expr operator()(const Negative<DimExpr>& dim_expr) {
const auto& [operand] = *dim_expr;
return ir::Sub::Make(ir::Expr(std::int64_t(0)), ConvertToIrExpr(operand));
}

ir::Expr operator()(const Reciprocal<DimExpr>& dim_expr) {
const auto& [operand] = *dim_expr;
return ir::Div::Make(ir::Expr(std::int64_t(1)), ConvertToIrExpr(operand));
}

ir::Expr operator()(const Add<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
if (operands->empty()) {
return ir::Expr(std::int64_t(0));
}
ir::Expr sum = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
sum = ir::Add::Make(sum, ConvertToIrExpr(operands->at(i)));
}
return sum;
}

ir::Expr operator()(const Mul<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
if (operands->empty()) {
return ir::Expr(std::int64_t(1));
}
ir::Expr product = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
product = ir::Mul::Make(product, ConvertToIrExpr(operands->at(i)));
}
return product;
}

ir::Expr operator()(const Max<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
CHECK(!operands->empty());
ir::Expr max = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
max = ir::Max::Make(max, ConvertToIrExpr(operands->at(i)));
}
return max;
}

ir::Expr operator()(const Min<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
CHECK(!operands->empty());
ir::Expr min = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
min = ir::Min::Make(min, ConvertToIrExpr(operands->at(i)));
}
return min;
}

ir::Expr operator()(const Broadcast<DimExpr>& dim_expr) {
LOG(FATAL)
<< "no support for converting from Broadcast<DimExpr> to ir::Expr";
}
};

} // namespace

ir::Expr DimExprConverter::ConvertToIrExpr(const DimExpr& dim_expr) const {
return DimExprToIrExprVisitor().ConvertToIrExpr(dim_expr);
}

} // namespace cinn::common
26 changes: 26 additions & 0 deletions paddle/cinn/common/dim_expr_converter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/ir/ir.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"

namespace cinn::common {

struct DimExprConverter final {
ir::Expr ConvertToIrExpr(const symbol::DimExpr&) const;
};

} // namespace cinn::common
79 changes: 79 additions & 0 deletions paddle/cinn/common/dim_expr_converter_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <sstream>

#include "gtest/gtest.h"

#include "paddle/cinn/common/dim_expr_converter.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_printer.h"

namespace cinn::common::test {

using namespace symbol; // NOLINT

TEST(Convert, AddExpr) {
List<DimExpr> num_lists{DimExpr(4), DimExpr(5), DimExpr("sym_0")};
DimExpr dim_expr{Add<DimExpr>{num_lists}};
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

ir::Expr expr1 =
ir::Add::Make(ir::Expr(std::int64_t(4)), ir::Expr(std::int64_t(5)));
ir::Expr dst_expr = ir::Add::Make(expr1, ir::_Var_::Make("sym_0", Int(64)));
ASSERT_TRUE(MathEqual(src_expr, dst_expr));
}

TEST(Convert, SubExpr) {
DimExpr dim_expr = DimExpr(4) - DimExpr("sym_0");
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

ir::Expr expr1 = ir::Sub::Make(ir::Expr(std::int64_t(0)),
ir::_Var_::Make("sym_0", Int(64)));
ir::Expr dst_expr = ir::Add::Make(ir::Expr(std::int64_t(4)), expr1);
ASSERT_TRUE(MathEqual(src_expr, dst_expr));
}

TEST(Convert, MulExpr) {
List<DimExpr> num_lists{DimExpr(4), DimExpr(5), DimExpr("sym_0")};
DimExpr dim_expr{Mul<DimExpr>{num_lists}};
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

ir::Expr expr1 =
ir::Mul::Make(ir::Expr(std::int64_t(4)), ir::Expr(std::int64_t(5)));
ir::Expr dst_expr = ir::Mul::Make(expr1, ir::_Var_::Make("sym_0", Int(64)));
ASSERT_TRUE(MathEqual(src_expr, dst_expr));
}

TEST(Convert, MaxExpr) {
List<DimExpr> num_lists{DimExpr(4), DimExpr(5), DimExpr("sym_0")};
DimExpr dim_expr{Max<DimExpr>{num_lists}};
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

std::ostringstream stream;
stream << src_expr;
ASSERT_EQ(stream.str(), "cinn_max(cinn_max(4ll, 5ll), sym_0)");
}

TEST(Convert, MinExpr) {
List<DimExpr> num_lists{DimExpr(4), DimExpr(5), DimExpr("sym_0")};
DimExpr dim_expr{Min<DimExpr>{num_lists}};
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

std::ostringstream stream;
stream << src_expr;
ASSERT_EQ(stream.str(), "cinn_min(cinn_min(4ll, 5ll), sym_0)");
}

} // namespace cinn::common::test
57 changes: 42 additions & 15 deletions paddle/cinn/common/integer_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ std::optional<bool> SymbolicExprAnalyzer::ProveEQ(const ir::Expr& lhs,
if (diff.is_constant()) {
return diff.get_constant() == 0;
}
ir::Expr diff_lower_bound = LowerBound(diff);
VLOG(6) << "lower bound of " << diff << " = " << diff_lower_bound;
ir::Expr diff_upper_bound = UpperBound(diff);
VLOG(6) << "upper bound of " << diff << " = " << diff_upper_bound;
if (diff_lower_bound.is_constant() && diff_upper_bound.is_constant() &&
diff_lower_bound.get_constant() == diff_upper_bound.get_constant()) {
return diff_lower_bound.get_constant() == 0;
}
std::optional<bool> prove_gt = ProveGT(lhs, rhs);
if (prove_gt.has_value() && prove_gt.value()) {
return false;
Expand All @@ -71,22 +79,11 @@ std::optional<bool> SymbolicExprAnalyzer::ProveEQ(const ir::Expr& lhs,

std::optional<bool> SymbolicExprAnalyzer::ProveNE(const ir::Expr& lhs,
const ir::Expr& rhs) const {
if (lhs == rhs) {
return false;
}
ir::Expr diff = AutoSimplify(ir::Sub::Make(lhs, rhs), var_intervals_);
if (diff.is_constant()) {
return diff.get_constant() != 0;
}
std::optional<bool> prove_gt = ProveGT(lhs, rhs);
if (prove_gt.has_value() && prove_gt.value()) {
return true;
}
std::optional<bool> prove_lt = ProveLT(lhs, rhs);
if (prove_lt.has_value() && prove_lt.value()) {
return true;
std::optional<bool> prove_eq = ProveEQ(lhs, rhs);
if (!prove_eq.has_value()) {
return std::nullopt;
}
return std::nullopt;
return !prove_eq.value();
}

std::optional<bool> SymbolicExprAnalyzer::ProveGE(const ir::Expr& lhs,
Expand Down Expand Up @@ -456,5 +453,35 @@ std::optional<bool> SingleIntervalIntSet::ProveSuperSet(
return std::nullopt;
}

ir::Expr EnhancedSimplifyModExpr(
ir::Expr e,
const absl::flat_hash_map<std::string, CasInterval>& var_intervals) {
struct Mutator : public ir::IRMutator<ir::Expr*> {
explicit Mutator(
const absl::flat_hash_map<std::string, CasInterval>& var_intervals)
: var_intervals_(var_intervals), analyzer_(var_intervals_) {}

void operator()(ir::Expr* expr) { Visit(expr); }
void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

private:
void Visit(const ir::Mod* op, ir::Expr* expr) override {
std::optional<bool> prove_lt = analyzer_.ProveLT(op->a(), op->b());
if (prove_lt.has_value() && prove_lt.value()) {
*expr = op->a();
}
}

private:
const absl::flat_hash_map<std::string, CasInterval>& var_intervals_;
SymbolicExprAnalyzer analyzer_;
};

Mutator mutator(var_intervals);
ir::Expr copied = ir::ir_utils::IRCopy(e);
mutator(&copied);
return copied;
}

} // namespace common
} // namespace cinn
20 changes: 17 additions & 3 deletions paddle/cinn/common/integer_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ struct SymbolicExprLimit {
// The set consisting of all integers in the interval from min to max
class SingleIntervalIntSet {
public:
explicit SingleIntervalIntSet(const ir::Expr& min,
const ir::Expr& max,
cas_intervals_t var_intervals = {});
explicit SingleIntervalIntSet(
const ir::Expr& min = SymbolicExprLimit::positive_inf,
const ir::Expr& max = SymbolicExprLimit::negative_inf,
cas_intervals_t var_intervals = {});
SingleIntervalIntSet(const SingleIntervalIntSet& set) = default;
SingleIntervalIntSet(SingleIntervalIntSet&& set) = default;
SingleIntervalIntSet& operator=(const SingleIntervalIntSet& set) = default;
Expand Down Expand Up @@ -92,5 +93,18 @@ class SingleIntervalIntSet {
cas_intervals_t var_intervals_;
};

std::optional<bool> ProveEQ(const SingleIntervalIntSet& lhs,
const SingleIntervalIntSet& rhs);
std::optional<SingleIntervalIntSet> ProvedUnion(const SingleIntervalIntSet& a,
const SingleIntervalIntSet& b);
std::optional<SingleIntervalIntSet> ProvedIntersect(
const SingleIntervalIntSet& a, const SingleIntervalIntSet& b);
cas_intervals_t MergeVarIntervals(const SingleIntervalIntSet& a,
const SingleIntervalIntSet& b);

ir::Expr EnhancedSimplifyModExpr(
ir::Expr e,
const absl::flat_hash_map<std::string, CasInterval>& var_intervals);

} // namespace common
} // namespace cinn
Loading

0 comments on commit 075ef78

Please sign in to comment.