Skip to content

Commit

Permalink
Merge pull request #2454 from Bensuo/l0_cmd-buf_multi-device
Browse files Browse the repository at this point in the history
Fix L0 command-buffer consumption of multi-device kernels
  • Loading branch information
martygrant authored Dec 13, 2024
2 parents 73e5f3c + fcddf07 commit b7047f6
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 25 deletions.
73 changes: 48 additions & 25 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,28 +895,31 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t CommandBuffer) {
/**
* Sets the kernel arguments for a kernel command that will be appended to the
* command buffer.
* @param[in] CommandBuffer The CommandBuffer where the command will be
* @param[in] Device The Device associated with the command-buffer where the
* kernel command will be appended.
* @param[in,out] Arguments stored in the ur_kernel_handle_t object to be set
* on the /p ZeKernel object.
* @param[in] ZeKernel The handle to the Level-Zero kernel that will be
* appended.
* @param[in] Kernel The handle to the kernel that will be appended.
* @return UR_RESULT_SUCCESS or an error code on failure
*/
ur_result_t
setKernelPendingArguments(ur_exp_command_buffer_handle_t CommandBuffer,
ur_kernel_handle_t Kernel) {

ur_result_t setKernelPendingArguments(
ur_device_handle_t Device,
std::vector<ur_kernel_handle_t_::ArgumentInfo> &PendingArguments,
ze_kernel_handle_t ZeKernel) {
// If there are any pending arguments set them now.
for (auto &Arg : Kernel->PendingArguments) {
for (auto &Arg : PendingArguments) {
// The ArgValue may be a NULL pointer in which case a NULL value is used for
// the kernel argument declared as a pointer to global or constant memory.
char **ZeHandlePtr = nullptr;
if (Arg.Value) {
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode,
CommandBuffer->Device, nullptr, 0u));
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device,
nullptr, 0u));
}
ZE2UR_CALL(zeKernelSetArgumentValue,
(Kernel->ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
}
Kernel->PendingArguments.clear();
PendingArguments.clear();

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -952,21 +955,29 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;

auto Platform = CommandBuffer->Context->getPlatform();
auto ZeDevice = CommandBuffer->Device->ZeDevice;

if (NumKernelAlternatives > 0) {
ZeMutableCommandDesc.flags |=
ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;

std::vector<ze_kernel_handle_t> TranslatedKernelHandles(
NumKernelAlternatives + 1, nullptr);

ze_kernel_handle_t ZeMainKernel{};
UR_CALL(getZeKernel(ZeDevice, Kernel, &ZeMainKernel));

// Translate main kernel first
ZE2UR_CALL(zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, Kernel->ZeKernel,
(ZEL_HANDLE_KERNEL, ZeMainKernel,
(void **)&TranslatedKernelHandles[0]));

for (size_t i = 0; i < NumKernelAlternatives; i++) {
ze_kernel_handle_t ZeAltKernel{};
UR_CALL(getZeKernel(ZeDevice, KernelAlternatives[i], &ZeAltKernel));

ZE2UR_CALL(zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, KernelAlternatives[i]->ZeKernel,
(ZEL_HANDLE_KERNEL, ZeAltKernel,
(void **)&TranslatedKernelHandles[i + 1]));
}

Expand Down Expand Up @@ -1023,23 +1034,28 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
Kernel->Mutex, Kernel->Program->Mutex, CommandBuffer->Mutex);

auto Device = CommandBuffer->Device;
ze_kernel_handle_t ZeKernel{};
UR_CALL(getZeKernel(Device->ZeDevice, Kernel, &ZeKernel));

if (GlobalWorkOffset != NULL) {
UR_CALL(setKernelGlobalOffset(CommandBuffer->Context, Kernel->ZeKernel,
WorkDim, GlobalWorkOffset));
UR_CALL(setKernelGlobalOffset(CommandBuffer->Context, ZeKernel, WorkDim,
GlobalWorkOffset));
}

// If there are any pending arguments set them now.
if (!Kernel->PendingArguments.empty()) {
UR_CALL(setKernelPendingArguments(CommandBuffer, Kernel));
UR_CALL(
setKernelPendingArguments(Device, Kernel->PendingArguments, ZeKernel));
}

ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
uint32_t WG[3];
UR_CALL(calculateKernelWorkDimensions(Kernel->ZeKernel, CommandBuffer->Device,
UR_CALL(calculateKernelWorkDimensions(ZeKernel, Device,
ZeThreadGroupDimensions, WG, WorkDim,
GlobalWorkSize, LocalWorkSize));

ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2]));
ZE2UR_CALL(zeKernelSetGroupSize, (ZeKernel, WG[0], WG[1], WG[2]));

CommandBuffer->KernelsList.push_back(Kernel);
for (size_t i = 0; i < NumKernelAlternatives; i++) {
Expand All @@ -1064,7 +1080,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
SyncPointWaitList, false, RetSyncPoint, ZeEventList, ZeLaunchEvent));

