diff --git a/copy.go b/copy.go index 9fc66ff1..66f4d7e9 100644 --- a/copy.go +++ b/copy.go @@ -342,5 +342,22 @@ func prepareCopy(ctx context.Context, dst Target, dstRef string, proxy *cas.Prox } } + onCopySkipped := opts.OnCopySkipped + opts.OnCopySkipped = func(ctx context.Context, desc ocispec.Descriptor) error { + if onCopySkipped != nil { + if err := onCopySkipped(ctx, desc); err != nil { + return err + } + } + if !descriptor.EqualOCI(desc, root) { + return nil + } + // enforce tagging when root is skipped + if refPusher, ok := dst.(registry.ReferencePusher); ok { + return copyCachedNodeWithReference(ctx, proxy, refPusher, desc, dstRef) + } + return dst.Tag(ctx, root, dstRef) + } + return nil } diff --git a/copy_test.go b/copy_test.go index 5ac56726..c8c36714 100644 --- a/copy_test.go +++ b/copy_test.go @@ -135,6 +135,114 @@ func TestCopy_FullCopy(t *testing.T) { } } +func TestCopy_ExistedRoot(t *testing.T) { + src := memory.New() + dst := memory.New() + + // generate test content + var blobs [][]byte + var descs []ocispec.Descriptor + appendBlob := func(mediaType string, blob []byte) { + blobs = append(blobs, blob) + descs = append(descs, ocispec.Descriptor{ + MediaType: mediaType, + Digest: digest.FromBytes(blob), + Size: int64(len(blob)), + }) + } + generateManifest := func(config ocispec.Descriptor, layers ...ocispec.Descriptor) { + manifest := ocispec.Manifest{ + Config: config, + Layers: layers, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatal(err) + } + appendBlob(ocispec.MediaTypeImageManifest, manifestJSON) + } + + appendBlob(ocispec.MediaTypeImageConfig, []byte("config")) // Blob 0 + appendBlob(ocispec.MediaTypeImageLayer, []byte("foo")) // Blob 1 + appendBlob(ocispec.MediaTypeImageLayer, []byte("bar")) // Blob 2 + generateManifest(descs[0], descs[1:3]...) // Blob 3 + + ctx := context.Background() + for i := range blobs { + err := src.Push(ctx, descs[i], bytes.NewReader(blobs[i])) + if err != nil { + t.Fatalf("failed to push test content to src: %d: %v", i, err) + } + } + + root := descs[3] + ref := "foobar" + newTag := "newtag" + err := src.Tag(ctx, root, ref) + if err != nil { + t.Fatal("fail to tag root node", err) + } + + var skippedCount int64 + copyOpts := oras.CopyOptions{ + CopyGraphOptions: oras.CopyGraphOptions{ + OnCopySkipped: func(ctx context.Context, desc ocispec.Descriptor) error { + atomic.AddInt64(&skippedCount, 1) + return nil + }, + }, + } + + // copy with src tag + gotDesc, err := oras.Copy(ctx, src, ref, dst, "", copyOpts) + if err != nil { + t.Fatalf("Copy() error = %v, wantErr %v", err, false) + } + if !reflect.DeepEqual(gotDesc, root) { + t.Errorf("Copy() = %v, want %v", gotDesc, root) + } + // copy with new tag + gotDesc, err = oras.Copy(ctx, src, ref, dst, newTag, copyOpts) + if err != nil { + t.Fatalf("Copy() error = %v, wantErr %v", err, false) + } + if !reflect.DeepEqual(gotDesc, root) { + t.Errorf("Copy() = %v, want %v", gotDesc, root) + } + + // verify contents + for i, desc := range descs { + exists, err := dst.Exists(ctx, desc) + if err != nil { + t.Fatalf("dst.Exists(%d) error = %v", i, err) + } + if !exists { + t.Errorf("dst.Exists(%d) = %v, want %v", i, exists, true) + } + } + + // verify src tag + gotDesc, err = dst.Resolve(ctx, ref) + if err != nil { + t.Fatal("dst.Resolve() error =", err) + } + if !reflect.DeepEqual(gotDesc, root) { + t.Errorf("dst.Resolve() = %v, want %v", gotDesc, root) + } + // verify new tag + gotDesc, err = dst.Resolve(ctx, newTag) + if err != nil { + t.Fatal("dst.Resolve() error =", err) + } + if !reflect.DeepEqual(gotDesc, root) { + t.Errorf("dst.Resolve() = %v, want %v", gotDesc, root) + } + // verify invocation of onCopySkipped() + if got, want := skippedCount, int64(1); got != want { + t.Errorf("count(OnCopySkipped()) = %v, want %v", got, want) + } +} + func TestCopyGraph_FullCopy(t *testing.T) { src := cas.NewMemory() dst := cas.NewMemory()