diff --git a/zstd/enc_base.go b/zstd/enc_base.go index 15ae8ee807..2760308d0b 100644 --- a/zstd/enc_base.go +++ b/zstd/enc_base.go @@ -126,24 +126,7 @@ func (e *fastBase) matchlen(s, t int32, src []byte) int32 { panic(fmt.Sprintf("len(src)-s (%d) > maxCompressedBlockSize (%d)", len(src)-int(s), maxCompressedBlockSize)) } } - a := src[s:] - b := src[t:] - b = b[:len(a)] - end := int32((len(a) >> 3) << 3) - for i := int32(0); i < end; i += 8 { - if diff := load6432(a, i) ^ load6432(b, i); diff != 0 { - return i + int32(bits.TrailingZeros64(diff)>>3) - } - } - - a = a[end:] - b = b[end:] - for i := range a { - if a[i] != b[i] { - return int32(i) + end - } - } - return int32(len(a)) + end + return int32(matchLen(src[s:], src[t:])) } // Reset the encoding table. diff --git a/zstd/zstd.go b/zstd/zstd.go index 3eb3f1c826..34b3cfdb08 100644 --- a/zstd/zstd.go +++ b/zstd/zstd.go @@ -110,26 +110,25 @@ func printf(format string, a ...interface{}) { } } -// matchLen returns the maximum length. +// matchLen returns the maximum common prefix length of a and b. // a must be the shortest of the two. -// The function also returns whether all bytes matched. -func matchLen(a, b []byte) int { - b = b[:len(a)] - for i := 0; i < len(a)-7; i += 8 { - if diff := load64(a, i) ^ load64(b, i); diff != 0 { - return i + (bits.TrailingZeros64(diff) >> 3) +func matchLen(a, b []byte) (n int) { + for ; len(a) >= 8 && len(b) >= 8; a, b = a[8:], b[8:] { + diff := binary.LittleEndian.Uint64(a) ^ binary.LittleEndian.Uint64(b) + if diff != 0 { + return n + bits.TrailingZeros64(diff)>>3 } + n += 8 } - checked := (len(a) >> 3) << 3 - a = a[checked:] - b = b[checked:] for i := range a { if a[i] != b[i] { - return i + checked + break } + n++ } - return len(a) + checked + return n + } func load3232(b []byte, i int32) uint32 { @@ -140,10 +139,6 @@ func load6432(b []byte, i int32) uint64 { return binary.LittleEndian.Uint64(b[i:]) } -func load64(b []byte, i int) uint64 { - return binary.LittleEndian.Uint64(b[i:]) -} - type byter interface { Bytes() []byte Len() int diff --git a/zstd/zstd_test.go b/zstd/zstd_test.go index 0278d49c45..fd1d3168e5 100644 --- a/zstd/zstd_test.go +++ b/zstd/zstd_test.go @@ -31,3 +31,24 @@ func TestMain(m *testing.M) { } os.Exit(ec) } + +func TestMatchLen(t *testing.T) { + a := make([]byte, 130) + for i := range a { + a[i] = byte(i) + } + b := append([]byte{}, a...) + + check := func(x, y []byte, l int) { + if m := matchLen(x, y); m != l { + t.Error("expected", l, "got", m) + } + } + + for l := range a { + a[l] = ^a[l] + check(a, b, l) + check(a[:l], b, l) + a[l] = ^a[l] + } +}