From ff10202b647b4e1f5c42c5fe76340f8ea551d82c Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Thu, 17 Aug 2023 12:09:32 +0200 Subject: [PATCH 1/7] WIP: Add dictionary builder Functional and ok, but has failure modes. --- dict/builder.go | 491 +++++++++++++++++++++++++++++++++++++ dict/cmd/builddict/main.go | 93 +++++++ s2/dict.go | 22 ++ zstd/dict.go | 213 ++++++++++++++++ 4 files changed, 819 insertions(+) create mode 100644 dict/builder.go create mode 100644 dict/cmd/builddict/main.go diff --git a/dict/builder.go b/dict/builder.go new file mode 100644 index 0000000000..d0430ad073 --- /dev/null +++ b/dict/builder.go @@ -0,0 +1,491 @@ +// Copyright 2023+ Klaus Post. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dict + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math/rand" + "sort" + "time" + + "github.com/klauspost/compress/s2" + "github.com/klauspost/compress/zstd" +) + +type match struct { + hash uint32 + n uint32 + offset int64 +} + +type matchValue struct { + value []byte + followBy map[uint32]uint32 + preceededBy map[uint32]uint32 +} + +type Options struct { + // MaxDictSize is the max size of the backreference dictionary. + MaxDictSize int + + // HashBytes is the minimum length to index. + // Must be >=4 and <=8 + HashBytes int + + // Debug output + Output io.Writer + + // ZstdDictID is the Zstd dictionary ID to use. + // Leave at zero to generate a random ID. + ZstdDictID uint32 + + outFormat int +} + +const ( + formatRaw = iota + formatZstd + formatS2 +) + +func BuildZstdDict(input [][]byte, o Options) ([]byte, error) { + o.outFormat = formatZstd + if o.ZstdDictID == 0 { + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + o.ZstdDictID = 32768 + uint32(rng.Int31n((1<<31)-32768)) + } + return buildDict(input, o) +} + +func BuildS2Dict(input [][]byte, o Options) ([]byte, error) { + o.outFormat = formatS2 + return buildDict(input, o) +} + +func BuildRawDict(input [][]byte, o Options) ([]byte, error) { + o.outFormat = formatRaw + return buildDict(input, o) +} + +func buildDict(input [][]byte, o Options) ([]byte, error) { + matches := make(map[uint32]uint32) + offsets := make(map[uint32]int64) + var total uint64 + + wantLen := o.MaxDictSize + hashBytes := o.HashBytes + if len(input) == 0 { + return nil, fmt.Errorf("no input provided") + } + if hashBytes < 4 || hashBytes > 8 { + return nil, fmt.Errorf("HashBytes must be >= 4 and <= 8") + } + println := func(args ...interface{}) { + if o.Output != nil { + fmt.Fprintln(o.Output, args...) + } + } + printf := func(s string, args ...interface{}) { + if o.Output != nil { + fmt.Fprintf(o.Output, s, args...) + } + } + found := make(map[uint32]struct{}) + for i, b := range input { + for k := range found { + delete(found, k) + } + for i := range b { + rem := b[i:] + if len(rem) < 8 { + break + } + h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes)) + if _, ok := found[h]; ok { + // Only count first occurrence + continue + } + matches[h]++ + offsets[h] += int64(i) + total++ + found[h] = struct{}{} + } + printf("\r input %d indexed...", i) + } + threshold := uint32(total / uint64(len(matches))) + println("\nTotal", total, "match", len(matches), "avg", threshold) + sorted := make([]match, 0, len(matches)/2) + for k, v := range matches { + if v <= threshold { + continue + } + sorted = append(sorted, match{hash: k, n: v, offset: offsets[k]}) + } + sort.Slice(sorted, func(i, j int) bool { + if sorted[i].n == sorted[j].n { + return sorted[i].offset < sorted[j].offset + } + return sorted[i].n > sorted[j].n + }) + println("Sorted len:", len(sorted)) + if len(sorted) > wantLen { + sorted = sorted[:wantLen] + } + lowestOcc := sorted[len(sorted)-1].n + println("Cropped len:", len(sorted), "Lowest occurrence:", lowestOcc) + + wantMatches := make(map[uint32]uint32, len(sorted)) + for _, v := range sorted { + wantMatches[v.hash] = v.n + } + + output := make(map[uint32]matchValue, len(sorted)) + var remainCnt [256]int + var remainTotal int + var firstOffsets []int + for i, b := range input { + for i := range b { + rem := b[i:] + if len(rem) < 8 { + break + } + var prev []byte + if i > hashBytes { + prev = b[i-hashBytes:] + } + + h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes)) + if _, ok := wantMatches[h]; !ok { + remainCnt[rem[0]]++ + remainTotal++ + continue + } + mv := output[h] + if len(mv.value) == 0 { + var tmp = make([]byte, hashBytes) + copy(tmp[:], rem) + mv.value = tmp[:] + } + if mv.followBy == nil { + mv.followBy = make(map[uint32]uint32, 4) + mv.preceededBy = make(map[uint32]uint32, 4) + } + if len(rem) > hashBytes+8 { + // Check if we should add next as well. + hNext := hashLen(binary.LittleEndian.Uint64(rem[hashBytes:]), 32, uint8(hashBytes)) + if _, ok := wantMatches[hNext]; ok { + mv.followBy[hNext]++ + } + } + if len(prev) >= 8 { + // Check if we should prev next as well. + hPrev := hashLen(binary.LittleEndian.Uint64(prev), 32, uint8(hashBytes)) + if _, ok := wantMatches[hPrev]; ok { + mv.preceededBy[hPrev]++ + } + } + output[h] = mv + } + printf("\rinput %d re-indexed...", i) + } + println("") + dst := make([][]byte, 0, wantLen/hashBytes) + added := 0 + const printUntil = 500 + for i, e := range sorted { + if added > o.MaxDictSize { + break + } + m, ok := output[e.hash] + if !ok { + // Already added + continue + } + var tmp = make([]byte, 0, hashBytes*2) + { + sortedPrev := make([]match, 0, len(m.followBy)) + for k, v := range m.preceededBy { + if _, ok := output[k]; !ok { + continue + } + sortedPrev = append(sortedPrev, match{ + hash: k, + n: v, + }) + } + if len(sortedPrev) > 0 { + sort.Slice(sortedPrev, func(i, j int) bool { + return sortedPrev[i].n > sortedPrev[j].n + }) + bestPrev := output[sortedPrev[0].hash] + tmp = append(tmp, bestPrev.value...) + } + } + tmp = append(tmp, m.value...) + delete(output, e.hash) + wantLen := e.n / uint32(hashBytes) / 4 + if wantLen <= lowestOcc { + wantLen = lowestOcc + } + for { + var nh uint32 // Next hash + stopAfter := false + if true { + sortedFollow := make([]match, 0, len(m.followBy)) + for k, v := range m.followBy { + if _, ok := output[k]; !ok { + continue + } + sortedFollow = append(sortedFollow, match{ + hash: k, + n: v, + }) + } + if len(sortedFollow) == 0 { + break + } + sort.Slice(sortedFollow, func(i, j int) bool { + return sortedFollow[i].n > sortedFollow[j].n + }) + nh = sortedFollow[0].hash + stopAfter = sortedFollow[0].n < wantLen + } + m, ok = output[nh] + if !ok { + break + } + if len(tmp) > 0 { + // Delete all hashes that are in the current string to avoid stuttering. + var toDel [16 + 8]byte + copy(toDel[:], tmp[len(tmp)-hashBytes:]) + copy(toDel[hashBytes:], m.value) + for i := range toDel[:hashBytes*2] { + delete(output, hashLen(binary.LittleEndian.Uint64(toDel[i:]), 32, uint8(hashBytes))) + } + } + tmp = append(tmp, m.value...) + //delete(output, nh) + if stopAfter { + // Last entry was no significant. + break + } + } + if i < printUntil { + printf("ENTRY %d: %q (%d occurrences, cutoff %d)\n", i, string(tmp), e.n, wantLen) + } + // Delete substrings already added. + if len(tmp) > hashBytes { + for j := range tmp[:len(tmp)-hashBytes+1] { + var t8 [8]byte + copy(t8[:], tmp[j:]) + if i < 100 { + if false { + printf("DELETE %q\n", string(t8[:hashBytes])) + } + } + delete(output, hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes))) + } + } + dst = append(dst, tmp) + added += len(tmp) + // Find offsets + // TODO: This can be better if done as a global search. + if len(firstOffsets) < 3 { + if len(tmp) > 16 { + tmp = tmp[:16] + } + offCnt := make(map[int]int, len(input)) + // Find first offsets + for _, b := range input { + off := bytes.Index(b, tmp) + if off == -1 { + continue + } + offCnt[off]++ + } + for _, off := range firstOffsets { + // Very unlikely, but we deleted it just in case + delete(offCnt, off-added) + } + maxCnt := 0 + maxOffset := 0 + for k, v := range offCnt { + if v == maxCnt && k > maxOffset { + // Prefer the longer offset on ties , since it is more expensive to encode + maxCnt = v + maxOffset = k + continue + } + + if v > maxCnt { + maxCnt = v + maxOffset = k + } + } + if maxCnt > 1 { + firstOffsets = append(firstOffsets, maxOffset+added) + println(" - Offset:", len(firstOffsets), "at", maxOffset+added, "count:", maxCnt, "total added:", added, "src index", maxOffset) + } + } + } + out := bytes.NewBuffer(nil) + written := 0 + for i, toWrite := range dst { + if len(toWrite)+written > wantLen { + toWrite = toWrite[:wantLen-written] + } + dst[i] = toWrite + written += len(toWrite) + if written >= wantLen { + dst = dst[:i+1] + break + } + } + // Write in reverse order. + for i := range dst { + toWrite := dst[len(dst)-i-1] + out.Write(toWrite) + } + if o.outFormat == formatRaw { + return out.Bytes(), nil + } + + if o.outFormat == formatS2 { + dOff := 0 + dBytes := out.Bytes() + if len(dBytes) > s2.MaxDictSize { + dBytes = dBytes[:s2.MaxDictSize] + } + for _, off := range firstOffsets { + myOff := len(dBytes) - off + if myOff < 0 || myOff > s2.MaxDictSrcOffset { + continue + } + dOff = myOff + } + + dict := s2.MakeDictManual(dBytes, uint16(dOff)) + if dict == nil { + return nil, fmt.Errorf("unable to create s2 dictionary") + } + return dict.Bytes(), nil + } + /* + avgSize := 256 + println("\nHuffman: literal total:", remainTotal, "normalized counts on remainder size:", avgSize) + huffBuff := make([]byte, 0, avgSize) + // Target size + div := remainTotal / avgSize + if div < 1 { + div = 1 + } + for i, n := range remainCnt[:] { + if n > 0 { + n = n / div + if n == 0 { + n = 1 + } + huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...) + fmt.Printf("[%d: %d], ", i, n) + } + } + println("") + scratch := &huff0.Scratch{} + _, _, err := huff0.Compress1X(huffBuff, scratch) + if err != nil { + // TODO: Handle RLE + return nil, err + } + println("Huffman table:", len(scratch.OutTable), "bytes") + */ + offsetsZstd := [3]int{1, 4, 8} + for i, off := range firstOffsets { + if i >= 3 || off == 0 || off >= out.Len() { + break + } + offsetsZstd[i] = off + } + println("\nCompressing. Offsets:", offsetsZstd) + return zstd.BuildDict(zstd.BuildDictOptions{ + ID: o.ZstdDictID, + Contents: input, + History: out.Bytes(), + Offsets: offsetsZstd, + }) +} + +const ( + prime3bytes = 506832829 + prime4bytes = 2654435761 + prime5bytes = 889523592379 + prime6bytes = 227718039650203 + prime7bytes = 58295818150454627 + prime8bytes = 0xcf1bbcdcb7a56463 +) + +// hashLen returns a hash of the lowest l bytes of u for a size size of h bytes. +// l must be >=4 and <=8. Any other value will return hash for 4 bytes. +// h should always be <32. +// Preferably h and l should be a constant. +// LENGTH 4 is passed straight through +func hashLen(u uint64, hashLog, mls uint8) uint32 { + switch mls { + case 5: + return hash5(u, hashLog) + case 6: + return hash6(u, hashLog) + case 7: + return hash7(u, hashLog) + case 8: + return hash8(u, hashLog) + default: + return uint32(u) + } +} + +// hash3 returns the hash of the lower 3 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <32. +func hash3(u uint32, h uint8) uint32 { + return ((u << (32 - 24)) * prime3bytes) >> ((32 - h) & 31) +} + +// hash4 returns the hash of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <32. +func hash4(u uint32, h uint8) uint32 { + return (u * prime4bytes) >> ((32 - h) & 31) +} + +// hash4x64 returns the hash of the lowest 4 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <32. +func hash4x64(u uint64, h uint8) uint32 { + return (uint32(u) * prime4bytes) >> ((32 - h) & 31) +} + +// hash5 returns the hash of the lowest 5 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash5(u uint64, h uint8) uint32 { + return uint32(((u << (64 - 40)) * prime5bytes) >> ((64 - h) & 63)) +} + +// hash6 returns the hash of the lowest 6 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash6(u uint64, h uint8) uint32 { + return uint32(((u << (64 - 48)) * prime6bytes) >> ((64 - h) & 63)) +} + +// hash7 returns the hash of the lowest 7 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash7(u uint64, h uint8) uint32 { + return uint32(((u << (64 - 56)) * prime7bytes) >> ((64 - h) & 63)) +} + +// hash8 returns the hash of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash8(u uint64, h uint8) uint32 { + return uint32((u * prime8bytes) >> ((64 - h) & 63)) +} diff --git a/dict/cmd/builddict/main.go b/dict/cmd/builddict/main.go new file mode 100644 index 0000000000..f57325d7ac --- /dev/null +++ b/dict/cmd/builddict/main.go @@ -0,0 +1,93 @@ +// Copyright 2023+ Klaus Post. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "flag" + "fmt" + "io" + "log" + "os" + "path/filepath" + + "github.com/klauspost/compress/dict" +) + +var ( + wantLenFlag = flag.Int("len", 112<<10, "Specify custom output size") + wantHashBytes = flag.Int("hash", 8, "Hash bytes match length. Minimum match length.") + wantMaxBytes = flag.Int("max", 32<<10, "Max input length to index per input file") + wantOutput = flag.String("o", "dictionary.bin", "Output name") + wantFormat = flag.String("format", "zstd", `Output type. "zstd" "s2" or "raw"`) + wantZstdID = flag.Uint("zstdid", 0, "Zstd dictionary ID. 0 will be random") + quiet = flag.Bool("q", false, "Do not print progress") +) + +func main() { + flag.Parse() + o := dict.Options{ + MaxDictSize: *wantLenFlag, + HashBytes: *wantHashBytes, + Output: os.Stdout, + ZstdDictID: uint32(*wantZstdID), + } + if *wantOutput == "" || *quiet { + o.Output = nil + } + var input [][]byte + base := flag.Arg(0) + if base == "" { + log.Fatal("no path with files specified") + } + + // Index ALL hashes in all files. + filepath.Walk(base, func(path string, info os.FileInfo, err error) error { + if info.IsDir() { + return nil + } + + f, err := os.Open(path) + if err != nil { + log.Print(err) + return nil + } + defer f.Close() + b, err := io.ReadAll(io.LimitReader(f, int64(*wantMaxBytes))) + if len(b) < 8 { + return nil + } + input = append(input, b) + if !*quiet { + fmt.Print("\r"+info.Name(), " read...") + } + return nil + }) + var out []byte + var err error + switch *wantFormat { + case "zstd": + out, err = dict.BuildZstdDict(input, o) + case "s2": + out, err = dict.BuildS2Dict(input, o) + case "raw": + out, err = dict.BuildRawDict(input, o) + default: + err = fmt.Errorf("unknown format %q", *wantFormat) + } + if err != nil { + log.Fatal(err) + } + if *wantOutput != "" { + err = os.WriteFile(*wantOutput, out, 0666) + if err != nil { + log.Fatal(err) + } + } else { + _, err = os.Stdout.Write(out) + if err != nil { + log.Fatal(err) + } + } +} diff --git a/s2/dict.go b/s2/dict.go index 24f7ce80bc..93e858ba65 100644 --- a/s2/dict.go +++ b/s2/dict.go @@ -106,6 +106,28 @@ func MakeDict(data []byte, searchStart []byte) *Dict { return &d } +// MakeDict will create a dictionary. +// 'data' must be at least MinDictSize. +// If data is longer than MaxDictSize only the last MaxDictSize bytes will be used. +// A manual first repeat value must be provided. It cannot be 0. +func MakeDictManual(data []byte, firstIdx uint16) *Dict { + if len(data) == 0 || int(firstIdx) > len(data)-8 || len(data) > MaxDictSize { + return nil + } + var d Dict + dict := data + d.dict = dict + if cap(d.dict) < len(d.dict)+16 { + d.dict = append(make([]byte, 0, len(d.dict)+16), d.dict...) + } + if len(dict) < MinDictSize { + return nil + } + + d.repeat = int(firstIdx) + return &d +} + // Encode returns the encoded form of src. The returned slice may be a sub- // slice of dst if dst was large enough to hold the entire encoded block. // Otherwise, a newly allocated slice will be returned. diff --git a/zstd/dict.go b/zstd/dict.go index ca0951452e..5acd6457c4 100644 --- a/zstd/dict.go +++ b/zstd/dict.go @@ -1,10 +1,12 @@ package zstd import ( + "bytes" "encoding/binary" "errors" "fmt" "io" + "math" "github.com/klauspost/compress/huff0" ) @@ -159,3 +161,214 @@ func InspectDictionary(b []byte) (interface { d, err := loadDict(b) return d, err } + +type BuildDictOptions struct { + // Dictionary ID. + ID uint32 + + // Content to use to create dictionary tables. + Contents [][]byte + + // History to use for all blocks. + History []byte + + // Offsets to use. + Offsets [3]int +} + +func BuildDict(o BuildDictOptions) ([]byte, error) { + initPredefined() + hist := o.History + contents := o.Contents + const debug = false + if len(hist) > dictMaxLength { + return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), dictMaxLength) + } + if len(hist) < 8 { + return nil, fmt.Errorf("dictionary of size %d < %d", len(hist), 8) + } + if len(contents) == 0 { + return nil, errors.New("no content provided") + } + d := dict{ + id: o.ID, + litEnc: nil, + llDec: sequenceDec{}, + ofDec: sequenceDec{}, + mlDec: sequenceDec{}, + offsets: o.Offsets, + content: hist, + } + block := blockEnc{lowMem: false} + block.init() + enc := &bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(maxMatchLen), bufferReset: math.MaxInt32 - int32(maxMatchLen*2), lowMem: false}} + var ( + remain [256]int + ll [256]int + ml [256]int + of [256]int + ) + addValues := func(dst *[256]int, src []byte) { + for _, v := range src { + dst[v]++ + } + } + addHist := func(dst *[256]int, src *[256]uint32) { + for i, v := range src { + dst[i] += int(v) + } + } + seqs := 0 + nUsed := 0 + litTotal := 0 + for _, b := range contents { + block.reset(nil) + if len(b) < 8 { + continue + } + nUsed++ + enc.Reset(&d, true) + enc.Encode(&block, b) + addValues(&remain, block.literals) + litTotal += len(block.literals) + seqs += len(block.sequences) + block.genCodes() + addHist(&ll, block.coders.llEnc.Histogram()) + addHist(&ml, block.coders.mlEnc.Histogram()) + addHist(&of, block.coders.ofEnc.Histogram()) + } + if nUsed == 0 || seqs == 0 { + return nil, fmt.Errorf("%d blocks, %d sequences found", nUsed, seqs) + } + if debug { + fmt.Println("Sequences:", seqs, "Blocks:", nUsed, "Literals:", litTotal) + } + if seqs/nUsed < 512 { + // Use 512 as minimum. + nUsed = seqs / 512 + } + copyHist := func(dst *fseEncoder, src *[256]int) ([]byte, error) { + hist := dst.Histogram() + var maxSym uint8 + var maxCount int + var fakeLength int + for i, v := range src { + if v > 0 { + v = v / nUsed + if v == 0 { + v = 1 + } + } + if v > maxCount { + maxCount = v + } + if v != 0 { + maxSym = uint8(i) + } + fakeLength += v + hist[i] = uint32(v) + } + dst.HistogramFinished(maxSym, maxCount) + dst.reUsed = false + dst.useRLE = false + err := dst.normalizeCount(fakeLength) + if err != nil { + return nil, err + } + if debug { + fmt.Println("RAW:", dst.count[:maxSym+1], "NORM:", dst.norm[:maxSym+1], "LEN:", fakeLength) + } + return dst.writeCount(nil) + } + if debug { + fmt.Print("Literal lengths: ") + } + llTable, err := copyHist(block.coders.llEnc, &ll) + if err != nil { + return nil, err + } + if debug { + fmt.Print("Match lengths: ") + } + mlTable, err := copyHist(block.coders.mlEnc, &ml) + if err != nil { + return nil, err + } + if debug { + fmt.Print("Offsets: ") + } + ofTable, err := copyHist(block.coders.ofEnc, &of) + if err != nil { + return nil, err + } + + // Liteal table + avgSize := litTotal + if avgSize > huff0.BlockSizeMax/2 { + avgSize = huff0.BlockSizeMax / 2 + } + huffBuff := make([]byte, 0, avgSize) + // Target size + div := litTotal / avgSize + if div < 1 { + div = 1 + } + if debug { + fmt.Println("Huffman weights:") + } + for i, n := range remain[:] { + if n > 0 { + n = n / div + // Allow all entries to be represented. + if n == 0 { + n = 1 + } + huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...) + if debug { + fmt.Printf("[%d: %d], ", i, n) + } + } + } + if remain[255]/div == 0 { + huffBuff = append(huffBuff, 255) + } + scratch := &huff0.Scratch{TableLog: 11} + _, _, err = huff0.Compress1X(huffBuff, scratch) + if err != nil { + // TODO: Handle RLE + return nil, err + } + + var out bytes.Buffer + out.Write([]byte(dictMagic)) + out.Write(binary.LittleEndian.AppendUint32(nil, o.ID)) + out.Write(scratch.OutTable) + if debug { + fmt.Println("huff table:", len(scratch.OutTable), "bytes") + fmt.Println("of table:", len(ofTable), "bytes") + fmt.Println("ml table:", len(mlTable), "bytes") + fmt.Println("ll table:", len(llTable), "bytes") + } + out.Write(ofTable) + out.Write(mlTable) + out.Write(llTable) + out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[0]))) + out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[1]))) + out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[2]))) + out.Write(hist) + if debug { + _, err := loadDict(out.Bytes()) + if err != nil { + panic(err) + } + i, err := InspectDictionary(out.Bytes()) + if err != nil { + panic(err) + } + fmt.Println("ID:", i.ID()) + fmt.Println("Content size:", i.ContentSize()) + fmt.Println("Encoder:", i.LitEncoder() != nil) + fmt.Println("Offsets:", i.Offsets()) + } + return out.Bytes(), nil +} From dbdb94e878004d465b5dde25f2be9b806b3542f3 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Thu, 17 Aug 2023 14:28:06 +0200 Subject: [PATCH 2/7] Reduce minimum to 16 bytes. Keep linter happy. --- zstd/blockenc.go | 4 ++-- zstd/dict.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/zstd/blockenc.go b/zstd/blockenc.go index fd4a36f730..25b1830e0b 100644 --- a/zstd/blockenc.go +++ b/zstd/blockenc.go @@ -361,7 +361,7 @@ func (b *blockEnc) encodeLits(lits []byte, raw bool) error { if len(lits) >= 1024 { // Use 4 Streams. out, reUsed, err = huff0.Compress4X(lits, b.litEnc) - } else if len(lits) > 32 { + } else if len(lits) > 16 { // Use 1 stream single = true out, reUsed, err = huff0.Compress1X(lits, b.litEnc) @@ -503,7 +503,7 @@ func (b *blockEnc) encode(org []byte, raw, rawAllLits bool) error { if len(b.literals) >= 1024 && !raw { // Use 4 Streams. out, reUsed, err = huff0.Compress4X(b.literals, b.litEnc) - } else if len(b.literals) > 32 && !raw { + } else if len(b.literals) > 16 && !raw { // Use 1 stream single = true out, reUsed, err = huff0.Compress1X(b.literals, b.litEnc) diff --git a/zstd/dict.go b/zstd/dict.go index 5acd6457c4..44ccb4a521 100644 --- a/zstd/dict.go +++ b/zstd/dict.go @@ -181,8 +181,8 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { hist := o.History contents := o.Contents const debug = false - if len(hist) > dictMaxLength { - return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), dictMaxLength) + if int64(len(hist)) > dictMaxLength { + return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), int64(dictMaxLength)) } if len(hist) < 8 { return nil, fmt.Errorf("dictionary of size %d < %d", len(hist), 8) From 1ca482fd4bf7a524263aea59fa82f05844c4df96 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Thu, 17 Aug 2023 15:38:23 +0200 Subject: [PATCH 3/7] Check sizes compared to header sizes. --- zstd/blockenc.go | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/zstd/blockenc.go b/zstd/blockenc.go index 25b1830e0b..aa20c82b24 100644 --- a/zstd/blockenc.go +++ b/zstd/blockenc.go @@ -368,7 +368,14 @@ func (b *blockEnc) encodeLits(lits []byte, raw bool) error { } else { err = huff0.ErrIncompressible } - + if err == nil && len(out)+5 > len(lits) { + // If we are close, we may still be worse or equal to raw. + var lh literalsHeader + lh.setSizes(len(out), len(lits), single) + if len(out)+lh.size() >= len(lits) { + err = huff0.ErrIncompressible + } + } switch err { case huff0.ErrIncompressible: if debugEncoder { @@ -511,6 +518,17 @@ func (b *blockEnc) encode(org []byte, raw, rawAllLits bool) error { err = huff0.ErrIncompressible } + if err == nil && len(out)+5 > len(b.literals) { + // If we are close, we may still be worse or equal to raw. + var lh literalsHeader + lh.setSize(len(b.literals)) + szRaw := lh.size() + lh.setSizes(len(out), len(b.literals), single) + szComp := lh.size() + if len(out)+szComp >= len(b.literals)+szRaw { + err = huff0.ErrIncompressible + } + } switch err { case huff0.ErrIncompressible: lh.setType(literalsBlockRaw) From 523fb7acb1d6090f2787b4a68d916dfc9e5ece1b Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 21 Aug 2023 11:50:44 +0200 Subject: [PATCH 4/7] Try more huffman tables before bailing out. Add "lazy" matching. Should probably be a skipping attempt instead. --- dict/builder.go | 27 +++++++++++++++---- dict/cmd/builddict/main.go | 2 +- zstd/dict.go | 55 +++++++++++++++++++++++++++++++++----- 3 files changed, 71 insertions(+), 13 deletions(-) diff --git a/dict/builder.go b/dict/builder.go index d0430ad073..4a57aeab7d 100644 --- a/dict/builder.go +++ b/dict/builder.go @@ -228,15 +228,17 @@ func buildDict(input [][]byte, o Options) ([]byte, error) { } tmp = append(tmp, m.value...) delete(output, e.hash) + wantLen := e.n / uint32(hashBytes) / 4 if wantLen <= lowestOcc { wantLen = lowestOcc } + sortedFollow := make([]match, 0, len(m.followBy)) for { var nh uint32 // Next hash stopAfter := false if true { - sortedFollow := make([]match, 0, len(m.followBy)) + sortedFollow = sortedFollow[:0] for k, v := range m.followBy { if _, ok := output[k]; !ok { continue @@ -247,6 +249,20 @@ func buildDict(input [][]byte, o Options) ([]byte, error) { }) } if len(sortedFollow) == 0 { + // Step back + if len(tmp) > hashBytes { + var t8 [8]byte + copy(t8[:], tmp[len(tmp)-hashBytes-1:]) + m, ok = output[hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes))] + if ok && len(m.followBy) > 0 { + tmp = tmp[:len(tmp)-1] + continue + } + } else { + if i < printUntil { + printf("FOLLOW: none after %q\n", string(m.value)) + } + } break } sort.Slice(sortedFollow, func(i, j int) bool { @@ -254,6 +270,9 @@ func buildDict(input [][]byte, o Options) ([]byte, error) { }) nh = sortedFollow[0].hash stopAfter = sortedFollow[0].n < wantLen + if stopAfter && i < printUntil { + printf("FOLLOW: %d > %d after %q\n", sortedFollow[0].n, wantLen, string(m.value)) + } } m, ok = output[nh] if !ok { @@ -283,10 +302,8 @@ func buildDict(input [][]byte, o Options) ([]byte, error) { for j := range tmp[:len(tmp)-hashBytes+1] { var t8 [8]byte copy(t8[:], tmp[j:]) - if i < 100 { - if false { - printf("DELETE %q\n", string(t8[:hashBytes])) - } + if i < printUntil { + printf("* POST DELETE %q\n", string(t8[:hashBytes])) } delete(output, hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes))) } diff --git a/dict/cmd/builddict/main.go b/dict/cmd/builddict/main.go index f57325d7ac..c80ec93f9c 100644 --- a/dict/cmd/builddict/main.go +++ b/dict/cmd/builddict/main.go @@ -17,7 +17,7 @@ import ( var ( wantLenFlag = flag.Int("len", 112<<10, "Specify custom output size") - wantHashBytes = flag.Int("hash", 8, "Hash bytes match length. Minimum match length.") + wantHashBytes = flag.Int("hash", 6, "Hash bytes match length. Minimum match length.") wantMaxBytes = flag.Int("max", 32<<10, "Max input length to index per input file") wantOutput = flag.String("o", "dictionary.bin", "Output name") wantFormat = flag.String("format", "zstd", `Output type. "zstd" "s2" or "raw"`) diff --git a/zstd/dict.go b/zstd/dict.go index 44ccb4a521..e7def04dd1 100644 --- a/zstd/dict.go +++ b/zstd/dict.go @@ -16,9 +16,8 @@ type dict struct { litEnc *huff0.Scratch llDec, ofDec, mlDec sequenceDec - //llEnc, ofEnc, mlEnc []*fseEncoder - offsets [3]int - content []byte + offsets [3]int + content []byte } const dictMagic = "\x37\xa4\x30\xec" @@ -333,10 +332,52 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { huffBuff = append(huffBuff, 255) } scratch := &huff0.Scratch{TableLog: 11} - _, _, err = huff0.Compress1X(huffBuff, scratch) - if err != nil { - // TODO: Handle RLE - return nil, err + for tries := 0; tries < 255; tries++ { + scratch = &huff0.Scratch{TableLog: 11} + _, _, err = huff0.Compress1X(huffBuff, scratch) + if err == nil { + break + } + if debug { + fmt.Printf("Try %d: Huffman error: %v\n", tries+1, err) + } + huffBuff = huffBuff[:0] + if tries == 250 { + if debug { + fmt.Println("Huffman: Bailing out with predefined table") + } + + // Bail out.... Just generate something + huffBuff = append(huffBuff, bytes.Repeat([]byte{255}, 10000)...) + for i := 0; i < 128; i++ { + huffBuff = append(huffBuff, byte(i)) + } + continue + } + if errors.Is(err, huff0.ErrIncompressible) { + // Try truncating least common. + for i, n := range remain[:] { + if n > 0 { + n = n / (div * (i + 1)) + if n > 0 { + huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...) + } + } + if i == 255 && (len(huffBuff) == 0 || huffBuff[len(huffBuff)-1] != 255) { + huffBuff = append(huffBuff, 255) + } + } + } + if errors.Is(err, huff0.ErrUseRLE) { + for i, n := range remain[:] { + n = n / (div * (i + 1)) + // Allow all entries to be represented. + if n == 0 { + n = 1 + } + huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...) + } + } } var out bytes.Buffer From ed11654b37f3bd22055f9a5b2b6fdf2b7dc32eab Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 25 Aug 2023 14:52:56 +0200 Subject: [PATCH 5/7] Rewrite zstd offsets Group output and use offset as secondary. Add step back for long hashes. --- dict/builder.go | 65 +++++++++++++++++++++++++++++++++++++------------ zstd/dict.go | 64 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 112 insertions(+), 17 deletions(-) diff --git a/dict/builder.go b/dict/builder.go index 4a57aeab7d..1c7f0af0f4 100644 --- a/dict/builder.go +++ b/dict/builder.go @@ -127,8 +127,20 @@ func buildDict(input [][]byte, o Options) ([]byte, error) { sorted = append(sorted, match{hash: k, n: v, offset: offsets[k]}) } sort.Slice(sorted, func(i, j int) bool { - if sorted[i].n == sorted[j].n { - return sorted[i].offset < sorted[j].offset + if true { + // Group very similar counts together and emit low offsets first. + // This will keep together strings that are very similar. + deltaN := int(sorted[i].n) - int(sorted[j].n) + if deltaN < 0 { + deltaN = -deltaN + } + if uint32(deltaN) < sorted[i].n/32 { + return sorted[i].offset < sorted[j].offset + } + } else { + if sorted[i].n == sorted[j].n { + return sorted[i].offset < sorted[j].offset + } } return sorted[i].n > sorted[j].n }) @@ -199,6 +211,7 @@ func buildDict(input [][]byte, o Options) ([]byte, error) { const printUntil = 500 for i, e := range sorted { if added > o.MaxDictSize { + println("Ending. Next Occurrence:", e.n) break } m, ok := output[e.hash] @@ -206,11 +219,16 @@ func buildDict(input [][]byte, o Options) ([]byte, error) { // Already added continue } + wantLen := e.n / uint32(hashBytes) / 4 + if wantLen <= lowestOcc { + wantLen = lowestOcc + } + var tmp = make([]byte, 0, hashBytes*2) { sortedPrev := make([]match, 0, len(m.followBy)) for k, v := range m.preceededBy { - if _, ok := output[k]; !ok { + if _, ok := output[k]; v < wantLen || !ok { continue } sortedPrev = append(sortedPrev, match{ @@ -229,35 +247,47 @@ func buildDict(input [][]byte, o Options) ([]byte, error) { tmp = append(tmp, m.value...) delete(output, e.hash) - wantLen := e.n / uint32(hashBytes) / 4 - if wantLen <= lowestOcc { - wantLen = lowestOcc - } sortedFollow := make([]match, 0, len(m.followBy)) for { var nh uint32 // Next hash stopAfter := false - if true { + { sortedFollow = sortedFollow[:0] for k, v := range m.followBy { if _, ok := output[k]; !ok { continue } sortedFollow = append(sortedFollow, match{ - hash: k, - n: v, + hash: k, + n: v, + offset: offsets[k], }) } if len(sortedFollow) == 0 { // Step back - if len(tmp) > hashBytes { + // Extremely small impact, but helps longer hashes a bit. + const stepBack = 2 + if stepBack > 0 && len(tmp) >= hashBytes+stepBack { var t8 [8]byte - copy(t8[:], tmp[len(tmp)-hashBytes-1:]) + copy(t8[:], tmp[len(tmp)-hashBytes-stepBack:]) m, ok = output[hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes))] if ok && len(m.followBy) > 0 { - tmp = tmp[:len(tmp)-1] - continue + found := []byte(nil) + for k := range m.followBy { + v, ok := output[k] + if !ok { + continue + } + found = v.value + break + } + if found != nil { + tmp = tmp[:len(tmp)-stepBack] + printf("Step back: %q + %q\n", string(tmp), string(found)) + continue + } } + break } else { if i < printUntil { printf("FOLLOW: none after %q\n", string(m.value)) @@ -266,12 +296,15 @@ func buildDict(input [][]byte, o Options) ([]byte, error) { break } sort.Slice(sortedFollow, func(i, j int) bool { + if sortedFollow[i].n == sortedFollow[j].n { + return sortedFollow[i].offset > sortedFollow[j].offset + } return sortedFollow[i].n > sortedFollow[j].n }) nh = sortedFollow[0].hash stopAfter = sortedFollow[0].n < wantLen if stopAfter && i < printUntil { - printf("FOLLOW: %d > %d after %q\n", sortedFollow[0].n, wantLen, string(m.value)) + printf("FOLLOW: %d < %d after %q. Stopping after this.\n", sortedFollow[0].n, wantLen, string(m.value)) } } m, ok = output[nh] @@ -303,7 +336,7 @@ func buildDict(input [][]byte, o Options) ([]byte, error) { var t8 [8]byte copy(t8[:], tmp[j:]) if i < printUntil { - printf("* POST DELETE %q\n", string(t8[:hashBytes])) + //printf("* POST DELETE %q\n", string(t8[:hashBytes])) } delete(output, hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes))) } diff --git a/zstd/dict.go b/zstd/dict.go index e7def04dd1..bb46ef7224 100644 --- a/zstd/dict.go +++ b/zstd/dict.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "math" + "sort" "github.com/klauspost/compress/huff0" ) @@ -220,6 +221,7 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { seqs := 0 nUsed := 0 litTotal := 0 + newOffsets := make(map[uint32]int, 1000) for _, b := range contents { block.reset(nil) if len(b) < 8 { @@ -235,7 +237,55 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { addHist(&ll, block.coders.llEnc.Histogram()) addHist(&ml, block.coders.mlEnc.Histogram()) addHist(&of, block.coders.ofEnc.Histogram()) + for i, seq := range block.sequences { + if i > 3 { + break + } + offset := seq.offset + if offset == 0 { + continue + } + if offset > 3 { + newOffsets[offset-3]++ + } else { + newOffsets[uint32(o.Offsets[offset-1])]++ + } + } + } + // Find most used offsets. + var sortedOffsets []uint32 + for k := range newOffsets { + sortedOffsets = append(sortedOffsets, k) + } + sort.Slice(sortedOffsets, func(i, j int) bool { + a, b := sortedOffsets[i], sortedOffsets[j] + if a == b { + // Prefer the longer offset + return sortedOffsets[i] > sortedOffsets[j] + } + return newOffsets[sortedOffsets[i]] > newOffsets[sortedOffsets[j]] + }) + if len(sortedOffsets) > 3 { + if debug { + fmt.Print("Offsets:") + for i, v := range sortedOffsets { + if i > 20 { + break + } + fmt.Printf("[%d: %d],", v, newOffsets[v]) + } + fmt.Println("") + } + + sortedOffsets = sortedOffsets[:3] + } + for i, v := range sortedOffsets { + o.Offsets[i] = int(v) + } + if debug { + fmt.Println("New repeat offsets", o.Offsets) } + if nUsed == 0 || seqs == 0 { return nil, fmt.Errorf("%d blocks, %d sequences found", nUsed, seqs) } @@ -301,7 +351,7 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { return nil, err } - // Liteal table + // Literal table avgSize := litTotal if avgSize > huff0.BlockSizeMax/2 { avgSize = huff0.BlockSizeMax / 2 @@ -410,6 +460,18 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { fmt.Println("Content size:", i.ContentSize()) fmt.Println("Encoder:", i.LitEncoder() != nil) fmt.Println("Offsets:", i.Offsets()) + enc, err := NewWriter(nil, WithEncoderDict(out.Bytes())) + if err != nil { + panic(err) + } + defer enc.Close() + var dst []byte + var totalSize int + for _, b := range contents { + dst = enc.EncodeAll(b, dst[:0]) + totalSize += len(dst) + } + fmt.Println("Compressed size:", totalSize) } return out.Bytes(), nil } From 9663333ffddb25658b5a59853a403cb21c046eea Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 29 Aug 2023 12:09:47 +0200 Subject: [PATCH 6/7] Add docs and zstd level and compat. --- dict/README.md | 108 +++++++++++++++++++++++++++++++ dict/builder.go | 58 ++++++++--------- dict/cmd/builddict/main.go | 45 +++++++++---- zstd/dict.go | 129 ++++++++++++++++++++++++++----------- 4 files changed, 259 insertions(+), 81 deletions(-) create mode 100644 dict/README.md diff --git a/dict/README.md b/dict/README.md new file mode 100644 index 0000000000..9d51ce2693 --- /dev/null +++ b/dict/README.md @@ -0,0 +1,108 @@ +# Dictionary builder + +This is an *experimental* dictionary builder for Zstandard, S2, LZ4 and more. + +This diverges from the Zstandard dictionary builder, and may have some failure scenarios for very small or uniform inputs. + +Dictionaries returned should all be valid, but if very little data is supplied, it may not be able to generate a dictionary. + +With a large, diverse sample set, it will generate a dictionary that can compete with the Zstandard dictionary builder, +but for very similar data it will not be able to generate a dictionary that is as good. + +Feedback is welcome. + +## Usage + +First of all a collection of *samples* must be collected. + +These samples should be representative of the input data and should not contain any complete duplicates. + +Only the *beginning* of the samples is important, the rest can be truncated. +Beyond something like 64KB the input is not important anymore. +The commandline tool can do this truncation for you. + +## Command line + +To install the command line tool run: + +``` +$ go install github.com/klaupost/compress/dict/cmd/builddict@latest +``` + +Collect the samples in a directory, for example `samples/`. + +Then run the command line tool. Basic usage is just to pass the directory with the samples: + +``` +$ builddict samples/ +``` + +This will build a Zstandard dictionary and write it to `dictionary.bin` in the current folder. + +The dictionary can be used with the Zstandard command line tool: + +``` +$ zstd -D dictionary.bin input +``` + +### Options + +The command line tool has a few options: + +- `-format`. Output type. "zstd" "s2" or "raw". Default "zstd". + +Output a dictionary in Zstandard format, S2 format or raw bytes. +The raw bytes can be used with Deflate, LZ4, etc. + +- `-hash` Hash bytes match length. Minimum match length. Must be 4-8 (inclusive) Default 6. + +The hash bytes are used to define the shortest matches to look for. +Shorter matches can generate a more fractured dictionary with less compression, but can for certain inputs be better. +Usually lengths around 6-8 are best. + +- `-len` Specify custom output size. Default 114688. +- `-max` Max input length to index per input file. Default 32768. All inputs are truncated to this. +- `-o` Output name. Default `dictionary.bin`. +- `-q` Do not print progress +- `-dictID` zstd dictionary ID. 0 will be random. Default 0. +- `-zcompat` Generate dictionary compatible with zstd 1.5.5 and older. Default false. +- `-zlevel` Zstandard compression level. + +The Zstandard compression level to use when compressing the samples. +The dictionary will be built using the specified encoder level, +which will reflect speed and make the dictionary tailored for that level. +Default will use level 4 (best). + +Valid values are 1-4, where 1 = fastest, 2 = default, 3 = better, 4 = best. + +## Library + +The `github.com/klaupost/compress/dict` package can be used to build dictionaries in code. +The caller must supply a collection of (pre-truncated) samples, and the options to use. +The options largely correspond to the command line options. + +```Go +package main + +import ( + "github.com/klaupost/compress/dict" + "github.com/klauspost/compress/zstd" +) + +func main() { + var samples [][]byte + + // ... Fill samples with representative data. + + dict, err := dict.BuildZstdDict(samples, dict.Options{ + HashLen: 6, + MaxDictSize: 114688, + ZstdDictID: 0, // Random + ZstdCompat: false, + ZstdLevel: zstd.SpeedBestCompression, + }) + // ... Handle error, etc. +} +``` + +There are similar functions for S2 and raw dictionaries (`BuildS2Dict` and `BuildRawDict`). diff --git a/dict/builder.go b/dict/builder.go index 1c7f0af0f4..78d8583801 100644 --- a/dict/builder.go +++ b/dict/builder.go @@ -7,6 +7,7 @@ package dict import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "math/rand" @@ -44,6 +45,16 @@ type Options struct { // Leave at zero to generate a random ID. ZstdDictID uint32 + // ZstdDictCompat will make the dictionary compatible with Zstd v1.5.5 and earlier. + // See https://github.com/facebook/zstd/issues/3724 + ZstdDictCompat bool + + // Use the specified encoder level for Zstandard dictionaries. + // The dictionary will be built using the specified encoder level, + // which will reflect speed and make the dictionary tailored for that level. + // If not set zstd.SpeedBestCompression will be used. + ZstdLevel zstd.EncoderLevel + outFormat int } @@ -53,6 +64,7 @@ const ( formatS2 ) +// BuildZstdDict will build a Zstandard dictionary from the provided input. func BuildZstdDict(input [][]byte, o Options) ([]byte, error) { o.outFormat = formatZstd if o.ZstdDictID == 0 { @@ -62,11 +74,17 @@ func BuildZstdDict(input [][]byte, o Options) ([]byte, error) { return buildDict(input, o) } +// BuildS2Dict will build a S2 dictionary from the provided input. func BuildS2Dict(input [][]byte, o Options) ([]byte, error) { o.outFormat = formatS2 + if o.MaxDictSize > s2.MaxDictSize { + return nil, errors.New("max dict size too large") + } return buildDict(input, o) } +// BuildRawDict will build a raw dictionary from the provided input. +// This can be used for deflate, lz4 and others. func BuildRawDict(input [][]byte, o Options) ([]byte, error) { o.outFormat = formatRaw return buildDict(input, o) @@ -425,34 +443,7 @@ func buildDict(input [][]byte, o Options) ([]byte, error) { } return dict.Bytes(), nil } - /* - avgSize := 256 - println("\nHuffman: literal total:", remainTotal, "normalized counts on remainder size:", avgSize) - huffBuff := make([]byte, 0, avgSize) - // Target size - div := remainTotal / avgSize - if div < 1 { - div = 1 - } - for i, n := range remainCnt[:] { - if n > 0 { - n = n / div - if n == 0 { - n = 1 - } - huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...) - fmt.Printf("[%d: %d], ", i, n) - } - } - println("") - scratch := &huff0.Scratch{} - _, _, err := huff0.Compress1X(huffBuff, scratch) - if err != nil { - // TODO: Handle RLE - return nil, err - } - println("Huffman table:", len(scratch.OutTable), "bytes") - */ + offsetsZstd := [3]int{1, 4, 8} for i, off := range firstOffsets { if i >= 3 || off == 0 || off >= out.Len() { @@ -462,10 +453,13 @@ func buildDict(input [][]byte, o Options) ([]byte, error) { } println("\nCompressing. Offsets:", offsetsZstd) return zstd.BuildDict(zstd.BuildDictOptions{ - ID: o.ZstdDictID, - Contents: input, - History: out.Bytes(), - Offsets: offsetsZstd, + ID: o.ZstdDictID, + Contents: input, + History: out.Bytes(), + Offsets: offsetsZstd, + CompatV155: o.ZstdDictCompat, + Level: o.ZstdLevel, + DebugOut: o.Output, }) } diff --git a/dict/cmd/builddict/main.go b/dict/cmd/builddict/main.go index c80ec93f9c..f9c3428469 100644 --- a/dict/cmd/builddict/main.go +++ b/dict/cmd/builddict/main.go @@ -11,27 +11,34 @@ import ( "log" "os" "path/filepath" + "runtime/debug" "github.com/klauspost/compress/dict" + "github.com/klauspost/compress/zstd" ) var ( - wantLenFlag = flag.Int("len", 112<<10, "Specify custom output size") - wantHashBytes = flag.Int("hash", 6, "Hash bytes match length. Minimum match length.") - wantMaxBytes = flag.Int("max", 32<<10, "Max input length to index per input file") - wantOutput = flag.String("o", "dictionary.bin", "Output name") - wantFormat = flag.String("format", "zstd", `Output type. "zstd" "s2" or "raw"`) - wantZstdID = flag.Uint("zstdid", 0, "Zstd dictionary ID. 0 will be random") - quiet = flag.Bool("q", false, "Do not print progress") + wantLenFlag = flag.Int("len", 112<<10, "Specify custom output size") + wantHashBytes = flag.Int("hash", 6, "Hash bytes match length. Minimum match length.") + wantMaxBytes = flag.Int("max", 32<<10, "Max input length to index per input file") + wantOutput = flag.String("o", "dictionary.bin", "Output name") + wantFormat = flag.String("format", "zstd", `Output type. "zstd" "s2" or "raw"`) + wantZstdID = flag.Uint("dictID", 0, "Zstd dictionary ID. Default (0) will be random") + wantZstdCompat = flag.Bool("zcompat", true, "Generate dictionary compatible with zstd 1.5.5 and older") + wantZstdLevel = flag.Int("zlevel", 0, "Zstd compression level. 0-4") + quiet = flag.Bool("q", false, "Do not print progress") ) func main() { flag.Parse() + debug.SetGCPercent(25) o := dict.Options{ - MaxDictSize: *wantLenFlag, - HashBytes: *wantHashBytes, - Output: os.Stdout, - ZstdDictID: uint32(*wantZstdID), + MaxDictSize: *wantLenFlag, + HashBytes: *wantHashBytes, + Output: os.Stdout, + ZstdDictID: uint32(*wantZstdID), + ZstdDictCompat: *wantZstdCompat, + ZstdLevel: zstd.EncoderLevel(*wantZstdLevel), } if *wantOutput == "" || *quiet { o.Output = nil @@ -39,11 +46,15 @@ func main() { var input [][]byte base := flag.Arg(0) if base == "" { + flag.Usage() log.Fatal("no path with files specified") } // Index ALL hashes in all files. - filepath.Walk(base, func(path string, info os.FileInfo, err error) error { + err := filepath.Walk(base, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } if info.IsDir() { return nil } @@ -58,14 +69,22 @@ func main() { if len(b) < 8 { return nil } + if len(b) == 0 { + return nil + } input = append(input, b) if !*quiet { fmt.Print("\r"+info.Name(), " read...") } return nil }) + if err != nil { + log.Fatal(err) + } + if len(input) == 0 { + log.Fatal("no files read") + } var out []byte - var err error switch *wantFormat { case "zstd": out, err = dict.BuildZstdDict(input, o) diff --git a/zstd/dict.go b/zstd/dict.go index bb46ef7224..8d5567fe64 100644 --- a/zstd/dict.go +++ b/zstd/dict.go @@ -174,13 +174,42 @@ type BuildDictOptions struct { // Offsets to use. Offsets [3]int + + // CompatV155 will make the dictionary compatible with Zstd v1.5.5 and earlier. + // See https://github.com/facebook/zstd/issues/3724 + CompatV155 bool + + // Use the specified encoder level. + // The dictionary will be built using the specified encoder level, + // which will reflect speed and make the dictionary tailored for that level. + // If not set SpeedBestCompression will be used. + Level EncoderLevel + + // DebugOut will write stats and other details here if set. + DebugOut io.Writer } func BuildDict(o BuildDictOptions) ([]byte, error) { initPredefined() hist := o.History contents := o.Contents - const debug = false + debug := o.DebugOut != nil + println := func(args ...interface{}) { + if o.DebugOut != nil { + fmt.Fprintln(o.DebugOut, args...) + } + } + printf := func(s string, args ...interface{}) { + if o.DebugOut != nil { + fmt.Fprintf(o.DebugOut, s, args...) + } + } + print := func(args ...interface{}) { + if o.DebugOut != nil { + fmt.Fprint(o.DebugOut, args...) + } + } + if int64(len(hist)) > dictMaxLength { return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), int64(dictMaxLength)) } @@ -201,7 +230,19 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { } block := blockEnc{lowMem: false} block.init() - enc := &bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(maxMatchLen), bufferReset: math.MaxInt32 - int32(maxMatchLen*2), lowMem: false}} + enc := encoder(&bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(maxMatchLen), bufferReset: math.MaxInt32 - int32(maxMatchLen*2), lowMem: false}}) + if o.Level != 0 { + eOpts := encoderOptions{ + level: o.Level, + blockSize: maxMatchLen, + windowSize: maxMatchLen, + dict: &d, + lowMem: false, + } + enc = eOpts.encoder() + } else { + o.Level = SpeedBestCompression + } var ( remain [256]int ll [256]int @@ -267,14 +308,14 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { }) if len(sortedOffsets) > 3 { if debug { - fmt.Print("Offsets:") + print("Offsets:") for i, v := range sortedOffsets { if i > 20 { break } - fmt.Printf("[%d: %d],", v, newOffsets[v]) + printf("[%d: %d],", v, newOffsets[v]) } - fmt.Println("") + println("") } sortedOffsets = sortedOffsets[:3] @@ -283,14 +324,14 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { o.Offsets[i] = int(v) } if debug { - fmt.Println("New repeat offsets", o.Offsets) + println("New repeat offsets", o.Offsets) } if nUsed == 0 || seqs == 0 { return nil, fmt.Errorf("%d blocks, %d sequences found", nUsed, seqs) } if debug { - fmt.Println("Sequences:", seqs, "Blocks:", nUsed, "Literals:", litTotal) + println("Sequences:", seqs, "Blocks:", nUsed, "Literals:", litTotal) } if seqs/nUsed < 512 { // Use 512 as minimum. @@ -325,26 +366,26 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { return nil, err } if debug { - fmt.Println("RAW:", dst.count[:maxSym+1], "NORM:", dst.norm[:maxSym+1], "LEN:", fakeLength) + println("RAW:", dst.count[:maxSym+1], "NORM:", dst.norm[:maxSym+1], "LEN:", fakeLength) } return dst.writeCount(nil) } if debug { - fmt.Print("Literal lengths: ") + print("Literal lengths: ") } llTable, err := copyHist(block.coders.llEnc, &ll) if err != nil { return nil, err } if debug { - fmt.Print("Match lengths: ") + print("Match lengths: ") } mlTable, err := copyHist(block.coders.mlEnc, &ml) if err != nil { return nil, err } if debug { - fmt.Print("Offsets: ") + print("Offsets: ") } ofTable, err := copyHist(block.coders.ofEnc, &of) if err != nil { @@ -363,7 +404,7 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { div = 1 } if debug { - fmt.Println("Huffman weights:") + println("Huffman weights:") } for i, n := range remain[:] { if n > 0 { @@ -374,11 +415,11 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { } huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...) if debug { - fmt.Printf("[%d: %d], ", i, n) + printf("[%d: %d], ", i, n) } } } - if remain[255]/div == 0 { + if o.CompatV155 && remain[255]/div == 0 { huffBuff = append(huffBuff, 255) } scratch := &huff0.Scratch{TableLog: 11} @@ -389,12 +430,12 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { break } if debug { - fmt.Printf("Try %d: Huffman error: %v\n", tries+1, err) + printf("Try %d: Huffman error: %v\n", tries+1, err) } huffBuff = huffBuff[:0] if tries == 250 { if debug { - fmt.Println("Huffman: Bailing out with predefined table") + println("Huffman: Bailing out with predefined table") } // Bail out.... Just generate something @@ -413,9 +454,12 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...) } } - if i == 255 && (len(huffBuff) == 0 || huffBuff[len(huffBuff)-1] != 255) { - huffBuff = append(huffBuff, 255) - } + } + if o.CompatV155 && len(huffBuff) > 0 && huffBuff[len(huffBuff)-1] != 255 { + huffBuff = append(huffBuff, 255) + } + if len(huffBuff) == 0 { + huffBuff = append(huffBuff, 0, 255) } } if errors.Is(err, huff0.ErrUseRLE) { @@ -435,10 +479,10 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { out.Write(binary.LittleEndian.AppendUint32(nil, o.ID)) out.Write(scratch.OutTable) if debug { - fmt.Println("huff table:", len(scratch.OutTable), "bytes") - fmt.Println("of table:", len(ofTable), "bytes") - fmt.Println("ml table:", len(mlTable), "bytes") - fmt.Println("ll table:", len(llTable), "bytes") + println("huff table:", len(scratch.OutTable), "bytes") + println("of table:", len(ofTable), "bytes") + println("ml table:", len(mlTable), "bytes") + println("ll table:", len(llTable), "bytes") } out.Write(ofTable) out.Write(mlTable) @@ -456,22 +500,35 @@ func BuildDict(o BuildDictOptions) ([]byte, error) { if err != nil { panic(err) } - fmt.Println("ID:", i.ID()) - fmt.Println("Content size:", i.ContentSize()) - fmt.Println("Encoder:", i.LitEncoder() != nil) - fmt.Println("Offsets:", i.Offsets()) - enc, err := NewWriter(nil, WithEncoderDict(out.Bytes())) - if err != nil { - panic(err) - } - defer enc.Close() - var dst []byte + println("ID:", i.ID()) + println("Content size:", i.ContentSize()) + println("Encoder:", i.LitEncoder() != nil) + println("Offsets:", i.Offsets()) var totalSize int for _, b := range contents { - dst = enc.EncodeAll(b, dst[:0]) - totalSize += len(dst) + totalSize += len(b) + } + + encWith := func(opts ...EOption) int { + enc, err := NewWriter(nil, opts...) + if err != nil { + panic(err) + } + defer enc.Close() + var dst []byte + var totalSize int + for _, b := range contents { + dst = enc.EncodeAll(b, dst[:0]) + totalSize += len(dst) + } + return totalSize } - fmt.Println("Compressed size:", totalSize) + plain := encWith(WithEncoderLevel(o.Level)) + withDict := encWith(WithEncoderLevel(o.Level), WithEncoderDict(out.Bytes())) + println("Input size:", totalSize) + println("Plain Compressed:", plain) + println("Dict Compressed:", withDict) + println("Saved:", plain-withDict, (plain-withDict)/len(contents), "bytes per input (rounded down)") } return out.Bytes(), nil } From 5b71795b7fc9ba61a3a550260837420b35c4309d Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 29 Aug 2023 13:19:11 +0200 Subject: [PATCH 7/7] Final tweaks --- dict/README.md | 2 +- s2/dict.go | 13 +++++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/dict/README.md b/dict/README.md index 9d51ce2693..e4784fb684 100644 --- a/dict/README.md +++ b/dict/README.md @@ -1,6 +1,6 @@ # Dictionary builder -This is an *experimental* dictionary builder for Zstandard, S2, LZ4 and more. +This is an *experimental* dictionary builder for Zstandard, S2, LZ4, deflate and more. This diverges from the Zstandard dictionary builder, and may have some failure scenarios for very small or uniform inputs. diff --git a/s2/dict.go b/s2/dict.go index 93e858ba65..f125ad0963 100644 --- a/s2/dict.go +++ b/s2/dict.go @@ -106,12 +106,12 @@ func MakeDict(data []byte, searchStart []byte) *Dict { return &d } -// MakeDict will create a dictionary. -// 'data' must be at least MinDictSize. -// If data is longer than MaxDictSize only the last MaxDictSize bytes will be used. -// A manual first repeat value must be provided. It cannot be 0. +// MakeDictManual will create a dictionary. +// 'data' must be at least MinDictSize and less than or equal to MaxDictSize. +// A manual first repeat index into data must be provided. +// It must be less than len(data)-8. func MakeDictManual(data []byte, firstIdx uint16) *Dict { - if len(data) == 0 || int(firstIdx) > len(data)-8 || len(data) > MaxDictSize { + if len(data) < MinDictSize || int(firstIdx) >= len(data)-8 || len(data) > MaxDictSize { return nil } var d Dict @@ -120,9 +120,6 @@ func MakeDictManual(data []byte, firstIdx uint16) *Dict { if cap(d.dict) < len(d.dict)+16 { d.dict = append(make([]byte, 0, len(d.dict)+16), d.dict...) } - if len(dict) < MinDictSize { - return nil - } d.repeat = int(firstIdx) return &d