From 5a1702f654777dac66c5a7ca0d71d2b9904c343d Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Thu, 23 May 2024 12:31:22 -0400 Subject: [PATCH 01/12] No longer silently hide errors in Metal completion handlers --- src/runtime/metal.cpp | 17 +++++++ test/correctness/CMakeLists.txt | 1 + ...u_metal_completion_handler_error_check.cpp | 51 +++++++++++++++++++ 3 files changed, 69 insertions(+) create mode 100644 test/correctness/gpu_metal_completion_handler_error_check.cpp diff --git a/src/runtime/metal.cpp b/src/runtime/metal.cpp index abc935b0743e..83cd19ab26fc 100644 --- a/src/runtime/metal.cpp +++ b/src/runtime/metal.cpp @@ -426,9 +426,26 @@ WEAK command_buffer_completed_handler_block_descriptor_1 command_buffer_complete 0, sizeof(command_buffer_completed_handler_block_literal)}; WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handler_block_literal *block, mtl_command_buffer *buffer) { + halide_print(nullptr, "Completion handler invoked\n"); objc_id buffer_error = command_buffer_error(buffer); if (buffer_error != nullptr) { + retain_ns_object(buffer_error); + + // Obtain the localized NSString for the error + typedef objc_id (*localized_description_method_t)(objc_id objc, objc_sel sel); + localized_description_method_t localized_description_method = (localized_description_method_t)&objc_msgSend; + objc_id error_ns_string = (*localized_description_method)(buffer_error, sel_getUid("localizedDescription")); + + // Obtain a C-style string, but do not release the NSString until reporting/printing the error + typedef char* (*utf8_string_method_t)(objc_id objc, objc_sel sel); + utf8_string_method_t utf8_string_method = (utf8_string_method_t)&objc_msgSend; + char* error_string = (*utf8_string_method)(error_ns_string, sel_getUid("UTF8String")); + ns_log_object(buffer_error); + + // This is an error indicating the command buffer wasn't executed, and because it is asynchronous + // with respect to the pipeline that caused it, it is not recoverable + halide_error(nullptr, error_string); release_ns_object(buffer_error); } } diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index ae4a6776ac72..e910eeee6d2d 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -138,6 +138,7 @@ tests(GROUPS correctness gpu_jit_explicit_copy_to_device.cpp gpu_large_alloc.cpp gpu_many_kernels.cpp + gpu_metal_completion_handler_error_check.cpp gpu_mixed_dimensionality.cpp gpu_mixed_shared_mem_types.cpp gpu_multi_kernel.cpp diff --git a/test/correctness/gpu_metal_completion_handler_error_check.cpp b/test/correctness/gpu_metal_completion_handler_error_check.cpp new file mode 100644 index 000000000000..291102a6dbb8 --- /dev/null +++ b/test/correctness/gpu_metal_completion_handler_error_check.cpp @@ -0,0 +1,51 @@ +#include "Halide.h" +#include + +using namespace Halide; + +bool errored = false; +void my_error(JITUserContext *, const char *msg) { + // Emitting "error.*:" to stdout or stderr will cause CMake to report the + // test as a failure on Windows, regardless of error code returned, + // hence the abbreviation to "err". + printf("Expected err: %s\n", msg); + errored = true; +} + +int main(int argc, char **argv) { + Target t = get_jit_target_from_environment(); + if (!t.has_feature(Target::Metal)) { + printf("[SKIP] Metal not enabled\n"); + return 0; + } + + Func f; + Var c, x, ci, xi; + RVar rxi; + RDom r(0, 1000, -327600, 327600); + + // Create a function that is very costly to execute, resulting in a timeout + // on the GPU + f(x, c) = x + 0.1f * c; + f(r.x, c) += cos(sin(tan(cosh(tanh(sinh(exp(tanh(exp(log(tan(cos(exp(f(r.x, c) / cos(cosh(sinh(sin(f(r.x, c))))) + / tanh(tan(tan(f(r.x, c)))))))))) + cast(cast(f(r.x, c) / cast(log(f(r.x, c)))))))))))); + + f.gpu_tile(x, c, xi, ci, 4, 4); + f.update(0).gpu_tile(r.x, c, rxi, ci, 4, 4); + + // Because the error handler is invoked from a Metal runtime thread, setting a custom handler for just + // this pipeline is insufficient. Instead, we set a custom handler for the JIT runtime + JITHandlers handlers; + handlers.custom_error = my_error; + Internal::JITSharedRuntime::set_default_handlers(handlers); + + f.realize({1000, 100}, t); + + if (!errored) { + printf("There was supposed to be an error\n"); + return 1; + } + + printf("Success!\n"); + return 0; +} From 4386a2ae819da94300c9f30d28bf4264e73b1cb8 Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Tue, 28 May 2024 13:22:29 -0400 Subject: [PATCH 02/12] Actually implement alternative --- src/runtime/metal.cpp | 77 +++++++++++++++---- ...u_metal_completion_handler_error_check.cpp | 15 +++- 2 files changed, 76 insertions(+), 16 deletions(-) diff --git a/src/runtime/metal.cpp b/src/runtime/metal.cpp index 83cd19ab26fc..b7fc09997bf9 100644 --- a/src/runtime/metal.cpp +++ b/src/runtime/metal.cpp @@ -381,6 +381,8 @@ WEAK int halide_metal_release_context(void *user_context) { } // extern "C" +extern "C" size_t strnlen(const char *s, size_t maxlen); + namespace Halide { namespace Runtime { namespace Internal { @@ -390,11 +392,15 @@ class MetalContextHolder { objc_id pool; void *const user_context; int status; // must always be a valid halide_error_code_t value + static int saved_status; // must always be a valid halide_error_code_t value + static halide_mutex saved_status_mutex; // mutex for accessing saved status + static char error_string[1024]; public: mtl_device *device; mtl_command_queue *queue; + ALWAYS_INLINE MetalContextHolder(void *user_context, bool create) : pool(create_autorelease_pool()), user_context(user_context) { status = halide_metal_acquire_context(user_context, &device, &queue, create); @@ -404,11 +410,57 @@ class MetalContextHolder { drain_autorelease_pool(pool); } - ALWAYS_INLINE int error() const { - return status; + // We use two variants of this function: one for just checking status, and one + // that returns and clears the previous status. + ALWAYS_INLINE static int get_and_clear_saved_status(char* error_string = nullptr) { + halide_mutex_lock(&saved_status_mutex); + int result = saved_status; + saved_status = halide_error_code_success; + if (error_string != nullptr && result != halide_error_code_success + && strnlen(MetalContextHolder::error_string, 1024) > 0 ) { + strncpy(error_string, MetalContextHolder::error_string, 1024); + MetalContextHolder::error_string[0] = '\0'; + debug(nullptr) << "MetalContextHolder::get_and_clear_saved_status: " << error_string << "\n"; + } + halide_mutex_unlock(&saved_status_mutex); + return result; + } + + // Returns the previous status without clearing, and optionally copies the error string + ALWAYS_INLINE static int get_saved_status(char* error_string = nullptr) { + halide_mutex_lock(&saved_status_mutex); + int result = saved_status; + if (error_string != nullptr && result != halide_error_code_success + && strnlen(MetalContextHolder::error_string, 1024) > 0 ) { + strncpy(error_string, MetalContextHolder::error_string, 1024); + } + halide_mutex_unlock(&saved_status_mutex); + return result; + } + + ALWAYS_INLINE static void set_saved_status(int new_status, char* error_string = nullptr) { + halide_mutex_lock(&saved_status_mutex); + saved_status = new_status; + if (error_string != nullptr) { + strncpy(MetalContextHolder::error_string, error_string, 1024); + debug(nullptr) << "MetalContextHolder::set_saved_status: " << error_string << "\n"; + } + halide_mutex_unlock(&saved_status_mutex); + } + + ALWAYS_INLINE int error(char* error_string = nullptr) const { + return status || get_saved_status(error_string); + } + + ALWAYS_INLINE int get_and_clear_error(char* error_string = nullptr) const { + return status || get_and_clear_saved_status(error_string); } }; +int MetalContextHolder::saved_status = halide_error_code_success; +halide_mutex MetalContextHolder::saved_status_mutex = {0}; +char MetalContextHolder::error_string[1024] = {0}; + struct command_buffer_completed_handler_block_descriptor_1 { unsigned long reserved; unsigned long block_size; @@ -426,7 +478,6 @@ WEAK command_buffer_completed_handler_block_descriptor_1 command_buffer_complete 0, sizeof(command_buffer_completed_handler_block_literal)}; WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handler_block_literal *block, mtl_command_buffer *buffer) { - halide_print(nullptr, "Completion handler invoked\n"); objc_id buffer_error = command_buffer_error(buffer); if (buffer_error != nullptr) { retain_ns_object(buffer_error); @@ -444,8 +495,8 @@ WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handl ns_log_object(buffer_error); // This is an error indicating the command buffer wasn't executed, and because it is asynchronous - // with respect to the pipeline that caused it, it is not recoverable - halide_error(nullptr, error_string); + // we store it in a static variable to report on the next check for an error + MetalContextHolder::set_saved_status(halide_error_code_device_run_failed, error_string); release_ns_object(buffer_error); } } @@ -493,7 +544,7 @@ WEAK int halide_metal_device_malloc(void *user_context, halide_buffer_t *buf) { MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } #ifdef DEBUG_RUNTIME @@ -561,7 +612,7 @@ WEAK int halide_metal_device_free(void *user_context, halide_buffer_t *buf) { WEAK int halide_metal_initialize_kernels(void *user_context, void **state_ptr, const char *source, int source_size) { MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } #ifdef DEBUG_RUNTIME uint64_t t_before = halide_current_time_ns(user_context); @@ -617,7 +668,7 @@ WEAK int halide_metal_device_sync(void *user_context, struct halide_buffer_t *bu MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } halide_metal_device_sync_internal(metal_context.queue, buffer); @@ -668,7 +719,7 @@ WEAK int halide_metal_copy_to_device(void *user_context, halide_buffer_t *buffer MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } if (!(buffer->host && buffer->device)) { @@ -712,7 +763,7 @@ WEAK int halide_metal_copy_to_host(void *user_context, halide_buffer_t *buffer) MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } halide_metal_device_sync_internal(metal_context.queue, buffer); @@ -755,7 +806,7 @@ WEAK int halide_metal_run(void *user_context, MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } mtl_command_buffer *command_buffer = new_command_buffer(metal_context.queue, entry_name, strlen(entry_name)); @@ -979,7 +1030,7 @@ WEAK int halide_metal_buffer_copy(void *user_context, struct halide_buffer_t *sr { MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } debug(user_context) @@ -1053,7 +1104,7 @@ WEAK int metal_device_crop_from_offset(void *user_context, struct halide_buffer_t *dst) { MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } dst->device_interface = src->device_interface; diff --git a/test/correctness/gpu_metal_completion_handler_error_check.cpp b/test/correctness/gpu_metal_completion_handler_error_check.cpp index 291102a6dbb8..38802d566f41 100644 --- a/test/correctness/gpu_metal_completion_handler_error_check.cpp +++ b/test/correctness/gpu_metal_completion_handler_error_check.cpp @@ -19,7 +19,7 @@ int main(int argc, char **argv) { return 0; } - Func f; + Func f, g; Var c, x, ci, xi; RVar rxi; RDom r(0, 1000, -327600, 327600); @@ -38,8 +38,17 @@ int main(int argc, char **argv) { JITHandlers handlers; handlers.custom_error = my_error; Internal::JITSharedRuntime::set_default_handlers(handlers); - - f.realize({1000, 100}, t); + f.jit_handlers().custom_error = my_error; + + // Metal is surprisingly resilient. Run this in a loop just to make sure we trigger the error. + for (int i = 0; (i < 10) && !errored ; i++) { + auto out = f.realize({1000, 100}, t); + int result = out.device_sync(); + if (result != halide_error_code_success) { + printf("Device sync failed as expected: %d\n", result); + errored = true; + } + } if (!errored) { printf("There was supposed to be an error\n"); From 4b8a7a5f5236f6d70bbd97585f9c0401425d3ca8 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Mon, 3 Jun 2024 11:35:13 -0700 Subject: [PATCH 03/12] clang-format --- src/runtime/metal.cpp | 33 +++++++++---------- ...u_metal_completion_handler_error_check.cpp | 9 +++-- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/src/runtime/metal.cpp b/src/runtime/metal.cpp index b7fc09997bf9..2d01f6cc34b1 100644 --- a/src/runtime/metal.cpp +++ b/src/runtime/metal.cpp @@ -391,8 +391,8 @@ namespace Metal { class MetalContextHolder { objc_id pool; void *const user_context; - int status; // must always be a valid halide_error_code_t value - static int saved_status; // must always be a valid halide_error_code_t value + int status; // must always be a valid halide_error_code_t value + static int saved_status; // must always be a valid halide_error_code_t value static halide_mutex saved_status_mutex; // mutex for accessing saved status static char error_string[1024]; @@ -400,7 +400,6 @@ class MetalContextHolder { mtl_device *device; mtl_command_queue *queue; - ALWAYS_INLINE MetalContextHolder(void *user_context, bool create) : pool(create_autorelease_pool()), user_context(user_context) { status = halide_metal_acquire_context(user_context, &device, &queue, create); @@ -412,12 +411,11 @@ class MetalContextHolder { // We use two variants of this function: one for just checking status, and one // that returns and clears the previous status. - ALWAYS_INLINE static int get_and_clear_saved_status(char* error_string = nullptr) { + ALWAYS_INLINE static int get_and_clear_saved_status(char *error_string = nullptr) { halide_mutex_lock(&saved_status_mutex); int result = saved_status; saved_status = halide_error_code_success; - if (error_string != nullptr && result != halide_error_code_success - && strnlen(MetalContextHolder::error_string, 1024) > 0 ) { + if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) { strncpy(error_string, MetalContextHolder::error_string, 1024); MetalContextHolder::error_string[0] = '\0'; debug(nullptr) << "MetalContextHolder::get_and_clear_saved_status: " << error_string << "\n"; @@ -427,32 +425,31 @@ class MetalContextHolder { } // Returns the previous status without clearing, and optionally copies the error string - ALWAYS_INLINE static int get_saved_status(char* error_string = nullptr) { + ALWAYS_INLINE static int get_saved_status(char *error_string = nullptr) { halide_mutex_lock(&saved_status_mutex); int result = saved_status; - if (error_string != nullptr && result != halide_error_code_success - && strnlen(MetalContextHolder::error_string, 1024) > 0 ) { + if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) { strncpy(error_string, MetalContextHolder::error_string, 1024); } halide_mutex_unlock(&saved_status_mutex); return result; } - ALWAYS_INLINE static void set_saved_status(int new_status, char* error_string = nullptr) { + ALWAYS_INLINE static void set_saved_status(int new_status, char *error_string = nullptr) { halide_mutex_lock(&saved_status_mutex); saved_status = new_status; - if (error_string != nullptr) { - strncpy(MetalContextHolder::error_string, error_string, 1024); - debug(nullptr) << "MetalContextHolder::set_saved_status: " << error_string << "\n"; + if (error_string != nullptr) { + strncpy(MetalContextHolder::error_string, error_string, 1024); + debug(nullptr) << "MetalContextHolder::set_saved_status: " << error_string << "\n"; } halide_mutex_unlock(&saved_status_mutex); } - ALWAYS_INLINE int error(char* error_string = nullptr) const { + ALWAYS_INLINE int error(char *error_string = nullptr) const { return status || get_saved_status(error_string); } - ALWAYS_INLINE int get_and_clear_error(char* error_string = nullptr) const { + ALWAYS_INLINE int get_and_clear_error(char *error_string = nullptr) const { return status || get_and_clear_saved_status(error_string); } }; @@ -488,13 +485,13 @@ WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handl objc_id error_ns_string = (*localized_description_method)(buffer_error, sel_getUid("localizedDescription")); // Obtain a C-style string, but do not release the NSString until reporting/printing the error - typedef char* (*utf8_string_method_t)(objc_id objc, objc_sel sel); + typedef char *(*utf8_string_method_t)(objc_id objc, objc_sel sel); utf8_string_method_t utf8_string_method = (utf8_string_method_t)&objc_msgSend; - char* error_string = (*utf8_string_method)(error_ns_string, sel_getUid("UTF8String")); + char *error_string = (*utf8_string_method)(error_ns_string, sel_getUid("UTF8String")); ns_log_object(buffer_error); - // This is an error indicating the command buffer wasn't executed, and because it is asynchronous + // This is an error indicating the command buffer wasn't executed, and because it is asynchronous // we store it in a static variable to report on the next check for an error MetalContextHolder::set_saved_status(halide_error_code_device_run_failed, error_string); release_ns_object(buffer_error); diff --git a/test/correctness/gpu_metal_completion_handler_error_check.cpp b/test/correctness/gpu_metal_completion_handler_error_check.cpp index 38802d566f41..0da8e6430316 100644 --- a/test/correctness/gpu_metal_completion_handler_error_check.cpp +++ b/test/correctness/gpu_metal_completion_handler_error_check.cpp @@ -26,10 +26,9 @@ int main(int argc, char **argv) { // Create a function that is very costly to execute, resulting in a timeout // on the GPU - f(x, c) = x + 0.1f * c; - f(r.x, c) += cos(sin(tan(cosh(tanh(sinh(exp(tanh(exp(log(tan(cos(exp(f(r.x, c) / cos(cosh(sinh(sin(f(r.x, c))))) - / tanh(tan(tan(f(r.x, c)))))))))) + cast(cast(f(r.x, c) / cast(log(f(r.x, c)))))))))))); - + f(x, c) = x + 0.1f * c; + f(r.x, c) += cos(sin(tan(cosh(tanh(sinh(exp(tanh(exp(log(tan(cos(exp(f(r.x, c) / cos(cosh(sinh(sin(f(r.x, c))))) / tanh(tan(tan(f(r.x, c)))))))))) + cast(cast(f(r.x, c) / cast(log(f(r.x, c)))))))))))); + f.gpu_tile(x, c, xi, ci, 4, 4); f.update(0).gpu_tile(r.x, c, rxi, ci, 4, 4); @@ -41,7 +40,7 @@ int main(int argc, char **argv) { f.jit_handlers().custom_error = my_error; // Metal is surprisingly resilient. Run this in a loop just to make sure we trigger the error. - for (int i = 0; (i < 10) && !errored ; i++) { + for (int i = 0; (i < 10) && !errored; i++) { auto out = f.realize({1000, 100}, t); int result = out.device_sync(); if (result != halide_error_code_success) { From a443963058420399a8ce5a4e8d6f17491517cac2 Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Fri, 7 Jun 2024 13:42:44 -0400 Subject: [PATCH 04/12] Implement new API --- src/runtime/metal.cpp | 90 +++++++++++++++++++++++++++++++++---------- 1 file changed, 69 insertions(+), 21 deletions(-) diff --git a/src/runtime/metal.cpp b/src/runtime/metal.cpp index 2d01f6cc34b1..db86b39edd62 100644 --- a/src/runtime/metal.cpp +++ b/src/runtime/metal.cpp @@ -417,6 +417,7 @@ class MetalContextHolder { saved_status = halide_error_code_success; if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) { strncpy(error_string, MetalContextHolder::error_string, 1024); + error_string[1023] = '\0'; MetalContextHolder::error_string[0] = '\0'; debug(nullptr) << "MetalContextHolder::get_and_clear_saved_status: " << error_string << "\n"; } @@ -430,6 +431,7 @@ class MetalContextHolder { int result = saved_status; if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) { strncpy(error_string, MetalContextHolder::error_string, 1024); + error_string[1023] = '\0'; } halide_mutex_unlock(&saved_status_mutex); return result; @@ -440,17 +442,27 @@ class MetalContextHolder { saved_status = new_status; if (error_string != nullptr) { strncpy(MetalContextHolder::error_string, error_string, 1024); + error_string[1023] = '\0'; debug(nullptr) << "MetalContextHolder::set_saved_status: " << error_string << "\n"; } halide_mutex_unlock(&saved_status_mutex); } ALWAYS_INLINE int error(char *error_string = nullptr) const { - return status || get_saved_status(error_string); + if (status != halide_error_code_success) { + return status; + } else { + return get_saved_status(error_string); + } } ALWAYS_INLINE int get_and_clear_error(char *error_string = nullptr) const { - return status || get_and_clear_saved_status(error_string); + auto cleared_status = get_and_clear_saved_status(error_string); + if (status != halide_error_code_success) { + return status; + } else { + return cleared_status; + } } }; @@ -458,6 +470,55 @@ int MetalContextHolder::saved_status = halide_error_code_success; halide_mutex MetalContextHolder::saved_status_mutex = {0}; char MetalContextHolder::error_string[1024] = {0}; +} // namespace Metal +} // namespace Internal +} // namespace Runtime +} // namespace Halide + +extern "C" { +WEAK int halide_metal_command_buffer_completion_handler(void* user_context, mtl_command_buffer *buffer, char **returned_error_string) { + objc_id buffer_error = command_buffer_error(buffer); + if (buffer_error != nullptr) { + retain_ns_object(buffer_error); + + ns_log_object(buffer_error); + + // Obtain the localized NSString for the error + typedef objc_id (*localized_description_method_t)(objc_id objc, objc_sel sel); + localized_description_method_t localized_description_method = (localized_description_method_t)&objc_msgSend; + objc_id error_ns_string = (*localized_description_method)(buffer_error, sel_getUid("localizedDescription")); + + retain_ns_object(error_ns_string); + + // Obtain a C-style string + typedef char *(*utf8_string_method_t)(objc_id objc, objc_sel sel); + utf8_string_method_t utf8_string_method = (utf8_string_method_t)&objc_msgSend; + char *error_string = (*utf8_string_method)(error_ns_string, sel_getUid("UTF8String")); + + // Copy C-style string into a fresh buffer + if (returned_error_string != nullptr) { + *returned_error_string = (char*)malloc(sizeof(char) * 1024); + if (*returned_error_string != nullptr) { + strncpy(*returned_error_string, error_string, 1024); + (*returned_error_string)[1023] = '\0'; + } else { + debug(user_context) << "halide_metal_command_buffer_completion_handler: Failed to allocate memory for error string.\n"; + } + } + + release_ns_object(error_ns_string); + release_ns_object(buffer_error); + return halide_error_code_device_run_failed; + } + return halide_error_code_success; +} +} // extern "C" + +namespace Halide { +namespace Runtime { +namespace Internal { +namespace Metal { + struct command_buffer_completed_handler_block_descriptor_1 { unsigned long reserved; unsigned long block_size; @@ -475,27 +536,14 @@ WEAK command_buffer_completed_handler_block_descriptor_1 command_buffer_complete 0, sizeof(command_buffer_completed_handler_block_literal)}; WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handler_block_literal *block, mtl_command_buffer *buffer) { - objc_id buffer_error = command_buffer_error(buffer); - if (buffer_error != nullptr) { - retain_ns_object(buffer_error); + retain_ns_object(buffer); + char* error_string = nullptr; + auto status = halide_metal_command_buffer_completion_handler(nullptr, buffer, &error_string); + release_ns_object(buffer); - // Obtain the localized NSString for the error - typedef objc_id (*localized_description_method_t)(objc_id objc, objc_sel sel); - localized_description_method_t localized_description_method = (localized_description_method_t)&objc_msgSend; - objc_id error_ns_string = (*localized_description_method)(buffer_error, sel_getUid("localizedDescription")); + MetalContextHolder::set_saved_status(status, error_string); + free(error_string); - // Obtain a C-style string, but do not release the NSString until reporting/printing the error - typedef char *(*utf8_string_method_t)(objc_id objc, objc_sel sel); - utf8_string_method_t utf8_string_method = (utf8_string_method_t)&objc_msgSend; - char *error_string = (*utf8_string_method)(error_ns_string, sel_getUid("UTF8String")); - - ns_log_object(buffer_error); - - // This is an error indicating the command buffer wasn't executed, and because it is asynchronous - // we store it in a static variable to report on the next check for an error - MetalContextHolder::set_saved_status(halide_error_code_device_run_failed, error_string); - release_ns_object(buffer_error); - } } WEAK command_buffer_completed_handler_block_literal command_buffer_completed_handler_block = { From 79f29d58f807daff98deb897a317e740e9420152 Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Mon, 10 Jun 2024 14:13:59 -0400 Subject: [PATCH 05/12] Implement test and refine the API --- src/runtime/HalideRuntimeMetal.h | 12 +++++ src/runtime/metal.cpp | 43 ++++++++++++++--- ...u_metal_completion_handler_error_check.cpp | 13 ----- test/generator/CMakeLists.txt | 5 ++ ...al_completion_handler_override_aottest.cpp | 48 +++++++++++++++++++ ..._completion_handler_override_generator.cpp | 25 ++++++++++ 6 files changed, 126 insertions(+), 20 deletions(-) create mode 100644 test/generator/metal_completion_handler_override_aottest.cpp create mode 100644 test/generator/metal_completion_handler_override_generator.cpp diff --git a/src/runtime/HalideRuntimeMetal.h b/src/runtime/HalideRuntimeMetal.h index 8fd0f364cebb..255d24455938 100644 --- a/src/runtime/HalideRuntimeMetal.h +++ b/src/runtime/HalideRuntimeMetal.h @@ -68,6 +68,7 @@ extern uint64_t halide_metal_get_crop_offset(void *user_context, struct halide_b struct halide_metal_device; struct halide_metal_command_queue; +struct halide_metal_command_buffer; /** This prototype is exported as applications will typically need to * replace it to get Halide filters to execute on the same device and @@ -93,6 +94,17 @@ extern int halide_metal_acquire_context(void *user_context, struct halide_metal_ */ extern int halide_metal_release_context(void *user_context); +/** This function is called as part of the callback when a Metal command buffer completes. + * The return value, if not halide_error_code_success, will be stashed in Metal runtime and returned + * to the next call into the runtime, and the error string will be saved as well. + * The error string will be freed by the caller. The return value must be a valid Halide error code. + * This is called from the Metal driver, and thus: + * - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback + * - The function must be thread-safe +*/ +extern int halide_metal_command_buffer_completion_handler(void* user_context, struct halide_metal_command_buffer *buffer, + char **returned_error_string); + #ifdef __cplusplus } // End extern "C" #endif diff --git a/src/runtime/metal.cpp b/src/runtime/metal.cpp index db86b39edd62..f01a31a38532 100644 --- a/src/runtime/metal.cpp +++ b/src/runtime/metal.cpp @@ -12,6 +12,7 @@ extern "C" { extern objc_id MTLCreateSystemDefaultDevice(); extern struct ObjectiveCClass _NSConcreteGlobalBlock; +extern struct ObjectiveCClass _NSConcreteStackBlock; void *dlsym(void *, const char *); #define RTLD_DEFAULT ((void *)-2) } @@ -23,8 +24,8 @@ namespace Metal { typedef halide_metal_device mtl_device; typedef halide_metal_command_queue mtl_command_queue; +typedef halide_metal_command_buffer mtl_command_buffer; struct mtl_buffer; -struct mtl_command_buffer; struct mtl_compute_command_encoder; struct mtl_blit_command_encoder; struct mtl_compute_pipeline_state; @@ -476,6 +477,14 @@ char MetalContextHolder::error_string[1024] = {0}; } // namespace Halide extern "C" { +/** This function is called as part of the callback when a Metal command buffer completes. + * The return value, if not halide_error_code_success, will be stashed in Metal runtime and returned + * to the next call into the runtime, and the error string will be saved as well. + * The error string will be freed by the caller. The return value must be a valid Halide error code. + * This is called from the Metal driver, and thus: + * - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback + * - The function must be thread-safe +*/ WEAK int halide_metal_command_buffer_completion_handler(void* user_context, mtl_command_buffer *buffer, char **returned_error_string) { objc_id buffer_error = command_buffer_error(buffer); if (buffer_error != nullptr) { @@ -519,6 +528,14 @@ namespace Runtime { namespace Internal { namespace Metal { +struct user_context_block_byref { + void *isa; + struct user_context_block_byref *forwarding; + int flags; + int size; + void* user_context; +}; + struct command_buffer_completed_handler_block_descriptor_1 { unsigned long reserved; unsigned long block_size; @@ -530,6 +547,7 @@ struct command_buffer_completed_handler_block_literal { int reserved; void (*invoke)(command_buffer_completed_handler_block_literal *, mtl_command_buffer *buffer); struct command_buffer_completed_handler_block_descriptor_1 *descriptor; + struct user_context_block_byref *user_context_holder; }; WEAK command_buffer_completed_handler_block_descriptor_1 command_buffer_completed_handler_descriptor = { @@ -538,7 +556,7 @@ WEAK command_buffer_completed_handler_block_descriptor_1 command_buffer_complete WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handler_block_literal *block, mtl_command_buffer *buffer) { retain_ns_object(buffer); char* error_string = nullptr; - auto status = halide_metal_command_buffer_completion_handler(nullptr, buffer, &error_string); + auto status = halide_metal_command_buffer_completion_handler(block->user_context_holder->user_context, buffer, &error_string); release_ns_object(buffer); MetalContextHolder::set_saved_status(status, error_string); @@ -546,11 +564,7 @@ WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handl } -WEAK command_buffer_completed_handler_block_literal command_buffer_completed_handler_block = { - &_NSConcreteGlobalBlock, - (1 << 28) | (1 << 29), // BLOCK_IS_GLOBAL | BLOCK_HAS_DESCRIPTOR - 0, command_buffer_completed_handler_invoke, - &command_buffer_completed_handler_descriptor}; + } // namespace Metal } // namespace Internal @@ -995,6 +1009,21 @@ WEAK int halide_metal_run(void *user_context, threadsX, threadsY, threadsZ); end_encoding(encoder); + // Construct an Objective C block to check for errors on command buffer completion, saving the user context + user_context_block_byref user_context_holder = { + &_NSConcreteStackBlock, + &user_context_holder, + 0, + sizeof(user_context_holder), + user_context}; + + command_buffer_completed_handler_block_literal command_buffer_completed_handler_block = { + &_NSConcreteGlobalBlock, + /*(1 << 28) | */ (1 << 29), // BLOCK_IS_GLOBAL | BLOCK_HAS_DESCRIPTOR + 0, command_buffer_completed_handler_invoke, + &command_buffer_completed_handler_descriptor, + &user_context_holder}; + add_command_buffer_completed_handler(command_buffer, &command_buffer_completed_handler_block); commit_command_buffer(command_buffer); diff --git a/test/correctness/gpu_metal_completion_handler_error_check.cpp b/test/correctness/gpu_metal_completion_handler_error_check.cpp index 0da8e6430316..60e9c206d158 100644 --- a/test/correctness/gpu_metal_completion_handler_error_check.cpp +++ b/test/correctness/gpu_metal_completion_handler_error_check.cpp @@ -4,13 +4,6 @@ using namespace Halide; bool errored = false; -void my_error(JITUserContext *, const char *msg) { - // Emitting "error.*:" to stdout or stderr will cause CMake to report the - // test as a failure on Windows, regardless of error code returned, - // hence the abbreviation to "err". - printf("Expected err: %s\n", msg); - errored = true; -} int main(int argc, char **argv) { Target t = get_jit_target_from_environment(); @@ -32,12 +25,6 @@ int main(int argc, char **argv) { f.gpu_tile(x, c, xi, ci, 4, 4); f.update(0).gpu_tile(r.x, c, rxi, ci, 4, 4); - // Because the error handler is invoked from a Metal runtime thread, setting a custom handler for just - // this pipeline is insufficient. Instead, we set a custom handler for the JIT runtime - JITHandlers handlers; - handlers.custom_error = my_error; - Internal::JITSharedRuntime::set_default_handlers(handlers); - f.jit_handlers().custom_error = my_error; // Metal is surprisingly resilient. Run this in a loop just to make sure we trigger the error. for (int i = 0; (i < 10) && !errored; i++) { diff --git a/test/generator/CMakeLists.txt b/test/generator/CMakeLists.txt index fc1cbfc76e78..2c010ae07717 100644 --- a/test/generator/CMakeLists.txt +++ b/test/generator/CMakeLists.txt @@ -497,6 +497,11 @@ _add_halide_libraries(metadata_tester_ucon _add_halide_aot_tests(metadata_tester HALIDE_LIBRARIES metadata_tester metadata_tester_ucon) +# metal_completion_handler_override_aottest.cpp +# metal_completion_handler_override_generator.cpp +_add_halide_libraries(metal_completion_handler_override FEATURES user_context) +_add_halide_aot_tests(metal_completion_handler_override) + # msan_aottest.cpp # msan_generator.cpp if ("${Halide_TARGET}" MATCHES "webgpu") diff --git a/test/generator/metal_completion_handler_override_aottest.cpp b/test/generator/metal_completion_handler_override_aottest.cpp new file mode 100644 index 000000000000..39abbb82d6e5 --- /dev/null +++ b/test/generator/metal_completion_handler_override_aottest.cpp @@ -0,0 +1,48 @@ +#include "HalideBuffer.h" +#include "HalideRuntime.h" +#include "HalideRuntimeMetal.h" + +#include "metal_completion_handler_override.h" + + +struct MyUserContext { + int counter; + + MyUserContext() : counter(0) {} +}; + +extern "C" int halide_metal_command_buffer_completion_handler(void* user_context, struct halide_metal_command_buffer *, char **) { + auto ctx = (MyUserContext *)user_context; + ctx->counter++; + return halide_error_code_success; +} + +int main(int argc, char* argv[]) { +#if defined(TEST_METAL) + Halide::Runtime::Buffer output(32, 32); + + MyUserContext my_context; + metal_completion_handler_override(&my_context, output); + output.copy_to_host(); + + // Check the output + for (int y = 0; y < output.height(); y++) { + for (int x = 0; x < output.width(); x++) { + if (output(x, y) != x + y * 2) { + printf("Error: output(%d, %d) = %d instead of %d\n", x, y, output(x, y), x + y * 2); + return -1; + } + } + } + + if (my_context.counter < 1) { + printf("Error: completion handler was not called\n"); + return -1; + } + + printf("Success!\n"); +#else + printf("[SKIP] Metal not enabled\n"); +#endif + return 0; +} \ No newline at end of file diff --git a/test/generator/metal_completion_handler_override_generator.cpp b/test/generator/metal_completion_handler_override_generator.cpp new file mode 100644 index 000000000000..56a70aba985f --- /dev/null +++ b/test/generator/metal_completion_handler_override_generator.cpp @@ -0,0 +1,25 @@ +#include "Halide.h" + +namespace { + +class SimpleMetalPipeline : public Halide::Generator { +public: + Output> output{"output"}; + + void generate() { + Var x("x"), y("y"); + + // Create a simple pipeline that scales pixel values by 2. + output(x, y) = x + y * 2; + + Target target = get_target(); + if (target.has_gpu_feature()) { + Var xo, yo, xi, yi; + output.gpu_tile(x, y, xo, yo, xi, yi, 16, 16); + } + } +}; + +} // namespace + +HALIDE_REGISTER_GENERATOR(SimpleMetalPipeline, metal_completion_handler_override) \ No newline at end of file From 9270b21ec51fa23673795b4abc6524da5b8977ec Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Mon, 10 Jun 2024 14:19:03 -0400 Subject: [PATCH 06/12] Format. --- src/runtime/HalideRuntimeMetal.h | 6 ++--- src/runtime/metal.cpp | 23 ++++++++----------- ...u_metal_completion_handler_error_check.cpp | 1 - ...al_completion_handler_override_aottest.cpp | 11 +++++---- ..._completion_handler_override_generator.cpp | 2 +- 5 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/runtime/HalideRuntimeMetal.h b/src/runtime/HalideRuntimeMetal.h index 255d24455938..30762e07d8ae 100644 --- a/src/runtime/HalideRuntimeMetal.h +++ b/src/runtime/HalideRuntimeMetal.h @@ -96,13 +96,13 @@ extern int halide_metal_release_context(void *user_context); /** This function is called as part of the callback when a Metal command buffer completes. * The return value, if not halide_error_code_success, will be stashed in Metal runtime and returned - * to the next call into the runtime, and the error string will be saved as well. + * to the next call into the runtime, and the error string will be saved as well. * The error string will be freed by the caller. The return value must be a valid Halide error code. * This is called from the Metal driver, and thus: * - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback * - The function must be thread-safe -*/ -extern int halide_metal_command_buffer_completion_handler(void* user_context, struct halide_metal_command_buffer *buffer, + */ +extern int halide_metal_command_buffer_completion_handler(void *user_context, struct halide_metal_command_buffer *buffer, char **returned_error_string); #ifdef __cplusplus diff --git a/src/runtime/metal.cpp b/src/runtime/metal.cpp index f01a31a38532..f6c42c1500bd 100644 --- a/src/runtime/metal.cpp +++ b/src/runtime/metal.cpp @@ -418,7 +418,7 @@ class MetalContextHolder { saved_status = halide_error_code_success; if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) { strncpy(error_string, MetalContextHolder::error_string, 1024); - error_string[1023] = '\0'; + error_string[1023] = '\0'; MetalContextHolder::error_string[0] = '\0'; debug(nullptr) << "MetalContextHolder::get_and_clear_saved_status: " << error_string << "\n"; } @@ -432,7 +432,7 @@ class MetalContextHolder { int result = saved_status; if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) { strncpy(error_string, MetalContextHolder::error_string, 1024); - error_string[1023] = '\0'; + error_string[1023] = '\0'; } halide_mutex_unlock(&saved_status_mutex); return result; @@ -443,7 +443,7 @@ class MetalContextHolder { saved_status = new_status; if (error_string != nullptr) { strncpy(MetalContextHolder::error_string, error_string, 1024); - error_string[1023] = '\0'; + error_string[1023] = '\0'; debug(nullptr) << "MetalContextHolder::set_saved_status: " << error_string << "\n"; } halide_mutex_unlock(&saved_status_mutex); @@ -479,13 +479,13 @@ char MetalContextHolder::error_string[1024] = {0}; extern "C" { /** This function is called as part of the callback when a Metal command buffer completes. * The return value, if not halide_error_code_success, will be stashed in Metal runtime and returned - * to the next call into the runtime, and the error string will be saved as well. + * to the next call into the runtime, and the error string will be saved as well. * The error string will be freed by the caller. The return value must be a valid Halide error code. * This is called from the Metal driver, and thus: * - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback * - The function must be thread-safe -*/ -WEAK int halide_metal_command_buffer_completion_handler(void* user_context, mtl_command_buffer *buffer, char **returned_error_string) { + */ +WEAK int halide_metal_command_buffer_completion_handler(void *user_context, mtl_command_buffer *buffer, char **returned_error_string) { objc_id buffer_error = command_buffer_error(buffer); if (buffer_error != nullptr) { retain_ns_object(buffer_error); @@ -506,7 +506,7 @@ WEAK int halide_metal_command_buffer_completion_handler(void* user_context, mtl_ // Copy C-style string into a fresh buffer if (returned_error_string != nullptr) { - *returned_error_string = (char*)malloc(sizeof(char) * 1024); + *returned_error_string = (char *)malloc(sizeof(char) * 1024); if (*returned_error_string != nullptr) { strncpy(*returned_error_string, error_string, 1024); (*returned_error_string)[1023] = '\0'; @@ -521,7 +521,7 @@ WEAK int halide_metal_command_buffer_completion_handler(void* user_context, mtl_ } return halide_error_code_success; } -} // extern "C" +} // extern "C" namespace Halide { namespace Runtime { @@ -533,7 +533,7 @@ struct user_context_block_byref { struct user_context_block_byref *forwarding; int flags; int size; - void* user_context; + void *user_context; }; struct command_buffer_completed_handler_block_descriptor_1 { @@ -555,17 +555,14 @@ WEAK command_buffer_completed_handler_block_descriptor_1 command_buffer_complete WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handler_block_literal *block, mtl_command_buffer *buffer) { retain_ns_object(buffer); - char* error_string = nullptr; + char *error_string = nullptr; auto status = halide_metal_command_buffer_completion_handler(block->user_context_holder->user_context, buffer, &error_string); release_ns_object(buffer); MetalContextHolder::set_saved_status(status, error_string); free(error_string); - } - - } // namespace Metal } // namespace Internal } // namespace Runtime diff --git a/test/correctness/gpu_metal_completion_handler_error_check.cpp b/test/correctness/gpu_metal_completion_handler_error_check.cpp index 60e9c206d158..d2e5f26782ff 100644 --- a/test/correctness/gpu_metal_completion_handler_error_check.cpp +++ b/test/correctness/gpu_metal_completion_handler_error_check.cpp @@ -25,7 +25,6 @@ int main(int argc, char **argv) { f.gpu_tile(x, c, xi, ci, 4, 4); f.update(0).gpu_tile(r.x, c, rxi, ci, 4, 4); - // Metal is surprisingly resilient. Run this in a loop just to make sure we trigger the error. for (int i = 0; (i < 10) && !errored; i++) { auto out = f.realize({1000, 100}, t); diff --git a/test/generator/metal_completion_handler_override_aottest.cpp b/test/generator/metal_completion_handler_override_aottest.cpp index 39abbb82d6e5..a855af7b1703 100644 --- a/test/generator/metal_completion_handler_override_aottest.cpp +++ b/test/generator/metal_completion_handler_override_aottest.cpp @@ -4,20 +4,21 @@ #include "metal_completion_handler_override.h" - struct MyUserContext { int counter; - MyUserContext() : counter(0) {} + MyUserContext() + : counter(0) { + } }; -extern "C" int halide_metal_command_buffer_completion_handler(void* user_context, struct halide_metal_command_buffer *, char **) { +extern "C" int halide_metal_command_buffer_completion_handler(void *user_context, struct halide_metal_command_buffer *, char **) { auto ctx = (MyUserContext *)user_context; ctx->counter++; return halide_error_code_success; } -int main(int argc, char* argv[]) { +int main(int argc, char *argv[]) { #if defined(TEST_METAL) Halide::Runtime::Buffer output(32, 32); @@ -39,7 +40,7 @@ int main(int argc, char* argv[]) { printf("Error: completion handler was not called\n"); return -1; } - + printf("Success!\n"); #else printf("[SKIP] Metal not enabled\n"); diff --git a/test/generator/metal_completion_handler_override_generator.cpp b/test/generator/metal_completion_handler_override_generator.cpp index 56a70aba985f..8130a87710dc 100644 --- a/test/generator/metal_completion_handler_override_generator.cpp +++ b/test/generator/metal_completion_handler_override_generator.cpp @@ -10,7 +10,7 @@ class SimpleMetalPipeline : public Halide::Generator { Var x("x"), y("y"); // Create a simple pipeline that scales pixel values by 2. - output(x, y) = x + y * 2; + output(x, y) = x + y * 2; Target target = get_target(); if (target.has_gpu_feature()) { From b7ae0fb31b0d064e86d357dbcfc32f55720f2b3c Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Mon, 10 Jun 2024 14:48:40 -0400 Subject: [PATCH 07/12] Remove some debug code --- src/runtime/metal.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/runtime/metal.cpp b/src/runtime/metal.cpp index f6c42c1500bd..310c9566e01b 100644 --- a/src/runtime/metal.cpp +++ b/src/runtime/metal.cpp @@ -1015,8 +1015,8 @@ WEAK int halide_metal_run(void *user_context, user_context}; command_buffer_completed_handler_block_literal command_buffer_completed_handler_block = { - &_NSConcreteGlobalBlock, - /*(1 << 28) | */ (1 << 29), // BLOCK_IS_GLOBAL | BLOCK_HAS_DESCRIPTOR + &_NSConcreteStackBlock, + (1 << 29), // BLOCK_HAS_DESCRIPTOR 0, command_buffer_completed_handler_invoke, &command_buffer_completed_handler_descriptor, &user_context_holder}; From 51eb3aa870e360fc3d3eadaa2527feda4e4b6c78 Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Mon, 10 Jun 2024 16:15:14 -0400 Subject: [PATCH 08/12] Add missing includes. --- test/correctness/gpu_metal_completion_handler_error_check.cpp | 2 +- test/generator/metal_completion_handler_override_aottest.cpp | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/test/correctness/gpu_metal_completion_handler_error_check.cpp b/test/correctness/gpu_metal_completion_handler_error_check.cpp index d2e5f26782ff..f0bb396e2c12 100644 --- a/test/correctness/gpu_metal_completion_handler_error_check.cpp +++ b/test/correctness/gpu_metal_completion_handler_error_check.cpp @@ -1,5 +1,5 @@ #include "Halide.h" -#include +#include using namespace Halide; diff --git a/test/generator/metal_completion_handler_override_aottest.cpp b/test/generator/metal_completion_handler_override_aottest.cpp index a855af7b1703..270386e1f91a 100644 --- a/test/generator/metal_completion_handler_override_aottest.cpp +++ b/test/generator/metal_completion_handler_override_aottest.cpp @@ -1,3 +1,5 @@ +#include + #include "HalideBuffer.h" #include "HalideRuntime.h" #include "HalideRuntimeMetal.h" From faaa65e0ddc723b68d38d732854a4cc62c809eff Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Mon, 10 Jun 2024 19:10:11 -0400 Subject: [PATCH 09/12] Add comment noting why we manually null-terminate after strncpy --- src/runtime/metal.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/runtime/metal.cpp b/src/runtime/metal.cpp index 310c9566e01b..2f0d4deab3da 100644 --- a/src/runtime/metal.cpp +++ b/src/runtime/metal.cpp @@ -418,6 +418,7 @@ class MetalContextHolder { saved_status = halide_error_code_success; if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) { strncpy(error_string, MetalContextHolder::error_string, 1024); + // Ensure null-termination, since strncpy won't if the source string is too long error_string[1023] = '\0'; MetalContextHolder::error_string[0] = '\0'; debug(nullptr) << "MetalContextHolder::get_and_clear_saved_status: " << error_string << "\n"; @@ -432,6 +433,7 @@ class MetalContextHolder { int result = saved_status; if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) { strncpy(error_string, MetalContextHolder::error_string, 1024); + // Ensure null-termination, since strncpy won't if the source string is too long error_string[1023] = '\0'; } halide_mutex_unlock(&saved_status_mutex); @@ -443,6 +445,7 @@ class MetalContextHolder { saved_status = new_status; if (error_string != nullptr) { strncpy(MetalContextHolder::error_string, error_string, 1024); + // Ensure null-termination, since strncpy won't if the source string is too long error_string[1023] = '\0'; debug(nullptr) << "MetalContextHolder::set_saved_status: " << error_string << "\n"; } @@ -509,6 +512,7 @@ WEAK int halide_metal_command_buffer_completion_handler(void *user_context, mtl_ *returned_error_string = (char *)malloc(sizeof(char) * 1024); if (*returned_error_string != nullptr) { strncpy(*returned_error_string, error_string, 1024); + // Ensure null-termination, since strncpy won't if the source string is too long (*returned_error_string)[1023] = '\0'; } else { debug(user_context) << "halide_metal_command_buffer_completion_handler: Failed to allocate memory for error string.\n"; From 3c3a65679906a8f29441c3a8f0974131e817c016 Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Wed, 12 Jun 2024 14:06:31 -0400 Subject: [PATCH 10/12] Reverse engineer Objective-C API for passing void* in a block; it turns out to be much simpler than I thought --- src/runtime/HalideRuntimeMetal.h | 3 +- src/runtime/metal.cpp | 31 ++++++------------- ...al_completion_handler_override_aottest.cpp | 6 +++- 3 files changed, 16 insertions(+), 24 deletions(-) diff --git a/src/runtime/HalideRuntimeMetal.h b/src/runtime/HalideRuntimeMetal.h index 30762e07d8ae..f171ca33bb6c 100644 --- a/src/runtime/HalideRuntimeMetal.h +++ b/src/runtime/HalideRuntimeMetal.h @@ -101,8 +101,9 @@ extern int halide_metal_release_context(void *user_context); * This is called from the Metal driver, and thus: * - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback * - The function must be thread-safe + * - For Objective-C API reasons, the user context is passed in as a void *const */ -extern int halide_metal_command_buffer_completion_handler(void *user_context, struct halide_metal_command_buffer *buffer, +extern int halide_metal_command_buffer_completion_handler(void *const user_context, struct halide_metal_command_buffer *buffer, char **returned_error_string); #ifdef __cplusplus diff --git a/src/runtime/metal.cpp b/src/runtime/metal.cpp index 2f0d4deab3da..36b60d863a6d 100644 --- a/src/runtime/metal.cpp +++ b/src/runtime/metal.cpp @@ -487,8 +487,9 @@ extern "C" { * This is called from the Metal driver, and thus: * - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback * - The function must be thread-safe + * - For Objective-C API reasons, the user context is passed in as a void *const */ -WEAK int halide_metal_command_buffer_completion_handler(void *user_context, mtl_command_buffer *buffer, char **returned_error_string) { +WEAK int halide_metal_command_buffer_completion_handler(void *const user_context, mtl_command_buffer *buffer, char **returned_error_string) { objc_id buffer_error = command_buffer_error(buffer); if (buffer_error != nullptr) { retain_ns_object(buffer_error); @@ -532,14 +533,6 @@ namespace Runtime { namespace Internal { namespace Metal { -struct user_context_block_byref { - void *isa; - struct user_context_block_byref *forwarding; - int flags; - int size; - void *user_context; -}; - struct command_buffer_completed_handler_block_descriptor_1 { unsigned long reserved; unsigned long block_size; @@ -551,7 +544,7 @@ struct command_buffer_completed_handler_block_literal { int reserved; void (*invoke)(command_buffer_completed_handler_block_literal *, mtl_command_buffer *buffer); struct command_buffer_completed_handler_block_descriptor_1 *descriptor; - struct user_context_block_byref *user_context_holder; + void *const user_context; }; WEAK command_buffer_completed_handler_block_descriptor_1 command_buffer_completed_handler_descriptor = { @@ -560,7 +553,9 @@ WEAK command_buffer_completed_handler_block_descriptor_1 command_buffer_complete WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handler_block_literal *block, mtl_command_buffer *buffer) { retain_ns_object(buffer); char *error_string = nullptr; - auto status = halide_metal_command_buffer_completion_handler(block->user_context_holder->user_context, buffer, &error_string); + void *const user_context = block->user_context; + + auto status = halide_metal_command_buffer_completion_handler(user_context, buffer, &error_string); release_ns_object(buffer); MetalContextHolder::set_saved_status(status, error_string); @@ -1009,21 +1004,13 @@ WEAK int halide_metal_run(void *user_context, blocksX, blocksY, blocksZ, threadsX, threadsY, threadsZ); end_encoding(encoder); - - // Construct an Objective C block to check for errors on command buffer completion, saving the user context - user_context_block_byref user_context_holder = { - &_NSConcreteStackBlock, - &user_context_holder, - 0, - sizeof(user_context_holder), - user_context}; - + command_buffer_completed_handler_block_literal command_buffer_completed_handler_block = { &_NSConcreteStackBlock, - (1 << 29), // BLOCK_HAS_DESCRIPTOR + 0, // must be 0 for stack blocks 0, command_buffer_completed_handler_invoke, &command_buffer_completed_handler_descriptor, - &user_context_holder}; + user_context}; add_command_buffer_completed_handler(command_buffer, &command_buffer_completed_handler_block); diff --git a/test/generator/metal_completion_handler_override_aottest.cpp b/test/generator/metal_completion_handler_override_aottest.cpp index 270386e1f91a..d6cd7fb5a112 100644 --- a/test/generator/metal_completion_handler_override_aottest.cpp +++ b/test/generator/metal_completion_handler_override_aottest.cpp @@ -14,7 +14,11 @@ struct MyUserContext { } }; -extern "C" int halide_metal_command_buffer_completion_handler(void *user_context, struct halide_metal_command_buffer *, char **) { +extern "C" int halide_metal_command_buffer_completion_handler(void *const user_context, struct halide_metal_command_buffer *, char **) { + if (user_context == nullptr) { + printf("Error: user_context is nullptr\n"); + return -1; + } auto ctx = (MyUserContext *)user_context; ctx->counter++; return halide_error_code_success; From c3a11b2590ef178f86b8a7fcb06147c3c5741617 Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Wed, 12 Jun 2024 14:23:54 -0400 Subject: [PATCH 11/12] Formatting --- src/runtime/metal.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/runtime/metal.cpp b/src/runtime/metal.cpp index 36b60d863a6d..41cca0d0dc42 100644 --- a/src/runtime/metal.cpp +++ b/src/runtime/metal.cpp @@ -1004,10 +1004,10 @@ WEAK int halide_metal_run(void *user_context, blocksX, blocksY, blocksZ, threadsX, threadsY, threadsZ); end_encoding(encoder); - + command_buffer_completed_handler_block_literal command_buffer_completed_handler_block = { &_NSConcreteStackBlock, - 0, // must be 0 for stack blocks + 0, // must be 0 for stack blocks 0, command_buffer_completed_handler_invoke, &command_buffer_completed_handler_descriptor, user_context}; From 588a59c03703f02351260914b1feed4afe2dbcbe Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Wed, 12 Jun 2024 15:10:00 -0400 Subject: [PATCH 12/12] Don't add const-ness to declaration. --- src/runtime/HalideRuntimeMetal.h | 3 +-- src/runtime/metal.cpp | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/runtime/HalideRuntimeMetal.h b/src/runtime/HalideRuntimeMetal.h index f171ca33bb6c..30762e07d8ae 100644 --- a/src/runtime/HalideRuntimeMetal.h +++ b/src/runtime/HalideRuntimeMetal.h @@ -101,9 +101,8 @@ extern int halide_metal_release_context(void *user_context); * This is called from the Metal driver, and thus: * - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback * - The function must be thread-safe - * - For Objective-C API reasons, the user context is passed in as a void *const */ -extern int halide_metal_command_buffer_completion_handler(void *const user_context, struct halide_metal_command_buffer *buffer, +extern int halide_metal_command_buffer_completion_handler(void *user_context, struct halide_metal_command_buffer *buffer, char **returned_error_string); #ifdef __cplusplus diff --git a/src/runtime/metal.cpp b/src/runtime/metal.cpp index 41cca0d0dc42..1fe7d895561b 100644 --- a/src/runtime/metal.cpp +++ b/src/runtime/metal.cpp @@ -487,7 +487,6 @@ extern "C" { * This is called from the Metal driver, and thus: * - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback * - The function must be thread-safe - * - For Objective-C API reasons, the user context is passed in as a void *const */ WEAK int halide_metal_command_buffer_completion_handler(void *const user_context, mtl_command_buffer *buffer, char **returned_error_string) { objc_id buffer_error = command_buffer_error(buffer);