From a418b4e4787f495abcc1f2e7768bd52a8d5ffc82 Mon Sep 17 00:00:00 2001 From: Nick Terrell Date: Mon, 13 Sep 2021 16:59:20 -0700 Subject: [PATCH] [rsyncable] Ensure ZSTD_compressBound() is respected In degenerate cases `--rsyncable` could create very small blocks (1 byte). This causes the compressed output to be larger than `ZSTD_compressBound()`. Fix the issue by ensuring that rsyncable mode never outputs blocks smaller than 128 KB. The minimum job size is 512 KB, so we shouldn't lose many synchronization points from skipping any that cause blocks smaller than 128 KB. And even if we do, that is fine, because we'll find the next one. This fixes the `raw_dictionary_round_trip` oss-fuzz assert. Credit to OSS-Fuzz --- lib/compress/zstdmt_compress.c | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/lib/compress/zstdmt_compress.c b/lib/compress/zstdmt_compress.c index 22aa3e1245a..94323846830 100644 --- a/lib/compress/zstdmt_compress.c +++ b/lib/compress/zstdmt_compress.c @@ -807,6 +807,15 @@ typedef struct { static const roundBuff_t kNullRoundBuff = {NULL, 0, 0}; #define RSYNC_LENGTH 32 +/* Don't create chunks smaller than the zstd block size. + * This stops us from regressing compression ratio too much, + * and ensures our output fits in ZSTD_compressBound(). + * + * If this is shrunk < ZSTD_BLOCKSIZELOG_MIN then + * ZSTD_COMPRESSBOUND() will need to be updated. + */ +#define RSYNC_MIN_BLOCK_LOG ZSTD_BLOCKSIZELOG_MAX +#define RSYNC_MIN_BLOCK_SIZE (1<targetSectionSize >> 10); U32 const rsyncBits = (assert(jobSizeKB >= 1), ZSTD_highbit32(jobSizeKB) + 10); + /* We refuse to create jobs < RSYNC_MIN_BLOCK_SIZE bytes, so make sure our + * expected job size is at least 4x larger. */ + assert(rsyncBits >= RSYNC_MIN_BLOCK_LOG + 2); DEBUGLOG(4, "rsyncLog = %u", rsyncBits); mtctx->rsync.hash = 0; mtctx->rsync.hitMask = (1ULL << rsyncBits) - 1; @@ -1678,6 +1690,11 @@ findSynchronizationPoint(ZSTDMT_CCtx const* mtctx, ZSTD_inBuffer const input) if (!mtctx->params.rsyncable) /* Rsync is disabled. */ return syncPoint; + if (mtctx->inBuff.filled + input.size - input.pos < RSYNC_MIN_BLOCK_SIZE) + /* We don't emit synchronization points if it would produce too small blocks. + * We don't have enough input to find a synchronization point, so don't look. + */ + return syncPoint; if (mtctx->inBuff.filled + syncPoint.toLoad < RSYNC_LENGTH) /* Not enough to compute the hash. * We will miss any synchronization points in this RSYNC_LENGTH byte @@ -1688,14 +1705,24 @@ findSynchronizationPoint(ZSTDMT_CCtx const* mtctx, ZSTD_inBuffer const input) */ return syncPoint; /* Initialize the loop variables. */ - if (mtctx->inBuff.filled >= RSYNC_LENGTH) { + if (mtctx->inBuff.filled < RSYNC_MIN_BLOCK_SIZE - RSYNC_LENGTH) { + /* We don't need to scan the first RSYNC_MIN_BLOCK_SIZE positions + * because they can't possibly be a sync point. So we can start + * part way through the input buffer. + */ + pos = RSYNC_MIN_BLOCK_SIZE - mtctx->inBuff.filled; + assert(pos <= input.size - input.pos /* validated earlier */); + assert(pos >= RSYNC_LENGTH); + prev = istart + pos - RSYNC_LENGTH; + hash = ZSTD_rollingHash_compute(prev, RSYNC_LENGTH); + } else if (mtctx->inBuff.filled >= RSYNC_LENGTH) { /* We have enough bytes buffered to initialize the hash. * Start scanning at the beginning of the input. */ pos = 0; prev = (BYTE const*)mtctx->inBuff.buffer.start + mtctx->inBuff.filled - RSYNC_LENGTH; hash = ZSTD_rollingHash_compute(prev, RSYNC_LENGTH); - if ((hash & hitMask) == hitMask) { + if ((hash & hitMask) == hitMask && mtctx->inBuff.filled >= RSYNC_MIN_BLOCK_SIZE) { /* We're already at a sync point so don't load any more until * we're able to flush this sync point. * This likely happened because the job table was full so we @@ -1728,6 +1755,7 @@ findSynchronizationPoint(ZSTDMT_CCtx const* mtctx, ZSTD_inBuffer const input) BYTE const toRemove = pos < RSYNC_LENGTH ? prev[pos] : istart[pos - RSYNC_LENGTH]; /* if (pos >= RSYNC_LENGTH) assert(ZSTD_rollingHash_compute(istart + pos - RSYNC_LENGTH, RSYNC_LENGTH) == hash); */ hash = ZSTD_rollingHash_rotate(hash, toRemove, istart[pos], primePower); + assert(mtctx->inBuff.filled + pos >= RSYNC_MIN_BLOCK_SIZE); if ((hash & hitMask) == hitMask) { syncPoint.toLoad = pos + 1; syncPoint.flush = 1;