Skip to content
This repository has been archived by the owner on Nov 19, 2024. It is now read-only.

Change Identify API to accept only an io.Reader #322

Merged
merged 5 commits into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 104 additions & 37 deletions formats.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package archiver

import (
"bytes"
"context"
"errors"
"fmt"
Expand All @@ -25,20 +26,28 @@ func RegisterFormat(format Format) {
// value can be type-asserted to ascertain its capabilities.
//
// If no matching formats were found, special error ErrNoMatch is returned.
func Identify(filename string, stream io.ReadSeeker) (Format, error) {
//
// The returned io.Reader will always be non-nil and will read from the
// same point as the reader which was passed in; it should be used in place
// of the input stream after calling Identify() because it preserves and
// re-reads the bytes that were already read during the identification
// process.
func Identify(filename string, stream io.Reader) (Format, io.Reader, error) {
var compression Compression
var archival Archival

rewindableStream := newRewindReader(stream)

// try compression format first, since that's the outer "layer"
for name, format := range formats {
cf, isCompression := format.(Compression)
if !isCompression {
continue
}

matchResult, err := identifyOne(format, filename, stream, nil)
matchResult, err := identifyOne(format, filename, rewindableStream, nil)
if err != nil {
return nil, fmt.Errorf("matching %s: %w", name, err)
return nil, rewindableStream.reader(), fmt.Errorf("matching %s: %w", name, err)
}

// if matched, wrap input stream with decompression
Expand All @@ -56,9 +65,9 @@ func Identify(filename string, stream io.ReadSeeker) (Format, error) {
continue
}

matchResult, err := identifyOne(format, filename, stream, compression)
matchResult, err := identifyOne(format, filename, rewindableStream, compression)
if err != nil {
return nil, fmt.Errorf("matching %s: %w", name, err)
return nil, rewindableStream.reader(), fmt.Errorf("matching %s: %w", name, err)
}

if matchResult.Matched() {
Expand All @@ -67,57 +76,45 @@ func Identify(filename string, stream io.ReadSeeker) (Format, error) {
}
}

// the stream should be rewound by identifyOne
bufferedStream := rewindableStream.reader()
switch {
case compression != nil && archival == nil:
return compression, nil
return compression, bufferedStream, nil
case compression == nil && archival != nil:
return archival, nil
return archival, bufferedStream, nil
case compression != nil && archival != nil:
return CompressedArchive{compression, archival}, nil
return CompressedArchive{compression, archival}, bufferedStream, nil
default:
return nil, ErrNoMatch
return nil, bufferedStream, ErrNoMatch
}
}

func identifyOne(format Format, filename string, stream io.ReadSeeker, comp Compression) (MatchResult, error) {
if stream == nil {
// shimming an empty stream is easier than hoping every format's
// implementation of Match() expects and handles a nil stream
stream = strings.NewReader("")
}

// reset stream position to beginning, then restore current position when done
previousOffset, err := stream.Seek(0, io.SeekCurrent)
if err != nil {
return MatchResult{}, err
}
_, err = stream.Seek(0, io.SeekStart)
if err != nil {
return MatchResult{}, err
}
defer stream.Seek(previousOffset, io.SeekStart)
func identifyOne(format Format, filename string, stream *rewindReader, comp Compression) (mr MatchResult, err error) {
defer stream.rewind()

// if looking within a compressed format, wrap the stream in a
// reader that can decompress it so we can match the "inner" format
// (yes, we have to make a new reader every time we do a match,
// because we reset/seek the stream each time and that can mess up
// the compression reader's state if we don't discard it also)
if comp != nil {
decompressedStream, err := comp.OpenReader(stream)
if err != nil {
return MatchResult{}, err
decompressedStream, openErr := comp.OpenReader(stream)
if openErr != nil {
return MatchResult{}, openErr
}
defer decompressedStream.Close()
stream = struct {
io.Reader
io.Seeker
}{
Reader: decompressedStream,
Seeker: stream,
}
mr, err = format.Match(filename, decompressedStream)
} else {
mr, err = format.Match(filename, stream)
}

return format.Match(filename, stream)
// if the error is EOF, we can just ignore it.
// Just means we have a small input file.
if errors.Is(err, io.EOF) {
err = nil
}
return mr, err
}

// readAtMost reads at most n bytes from the stream. A nil, empty, or short
Expand Down Expand Up @@ -256,6 +253,76 @@ type MatchResult struct {
// Matched returns true if a match was made by either name or stream.
func (mr MatchResult) Matched() bool { return mr.ByName || mr.ByStream }

// rewindReader is a Reader that can be rewound (reset) to re-read what
// was already read and then continue to read more from the underlying
// stream. When no more rewinding is necessary, call reader() to get a
// new reader that first reads the buffered bytes, then continues to
// read from the stream. This is useful for "peeking" a stream an
// arbitrary number of bytes. Loosely based on the Connection type
// from https://github.com/mholt/caddy-l4.
type rewindReader struct {
io.Reader
buf *bytes.Buffer
bufReader io.Reader
}

func newRewindReader(r io.Reader) *rewindReader {
return &rewindReader{
Reader: r,
buf: new(bytes.Buffer),
}
}

func (rr *rewindReader) Read(p []byte) (n int, err error) {
// if there is a buffer we should read from, start
// with that; we only read from the underlying stream
// after the buffer has been "depleted"
if rr.bufReader != nil {
n, err = rr.bufReader.Read(p)
if err == io.EOF {
rr.bufReader = nil
err = nil
}
if n == len(p) {
return
}
}

// buffer has been "depleted" so read from
// underlying connection
nr, err := rr.Reader.Read(p[n:])

// anything that was read needs to be written to
// the buffer, even if there was an error
if nr > 0 {
if nw, errw := rr.buf.Write(p[n : n+nr]); errw != nil {
return nw, errw
}
}

// up to now, n was how many bytes were read from
// the buffer, and nr was how many bytes were read
// from the stream; add them to return total count
n += nr

return
}

// rewind resets the stream to the beginning by causing
// Read() to start reading from the beginning of the
// buffered bytes.
func (rr *rewindReader) rewind() {
rr.bufReader = bytes.NewReader(rr.buf.Bytes())
}

// reader returns a reader that reads first from the buffered
// bytes, then from the underlying stream. After calling this,
// no more rewinding is allowed since reads from the stream are
// not recorded, so rewinding properly is impossible.
func (rr *rewindReader) reader() io.Reader {
return io.MultiReader(bytes.NewReader(rr.buf.Bytes()), rr.Reader)
}

// ErrNoMatch is returned if there are no matching formats.
var ErrNoMatch = fmt.Errorf("no formats matched")

Expand Down
106 changes: 103 additions & 3 deletions formats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,111 @@ import (
"context"
"io"
"io/fs"
"math/rand"
"os"
"strings"
"testing"
"time"
)

func TestRewindReader(t *testing.T) {
data := "the header\nthe body\n"

r := newRewindReader(strings.NewReader(data))

buf := make([]byte, 10) // enough for 'the header'

// test rewinding reads
for i := 0; i < 10; i++ {
r.rewind()
n, err := r.Read(buf)
if err != nil {
t.Fatalf("Read failed: %s", err)
}
if string(buf[:n]) != "the header" {
t.Fatalf("iteration %d: expected 'the header' but got '%s' (n=%d)", i, string(buf[:n]), n)
}
}

// get the reader from header reader and make sure we can read all of the data out
r.rewind()
finalReader := r.reader()
buf = make([]byte, len(data))
n, err := io.ReadFull(finalReader, buf)
if err != nil {
t.Fatalf("ReadFull failed: %s (n=%d)", err, n)
}
if string(buf) != data {
t.Fatalf("expected '%s' but got '%s'", string(data), string(buf))
}
}

func TestCompression(t *testing.T) {
seed := time.Now().UnixNano()
t.Logf("seed: %d", seed)
r := rand.New(rand.NewSource(seed))

contents := make([]byte, 1024)
r.Read(contents)

compressed := new(bytes.Buffer)

testOK := func(t *testing.T, comp Compression, testFilename string) {
// compress into buffer
compressed.Reset()
wc, err := comp.OpenWriter(compressed)
checkErr(t, err, "opening writer")
_, err = wc.Write(contents)
checkErr(t, err, "writing contents")
checkErr(t, wc.Close(), "closing writer")

// make sure Identify correctly chooses this compression method
format, stream, err := Identify(testFilename, compressed)
checkErr(t, err, "identifying")
if format.Name() != comp.Name() {
t.Fatalf("expected format %s but got %s", comp.Name(), format.Name())
}

// read the contents back out and compare
decompReader, err := format.(Decompressor).OpenReader(stream)
checkErr(t, err, "opening with decompressor '%s'", format.Name())
data, err := io.ReadAll(decompReader)
checkErr(t, err, "reading decompressed data")
checkErr(t, decompReader.Close(), "closing decompressor")
if !bytes.Equal(data, contents) {
t.Fatalf("not equal to original")
}
}

var cannotIdentifyFromStream = map[string]bool{Brotli{}.Name(): true}

for _, f := range formats {
// only test compressors
comp, ok := f.(Compression)
if !ok {
continue
}

t.Run(f.Name()+"_with_extension", func(t *testing.T) {
testOK(t, comp, "file"+f.Name())
})
if !cannotIdentifyFromStream[f.Name()] {
t.Run(f.Name()+"_without_extension", func(t *testing.T) {
testOK(t, comp, "")
})
}
}
}

func checkErr(t *testing.T, err error, msgFmt string, args ...interface{}) {
t.Helper()
if err == nil {
return
}
args = append(args, err)
t.Fatalf(msgFmt+": %s", args...)
}

func TestIdentifyDoesNotMatchContentFromTrimmedKnownHeaderHaving0Suffix(t *testing.T) {
// Using the outcome of `n, err := io.ReadFull(stream, buf)` without minding n
// may lead to a mis-characterization for cases with known header ending with 0x0
Expand Down Expand Up @@ -41,7 +141,7 @@ func TestIdentifyDoesNotMatchContentFromTrimmedKnownHeaderHaving0Suffix(t *testi
}
headerTrimmed := tt.header[:headerLen-1]
stream := bytes.NewReader(headerTrimmed)
got, err := Identify("", stream)
got, _, err := Identify("", stream)
if got != nil {
t.Errorf("no Format expected for trimmed know %s header: found Format= %v", tt.name, got.Name())
return
Expand Down Expand Up @@ -84,7 +184,7 @@ func TestIdentifyCanAssessSmallOrNoContent(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Identify("", tt.args.stream)
got, _, err := Identify("", tt.args.stream)
if got != nil {
t.Errorf("no Format expected for non archive and not compressed stream: found Format= %v", got.Name())
return
Expand Down Expand Up @@ -274,7 +374,7 @@ func TestIdentifyFindFormatByStreamContent(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stream := bytes.NewReader(compress(t, tt.compressorName, tt.content, tt.openCompressionWriter))
got, err := Identify("", stream)
got, _, err := Identify("", stream)
if err != nil {
t.Fatalf("should have found a corresponding Format: err :=%+v", err)
return
Expand Down
2 changes: 1 addition & 1 deletion fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func FileSystem(root string) (fs.FS, error) {
return nil, err
}
defer file.Close()
format, err := Identify(filepath.Base(root), file)
format, _, err := Identify(filepath.Base(root), file)
if err != nil && !errors.Is(err, ErrNoMatch) {
return nil, err
}
Expand Down