Skip to content

Commit

Permalink
use function instead of static map
Browse files Browse the repository at this point in the history
Signed-off-by: Avi Deitcher <[email protected]>
  • Loading branch information
deitch committed Feb 12, 2021
1 parent 38788d6 commit 76a5d13
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 49 deletions.
4 changes: 2 additions & 2 deletions pkg/content/multiwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
// MultiWriterIngester an ingester that can provide a single writer or multiple writers for a single
// descriptor. Useful when the target of a descriptor can have multiple items within it, e.g. a layer
// that is a tar file with multiple files, each of which should go to a different stream, some of which
// should not be handled at all
// should not be handled at all.
type MultiWriterIngester interface {
ctrcontent.Ingester
Writers(ctx context.Context, opts ...ctrcontent.WriterOpt) (map[string]ctrcontent.Writer, error)
Writers(ctx context.Context, opts ...ctrcontent.WriterOpt) (func(string) (ctrcontent.Writer, error), error)
}
61 changes: 35 additions & 26 deletions pkg/content/passthrough.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ func (pw *PassthroughWriter) Write(p []byte) (n int, err error) {
}

func (pw *PassthroughWriter) Close() error {
pw.pipew.Close()
if pw.pipew != nil {
pw.pipew.Close()
}
pw.writer.Close()
return nil
}
Expand All @@ -82,9 +84,13 @@ func (pw *PassthroughWriter) Digest() digest.Digest {
// Commit always closes the writer, even on error.
// ErrAlreadyExists aborts the writer.
func (pw *PassthroughWriter) Commit(ctx context.Context, size int64, expected digest.Digest, opts ...content.Opt) error {
pw.pipew.Close()
if pw.pipew != nil {
pw.pipew.Close()
}
err := <-pw.done
pw.reader.Close()
if pw.reader != nil {
pw.reader.Close()
}
if err != nil && err != io.EOF {
return err
}
Expand Down Expand Up @@ -152,10 +158,9 @@ type PassthroughMultiWriter struct {
done chan error
startedAt time.Time
updatedAt time.Time
ref string
}

func NewPassthroughMultiWriter(writers []content.Writer, f func(r io.Reader, w []io.Writer, done chan<- error), opts ...WriterOpt) content.Writer {
func NewPassthroughMultiWriter(writers func(name string) (content.Writer, error), f func(r io.Reader, getwriter func(name string) io.Writer, done chan<- error), opts ...WriterOpt) content.Writer {
// process opts for default
wOpts := DefaultWriterOpts()
for _, opt := range opts {
Expand All @@ -164,36 +169,38 @@ func NewPassthroughMultiWriter(writers []content.Writer, f func(r io.Reader, w [
}
}

var pws []*PassthroughWriter
r, w := io.Pipe()
for _, writer := range writers {
pws = append(pws, &PassthroughWriter{

pmw := &PassthroughMultiWriter{
startedAt: time.Now(),
updatedAt: time.Now(),
done: make(chan error, 1),
digester: digest.Canonical.Digester(),
hash: wOpts.InputHash,
pipew: w,
reader: r,
}

// get our output writers
getwriter := func(name string) io.Writer {
writer, err := writers(name)
if err != nil || writer == nil {
return nil
}
pw := &PassthroughWriter{
writer: writer,
pipew: w,
digester: digest.Canonical.Digester(),
underlyingWriter: &underlyingWriter{
writer: writer,
digester: digest.Canonical.Digester(),
hash: wOpts.OutputHash,
},
reader: r,
hash: wOpts.InputHash,
done: make(chan error, 1),
})
}

pmw := &PassthroughMultiWriter{
writers: pws,
startedAt: time.Now(),
updatedAt: time.Now(),
done: make(chan error, 1),
}
// get our output writers
var uws []io.Writer
for _, uw := range pws {
uws = append(uws, uw.underlyingWriter)
}
pmw.writers = append(pmw.writers, pw)
return pw.underlyingWriter
}
go f(r, uws, pmw.done)
go f(r, getwriter, pmw.done)
return pmw
}

Expand Down Expand Up @@ -230,7 +237,9 @@ func (pmw *PassthroughMultiWriter) Digest() digest.Digest {
func (pmw *PassthroughMultiWriter) Commit(ctx context.Context, size int64, expected digest.Digest, opts ...content.Opt) error {
pmw.pipew.Close()
err := <-pmw.done
pmw.reader.Close()
if pmw.reader != nil {
pmw.reader.Close()
}
if err != nil && err != io.EOF {
return err
}
Expand Down
76 changes: 76 additions & 0 deletions pkg/content/passthrough_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package content_test

import (
"bytes"
"context"
"fmt"
"io"
"math/rand"
"testing"

ctrcontent "github.com/containerd/containerd/content"
Expand Down Expand Up @@ -115,3 +117,77 @@ func TestPassthroughWriter(t *testing.T) {
}
}
}

func TestPassthroughMultiWriter(t *testing.T) {
// pass through function that selects one of two outputs
var (
b1, b2 bytes.Buffer
name1, name2 = "I am name 01", "I am name 02" // each of these is 12 bytes
data1, data2 = make([]byte, 500), make([]byte, 500)
)
rand.Read(data1)
rand.Read(data2)
combined := append([]byte(name1), data1...)
combined = append(combined, []byte(name2)...)
combined = append(combined, data2...)
f := func(r io.Reader, getwriter func(name string) io.Writer, done chan<- error) {
var (
err error
)
// test is done rather simply, with a single 1024 byte chunk, split into 2x512 data streams, each of which is
// 12 bytes of name and 500 bytes of data
b := make([]byte, 1024)
_, err = r.Read(b)
if err != nil && err != io.EOF {
t.Fatalf("data read error: %v", err)
}

// get the names and data for each
n1, n2 := string(b[0:12]), string(b[512+0:512+12])
w1, w2 := getwriter(n1), getwriter(n2)
if _, err := w1.Write(b[12:512]); err != nil {
t.Fatalf("w1 write error: %v", err)
}
if _, err := w2.Write(b[512+12 : 1024]); err != nil {
t.Fatalf("w2 write error: %v", err)
}
done <- err
}

var (
opts = []content.WriterOpt{content.WithInputHash(testContentHash), content.WithOutputHash(modifiedContentHash)}
hash = testContentHash
)
ctx := context.Background()
writers := func(name string) (ctrcontent.Writer, error) {
switch name {
case name1:
return content.NewIoContentWriter(&b1), nil
case name2:
return content.NewIoContentWriter(&b2), nil
}
return nil, fmt.Errorf("unknown name %s", name)
}
writer := content.NewPassthroughMultiWriter(writers, f, opts...)
n, err := writer.Write(combined)
if err != nil {
t.Fatalf("unexpected error on Write: %v", err)
}
if n != len(combined) {
t.Fatalf("wrote %d bytes instead of %d", n, len(combined))
}
if err := writer.Commit(ctx, testDescriptor.Size, hash); err != nil {
t.Errorf("unexpected error on Commit: %v", err)
}
if digest := writer.Digest(); digest != hash {
t.Errorf("mismatched digest: actual %v, expected %v", digest, hash)
}

// make sure the data is what we expected
if !bytes.Equal(data1, b1.Bytes()) {
t.Errorf("b1 data1 did not match")
}
if !bytes.Equal(data2, b2.Bytes()) {
t.Errorf("b2 data2 did not match")
}
}
32 changes: 11 additions & 21 deletions pkg/content/untar.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,9 @@ func NewUntarWriter(writer content.Writer, opts ...WriterOpt) content.Writer {
}

// NewUntarWriterByName wrap multiple writers with an untar, so that the stream is untarred and passed
// to the appropriate writer, based on the filename. If a filename is not found, it will not pass it
// to any writer. The filename "" will handle any stream that does not have a specific filename; use
// it for the default of a single file in a tar stream.
func NewUntarWriterByName(writers map[string]content.Writer, opts ...WriterOpt) content.Writer {
// to the appropriate writer, based on the filename. If a filename is not found, it is up to the called func
// to determine how to process it.
func NewUntarWriterByName(writers func(string) (content.Writer, error), opts ...WriterOpt) content.Writer {
// process opts for default
wOpts := DefaultWriterOpts()
for _, opt := range opts {
Expand All @@ -84,15 +83,8 @@ func NewUntarWriterByName(writers map[string]content.Writer, opts ...WriterOpt)
}
}

// construct an array of content.Writer
nameToIndex := map[string]int{}
var writerSlice []content.Writer
for name, writer := range writers {
writerSlice = append(writerSlice, writer)
nameToIndex[name] = len(writerSlice) - 1
}
// need a PassthroughMultiWriter here
return NewPassthroughMultiWriter(writerSlice, func(r io.Reader, ws []io.Writer, done chan<- error) {
return NewPassthroughMultiWriter(writers, func(r io.Reader, getwriter func(name string) io.Writer, done chan<- error) {
tr := tar.NewReader(r)
var err error
for {
Expand All @@ -109,13 +101,11 @@ func NewUntarWriterByName(writers map[string]content.Writer, opts ...WriterOpt)
}
// get the filename
filename := header.Name
index, ok := nameToIndex[filename]
if !ok {
index, ok = nameToIndex[""]
if !ok {
// we did not find this file or the wildcard, so do not process this file
continue
}

// get the writer for this filename
w := getwriter(filename)
if w == nil {
continue
}

// write out the untarred data
Expand All @@ -133,8 +123,8 @@ func NewUntarWriterByName(writers map[string]content.Writer, opts ...WriterOpt)
if n > len(b) {
l = len(b)
}
if _, err2 := ws[index].Write(b[:l]); err2 != nil {
err = fmt.Errorf("UntarWriter error writing to underlying writer at index %d for name '%s': %v", index, filename, err2)
if _, err2 := w.Write(b[:l]); err2 != nil {
err = fmt.Errorf("UntarWriter error writing to underlying writer at for name '%s': %v", filename, err2)
break
}
if err == io.EOF {
Expand Down

0 comments on commit 76a5d13

Please sign in to comment.