From 2185f4e55e20d55b55fb63fba519c2b4c8354ab3 Mon Sep 17 00:00:00 2001 From: Derek Gerstmann Date: Wed, 7 Feb 2024 09:43:58 -0800 Subject: [PATCH] Fix bool conversion bug in Vulkan code generator (#8067) * Fix bug in Vulkan code generator that was incorrectly passing the address of a byte vector, instead of its contents to builder.declare_constant() * Add bool_predicate_cast correctness test to verify bool conversion for Vulkan codegen works as expected --------- Co-authored-by: Derek Gerstmann --- src/CodeGen_Vulkan_Dev.cpp | 7 +++-- test/correctness/CMakeLists.txt | 1 + test/correctness/bool_predicate_cast.cpp | 39 ++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) create mode 100644 test/correctness/bool_predicate_cast.cpp diff --git a/src/CodeGen_Vulkan_Dev.cpp b/src/CodeGen_Vulkan_Dev.cpp index 7e06447a27fc..b86c99f9269e 100644 --- a/src/CodeGen_Vulkan_Dev.cpp +++ b/src/CodeGen_Vulkan_Dev.cpp @@ -548,6 +548,9 @@ void fill_bytes_with_value(uint8_t *bytes, int count, int value) { } SpvId CodeGen_Vulkan_Dev::SPIRV_Emitter::convert_to_bool(Type target_type, Type value_type, SpvId value_id) { + debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::convert_to_bool(): casting from value type '" + << value_type << "' to target type '" << target_type << "' for value id '" << value_id << "' !\n"; + if (!value_type.is_bool()) { value_id = cast_type(Bool(), value_type, value_id); } @@ -590,8 +593,8 @@ SpvId CodeGen_Vulkan_Dev::SPIRV_Emitter::convert_to_bool(Type target_type, Type SpvId result_id = builder.reserve_id(SpvResultId); SpvId target_type_id = builder.declare_type(target_type); - SpvId true_value_id = builder.declare_constant(target_type, &true_data); - SpvId false_value_id = builder.declare_constant(target_type, &false_data); + SpvId true_value_id = builder.declare_constant(target_type, &true_data[0]); + SpvId false_value_id = builder.declare_constant(target_type, &false_data[0]); builder.append(SpvFactory::select(target_type_id, result_id, value_id, true_value_id, false_value_id)); return result_id; } diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index cd66f21a346e..5960e7922658 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -14,6 +14,7 @@ tests(GROUPS correctness bit_counting.cpp bitwise_ops.cpp bool_compute_root_vectorize.cpp + bool_predicate_cast.cpp bound.cpp bound_small_allocations.cpp bound_storage.cpp diff --git a/test/correctness/bool_predicate_cast.cpp b/test/correctness/bool_predicate_cast.cpp new file mode 100644 index 000000000000..1043f329b76c --- /dev/null +++ b/test/correctness/bool_predicate_cast.cpp @@ -0,0 +1,39 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + + // Test explicit casting of a predicate to an integer as part of a reduction + // NOTE: triggers a convert_to_bool in Vulkan for a SelectOp + Target target = get_jit_target_from_environment(); + Var x("x"), y("y"); + + Func input("input"); + input(x, y) = cast(x + y); + + Func test("test"); + test(x, y) = cast(UInt(8), input(x, y) >= 32); + + if (target.has_gpu_feature()) { + Var xi("xi"), yi("yi"); + test.gpu_tile(x, y, xi, yi, 8, 8); + } + + Realization result = test.realize({96, 96}); + Buffer a = result[0]; + for (int y = 0; y < a.height(); y++) { + for (int x = 0; x < a.width(); x++) { + uint8_t correct_a = ((x + y) >= 32) ? 1 : 0; + if (a(x, y) != correct_a) { + printf("result(%d, %d) = (%d) instead of (%d)\n", + x, y, a(x, y), correct_a); + return 1; + } + } + } + + printf("Success!\n"); + return 0; +}