Skip to content

Commit

Permalink
[CINN][New Hardware Update] extend SplitCudaAndHostModule (#64345)
Browse files Browse the repository at this point in the history
* [CINN][New Hardware Update] rename SplitCudaAndHostModule

* rename SplitCudaAndHostModule to SplitDeviceAndHostModule

* [CINN][New Hardware Update] fix CMakeLists

* [CINN][New Hardware Update] extend SplitDeviceAndHostModule

* fix review
  • Loading branch information
DongBaiYue authored May 17, 2024
1 parent b55bb3a commit 2188b4a
Show file tree
Hide file tree
Showing 19 changed files with 89 additions and 38 deletions.
5 changes: 3 additions & 2 deletions paddle/cinn/backends/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ gather_srcs(
extern_func_protos.cc
extern_func_jit_register.cc
modular.cc
compiler.cc)
compiler.cc
codegen_device_util.cc)

if(WITH_CUDA)
add_subdirectory(nvrtc)
list(APPEND srcs cuda_util.cc codegen_cuda_dev.cc codegen_cuda_util.cc)
list(APPEND srcs cuda_util.cc codegen_cuda_dev.cc)
endif()

if(WITH_OPENMP)
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_generate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/extern_func_jit_register.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/llvm/simple_jit.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <string>
#include <unordered_map>

#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/extern_func_emitter_builtin.h"
#include "paddle/cinn/backends/extern_func_jit_register.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"

