Skip to content

Commit

Permalink
ARROW-13003: [C++] Fix key map unaligned access
Browse files Browse the repository at this point in the history
Closes apache#10489 from cyb70289/13003-unaligned-access

Authored-by: Yibo Cai <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
cyb70289 authored and pitrou committed Jun 15, 2021
1 parent 4b3f6c3 commit 889291b
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 49 deletions.
21 changes: 11 additions & 10 deletions cpp/src/arrow/compute/exec/key_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cstdint>

#include "arrow/compute/exec/util.h"
#include "arrow/util/ubsan.h"

namespace arrow {
namespace compute {
Expand Down Expand Up @@ -170,19 +171,19 @@ void KeyCompare::CompareFixedLengthImp(uint32_t num_rows_already_processed,
//
if (num_64bit_words == 0) {
for (; istripe < num_loops_less_one; ++istripe) {
uint64_t key_left = key_left_ptr[istripe];
uint64_t key_right = key_right_ptr[istripe];
uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]);
uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]);
result_or |= (key_left ^ key_right);
}
} else if (num_64bit_words == 2) {
uint64_t key_left = key_left_ptr[istripe];
uint64_t key_right = key_right_ptr[istripe];
uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]);
uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]);
result_or |= (key_left ^ key_right);
++istripe;
}

uint64_t key_left = key_left_ptr[istripe];
uint64_t key_right = key_right_ptr[istripe];
uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]);
uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]);
result_or |= (tail_mask & (key_left ^ key_right));

int result = (result_or == 0 ? 0xff : 0);
Expand Down Expand Up @@ -246,16 +247,16 @@ void KeyCompare::CompareVaryingLengthImp(
int32_t istripe;
// length can be zero
for (istripe = 0; istripe < (static_cast<int32_t>(length) + 7) / 8 - 1; ++istripe) {
uint64_t key_left = key_left_ptr[istripe];
uint64_t key_right = key_right_ptr[istripe];
uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]);
uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]);
result_or |= (key_left ^ key_right);
}

uint32_t length_remaining = length - static_cast<uint32_t>(istripe) * 8;
uint64_t tail_mask = tail_masks[length_remaining];

uint64_t key_left = key_left_ptr[istripe];
uint64_t key_right = key_right_ptr[istripe];
uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]);
uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]);
result_or |= (tail_mask & (key_left ^ key_right));

int result = (result_or == 0 ? 0xff : 0);
Expand Down
71 changes: 39 additions & 32 deletions cpp/src/arrow/compute/exec/key_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/ubsan.h"

