Skip to content

Commit

Permalink
Resolve racecheck errors in ORC kernels (#9916)
Browse files Browse the repository at this point in the history
Running ORC Python tests with `compute-sanitizer --tool racecheck` results in a number of errors/warnings.
This PR resolves the errors originating in ORC kernels. Remaining errors come from `gpu_inflate`.

Adds a few missing block/warp syncs and minor clean up in the affected code.

Causes ~4~2% slowdown on average in ORC reader benchmarks. Not negligible, will double check whether the changes are required, or just resolving false positives in `racecheck`.
Ran the benchmarks many more times, and the average time difference is smaller than variations between runs.

Authors:
  - Vukasin Milovanovic (https://github.com/vuule)

Approvers:
  - Elias Stehle (https://github.com/elstehle)
  - Devavret Makkar (https://github.com/devavret)

URL: #9916
  • Loading branch information
vuule authored Jan 7, 2022
1 parent 120aa62 commit de8c0b8
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 32 deletions.
17 changes: 7 additions & 10 deletions cpp/src/io/comp/gpuinflate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -780,22 +780,19 @@ __device__ void process_symbols(inflate_state_s* s, int t)

do {
volatile uint32_t* b = &s->x.u.symqueue[batch * batch_size];
int batch_len, pos;
int32_t symt;
uint32_t lit_mask;

int batch_len = 0;
if (t == 0) {
while ((batch_len = s->x.batch_len[batch]) == 0) {}
} else {
batch_len = 0;
}
batch_len = shuffle(batch_len);
if (batch_len < 0) { break; }

symt = (t < batch_len) ? b[t] : 256;
lit_mask = ballot(symt >= 256);
pos = min((__ffs(lit_mask) - 1) & 0xff, 32);
auto const symt = (t < batch_len) ? b[t] : 256;
auto const lit_mask = ballot(symt >= 256);
auto pos = min((__ffs(lit_mask) - 1) & 0xff, 32);

if (t == 0) { s->x.batch_len[batch] = 0; }

if (t < pos && out + t < outend) { out[t] = symt; }
out += pos;
batch_len -= pos;
Expand Down Expand Up @@ -825,7 +822,7 @@ __device__ void process_symbols(inflate_state_s* s, int t)
}
}
batch = (batch + 1) & (batch_count - 1);
} while (1);
} while (true);

if (t == 0) { s->out = out; }
}
Expand Down
35 changes: 18 additions & 17 deletions cpp/src/io/orc/stripe_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ inline __device__ int decode_base128_varint(volatile orc_bytestream_s* bs, int p
if (b > 0x7f) {
b = bytestream_readbyte(bs, pos++);
v = (v & 0x0fffffff) | (b << 28);
if (sizeof(T) > 4) {
if constexpr (sizeof(T) > 4) {
uint32_t lo = v;
uint64_t hi;
v = b >> 4;
Expand Down Expand Up @@ -650,13 +650,11 @@ static __device__ uint32_t Integer_RLEv2(orc_bytestream_s* bs,
int t,
bool has_buffered_values = false)
{
uint32_t numvals, numruns;
int r, tr;

if (t == 0) {
uint32_t maxpos = min(bs->len, bs->pos + (bytestream_buffer_size - 8u));
uint32_t lastpos = bs->pos;
numvals = numruns = 0;
auto numvals = 0;
auto numruns = 0;
// Find the length and start location of each run
while (numvals < maxvals) {
uint32_t pos = lastpos;
Expand Down Expand Up @@ -713,9 +711,9 @@ static __device__ uint32_t Integer_RLEv2(orc_bytestream_s* bs,
}
__syncthreads();
// Process the runs, 1 warp per run
numruns = rle->num_runs;
r = t >> 5;
tr = t & 0x1f;
auto const numruns = rle->num_runs;
auto const r = t >> 5;
auto const tr = t & 0x1f;
for (uint32_t run = r; run < numruns; run += num_warps) {
uint32_t base, pos, w, n;
int mode;
Expand All @@ -731,7 +729,7 @@ static __device__ uint32_t Integer_RLEv2(orc_bytestream_s* bs,
w = 8 + (byte0 & 0x38); // 8 to 64 bits
n = 3 + (byte0 & 7); // 3 to 10 values
bytestream_readbe(bs, pos * 8, w, baseval);
if (sizeof(T) <= 4) {
if constexpr (sizeof(T) <= 4) {
rle->baseval.u32[r] = baseval;
} else {
rle->baseval.u64[r] = baseval;
Expand All @@ -746,7 +744,7 @@ static __device__ uint32_t Integer_RLEv2(orc_bytestream_s* bs,
uint32_t byte3 = bytestream_readbyte(bs, pos++);
uint32_t bw = 1 + (byte2 >> 5); // base value width, 1 to 8 bytes
uint32_t pw = kRLEv2_W[byte2 & 0x1f]; // patch width, 1 to 64 bits
if (sizeof(T) <= 4) {
if constexpr (sizeof(T) <= 4) {
uint32_t baseval, mask;
bytestream_readbe(bs, pos * 8, bw * 8, baseval);
mask = (1 << (bw * 8 - 1)) - 1;
Expand All @@ -766,7 +764,7 @@ static __device__ uint32_t Integer_RLEv2(orc_bytestream_s* bs,
int64_t delta;
// Delta
pos = decode_varint(bs, pos, baseval);
if (sizeof(T) <= 4) {
if constexpr (sizeof(T) <= 4) {
rle->baseval.u32[r] = baseval;
} else {
rle->baseval.u64[r] = baseval;
Expand All @@ -782,8 +780,9 @@ static __device__ uint32_t Integer_RLEv2(orc_bytestream_s* bs,
pos = shuffle(pos);
n = shuffle(n);
w = shuffle(w);
__syncwarp(); // Not required, included to fix the racecheck warning
for (uint32_t i = tr; i < n; i += 32) {
if (sizeof(T) <= 4) {
if constexpr (sizeof(T) <= 4) {
if (mode == 0) {
vals[base + i] = rle->baseval.u32[r];
} else if (mode == 1) {
Expand Down Expand Up @@ -860,14 +859,15 @@ static __device__ uint32_t Integer_RLEv2(orc_bytestream_s* bs,
if (j & i) vals[base + j] += vals[base + ((j & ~i) | (i - 1))];
}
}
if (sizeof(T) <= 4)
if constexpr (sizeof(T) <= 4)
baseval = rle->baseval.u32[r];
else
baseval = rle->baseval.u64[r];
for (uint32_t j = tr; j < n; j += 32) {
vals[base + j] += baseval;
}
}
__syncwarp();
}
__syncthreads();
return rle->num_vals;
Expand Down Expand Up @@ -1679,11 +1679,12 @@ __global__ void __launch_bounds__(block_size)
}
}
}
if (t == 0 && numvals + vals_skipped > 0 && numvals < s->top.data.max_vals) {
if (s->chunk.type_kind == TIMESTAMP) {
s->top.data.buffered_count = s->top.data.max_vals - numvals;
if (t == 0 && numvals + vals_skipped > 0) {
auto const max_vals = s->top.data.max_vals;
if (max_vals > numvals) {
if (s->chunk.type_kind == TIMESTAMP) { s->top.data.buffered_count = max_vals - numvals; }
s->top.data.max_vals = numvals;
}
s->top.data.max_vals = numvals;
}
__syncthreads();
// Use the valid bits to compute non-null row positions until we get a full batch of values to
Expand Down
7 changes: 2 additions & 5 deletions cpp/src/io/orc/stripe_enc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ static __device__ uint32_t IntegerRLE(
uint32_t mode1_w, mode2_w;
typename std::make_unsigned<T>::type vrange_mode1, vrange_mode2;
block_vmin = static_cast<uint64_t>(vmin);
if (sizeof(T) > 4) {
if constexpr (sizeof(T) > 4) {
vrange_mode1 = (is_signed) ? max(zigzag(vmin), zigzag(vmax)) : vmax;
vrange_mode2 = vmax - vmin;
mode1_w = 8 - min(CountLeadingBytes64(vrange_mode1), 7);
Expand Down Expand Up @@ -705,10 +705,7 @@ static __device__ void encode_null_mask(orcenc_state_s* s,
}

// reset shared state
if (t == 0) {
s->nnz = 0;
s->numvals = 0;
}
if (t == 0) { s->nnz = 0; }
}

/**
Expand Down

0 comments on commit de8c0b8

Please sign in to comment.