Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:smallpoxscattered/Paddle into sm…
Browse files Browse the repository at this point in the history
…allpoxscatter
  • Loading branch information
smallpoxscattered committed Aug 21, 2024
2 parents 44cd7ec + 3901873 commit 9a050a3
Show file tree
Hide file tree
Showing 845 changed files with 26,580 additions and 11,660 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ repos:
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.0
rev: v0.6.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --no-cache]
Expand Down
7 changes: 0 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -593,13 +593,6 @@ if(WITH_RPC)
OFF
CACHE BOOL "Disable WITH_RPC when compiling with XPU" FORCE)
endif()
if(WITH_CINN AND WITH_RPC)
message(
WARNING "Disable WITH_RPC when compiling with CINN. Force WITH_RPC=OFF.")
set(WITH_RPC
OFF
CACHE BOOL "Disable WITH_RPC when compiling with CINN" FORCE)
endif()
endif()

if(WITH_MPI)
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ if(NOT DEFINED XPU_XRE_BASE_VERSION)
set(XPU_XRE_BASE_VERSION "4.32.0.1")
endif()
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "20240804")
set(XPU_XHPC_BASE_DATE "20240809")
endif()
set(XPU_XCCL_BASE_VERSION "1.2.5")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
31 changes: 24 additions & 7 deletions paddle/cinn/adt/simplify_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ struct SimplifyDotUndot {
pre_index_undot = index_undot_value;
}
}
CHECK(pre_index_undot.has_value());
PADDLE_ENFORCE_EQ(
pre_index_undot.has_value(),
true,
phi::errors::InvalidArgument("pre_index_undot should not be null"));
const auto& [index_value, undot_dims] =
pre_index_undot.value()
.Get<IndexUnDotValue<Value, List<DimExpr>>>()
Expand Down Expand Up @@ -195,9 +198,14 @@ struct SimplifyGcdShape {
const auto& iter_values = index_dot_values.Get<List<Value>>();
const auto& undot_dim_values = undot_dims;
const auto& dot_dim_values = dot_dims;
CHECK(IsConstantListAllPositiveInt64(undot_dim_values));
CHECK(IsConstantListAllPositiveInt64(dot_dim_values));

PADDLE_ENFORCE_EQ(IsConstantListAllPositiveInt64(undot_dim_values),
true,
phi::errors::InvalidArgument(
"The undot_dim_values should be all positive int64"));
PADDLE_ENFORCE_EQ(IsConstantListAllPositiveInt64(dot_dim_values),
true,
phi::errors::InvalidArgument(
"The dot_dim_values should be all positive int64"));
const auto& sub_reshape_dim_ranges =
GetSubReshapeDimRanges(undot_dim_values, dot_dim_values);
if (!sub_reshape_dim_ranges.has_value()) {
Expand Down Expand Up @@ -321,7 +329,10 @@ struct SimplifyDotDot {
std::int64_t Product(const List<DimExpr>& dims) {
std::int64_t ret = 1;
for (const auto& dim : *dims) {
CHECK(dim.Has<std::int64_t>());
PADDLE_ENFORCE_EQ(
dim.Has<std::int64_t>(),
true,
phi::errors::InvalidArgument("dim should have std::int64_t"));
ret *= dim.Get<std::int64_t>();
}
return ret;
Expand Down Expand Up @@ -400,7 +411,10 @@ struct SymbolicDim_SimplifyDotUndot {
pre_index_undot = index_undot_value;
}
}
CHECK(pre_index_undot.has_value());
PADDLE_ENFORCE_EQ(
pre_index_undot.has_value(),
true,
phi::errors::InvalidArgument("pre_index_undot should not be null"));
const auto& [index_value, undot_dims] =
pre_index_undot.value()
.Get<IndexUnDotValue<Value, List<DimExpr>>>()
Expand Down Expand Up @@ -447,7 +461,10 @@ struct SymbolicDim_SimplifyDotUndot_DimExpr {
pre_index_undot = index_undot_value;
}
}
CHECK(pre_index_undot.has_value());
PADDLE_ENFORCE_EQ(
pre_index_undot.has_value(),
true,
phi::errors::InvalidArgument("pre_index_undot should not be null"));
const auto& [index_value, undot_dims] =
pre_index_undot.value()
.Get<IndexUnDotValue<Value, List<DimExpr>>>()
Expand Down
19 changes: 16 additions & 3 deletions paddle/cinn/backends/codegen_cuda_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ llvm::Value* CodeGenCudaHost::LowerGPUKernelLauncher(
const ir::_LoweredFunc_* func) {
auto body = func->body;
auto* call_ir = body.As<ir::Call>();
CHECK(call_ir);
PADDLE_ENFORCE_EQ(
call_ir,
nullptr,
phi::errors::InvalidArgument("The 'call_ir' must be true."));

// Create the function
// @{
Expand Down Expand Up @@ -144,7 +147,12 @@ llvm::Value* CodeGenCudaHost::LowerGPUKernelLauncher(
b_->getInt8PtrTy(), kvalue, r_arg.as_var()->name + "_ptr_load"));
} else if (r_arg.as_var()->type().is_cpp_handle() ||
r_arg.as_var()->type().is_int(32)) {
CHECK(global_args.count(r_arg.as_var()->name));
PADDLE_ENFORCE_EQ(
global_args.count(r_arg.as_var()->name),
1,
phi::errors::InvalidArgument(
"The argument '%s' must be present in global_args.",
r_arg.as_var()->name.c_str()));
call_args.push_back(global_args[r_arg.as_var()->name]);
} else {
CINN_NOT_IMPLEMENTED;
Expand Down Expand Up @@ -285,7 +293,12 @@ llvm::Value* CodeGenCudaHost::LowerGPUKernelCall(const ir::Call* call_ir) {
call_args.push_back(b_->CreateLoad(
b_->getInt8PtrTy(), kvalue, r_arg.as_var()->name + "_ptr_load"));
} else if (r_arg.as_var()->type().is_cpp_handle()) {
CHECK(global_args.count(r_arg.as_var()->name));
PADDLE_ENFORCE_EQ(
global_args.count(r_arg.as_var()->name),
1,
phi::errors::InvalidArgument(
"The argument '%s' must be present in global_args.",
r_arg.as_var()->name.c_str()));
call_args.push_back(global_args[r_arg.as_var()->name]);
} else if (r_arg.as_var()->type().is_int()) {
call_args.push_back(GetVar(r_arg.as_var()->name, false));
Expand Down
19 changes: 15 additions & 4 deletions paddle/cinn/backends/codegen_invoke_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,17 @@ llvm::Value* CodeGenInvokeModule::LowerParseArgsValueCall(
2,
::common::errors::InvalidArgument(
"The number of arguments of ParseArgsValue should be 2"));
CHECK(call_ir->read_args[0].is_var() &&
call_ir->read_args[0].as_var()->type().is_cpp_handle());
CHECK(call_ir->read_args[1].type().is_int(32));
PADDLE_ENFORCE_EQ(
call_ir->read_args[0].is_var() &&
call_ir->read_args[0].as_var()->type().is_cpp_handle(),
true,
phi::errors::InvalidArgument("The first read argument must be a variable "
"with a C++ handle type."));

PADDLE_ENFORCE_EQ(call_ir->read_args[1].type().is_int(32),
true,
phi::errors::InvalidArgument(
"The second read argument must be of type int32."));
args_type.push_back(CinnTypeToLLVMType(type_of<void*>(), m_));
args_type.push_back(CinnTypeToLLVMType(type_of<int32_t>(), m_));

Expand All @@ -94,7 +102,10 @@ llvm::Value* CodeGenSwitchHost::LowerInnerCaseCall(const ir::Call* op) {
[](auto& arg) { return std::addressof(arg); });
// TODO(Hongqing-work): Add check for parameter type
llvm::Function* call_func = m_->getFunction(op->name);
CHECK(call_func) << "Unknown function referenced. [" << op->name << "]";
PADDLE_ENFORCE_NOT_NULL(
call_func,
phi::errors::InvalidArgument("Unknown function referenced. [%s]",
op->name.c_str()));
b_->CreateCall(call_func, ll_function_args);
return nullptr;
}
Expand Down
20 changes: 13 additions & 7 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,10 @@ void Compiler::RegisterCudaModuleSymbol() {
nvrtc::Compiler compiler;
std::string source_code = CodeGenCudaDev::GetSourceHeader() + device_fn_code_;
auto ptx = compiler(source_code);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n"
<< source_code;
PADDLE_ENFORCE_EQ(
!ptx.empty(),
true,
phi::errors::InvalidArgument("Compile PTX failed from source code\n"));
using runtime::cuda::CUDAModule;
cuda_module_.reset(new CUDAModule(ptx,
compiler.compile_to_cubin()
Expand All @@ -341,7 +343,9 @@ void Compiler::RegisterCudaModuleSymbol() {
RuntimeSymbols symbols;
for (const auto& kernel_fn_name : device_fn_name_) {
auto fn_kernel = cuda_module_->GetFunction(kernel_fn_name);
CHECK(fn_kernel) << "Fail to get CUfunction kernel_fn_name";
PADDLE_ENFORCE_NOT_NULL(
fn_kernel,
phi::errors::InvalidArgument("Fail to get CUfunction kernel_fn_name"));
fn_ptr_.push_back(reinterpret_cast<void*>(fn_kernel));
symbols.RegisterVar(kernel_fn_name + "_ptr_",
reinterpret_cast<void*>(fn_kernel));
Expand Down Expand Up @@ -407,9 +411,10 @@ void Compiler::CompileCudaModule(const Module& module,
source_code = code;
}

CHECK(!source_code.empty())
<< "Compile CUDA C code failed from device module:\n"
<< device_module;
PADDLE_ENFORCE_EQ(!source_code.empty(),
true,
phi::errors::InvalidArgument(
"Compile CUDA C code failed from device module"));
VLOG(3) << "[CUDA] C:\n" << source_code;
SourceCodePrint::GetInstance()->write(source_code);
device_fn_code_ += source_code;
Expand Down Expand Up @@ -470,7 +475,8 @@ void Compiler::ExportObject(const std::string& path) {
}

void* Compiler::Lookup(absl::string_view fn_name) {
CHECK(engine_);
PADDLE_ENFORCE_NOT_NULL(
engine_, phi::errors::InvalidArgument("Sorry, engine_ is nullptr"));
if (engine_->Lookup(fn_name) != nullptr) {
return engine_->Lookup(fn_name);
}
Expand Down
10 changes: 8 additions & 2 deletions paddle/cinn/backends/extern_func_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ void ExternFunctionEmitterRegistry::Register(const ExternFuncID& name,
utils::GetStreamCnt(name).c_str());
}
#endif // CINN_WITH_DEBUG
CHECK(!x.empty()) << "Extern Function name is empty.";
PADDLE_ENFORCE_EQ(
!x.empty(),
true,
phi::errors::InvalidArgument("Extern Function name is empty."));
data_[name] = x;
}

Expand All @@ -68,7 +71,10 @@ ExternFunctionEmitterRegistry::ExternFunctionEmitterRegistry() {}

const FunctionProto& ExternFunctionEmitter::func_proto() const {
auto* proto = ExternFunctionProtoRegistry::Global().Lookup(func_name());
CHECK(proto) << "No prototype of function [" << func_name() << "]";
PADDLE_ENFORCE_NOT_NULL(
proto,
phi::errors::InvalidArgument("No prototype of function [" +
std::string(func_name()) + "]"));
return *proto;
}

Expand Down
11 changes: 7 additions & 4 deletions paddle/cinn/backends/llvm/execution_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,18 @@ void ExecutionEngine::Link(const ir::Module &module) {
VLOG(3) << "ir_emitter->Compile(module) Begin";
ir_emitter->Compile(module);
VLOG(3) << "ir_emitter->Compile(module) Succeed!";
CHECK(!llvm::verifyModule(*m, &llvm::errs())) << "Invalid module found";

PADDLE_ENFORCE_EQ(!llvm::verifyModule(*m, &llvm::errs()),
true,
phi::errors::InvalidArgument("Sorry,Invalid module found"));
auto machine = std::move(llvm::cantFail(
llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost())
.createTargetMachine()));
LLVMModuleOptimizer optimize(machine.get(), 3, {}, true);
optimize(m.get());
CHECK(!llvm::verifyModule(*m, &llvm::errs()))
<< "Invalid optimized module detected";
PADDLE_ENFORCE_EQ(
!llvm::verifyModule(*m, &llvm::errs()),
true,
phi::errors::InvalidArgument("Invalid optimized module detected"));
for (auto &f : *m) {
VLOG(5) << "function: " << DumpToString(f);
}
Expand Down
34 changes: 23 additions & 11 deletions paddle/cinn/common/arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,19 @@ GiNaC::ex ExprToGinacConverter::operator()(Expr expr) {
n->As<IfThenElse>();
});

CHECK(complex_nodes.empty()) << "Ginac converter can only deal with simple "
"math expression, but get some complex nodes"
<< expr;

PADDLE_ENFORCE_EQ(complex_nodes.empty(),
true,
::common::errors::InvalidArgument(
"Ginac converter can only deal with simple math "
"expression, but get some complex nodes."));
return BuildHelper(expr);
}

GiNaC::symbol ExprToGinacConverter::CreateGinacSymbol(const std::string& repr) {
CHECK(!repr.empty());
PADDLE_ENFORCE_EQ(
!repr.empty(),
true,
::common::errors::InvalidArgument("The repr should not be empty."));
auto it = repr_to_ginac_.find(repr);
if (it != repr_to_ginac_.end()) return it->second;

Expand All @@ -165,7 +169,9 @@ GiNaC::symbol ExprToGinacConverter::CreateGinacSymbol(const std::string& repr) {
}

GiNaC::symbol ExprToGinacConverter::CreateGinacSymbol(const ir::Expr& var) {
CHECK(var.As<_Var_>());
PADDLE_ENFORCE_NOT_NULL(
var.As<_Var_>(),
::common::errors::InvalidArgument("The var should not be nullptr."));
return CreateGinacSymbol(Repr(var));
}

Expand All @@ -191,8 +197,10 @@ class GiNaCToExprVisitor : public GiNaC::symbol::visitor,

void visit(const GiNaC::symbol& node) override {
auto it = repr_to_expr.find(node.get_name());
CHECK(it != repr_to_expr.end())
<< "node [" << node.get_name() << "] not found";
PADDLE_ENFORCE_NE(
it,
repr_to_expr.end(),
::common::errors::InvalidArgument("The node should be found."));
cur = it->second;
}

Expand Down Expand Up @@ -221,7 +229,9 @@ class GiNaCToExprVisitor : public GiNaC::symbol::visitor,
node.op(1).accept(*this);

auto* intv = cur.As<IntImm>();
CHECK(intv);
PADDLE_ENFORCE_NOT_NULL(
intv,
::common::errors::InvalidArgument("The intv should not be nullptr."));
PADDLE_ENFORCE_EQ(
intv->value,
-1,
Expand Down Expand Up @@ -313,8 +323,10 @@ std::tuple<Expr, bool /*positive*/> Solve(Expr lhs, Expr rhs, Var var) {
// tell the symbol
auto diff = lhs_ex - rhs_ex;
auto diff_res = ginac::diff(diff, symbol);
CHECK(!diff_res.is_zero());

PADDLE_ENFORCE_EQ(
!diff_res.is_zero(),
true,
::common::errors::InvalidArgument("The diff_res should not be zero."));
return std::make_tuple(value, diff_res > 0);
}

Expand Down
10 changes: 8 additions & 2 deletions paddle/cinn/common/dev_info_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,17 @@ class DevInfoMgr final {
using RetType = typename GetDevType<arch>::DevType;

const RetType* operator->() const {
CHECK(!std::is_void<RetType>()) << "Current device can't be recognized!\n";
PADDLE_ENFORCE_EQ(
!std::is_void<RetType>(),
true,
phi::errors::InvalidArgument("Current device can't be recognized!"));
return dynamic_cast<const RetType*>(impl_.get());
}
RetType* operator->() {
CHECK(!std::is_void<RetType>()) << "Current device can't be recognized!\n";
PADDLE_ENFORCE_EQ(
!std::is_void<RetType>(),
true,
phi::errors::InvalidArgument("Current device can't be recognized!"));
return dynamic_cast<RetType*>(impl_.get());
}
};
Expand Down
Loading

0 comments on commit 9a050a3

Please sign in to comment.