#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/cas.h"
Expand All @@ -22,7 +22,7 @@ PD_DECLARE_bool(cinn_bucket_compile);
namespace cinn {
namespace backends {

std::tuple<ir::Module, ir::Module> SplitCudaAndHostModule(ir::Module module) {
std::tuple<ir::Module, ir::Module> SplitDeviceAndHostModule(ir::Module module) {
if (FLAGS_cinn_bucket_compile) {
detail::CollectBucketStrategyHostFunctionVisitor visitor(module->name);
Expr expr(module);
Expand Down Expand Up @@ -91,7 +91,16 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
ir::Var kernel_ptr(GenDeviceKernelName(func_node->name, predicate),
type_of<std::string>());

Expr shared_mem_bytes = CalculateSharedMemory(func);
std::optional<Expr> shared_mem_bytes;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
CINN_NOT_IMPLEMENTED;
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
shared_mem_bytes = CalculateSharedMemory(func);
#endif
});

VLOG(6) << "Add a call node for func_node->name " << func_node->name << "\n"
<< "grid_dim: (" << func_node->cuda_axis_info.grid_dim(0) << ", "
Expand All @@ -100,10 +109,18 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
<< "block_dim: (" << func_node->cuda_axis_info.block_dim(0) << ", "
<< func_node->cuda_axis_info.block_dim(1) << ", "
<< func_node->cuda_axis_info.block_dim(2) << "), "
<< "shared_mem: " << shared_mem_bytes;
<< "shared_mem: " << shared_mem_bytes.value();
std::optional<const char *> call_kernel;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
CINN_NOT_IMPLEMENTED;
},
[&](common::NVGPUArch) {
call_kernel = runtime::intrinsic::call_cuda_kernel;
});
ir::Expr call_extern_api =
ir::Call::Make(Void(),
runtime::intrinsic::call_cuda_kernel,
call_kernel.value(),
{kernel_ptr,
kernel_args_,
kernel_args_num_,
Expand All @@ -113,7 +130,7 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
func_node->cuda_axis_info.block_dim(0), // block_x
func_node->cuda_axis_info.block_dim(1), // block_y
func_node->cuda_axis_info.block_dim(2), // block_z
shared_mem_bytes, // shared_mem
shared_mem_bytes.value(), // shared_mem
kernel_stream_},
{},
ir::CallType::Extern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
#include <string>
#include <tuple>
#include <vector>

#ifdef CINN_WITH_CUDA
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#endif
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/runtime/flags.h"

namespace cinn {
namespace backends {
Expand All @@ -43,7 +45,7 @@ namespace backends {
* - replace the original kernel function with a Call node and add it to the
* first module, add a device kernel function to the second module.
*/
std::tuple<ir::Module, ir::Module> SplitCudaAndHostModule(ir::Module module);
std::tuple<ir::Module, ir::Module> SplitDeviceAndHostModule(ir::Module module);

namespace detail {

Expand All @@ -52,7 +54,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
: host_module_builder(module_name + "_host",
cinn::common::DefaultHostTarget()),
device_module_builder(module_name + "_gpu_device",
cinn::common::DefaultNVGPUTarget()) {}
cinn::common::DefaultDeviceTarget()) {}

std::tuple<ir::Module, ir::Module> operator()(Expr* expr) {
ir::IRMutator<>::Visit(expr, expr);
Expand Down Expand Up @@ -109,9 +111,18 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
// shared_mem_bytes Can be calculated after codegen_cuda_dev buffer creation
// however, this make CodeGenCUDA_Dev before spliting the host and device
// module Maybe we could reorder the process.
CodeGenCUDA_Dev codegen_dev(cinn::common::DefaultNVGPUTarget());
codegen_dev.Compile(ir::LoweredFunc(func));
Expr shared_mem_bytes = codegen_dev.GetDynSharedMemOffset();
std::optional<Expr> shared_mem_bytes;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) { CINN_NOT_IMPLEMENTED; },
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
CodeGenCUDA_Dev codegen_dev(cinn::common::DefaultNVGPUTarget());
codegen_dev.Compile(ir::LoweredFunc(func));
shared_mem_bytes = codegen_dev.GetDynSharedMemOffset();
#endif
});

VLOG(6) << "Add a call node for func->name " << func->name << "\n"
<< "grid_dim: (" << func->cuda_axis_info.grid_dim(0) << ", "
Expand All @@ -120,10 +131,20 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
<< "block_dim: (" << func->cuda_axis_info.block_dim(0) << ", "
<< func->cuda_axis_info.block_dim(1) << ", "
<< func->cuda_axis_info.block_dim(2) << "), "
<< "shared_mem: " << shared_mem_bytes;
<< "shared_mem: " << shared_mem_bytes.value();

std::optional<const char*> call_kernel;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) { CINN_NOT_IMPLEMENTED; },
[&](common::NVGPUArch) {
call_kernel = runtime::intrinsic::call_cuda_kernel;
});

auto call_extern_api =
ir::Call::Make(Void(),
runtime::intrinsic::call_cuda_kernel,
call_kernel.value(),
{kernel_ptr,
kernel_args,
kernel_args_num,
Expand All @@ -133,7 +154,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
func->cuda_axis_info.block_dim(0), // block_x
func->cuda_axis_info.block_dim(1), // block_y
func->cuda_axis_info.block_dim(2), // block_z
shared_mem_bytes,
shared_mem_bytes.value(),
kernel_stream},
{},
ir::CallType::Extern,
Expand Down
7 changes: 4 additions & 3 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include "paddle/cinn/runtime/cuda/cuda_module.h"
#include "paddle/cinn/runtime/cuda/cuda_util.h"
Expand Down Expand Up @@ -246,7 +246,7 @@ std::string Compiler::GetSourceCode(const ir::Module& module) {
[&](common::NVGPUArch) -> std::string {
#ifdef CINN_WITH_CUDA
auto _host_module_device_module_ =
SplitCudaAndHostModule(module); // NOLINT
SplitDeviceAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
CodeGenCUDA_Dev codegen(target_);
Expand All @@ -270,7 +270,8 @@ void Compiler::BuildDefault(const Module& module) {
void Compiler::CompileCudaModule(const Module& module,
const std::string& code) {
#ifdef CINN_WITH_CUDA
auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT
auto _host_module_device_module_ =
SplitDeviceAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
VLOG(3) << "[CUDA] host module:\n" << host_module;
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/common/arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <functional>
#include <ostream>
#include <variant>
#include "paddle/common/overloaded.h"

namespace cinn {
namespace common {
Expand Down Expand Up @@ -45,6 +46,8 @@ struct Arch final : public ArchBase {
return static_cast<const ArchBase&>(*this);
}

DEFINE_MATCH_METHOD();

bool operator==(const auto& other) const {
return this->index() == other.index();
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/common/cuda_test_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include "paddle/cinn/runtime/cuda/cuda_module.h"
#include "paddle/cinn/runtime/cuda/cuda_util.h"
Expand All @@ -28,7 +28,7 @@ namespace common {
void CudaModuleTester::Compile(const ir::Module& m,
const std::string& rewrite_cuda_code) {
auto _host_module_device_module_ =
backends::SplitCudaAndHostModule(m); // NOLINT
backends::SplitDeviceAndHostModule(m); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
CHECK(!host_module.functions().empty());
Expand Down
6 changes: 6 additions & 0 deletions paddle/cinn/common/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ const Target &DefaultNVGPUTarget() {
return target;
}

const Target &DefaultDeviceTarget() {
#ifdef CINN_WITH_CUDA
return DefaultNVGPUTarget();
#endif
}

int GetMaxThreads() {
// cudaDeviceGetAttribute ( int* value, cudaDeviceAttr attr, int device )
int max_threads = 1;
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/common/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ const Target& DefaultHostTarget();

const Target& DefaultNVGPUTarget();

const Target& DefaultDeviceTarget();

const Target& DefaultTarget();

int GetMaxThreads();
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/frontend/paddle/model_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/frontend/paddle/compatible_pb.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/graph_compiler_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct CompilationContext {
void* stream = nullptr;

// Set attached source code, if code is not empty, these codes will replace
// the device_module code after SplitCudaAndHostModule.
// the device_module code after SplitDeviceAndHostModule.
void ApplySourceCode(const std::string& code);
// Apply results of auto-tune to compile.
// Compilation will start from CompilationStage::CODEGEN_AND_JIT when tuning
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/op_lowering_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include "paddle/cinn/backends/codegen_c_x86.h"
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/framework/parallel_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/compiler.h"
#include "paddle/cinn/backends/llvm/codegen_x86.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
Expand Down Expand Up @@ -238,7 +238,7 @@ void ParallelCompiler::Task::CodegenAndJit() {
auto ir_module = builder.Build();
if (context->target == cinn::common::DefaultNVGPUTarget()) {
#ifdef CINN_WITH_CUDA
auto splited_module = backends::SplitCudaAndHostModule(ir_module);
auto splited_module = backends::SplitDeviceAndHostModule(ir_module);
auto hmodule = std::get<0>(splited_module);
auto dmodule = std::get<1>(splited_module);

Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include "paddle/cinn/adt/map_expr_ctx.h"
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/framework/compile_error.h"
#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/op/custom_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/op/reduction_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
Expand Down Expand Up @@ -116,7 +116,7 @@ std::pair<ir::Module, std::string> GenReduceCode(
// now.
auto module = builder.Build();
auto host_module_device_module =
backends::SplitCudaAndHostModule(module); // NOLINT
backends::SplitDeviceAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(host_module_device_module);
auto& device_module = std::get<1>(host_module_device_module);

Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/op/transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
Expand Down
8 changes: 4 additions & 4 deletions paddle/cinn/hlir/pe/pe_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <gtest/gtest.h>

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
Expand Down Expand Up @@ -132,7 +132,7 @@ TEST(ScatterAssign, ScatterAssign) {
builder.AddFunction(func);

auto module = builder.Build();
auto host_module_device_module = backends::SplitCudaAndHostModule(module);
auto host_module_device_module = backends::SplitDeviceAndHostModule(module);
auto &host_module = std::get<0>(host_module_device_module);
auto &device_module = std::get<1>(host_module_device_module);

Expand Down Expand Up @@ -176,7 +176,7 @@ TEST(SliceAssign, SliceAssign) {
builder.AddFunction(func);

auto module = builder.Build();
auto host_module_device_module = backends::SplitCudaAndHostModule(module);
auto host_module_device_module = backends::SplitDeviceAndHostModule(module);
auto &host_module = std::get<0>(host_module_device_module);
auto &device_module = std::get<1>(host_module_device_module);

Expand Down Expand Up @@ -217,7 +217,7 @@ TEST(Concat, ConcatCase0) {
builder.AddFunction(func);

auto module = builder.Build();
auto host_module_device_module = backends::SplitCudaAndHostModule(module);
auto host_module_device_module = backends::SplitDeviceAndHostModule(module);
auto &host_module = std::get<0>(host_module_device_module);
auto &device_module = std::get<1>(host_module_device_module);

Expand Down

0 comments on commit 2188b4a

Please sign in to comment.