diff --git a/dict/README.md b/dict/README.md new file mode 100644 index 0000000000..e4784fb684 --- /dev/null +++ b/dict/README.md @@ -0,0 +1,108 @@ +# Dictionary builder + +This is an *experimental* dictionary builder for Zstandard, S2, LZ4, deflate and more. + +This diverges from the Zstandard dictionary builder, and may have some failure scenarios for very small or uniform inputs. + +Dictionaries returned should all be valid, but if very little data is supplied, it may not be able to generate a dictionary. + +With a large, diverse sample set, it will generate a dictionary that can compete with the Zstandard dictionary builder, +but for very similar data it will not be able to generate a dictionary that is as good. + +Feedback is welcome. + +## Usage + +First of all a collection of *samples* must be collected. + +These samples should be representative of the input data and should not contain any complete duplicates. + +Only the *beginning* of the samples is important, the rest can be truncated. +Beyond something like 64KB the input is not important anymore. +The commandline tool can do this truncation for you. + +## Command line + +To install the command line tool run: + +``` +$ go install github.com/klaupost/compress/dict/cmd/builddict@latest +``` + +Collect the samples in a directory, for example `samples/`. + +Then run the command line tool. Basic usage is just to pass the directory with the samples: + +``` +$ builddict samples/ +``` + +This will build a Zstandard dictionary and write it to `dictionary.bin` in the current folder. + +The dictionary can be used with the Zstandard command line tool: + +``` +$ zstd -D dictionary.bin input +``` + +### Options + +The command line tool has a few options: + +- `-format`. Output type. "zstd" "s2" or "raw". Default "zstd". + +Output a dictionary in Zstandard format, S2 format or raw bytes. +The raw bytes can be used with Deflate, LZ4, etc. + +- `-hash` Hash bytes match length. Minimum match length. Must be 4-8 (inclusive) Default 6. + +The hash bytes are used to define the shortest matches to look for. +Shorter matches can generate a more fractured dictionary with less compression, but can for certain inputs be better. +Usually lengths around 6-8 are best. + +- `-len` Specify custom output size. Default 114688. +- `-max` Max input length to index per input file. Default 32768. All inputs are truncated to this. +- `-o` Output name. Default `dictionary.bin`. +- `-q` Do not print progress +- `-dictID` zstd dictionary ID. 0 will be random. Default 0. +- `-zcompat` Generate dictionary compatible with zstd 1.5.5 and older. Default false. +- `-zlevel` Zstandard compression level. + +The Zstandard compression level to use when compressing the samples. +The dictionary will be built using the specified encoder level, +which will reflect speed and make the dictionary tailored for that level. +Default will use level 4 (best). + +Valid values are 1-4, where 1 = fastest, 2 = default, 3 = better, 4 = best. + +## Library + +The `github.com/klaupost/compress/dict` package can be used to build dictionaries in code. +The caller must supply a collection of (pre-truncated) samples, and the options to use. +The options largely correspond to the command line options. + +```Go +package main + +import ( + "github.com/klaupost/compress/dict" + "github.com/klauspost/compress/zstd" +) + +func main() { + var samples [][]byte + + // ... Fill samples with representative data. + + dict, err := dict.BuildZstdDict(samples, dict.Options{ + HashLen: 6, + MaxDictSize: 114688, + ZstdDictID: 0, // Random + ZstdCompat: false, + ZstdLevel: zstd.SpeedBestCompression, + }) + // ... Handle error, etc. +} +``` + +There are similar functions for S2 and raw dictionaries (`BuildS2Dict` and `BuildRawDict`). diff --git a/dict/builder.go b/dict/builder.go new file mode 100644 index 0000000000..78d8583801 --- /dev/null +++ b/dict/builder.go @@ -0,0 +1,535 @@ +// Copyright 2023+ Klaus Post. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dict + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math/rand" + "sort" + "time" + + "github.com/klauspost/compress/s2" + "github.com/klauspost/compress/zstd" +) + +type match struct { + hash uint32 + n uint32 + offset int64 +} + +type matchValue struct { + value []byte + followBy map[uint32]uint32 + preceededBy map[uint32]uint32 +} + +type Options struct { + // MaxDictSize is the max size of the backreference dictionary. + MaxDictSize int + + // HashBytes is the minimum length to index. + // Must be >=4 and <=8 + HashBytes int + + // Debug output + Output io.Writer + + // ZstdDictID is the Zstd dictionary ID to use. + // Leave at zero to generate a random ID. + ZstdDictID uint32 + + // ZstdDictCompat will make the dictionary compatible with Zstd v1.5.5 and earlier. + // See https://github.com/facebook/zstd/issues/3724 + ZstdDictCompat bool + + // Use the specified encoder level for Zstandard dictionaries. + // The dictionary will be built using the specified encoder level, + // which will reflect speed and make the dictionary tailored for that level. + // If not set zstd.SpeedBestCompression will be used. + ZstdLevel zstd.EncoderLevel + + outFormat int +} + +const ( + formatRaw = iota + formatZstd + formatS2 +) + +// BuildZstdDict will build a Zstandard dictionary from the provided input. +func BuildZstdDict(input [][]byte, o Options) ([]byte, error) { + o.outFormat = formatZstd + if o.ZstdDictID == 0 { + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + o.ZstdDictID = 32768 + uint32(rng.Int31n((1<<31)-32768)) + } + return buildDict(input, o) +} + +// BuildS2Dict will build a S2 dictionary from the provided input. +func BuildS2Dict(input [][]byte, o Options) ([]byte, error) { + o.outFormat = formatS2 + if o.MaxDictSize > s2.MaxDictSize { + return nil, errors.New("max dict size too large") + } + return buildDict(input, o) +} + +// BuildRawDict will build a raw dictionary from the provided input. +// This can be used for deflate, lz4 and others. +func BuildRawDict(input [][]byte, o Options) ([]byte, error) { + o.outFormat = formatRaw + return buildDict(input, o) +} + +func buildDict(input [][]byte, o Options) ([]byte, error) { + matches := make(map[uint32]uint32) + offsets := make(map[uint32]int64) + var total uint64 + + wantLen := o.MaxDictSize + hashBytes := o.HashBytes + if len(input) == 0 { + return nil, fmt.Errorf("no input provided") + } + if hashBytes < 4 || hashBytes > 8 { + return nil, fmt.Errorf("HashBytes must be >= 4 and <= 8") + } + println := func(args ...interface{}) { + if o.Output != nil { + fmt.Fprintln(o.Output, args...) + } + } + printf := func(s string, args ...interface{}) { + if o.Output != nil { + fmt.Fprintf(o.Output, s, args...) + } + } + found := make(map[uint32]struct{}) + for i, b := range input { + for k := range found { + delete(found, k) + } + for i := range b { + rem := b[i:] + if len(rem) < 8 { + break + } + h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes)) + if _, ok := found[h]; ok { + // Only count first occurrence + continue + } + matches[h]++ + offsets[h] += int64(i) + total++ + found[h] = struct{}{} + } + printf("\r input %d indexed...", i) + } + threshold := uint32(total / uint64(len(matches))) + println("\nTotal", total, "match", len(matches), "avg", threshold) + sorted := make([]match, 0, len(matches)/2) + for k, v := range matches { + if v <= threshold { + continue + } + sorted = append(sorted, match{hash: k, n: v, offset: offsets[k]}) + } + sort.Slice(sorted, func(i, j int) bool { + if true { + // Group very similar counts together and emit low offsets first. + // This will keep together strings that are very similar. + deltaN := int(sorted[i].n) - int(sorted[j].n) + if deltaN < 0 { + deltaN = -deltaN + } + if uint32(deltaN) < sorted[i].n/32 { + return sorted[i].offset < sorted[j].offset + } + } else { + if sorted[i].n == sorted[j].n { + return sorted[i].offset < sorted[j].offset + } + } + return sorted[i].n > sorted[j].n + }) + println("Sorted len:", len(sorted)) + if len(sorted) > wantLen { + sorted = sorted[:wantLen] + } + lowestOcc := sorted[len(sorted)-1].n + println("Cropped len:", len(sorted), "Lowest occurrence:", lowestOcc) + + wantMatches := make(map[uint32]uint32, len(sorted)) + for _, v := range sorted { + wantMatches[v.hash] = v.n + } + + output := make(map[uint32]matchValue, len(sorted)) + var remainCnt [256]int + var remainTotal int + var firstOffsets []int + for i, b := range input { + for i := range b { + rem := b[i:] + if len(rem) < 8 { + break + } + var prev []byte + if i > hashBytes { + prev = b[i-hashBytes:] + } + + h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes)) + if _, ok := wantMatches[h]; !ok { + remainCnt[rem[0]]++ + remainTotal++ + continue + } + mv := output[h] + if len(mv.value) == 0 { + var tmp = make([]byte, hashBytes) + copy(tmp[:], rem) + mv.value = tmp[:] + } + if mv.followBy == nil { + mv.followBy = make(map[uint32]uint32, 4) + mv.preceededBy = make(map[uint32]uint32, 4) + } + if len(rem) > hashBytes+8 { + // Check if we should add next as well. + hNext := hashLen(binary.LittleEndian.Uint64(rem[hashBytes:]), 32, uint8(hashBytes)) + if _, ok := wantMatches[hNext]; ok { + mv.followBy[hNext]++ + } + } + if len(prev) >= 8 { + // Check if we should prev next as well. + hPrev := hashLen(binary.LittleEndian.Uint64(prev), 32, uint8(hashBytes)) + if _, ok := wantMatches[hPrev]; ok { + mv.preceededBy[hPrev]++ + } + } + output[h] = mv + } + printf("\rinput %d re-indexed...", i) + } + println("") + dst := make([][]byte, 0, wantLen/hashBytes) + added := 0 + const printUntil = 500 + for i, e := range sorted { + if added > o.MaxDictSize { + println("Ending. Next Occurrence:", e.n) + break + } + m, ok := output[e.hash] + if !ok { + // Already added + continue + } + wantLen := e.n / uint32(hashBytes) / 4 + if wantLen <= lowestOcc { + wantLen = lowestOcc + } + + var tmp = make([]byte, 0, hashBytes*2) + { + sortedPrev := make([]match, 0, len(m.followBy)) + for k, v := range m.preceededBy { + if _, ok := output[k]; v < wantLen || !ok { + continue + } + sortedPrev = append(sortedPrev, match{ + hash: k, + n: v, + }) + } + if len(sortedPrev) > 0 { + sort.Slice(sortedPrev, func(i, j int) bool { + return sortedPrev[i].n > sortedPrev[j].n + }) + bestPrev := output[sortedPrev[0].hash] + tmp = append(tmp, bestPrev.value...) + } + } + tmp = append(tmp, m.value...) + delete(output, e.hash) + + sortedFollow := make([]match, 0, len(m.followBy)) + for { + var nh uint32 // Next hash + stopAfter := false + { + sortedFollow = sortedFollow[:0] + for k, v := range m.followBy { + if _, ok := output[k]; !ok { + continue + } + sortedFollow = append(sortedFollow, match{ + hash: k, + n: v, + offset: offsets[k], + }) + } + if len(sortedFollow) == 0 { + // Step back + // Extremely small impact, but helps longer hashes a bit. + const stepBack = 2 + if stepBack > 0 && len(tmp) >= hashBytes+stepBack { + var t8 [8]byte + copy(t8[:], tmp[len(tmp)-hashBytes-stepBack:]) + m, ok = output[hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes))] + if ok && len(m.followBy) > 0 { + found := []byte(nil) + for k := range m.followBy { + v, ok := output[k] + if !ok { + continue + } + found = v.value + break + } + if found != nil { + tmp = tmp[:len(tmp)-stepBack] + printf("Step back: %q + %q\n", string(tmp), string(found)) + continue + } + } + break + } else { + if i < printUntil { + printf("FOLLOW: none after %q\n", string(m.value)) + } + } + break + } + sort.Slice(sortedFollow, func(i, j int) bool { + if sortedFollow[i].n == sortedFollow[j].n { + return sortedFollow[i].offset > sortedFollow[j].offset + } + return sortedFollow[i].n > sortedFollow[j].n + }) + nh = sortedFollow[0].hash + stopAfter = sortedFollow[0].n < wantLen + if stopAfter && i < printUntil { + printf("FOLLOW: %d < %d after %q. Stopping after this.\n", sortedFollow[0].n, wantLen, string(m.value)) + } + } + m, ok = output[nh] + if !ok { + break + } + if len(tmp) > 0 { + // Delete all hashes that are in the current string to avoid stuttering. + var toDel [16 + 8]byte + copy(toDel[:], tmp[len(tmp)-hashBytes:]) + copy(toDel[hashBytes:], m.value) + for i := range toDel[:hashBytes*2] { + delete(output, hashLen(binary.LittleEndian.Uint64(toDel[i:]), 32, uint8(hashBytes))) + } + } + tmp = append(tmp, m.value...) + //delete(output, nh) + if stopAfter { + // Last entry was no significant. + break + } + } + if i < printUntil { + printf("ENTRY %d: %q (%d occurrences, cutoff %d)\n", i, string(tmp), e.n, wantLen) + } + // Delete substrings already added. + if len(tmp) > hashBytes { + for j := range tmp[:len(tmp)-hashBytes+1] { + var t8 [8]byte + copy(t8[:], tmp[j:]) + if i < printUntil { + //printf("* POST DELETE %q\n", string(t8[:hashBytes])) + } + delete(output, hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes))) + } + } + dst = append(dst, tmp) + added += len(tmp) + // Find offsets + // TODO: This can be better if done as a global search. + if len(firstOffsets) < 3 { + if len(tmp) > 16 { + tmp = tmp[:16] + } + offCnt := make(map[int]int, len(input)) + // Find first offsets + for _, b := range input { + off := bytes.Index(b, tmp) + if off == -1 { + continue + } + offCnt[off]++ + } + for _, off := range firstOffsets { + // Very unlikely, but we deleted it just in case + delete(offCnt, off-added) + } + maxCnt := 0 + maxOffset := 0 + for k, v := range offCnt { + if v == maxCnt && k > maxOffset { + // Prefer the longer offset on ties , since it is more expensive to encode + maxCnt = v + maxOffset = k + continue + } + + if v > maxCnt { + maxCnt = v + maxOffset = k + } + } + if maxCnt > 1 { + firstOffsets = append(firstOffsets, maxOffset+added) + println(" - Offset:", len(firstOffsets), "at", maxOffset+added, "count:", maxCnt, "total added:", added, "src index", maxOffset) + } + } + } + out := bytes.NewBuffer(nil) + written := 0 + for i, toWrite := range dst { + if len(toWrite)+written > wantLen { + toWrite = toWrite[:wantLen-written] + } + dst[i] = toWrite + written += len(toWrite) + if written >= wantLen { + dst = dst[:i+1] + break + } + } + // Write in reverse order. + for i := range dst { + toWrite := dst[len(dst)-i-1] + out.Write(toWrite) + } + if o.outFormat == formatRaw { + return out.Bytes(), nil + } + + if o.outFormat == formatS2 { + dOff := 0 + dBytes := out.Bytes() + if len(dBytes) > s2.MaxDictSize { + dBytes = dBytes[:s2.MaxDictSize] + } + for _, off := range firstOffsets { + myOff := len(dBytes) - off + if myOff < 0 || myOff > s2.MaxDictSrcOffset { + continue + } + dOff = myOff + } + + dict := s2.MakeDictManual(dBytes, uint16(dOff)) + if dict == nil { + return nil, fmt.Errorf("unable to create s2 dictionary") + } + return dict.Bytes(), nil + } + + offsetsZstd := [3]int{1, 4, 8} + for i, off := range firstOffsets { + if i >= 3 || off == 0 || off >= out.Len() { + break + } + offsetsZstd[i] = off + } + println("\nCompressing. Offsets:", offsetsZstd) + return zstd.BuildDict(zstd.BuildDictOptions{ + ID: o.ZstdDictID, + Contents: input, + History: out.Bytes(), + Offsets: offsetsZstd, + CompatV155: o.ZstdDictCompat, + Level: o.ZstdLevel, + DebugOut: o.Output, + }) +} + +const ( + prime3bytes = 506832829 + prime4bytes = 2654435761 + prime5bytes = 889523592379 + prime6bytes = 227718039650203 + prime7bytes = 58295818150454627 + prime8bytes = 0xcf1bbcdcb7a56463 +) + +// hashLen returns a hash of the lowest l bytes of u for a size size of h bytes. +// l must be >=4 and <=8. Any other value will return hash for 4 bytes. +// h should always be <32. +// Preferably h and l should be a constant. +// LENGTH 4 is passed straight through +func hashLen(u uint64, hashLog, mls uint8) uint32 { + switch mls { + case 5: + return hash5(u, hashLog) + case 6: + return hash6(u, hashLog) + case 7: + return hash7(u, hashLog) + case 8: + return hash8(u, hashLog) + default: + return uint32(u) + } +} + +// hash3 returns the hash of the lower 3 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <32. +func hash3(u uint32, h uint8) uint32 { + return ((u << (32 - 24)) * prime3bytes) >> ((32 - h) & 31) +} + +// hash4 returns the hash of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <32. +func hash4(u uint32, h uint8) uint32 { + return (u * prime4bytes) >> ((32 - h) & 31) +} + +// hash4x64 returns the hash of the lowest 4 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <32. +func hash4x64(u uint64, h uint8) uint32 { + return (uint32(u) * prime4bytes) >> ((32 - h) & 31) +} + +// hash5 returns the hash of the lowest 5 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash5(u uint64, h uint8) uint32 { + return uint32(((u << (64 - 40)) * prime5bytes) >> ((64 - h) & 63)) +} + +// hash6 returns the hash of the lowest 6 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash6(u uint64, h uint8) uint32 { + return uint32(((u << (64 - 48)) * prime6bytes) >> ((64 - h) & 63)) +} + +// hash7 returns the hash of the lowest 7 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash7(u uint64, h uint8) uint32 { + return uint32(((u << (64 - 56)) * prime7bytes) >> ((64 - h) & 63)) +} + +// hash8 returns the hash of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash8(u uint64, h uint8) uint32 { + return uint32((u * prime8bytes) >> ((64 - h) & 63)) +} diff --git a/dict/cmd/builddict/main.go b/dict/cmd/builddict/main.go new file mode 100644 index 0000000000..f9c3428469 --- /dev/null +++ b/dict/cmd/builddict/main.go @@ -0,0 +1,112 @@ +// Copyright 2023+ Klaus Post. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "flag" + "fmt" + "io" + "log" + "os" + "path/filepath" + "runtime/debug" + + "github.com/klauspost/compress/dict" + "github.com/klauspost/compress/zstd" +) + +var ( + wantLenFlag = flag.Int("len", 112<<10, "Specify custom output size") + wantHashBytes = flag.Int("hash", 6, "Hash bytes match length. Minimum match length.") + wantMaxBytes = flag.Int("max", 32<<10, "Max input length to index per input file") + wantOutput = flag.String("o", "dictionary.bin", "Output name") + wantFormat = flag.String("format", "zstd", `Output type. "zstd" "s2" or "raw"`) + wantZstdID = flag.Uint("dictID", 0, "Zstd dictionary ID. Default (0) will be random") + wantZstdCompat = flag.Bool("zcompat", true, "Generate dictionary compatible with zstd 1.5.5 and older") + wantZstdLevel = flag.Int("zlevel", 0, "Zstd compression level. 0-4") + quiet = flag.Bool("q", false, "Do not print progress") +) + +func main() { + flag.Parse() + debug.SetGCPercent(25) + o := dict.Options{ + MaxDictSize: *wantLenFlag, + HashBytes: *wantHashBytes, + Output: os.Stdout, + ZstdDictID: uint32(*wantZstdID), + ZstdDictCompat: *wantZstdCompat, + ZstdLevel: zstd.EncoderLevel(*wantZstdLevel), + } + if *wantOutput == "" || *quiet { + o.Output = nil + } + var input [][]byte + base := flag.Arg(0) + if base == "" { + flag.Usage() + log.Fatal("no path with files specified") + } + + // Index ALL hashes in all files. + err := filepath.Walk(base, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + + f, err := os.Open(path) + if err != nil { + log.Print(err) + return nil + } + defer f.Close() + b, err := io.ReadAll(io.LimitReader(f, int64(*wantMaxBytes))) + if len(b) < 8 { + return nil + } + if len(b) == 0 { + return nil + } + input = append(input, b) + if !*quiet { + fmt.Print("\r"+info.Name(), " read...") + } + return nil + }) + if err != nil { + log.Fatal(err) + } + if len(input) == 0 { + log.Fatal("no files read") + } + var out []byte + switch *wantFormat { + case "zstd": + out, err = dict.BuildZstdDict(input, o) + case "s2": + out, err = dict.BuildS2Dict(input, o) + case "raw": + out, err = dict.BuildRawDict(input, o) + default: + err = fmt.Errorf("unknown format %q", *wantFormat) + } + if err != nil { + log.Fatal(err) + } + if *wantOutput != "" { + err = os.WriteFile(*wantOutput, out, 0666) + if err != nil { + log.Fatal(err) + } + } else { + _, err = os.Stdout.Write(out) + if err != nil { + log.Fatal(err) + } + } +} diff --git a/s2/dict.go b/s2/dict.go index 24f7ce80bc..f125ad0963 100644 --- a/s2/dict.go +++ b/s2/dict.go @@ -106,6 +106,25 @@ func MakeDict(data []byte, searchStart []byte) *Dict { return &d } +// MakeDictManual will create a dictionary. +// 'data' must be at least MinDictSize and less than or equal to MaxDictSize. +// A manual first repeat index into data must be provided. +// It must be less than len(data)-8. +func MakeDictManual(data []byte, firstIdx uint16) *Dict { + if len(data) < MinDictSize || int(firstIdx) >= len(data)-8 || len(data) > MaxDictSize { + return nil + } + var d Dict + dict := data + d.dict = dict + if cap(d.dict) < len(d.dict)+16 { + d.dict = append(make([]byte, 0, len(d.dict)+16), d.dict...) + } + + d.repeat = int(firstIdx) + return &d +} + // Encode returns the encoded form of src. The returned slice may be a sub- // slice of dst if dst was large enough to hold the entire encoded block. // Otherwise, a newly allocated slice will be returned. diff --git a/zstd/blockenc.go b/zstd/blockenc.go index fd4a36f730..aa20c82b24 100644 --- a/zstd/blockenc.go +++ b/zstd/blockenc.go @@ -361,14 +361,21 @@ func (b *blockEnc) encodeLits(lits []byte, raw bool) error { if len(lits) >= 1024 { // Use 4 Streams. out, reUsed, err = huff0.Compress4X(lits, b.litEnc) - } else if len(lits) > 32 { + } else if len(lits) > 16 { // Use 1 stream single = true out, reUsed, err = huff0.Compress1X(lits, b.litEnc) } else { err = huff0.ErrIncompressible } - + if err == nil && len(out)+5 > len(lits) { + // If we are close, we may still be worse or equal to raw. + var lh literalsHeader + lh.setSizes(len(out), len(lits), single) + if len(out)+lh.size() >= len(lits) { + err = huff0.ErrIncompressible + } + } switch err { case huff0.ErrIncompressible: if debugEncoder { @@ -503,7 +510,7 @@ func (b *blockEnc) encode(org []byte, raw, rawAllLits bool) error { if len(b.literals) >= 1024 && !raw { // Use 4 Streams. out, reUsed, err = huff0.Compress4X(b.literals, b.litEnc) - } else if len(b.literals) > 32 && !raw { + } else if len(b.literals) > 16 && !raw { // Use 1 stream single = true out, reUsed, err = huff0.Compress1X(b.literals, b.litEnc) @@ -511,6 +518,17 @@ func (b *blockEnc) encode(org []byte, raw, rawAllLits bool) error { err = huff0.ErrIncompressible } + if err == nil && len(out)+5 > len(b.literals) { + // If we are close, we may still be worse or equal to raw. + var lh literalsHeader + lh.setSize(len(b.literals)) + szRaw := lh.size() + lh.setSizes(len(out), len(b.literals), single) + szComp := lh.size() + if len(out)+szComp >= len(b.literals)+szRaw { + err = huff0.ErrIncompressible + } + } switch err { case huff0.ErrIncompressible: lh.setType(literalsBlockRaw) diff --git a/zstd/dict.go b/zstd/dict.go index ca0951452e..8d5567fe64 100644 --- a/zstd/dict.go +++ b/zstd/dict.go @@ -1,10 +1,13 @@ package zstd import ( + "bytes" "encoding/binary" "errors" "fmt" "io" + "math" + "sort" "github.com/klauspost/compress/huff0" ) @@ -14,9 +17,8 @@ type dict struct { litEnc *huff0.Scratch llDec, ofDec, mlDec sequenceDec - //llEnc, ofEnc, mlEnc []*fseEncoder - offsets [3]int - content []byte + offsets [3]int + content []byte } const dictMagic = "\x37\xa4\x30\xec" @@ -159,3 +161,374 @@ func InspectDictionary(b []byte) (interface { d, err := loadDict(b) return d, err } + +type BuildDictOptions struct { + // Dictionary ID. + ID uint32 + + // Content to use to create dictionary tables. + Contents [][]byte + + // History to use for all blocks. + History []byte + + // Offsets to use. + Offsets [3]int + + // CompatV155 will make the dictionary compatible with Zstd v1.5.5 and earlier. + // See https://github.com/facebook/zstd/issues/3724 + CompatV155 bool + + // Use the specified encoder level. + // The dictionary will be built using the specified encoder level, + // which will reflect speed and make the dictionary tailored for that level. + // If not set SpeedBestCompression will be used. + Level EncoderLevel + + // DebugOut will write stats and other details here if set. + DebugOut io.Writer +} + +func BuildDict(o BuildDictOptions) ([]byte, error) { + initPredefined() + hist := o.History + contents := o.Contents + debug := o.DebugOut != nil + println := func(args ...interface{}) { + if o.DebugOut != nil { + fmt.Fprintln(o.DebugOut, args...) + } + } + printf := func(s string, args ...interface{}) { + if o.DebugOut != nil { + fmt.Fprintf(o.DebugOut, s, args...) + } + } + print := func(args ...interface{}) { + if o.DebugOut != nil { + fmt.Fprint(o.DebugOut, args...) + } + } + + if int64(len(hist)) > dictMaxLength { + return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), int64(dictMaxLength)) + } + if len(hist) < 8 { + return nil, fmt.Errorf("dictionary of size %d < %d", len(hist), 8) + } + if len(contents) == 0 { + return nil, errors.New("no content provided") + } + d := dict{ + id: o.ID, + litEnc: nil, + llDec: sequenceDec{}, + ofDec: sequenceDec{}, + mlDec: sequenceDec{}, + offsets: o.Offsets, + content: hist, + } + block := blockEnc{lowMem: false} + block.init() + enc := encoder(&bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(maxMatchLen), bufferReset: math.MaxInt32 - int32(maxMatchLen*2), lowMem: false}}) + if o.Level != 0 { + eOpts := encoderOptions{ + level: o.Level, + blockSize: maxMatchLen, + windowSize: maxMatchLen, + dict: &d, + lowMem: false, + } + enc = eOpts.encoder() + } else { + o.Level = SpeedBestCompression + } + var ( + remain [256]int + ll [256]int + ml [256]int + of [256]int + ) + addValues := func(dst *[256]int, src []byte) { + for _, v := range src { + dst[v]++ + } + } + addHist := func(dst *[256]int, src *[256]uint32) { + for i, v := range src { + dst[i] += int(v) + } + } + seqs := 0 + nUsed := 0 + litTotal := 0 + newOffsets := make(map[uint32]int, 1000) + for _, b := range contents { + block.reset(nil) + if len(b) < 8 { + continue + } + nUsed++ + enc.Reset(&d, true) + enc.Encode(&block, b) + addValues(&remain, block.literals) + litTotal += len(block.literals) + seqs += len(block.sequences) + block.genCodes() + addHist(&ll, block.coders.llEnc.Histogram()) + addHist(&ml, block.coders.mlEnc.Histogram()) + addHist(&of, block.coders.ofEnc.Histogram()) + for i, seq := range block.sequences { + if i > 3 { + break + } + offset := seq.offset + if offset == 0 { + continue + } + if offset > 3 { + newOffsets[offset-3]++ + } else { + newOffsets[uint32(o.Offsets[offset-1])]++ + } + } + } + // Find most used offsets. + var sortedOffsets []uint32 + for k := range newOffsets { + sortedOffsets = append(sortedOffsets, k) + } + sort.Slice(sortedOffsets, func(i, j int) bool { + a, b := sortedOffsets[i], sortedOffsets[j] + if a == b { + // Prefer the longer offset + return sortedOffsets[i] > sortedOffsets[j] + } + return newOffsets[sortedOffsets[i]] > newOffsets[sortedOffsets[j]] + }) + if len(sortedOffsets) > 3 { + if debug { + print("Offsets:") + for i, v := range sortedOffsets { + if i > 20 { + break + } + printf("[%d: %d],", v, newOffsets[v]) + } + println("") + } + + sortedOffsets = sortedOffsets[:3] + } + for i, v := range sortedOffsets { + o.Offsets[i] = int(v) + } + if debug { + println("New repeat offsets", o.Offsets) + } + + if nUsed == 0 || seqs == 0 { + return nil, fmt.Errorf("%d blocks, %d sequences found", nUsed, seqs) + } + if debug { + println("Sequences:", seqs, "Blocks:", nUsed, "Literals:", litTotal) + } + if seqs/nUsed < 512 { + // Use 512 as minimum. + nUsed = seqs / 512 + } + copyHist := func(dst *fseEncoder, src *[256]int) ([]byte, error) { + hist := dst.Histogram() + var maxSym uint8 + var maxCount int + var fakeLength int + for i, v := range src { + if v > 0 { + v = v / nUsed + if v == 0 { + v = 1 + } + } + if v > maxCount { + maxCount = v + } + if v != 0 { + maxSym = uint8(i) + } + fakeLength += v + hist[i] = uint32(v) + } + dst.HistogramFinished(maxSym, maxCount) + dst.reUsed = false + dst.useRLE = false + err := dst.normalizeCount(fakeLength) + if err != nil { + return nil, err + } + if debug { + println("RAW:", dst.count[:maxSym+1], "NORM:", dst.norm[:maxSym+1], "LEN:", fakeLength) + } + return dst.writeCount(nil) + } + if debug { + print("Literal lengths: ") + } + llTable, err := copyHist(block.coders.llEnc, &ll) + if err != nil { + return nil, err + } + if debug { + print("Match lengths: ") + } + mlTable, err := copyHist(block.coders.mlEnc, &ml) + if err != nil { + return nil, err + } + if debug { + print("Offsets: ") + } + ofTable, err := copyHist(block.coders.ofEnc, &of) + if err != nil { + return nil, err + } + + // Literal table + avgSize := litTotal + if avgSize > huff0.BlockSizeMax/2 { + avgSize = huff0.BlockSizeMax / 2 + } + huffBuff := make([]byte, 0, avgSize) + // Target size + div := litTotal / avgSize + if div < 1 { + div = 1 + } + if debug { + println("Huffman weights:") + } + for i, n := range remain[:] { + if n > 0 { + n = n / div + // Allow all entries to be represented. + if n == 0 { + n = 1 + } + huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...) + if debug { + printf("[%d: %d], ", i, n) + } + } + } + if o.CompatV155 && remain[255]/div == 0 { + huffBuff = append(huffBuff, 255) + } + scratch := &huff0.Scratch{TableLog: 11} + for tries := 0; tries < 255; tries++ { + scratch = &huff0.Scratch{TableLog: 11} + _, _, err = huff0.Compress1X(huffBuff, scratch) + if err == nil { + break + } + if debug { + printf("Try %d: Huffman error: %v\n", tries+1, err) + } + huffBuff = huffBuff[:0] + if tries == 250 { + if debug { + println("Huffman: Bailing out with predefined table") + } + + // Bail out.... Just generate something + huffBuff = append(huffBuff, bytes.Repeat([]byte{255}, 10000)...) + for i := 0; i < 128; i++ { + huffBuff = append(huffBuff, byte(i)) + } + continue + } + if errors.Is(err, huff0.ErrIncompressible) { + // Try truncating least common. + for i, n := range remain[:] { + if n > 0 { + n = n / (div * (i + 1)) + if n > 0 { + huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...) + } + } + } + if o.CompatV155 && len(huffBuff) > 0 && huffBuff[len(huffBuff)-1] != 255 { + huffBuff = append(huffBuff, 255) + } + if len(huffBuff) == 0 { + huffBuff = append(huffBuff, 0, 255) + } + } + if errors.Is(err, huff0.ErrUseRLE) { + for i, n := range remain[:] { + n = n / (div * (i + 1)) + // Allow all entries to be represented. + if n == 0 { + n = 1 + } + huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...) + } + } + } + + var out bytes.Buffer + out.Write([]byte(dictMagic)) + out.Write(binary.LittleEndian.AppendUint32(nil, o.ID)) + out.Write(scratch.OutTable) + if debug { + println("huff table:", len(scratch.OutTable), "bytes") + println("of table:", len(ofTable), "bytes") + println("ml table:", len(mlTable), "bytes") + println("ll table:", len(llTable), "bytes") + } + out.Write(ofTable) + out.Write(mlTable) + out.Write(llTable) + out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[0]))) + out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[1]))) + out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[2]))) + out.Write(hist) + if debug { + _, err := loadDict(out.Bytes()) + if err != nil { + panic(err) + } + i, err := InspectDictionary(out.Bytes()) + if err != nil { + panic(err) + } + println("ID:", i.ID()) + println("Content size:", i.ContentSize()) + println("Encoder:", i.LitEncoder() != nil) + println("Offsets:", i.Offsets()) + var totalSize int + for _, b := range contents { + totalSize += len(b) + } + + encWith := func(opts ...EOption) int { + enc, err := NewWriter(nil, opts...) + if err != nil { + panic(err) + } + defer enc.Close() + var dst []byte + var totalSize int + for _, b := range contents { + dst = enc.EncodeAll(b, dst[:0]) + totalSize += len(dst) + } + return totalSize + } + plain := encWith(WithEncoderLevel(o.Level)) + withDict := encWith(WithEncoderLevel(o.Level), WithEncoderDict(out.Bytes())) + println("Input size:", totalSize) + println("Plain Compressed:", plain) + println("Dict Compressed:", withDict) + println("Saved:", plain-withDict, (plain-withDict)/len(contents), "bytes per input (rounded down)") + } + return out.Bytes(), nil +}