Skip to content

Commit

Permalink
fixed bug with aligned alloc for macos
Browse files Browse the repository at this point in the history
  • Loading branch information
GagaLP authored and fknorr committed Feb 7, 2024
1 parent 28a74bb commit 0b94209
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 21 deletions.
10 changes: 9 additions & 1 deletion include/simsycl/detail/allocation.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "../sycl/enums.hh"
#include "../sycl/forward.hh"

#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <optional>
Expand All @@ -15,9 +16,16 @@ namespace simsycl::detail {
// NOTE: returned pointers must be freed with aligned_free
inline void *aligned_alloc(size_t alignment, size_t size) {
#if defined(_MSC_VER)
// MSVC does not have std::aligned_alloc because the pointers it returns cannot be freed with std::free
return _aligned_malloc(size, alignment);
#else
return std::aligned_alloc(alignment, size);
// POSIX and notably macOS requires a minimum alignment of sizeof(void*) for aligned_alloc, so we only use that for
// over-aligned allocations
if(alignment <= alignof(std::max_align_t)) {
return std::malloc(size);
} else {
return std::aligned_alloc(alignment, size);
}
#endif
}

Expand Down
20 changes: 2 additions & 18 deletions src/simsycl/system.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,19 +182,7 @@ void *usm_alloc(const sycl::context &context, sycl::usm::alloc kind, std::option
if(*bytes_free < size_bytes) return nullptr;
}

void *ptr;
#if defined(_MSC_VER)
// MSVC does not have std::aligned_alloc because the pointers it returns cannot be freed with std::free
ptr = _aligned_malloc(size_bytes, alignment_bytes);
#else
// POSIX and notably macOS requires a minimum alignment of sizeof(void*) for aligned_alloc, so we only use that for
// over-aligned allocations
if(alignment_bytes <= alignof(std::max_align_t)) {
ptr = std::malloc(size_bytes);
} else {
ptr = std::aligned_alloc(alignment_bytes, size_bytes);
}
#endif
void *ptr = detail::aligned_alloc(alignment_bytes, size_bytes);

if(ptr == nullptr) return nullptr;
std::memset(ptr, static_cast<int>(uninitialized_memory_pattern), size_bytes);
Expand All @@ -220,11 +208,7 @@ void usm_free(void *ptr, const sycl::context &context) {
throw sycl::exception(sycl::errc::invalid, "Pointer is not associated with the given context");
}

#if defined(_MSC_VER)
_aligned_free(ptr);
#else
std::free(ptr);
#endif
detail::aligned_free(ptr);

if(iter->get_device().has_value()) {
*detail::device_bytes_free(iter->get_device().value()) += iter->get_size_bytes();
Expand Down
2 changes: 1 addition & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ add_executable(tests
math_tests.cc
reduction_tests.cc
simulation_tests.cc
usm_tests.cc
alloc_tests.cc
vec_tests.cc
)

Expand Down
13 changes: 12 additions & 1 deletion test/usm_tests.cc → test/alloc_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@

#include <catch2/catch_test_macros.hpp>

TEST_CASE("allocates memory of any alignment", "[aligned_alloc]") {
const size_t largest_sycl_align_bytes = alignof(sycl::long16);
const size_t size_bytes = 4096;
CAPTURE(size_bytes);
for(size_t align_bytes = 1; align_bytes <= largest_sycl_align_bytes; align_bytes *= 2) {
CAPTURE(align_bytes);
const auto p = simsycl::detail::aligned_alloc(align_bytes, size_bytes);
CHECK(p != nullptr);
CHECK(reinterpret_cast<std::uintptr_t>(p) % align_bytes == 0);
simsycl::detail::aligned_free(p);
}
}

TEST_CASE("usm_alloc allocates memory of any alignment", "[usm]") {
const size_t largest_sycl_align_bytes = alignof(sycl::long16);
Expand All @@ -11,7 +23,6 @@ TEST_CASE("usm_alloc allocates memory of any alignment", "[usm]") {
sycl::context ctx;
for(size_t align_bytes = 1; align_bytes <= largest_sycl_align_bytes; align_bytes *= 2) {
CAPTURE(align_bytes);
errno = 0;
const auto p = simsycl::detail::usm_alloc(ctx, sycl::usm::alloc::host, std::nullopt, size_bytes, align_bytes);
CHECK(p != nullptr);
CHECK(reinterpret_cast<std::uintptr_t>(p) % align_bytes == 0);
Expand Down

0 comments on commit 0b94209

Please sign in to comment.