diff --git a/copy.go b/copy.go index 9caed9809..ea42d56ce 100644 --- a/copy.go +++ b/copy.go @@ -105,6 +105,12 @@ type CopyGraphOptions struct { // OnCopySkipped will be called when the sub-DAG rooted by the current node // is skipped. OnCopySkipped func(ctx context.Context, desc ocispec.Descriptor) error + // MountFrom returns the candidate repositories that desc may be mounted from. + // The OCI references will be tried in turn. If mounting fails on all of them, + // then it falls back to a copy. + MountFrom func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) + // OnMounted will be invoked when desc is mounted. + OnMounted func(ctx context.Context, desc ocispec.Descriptor) error // FindSuccessors finds the successors of the current node. // fetcher provides cached access to the source storage, and is suitable // for fetching non-leaf nodes like manifests. Since anything fetched from @@ -259,12 +265,86 @@ func copyGraph(ctx context.Context, src content.ReadOnlyStorage, dst content.Sto if exists { return copyNode(ctx, proxy.Cache, dst, desc, opts) } - return copyNode(ctx, src, dst, desc, opts) + return mountOrCopyNode(ctx, src, dst, desc, opts) } return syncutil.Go(ctx, limiter, fn, root) } +// mountOrCopyNode tries to mount the node, if not falls back to copying. +func mountOrCopyNode(ctx context.Context, src content.ReadOnlyStorage, dst content.Storage, desc ocispec.Descriptor, opts CopyGraphOptions) error { + // Need MountFrom and it must be a blob + if opts.MountFrom == nil || descriptor.IsManifest(desc) { + return copyNode(ctx, src, dst, desc, opts) + } + + mounter, ok := dst.(registry.Mounter) + if !ok { + // mounting is not supported by the destination + return copyNode(ctx, src, dst, desc, opts) + } + + sourceRepositories, err := opts.MountFrom(ctx, desc) + if err != nil { + // Technically this error is not fatal, we can still attempt to copy the node + // But for consistency with the other callbacks we bail out. + return err + } + + if len(sourceRepositories) == 0 { + return copyNode(ctx, src, dst, desc, opts) + } + + skipContent := errors.New("skip content") + for i, sourceRepository := range sourceRepositories { + // try mounting this source repository + var mountFailed bool + getContent := func() (io.ReadCloser, error) { + // the invocation of getContent indicates that mounting has failed + mountFailed = true + + if len(sourceRepositories)-1 == i { + // this is the last iteration so we need to actually get the content and do the copy + + // call the original PreCopy function if it exists + if opts.PreCopy != nil { + if err := opts.PreCopy(ctx, desc); err != nil { + return nil, err + } + } + return src.Fetch(ctx, desc) + } + + // We want to return an error that we will test for from mounter.Mount() + return nil, skipContent + } + + // Mount or copy + if err := mounter.Mount(ctx, desc, sourceRepository, getContent); err != nil && !errors.Is(err, skipContent) { + return err + } + + if !mountFailed { + // mounted, success + if opts.OnMounted != nil { + if err := opts.OnMounted(ctx, desc); err != nil { + return err + } + } + return nil + } + } + + // we copied it + if opts.PostCopy != nil { + if err := opts.PostCopy(ctx, desc); err != nil { + return err + } + } + + return nil +} + // doCopyNode copies a single content from the source CAS to the destination CAS. func doCopyNode(ctx context.Context, src content.ReadOnlyStorage, dst content.Storage, desc ocispec.Descriptor) error { rc, err := src.Fetch(ctx, desc) diff --git a/copy_test.go b/copy_test.go index 89ac7ed1e..02421524c 100644 --- a/copy_test.go +++ b/copy_test.go @@ -1456,7 +1456,7 @@ func TestCopyGraph_WithOptions(t *testing.T) { }, } if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { - t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit) + t.Fatalf("CopyGraph() error = %v", err) } if got, expected := dst.numExists.Load(), int64(7); got != expected { @@ -1471,11 +1471,418 @@ func TestCopyGraph_WithOptions(t *testing.T) { t.Errorf("count(Push()) = %d, want %d", got, expected) } }) + + t.Run("MountFrom mounted", func(t *testing.T) { + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + var numMount atomic.Int64 + dst.mount = func(ctx context.Context, + desc ocispec.Descriptor, + fromRepo string, + getContent func() (io.ReadCloser, error), + ) error { + numMount.Add(1) + if expected := "source"; fromRepo != expected { + t.Fatalf("fromRepo = %v, want %v", fromRepo, expected) + } + rc, err := src.Fetch(ctx, desc) + if err != nil { + t.Fatalf("Failed to fetch content: %v", err) + } + defer rc.Close() + err = dst.storage.Push(ctx, desc, rc) // bypass the counters + if err != nil { + t.Fatalf("Failed to push content: %v", err) + } + return nil + } + opts = oras.CopyGraphOptions{} + var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64 + opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPreCopy.Add(1) + return nil + } + opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPostCopy.Add(1) + return nil + } + opts.OnMounted = func(ctx context.Context, d ocispec.Descriptor) error { + numOnMounted.Add(1) + return nil + } + opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) { + numMountFrom.Add(1) + return []string{"source"}, nil + } + if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { + t.Fatalf("CopyGraph() error = %v", err) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + // 7 (exists) - 1 (skipped) = 6 pushes expected + if got, expected := dst.numPush.Load(), int64(3); got != expected { + // If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do. + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + if got, expected := numMount.Load(), int64(4); got != expected { + t.Errorf("count(Mount()) = %d, want %d", got, expected) + } + if got, expected := numOnMounted.Load(), int64(4); got != expected { + t.Errorf("count(OnMounted()) = %d, want %d", got, expected) + } + if got, expected := numMountFrom.Load(), int64(4); got != expected { + t.Errorf("count(MountFrom()) = %d, want %d", got, expected) + } + if got, expected := numPreCopy.Load(), int64(3); got != expected { + t.Errorf("count(PreCopy()) = %d, want %d", got, expected) + } + if got, expected := numPostCopy.Load(), int64(3); got != expected { + t.Errorf("count(PostCopy()) = %d, want %d", got, expected) + } + }) + + t.Run("MountFrom copied", func(t *testing.T) { + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + var numMount atomic.Int64 + dst.mount = func(ctx context.Context, + desc ocispec.Descriptor, + fromRepo string, + getContent func() (io.ReadCloser, error), + ) error { + numMount.Add(1) + if expected := "source"; fromRepo != expected { + t.Fatalf("fromRepo = %v, want %v", fromRepo, expected) + } + + rc, err := getContent() + if err != nil { + t.Fatalf("Failed to fetch content: %v", err) + } + defer rc.Close() + err = dst.storage.Push(ctx, desc, rc) // bypass the counters + if err != nil { + t.Fatalf("Failed to push content: %v", err) + } + return nil + } + opts = oras.CopyGraphOptions{} + var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64 + opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPreCopy.Add(1) + return nil + } + opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPostCopy.Add(1) + return nil + } + opts.OnMounted = func(ctx context.Context, d ocispec.Descriptor) error { + numOnMounted.Add(1) + return nil + } + opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) { + numMountFrom.Add(1) + return []string{"source"}, nil + } + if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { + t.Fatalf("CopyGraph() error = %v", err) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + // 7 (exists) - 1 (skipped) = 6 pushes expected + if got, expected := dst.numPush.Load(), int64(3); got != expected { + // If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do. + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + if got, expected := numMount.Load(), int64(4); got != expected { + t.Errorf("count(Mount()) = %d, want %d", got, expected) + } + if got, expected := numOnMounted.Load(), int64(0); got != expected { + t.Errorf("count(OnMounted()) = %d, want %d", got, expected) + } + if got, expected := numMountFrom.Load(), int64(4); got != expected { + t.Errorf("count(MountFrom()) = %d, want %d", got, expected) + } + if got, expected := numPreCopy.Load(), int64(7); got != expected { + t.Errorf("count(PreCopy()) = %d, want %d", got, expected) + } + if got, expected := numPostCopy.Load(), int64(7); got != expected { + t.Errorf("count(PostCopy()) = %d, want %d", got, expected) + } + }) + + t.Run("MountFrom mounted second try", func(t *testing.T) { + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + var numMount atomic.Int64 + dst.mount = func(ctx context.Context, + desc ocispec.Descriptor, + fromRepo string, + getContent func() (io.ReadCloser, error), + ) error { + numMount.Add(1) + switch fromRepo { + case "source": + rc, err := src.Fetch(ctx, desc) + if err != nil { + t.Fatalf("Failed to fetch content: %v", err) + } + defer rc.Close() + err = dst.storage.Push(ctx, desc, rc) // bypass the counters + if err != nil { + t.Fatalf("Failed to push content: %v", err) + } + return nil + case "missing/the/data": + // simulate a registry mount will fail, so it will request the content to start the copy. + rc, err := getContent() + if err != nil { + return fmt.Errorf("getContent failed: %w", err) + } + defer rc.Close() + err = dst.storage.Push(ctx, desc, rc) // bypass the counters + if err != nil { + t.Fatalf("Failed to push content: %v", err) + } + return nil + default: + t.Fatalf("fromRepo = %v, want either %v or %v", fromRepo, "missing/the/data", "source") + return nil + } + } + opts = oras.CopyGraphOptions{} + var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64 + opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPreCopy.Add(1) + return nil + } + opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPostCopy.Add(1) + return nil + } + opts.OnMounted = func(ctx context.Context, d ocispec.Descriptor) error { + numOnMounted.Add(1) + return nil + } + opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) { + numMountFrom.Add(1) + return []string{"missing/the/data", "source"}, nil + } + if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { + t.Fatalf("CopyGraph() error = %v", err) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + // 7 (exists) - 1 (skipped) = 6 pushes expected + if got, expected := dst.numPush.Load(), int64(3); got != expected { + // If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do. + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + if got, expected := numMount.Load(), int64(4*2); got != expected { + t.Errorf("count(Mount()) = %d, want %d", got, expected) + } + if got, expected := numOnMounted.Load(), int64(4); got != expected { + t.Errorf("count(OnMounted()) = %d, want %d", got, expected) + } + if got, expected := numMountFrom.Load(), int64(4); got != expected { + t.Errorf("count(MountFrom()) = %d, want %d", got, expected) + } + if got, expected := numPreCopy.Load(), int64(3); got != expected { + t.Errorf("count(PreCopy()) = %d, want %d", got, expected) + } + if got, expected := numPostCopy.Load(), int64(3); got != expected { + t.Errorf("count(PostCopy()) = %d, want %d", got, expected) + } + }) + + t.Run("MountFrom copied dst not a Mounter", func(t *testing.T) { + root = descs[6] + dst := cas.NewMemory() + opts = oras.CopyGraphOptions{} + var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64 + opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPreCopy.Add(1) + return nil + } + opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPostCopy.Add(1) + return nil + } + opts.OnMounted = func(ctx context.Context, d ocispec.Descriptor) error { + numOnMounted.Add(1) + return nil + } + opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) { + numMountFrom.Add(1) + return []string{"source"}, nil + } + if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { + t.Fatalf("CopyGraph() error = %v", err) + } + + if got, expected := numOnMounted.Load(), int64(0); got != expected { + t.Errorf("count(OnMounted()) = %d, want %d", got, expected) + } + if got, expected := numMountFrom.Load(), int64(0); got != expected { + t.Errorf("count(MountFrom()) = %d, want %d", got, expected) + } + if got, expected := numPreCopy.Load(), int64(7); got != expected { + t.Errorf("count(PreCopy()) = %d, want %d", got, expected) + } + if got, expected := numPostCopy.Load(), int64(7); got != expected { + t.Errorf("count(PostCopy()) = %d, want %d", got, expected) + } + }) + + t.Run("MountFrom empty sourceRepositories", func(t *testing.T) { + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + opts = oras.CopyGraphOptions{} + var numMountFrom atomic.Int64 + opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) { + numMountFrom.Add(1) + return nil, nil + } + if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { + t.Fatalf("CopyGraph() error = %v", err) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + if got, expected := dst.numPush.Load(), int64(7); got != expected { + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + if got, expected := numMountFrom.Load(), int64(4); got != expected { + t.Errorf("count(MountFrom()) = %d, want %d", got, expected) + } + }) + + t.Run("MountFrom error", func(t *testing.T) { + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + opts = oras.CopyGraphOptions{} + var numMountFrom atomic.Int64 + e := errors.New("mountFrom error") + opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) { + numMountFrom.Add(1) + return nil, e + } + if err := oras.CopyGraph(ctx, src, dst, root, opts); !errors.Is(err, e) { + t.Fatalf("CopyGraph() error = %v, wantErr %v", err, e) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + if got, expected := dst.numPush.Load(), int64(0); got != expected { + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + if got, expected := numMountFrom.Load(), int64(4); got != expected { + t.Errorf("count(MountFrom()) = %d, want %d", got, expected) + } + }) + + t.Run("MountFrom OnMounted error", func(t *testing.T) { + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + var numMount atomic.Int64 + dst.mount = func(ctx context.Context, + desc ocispec.Descriptor, + fromRepo string, + getContent func() (io.ReadCloser, error), + ) error { + numMount.Add(1) + if expected := "source"; fromRepo != expected { + t.Fatalf("fromRepo = %v, want %v", fromRepo, expected) + } + rc, err := src.Fetch(ctx, desc) + if err != nil { + t.Fatalf("Failed to fetch content: %v", err) + } + defer rc.Close() + err = dst.storage.Push(ctx, desc, rc) // bypass the counters + if err != nil { + t.Fatalf("Failed to push content: %v", err) + } + return nil + } + opts = oras.CopyGraphOptions{} + var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64 + opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPreCopy.Add(1) + return nil + } + opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPostCopy.Add(1) + return nil + } + e := errors.New("onMounted error") + opts.OnMounted = func(ctx context.Context, d ocispec.Descriptor) error { + numOnMounted.Add(1) + return e + } + opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) { + numMountFrom.Add(1) + return []string{"source"}, nil + } + if err := oras.CopyGraph(ctx, src, dst, root, opts); !errors.Is(err, e) { + t.Fatalf("CopyGraph() error = %v, wantErr %v", err, e) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + if got, expected := dst.numPush.Load(), int64(0); got != expected { + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + if got, expected := numMount.Load(), int64(4); got != expected { + t.Errorf("count(Mount()) = %d, want %d", got, expected) + } + if got, expected := numOnMounted.Load(), int64(4); got != expected { + t.Errorf("count(OnMounted()) = %d, want %d", got, expected) + } + if got, expected := numMountFrom.Load(), int64(4); got != expected { + t.Errorf("count(MountFrom()) = %d, want %d", got, expected) + } + if got, expected := numPreCopy.Load(), int64(0); got != expected { + t.Errorf("count(PreCopy()) = %d, want %d", got, expected) + } + if got, expected := numPostCopy.Load(), int64(0); got != expected { + t.Errorf("count(PostCopy()) = %d, want %d", got, expected) + } + }) } // countingStorage counts the calls to its content.Storage methods type countingStorage struct { - storage content.Storage + storage content.Storage + mount mountFunc + numExists, numFetch, numPush atomic.Int64 } @@ -1494,6 +1901,16 @@ func (cs *countingStorage) Push(ctx context.Context, target ocispec.Descriptor, return cs.storage.Push(ctx, target, r) } +type mountFunc func(context.Context, ocispec.Descriptor, string, func() (io.ReadCloser, error)) error + +func (cs *countingStorage) Mount(ctx context.Context, + desc ocispec.Descriptor, + fromRepo string, + getContent func() (io.ReadCloser, error), +) error { + return cs.mount(ctx, desc, fromRepo, getContent) +} + func TestCopyGraph_WithConcurrencyLimit(t *testing.T) { src := cas.NewMemory() // generate test content diff --git a/example_copy_test.go b/example_copy_test.go index 58ee9f564..acfeb8188 100644 --- a/example_copy_test.go +++ b/example_copy_test.go @@ -215,6 +215,46 @@ func ExampleCopy_remoteToRemote() { // sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 } +func ExampleCopy_remoteToRemoteWithMount() { + reg, err := remote.NewRegistry(remoteHost) + if err != nil { + panic(err) // Handle error + } + ctx := context.Background() + src, err := reg.Repository(ctx, "source") + if err != nil { + panic(err) // Handle error + } + dst, err := reg.Repository(ctx, "target") + if err != nil { + panic(err) // Handle error + } + + tagName := "latest" + + opts := oras.CopyOptions{} + // optionally be notified that a mount occurred. + opts.OnMounted = func(ctx context.Context, desc ocispec.Descriptor) error { + // log.Println("Mounted", desc.Digest) + return nil + } + + // Enable cross-repository blob mounting + opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) { + // the slice of source repositores may also come from a database of known locations of blobs + return []string{"source/repository/name"}, nil + } + + desc, err := oras.Copy(ctx, src, tagName, dst, tagName, opts) + if err != nil { + panic(err) // Handle error + } + fmt.Println("Final", desc.Digest) + + // Output: + // Final sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 +} + func ExampleCopy_remoteToLocal() { reg, err := remote.NewRegistry(remoteHost) if err != nil { diff --git a/registry/remote/repository_test.go b/registry/remote/repository_test.go index b66aec462..583f9e812 100644 --- a/registry/remote/repository_test.go +++ b/registry/remote/repository_test.go @@ -421,16 +421,37 @@ func TestRepository_Mount_Fallback(t *testing.T) { repo.PlainHTTP = true ctx := context.Background() - err = repo.Mount(ctx, blobDesc, "test", nil) - if err != nil { - t.Fatalf("Repository.Push() error = %v", err) - } - if !bytes.Equal(gotBlob, blob) { - t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob) - } - if got, want := sequence, "post get put "; got != want { - t.Errorf("unexpected request sequence; got %q want %q", got, want) - } + t.Run("getContent is nil", func(t *testing.T) { + sequence = "" + + err = repo.Mount(ctx, blobDesc, "test", nil) + if err != nil { + t.Fatalf("Repository.Push() error = %v", err) + } + if !bytes.Equal(gotBlob, blob) { + t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob) + } + if got, want := sequence, "post get put "; got != want { + t.Errorf("unexpected request sequence; got %q want %q", got, want) + } + }) + + t.Run("getContent is non nil", func(t *testing.T) { + sequence = "" + + err = repo.Mount(ctx, blobDesc, "test", func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(blob)), nil + }) + if err != nil { + t.Fatalf("Repository.Push() error = %v", err) + } + if !bytes.Equal(gotBlob, blob) { + t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob) + } + if got, want := sequence, "post put "; got != want { + t.Errorf("unexpected request sequence; got %q want %q", got, want) + } + }) } func TestRepository_Mount_Error(t *testing.T) {