Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

module mlir supported on ROCm #99

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/gpu/include/tfrt/gpu/wrapper/hip_forwards.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ using ncclComm_t = struct ncclComm *;

// Forward declaration of hipFFT types.
using hipfftHandle = struct hipfftHandle_t *;
// Forward declaration of hiprtcProgram
using hiprtcProgram = struct _hiprtcProgram *;
// Enums for corresponding #defines in the hipFFT headers.
enum hipfftDirection_t : int {
HIPFFT_FORWARD = -1,
Expand Down
23 changes: 23 additions & 0 deletions backends/gpu/include/tfrt/gpu/wrapper/hip_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,29 @@ extern "C" {

const char* hipGetErrorName(hipError_t hip_error);
const char* hipGetErrorString(hipError_t hip_error);
const char *hiprtcGetErrorString(hiprtcResult result);
hiprtcResult hiprtcVersion(int* major, int* minor);
hiprtcResult hiprtcAddNameExpression(hiprtcProgram prog, const char* name_expression);
hiprtcResult hiprtcCompileProgram(
hiprtcProgram prog,
int numOptions,
const char** options);
hiprtcResult hiprtcCreateProgram(
hiprtcProgram* prog,
const char* src,
const char* name,
int numberHeaders,
char** headers,
const char** includeNames);
hiprtcResult hiprtcDestroyProgram(hiprtcProgram* prog);
hiprtcResult hiprtcGetLoweredName(
hiprtcProgram prog,
const char* name_expression,
const char** lowered_name);
hiprtcResult hiprtcGetProgramLog(hiprtcProgram prog, char* log);
hiprtcResult hiprtcGetProgramLogSize(hiprtcProgram prog, size_t* logSizeRet);
hiprtcResult hiprtcGetCode(hiprtcProgram prog, char* code);
hiprtcResult hiprtcGetCodeSize(hiprtcProgram prog, size_t* codeSizeRet);

// Enums for corresponding #defines in the HIP headers.
enum hipDeviceFlags_t {
Expand Down
3 changes: 3 additions & 0 deletions backends/gpu/include/tfrt/gpu/wrapper/hip_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace gpu {
namespace wrapper {

raw_ostream& Print(raw_ostream& os, hipError_t error);
raw_ostream& Print(raw_ostream& os, hiprtcResult result);

namespace internal {
template <>
Expand Down Expand Up @@ -162,6 +163,8 @@ llvm::Error HipMemsetD32Async(CurrentContext current, Pointer<void> dst,

llvm::Expected<OwningModule> HipModuleLoadData(CurrentContext current,
const void* image);
llvm::Expected<OwningModule> HipRTCModuleLoadData(CurrentContext current,
const void* image);
llvm::Expected<OwningModule> HipModuleLoadDataEx(
CurrentContext current, const void* image,
llvm::ArrayRef<hipJitOption> options, llvm::ArrayRef<void*> option_values);
Expand Down
1 change: 0 additions & 1 deletion backends/gpu/lib/kernels/driver_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ static Expected<GpuModule> GpuModuleLoad(Argument<GpuContext> context,
MakeStringError("GPU JIT error log: ", error_log));
}
#endif

return GpuModule(context.ValueRef(), std::move(*module));
}

Expand Down
2 changes: 1 addition & 1 deletion backends/gpu/lib/wrapper/driver_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ llvm::Expected<OwningModule> ModuleLoadData(CurrentContext current,
case Platform::CUDA:
return CuModuleLoadData(current, image);
case Platform::ROCm:
return HipModuleLoadData(current, image);
return HipRTCModuleLoadData(current, image);
default:
return InvalidPlatform(platform);
}
Expand Down
89 changes: 89 additions & 0 deletions backends/gpu/lib/wrapper/hip_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,92 @@ const char *hipGetErrorString(hipError_t hip_error) {
if (!func_ptr) return "FAILED_TO_LOAD_FUNCTION_SYMBOL";
return func_ptr(hip_error);
}

const char *hiprtcGetErrorString(hiprtcResult result) {
static auto func_ptr =
GetFunctionPointer("hiprtcGetErrorString", hiprtcGetErrorString);
if (!func_ptr) return "FAILED_TO_LOAD_FUNCTION_SYMBOL";
return func_ptr(result);
}

hiprtcResult hiprtcVersion(int* major, int* minor){
static auto func_ptr =
GetFunctionPointer("hiprtcVersion", hiprtcVersion);
if (!func_ptr) return HIPRTC_ERROR_INTERNAL_ERROR;
return func_ptr(major, minor);
}

hiprtcResult hiprtcAddNameExpression(hiprtcProgram prog, const char* name_expression){
static auto func_ptr =
GetFunctionPointer("hiprtcAddNameExpression", hiprtcAddNameExpression);
if (!func_ptr) return HIPRTC_ERROR_NAME_EXPRESSION_NOT_VALID;
return func_ptr(prog, name_expression);
}

hiprtcResult hiprtcCompileProgram(
hiprtcProgram prog,
int numOptions,
const char** options){
static auto func_ptr =
GetFunctionPointer("hiprtcCompileProgram", hiprtcCompileProgram);
if (!func_ptr) return HIPRTC_ERROR_INTERNAL_ERROR;
return func_ptr(prog, numOptions, options);
}

hiprtcResult hiprtcCreateProgram(
hiprtcProgram* prog,
const char* src,
const char* name,
int numberHeaders,
char** headers,
const char** includeNames){
static auto func_ptr =
GetFunctionPointer("hiprtcCreateProgram", hiprtcCreateProgram);
if (!func_ptr) return HIPRTC_ERROR_PROGRAM_CREATION_FAILURE;
return func_ptr(prog, src, name, numberHeaders, headers, includeNames);
}

hiprtcResult hiprtcDestroyProgram(hiprtcProgram* prog){
static auto func_ptr =
GetFunctionPointer("hiprtcDestroyProgram", hiprtcDestroyProgram);
if (!func_ptr) return HIPRTC_ERROR_INTERNAL_ERROR;
return func_ptr(prog);
}

hiprtcResult hiprtcGetLoweredName(
hiprtcProgram prog,
const char* name_expression,
const char** lowered_name){
static auto func_ptr =
GetFunctionPointer("hiprtcGetLoweredName", hiprtcGetLoweredName);
if (!func_ptr) return HIPRTC_ERROR_INTERNAL_ERROR;
return func_ptr(prog, name_expression, lowered_name);
}

hiprtcResult hiprtcGetProgramLog(hiprtcProgram prog, char* log){
static auto func_ptr =
GetFunctionPointer("hiprtcGetProgramLog", hiprtcGetProgramLog);
if (!func_ptr) return HIPRTC_ERROR_INTERNAL_ERROR;
return func_ptr(prog, log);
}

hiprtcResult hiprtcGetProgramLogSize(hiprtcProgram prog, size_t* logSizeRet){
static auto func_ptr =
GetFunctionPointer("hiprtcGetProgramLogSize", hiprtcGetProgramLogSize);
if (!func_ptr) return HIPRTC_ERROR_INTERNAL_ERROR;
return func_ptr(prog, logSizeRet);
}

hiprtcResult hiprtcGetCode(hiprtcProgram prog, char* code){
static auto func_ptr =
GetFunctionPointer("hiprtcGetCode", hiprtcGetCode);
if (!func_ptr) return HIPRTC_ERROR_INTERNAL_ERROR;
return func_ptr(prog, code);
}

hiprtcResult hiprtcGetCodeSize(hiprtcProgram prog, size_t* codeSizeRet){
static auto func_ptr =
GetFunctionPointer("hiprtcGetCodeSize", hiprtcGetCodeSize);
if (!func_ptr) return HIPRTC_ERROR_INTERNAL_ERROR;
return func_ptr(prog, codeSizeRet);
}
45 changes: 45 additions & 0 deletions backends/gpu/lib/wrapper/hip_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ llvm::raw_ostream& Print(llvm::raw_ostream& os, hipError_t error) {
return os;
}

llvm::raw_ostream& Print(llvm::raw_ostream& os, hiprtcResult result) {
const char* msg = hiprtcGetErrorString(result);
if (msg != nullptr) os << "hiprtc Error: (" << msg << ")";
return os;
}

// Convert wrapper types to HIP types.
static hipDevice_t ToRocm(Device device) { return device.id(Platform::ROCm); }

Expand Down Expand Up @@ -540,6 +546,45 @@ llvm::Expected<OwningModule> HipModuleLoadData(CurrentContext current,
return OwningModule(module);
}

llvm::Expected<OwningModule> HipRTCModuleLoadData(CurrentContext current,
const void* image) {
CheckHipContext(current);
hiprtcProgram prog;
//auto img = reinterpret_cast<const char*>(const_cast<void*>(image));
auto kernel = static_cast<const char*>(image);
std::string kname(kernel);
kname += ".cu";
RETURN_IF_ERROR(hiprtcCreateProgram(&prog,
kernel,
kname.c_str(),
0,
nullptr,
nullptr
));
hiprtcResult compileResult = hiprtcCompileProgram(prog, 0, nullptr);
if (compileResult != HIPRTC_SUCCESS) {
size_t logSize;
hiprtcGetProgramLogSize(prog, &logSize);
if (logSize) {
std::string log(logSize, '\0');
hiprtcGetProgramLog(prog, &log[0]);
MakeStringError(log.c_str());
}
}

size_t code_size;
RETURN_IF_ERROR(hiprtcGetCodeSize(prog, &code_size));
std::vector<char> code(code_size);
RETURN_IF_ERROR(hiprtcGetCode(prog, code.data()));
RETURN_IF_ERROR(hiprtcDestroyProgram(&prog));

hipModule_t module;
RETURN_IF_ERROR(hipModuleLoadData(&module, code.data()));

NotifyResourceCreated(ResourceType::kModule, module);
return OwningModule(module);
}

llvm::Expected<OwningModule> HipModuleLoadDataEx(
CurrentContext current, const void* image,
llvm::ArrayRef<hipJitOption> options, llvm::ArrayRef<void*> option_values) {
Expand Down
Loading