ZE2UR_CALL(zeCommandListAppendLaunchKernel,
(CommandBuffer->ZeComputeCommandList, Kernel->ZeKernel,
(CommandBuffer->ZeComputeCommandList, ZeKernel,
&ZeThreadGroupDimensions, ZeLaunchEvent, ZeEventList.size(),
getPointerFromVector(ZeEventList)));

Expand Down Expand Up @@ -1837,6 +1853,7 @@ ur_result_t updateKernelCommand(
const auto CommandBuffer = Command->CommandBuffer;
const void *NextDesc = nullptr;
auto Platform = CommandBuffer->Context->getPlatform();
auto ZeDevice = CommandBuffer->Device->ZeDevice;

uint32_t Dim = CommandDesc->newWorkDim;
size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset;
Expand All @@ -1845,11 +1862,14 @@ ur_result_t updateKernelCommand(

// Kernel handle must be updated first for a given CommandId if required
ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel;

if (NewKernel && Command->Kernel != NewKernel) {
ze_kernel_handle_t ZeNewKernel{};
UR_CALL(getZeKernel(ZeDevice, NewKernel, &ZeNewKernel));

ze_kernel_handle_t ZeKernelTranslated = nullptr;
ZE2UR_CALL(
zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, NewKernel->ZeKernel, (void **)&ZeKernelTranslated));
ZE2UR_CALL(zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&ZeKernelTranslated));

ZE2UR_CALL(Platform->ZeMutableCmdListExt
.zexCommandListUpdateMutableCommandKernelsExp,
Expand Down Expand Up @@ -1906,10 +1926,13 @@ ur_result_t updateKernelCommand(
// by the driver for the kernel.
bool UpdateWGSize = NewLocalWorkSize == nullptr;

ze_kernel_handle_t ZeKernel{};
UR_CALL(getZeKernel(ZeDevice, Command->Kernel, &ZeKernel));

uint32_t WG[3];
UR_CALL(calculateKernelWorkDimensions(
Command->Kernel->ZeKernel, CommandBuffer->Device,
ZeThreadGroupDimensions, WG, Dim, NewGlobalWorkSize, NewLocalWorkSize));
UR_CALL(calculateKernelWorkDimensions(ZeKernel, CommandBuffer->Device,
ZeThreadGroupDimensions, WG, Dim,
NewGlobalWorkSize, NewLocalWorkSize));

auto MutableGroupCountDesc =
std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t>>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
urProgramSetSpecializationConstantsTest.InvalidValueSize/*
urProgramSetSpecializationConstantsTest.InvalidValueId/*
urProgramSetSpecializationConstantsTest.InvalidValuePtr/*
{{OPT}}urMultiDeviceCommandBufferExpTest.*
138 changes: 138 additions & 0 deletions test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,141 @@ TEST_F(urMultiDeviceProgramCreateWithBinaryTest, CheckProgramGetInfo) {
reinterpret_cast<char *>(property_value.data());
ASSERT_STRNE(returned_kernel_names, "");
}

struct urMultiDeviceCommandBufferExpTest
: urMultiDeviceProgramCreateWithBinaryTest {
void SetUp() override {
UUR_RETURN_ON_FATAL_FAILURE(
urMultiDeviceProgramCreateWithBinaryTest::SetUp());

auto kernelName =
uur::KernelsEnvironment::instance->GetEntryPointNames("foo")[0];

ASSERT_SUCCESS(urProgramBuild(context, binary_program, nullptr));
ASSERT_SUCCESS(
urKernelCreate(binary_program, kernelName.data(), &kernel));
}

void TearDown() override {
if (kernel) {
EXPECT_SUCCESS(urKernelRelease(kernel));
}
UUR_RETURN_ON_FATAL_FAILURE(
urMultiDeviceProgramCreateWithBinaryTest::TearDown());
}

static bool hasCommandBufferSupport(ur_device_handle_t device) {
ur_bool_t cmd_buffer_support = false;
auto res = urDeviceGetInfo(
device, UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP,
sizeof(cmd_buffer_support), &cmd_buffer_support, nullptr);

if (res) {
return false;
}

return cmd_buffer_support;
}

static bool hasCommandBufferUpdateSupport(ur_device_handle_t device) {
ur_device_command_buffer_update_capability_flags_t
update_capability_flags;
auto res = urDeviceGetInfo(
device, UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_CAPABILITIES_EXP,
sizeof(update_capability_flags), &update_capability_flags, nullptr);

if (res) {
return false;
}

return (0 != update_capability_flags);
}

ur_kernel_handle_t kernel = nullptr;

static constexpr size_t global_offset = 0;
static constexpr size_t n_dimensions = 1;
static constexpr size_t global_size = 64;
static constexpr size_t local_size = 4;
};

TEST_F(urMultiDeviceCommandBufferExpTest, Enqueue) {
for (size_t i = 0; i < devices.size(); i++) {
auto device = devices[i];
if (!hasCommandBufferSupport(device)) {
continue;
}

// Create command-buffer
uur::raii::CommandBuffer cmd_buf_handle;
ASSERT_SUCCESS(urCommandBufferCreateExp(context, device, nullptr,
cmd_buf_handle.ptr()));

// Append kernel command to command-buffer and close command-buffer
ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp(
cmd_buf_handle, kernel, n_dimensions, &global_offset, &global_size,
&local_size, 0, nullptr, 0, nullptr, 0, nullptr, nullptr, nullptr,
nullptr));
ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle));

// Verify execution succeeds
ASSERT_SUCCESS(urCommandBufferEnqueueExp(cmd_buf_handle, queues[i], 0,
nullptr, nullptr));
ASSERT_SUCCESS(urQueueFinish(queues[i]));
}
}

TEST_F(urMultiDeviceCommandBufferExpTest, Update) {
for (size_t i = 0; i < devices.size(); i++) {
auto device = devices[i];
if (!(hasCommandBufferSupport(device) &&
hasCommandBufferUpdateSupport(device))) {
continue;
}

// Create a command-buffer with update enabled.
ur_exp_command_buffer_desc_t desc{
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC, nullptr, true, false,
false};

// Create command-buffer
uur::raii::CommandBuffer cmd_buf_handle;
ASSERT_SUCCESS(urCommandBufferCreateExp(context, device, &desc,
cmd_buf_handle.ptr()));

// Append kernel command to command-buffer and close command-buffer
uur::raii::CommandBufferCommand command;
ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp(
cmd_buf_handle, kernel, n_dimensions, &global_offset, &global_size,
&local_size, 0, nullptr, 0, nullptr, 0, nullptr, nullptr, nullptr,
command.ptr()));
ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle));

// Verify execution succeeds
ASSERT_SUCCESS(urCommandBufferEnqueueExp(cmd_buf_handle, queues[i], 0,
nullptr, nullptr));
ASSERT_SUCCESS(urQueueFinish(queues[i]));

// Update kernel and enqueue command-buffer again
ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = {
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype
nullptr, // pNext
kernel, // hNewKernel
0, // numNewMemObjArgs
0, // numNewPointerArgs
0, // numNewValueArgs
n_dimensions, // newWorkDim
nullptr, // pNewMemObjArgList
nullptr, // pNewPointerArgList
nullptr, // pNewValueArgList
nullptr, // pNewGlobalWorkOffset
nullptr, // pNewGlobalWorkSize
nullptr, // pNewLocalWorkSize
};
ASSERT_SUCCESS(
urCommandBufferUpdateKernelLaunchExp(command, &update_desc));
ASSERT_SUCCESS(urCommandBufferEnqueueExp(cmd_buf_handle, queues[i], 0,
nullptr, nullptr));
ASSERT_SUCCESS(urQueueFinish(queues[i]));
}
}

0 comments on commit b7047f6

Please sign in to comment.