From 40585c1ff28aaf55d4ec74cc1b58c35d39ae5d81 Mon Sep 17 00:00:00 2001 From: Yibo Cai Date: Tue, 15 Jun 2021 16:44:37 +0200 Subject: [PATCH] ARROW-13003: [C++] Fix key map unaligned access Closes #10489 from cyb70289/13003-unaligned-access Authored-by: Yibo Cai Signed-off-by: Antoine Pitrou --- cpp/src/arrow/compute/exec/key_compare.cc | 21 +++---- cpp/src/arrow/compute/exec/key_map.cc | 71 +++++++++++++---------- cpp/src/arrow/compute/exec/util.cc | 16 ++--- 3 files changed, 59 insertions(+), 49 deletions(-) diff --git a/cpp/src/arrow/compute/exec/key_compare.cc b/cpp/src/arrow/compute/exec/key_compare.cc index f8d74859b0170..7a5b0be9990f6 100644 --- a/cpp/src/arrow/compute/exec/key_compare.cc +++ b/cpp/src/arrow/compute/exec/key_compare.cc @@ -21,6 +21,7 @@ #include #include "arrow/compute/exec/util.h" +#include "arrow/util/ubsan.h" namespace arrow { namespace compute { @@ -170,19 +171,19 @@ void KeyCompare::CompareFixedLengthImp(uint32_t num_rows_already_processed, // if (num_64bit_words == 0) { for (; istripe < num_loops_less_one; ++istripe) { - uint64_t key_left = key_left_ptr[istripe]; - uint64_t key_right = key_right_ptr[istripe]; + uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]); + uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]); result_or |= (key_left ^ key_right); } } else if (num_64bit_words == 2) { - uint64_t key_left = key_left_ptr[istripe]; - uint64_t key_right = key_right_ptr[istripe]; + uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]); + uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]); result_or |= (key_left ^ key_right); ++istripe; } - uint64_t key_left = key_left_ptr[istripe]; - uint64_t key_right = key_right_ptr[istripe]; + uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]); + uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]); result_or |= (tail_mask & (key_left ^ key_right)); int result = (result_or == 0 ? 0xff : 0); @@ -246,16 +247,16 @@ void KeyCompare::CompareVaryingLengthImp( int32_t istripe; // length can be zero for (istripe = 0; istripe < (static_cast(length) + 7) / 8 - 1; ++istripe) { - uint64_t key_left = key_left_ptr[istripe]; - uint64_t key_right = key_right_ptr[istripe]; + uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]); + uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]); result_or |= (key_left ^ key_right); } uint32_t length_remaining = length - static_cast(istripe) * 8; uint64_t tail_mask = tail_masks[length_remaining]; - uint64_t key_left = key_left_ptr[istripe]; - uint64_t key_right = key_right_ptr[istripe]; + uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]); + uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]); result_or |= (tail_mask & (key_left ^ key_right)); int result = (result_or == 0 ? 0xff : 0); diff --git a/cpp/src/arrow/compute/exec/key_map.cc b/cpp/src/arrow/compute/exec/key_map.cc index c48487793e035..ac47c04403c72 100644 --- a/cpp/src/arrow/compute/exec/key_map.cc +++ b/cpp/src/arrow/compute/exec/key_map.cc @@ -24,6 +24,7 @@ #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" +#include "arrow/util/ubsan.h" namespace arrow { @@ -153,7 +154,7 @@ void SwissTable::lookup_1(const uint16_t* selection, const int num_keys, for (int i = 0; i < num_keys; ++i) { int id; if (use_selection) { - id = selection[i]; + id = util::SafeLoad(&selection[i]); } else { id = i; } @@ -168,7 +169,7 @@ void SwissTable::lookup_1(const uint16_t* selection, const int num_keys, uint32_t num_block_bytes = num_groupid_bits + 8; const uint8_t* blockbase = reinterpret_cast(blocks_) + static_cast(iblock) * num_block_bytes; - uint64_t block = *reinterpret_cast(blockbase); + uint64_t block = util::SafeLoadAs(blockbase); // Call helper functions to obtain the output triplet: // - match (of a stamp) found flag @@ -182,8 +183,8 @@ void SwissTable::lookup_1(const uint16_t* selection, const int num_keys, uint64_t islot = next_slot_to_visit(iblock, islot_in_block, match_found); out_match_bitvector[id / 8] |= match_found << (id & 7); - out_groupids[id] = static_cast(groupid); - out_slot_ids[id] = static_cast(islot); + util::SafeStore(&out_groupids[id], static_cast(groupid)); + util::SafeStore(&out_slot_ids[id], static_cast(islot)); } } @@ -239,7 +240,7 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected uint16_t* ids[3]{inout_selection, ids_for_comparison_buf.mutable_data(), ids_inserted_buf.mutable_data()}; auto push_id = [&num_ids, &ids](int category, int id) { - ids[category][num_ids[category]++] = static_cast(id); + util::SafeStore(&ids[category][num_ids[category]++], static_cast(id)); }; uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); @@ -256,9 +257,9 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected num_inserted_ + num_ids[category_inserted] < num_groups_limit; ++num_processed) { // row id in original batch - int id = inout_selection[num_processed]; + int id = util::SafeLoad(&inout_selection[num_processed]); - uint64_t slot_id = wrap_global_slot_id(inout_next_slot_ids[id]); + uint64_t slot_id = wrap_global_slot_id(util::SafeLoad(&inout_next_slot_ids[id])); uint64_t block_id = slot_id >> 3; uint32_t hash = hashes[id]; uint8_t* blockbase = blocks_ + num_block_bytes * block_id; @@ -278,11 +279,13 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected // In that case we can insert group id value using aligned 64-bit word access. ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 || num_groupid_bits == 32 || num_groupid_bits == 64); - reinterpret_cast(blockbase + 8)[groupid_bit_offset >> 6] |= - (static_cast(group_id) << (groupid_bit_offset & 63)); + uint64_t* ptr = + &reinterpret_cast(blockbase + 8)[groupid_bit_offset >> 6]; + util::SafeStore(ptr, util::SafeLoad(ptr) | (static_cast(group_id) + << (groupid_bit_offset & 63))); hashes_[slot_id] = hash; - out_group_ids[id] = group_id; + util::SafeStore(&out_group_ids[id], group_id); push_id(category_inserted, id); } else { // We search for a slot with a matching stamp within a single block. @@ -298,8 +301,8 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected ARROW_DCHECK(new_groupid < num_inserted_ + num_ids[category_inserted]); new_slot = static_cast(next_slot_to_visit(block_id, new_slot, new_match_found)); - inout_next_slot_ids[id] = new_slot; - out_group_ids[id] = new_groupid; + util::SafeStore(&inout_next_slot_ids[id], new_slot); + util::SafeStore(&out_group_ids[id], new_groupid); push_id(new_match_found, id); } } @@ -410,7 +413,8 @@ Status SwissTable::map(const int num_keys, const uint32_t* hashes, // for (uint32_t i = 0; i < num_ids; ++i) { // First slot in the new starting block - slot_ids[ids[i]] = (hashes[ids[i]] >> (bits_hash_ - log_blocks_)) * 8; + const int16_t id = util::SafeLoad(&ids[i]); + util::SafeStore(&slot_ids[id], (hashes[id] >> (bits_hash_ - log_blocks_)) * 8); } } } while (num_ids > 0); @@ -457,9 +461,8 @@ Status SwissTable::grow_double() { static_cast(CountLeadingZeros(block & kHighBitOfEachByte) >> 3); int full_slots_new[2]; full_slots_new[0] = full_slots_new[1] = 0; - *reinterpret_cast(double_block_base_new) = kHighBitOfEachByte; - *reinterpret_cast(double_block_base_new + block_size_after) = - kHighBitOfEachByte; + util::SafeStore(double_block_base_new, kHighBitOfEachByte); + util::SafeStore(double_block_base_new + block_size_after, kHighBitOfEachByte); for (int j = 0; j < full_slots; ++j) { uint64_t slot_id = i * 8 + j; @@ -474,18 +477,20 @@ Status SwissTable::grow_double() { uint8_t stamp_new = hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask; uint64_t group_id_bit_offs = j * num_group_id_bits_before; - uint64_t group_id = (*reinterpret_cast(block_base + 8 + - (group_id_bit_offs >> 3)) >> - (group_id_bit_offs & 7)) & - group_id_mask_before; + uint64_t group_id = + (util::SafeLoadAs(block_base + 8 + (group_id_bit_offs >> 3)) >> + (group_id_bit_offs & 7)) & + group_id_mask_before; uint64_t slot_id_new = i * 16 + ihalf * 8 + full_slots_new[ihalf]; hashes_new[slot_id_new] = hash; uint8_t* block_base_new = double_block_base_new + ihalf * block_size_after; block_base_new[7 - full_slots_new[ihalf]] = stamp_new; int group_id_bit_offs_new = full_slots_new[ihalf] * num_group_id_bits_after; - *reinterpret_cast(block_base_new + 8 + (group_id_bit_offs_new >> 3)) |= - (group_id << (group_id_bit_offs_new & 7)); + uint64_t* ptr = + reinterpret_cast(block_base_new + 8 + (group_id_bit_offs_new >> 3)); + util::SafeStore(ptr, + util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7))); full_slots_new[ihalf]++; } } @@ -495,7 +500,7 @@ Status SwissTable::grow_double() { for (int i = 0; i < (1 << log_blocks_); ++i) { // How many full slots in this block uint8_t* block_base = blocks_ + i * block_size_before; - uint64_t block = *reinterpret_cast(block_base); + uint64_t block = util::SafeLoadAs(block_base); int full_slots = static_cast(CountLeadingZeros(block & kHighBitOfEachByte) >> 3); for (int j = 0; j < full_slots; ++j) { @@ -508,21 +513,21 @@ Status SwissTable::grow_double() { } uint64_t group_id_bit_offs = j * num_group_id_bits_before; - uint64_t group_id = (*reinterpret_cast(block_base + 8 + - (group_id_bit_offs >> 3)) >> - (group_id_bit_offs & 7)) & - group_id_mask_before; + uint64_t group_id = + (util::SafeLoadAs(block_base + 8 + (group_id_bit_offs >> 3)) >> + (group_id_bit_offs & 7)) & + group_id_mask_before; uint8_t stamp_new = hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask; uint8_t* block_base_new = blocks_new + block_id_new * block_size_after; - uint64_t block_new = *reinterpret_cast(block_base_new); + uint64_t block_new = util::SafeLoadAs(block_base_new); int full_slots_new = static_cast(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3); while (full_slots_new == 8) { block_id_new = (block_id_new + 1) & ((1 << log_blocks_after) - 1); block_base_new = blocks_new + block_id_new * block_size_after; - block_new = *reinterpret_cast(block_base_new); + block_new = util::SafeLoadAs(block_base_new); full_slots_new = static_cast(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3); } @@ -530,8 +535,10 @@ Status SwissTable::grow_double() { hashes_new[block_id_new * 8 + full_slots_new] = hash; block_base_new[7 - full_slots_new] = stamp_new; int group_id_bit_offs_new = full_slots_new * num_group_id_bits_after; - *reinterpret_cast(block_base_new + 8 + (group_id_bit_offs_new >> 3)) |= - (group_id << (group_id_bit_offs_new & 7)); + uint64_t* ptr = + reinterpret_cast(block_base_new + 8 + (group_id_bit_offs_new >> 3)); + util::SafeStore(ptr, + util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7))); } } @@ -567,7 +574,7 @@ Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, // Initialize all status bytes to represent an empty slot. for (uint64_t i = 0; i < (static_cast(1) << log_blocks_); ++i) { - *reinterpret_cast(blocks_ + i * block_bytes) = kHighBitOfEachByte; + util::SafeStore(blocks_ + i * block_bytes, kHighBitOfEachByte); } uint64_t num_slots = 1ULL << (log_blocks_ + 3); diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc index 5f1c0776c564d..88303348645b1 100644 --- a/cpp/src/arrow/compute/exec/util.cc +++ b/cpp/src/arrow/compute/exec/util.cc @@ -19,6 +19,7 @@ #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" +#include "arrow/util/ubsan.h" namespace arrow { @@ -66,7 +67,7 @@ void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bit #endif *num_indexes = 0; for (int i = 0; i < num_bits / unroll; ++i) { - uint64_t word = reinterpret_cast(bits)[i]; + uint64_t word = util::SafeLoad(&reinterpret_cast(bits)[i]); if (bit_to_search == 0) { word = ~word; } @@ -81,7 +82,8 @@ void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bit #endif // Optionally process the last partial word with masking out bits outside range if (tail) { - uint64_t word = reinterpret_cast(bits)[num_bits / unroll]; + uint64_t word = + util::SafeLoad(&reinterpret_cast(bits)[num_bits / unroll]); if (bit_to_search == 0) { word = ~word; } @@ -144,7 +146,7 @@ void BitUtil::bits_to_bytes_internal(const int num_bits, const uint8_t* bits, unpacked |= (bits_next & 1); unpacked &= 0x0101010101010101ULL; unpacked *= 255; - reinterpret_cast(bytes)[i] = unpacked; + util::SafeStore(&reinterpret_cast(bytes)[i], unpacked); } } @@ -153,7 +155,7 @@ void BitUtil::bytes_to_bits_internal(const int num_bits, const uint8_t* bytes, constexpr int unroll = 8; // Process 8 bits at a time for (int i = 0; i < (num_bits + unroll - 1) / unroll; ++i) { - uint64_t bytes_next = reinterpret_cast(bytes)[i]; + uint64_t bytes_next = util::SafeLoad(&reinterpret_cast(bytes)[i]); bytes_next &= 0x0101010101010101ULL; bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes @@ -184,7 +186,7 @@ void BitUtil::bits_to_bytes(int64_t hardware_flags, const int num_bits, unpacked |= (bits_next & 1); unpacked &= 0x0101010101010101ULL; unpacked *= 255; - reinterpret_cast(bytes)[i] = unpacked; + util::SafeStore(&reinterpret_cast(bytes)[i], unpacked); } } @@ -201,7 +203,7 @@ void BitUtil::bytes_to_bits(int64_t hardware_flags, const int num_bits, // Process 8 bits at a time constexpr int unroll = 8; for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { - uint64_t bytes_next = reinterpret_cast(bytes)[i]; + uint64_t bytes_next = util::SafeLoad(&reinterpret_cast(bytes)[i]); bytes_next &= 0x0101010101010101ULL; bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes @@ -220,7 +222,7 @@ bool BitUtil::are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes, uint64_t result_or = 0; uint32_t i; for (i = 0; i < num_bytes / 8; ++i) { - uint64_t x = reinterpret_cast(bytes)[i]; + uint64_t x = util::SafeLoad(&reinterpret_cast(bytes)[i]); result_or |= x; } if (num_bytes % 8 > 0) {