namespace arrow {

Expand Down Expand Up @@ -153,7 +154,7 @@ void SwissTable::lookup_1(const uint16_t* selection, const int num_keys,
for (int i = 0; i < num_keys; ++i) {
int id;
if (use_selection) {
id = selection[i];
id = util::SafeLoad(&selection[i]);
} else {
id = i;
}
Expand All @@ -168,7 +169,7 @@ void SwissTable::lookup_1(const uint16_t* selection, const int num_keys,
uint32_t num_block_bytes = num_groupid_bits + 8;
const uint8_t* blockbase = reinterpret_cast<const uint8_t*>(blocks_) +
static_cast<uint64_t>(iblock) * num_block_bytes;
uint64_t block = *reinterpret_cast<const uint64_t*>(blockbase);
uint64_t block = util::SafeLoadAs<uint64_t>(blockbase);

// Call helper functions to obtain the output triplet:
// - match (of a stamp) found flag
Expand All @@ -182,8 +183,8 @@ void SwissTable::lookup_1(const uint16_t* selection, const int num_keys,
uint64_t islot = next_slot_to_visit(iblock, islot_in_block, match_found);

out_match_bitvector[id / 8] |= match_found << (id & 7);
out_groupids[id] = static_cast<uint32_t>(groupid);
out_slot_ids[id] = static_cast<uint32_t>(islot);
util::SafeStore(&out_groupids[id], static_cast<uint32_t>(groupid));
util::SafeStore(&out_slot_ids[id], static_cast<uint32_t>(islot));
}
}

Expand Down Expand Up @@ -239,7 +240,7 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected
uint16_t* ids[3]{inout_selection, ids_for_comparison_buf.mutable_data(),
ids_inserted_buf.mutable_data()};
auto push_id = [&num_ids, &ids](int category, int id) {
ids[category][num_ids[category]++] = static_cast<uint16_t>(id);
util::SafeStore(&ids[category][num_ids[category]++], static_cast<uint16_t>(id));
};

uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
Expand All @@ -256,9 +257,9 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected
num_inserted_ + num_ids[category_inserted] < num_groups_limit;
++num_processed) {
// row id in original batch
int id = inout_selection[num_processed];
int id = util::SafeLoad(&inout_selection[num_processed]);

uint64_t slot_id = wrap_global_slot_id(inout_next_slot_ids[id]);
uint64_t slot_id = wrap_global_slot_id(util::SafeLoad(&inout_next_slot_ids[id]));
uint64_t block_id = slot_id >> 3;
uint32_t hash = hashes[id];
uint8_t* blockbase = blocks_ + num_block_bytes * block_id;
Expand All @@ -278,11 +279,13 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected
// In that case we can insert group id value using aligned 64-bit word access.
ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 ||
num_groupid_bits == 32 || num_groupid_bits == 64);
reinterpret_cast<uint64_t*>(blockbase + 8)[groupid_bit_offset >> 6] |=
(static_cast<uint64_t>(group_id) << (groupid_bit_offset & 63));
uint64_t* ptr =
&reinterpret_cast<uint64_t*>(blockbase + 8)[groupid_bit_offset >> 6];
util::SafeStore(ptr, util::SafeLoad(ptr) | (static_cast<uint64_t>(group_id)
<< (groupid_bit_offset & 63)));

hashes_[slot_id] = hash;
out_group_ids[id] = group_id;
util::SafeStore(&out_group_ids[id], group_id);
push_id(category_inserted, id);
} else {
// We search for a slot with a matching stamp within a single block.
Expand All @@ -298,8 +301,8 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected
ARROW_DCHECK(new_groupid < num_inserted_ + num_ids[category_inserted]);
new_slot =
static_cast<int>(next_slot_to_visit(block_id, new_slot, new_match_found));
inout_next_slot_ids[id] = new_slot;
out_group_ids[id] = new_groupid;
util::SafeStore(&inout_next_slot_ids[id], new_slot);
util::SafeStore(&out_group_ids[id], new_groupid);
push_id(new_match_found, id);
}
}
Expand Down Expand Up @@ -410,7 +413,8 @@ Status SwissTable::map(const int num_keys, const uint32_t* hashes,
//
for (uint32_t i = 0; i < num_ids; ++i) {
// First slot in the new starting block
slot_ids[ids[i]] = (hashes[ids[i]] >> (bits_hash_ - log_blocks_)) * 8;
const int16_t id = util::SafeLoad(&ids[i]);
util::SafeStore(&slot_ids[id], (hashes[id] >> (bits_hash_ - log_blocks_)) * 8);
}
}
} while (num_ids > 0);
Expand Down Expand Up @@ -457,9 +461,8 @@ Status SwissTable::grow_double() {
static_cast<int>(CountLeadingZeros(block & kHighBitOfEachByte) >> 3);
int full_slots_new[2];
full_slots_new[0] = full_slots_new[1] = 0;
*reinterpret_cast<uint64_t*>(double_block_base_new) = kHighBitOfEachByte;
*reinterpret_cast<uint64_t*>(double_block_base_new + block_size_after) =
kHighBitOfEachByte;
util::SafeStore(double_block_base_new, kHighBitOfEachByte);
util::SafeStore(double_block_base_new + block_size_after, kHighBitOfEachByte);

for (int j = 0; j < full_slots; ++j) {
uint64_t slot_id = i * 8 + j;
Expand All @@ -474,18 +477,20 @@ Status SwissTable::grow_double() {
uint8_t stamp_new =
hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask;
uint64_t group_id_bit_offs = j * num_group_id_bits_before;
uint64_t group_id = (*reinterpret_cast<const uint64_t*>(block_base + 8 +
(group_id_bit_offs >> 3)) >>
(group_id_bit_offs & 7)) &
group_id_mask_before;
uint64_t group_id =
(util::SafeLoadAs<uint64_t>(block_base + 8 + (group_id_bit_offs >> 3)) >>
(group_id_bit_offs & 7)) &
group_id_mask_before;

uint64_t slot_id_new = i * 16 + ihalf * 8 + full_slots_new[ihalf];
hashes_new[slot_id_new] = hash;
uint8_t* block_base_new = double_block_base_new + ihalf * block_size_after;
block_base_new[7 - full_slots_new[ihalf]] = stamp_new;
int group_id_bit_offs_new = full_slots_new[ihalf] * num_group_id_bits_after;
*reinterpret_cast<uint64_t*>(block_base_new + 8 + (group_id_bit_offs_new >> 3)) |=
(group_id << (group_id_bit_offs_new & 7));
uint64_t* ptr =
reinterpret_cast<uint64_t*>(block_base_new + 8 + (group_id_bit_offs_new >> 3));
util::SafeStore(ptr,
util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7)));
full_slots_new[ihalf]++;
}
}
Expand All @@ -495,7 +500,7 @@ Status SwissTable::grow_double() {
for (int i = 0; i < (1 << log_blocks_); ++i) {
// How many full slots in this block
uint8_t* block_base = blocks_ + i * block_size_before;
uint64_t block = *reinterpret_cast<const uint64_t*>(block_base);
uint64_t block = util::SafeLoadAs<uint64_t>(block_base);
int full_slots = static_cast<int>(CountLeadingZeros(block & kHighBitOfEachByte) >> 3);

for (int j = 0; j < full_slots; ++j) {
Expand All @@ -508,30 +513,32 @@ Status SwissTable::grow_double() {
}

uint64_t group_id_bit_offs = j * num_group_id_bits_before;
uint64_t group_id = (*reinterpret_cast<const uint64_t*>(block_base + 8 +
(group_id_bit_offs >> 3)) >>
(group_id_bit_offs & 7)) &
group_id_mask_before;
uint64_t group_id =
(util::SafeLoadAs<uint64_t>(block_base + 8 + (group_id_bit_offs >> 3)) >>
(group_id_bit_offs & 7)) &
group_id_mask_before;
uint8_t stamp_new =
hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask;

uint8_t* block_base_new = blocks_new + block_id_new * block_size_after;
uint64_t block_new = *reinterpret_cast<const uint64_t*>(block_base_new);
uint64_t block_new = util::SafeLoadAs<uint64_t>(block_base_new);
int full_slots_new =
static_cast<int>(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3);
while (full_slots_new == 8) {
block_id_new = (block_id_new + 1) & ((1 << log_blocks_after) - 1);
block_base_new = blocks_new + block_id_new * block_size_after;
block_new = *reinterpret_cast<const uint64_t*>(block_base_new);
block_new = util::SafeLoadAs<uint64_t>(block_base_new);
full_slots_new =
static_cast<int>(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3);
}

hashes_new[block_id_new * 8 + full_slots_new] = hash;
block_base_new[7 - full_slots_new] = stamp_new;
int group_id_bit_offs_new = full_slots_new * num_group_id_bits_after;
*reinterpret_cast<uint64_t*>(block_base_new + 8 + (group_id_bit_offs_new >> 3)) |=
(group_id << (group_id_bit_offs_new & 7));
uint64_t* ptr =
reinterpret_cast<uint64_t*>(block_base_new + 8 + (group_id_bit_offs_new >> 3));
util::SafeStore(ptr,
util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7)));
}
}

