Skip to content

Commit

Permalink
s2: Support ReadAt in ReadSeeker (#747)
Browse files Browse the repository at this point in the history
Also simplifies seeking.
  • Loading branch information
klauspost authored Feb 5, 2023
1 parent 69922df commit c847bde
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 29 deletions.
121 changes: 92 additions & 29 deletions s2/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -880,15 +880,20 @@ func (r *Reader) Skip(n int64) error {
// See Reader.ReadSeeker
type ReadSeeker struct {
*Reader
readAtMu sync.Mutex
}

// ReadSeeker will return an io.ReadSeeker compatible version of the reader.
// ReadSeeker will return an io.ReadSeeker and io.ReaderAt
// compatible version of the reader.
// If 'random' is specified the returned io.Seeker can be used for
// random seeking, otherwise only forward seeking is supported.
// Enabling random seeking requires the original input to support
// the io.Seeker interface.
// A custom index can be specified which will be used if supplied.
// When using a custom index, it will not be read from the input stream.
// The ReadAt position will affect regular reads and the current position of Seek.
// So using Read after ReadAt will continue from where the ReadAt stopped.
// No functions should be used concurrently.
// The returned ReadSeeker contains a shallow reference to the existing Reader,
// meaning changes performed to one is reflected in the other.
func (r *Reader) ReadSeeker(random bool, index []byte) (*ReadSeeker, error) {
Expand Down Expand Up @@ -958,42 +963,55 @@ func (r *ReadSeeker) Seek(offset int64, whence int) (int64, error) {
// Reset on EOF
r.err = nil
}
if offset == 0 && whence == io.SeekCurrent {
return r.blockStart + int64(r.i), nil

// Calculate absolute offset.
absOffset := offset

switch whence {
case io.SeekStart:
case io.SeekCurrent:
absOffset = r.blockStart + int64(r.i) + offset
case io.SeekEnd:
if r.index == nil {
return 0, ErrUnsupported
}
absOffset = r.index.TotalUncompressed + offset
default:
r.err = ErrUnsupported
return 0, r.err
}

if absOffset < 0 {
return 0, errors.New("seek before start of file")
}

if !r.readHeader {
// Make sure we read the header.
_, r.err = r.Read([]byte{})
if r.err != nil {
return 0, r.err
}
}

// If we are inside current block no need to seek.
// This includes no offset changes.
if absOffset >= r.blockStart && absOffset < r.blockStart+int64(r.j) {
r.i = int(absOffset - r.blockStart)
return r.blockStart + int64(r.i), nil
}

rs, ok := r.r.(io.ReadSeeker)
if r.index == nil || !ok {
if whence == io.SeekCurrent && offset >= 0 {
err := r.Skip(offset)
return r.blockStart + int64(r.i), err
}
if whence == io.SeekStart && offset >= r.blockStart+int64(r.i) {
err := r.Skip(offset - r.blockStart - int64(r.i))
currOffset := r.blockStart + int64(r.i)
if absOffset >= currOffset {
err := r.Skip(absOffset - currOffset)
return r.blockStart + int64(r.i), err
}
return 0, ErrUnsupported

}

switch whence {
case io.SeekCurrent:
offset += r.blockStart + int64(r.i)
case io.SeekEnd:
if offset > 0 {
return 0, errors.New("seek after end of file")
}
offset = r.index.TotalUncompressed + offset
}

if offset < 0 {
return 0, errors.New("seek before start of file")
}

c, u, err := r.index.Find(offset)
// We can seek and we have an index.
c, u, err := r.index.Find(absOffset)
if err != nil {
return r.blockStart + int64(r.i), err
}
Expand All @@ -1004,12 +1022,57 @@ func (r *ReadSeeker) Seek(offset int64, whence int) (int64, error) {
return 0, err
}

r.i = r.j // Remove rest of current block.
if u < offset {
r.i = r.j // Remove rest of current block.
r.blockStart = u - int64(r.j) // Adjust current block start for accounting.
if u < absOffset {
// Forward inside block
return offset, r.Skip(offset - u)
return absOffset, r.Skip(absOffset - u)
}
if u > absOffset {
return 0, fmt.Errorf("s2 seek: (internal error) u (%d) > absOffset (%d)", u, absOffset)
}
return absOffset, nil
}

// ReadAt reads len(p) bytes into p starting at offset off in the
// underlying input source. It returns the number of bytes
// read (0 <= n <= len(p)) and any error encountered.
//
// When ReadAt returns n < len(p), it returns a non-nil error
// explaining why more bytes were not returned. In this respect,
// ReadAt is stricter than Read.
//
// Even if ReadAt returns n < len(p), it may use all of p as scratch
// space during the call. If some data is available but not len(p) bytes,
// ReadAt blocks until either all the data is available or an error occurs.
// In this respect ReadAt is different from Read.
//
// If the n = len(p) bytes returned by ReadAt are at the end of the
// input source, ReadAt may return either err == EOF or err == nil.
//
// If ReadAt is reading from an input source with a seek offset,
// ReadAt should not affect nor be affected by the underlying
// seek offset.
//
// Clients of ReadAt can execute parallel ReadAt calls on the
// same input source. This is however not recommended.
func (r *ReadSeeker) ReadAt(p []byte, offset int64) (int, error) {
r.readAtMu.Lock()
defer r.readAtMu.Unlock()
_, err := r.Seek(offset, io.SeekStart)
if err != nil {
return 0, err
}
n := 0
for n < len(p) {
n2, err := r.Read(p[n:])
if err != nil {
// This will include io.EOF
return n + n2, err
}
n += n2
}
return offset, nil
return n, nil
}

// ReadByte satisfies the io.ByteReader interface.
Expand Down
31 changes: 31 additions & 0 deletions s2/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,37 @@ func TestIndex(t *testing.T) {
}
})
}
t.Run(fmt.Sprintf("ReadAt"), func(t *testing.T) {
// Read it from a seekable stream
dec = NewReader(bytes.NewReader(compressed))

rs, err := dec.ReadSeeker(true, nil)
fatalErr(t, err)

// Read a little...
var tmp = make([]byte, len(input)/2)
_, err = io.ReadFull(rs, tmp[:])
fatalErr(t, err)
wantLen := len(tmp)
if wantLen+int(wantOffset) > len(input) {
wantLen = len(input) - int(wantOffset)
}
// Read from wantOffset
n, err := rs.ReadAt(tmp, wantOffset)
if n != wantLen {
t.Errorf("got length %d, want %d", n, wantLen)
}
if err != io.EOF {
fatalErr(t, err)
}
want := want[:n]
got := tmp[:n]

// Read the rest of the stream...
if !bytes.Equal(got, want) {
t.Error("Result mismatch", wantOffset)
}
})
})
}
}
Expand Down
60 changes: 60 additions & 0 deletions s2/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,66 @@ func TestSeeking(t *testing.T) {
}
})
}
// Test seek current
t.Run(fmt.Sprintf("seekCurrent"), func(t *testing.T) {
dec := s2.NewReader(io.ReadSeeker(bytes.NewReader(compressed.Bytes())))

seeker, err := dec.ReadSeeker(true, index)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 25)
rng := rand.New(rand.NewSource(0))
var currentOff int64
for i := 0; i < nElems/10; i++ {
rec := rng.Intn(nElems)
offset := int64(rec * 25)
//t.Logf("Reading record %d", rec)
absOff, err := seeker.Seek(offset-currentOff, io.SeekCurrent)
if err != nil {
t.Fatalf("Failed to seek: %v", err)
}
if absOff != offset {
t.Fatalf("Unexpected seek offset: want %v, got %v", offset, absOff)
}
_, err = io.ReadFull(dec, buf)
if err != nil {
t.Fatalf("Failed to seek: %v", err)
}
expected := fmt.Sprintf("Item %019d\n", rec)
if string(buf) != expected {
t.Fatalf("Expected %q, got %q", expected, buf)
}
// Adjust offset
currentOff = offset + int64(len(buf))
}
})
// Test ReadAt
t.Run(fmt.Sprintf("ReadAt"), func(t *testing.T) {
dec := s2.NewReader(io.ReadSeeker(bytes.NewReader(compressed.Bytes())))

seeker, err := dec.ReadSeeker(true, index)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 25)
rng := rand.New(rand.NewSource(0))
for i := 0; i < nElems/10; i++ {
rec := rng.Intn(nElems)
offset := int64(rec * 25)
n, err := seeker.ReadAt(buf, offset)
if err != nil {
t.Fatalf("Failed to seek: %v", err)
}
if n != len(buf) {
t.Fatalf("Unexpected read length: want %v, got %v", len(buf), n)
}
expected := fmt.Sprintf("Item %019d\n", rec)
if string(buf) != expected {
t.Fatalf("Expected %q, got %q", expected, buf)
}
}
})
}

// ExampleIndexStream shows an example of indexing a stream
Expand Down

0 comments on commit c847bde

Please sign in to comment.