Skip to content

Commit

Permalink
Fix USM function pointers caching (intel#1008)
Browse files Browse the repository at this point in the history
Fix function pointer caching to properly distinguish different functions
as func ptr types are not unique.

Signed-off-by: James Brodman <[email protected]>
  • Loading branch information
jbrodman authored and romanovvlad committed Jan 15, 2020
1 parent fba2e06 commit fca1736
Showing 1 changed file with 46 additions and 28 deletions.
74 changes: 46 additions & 28 deletions sycl/plugins/opencl/pi_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,21 @@ template <class To, class From> To cast(From value) {
return (To)(value);
}

// Names of USM functions that are queried from OpenCL
const char clHostMemAllocName[] = "clHostMemAllocINTEL";
const char clDeviceMemAllocName[] = "clDeviceMemAllocINTEL";
const char clSharedMemAllocName[] = "clSharedMemAllocINTEL";
const char clMemFreeName[] = "clMemFreeINTEL";
const char clSetKernelArgMemPointerName[] = "clSetKernelArgMemPointerINTEL";
const char clEnqueueMemsetName[] = "clEnqueueMemsetINTEL";
const char clEnqueueMemcpyName[] = "clEnqueueMemcpyINTEL";
const char clEnqueueMigrateMemName[] = "clEnqueueMigrateMemINTEL";
const char clEnqueueMemAdviseName[] = "clEnqueueMemAdviseINTEL";
const char clGetMemAllocInfoName[] = "clGetMemAllocInfoINTEL";

// USM helper function to get an extension function pointer
template <typename T>
pi_result getExtFuncFromContext(pi_context context, const char *func, T *fptr) {
template <const char *FuncName, typename T>
static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
// TODO
// Potentially redo caching as PI interface changes.
thread_local static std::map<pi_context, T> FuncPtrs;
Expand Down Expand Up @@ -68,11 +80,11 @@ pi_result getExtFuncFromContext(pi_context context, const char *func, T *fptr) {
return PI_INVALID_CONTEXT;
}

T FuncPtr = (T) clGetExtensionFunctionAddressForPlatform(curPlatform,
func);
if (!FuncPtr) {
T FuncPtr =
(T)clGetExtensionFunctionAddressForPlatform(curPlatform, FuncName);

if (!FuncPtr)
return PI_INVALID_VALUE;
}

*fptr = FuncPtr;
FuncPtrs[context] = FuncPtr;
Expand All @@ -98,24 +110,24 @@ static pi_result USMSetIndirectAccess(pi_kernel kernel) {
return cast<pi_result>(CLErr);
}

getExtFuncFromContext<clHostMemAllocINTEL_fn>(cast<pi_context>(CLContext),
"clHostMemAllocINTEL", &HFunc);
getExtFuncFromContext<clHostMemAllocName, clHostMemAllocINTEL_fn>(
cast<pi_context>(CLContext), &HFunc);
if (HFunc) {
clSetKernelExecInfo(cast<cl_kernel>(kernel),
CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL,
sizeof(cl_bool), &TrueVal);
}

getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
cast<pi_context>(CLContext), "clDeviceMemAllocINTEL", &DFunc);
getExtFuncFromContext<clDeviceMemAllocName, clDeviceMemAllocINTEL_fn>(
cast<pi_context>(CLContext), &DFunc);
if (DFunc) {
clSetKernelExecInfo(cast<cl_kernel>(kernel),
CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL,
sizeof(cl_bool), &TrueVal);
}

getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
cast<pi_context>(CLContext), "clSharedMemAllocINTEL", &SFunc);
getExtFuncFromContext<clSharedMemAllocName, clSharedMemAllocINTEL_fn>(
cast<pi_context>(CLContext), &SFunc);
if (SFunc) {
clSetKernelExecInfo(cast<cl_kernel>(kernel),
CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL,
Expand Down Expand Up @@ -569,8 +581,8 @@ pi_result OCL(piextUSMHostAlloc)(void **result_ptr, pi_context context,

// First we need to look up the function pointer
clHostMemAllocINTEL_fn FuncPtr = nullptr;
RetVal = getExtFuncFromContext<clHostMemAllocINTEL_fn>(
context, "clHostMemAllocINTEL", &FuncPtr);
RetVal = getExtFuncFromContext<clHostMemAllocName, clHostMemAllocINTEL_fn>(
context, &FuncPtr);

if (FuncPtr) {
Ptr = FuncPtr(cast<cl_context>(context),
Expand Down Expand Up @@ -601,8 +613,9 @@ pi_result OCL(piextUSMDeviceAlloc)(void **result_ptr, pi_context context,

// First we need to look up the function pointer
clDeviceMemAllocINTEL_fn FuncPtr = nullptr;
RetVal = getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
context, "clDeviceMemAllocINTEL", &FuncPtr);
RetVal =
getExtFuncFromContext<clDeviceMemAllocName, clDeviceMemAllocINTEL_fn>(
context, &FuncPtr);

if (FuncPtr) {
Ptr = FuncPtr(cast<cl_context>(context), cast<cl_device_id>(device),
Expand Down Expand Up @@ -633,8 +646,9 @@ pi_result OCL(piextUSMSharedAlloc)(void **result_ptr, pi_context context,

// First we need to look up the function pointer
clSharedMemAllocINTEL_fn FuncPtr = nullptr;
RetVal = getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
context, "clSharedMemAllocINTEL", &FuncPtr);
RetVal =
getExtFuncFromContext<clSharedMemAllocName, clSharedMemAllocINTEL_fn>(
context, &FuncPtr);

if (FuncPtr) {
Ptr = FuncPtr(cast<cl_context>(context), cast<cl_device_id>(device),
Expand All @@ -655,8 +669,8 @@ pi_result OCL(piextUSMFree)(pi_context context, void *ptr) {

clMemFreeINTEL_fn FuncPtr = nullptr;
pi_result RetVal = PI_INVALID_OPERATION;
RetVal = getExtFuncFromContext<clMemFreeINTEL_fn>(context, "clMemFreeINTEL",
&FuncPtr);
RetVal = getExtFuncFromContext<clMemFreeName, clMemFreeINTEL_fn>(context,
&FuncPtr);

if (FuncPtr) {
RetVal = cast<pi_result>(FuncPtr(cast<cl_context>(context), ptr));
Expand Down Expand Up @@ -687,8 +701,9 @@ pi_result OCL(piextKernelSetArgPointer)(pi_kernel kernel, pi_uint32 arg_index,
}

clSetKernelArgMemPointerINTEL_fn FuncPtr = nullptr;
pi_result RetVal = getExtFuncFromContext<clSetKernelArgMemPointerINTEL_fn>(
cast<pi_context>(CLContext), "clSetKernelArgMemPointerINTEL", &FuncPtr);
pi_result RetVal = getExtFuncFromContext<clSetKernelArgMemPointerName,
clSetKernelArgMemPointerINTEL_fn>(
cast<pi_context>(CLContext), &FuncPtr);

if (FuncPtr) {
// OpenCL passes pointers by value not by reference
Expand Down Expand Up @@ -727,8 +742,9 @@ pi_result OCL(piextUSMEnqueueMemset)(pi_queue queue, void *ptr, pi_int32 value,
}

clEnqueueMemsetINTEL_fn FuncPtr = nullptr;
pi_result RetVal = getExtFuncFromContext<clEnqueueMemsetINTEL_fn>(
cast<pi_context>(CLContext), "clEnqueueMemsetINTEL", &FuncPtr);
pi_result RetVal =
getExtFuncFromContext<clEnqueueMemsetName, clEnqueueMemsetINTEL_fn>(
cast<pi_context>(CLContext), &FuncPtr);

if (FuncPtr) {
RetVal = cast<pi_result>(FuncPtr(cast<cl_command_queue>(queue), ptr, value,
Expand Down Expand Up @@ -767,8 +783,9 @@ pi_result OCL(piextUSMEnqueueMemcpy)(pi_queue queue, pi_bool blocking,
}

clEnqueueMemcpyINTEL_fn FuncPtr = nullptr;
pi_result RetVal = getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
cast<pi_context>(CLContext), "clEnqueueMemcpyINTEL", &FuncPtr);
pi_result RetVal =
getExtFuncFromContext<clEnqueueMemcpyName, clEnqueueMemcpyINTEL_fn>(
cast<pi_context>(CLContext), &FuncPtr);

if (FuncPtr) {
RetVal = cast<pi_result>(
Expand Down Expand Up @@ -893,8 +910,9 @@ pi_result OCL(piextUSMGetMemAllocInfo)(pi_context context, const void *ptr,
size_t *param_value_size_ret) {

clGetMemAllocInfoINTEL_fn FuncPtr = nullptr;
pi_result RetVal = getExtFuncFromContext<clGetMemAllocInfoINTEL_fn>(
context, "clGetMemAllocInfoINTEL", &FuncPtr);
pi_result RetVal =
getExtFuncFromContext<clGetMemAllocInfoName, clGetMemAllocInfoINTEL_fn>(
context, &FuncPtr);

if (FuncPtr) {
RetVal = cast<pi_result>(FuncPtr(cast<cl_context>(context), ptr, param_name,
Expand Down

0 comments on commit fca1736

Please sign in to comment.