Expand Down Expand Up @@ -567,7 +574,7 @@ Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool,

// Initialize all status bytes to represent an empty slot.
for (uint64_t i = 0; i < (static_cast<uint64_t>(1) << log_blocks_); ++i) {
*reinterpret_cast<uint64_t*>(blocks_ + i * block_bytes) = kHighBitOfEachByte;
util::SafeStore(blocks_ + i * block_bytes, kHighBitOfEachByte);
}

uint64_t num_slots = 1ULL << (log_blocks_ + 3);
Expand Down
16 changes: 9 additions & 7 deletions cpp/src/arrow/compute/exec/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/ubsan.h"

namespace arrow {

Expand Down Expand Up @@ -66,7 +67,7 @@ void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bit
#endif
*num_indexes = 0;
for (int i = 0; i < num_bits / unroll; ++i) {
uint64_t word = reinterpret_cast<const uint64_t*>(bits)[i];
uint64_t word = util::SafeLoad(&reinterpret_cast<const uint64_t*>(bits)[i]);
if (bit_to_search == 0) {
word = ~word;
}
Expand All @@ -81,7 +82,8 @@ void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bit
#endif
// Optionally process the last partial word with masking out bits outside range
if (tail) {
uint64_t word = reinterpret_cast<const uint64_t*>(bits)[num_bits / unroll];
uint64_t word =
util::SafeLoad(&reinterpret_cast<const uint64_t*>(bits)[num_bits / unroll]);
if (bit_to_search == 0) {
word = ~word;
}
Expand Down Expand Up @@ -144,7 +146,7 @@ void BitUtil::bits_to_bytes_internal(const int num_bits, const uint8_t* bits,
unpacked |= (bits_next & 1);
unpacked &= 0x0101010101010101ULL;
unpacked *= 255;
reinterpret_cast<uint64_t*>(bytes)[i] = unpacked;
util::SafeStore(&reinterpret_cast<uint64_t*>(bytes)[i], unpacked);
}
}

Expand All @@ -153,7 +155,7 @@ void BitUtil::bytes_to_bits_internal(const int num_bits, const uint8_t* bytes,
constexpr int unroll = 8;
// Process 8 bits at a time
for (int i = 0; i < (num_bits + unroll - 1) / unroll; ++i) {
uint64_t bytes_next = reinterpret_cast<const uint64_t*>(bytes)[i];
uint64_t bytes_next = util::SafeLoad(&reinterpret_cast<const uint64_t*>(bytes)[i]);
bytes_next &= 0x0101010101010101ULL;
bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes
bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes
Expand Down Expand Up @@ -184,7 +186,7 @@ void BitUtil::bits_to_bytes(int64_t hardware_flags, const int num_bits,
unpacked |= (bits_next & 1);
unpacked &= 0x0101010101010101ULL;
unpacked *= 255;
reinterpret_cast<uint64_t*>(bytes)[i] = unpacked;
util::SafeStore(&reinterpret_cast<uint64_t*>(bytes)[i], unpacked);
}
}

Expand All @@ -201,7 +203,7 @@ void BitUtil::bytes_to_bits(int64_t hardware_flags, const int num_bits,
// Process 8 bits at a time
constexpr int unroll = 8;
for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) {
uint64_t bytes_next = reinterpret_cast<const uint64_t*>(bytes)[i];
uint64_t bytes_next = util::SafeLoad(&reinterpret_cast<const uint64_t*>(bytes)[i]);
bytes_next &= 0x0101010101010101ULL;
bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes
bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes
Expand All @@ -220,7 +222,7 @@ bool BitUtil::are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes,
uint64_t result_or = 0;
uint32_t i;
for (i = 0; i < num_bytes / 8; ++i) {
uint64_t x = reinterpret_cast<const uint64_t*>(bytes)[i];
uint64_t x = util::SafeLoad(&reinterpret_cast<const uint64_t*>(bytes)[i]);
result_or |= x;
}
if (num_bytes % 8 > 0) {
Expand Down

0 comments on commit 889291b

Please sign in to comment.