Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lazy] Optimize ZSTD_row_getMatchMask for levels 8-10 for ARM #3139

Merged
merged 4 commits into from
May 24, 2022
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 61 additions & 35 deletions lib/compress/zstd_lazy.c
Original file line number Diff line number Diff line change
Expand Up @@ -974,51 +974,76 @@ ZSTD_row_getSSEMask(int nbChunks, const BYTE* const src, const BYTE tag, const U
}
#endif

/* Returns a ZSTD_VecMask (U32) that has the nth bit set to 1 if the newly-computed "tag" matches
* the hash at the nth position in a row of the tagTable.
* Each row is a circular buffer beginning at the value of "head". So we must rotate the "matches" bitfield
* to match up with the actual layout of the entries within the hashTable */
/* Returns the mask width of bits group of which will be set to 1. Given not all
* architectures have easy movemask instruction, this helps to iterate over
* groups of bits easier and faster.
*/
FORCE_INLINE_TEMPLATE U32
ZSTD_row_matchMaskGroupWidth(const U32 rowEntries) {
assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64);
assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES);
(void)rowEntries;
#if defined(ZSTD_ARCH_ARM_NEON)
if (rowEntries == 16) {
return 4;
}
if (rowEntries == 32) {
return 2;
}
if (rowEntries == 64) {
return 1;
}
#endif
return 1;
}

/* Returns a ZSTD_VecMask (U64) that has the nth group (determined by
* ZSTD_row_matchMaskGroupWidth) of bits set to 1 if the newly-computed "tag"
* matches the hash at the nth position in a row of the tagTable.
* Each row is a circular buffer beginning at the value of "headGrouped". So we
* must rotate the "matches" bitfield to match up with the actual layout of the
* entries within the hashTable */
FORCE_INLINE_TEMPLATE ZSTD_VecMask
ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, const U32 rowEntries)
ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 headGrouped, const U32 rowEntries)
{
const BYTE* const src = tagRow + ZSTD_ROW_HASH_TAG_OFFSET;
assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64);
assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES);
assert(ZSTD_row_matchMaskGroupWidth(rowEntries) * rowEntries <= sizeof(ZSTD_VecMask) * 8);

#if defined(ZSTD_ARCH_X86_SSE2)

return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, head);
return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, headGrouped);

#else /* SW or NEON-LE */

