diff --git a/content/storage.go b/content/storage.go index 8b527a24..2c2687fd 100644 --- a/content/storage.go +++ b/content/storage.go @@ -20,7 +20,6 @@ import ( "io" ocispec "github.com/opencontainers/image-spec/specs-go/v1" - "oras.land/oras-go/v2/internal/ioutil" ) // Fetcher fetches content. @@ -64,7 +63,7 @@ func FetchAll(ctx context.Context, fetcher Fetcher, desc ocispec.Descriptor) ([] return nil, err } defer rc.Close() - return ioutil.ReadAll(rc, desc) + return ReadAll(rc, desc) } // FetcherFunc is the basic Fetch method defined in Fetcher. diff --git a/content/utils.go b/content/utils.go new file mode 100644 index 00000000..193becce --- /dev/null +++ b/content/utils.go @@ -0,0 +1,64 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package content + +import ( + "errors" + "fmt" + "io" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2/internal/ioutil" +) + +var ( + // ErrInvalidDescriptorSize is returned by ReadAll() when + // the descriptor has an invalid size. + ErrInvalidDescriptorSize = errors.New("invalid descriptor size") + + // ErrMismatchedDigest is returned by ReadAll() when + // the descriptor has an invalid digest. + ErrMismatchedDigest = errors.New("mismatched digest") + + // ErrTrailingData is returned by ReadAll() when + // there exists trailing data unread when the read terminates. + ErrTrailingData = errors.New("trailing data") +) + +// ReadAll safely reads the content described by the descriptor. +// The read content is verified against the size and the digest. +func ReadAll(r io.Reader, desc ocispec.Descriptor) ([]byte, error) { + if desc.Size < 0 { + return nil, ErrInvalidDescriptorSize + } + buf := make([]byte, desc.Size) + + // verify while reading + verifier := desc.Digest.Verifier() + r = io.TeeReader(r, verifier) + // verify the size of the read content + if _, err := io.ReadFull(r, buf); err != nil { + return nil, fmt.Errorf("read failed: %w", err) + } + if err := ioutil.EnsureEOF(r); err != nil { + return nil, ErrTrailingData + } + // verify the digest of the read content + if !verifier.Verified() { + return nil, ErrMismatchedDigest + } + return buf, nil +} diff --git a/content/utils_test.go b/content/utils_test.go new file mode 100644 index 00000000..f77bf555 --- /dev/null +++ b/content/utils_test.go @@ -0,0 +1,113 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package content + +import ( + "bytes" + _ "crypto/sha256" + "errors" + "io" + "testing" + + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" +) + +func TestReadAllCorrectDescriptor(t *testing.T) { + content := []byte("example content") + desc := ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageLayer, + Digest: digest.FromBytes(content), + Size: int64(len(content))} + r := bytes.NewReader([]byte(content)) + got, err := ReadAll(r, desc) + if err != nil { + t.Fatal("ReadAll() error = ", err) + } + if !bytes.Equal(got, content) { + t.Errorf("ReadAll() = %v, want %v", got, content) + } +} + +func TestReadAllReadSizeSmallerThanDescriptorSize(t *testing.T) { + content := []byte("example content") + desc := ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageLayer, + Digest: digest.FromBytes(content), + Size: int64(len(content) + 1)} + r := bytes.NewReader([]byte(content)) + _, err := ReadAll(r, desc) + if err == nil || !errors.Is(err, io.ErrUnexpectedEOF) { + t.Errorf("ReadAll() error = %v, want %v", err, io.ErrUnexpectedEOF) + } +} + +func TestReadAllReadSizeLargerThanDescriptorSize(t *testing.T) { + content := []byte("example content") + desc := ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageLayer, + Digest: digest.FromBytes(content), + Size: int64(len(content) - 1)} + r := bytes.NewReader([]byte(content)) + _, err := ReadAll(r, desc) + if err == nil || !errors.Is(err, ErrTrailingData) { + t.Errorf("ReadAll() error = %v, want %v", err, ErrTrailingData) + } +} + +func TestReadAllInvalidDigest(t *testing.T) { + content := []byte("example content") + desc := ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageLayer, + Digest: digest.FromBytes([]byte("wrong content")), + Size: int64(len(content))} + r := bytes.NewReader([]byte(content)) + _, err := ReadAll(r, desc) + if err == nil || !errors.Is(err, ErrMismatchedDigest) { + t.Errorf("ReadAll() error = %v, want %v", err, ErrMismatchedDigest) + } +} + +func TestReadAllEmptyContent(t *testing.T) { + content := []byte("") + desc := ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageLayer, + Digest: digest.FromBytes(content), + Size: int64(len(content)), + } + r := bytes.NewReader([]byte(content)) + got, err := ReadAll(r, desc) + if err != nil { + t.Fatal("ReadAll() error = ", err) + } + if !bytes.Equal(got, content) { + t.Errorf("ReadAll() = %v, want %v", got, content) + } +} + +func TestReadAllInvalidDescriptorSize(t *testing.T) { + content := []byte("example content") + desc := ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageLayer, + Digest: digest.FromBytes(content), + Size: -1, + } + r := bytes.NewReader([]byte(content)) + _, err := ReadAll(r, desc) + if err == nil || !errors.Is(err, ErrInvalidDescriptorSize) { + t.Errorf("ReadAll() error = %v, want %v", err, ErrInvalidDescriptorSize) + } +} diff --git a/internal/cas/memory.go b/internal/cas/memory.go index 7bd694e9..7e358e13 100644 --- a/internal/cas/memory.go +++ b/internal/cas/memory.go @@ -23,9 +23,9 @@ import ( "sync" ocispec "github.com/opencontainers/image-spec/specs-go/v1" + contentpkg "oras.land/oras-go/v2/content" "oras.land/oras-go/v2/errdef" "oras.land/oras-go/v2/internal/descriptor" - "oras.land/oras-go/v2/internal/ioutil" ) // Memory is a memory based CAS. @@ -58,7 +58,7 @@ func (m *Memory) Push(_ context.Context, expected ocispec.Descriptor, content io } // read and try to store the content. - value, err := ioutil.ReadAll(content, expected) + value, err := contentpkg.ReadAll(content, expected) if err != nil { return err } diff --git a/internal/ioutil/io.go b/internal/ioutil/io.go index 84631ea6..431fdc9e 100644 --- a/internal/ioutil/io.go +++ b/internal/ioutil/io.go @@ -32,28 +32,6 @@ func (fn CloserFunc) Close() error { return fn() } -// ReadAll safely reads the content described by the descriptor. -// The read content is verified against the size and the digest. -func ReadAll(r io.Reader, desc ocispec.Descriptor) ([]byte, error) { - // verify while reading - verifier := desc.Digest.Verifier() - r = io.TeeReader(r, verifier) - buf := make([]byte, desc.Size) - _, err := io.ReadFull(r, buf) - if err != nil { - return nil, fmt.Errorf("read failed: %w", err) - } - if !verifier.Verified() { - return nil, errors.New("digest verification failed") - } - - if err := ensureEOF(r); err != nil { - return nil, err - } - - return buf, nil -} - // CopyBuffer copies from src to dst through the provided buffer // until either EOF is reached on src, or an error occurs. // The copied content is verified against the size and the digest. @@ -71,10 +49,12 @@ func CopyBuffer(dst io.Writer, src io.Reader, buf []byte, desc ocispec.Descripto return errors.New("digest verification failed") } - return ensureEOF(lr) + return EnsureEOF(lr) } -func ensureEOF(r io.Reader) error { +// EnsureEOF ensures the read operation ends with an EOF and no +// trailing data is present. +func EnsureEOF(r io.Reader) error { var peek [1]byte _, err := io.ReadFull(r, peek[:]) if err != io.EOF {