Skip to content

Commit

Permalink
Merge branch 'develop' into fix_gate
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Jul 18, 2024
2 parents 990d29e + 1c15774 commit 2d0d54a
Show file tree
Hide file tree
Showing 162 changed files with 1,769 additions and 1,233 deletions.
13 changes: 8 additions & 5 deletions paddle/cinn/backends/codegen_device_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ ir::Module CreateSwitchWithBroadcastConditionModule(
ir::Argument(kernel_args_num, ir::Argument::IO::kInput),
ir::Argument(tensor_shape_args, ir::Argument::IO::kOutput)};

const auto &CreateSymbolArgDefines = [&]() -> std::vector<ir::Expr> {
const auto &symbolic_arg_define = [&]() -> std::vector<ir::Expr> {
std::vector<ir::Expr> arg_defs;
for (const auto &item : symbolic_shape_var_index) {
ir::Expr call_get_value_in_kernel_args =
Expand All @@ -68,13 +68,13 @@ ir::Module CreateSwitchWithBroadcastConditionModule(
arg_defs.push_back(stmt);
}
return arg_defs;
};
}();

const auto &CreateSwitchFunction =
[&](std::vector<ir::Argument> func_arguments,
const std::vector<ir::Expr> &read_args,
std::string name_extend) -> ir::Expr {
std::vector<ir::Expr> body_stmts(CreateSymbolArgDefines());
std::vector<ir::Expr> body_stmts(symbolic_arg_define);
for (int i = 0; i < broadcast_conditions.size(); ++i) {
ir::Expr callee = ir::Call::Make(Void(),
case_func_names[i] + name_extend,
Expand Down Expand Up @@ -113,8 +113,11 @@ ir::Module CreateSwitchWithBroadcastConditionModule(
module_builder.AddFunctionWithoutOptim(
infer_shape_func_caller.as_lowered_func_ref());
// no need cx86 func
ir::Expr cx86_func_caller = ir::_LoweredFunc_::Make(
wrapper_func_name + "_CX86", host_func_arguments, ir::Expr(), {});
ir::Expr cx86_func_caller =
ir::_LoweredFunc_::Make(wrapper_func_name + "_CX86",
host_func_arguments,
ir::Block::Make({}),
{});
module_builder.AddFunctionWithoutOptim(
cx86_func_caller.as_lowered_func_ref());
return module_builder.Build();
Expand Down
41 changes: 19 additions & 22 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,32 +227,30 @@ void SourceCodePrint::write(const std::string& source_code) {
}
}

void Compiler::Build(const Module& module,
const std::string& code,
const bool end) {
auto PatternMatch = adt::match{
void Compiler::Build(const Module& module, const std::string& code) {
target_.arch.Match(
[&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; },
[&](common::X86Arch) { CompileX86Module(module); },
[&](common::ARMArch) { CINN_NOT_IMPLEMENTED; },
[&](common::NVGPUArch) { CompileCudaModule(module, code); },
[&](common::HygonDCUArchHIP) { CompileHipModule(module, code); }};
std::visit(PatternMatch, target_.arch.variant());
if (end) {
RegisterDeviceModuleSymbol();
engine_->AddSelfModule();
}
[&](common::HygonDCUArchHIP) { CompileHipModule(module, code); });
}

void Compiler::AppendCX86(const Module& module, const bool end) {
void Compiler::AppendCX86(const Module& module) {
VLOG(3) << "Start Compiler::BuildCX86" << module;
CompileX86Module(module);
if (end) {
RegisterDeviceModuleSymbol();
engine_->AddSelfModule();
}
VLOG(3) << "Over Compiler::BuildCX86";
}

void Compiler::AppendBroadcastSwitchModule(const ir::Module& module) {
engine_->Link<CodeGenSwitchHost>(module);
}

void Compiler::EndCompile() {
RegisterDeviceModuleSymbol();
engine_->AddSelfModule();
}

std::string Compiler::GetSourceCode(const ir::Module& module) {
return target_.arch.Match(
[&](common::UnknownArch) -> std::string { CINN_NOT_IMPLEMENTED; },
Expand Down Expand Up @@ -305,13 +303,12 @@ std::string GetFileContent(const std::string& path) {
} // namespace

void Compiler::RegisterDeviceModuleSymbol() {
auto PatternMatch =
adt::match{[&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; },
[&](common::X86Arch) { return; },
[&](common::ARMArch) { return; },
[&](common::NVGPUArch) { RegisterCudaModuleSymbol(); },
[&](common::HygonDCUArchHIP) { CINN_NOT_IMPLEMENTED; }};
return std::visit(PatternMatch, target_.arch.variant());
return target_.arch.Match(
[&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; },
[&](common::X86Arch) { return; },
[&](common::ARMArch) { return; },
[&](common::NVGPUArch) { RegisterCudaModuleSymbol(); },
[&](common::HygonDCUArchHIP) { CINN_NOT_IMPLEMENTED; });
}

void Compiler::RegisterCudaModuleSymbol() {
Expand Down
11 changes: 7 additions & 4 deletions paddle/cinn/backends/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,13 @@ class Compiler final {
/**
* Compile and link to a CINN module.
*/
void Build(const ir::Module& module,
const std::string& code = "",
const bool end = true);
void AppendCX86(const ir::Module& module, const bool end = true);
void Build(const ir::Module& module, const std::string& code = "");

void AppendCX86(const ir::Module& module);

void AppendBroadcastSwitchModule(const ir::Module& module);

void EndCompile();

void ExportObject(const std::string& path);

Expand Down
4 changes: 3 additions & 1 deletion paddle/cinn/backends/llvm/execution_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ void ExecutionEngine::RegisterModuleRuntimeSymbols(
module_symbols_ = std::forward<RuntimeSymbols>(module_symbols);
auto *session = &jit_->getExecutionSession();
for (const auto &sym : module_symbols_.All()) {
VLOG(0) << "Add symbol: {" << sym.first << ":" << sym.second << "}";
VLOG(3) << "Add symbol: {" << sym.first << ":" << sym.second << "}";
llvm::cantFail(jit_->define(llvm::orc::absoluteSymbols(
{{session->intern(sym.first),
{llvm::pointerToJITTargetAddress(sym.second),
Expand Down Expand Up @@ -276,5 +276,7 @@ void ExecutionEngine::RegisterGlobalRuntimeSymbols() {
template void ExecutionEngine::Link<CodeGenLLVM>(const ir::Module &module);
template void ExecutionEngine::Link<CodeGenX86>(const ir::Module &module);
template void ExecutionEngine::Link<CodeGenCUDA_Host>(const ir::Module &module);
template void ExecutionEngine::Link<CodeGenSwitchHost>(
const ir::Module &module);

} // namespace cinn::backends
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ class BlockDimExprsAsserter {
for (int i = 0; i < op->num_operands(); ++i) {
const auto& input = op->operand_source(i);
if (!input || !input.type()) continue;
if (input.type().isa<pir::VectorType>()) {
if (input.type().isa<pir::VectorType>() ||
input.type().isa<paddle::dialect::DenseTensorArrayType>()) {
return std::vector<pir::Value>{};
}
inputs.push_back(input);
Expand Down
Loading

0 comments on commit 2d0d54a

Please sign in to comment.