# if defined(ZSTD_ARCH_ARM_NEON)
/* This NEON path only works for little endian - otherwise use SWAR below */
if (MEM_isLittleEndian()) {
if (rowEntries == 16) {
/* vshrn_n_u16 shifts by 4 every u16 and narrows to 8 lower bits.
danlark1 marked this conversation as resolved.
Show resolved Hide resolved
* After that groups of 4 bits represent the equalMask. We lower
* all bits except the highest in these groups by doing AND with
* 0x88 = 0b10001000.
*/
const uint8x16_t chunk = vld1q_u8(src);
const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag)));
const uint16x8_t t0 = vshlq_n_u16(equalMask, 7);
const uint32x4_t t1 = vreinterpretq_u32_u16(vsriq_n_u16(t0, t0, 14));
const uint64x2_t t2 = vreinterpretq_u64_u32(vshrq_n_u32(t1, 14));
const uint8x16_t t3 = vreinterpretq_u8_u64(vsraq_n_u64(t2, t2, 28));
const U16 hi = (U16)vgetq_lane_u8(t3, 8);
const U16 lo = (U16)vgetq_lane_u8(t3, 0);
return ZSTD_rotateRight_U16((hi << 8) | lo, head);
const uint8x8_t res = vshrn_n_u16(equalMask, 4);
const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0);
return ZSTD_rotateRight_U64(matches, headGrouped) & 0x8888888888888888ull;
} else if (rowEntries == 32) {
const uint16x8x2_t chunk = vld2q_u16((const U16*)(const void*)src);
/* Same idea as with rowEntries == 16 but doing AND with
* 0x55 = 0b01010101.
*/
const uint16x8x2_t chunk = vld2q_u16((const uint16_t*)(const void*)src);
const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]);
const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]);
const uint8x16_t equalMask0 = vceqq_u8(chunk0, vdupq_n_u8(tag));
const uint8x16_t equalMask1 = vceqq_u8(chunk1, vdupq_n_u8(tag));
const int8x8_t pack0 = vqmovn_s16(vreinterpretq_s16_u8(equalMask0));
const int8x8_t pack1 = vqmovn_s16(vreinterpretq_s16_u8(equalMask1));
const uint8x8_t t0 = vreinterpret_u8_s8(pack0);
const uint8x8_t t1 = vreinterpret_u8_s8(pack1);
const uint8x8_t t2 = vsri_n_u8(t1, t0, 2);
const uint8x8x2_t t3 = vuzp_u8(t2, t0);
const uint8x8_t t4 = vsri_n_u8(t3.val[1], t3.val[0], 4);
const U32 matches = vget_lane_u32(vreinterpret_u32_u8(t4), 0);
return ZSTD_rotateRight_U32(matches, head);
const uint8x16_t dup = vdupq_n_u8(tag);
const uint8x8_t t0 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk0, dup)), 6);
const uint8x8_t t1 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk1, dup)), 6);
const uint8x8_t res = vsli_n_u8(t0, t1, 4);
const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0) ;
return ZSTD_rotateRight_U64(matches, headGrouped) & 0x5555555555555555ull;
} else { /* rowEntries == 64 */
const uint8x16x4_t chunk = vld4q_u8(src);
const uint8x16_t dup = vdupq_n_u8(tag);
Expand All @@ -1033,7 +1058,7 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head,
const uint8x16_t t3 = vsriq_n_u8(t2, t2, 4);
const uint8x8_t t4 = vshrn_n_u16(vreinterpretq_u16_u8(t3), 4);
const U64 matches = vget_lane_u64(vreinterpret_u64_u8(t4), 0);
return ZSTD_rotateRight_U64(matches, head);
return ZSTD_rotateRight_U64(matches, headGrouped);
}
}
# endif /* ZSTD_ARCH_ARM_NEON */
Expand Down Expand Up @@ -1071,11 +1096,11 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head,
}
matches = ~matches;
if (rowEntries == 16) {
return ZSTD_rotateRight_U16((U16)matches, head);
return ZSTD_rotateRight_U16((U16)matches, headGrouped);
} else if (rowEntries == 32) {
return ZSTD_rotateRight_U32((U32)matches, head);
return ZSTD_rotateRight_U32((U32)matches, headGrouped);
} else {
return ZSTD_rotateRight_U64((U64)matches, head);
return ZSTD_rotateRight_U64((U64)matches, headGrouped);
}
}
#endif
Expand Down Expand Up @@ -1123,6 +1148,7 @@ size_t ZSTD_RowFindBestMatch(
const U32 rowEntries = (1U << rowLog);
const U32 rowMask = rowEntries - 1;
const U32 cappedSearchLog = MIN(cParams->searchLog, rowLog); /* nb of searches is capped at nb entries per row */
const U32 groupWidth = ZSTD_row_matchMaskGroupWidth(rowEntries);
U32 nbAttempts = 1U << cappedSearchLog;
size_t ml=4-1;

Expand Down Expand Up @@ -1165,15 +1191,15 @@ size_t ZSTD_RowFindBestMatch(
U32 const tag = hash & ZSTD_ROW_HASH_TAG_MASK;
U32* const row = hashTable + relRow;
BYTE* tagRow = (BYTE*)(tagTable + relRow);
U32 const head = *tagRow & rowMask;
U32 const headGrouped = (*tagRow & rowMask) * groupWidth;
U32 matchBuffer[ZSTD_ROW_HASH_MAX_ENTRIES];
size_t numMatches = 0;
size_t currMatch = 0;
ZSTD_VecMask matches = ZSTD_row_getMatchMask(tagRow, (BYTE)tag, head, rowEntries);
ZSTD_VecMask matches = ZSTD_row_getMatchMask(tagRow, (BYTE)tag, headGrouped, rowEntries);

/* Cycle through the matches and prefetch */
for (; (matches > 0) && (nbAttempts > 0); --nbAttempts, matches &= (matches - 1)) {
U32 const matchPos = (head + ZSTD_VecMask_next(matches)) & rowMask;
U32 const matchPos = ((headGrouped + ZSTD_VecMask_next(matches)) / groupWidth) & rowMask;
U32 const matchIndex = row[matchPos];
assert(numMatches < rowEntries);
if (matchIndex < lowLimit)
Expand Down Expand Up @@ -1234,14 +1260,14 @@ size_t ZSTD_RowFindBestMatch(
const U32 dmsSize = (U32)(dmsEnd - dmsBase);
const U32 dmsIndexDelta = dictLimit - dmsSize;

{ U32 const head = *dmsTagRow & rowMask;
{ U32 const headGrouped = (*dmsTagRow & rowMask) * groupWidth;
U32 matchBuffer[ZSTD_ROW_HASH_MAX_ENTRIES];
size_t numMatches = 0;
size_t currMatch = 0;
ZSTD_VecMask matches = ZSTD_row_getMatchMask(dmsTagRow, (BYTE)dmsTag, head, rowEntries);
ZSTD_VecMask matches = ZSTD_row_getMatchMask(dmsTagRow, (BYTE)dmsTag, headGrouped, rowEntries);

for (; (matches > 0) && (nbAttempts > 0); --nbAttempts, matches &= (matches - 1)) {
U32 const matchPos = (head + ZSTD_VecMask_next(matches)) & rowMask;
U32 const matchPos = ((headGrouped + ZSTD_VecMask_next(matches)) / groupWidth) & rowMask;
U32 const matchIndex = dmsRow[matchPos];
if (matchIndex < dmsLowestIndex)
break;
Expand Down