Skip to content

Commit

Permalink
[SYCL] Fix sub-group mask for smaller SG sizes (#4916)
Browse files Browse the repository at this point in the history
Fix accessing sub-group mask when sub-group size is less than 32. Make sure that false is returned for positions that are more than sub-group size.

Update the test to check this case.
  • Loading branch information
vladimirlaz authored Nov 18, 2021
1 parent 2ebde5f commit c855fd1
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 115 deletions.
33 changes: 20 additions & 13 deletions sycl/include/sycl/ext/oneapi/sub_group_mask.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct sub_group_mask {
}

reference(sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) {
RefBit = 1 << pos % word_size;
RefBit = (pos < gmask.bits_num) ? (1UL << pos) : 0;
}

private:
Expand All @@ -61,16 +61,17 @@ struct sub_group_mask {
};

bool operator[](id<1> id) const {
return Bits & (1 << (id.get(0) % word_size));
return (Bits & ((id.get(0) < bits_num) ? (1UL << id.get(0)) : 0));
}

reference operator[](id<1> id) { return {*this, id.get(0)}; }
bool test(id<1> id) const { return operator[](id); }
bool all() const { return !~Bits; }
bool any() const { return Bits; }
bool none() const { return !Bits; }
bool all() const { return count() == bits_num; }
bool any() const { return count() != 0; }
bool none() const { return count() == 0; }
uint32_t count() const {
unsigned int count = 0;
auto word = Bits;
auto word = (Bits & valuable_bits(bits_num));
while (word) {
word &= (word - 1);
count++;
Expand Down Expand Up @@ -99,9 +100,9 @@ struct sub_group_mask {
insert_data <<= pos.get(0);
uint32_t mask = 0;
if (pos.get(0) + insert_size < size())
mask |= (0xffffffff << (pos.get(0) + insert_size));
mask |= (valuable_bits(bits_num) << (pos.get(0) + insert_size));
if (pos.get(0) < size() && pos.get(0))
mask |= (0xffffffff >> (size() - pos.get(0)));
mask |= (valuable_bits(max_bits) >> (max_bits - pos.get(0)));
Bits &= mask;
Bits += insert_data;
}
Expand All @@ -125,14 +126,15 @@ struct sub_group_mask {
template <typename Type,
typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
void extract_bits(Type &bits, id<1> pos = 0) const {
uint32_t Res = Bits;
auto Res = Bits;
Res &= valuable_bits(bits_num);
if (pos.get(0) < size()) {
if (pos.get(0) > 0) {
Res >>= pos.get(0);
}

if (sizeof(Type) * CHAR_BIT < size()) {
Res &= (0xffffffff >> (size() - (sizeof(Type) * CHAR_BIT)));
if (sizeof(Type) * CHAR_BIT < max_bits) {
Res &= valuable_bits(sizeof(Type) * CHAR_BIT);
}
bits = (Type)Res;
} else {
Expand All @@ -154,13 +156,13 @@ struct sub_group_mask {
}
}

void set() { Bits = uint32_t{0xffffffff}; }
void set() { Bits = valuable_bits(bits_num); }
void set(id<1> id, bool value = true) { operator[](id) = value; }
void reset() { Bits = uint32_t{0}; }
void reset(id<1> id) { operator[](id) = 0; }
void reset_low() { reset(find_low()); }
void reset_high() { reset(find_high()); }
void flip() { Bits = ~Bits; }
void flip() { Bits = (~Bits & valuable_bits(bits_num)); }
void flip(id<1> id) { operator[](id).flip(); }

bool operator==(const sub_group_mask &rhs) const { return Bits == rhs.Bits; }
Expand All @@ -177,11 +179,13 @@ struct sub_group_mask {

sub_group_mask &operator^=(const sub_group_mask &rhs) {
Bits ^= rhs.Bits;
Bits &= valuable_bits(bits_num);
return *this;
}

sub_group_mask &operator<<=(size_t pos) {
Bits <<= pos;
Bits &= valuable_bits(bits_num);
return *this;
}

Expand Down Expand Up @@ -239,6 +243,9 @@ struct sub_group_mask {
sub_group_mask(uint32_t rhs, size_t bn) : Bits(rhs), bits_num(bn) {
assert(bits_num <= max_bits);
}
inline uint32_t valuable_bits(size_t bn) const {
return static_cast<uint32_t>((1ULL << bn) - 1ULL);
}
uint32_t Bits;
// Number of valuable bits
size_t bits_num;
Expand Down
218 changes: 116 additions & 102 deletions sycl/test/extensions/sub_group_mask.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clangxx -g -O0 -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
// RUN: %t.out

//==-------- sub_group_mask.cpp - SYCL sub-group mask test -----------------==//
Expand All @@ -13,110 +13,124 @@
#include <iostream>

int main() {
auto g = sycl::detail::Builder::createSubGroupMask<
sycl::ext::oneapi::sub_group_mask>(0, 32);
assert(g.none() && !g.any() && !g.all());
assert(g[10] == false); // reference::operator[](id) const;
g[10] = true; // reference::operator=(bool);
assert(g[10] == true);
g[11] = g[10]; // reference::operator=(reference) reference::operator[](id);
assert(g[10].flip() == false); // reference::flip()
assert(~g[10] == true); // refernce::operator~()
assert(g[10] == false);
assert(g[11] == true);
assert(g.test(10) == false && g.test(11) == true);
g.set(30, 1);
g.set(11, 0);
g.set(23, 1);
assert(!g.none() && g.any() && !g.all());
for (size_t sgsize = 32; sgsize > 4; sgsize /= 2) {
std::cout << "Running test for sub-group size = " << sgsize << std::endl;
auto g = sycl::detail::Builder::createSubGroupMask<
sycl::ext::oneapi::sub_group_mask>(0, sgsize);
assert(g.none() && !g.any() && !g.all());
assert(g[5] == false); // reference::operator[](id) const;
g[5] = true; // reference::operator=(bool);
assert(g[5] == true);
g[6] = g[5]; // reference::operator=(reference) reference::operator[](id);
assert(g[5].flip() == false); // reference::flip()
assert(~g[5 % sgsize] == true); // refernce::operator~()
assert(g[5 % sgsize] == false);
assert(g[6 % sgsize] == true);
assert(g.test(5 % sgsize) == false && g.test(6 % sgsize) == true);
g.set(3 % sgsize, 1);
g.set(6 % sgsize, 0);
g.set(2 % sgsize, 1);
assert(!g.none() && g.any() && !g.all());

assert(g.count() == 2);
assert(g.find_low() == 23);
assert(g.find_high() == 30);
assert(g.size() == 32);
assert(g.count() == 2);
assert(g.find_low() == 2 % sgsize);
assert(g.find_high() == 3 % sgsize);
assert(g.size() == sgsize);

g.reset();
assert(g.none() && !g.any() && !g.all());
assert(g.find_low() == g.size() && g.find_high() == g.size());
g.set();
assert(!g.none() && g.any() && g.all());
assert(g.find_low() == 0 && g.find_high() == 31);
g.flip();
assert(g.none() && !g.any() && !g.all());
g.reset();
assert(g.none() && !g.any() && !g.all());
assert(g.find_low() == g.size() && g.find_high() == g.size());
g.set();
assert(!g.none() && g.any() && g.all());
assert(g.find_low() == 0 && g.find_high() == 31 % sgsize);
g.flip();
assert(g.none() && !g.any() && !g.all());

g.flip(13);
g.flip(23);
g.flip(29);
auto b = g;
assert(b == g && !(b != g));
g.flip(31);
assert(g.find_high() == 31);
assert(b.find_high() == 29);
assert(b != g && !(b == g));
b.flip(31);
assert(b == g && !(b != g));
b = g >> 1;
assert(b[12] && b[22] && b[28] && b[30]);
b <<= 1;
assert(b == g);
g ^= ~b;
assert(!g.none() && g.any() && g.all());
assert((g | ~g).all());
assert((g & ~g).none());
assert((g ^ ~g).all());
b.reset_low();
b.reset_high();
assert(!b[13] && b[23] && b[29] && !b[31]);
b.insert_bits(0x01020408);
assert(b[24] && b[17] && b[10] && b[3]);
b <<= 13;
assert(!b[24] && !b[17] && !b[10] && !b[3] && b[30] && b[23] && b[16]);
b.insert_bits((char)0b01010101, 18);
assert(b[18] && b[20] && b[22] && b[24] && b[30] && !b[23] && b[16]);
b[3] = true;
b.insert_bits(sycl::marray<char, 8>{1, 2, 4, 8, 16, 32, 64, 128}, 5);
assert(!b[18] && !b[20] && !b[22] && !b[24] && !b[30] && !b[16] && b[3] &&
b[5] && b[14] && b[23]);
char r, rbc;
const auto b_const{b};
b.extract_bits(r);
b_const.extract_bits(rbc);
assert(r == 0b00101000);
assert(rbc == 0b00101000);
long r2 = -1, r2bc = -1;
b.extract_bits(r2, 16);
b_const.extract_bits(r2bc, 16);
assert(r2 == 128);
assert(r2bc == 128);
g.flip(2);
g.flip(3);
g.flip(7);
auto b = g;
assert(b == g && !(b != g));
g.flip(7);
assert(g.find_high() == 3 % sgsize);
assert(b.find_high() == 7 % sgsize);
assert(b != g && !(b == g));
g.flip(7);
assert(b == g && !(b != g));
b = g >> 1;
assert(b[1] && b[2] && b[6]);
b <<= 1;
assert(b == g);
g ^= ~b;
assert(!g.none() && g.any() && g.all());
assert((g | ~g).all());
assert((g & ~g).none());
assert((g ^ ~g).all());
b.reset_low();
b.reset_high();
assert(!b[2] && b[3] && !b[7]);
b.insert_bits(0x01020408);
assert(((b[24] && b[17]) || sgsize < 32) && (b[10] || sgsize < 16) && b[3]);
b <<= 10;
assert(((!b[24] && !b[17] && b[27] && b[20]) || sgsize < 32) &&
((!b[10] && b[13]) || sgsize < 16) && !b[3]);
b.insert_bits((char)0b01010101, 6);
assert(b[6] && ((b[8] && b[10] && b[12] && !b[13]) || sgsize < 16));
b[3] = true;
b.insert_bits(sycl::marray<char, 8>{1, 2, 4, 8, 16, 32, 64, 128}, 5);
assert(
((!b[18] && !b[20] && !b[22] && !b[24] && !b[30] && !b[16] && b[23]) ||
sgsize < 32) &&
b[3] && b[5] && (b[14] || sgsize < 16));
b.flip(14);
b.flip(23);
char r, rbc;
const auto b_const{b};
b.extract_bits(r);
b_const.extract_bits(rbc);
assert(r == 0b00101000);
assert(rbc == 0b00101000);
long r2 = -1, r2bc = -1;
b.extract_bits(r2, 3);
b_const.extract_bits(r2bc, 3);
assert(r2 == 5);
assert(r2bc == 5);

b[31] = true;
const auto b_const2{b};
sycl::marray<char, 6> r3{-1}, r3bc{-1};
b.extract_bits(r3, 14);
b_const2.extract_bits(r3bc, 14);
assert(r3[0] == 1 && r3[1] == 2 && r3[2] == 2 && !r3[3] && !r3[4] && !r3[5]);
assert(r3bc[0] == 1 && r3bc[1] == 2 && r3bc[2] == 2 && !r3bc[3] && !r3bc[4] &&
!r3bc[5]);
int ibits = 0b1010101010101010101010101010101;
b.insert_bits(ibits);
for (size_t i = 0; i < 32; i++) {
assert(b[i] != (bool)(i % 2));
b.insert_bits((uint32_t)0x08040201);
const auto b_const2{b};
sycl::marray<char, 6> r3{-1}, r3bc{-1};
b.extract_bits(r3);
b_const2.extract_bits(r3bc);
assert(r3[0] == 1 && r3[1] == (sgsize > 8 ? 2 : 0) &&
r3[2] == (sgsize > 16 ? 4 : 0) && r3[3] == (sgsize > 16 ? 8 : 0) &&
!r3[4] && !r3[5]);
assert(r3bc[0] == 1 && r3bc[1] == (sgsize > 8 ? 2 : 0) &&
r3bc[2] == (sgsize > 16 ? 4 : 0) &&
r3bc[3] == (sgsize > 16 ? 8 : 0) && !r3bc[4] && !r3bc[5]);
int ibits = 0b1010101010101010101010101010101;
b.insert_bits(ibits);
for (size_t i = 0; i < sgsize; i++) {
assert(b[i] != (bool)(i % 2));
}
short sbits = 0b0111011101110111;
b.insert_bits(sbits, 7);
b.extract_bits(ibits);
assert(ibits ==
(0b1010101001110111011101111010101 & ((1ULL << sgsize) - 1ULL)));
sbits = 0b1100001111000011;
b.insert_bits(sbits, 23);
b.extract_bits(ibits);
if (sgsize >= 32) {
int64_t lbits = -1;
b.extract_bits(lbits, 33);
assert(lbits == 0);
lbits = -1;
b.extract_bits(lbits, 5);
assert(lbits ==
(0b111000011011101110111011110 & ((1ULL << sgsize) - 1ULL)));
lbits = -1;
b.insert_bits(lbits);
assert(b.all());
}
}
short sbits = 0b0111011101110111;
b.insert_bits(sbits, 7);
b.extract_bits(ibits);
assert(ibits == 0b1010101001110111011101111010101);
sbits = 0b1100001111000011;
b.insert_bits(sbits, 23);
b.extract_bits(ibits);
assert(ibits == 0b11100001101110111011101111010101);
int64_t lbits = -1;
b.extract_bits(lbits, 33);
assert(lbits == 0);
lbits = -1;
b.extract_bits(lbits, 5);
assert(lbits == 0b111000011011101110111011110);
lbits = -1;
b.insert_bits(lbits);
assert(b.all());
}

0 comments on commit c855fd1

Please sign in to comment.