Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rsyncable] Ensure ZSTD_compressBound() is respected #2776

Merged
merged 1 commit into from
Sep 14, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions lib/compress/zstdmt_compress.c
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,15 @@ typedef struct {
static const roundBuff_t kNullRoundBuff = {NULL, 0, 0};

#define RSYNC_LENGTH 32
/* Don't create chunks smaller than the zstd block size.
* This stops us from regressing compression ratio too much,
* and ensures our output fits in ZSTD_compressBound().
*
* If this is shrunk < ZSTD_BLOCKSIZELOG_MIN then
* ZSTD_COMPRESSBOUND() will need to be updated.
*/
#define RSYNC_MIN_BLOCK_LOG ZSTD_BLOCKSIZELOG_MAX
#define RSYNC_MIN_BLOCK_SIZE (1<<RSYNC_MIN_BLOCK_LOG)

typedef struct {
U64 hash;
Expand Down Expand Up @@ -1252,6 +1261,9 @@ size_t ZSTDMT_initCStream_internal(
/* Aim for the targetsectionSize as the average job size. */
U32 const jobSizeKB = (U32)(mtctx->targetSectionSize >> 10);
U32 const rsyncBits = (assert(jobSizeKB >= 1), ZSTD_highbit32(jobSizeKB) + 10);
/* We refuse to create jobs < RSYNC_MIN_BLOCK_SIZE bytes, so make sure our
* expected job size is at least 4x larger. */
assert(rsyncBits >= RSYNC_MIN_BLOCK_LOG + 2);
DEBUGLOG(4, "rsyncLog = %u", rsyncBits);
mtctx->rsync.hash = 0;
mtctx->rsync.hitMask = (1ULL << rsyncBits) - 1;
Expand Down Expand Up @@ -1678,6 +1690,11 @@ findSynchronizationPoint(ZSTDMT_CCtx const* mtctx, ZSTD_inBuffer const input)
if (!mtctx->params.rsyncable)
/* Rsync is disabled. */
return syncPoint;
if (mtctx->inBuff.filled + input.size - input.pos < RSYNC_MIN_BLOCK_SIZE)
/* We don't emit synchronization points if it would produce too small blocks.
* We don't have enough input to find a synchronization point, so don't look.
*/
return syncPoint;
if (mtctx->inBuff.filled + syncPoint.toLoad < RSYNC_LENGTH)
/* Not enough to compute the hash.
* We will miss any synchronization points in this RSYNC_LENGTH byte
Expand All @@ -1688,14 +1705,24 @@ findSynchronizationPoint(ZSTDMT_CCtx const* mtctx, ZSTD_inBuffer const input)
*/
return syncPoint;
/* Initialize the loop variables. */
if (mtctx->inBuff.filled >= RSYNC_LENGTH) {
if (mtctx->inBuff.filled < RSYNC_MIN_BLOCK_SIZE - RSYNC_LENGTH) {
/* We don't need to scan the first RSYNC_MIN_BLOCK_SIZE positions
* because they can't possibly be a sync point. So we can start
* part way through the input buffer.
*/
pos = RSYNC_MIN_BLOCK_SIZE - mtctx->inBuff.filled;
assert(pos <= input.size - input.pos /* validated earlier */);
assert(pos >= RSYNC_LENGTH);
prev = istart + pos - RSYNC_LENGTH;
hash = ZSTD_rollingHash_compute(prev, RSYNC_LENGTH);
} else if (mtctx->inBuff.filled >= RSYNC_LENGTH) {
/* We have enough bytes buffered to initialize the hash.
* Start scanning at the beginning of the input.
*/
pos = 0;
prev = (BYTE const*)mtctx->inBuff.buffer.start + mtctx->inBuff.filled - RSYNC_LENGTH;
hash = ZSTD_rollingHash_compute(prev, RSYNC_LENGTH);
if ((hash & hitMask) == hitMask) {
if ((hash & hitMask) == hitMask && mtctx->inBuff.filled >= RSYNC_MIN_BLOCK_SIZE) {
/* We're already at a sync point so don't load any more until
* we're able to flush this sync point.
* This likely happened because the job table was full so we
Expand Down Expand Up @@ -1728,6 +1755,7 @@ findSynchronizationPoint(ZSTDMT_CCtx const* mtctx, ZSTD_inBuffer const input)
BYTE const toRemove = pos < RSYNC_LENGTH ? prev[pos] : istart[pos - RSYNC_LENGTH];
/* if (pos >= RSYNC_LENGTH) assert(ZSTD_rollingHash_compute(istart + pos - RSYNC_LENGTH, RSYNC_LENGTH) == hash); */
hash = ZSTD_rollingHash_rotate(hash, toRemove, istart[pos], primePower);
assert(mtctx->inBuff.filled + pos >= RSYNC_MIN_BLOCK_SIZE);
if ((hash & hitMask) == hitMask) {
syncPoint.toLoad = pos + 1;
syncPoint.flush = 1;
Expand Down