diff --git a/.github/.codecov.yml b/.github/codecov.yml similarity index 96% rename from .github/.codecov.yml rename to .github/codecov.yml index e686805d..7ca1cad2 100644 --- a/.github/.codecov.yml +++ b/.github/codecov.yml @@ -15,5 +15,5 @@ coverage: status: project: default: - target: 70% + target: 80% if_ci_failed: error \ No newline at end of file diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ef8691dd..4e1eac42 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,20 +15,24 @@ name: build on: push: - branches: main + branches: + - main + - release-* pull_request: - branches: main + branches: + - main + - release-* jobs: build: runs-on: ubuntu-latest strategy: matrix: - go-version: ['1.19', '1.20'] + go-version: ['1.20', '1.21'] fail-fast: true steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Go ${{ matrix.go-version }} environment uses: actions/setup-go@v4 with: diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 47b50790..4c8067d4 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -15,9 +15,13 @@ name: CodeQL on: push: - branches: main + branches: + - main + - release-* pull_request: - branches: main + branches: + - main + - release-* schedule: - cron: '34 13 * * 3' @@ -31,11 +35,11 @@ jobs: security-events: write strategy: matrix: - go-version: ['1.19', '1.20'] + go-version: ['1.20', '1.21'] fail-fast: false steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Go ${{ matrix.go-version }} environment uses: actions/setup-go@v4 with: diff --git a/.github/workflows/license-checker.yml b/.github/workflows/license-checker.yml index cc91c47a..33c39831 100644 --- a/.github/workflows/license-checker.yml +++ b/.github/workflows/license-checker.yml @@ -15,9 +15,13 @@ name: License Checker on: push: - branches: main + branches: + - main + - release-* pull_request: - branches: main + branches: + - main + - release-* permissions: contents: write @@ -28,13 +32,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Check license header - uses: apache/skywalking-eyes/header@v0.4.0 + uses: apache/skywalking-eyes/header@v0.5.0 with: mode: check config: .github/licenserc.yml - name: Check dependencies license - uses: apache/skywalking-eyes/dependency@v0.4.0 + uses: apache/skywalking-eyes/dependency@v0.5.0 with: config: .github/licenserc.yml diff --git a/README.md b/README.md index 50bb3064..70662555 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,9 @@ The ORAS Go library follows [Semantic Versioning](https://semver.org/), where br The version `2` is actively developed in the [`main`](https://github.com/oras-project/oras-go/tree/main) branch with all new features. +> [!Note] +> The `main` branch follows [Go's Security Policy](https://github.com/golang/go/security/policy) and supports the two latest versions of Go (currently `1.20` and `1.21`). + Examples for common use cases can be found below: - [Copy examples](https://pkg.go.dev/oras.land/oras-go/v2#pkg-examples) diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..ffefe341 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,3 @@ +# Security Policy + +Please follow the [security policy](https://oras.land/docs/community/reporting_security_concerns) to report a security vulnerability or concern. diff --git a/content.go b/content.go index 53eb6c75..b8bf2638 100644 --- a/content.go +++ b/content.go @@ -29,7 +29,6 @@ import ( "oras.land/oras-go/v2/internal/docker" "oras.land/oras-go/v2/internal/interfaces" "oras.land/oras-go/v2/internal/platform" - "oras.land/oras-go/v2/internal/registryutil" "oras.land/oras-go/v2/internal/syncutil" "oras.land/oras-go/v2/registry" "oras.land/oras-go/v2/registry/remote/auth" @@ -91,7 +90,7 @@ func TagN(ctx context.Context, target Target, srcReference string, dstReferences if err != nil { return ocispec.Descriptor{}, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull, auth.ActionPush) } desc, contentBytes, err := FetchBytes(ctx, target, srcReference, FetchBytesOptions{ @@ -149,7 +148,7 @@ func Tag(ctx context.Context, target Target, src, dst string) (ocispec.Descripto if err != nil { return ocispec.Descriptor{}, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull, auth.ActionPush) } desc, rc, err := refFetcher.FetchReference(ctx, src) if err != nil { diff --git a/content/graph.go b/content/graph.go index fa2f9efe..9ae83728 100644 --- a/content/graph.go +++ b/content/graph.go @@ -75,18 +75,33 @@ func Successors(ctx context.Context, fetcher Fetcher, node ocispec.Descriptor) ( } nodes = append(nodes, manifest.Config) return append(nodes, manifest.Layers...), nil - case docker.MediaTypeManifestList, ocispec.MediaTypeImageIndex: + case docker.MediaTypeManifestList: content, err := FetchAll(ctx, fetcher, node) if err != nil { return nil, err } - // docker manifest list and oci index are equivalent for successors. + // OCI manifest index schema can be used to marshal docker manifest list var index ocispec.Index if err := json.Unmarshal(content, &index); err != nil { return nil, err } return index.Manifests, nil + case ocispec.MediaTypeImageIndex: + content, err := FetchAll(ctx, fetcher, node) + if err != nil { + return nil, err + } + + var index ocispec.Index + if err := json.Unmarshal(content, &index); err != nil { + return nil, err + } + var nodes []ocispec.Descriptor + if index.Subject != nil { + nodes = append(nodes, *index.Subject) + } + return append(nodes, index.Manifests...), nil case spec.MediaTypeArtifactManifest: content, err := FetchAll(ctx, fetcher, node) if err != nil { diff --git a/content/graph_test.go b/content/graph_test.go new file mode 100644 index 00000000..d90dc81a --- /dev/null +++ b/content/graph_test.go @@ -0,0 +1,391 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package content_test + +import ( + "bytes" + "context" + "encoding/json" + "reflect" + "testing" + + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2/content" + "oras.land/oras-go/v2/internal/cas" + "oras.land/oras-go/v2/internal/docker" + "oras.land/oras-go/v2/internal/spec" +) + +func TestSuccessors_dockerManifest(t *testing.T) { + storage := cas.NewMemory() + + // generate test content + var blobs [][]byte + var descs []ocispec.Descriptor + appendBlob := func(mediaType string, blob []byte) { + blobs = append(blobs, blob) + descs = append(descs, ocispec.Descriptor{ + MediaType: mediaType, + Digest: digest.FromBytes(blob), + Size: int64(len(blob)), + }) + } + generateManifest := func(config ocispec.Descriptor, layers ...ocispec.Descriptor) { + manifest := ocispec.Manifest{ + Config: config, + Layers: layers, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatal(err) + } + appendBlob(docker.MediaTypeManifest, manifestJSON) + } + + appendBlob(ocispec.MediaTypeImageConfig, []byte("config")) // Blob 0 + appendBlob(ocispec.MediaTypeImageLayer, []byte("foo")) // Blob 1 + appendBlob(ocispec.MediaTypeImageLayer, []byte("bar")) // Blob 2 + appendBlob(ocispec.MediaTypeImageLayer, []byte("hello")) // Blob 3 + generateManifest(descs[0], descs[1:4]...) // Blob 4 + + ctx := context.Background() + for i := range blobs { + err := storage.Push(ctx, descs[i], bytes.NewReader(blobs[i])) + if err != nil { + t.Fatalf("failed to push test content to src: %d: %v", i, err) + } + } + + // test Successors + manifestDesc := descs[4] + got, err := content.Successors(ctx, storage, manifestDesc) + if err != nil { + t.Fatal("Successors() error =", err) + } + if want := descs[0:4]; !reflect.DeepEqual(got, want) { + t.Errorf("Successors() = %v, want %v", got, want) + } +} + +func TestSuccessors_imageManifest(t *testing.T) { + storage := cas.NewMemory() + + // generate test content + var blobs [][]byte + var descs []ocispec.Descriptor + appendBlob := func(mediaType string, blob []byte) { + blobs = append(blobs, blob) + descs = append(descs, ocispec.Descriptor{ + MediaType: mediaType, + Digest: digest.FromBytes(blob), + Size: int64(len(blob)), + }) + } + generateManifest := func(subject *ocispec.Descriptor, config ocispec.Descriptor, layers ...ocispec.Descriptor) { + manifest := ocispec.Manifest{ + Subject: subject, + Config: config, + Layers: layers, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatal(err) + } + appendBlob(ocispec.MediaTypeImageManifest, manifestJSON) + } + + appendBlob(ocispec.MediaTypeImageConfig, []byte("config")) // Blob 0 + appendBlob(ocispec.MediaTypeImageLayer, []byte("foo")) // Blob 1 + appendBlob(ocispec.MediaTypeImageLayer, []byte("bar")) // Blob 2 + appendBlob(ocispec.MediaTypeImageLayer, []byte("hello")) // Blob 3 + generateManifest(nil, descs[0], descs[1:4]...) // Blob 4 + appendBlob(ocispec.MediaTypeImageConfig, []byte("{}")) // Blob 5 + appendBlob("test/sig", []byte("sig")) // Blob 6 + generateManifest(&descs[4], descs[5], descs[6]) // Blob 7 + + ctx := context.Background() + for i := range blobs { + err := storage.Push(ctx, descs[i], bytes.NewReader(blobs[i])) + if err != nil { + t.Fatalf("failed to push test content to src: %d: %v", i, err) + } + } + + // test Successors: image manifest without a subject + manifestDesc := descs[4] + got, err := content.Successors(ctx, storage, manifestDesc) + if err != nil { + t.Fatal("Successors() error =", err) + } + if want := descs[0:4]; !reflect.DeepEqual(got, want) { + t.Errorf("Successors() = %v, want %v", got, want) + } + + // test Successors: image manifest with a subject + manifestDesc = descs[7] + got, err = content.Successors(ctx, storage, manifestDesc) + if err != nil { + t.Fatal("Successors() error =", err) + } + if want := descs[4:7]; !reflect.DeepEqual(got, want) { + t.Errorf("Successors() = %v, want %v", got, want) + } +} + +func TestSuccessors_dockerManifestList(t *testing.T) { + storage := cas.NewMemory() + + // generate test content + var blobs [][]byte + var descs []ocispec.Descriptor + appendBlob := func(mediaType string, blob []byte) { + blobs = append(blobs, blob) + descs = append(descs, ocispec.Descriptor{ + MediaType: mediaType, + Digest: digest.FromBytes(blob), + Size: int64(len(blob)), + }) + } + generateManifest := func(config ocispec.Descriptor, layers ...ocispec.Descriptor) { + manifest := ocispec.Manifest{ + Config: config, + Layers: layers, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatal(err) + } + appendBlob(docker.MediaTypeManifest, manifestJSON) + } + generateIndex := func(manifests ...ocispec.Descriptor) { + index := ocispec.Index{ + Manifests: manifests, + } + indexJSON, err := json.Marshal(index) + if err != nil { + t.Fatal(err) + } + appendBlob(docker.MediaTypeManifestList, indexJSON) + } + + appendBlob(ocispec.MediaTypeImageConfig, []byte("config")) // Blob 0 + appendBlob(ocispec.MediaTypeImageLayer, []byte("foo")) // Blob 1 + appendBlob(ocispec.MediaTypeImageLayer, []byte("bar")) // Blob 2 + appendBlob(ocispec.MediaTypeImageLayer, []byte("hello")) // Blob 3 + generateManifest(descs[0], descs[1:3]...) // Blob 4 + generateManifest(descs[0], descs[3]) // Blob 5 + generateIndex(descs[4:6]...) // Blob 6 + + ctx := context.Background() + for i := range blobs { + err := storage.Push(ctx, descs[i], bytes.NewReader(blobs[i])) + if err != nil { + t.Fatalf("failed to push test content to src: %d: %v", i, err) + } + } + + // test Successors + manifestDesc := descs[6] + got, err := content.Successors(ctx, storage, manifestDesc) + if err != nil { + t.Fatal("Successors() error =", err) + } + if want := descs[4:6]; !reflect.DeepEqual(got, want) { + t.Errorf("Successors() = %v, want %v", got, want) + } +} + +func TestSuccessors_imageIndex(t *testing.T) { + storage := cas.NewMemory() + + // generate test content + var blobs [][]byte + var descs []ocispec.Descriptor + appendBlob := func(mediaType string, blob []byte) { + blobs = append(blobs, blob) + descs = append(descs, ocispec.Descriptor{ + MediaType: mediaType, + Digest: digest.FromBytes(blob), + Size: int64(len(blob)), + }) + } + generateManifest := func(subject *ocispec.Descriptor, config ocispec.Descriptor, layers ...ocispec.Descriptor) { + manifest := ocispec.Manifest{ + Subject: subject, + Config: config, + Layers: layers, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatal(err) + } + appendBlob(ocispec.MediaTypeImageManifest, manifestJSON) + } + generateIndex := func(subject *ocispec.Descriptor, manifests ...ocispec.Descriptor) { + index := ocispec.Index{ + Subject: subject, + Manifests: manifests, + } + indexJSON, err := json.Marshal(index) + if err != nil { + t.Fatal(err) + } + appendBlob(ocispec.MediaTypeImageIndex, indexJSON) + } + + appendBlob(ocispec.MediaTypeImageConfig, []byte("config")) // Blob 0 + appendBlob(ocispec.MediaTypeImageLayer, []byte("foo")) // Blob 1 + appendBlob(ocispec.MediaTypeImageLayer, []byte("bar")) // Blob 2 + appendBlob(ocispec.MediaTypeImageLayer, []byte("hello")) // Blob 3 + generateManifest(nil, descs[0], descs[1:3]...) // Blob 4 + generateManifest(nil, descs[0], descs[3]) // Blob 5 + appendBlob(ocispec.MediaTypeImageConfig, []byte("{}")) // Blob 6 + appendBlob("test/sig", []byte("sig")) // Blob 7 + generateManifest(&descs[4], descs[5], descs[6]) // Blob 8 + generateIndex(&descs[8], descs[4:6]...) // Blob 9 + + ctx := context.Background() + for i := range blobs { + err := storage.Push(ctx, descs[i], bytes.NewReader(blobs[i])) + if err != nil { + t.Fatalf("failed to push test content to src: %d: %v", i, err) + } + } + + // test Successors + manifestDesc := descs[9] + got, err := content.Successors(ctx, storage, manifestDesc) + if err != nil { + t.Fatal("Successors() error =", err) + } + if want := append([]ocispec.Descriptor{descs[8]}, descs[4:6]...); !reflect.DeepEqual(got, want) { + t.Errorf("Successors() = %v, want %v", got, want) + } +} + +func TestSuccessors_artifactManifest(t *testing.T) { + storage := cas.NewMemory() + + // generate test content + var blobs [][]byte + var descs []ocispec.Descriptor + appendBlob := func(mediaType string, blob []byte) { + blobs = append(blobs, blob) + descs = append(descs, ocispec.Descriptor{ + MediaType: mediaType, + Digest: digest.FromBytes(blob), + Size: int64(len(blob)), + }) + } + generateArtifactManifest := func(subject *ocispec.Descriptor, blobs ...ocispec.Descriptor) { + manifest := spec.Artifact{ + Subject: subject, + Blobs: blobs, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatal(err) + } + appendBlob(spec.MediaTypeArtifactManifest, manifestJSON) + } + + appendBlob(ocispec.MediaTypeImageLayer, []byte("foo")) // Blob 0 + appendBlob(ocispec.MediaTypeImageLayer, []byte("bar")) // Blob 1 + appendBlob(ocispec.MediaTypeImageLayer, []byte("hello")) // Blob 2 + generateArtifactManifest(nil, descs[0:3]...) // Blob 3 + appendBlob("test/sig", []byte("sig")) // Blob 4 + generateArtifactManifest(&descs[3], descs[4]) // Blob 5 + + ctx := context.Background() + for i := range blobs { + err := storage.Push(ctx, descs[i], bytes.NewReader(blobs[i])) + if err != nil { + t.Fatalf("failed to push test content to src: %d: %v", i, err) + } + } + + // test Successors: image manifest without a subject + manifestDesc := descs[3] + got, err := content.Successors(ctx, storage, manifestDesc) + if err != nil { + t.Fatal("Successors() error =", err) + } + if want := descs[0:3]; !reflect.DeepEqual(got, want) { + t.Errorf("Successors() = %v, want %v", got, want) + } + + // test Successors: image manifest with a subject + manifestDesc = descs[5] + got, err = content.Successors(ctx, storage, manifestDesc) + if err != nil { + t.Fatal("Successors() error =", err) + } + if want := descs[3:5]; !reflect.DeepEqual(got, want) { + t.Errorf("Successors() = %v, want %v", got, want) + } +} + +func TestSuccessors_otherMediaType(t *testing.T) { + storage := cas.NewMemory() + + // generate test content + var blobs [][]byte + var descs []ocispec.Descriptor + appendBlob := func(mediaType string, blob []byte) { + blobs = append(blobs, blob) + descs = append(descs, ocispec.Descriptor{ + MediaType: mediaType, + Digest: digest.FromBytes(blob), + Size: int64(len(blob)), + }) + } + generateManifest := func(mediaType string, config ocispec.Descriptor, layers ...ocispec.Descriptor) { + manifest := ocispec.Manifest{ + Config: config, + Layers: layers, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatal(err) + } + appendBlob(mediaType, manifestJSON) + } + + appendBlob(ocispec.MediaTypeImageConfig, []byte("config")) // Blob 0 + appendBlob(ocispec.MediaTypeImageLayer, []byte("foo")) // Blob 1 + appendBlob(ocispec.MediaTypeImageLayer, []byte("bar")) // Blob 2 + appendBlob(ocispec.MediaTypeImageLayer, []byte("hello")) // Blob 3 + generateManifest("whatever", descs[0], descs[1:4]...) // Blob 4 + + ctx := context.Background() + for i := range blobs { + err := storage.Push(ctx, descs[i], bytes.NewReader(blobs[i])) + if err != nil { + t.Fatalf("failed to push test content to src: %d: %v", i, err) + } + } + + // test Successors: other media type + manifestDesc := descs[4] + got, err := content.Successors(ctx, storage, manifestDesc) + if err != nil { + t.Fatal("Successors() error =", err) + } + if got != nil { + t.Errorf("Successors() = %v, want nil", got) + } +} diff --git a/content/oci/oci.go b/content/oci/oci.go index a473e5c1..27afde16 100644 --- a/content/oci/oci.go +++ b/content/oci/oci.go @@ -14,7 +14,7 @@ limitations under the License. */ // Package oci provides access to an OCI content store. -// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc2/image-layout.md +// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc4/image-layout.md package oci import ( @@ -38,14 +38,9 @@ import ( "oras.land/oras-go/v2/internal/resolver" ) -// ociImageIndexFile is the file name of the index -// from the OCI Image Layout Specification. -// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc2/image-layout.md#indexjson-file -const ociImageIndexFile = "index.json" - // Store implements `oras.Target`, and represents a content store // based on file system with the OCI-Image layout. -// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc2/image-layout.md +// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc4/image-layout.md type Store struct { // AutoSaveIndex controls if the OCI store will automatically save the index // file on each Tag() call. @@ -84,20 +79,20 @@ func NewWithContext(ctx context.Context, root string) (*Store, error) { store := &Store{ AutoSaveIndex: true, root: rootAbs, - indexPath: filepath.Join(rootAbs, ociImageIndexFile), + indexPath: filepath.Join(rootAbs, ocispec.ImageIndexFile), storage: storage, tagResolver: resolver.NewMemory(), graph: graph.NewMemory(), } - if err := ensureDir(rootAbs); err != nil { + if err := ensureDir(filepath.Join(rootAbs, ocispec.ImageBlobsDir)); err != nil { return nil, err } if err := store.ensureOCILayoutFile(); err != nil { return nil, fmt.Errorf("invalid OCI Image Layout: %w", err) } if err := store.loadIndexFile(ctx); err != nil { - return nil, fmt.Errorf("invalid OCI Image Layout: %w", err) + return nil, fmt.Errorf("invalid OCI Image Index: %w", err) } return store, nil @@ -130,7 +125,7 @@ func (s *Store) Exists(ctx context.Context, target ocispec.Descriptor) (bool, er // Tag tags a descriptor with a reference string. // reference should be a valid tag (e.g. "latest"). -// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc2/image-layout.md#indexjson-file +// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc4/image-layout.md#indexjson-file func (s *Store) Tag(ctx context.Context, desc ocispec.Descriptor, reference string) error { if err := validateReference(reference); err != nil { return err diff --git a/content/oci/oci_test.go b/content/oci/oci_test.go index 931dabf5..43e57f0a 100644 --- a/content/oci/oci_test.go +++ b/content/oci/oci_test.go @@ -109,7 +109,7 @@ func TestStore_Success(t *testing.T) { } // validate index.json - indexFilePath := filepath.Join(tempDir, ociImageIndexFile) + indexFilePath := filepath.Join(tempDir, "index.json") indexFile, err := os.Open(indexFilePath) if err != nil { t.Errorf("error opening layout file, error = %v", err) @@ -361,7 +361,7 @@ func TestStore_NotExistingRoot(t *testing.T) { } // validate index.json - indexFilePath := filepath.Join(root, ociImageIndexFile) + indexFilePath := filepath.Join(root, "index.json") indexFile, err := os.Open(indexFilePath) if err != nil { t.Errorf("error opening layout file, error = %v", err) @@ -930,7 +930,7 @@ func TestStore_TagByDigest(t *testing.T) { func TestStore_BadIndex(t *testing.T) { tempDir := t.TempDir() content := []byte("whatever") - path := filepath.Join(tempDir, ociImageIndexFile) + path := filepath.Join(tempDir, "index.json") os.WriteFile(path, content, 0666) _, err := New(tempDir) diff --git a/content/oci/readonlyoci.go b/content/oci/readonlyoci.go index b70b7675..66ca54c9 100644 --- a/content/oci/readonlyoci.go +++ b/content/oci/readonlyoci.go @@ -36,7 +36,7 @@ import ( // ReadOnlyStore implements `oras.ReadonlyTarget`, and represents a read-only // content store based on file system with the OCI-Image layout. -// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc2/image-layout.md +// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc4/image-layout.md type ReadOnlyStore struct { fsys fs.FS storage content.ReadOnlyStorage @@ -57,7 +57,7 @@ func NewFromFS(ctx context.Context, fsys fs.FS) (*ReadOnlyStore, error) { return nil, fmt.Errorf("invalid OCI Image Layout: %w", err) } if err := store.loadIndexFile(ctx); err != nil { - return nil, fmt.Errorf("invalid OCI Image Layout: %w", err) + return nil, fmt.Errorf("invalid OCI Image Index: %w", err) } return store, nil @@ -154,7 +154,7 @@ func validateOCILayout(layout *ocispec.ImageLayout) error { // loadIndexFile reads index.json from s.fsys. func (s *ReadOnlyStore) loadIndexFile(ctx context.Context) error { - indexFile, err := s.fsys.Open(ociImageIndexFile) + indexFile, err := s.fsys.Open(ocispec.ImageIndexFile) if err != nil { return fmt.Errorf("failed to open index file: %w", err) } diff --git a/content/oci/readonlyoci_test.go b/content/oci/readonlyoci_test.go index 346c126f..d7b4fd0e 100644 --- a/content/oci/readonlyoci_test.go +++ b/content/oci/readonlyoci_test.go @@ -127,7 +127,7 @@ func TestReadOnlyStore(t *testing.T) { fsys[path] = &fstest.MapFile{Data: blobs[i]} } fsys[ocispec.ImageLayoutFile] = &fstest.MapFile{Data: layoutJSON} - fsys[ociImageIndexFile] = &fstest.MapFile{Data: indexJSON} + fsys["index.json"] = &fstest.MapFile{Data: indexJSON} // test read-only store ctx := context.Background() @@ -507,7 +507,7 @@ func TestReadOnlyStore_TarFS(t *testing.T) { func TestReadOnlyStore_BadIndex(t *testing.T) { content := []byte("whatever") fsys := fstest.MapFS{ - ociImageIndexFile: &fstest.MapFile{Data: content}, + "index.json": &fstest.MapFile{Data: content}, } ctx := context.Background() @@ -607,7 +607,7 @@ func TestReadOnlyStore_Copy_OCIToMemory(t *testing.T) { fsys[path] = &fstest.MapFile{Data: blobs[i]} } fsys[ocispec.ImageLayoutFile] = &fstest.MapFile{Data: layoutJSON} - fsys[ociImageIndexFile] = &fstest.MapFile{Data: indexJSON} + fsys["index.json"] = &fstest.MapFile{Data: indexJSON} // test read-only store ctx := context.Background() @@ -721,7 +721,7 @@ func TestReadOnlyStore_Tags(t *testing.T) { fsys[path] = &fstest.MapFile{Data: blobs[i]} } fsys[ocispec.ImageLayoutFile] = &fstest.MapFile{Data: layoutJSON} - fsys[ociImageIndexFile] = &fstest.MapFile{Data: indexJSON} + fsys["index.json"] = &fstest.MapFile{Data: indexJSON} // test read-only store ctx := context.Background() diff --git a/content/oci/readonlystorage.go b/content/oci/readonlystorage.go index 8a5b39c5..6e319a64 100644 --- a/content/oci/readonlystorage.go +++ b/content/oci/readonlystorage.go @@ -31,7 +31,7 @@ import ( // ReadOnlyStorage is a read-only CAS based on file system with the OCI-Image // layout. -// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc2/image-layout.md +// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc4/image-layout.md type ReadOnlyStorage struct { fsys fs.FS } @@ -95,5 +95,5 @@ func blobPath(dgst digest.Digest) (string, error) { return "", fmt.Errorf("cannot calculate blob path from invalid digest %s: %w: %v", dgst.String(), errdef.ErrInvalidDigest, err) } - return path.Join("blobs", dgst.Algorithm().String(), dgst.Encoded()), nil + return path.Join(ocispec.ImageBlobsDir, dgst.Algorithm().String(), dgst.Encoded()), nil } diff --git a/content/oci/storage.go b/content/oci/storage.go index 6b0e90a8..efb9f3d8 100644 --- a/content/oci/storage.go +++ b/content/oci/storage.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "io" + "io/fs" "os" "path/filepath" "sync" @@ -42,7 +43,7 @@ var bufPool = sync.Pool{ } // Storage is a CAS based on file system with the OCI-Image layout. -// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc2/image-layout.md +// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc4/image-layout.md type Storage struct { *ReadOnlyStorage // root is the root directory of the OCI layout. @@ -106,6 +107,23 @@ func (s *Storage) Push(_ context.Context, expected ocispec.Descriptor, content i return nil } +// Delete removes the target from the system. +func (s *Storage) Delete(ctx context.Context, target ocispec.Descriptor) error { + path, err := blobPath(target.Digest) + if err != nil { + return fmt.Errorf("%s: %s: %w", target.Digest, target.MediaType, errdef.ErrInvalidDigest) + } + targetPath := filepath.Join(s.root, path) + err = os.Remove(targetPath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("%s: %s: %w", target.Digest, target.MediaType, errdef.ErrNotFound) + } + return err + } + return nil +} + // ingest write the content into a temporary ingest file. func (s *Storage) ingest(expected ocispec.Descriptor, content io.Reader) (path string, ingestErr error) { if err := ensureDir(s.ingestRoot); err != nil { diff --git a/content/oci/storage_test.go b/content/oci/storage_test.go index 7e3b1e58..244acb21 100644 --- a/content/oci/storage_test.go +++ b/content/oci/storage_test.go @@ -377,3 +377,43 @@ func TestStorage_Fetch_Concurrent(t *testing.T) { t.Fatal(err) } } + +func TestStorage_Delete(t *testing.T) { + content := []byte("test delete") + desc := ocispec.Descriptor{ + MediaType: "test", + Digest: digest.FromBytes(content), + Size: int64(len(content)), + } + tempDir := t.TempDir() + s, err := NewStorage(tempDir) + if err != nil { + t.Fatal("New() error =", err) + } + ctx := context.Background() + if err := s.Push(ctx, desc, bytes.NewReader(content)); err != nil { + t.Fatal("Storage.Push() error =", err) + } + exists, err := s.Exists(ctx, desc) + if err != nil { + t.Fatal("Storage.Exists() error =", err) + } + if !exists { + t.Errorf("Storage.Exists() = %v, want %v", exists, true) + } + err = s.Delete(ctx, desc) + if err != nil { + t.Fatal("Storage.Delete() error =", err) + } + exists, err = s.Exists(ctx, desc) + if err != nil { + t.Fatal("Storage.Exists() error =", err) + } + if exists { + t.Errorf("Storage.Exists() = %v, want %v", exists, false) + } + err = s.Delete(ctx, desc) + if !errors.Is(err, errdef.ErrNotFound) { + t.Fatalf("got error = %v, want %v", err, errdef.ErrNotFound) + } +} diff --git a/content/reader.go b/content/reader.go index 11d27b23..e575378e 100644 --- a/content/reader.go +++ b/content/reader.go @@ -70,7 +70,7 @@ func (vr *VerifyReader) Read(p []byte) (n int, err error) { return } -// Verify verifies the read content against the size and the digest. +// Verify checks for remaining unread content and verifies the read content against the digest func (vr *VerifyReader) Verify() error { if vr.verified { return nil @@ -120,7 +120,10 @@ func ReadAll(r io.Reader, desc ocispec.Descriptor) ([]byte, error) { buf := make([]byte, desc.Size) vr := NewVerifyReader(r, desc) - if _, err := io.ReadFull(vr, buf); err != nil { + if n, err := io.ReadFull(vr, buf); err != nil { + if errors.Is(err, io.ErrUnexpectedEOF) { + return nil, fmt.Errorf("read failed: expected content size of %d, got %d, for digest %s: %w", desc.Size, n, desc.Digest.String(), err) + } return nil, fmt.Errorf("read failed: %w", err) } if err := vr.Verify(); err != nil { diff --git a/content/storage.go b/content/storage.go index 971142cb..47c95d87 100644 --- a/content/storage.go +++ b/content/storage.go @@ -31,7 +31,7 @@ type Fetcher interface { // Pusher pushes content. type Pusher interface { // Push pushes the content, matching the expected descriptor. - // Reader is perferred to Writer so that the suitable buffer size can be + // Reader is preferred to Writer so that the suitable buffer size can be // chosen by the underlying implementation. Furthermore, the implementation // can also do reflection on the Reader for more advanced I/O optimization. Push(ctx context.Context, expected ocispec.Descriptor, content io.Reader) error diff --git a/copy.go b/copy.go index e55312dd..ddb430b8 100644 --- a/copy.go +++ b/copy.go @@ -37,8 +37,9 @@ import ( // defaultConcurrency is the default value of CopyGraphOptions.Concurrency. const defaultConcurrency int = 3 // This value is consistent with dockerd and containerd. -// errSkipDesc signals copyNode() to stop processing a descriptor. -var errSkipDesc = errors.New("skip descriptor") +// ErrSkipDesc signals to stop copying a descriptor. When returned from PreCopy the blob must exist in the target. +// This can be used to signal that a blob has been made available in the target repository by "Mount()" or some other technique. +var ErrSkipDesc = errors.New("skip descriptor") // DefaultCopyOptions provides the default CopyOptions. var DefaultCopyOptions CopyOptions = CopyOptions{ @@ -281,7 +282,7 @@ func doCopyNode(ctx context.Context, src content.ReadOnlyStorage, dst content.St func copyNode(ctx context.Context, src content.ReadOnlyStorage, dst content.Storage, desc ocispec.Descriptor, opts CopyGraphOptions) error { if opts.PreCopy != nil { if err := opts.PreCopy(ctx, desc); err != nil { - if err == errSkipDesc { + if err == ErrSkipDesc { return nil } return err @@ -373,7 +374,7 @@ func prepareCopy(ctx context.Context, dst Target, dstRef string, proxy *cas.Prox } } // skip the regular copy workflow - return errSkipDesc + return ErrSkipDesc } } else { postCopy := opts.PostCopy @@ -393,18 +394,26 @@ func prepareCopy(ctx context.Context, dst Target, dstRef string, proxy *cas.Prox onCopySkipped := opts.OnCopySkipped opts.OnCopySkipped = func(ctx context.Context, desc ocispec.Descriptor) error { - if onCopySkipped != nil { - if err := onCopySkipped(ctx, desc); err != nil { - return err - } - } if !content.Equal(desc, root) { + if onCopySkipped != nil { + return onCopySkipped(ctx, desc) + } return nil } - // enforce tagging when root is skipped + + // enforce tagging when the skipped node is root if refPusher, ok := dst.(registry.ReferencePusher); ok { + // NOTE: refPusher tags the node by copying it with the reference, + // so onCopySkipped shouldn't be invoked in this case return copyCachedNodeWithReference(ctx, proxy, refPusher, desc, dstRef) } + + // invoke onCopySkipped before tagging + if onCopySkipped != nil { + if err := onCopySkipped(ctx, desc); err != nil { + return err + } + } return dst.Tag(ctx, root, dstRef) } diff --git a/copy_test.go b/copy_test.go index b1031bc6..0b6e6c20 100644 --- a/copy_test.go +++ b/copy_test.go @@ -1432,6 +1432,66 @@ func TestCopyGraph_WithOptions(t *testing.T) { if err := oras.CopyGraph(ctx, src, dst, root, opts); !errors.Is(err, errdef.ErrSizeExceedsLimit) { t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit) } + + t.Run("ErrSkipDesc", func(t *testing.T) { + // test CopyGraph with PreCopy = 1 + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + opts = oras.CopyGraphOptions{ + PreCopy: func(ctx context.Context, desc ocispec.Descriptor) error { + if descs[1].Digest == desc.Digest { + // blob 1 is handled by us (really this would be a Mount but ) + rc, err := src.Fetch(ctx, desc) + if err != nil { + t.Fatalf("Failed to fetch: %v", err) + } + defer rc.Close() + err = dst.storage.Push(ctx, desc, rc) // bypass the counters + if err != nil { + t.Fatalf("Failed to fetch: %v", err) + } + return oras.ErrSkipDesc + } + return nil + }, + } + if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { + t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + // 7 (exists) - 1 (skipped) = 6 pushes expected + if got, expected := dst.numPush.Load(), int64(6); got != expected { + // If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do. + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + }) +} + +// countingStorage counts the calls to its content.Storage methods +type countingStorage struct { + storage content.Storage + numExists, numFetch, numPush atomic.Int64 +} + +func (cs *countingStorage) Exists(ctx context.Context, target ocispec.Descriptor) (bool, error) { + cs.numExists.Add(1) + return cs.storage.Exists(ctx, target) +} + +func (cs *countingStorage) Fetch(ctx context.Context, target ocispec.Descriptor) (io.ReadCloser, error) { + cs.numFetch.Add(1) + return cs.storage.Fetch(ctx, target) +} + +func (cs *countingStorage) Push(ctx context.Context, target ocispec.Descriptor, r io.Reader) error { + cs.numPush.Add(1) + return cs.storage.Push(ctx, target, r) } func TestCopyGraph_WithConcurrencyLimit(t *testing.T) { diff --git a/errdef/errors.go b/errdef/errors.go index 030360ed..7adb44b1 100644 --- a/errdef/errors.go +++ b/errdef/errors.go @@ -22,6 +22,7 @@ var ( ErrAlreadyExists = errors.New("already exists") ErrInvalidDigest = errors.New("invalid digest") ErrInvalidReference = errors.New("invalid reference") + ErrInvalidMediaType = errors.New("invalid media type") ErrMissingReference = errors.New("missing reference") ErrNotFound = errors.New("not found") ErrSizeExceedsLimit = errors.New("size exceeds limit") diff --git a/example_copy_test.go b/example_copy_test.go new file mode 100644 index 00000000..b7e524a9 --- /dev/null +++ b/example_copy_test.go @@ -0,0 +1,357 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package oras_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strconv" + "strings" + "testing" + + "github.com/opencontainers/go-digest" + specs "github.com/opencontainers/image-spec/specs-go" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2" + "oras.land/oras-go/v2/content/memory" + "oras.land/oras-go/v2/content/oci" + "oras.land/oras-go/v2/internal/spec" + "oras.land/oras-go/v2/registry/remote" +) + +var exampleMemoryStore oras.Target +var remoteHost string +var ( + exampleManifest, _ = json.Marshal(spec.Artifact{ + MediaType: spec.MediaTypeArtifactManifest, + ArtifactType: "example/content"}) + exampleManifestDescriptor = ocispec.Descriptor{ + MediaType: spec.MediaTypeArtifactManifest, + Digest: digest.Digest(digest.FromBytes(exampleManifest)), + Size: int64(len(exampleManifest))} + exampleSignatureManifest, _ = json.Marshal(spec.Artifact{ + MediaType: spec.MediaTypeArtifactManifest, + ArtifactType: "example/signature", + Subject: &exampleManifestDescriptor}) + exampleSignatureManifestDescriptor = ocispec.Descriptor{ + MediaType: spec.MediaTypeArtifactManifest, + Digest: digest.FromBytes(exampleSignatureManifest), + Size: int64(len(exampleSignatureManifest))} +) + +func pushBlob(ctx context.Context, mediaType string, blob []byte, target oras.Target) (desc ocispec.Descriptor, err error) { + desc = ocispec.Descriptor{ // Generate descriptor based on the media type and blob content + MediaType: mediaType, + Digest: digest.FromBytes(blob), // Calculate digest + Size: int64(len(blob)), // Include blob size + } + return desc, target.Push(ctx, desc, bytes.NewReader(blob)) // Push the blob to the registry target +} + +func generateManifestContent(config ocispec.Descriptor, layers ...ocispec.Descriptor) ([]byte, error) { + content := ocispec.Manifest{ + Config: config, // Set config blob + Layers: layers, // Set layer blobs + Versioned: specs.Versioned{SchemaVersion: 2}, + } + return json.Marshal(content) // Get json content +} + +func TestMain(m *testing.M) { + const exampleTag = "latest" + const exampleUploadUUid = "0bc84d80-837c-41d9-824e-1907463c53b3" + + // Setup example local target + exampleMemoryStore = memory.New() + layerBlob := []byte("Hello layer") + ctx := context.Background() + layerDesc, err := pushBlob(ctx, ocispec.MediaTypeImageLayer, layerBlob, exampleMemoryStore) // push layer blob + if err != nil { + panic(err) + } + configBlob := []byte("Hello config") + configDesc, err := pushBlob(ctx, ocispec.MediaTypeImageConfig, configBlob, exampleMemoryStore) // push config blob + if err != nil { + panic(err) + } + manifestBlob, err := generateManifestContent(configDesc, layerDesc) // generate a image manifest + if err != nil { + panic(err) + } + manifestDesc, err := pushBlob(ctx, ocispec.MediaTypeImageManifest, manifestBlob, exampleMemoryStore) // push manifest blob + if err != nil { + panic(err) + } + err = exampleMemoryStore.Tag(ctx, manifestDesc, exampleTag) + if err != nil { + panic(err) + } + + // Setup example remote target + httpsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p := r.URL.Path + m := r.Method + switch { + case strings.Contains(p, "/blobs/uploads/") && m == "POST": + w.Header().Set("Content-Type", ocispec.MediaTypeImageManifest) + w.Header().Set("Location", p+exampleUploadUUid) + w.WriteHeader(http.StatusAccepted) + case strings.Contains(p, "/blobs/uploads/"+exampleUploadUUid) && m == "GET": + w.WriteHeader(http.StatusCreated) + case strings.Contains(p, "/manifests/"+string(exampleSignatureManifestDescriptor.Digest)): + w.Header().Set("Content-Type", spec.MediaTypeArtifactManifest) + w.Header().Set("Docker-Content-Digest", string(exampleSignatureManifestDescriptor.Digest)) + w.Header().Set("Content-Length", strconv.Itoa(len(exampleSignatureManifest))) + w.Write(exampleSignatureManifest) + case strings.Contains(p, "/manifests/latest") && m == "PUT": + w.WriteHeader(http.StatusCreated) + case strings.Contains(p, "/manifests/"+string(exampleManifestDescriptor.Digest)), + strings.Contains(p, "/manifests/latest") && m == "HEAD": + w.Header().Set("Content-Type", spec.MediaTypeArtifactManifest) + w.Header().Set("Docker-Content-Digest", string(exampleManifestDescriptor.Digest)) + w.Header().Set("Content-Length", strconv.Itoa(len(exampleManifest))) + if m == "GET" { + w.Write(exampleManifest) + } + case strings.Contains(p, "/v2/source/referrers/"): + var referrers []ocispec.Descriptor + if p == "/v2/source/referrers/"+exampleManifestDescriptor.Digest.String() { + referrers = []ocispec.Descriptor{exampleSignatureManifestDescriptor} + } + result := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: referrers, + } + if err := json.NewEncoder(w).Encode(result); err != nil { + panic(err) + } + case strings.Contains(p, "/manifests/") && (m == "HEAD" || m == "GET"): + w.Header().Set("Content-Type", ocispec.MediaTypeImageManifest) + w.Header().Set("Docker-Content-Digest", string(manifestDesc.Digest)) + w.Header().Set("Content-Length", strconv.Itoa(len([]byte(manifestBlob)))) + w.Write([]byte(manifestBlob)) + case strings.Contains(p, "/blobs/") && (m == "GET" || m == "HEAD"): + arr := strings.Split(p, "/") + digest := arr[len(arr)-1] + var desc ocispec.Descriptor + var content []byte + switch digest { + case layerDesc.Digest.String(): + desc = layerDesc + content = layerBlob + case configDesc.Digest.String(): + desc = configDesc + content = configBlob + case manifestDesc.Digest.String(): + desc = manifestDesc + content = manifestBlob + } + w.Header().Set("Content-Type", desc.MediaType) + w.Header().Set("Docker-Content-Digest", digest) + w.Header().Set("Content-Length", strconv.Itoa(len([]byte(content)))) + w.Write([]byte(content)) + case strings.Contains(p, "/manifests/") && m == "PUT": + w.WriteHeader(http.StatusCreated) + } + + })) + defer httpsServer.Close() + u, err := url.Parse(httpsServer.URL) + if err != nil { + panic(err) + } + remoteHost = u.Host + http.DefaultTransport = httpsServer.Client().Transport + + os.Exit(m.Run()) +} + +func ExampleCopy_remoteToRemote() { + reg, err := remote.NewRegistry(remoteHost) + if err != nil { + panic(err) // Handle error + } + ctx := context.Background() + src, err := reg.Repository(ctx, "source") + if err != nil { + panic(err) // Handle error + } + dst, err := reg.Repository(ctx, "target") + if err != nil { + panic(err) // Handle error + } + + tagName := "latest" + desc, err := oras.Copy(ctx, src, tagName, dst, tagName, oras.DefaultCopyOptions) + if err != nil { + panic(err) // Handle error + } + fmt.Println(desc.Digest) + + // Output: + // sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 +} + +func ExampleCopy_remoteToLocal() { + reg, err := remote.NewRegistry(remoteHost) + if err != nil { + panic(err) // Handle error + } + + ctx := context.Background() + src, err := reg.Repository(ctx, "source") + if err != nil { + panic(err) // Handle error + } + dst := memory.New() + + tagName := "latest" + desc, err := oras.Copy(ctx, src, tagName, dst, tagName, oras.DefaultCopyOptions) + if err != nil { + panic(err) // Handle error + } + fmt.Println(desc.Digest) + + // Output: + // sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 +} + +func ExampleCopy_localToLocal() { + src := exampleMemoryStore + dst := memory.New() + + tagName := "latest" + ctx := context.Background() + desc, err := oras.Copy(ctx, src, tagName, dst, tagName, oras.DefaultCopyOptions) + if err != nil { + panic(err) // Handle error + } + fmt.Println(desc.Digest) + + // Output: + // sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 +} + +func ExampleCopy_localToOciFile() { + src := exampleMemoryStore + tempDir, err := os.MkdirTemp("", "oras_oci_example_*") + if err != nil { + panic(err) // Handle error + } + defer os.RemoveAll(tempDir) + dst, err := oci.New(tempDir) + if err != nil { + panic(err) // Handle error + } + + tagName := "latest" + ctx := context.Background() + desc, err := oras.Copy(ctx, src, tagName, dst, tagName, oras.DefaultCopyOptions) + if err != nil { + panic(err) // Handle error + } + fmt.Println(desc.Digest) + + // Output: + // sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 +} + +func ExampleCopy_localToRemote() { + src := exampleMemoryStore + reg, err := remote.NewRegistry(remoteHost) + if err != nil { + panic(err) // Handle error + } + ctx := context.Background() + dst, err := reg.Repository(ctx, "target") + if err != nil { + panic(err) // Handle error + } + + tagName := "latest" + desc, err := oras.Copy(ctx, src, tagName, dst, tagName, oras.DefaultCopyOptions) + if err != nil { + panic(err) // Handle error + } + fmt.Println(desc.Digest) + + // Output: + // sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 +} + +// ExampleCopyArtifactManifestRemoteToLocal gives an example of copying +// an artifact manifest from a remote repository into memory. +func Example_copyArtifactManifestRemoteToLocal() { + src, err := remote.NewRepository(fmt.Sprintf("%s/source", remoteHost)) + if err != nil { + panic(err) + } + dst := memory.New() + ctx := context.Background() + + exampleDigest := "sha256:70c29a81e235dda5c2cebb8ec06eafd3cca346cbd91f15ac74cefd98681c5b3d" + descriptor, err := src.Resolve(ctx, exampleDigest) + if err != nil { + panic(err) + } + err = oras.CopyGraph(ctx, src, dst, descriptor, oras.DefaultCopyGraphOptions) + if err != nil { + panic(err) + } + + // verify that the artifact manifest described by the descriptor exists in dst + contentExists, err := dst.Exists(ctx, descriptor) + if err != nil { + panic(err) + } + fmt.Println(contentExists) + + // Output: + // true +} + +// ExampleExtendedCopyArtifactAndReferrersRemoteToLocal gives an example of +// copying an artifact along with its referrers from a remote repository into +// memory. +func Example_extendedCopyArtifactAndReferrersRemoteToLocal() { + src, err := remote.NewRepository(fmt.Sprintf("%s/source", remoteHost)) + if err != nil { + panic(err) + } + dst := memory.New() + ctx := context.Background() + + tagName := "latest" + // ExtendedCopy will copy the artifact tagged by "latest" along with all of its + // referrers from src to dst. + desc, err := oras.ExtendedCopy(ctx, src, tagName, dst, tagName, oras.DefaultExtendedCopyOptions) + if err != nil { + panic(err) + } + + fmt.Println(desc.Digest) + // Output: + // sha256:f396bc4d300934a39ca28ab0d5ac8a3573336d7d63c654d783a68cd1e2057662 +} diff --git a/example_pack_test.go b/example_pack_test.go new file mode 100644 index 00000000..a1d301c5 --- /dev/null +++ b/example_pack_test.go @@ -0,0 +1,98 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package oras_test + +import ( + "context" + "fmt" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2" + "oras.land/oras-go/v2/content" + "oras.land/oras-go/v2/content/memory" +) + +// ExampleImageV11RC4 demonstrates packing an OCI Image Manifest as defined in +// image-spec v1.1.0-rc4. +func ExamplePackManifest_imageV11RC4() { + // 0. Create a storage + store := memory.New() + + // 1. Set optional parameters + opts := oras.PackManifestOptions{ + ManifestAnnotations: map[string]string{ + // this timestamp will be automatically generated if not specified + // use a fixed value here in order to test the output + ocispec.AnnotationCreated: "2000-01-01T00:00:00Z", + }, + } + ctx := context.Background() + + // 2. Pack a manifest + artifactType := "application/vnd.example+type" + manifestDesc, err := oras.PackManifest(ctx, store, oras.PackManifestVersion1_1_RC4, artifactType, opts) + if err != nil { + panic(err) + } + fmt.Println("Manifest descriptor:", manifestDesc) + + // 3. Verify the packed manifest + manifestData, err := content.FetchAll(ctx, store, manifestDesc) + if err != nil { + panic(err) + } + fmt.Println("Manifest content:", string(manifestData)) + + // Output: + // Manifest descriptor: {application/vnd.oci.image.manifest.v1+json sha256:c259a195a48d8029d75449579c81269ca6225cd5b57d36073a7de6458afdfdbd 528 [] map[org.opencontainers.image.created:2000-01-01T00:00:00Z] [] application/vnd.example+type} + // Manifest content: {"schemaVersion":2,"mediaType":"application/vnd.oci.image.manifest.v1+json","artifactType":"application/vnd.example+type","config":{"mediaType":"application/vnd.oci.empty.v1+json","digest":"sha256:44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a","size":2,"data":"e30="},"layers":[{"mediaType":"application/vnd.oci.empty.v1+json","digest":"sha256:44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a","size":2,"data":"e30="}],"annotations":{"org.opencontainers.image.created":"2000-01-01T00:00:00Z"}} +} + +// ExampleImageV10 demonstrates packing an OCI Image Manifest as defined in +// image-spec v1.0.2. +func ExamplePackManifest_imageV10() { + // 0. Create a storage + store := memory.New() + + // 1. Set optional parameters + opts := oras.PackManifestOptions{ + ManifestAnnotations: map[string]string{ + // this timestamp will be automatically generated if not specified + // use a fixed value here in order to test the output + ocispec.AnnotationCreated: "2000-01-01T00:00:00Z", + }, + } + ctx := context.Background() + + // 2. Pack a manifest + artifactType := "application/vnd.example+type" + manifestDesc, err := oras.PackManifest(ctx, store, oras.PackManifestVersion1_0, artifactType, opts) + if err != nil { + panic(err) + } + fmt.Println("Manifest descriptor:", manifestDesc) + + // 3. Verify the packed manifest + manifestData, err := content.FetchAll(ctx, store, manifestDesc) + if err != nil { + panic(err) + } + fmt.Println("Manifest content:", string(manifestData)) + + // Output: + // Manifest descriptor: {application/vnd.oci.image.manifest.v1+json sha256:da221a11559704e4971c3dcf6564303707a333c8de8cb5475fc48b0072b36c19 308 [] map[org.opencontainers.image.created:2000-01-01T00:00:00Z] [] application/vnd.example+type} + // Manifest content: {"schemaVersion":2,"mediaType":"application/vnd.oci.image.manifest.v1+json","config":{"mediaType":"application/vnd.example+type","digest":"sha256:44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a","size":2},"layers":[],"annotations":{"org.opencontainers.image.created":"2000-01-01T00:00:00Z"}} +} diff --git a/example_test.go b/example_test.go index 91fb9da0..f5829641 100644 --- a/example_test.go +++ b/example_test.go @@ -16,341 +16,189 @@ limitations under the License. package oras_test import ( - "bytes" "context" - "encoding/json" "fmt" - "net/http" - "net/http/httptest" - "net/url" - "os" - "strconv" - "strings" - "testing" - "github.com/opencontainers/go-digest" - specs "github.com/opencontainers/image-spec/specs-go" - ocispec "github.com/opencontainers/image-spec/specs-go/v1" - "oras.land/oras-go/v2" - "oras.land/oras-go/v2/content/memory" + v1 "github.com/opencontainers/image-spec/specs-go/v1" + oras "oras.land/oras-go/v2" + "oras.land/oras-go/v2/content/file" "oras.land/oras-go/v2/content/oci" - "oras.land/oras-go/v2/internal/spec" "oras.land/oras-go/v2/registry/remote" + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/credentials" + "oras.land/oras-go/v2/registry/remote/retry" ) -var exampleMemoryStore oras.Target -var remoteHost string -var ( - exampleManifest, _ = json.Marshal(spec.Artifact{ - MediaType: spec.MediaTypeArtifactManifest, - ArtifactType: "example/content"}) - exampleManifestDescriptor = ocispec.Descriptor{ - MediaType: spec.MediaTypeArtifactManifest, - Digest: digest.Digest(digest.FromBytes(exampleManifest)), - Size: int64(len(exampleManifest))} - exampleSignatureManifest, _ = json.Marshal(spec.Artifact{ - MediaType: spec.MediaTypeArtifactManifest, - ArtifactType: "example/signature", - Subject: &exampleManifestDescriptor}) - exampleSignatureManifestDescriptor = ocispec.Descriptor{ - MediaType: spec.MediaTypeArtifactManifest, - Digest: digest.FromBytes(exampleSignatureManifest), - Size: int64(len(exampleSignatureManifest))} -) - -func pushBlob(ctx context.Context, mediaType string, blob []byte, target oras.Target) (desc ocispec.Descriptor, err error) { - desc = ocispec.Descriptor{ // Generate descriptor based on the media type and blob content - MediaType: mediaType, - Digest: digest.FromBytes(blob), // Calculate digest - Size: int64(len(blob)), // Include blob size - } - return desc, target.Push(ctx, desc, bytes.NewReader(blob)) // Push the blob to the registry target -} - -func generateManifestContent(config ocispec.Descriptor, layers ...ocispec.Descriptor) ([]byte, error) { - content := ocispec.Manifest{ - Config: config, // Set config blob - Layers: layers, // Set layer blobs - Versioned: specs.Versioned{SchemaVersion: 2}, - } - return json.Marshal(content) // Get json content -} - -func TestMain(m *testing.M) { - const exampleTag = "latest" - const exampleUploadUUid = "0bc84d80-837c-41d9-824e-1907463c53b3" - - // Setup example local target - exampleMemoryStore = memory.New() - layerBlob := []byte("Hello layer") - ctx := context.Background() - layerDesc, err := pushBlob(ctx, ocispec.MediaTypeImageLayer, layerBlob, exampleMemoryStore) // push layer blob - if err != nil { - panic(err) - } - configBlob := []byte("Hello config") - configDesc, err := pushBlob(ctx, ocispec.MediaTypeImageConfig, configBlob, exampleMemoryStore) // push config blob - if err != nil { - panic(err) - } - manifestBlob, err := generateManifestContent(configDesc, layerDesc) // generate a image manifest - if err != nil { - panic(err) - } - manifestDesc, err := pushBlob(ctx, ocispec.MediaTypeImageManifest, manifestBlob, exampleMemoryStore) // push manifest blob - if err != nil { - panic(err) - } - err = exampleMemoryStore.Tag(ctx, manifestDesc, exampleTag) +// ExamplePullFilesFromRemoteRepository gives an example of pulling files from +// a remote repository to the local file system. +func Example_pullFilesFromRemoteRepository() { + // 0. Create a file store + fs, err := file.New("/tmp/") if err != nil { panic(err) } + defer fs.Close() - // Setup example remote target - httpsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - p := r.URL.Path - m := r.Method - switch { - case strings.Contains(p, "/blobs/uploads/") && m == "POST": - w.Header().Set("Content-Type", ocispec.MediaTypeImageManifest) - w.Header().Set("Location", p+exampleUploadUUid) - w.WriteHeader(http.StatusAccepted) - case strings.Contains(p, "/blobs/uploads/"+exampleUploadUUid) && m == "GET": - w.WriteHeader(http.StatusCreated) - case strings.Contains(p, "/manifests/"+string(exampleSignatureManifestDescriptor.Digest)): - w.Header().Set("Content-Type", spec.MediaTypeArtifactManifest) - w.Header().Set("Docker-Content-Digest", string(exampleSignatureManifestDescriptor.Digest)) - w.Header().Set("Content-Length", strconv.Itoa(len(exampleSignatureManifest))) - w.Write(exampleSignatureManifest) - case strings.Contains(p, "/manifests/latest") && m == "PUT": - w.WriteHeader(http.StatusCreated) - case strings.Contains(p, "/manifests/"+string(exampleManifestDescriptor.Digest)), - strings.Contains(p, "/manifests/latest") && m == "HEAD": - w.Header().Set("Content-Type", spec.MediaTypeArtifactManifest) - w.Header().Set("Docker-Content-Digest", string(exampleManifestDescriptor.Digest)) - w.Header().Set("Content-Length", strconv.Itoa(len(exampleManifest))) - if m == "GET" { - w.Write(exampleManifest) - } - case strings.Contains(p, "/v2/source/referrers/"): - var referrers []ocispec.Descriptor - if p == "/v2/source/referrers/"+exampleManifestDescriptor.Digest.String() { - referrers = []ocispec.Descriptor{exampleSignatureManifestDescriptor} - } - result := ocispec.Index{ - Versioned: specs.Versioned{ - SchemaVersion: 2, // historical value. does not pertain to OCI or docker version - }, - MediaType: ocispec.MediaTypeImageIndex, - Manifests: referrers, - } - if err := json.NewEncoder(w).Encode(result); err != nil { - panic(err) - } - case strings.Contains(p, "/manifests/") && (m == "HEAD" || m == "GET"): - w.Header().Set("Content-Type", ocispec.MediaTypeImageManifest) - w.Header().Set("Docker-Content-Digest", string(manifestDesc.Digest)) - w.Header().Set("Content-Length", strconv.Itoa(len([]byte(manifestBlob)))) - w.Write([]byte(manifestBlob)) - case strings.Contains(p, "/blobs/") && (m == "GET" || m == "HEAD"): - arr := strings.Split(p, "/") - digest := arr[len(arr)-1] - var desc ocispec.Descriptor - var content []byte - switch digest { - case layerDesc.Digest.String(): - desc = layerDesc - content = layerBlob - case configDesc.Digest.String(): - desc = configDesc - content = configBlob - case manifestDesc.Digest.String(): - desc = manifestDesc - content = manifestBlob - } - w.Header().Set("Content-Type", desc.MediaType) - w.Header().Set("Docker-Content-Digest", digest) - w.Header().Set("Content-Length", strconv.Itoa(len([]byte(content)))) - w.Write([]byte(content)) - case strings.Contains(p, "/manifests/") && m == "PUT": - w.WriteHeader(http.StatusCreated) - } - - })) - defer httpsServer.Close() - u, err := url.Parse(httpsServer.URL) - if err != nil { - panic(err) - } - remoteHost = u.Host - http.DefaultTransport = httpsServer.Client().Transport - - os.Exit(m.Run()) -} - -func ExampleCopy_remoteToRemote() { - reg, err := remote.NewRegistry(remoteHost) - if err != nil { - panic(err) // Handle error - } + // 1. Connect to a remote repository ctx := context.Background() - src, err := reg.Repository(ctx, "source") + reg := "myregistry.example.com" + repo, err := remote.NewRepository(reg + "/myrepo") if err != nil { - panic(err) // Handle error + panic(err) } - dst, err := reg.Repository(ctx, "target") - if err != nil { - panic(err) // Handle error + // Note: The below code can be omitted if authentication is not required + repo.Client = &auth.Client{ + Client: retry.DefaultClient, + Cache: auth.NewCache(), + Credential: auth.StaticCredential(reg, auth.Credential{ + Username: "username", + Password: "password", + }), } - tagName := "latest" - desc, err := oras.Copy(ctx, src, tagName, dst, tagName, oras.DefaultCopyOptions) + // 2. Copy from the remote repository to the file store + tag := "latest" + manifestDescriptor, err := oras.Copy(ctx, repo, tag, fs, tag, oras.DefaultCopyOptions) if err != nil { - panic(err) // Handle error + panic(err) } - fmt.Println(desc.Digest) - - // Output: - // sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 + fmt.Println("manifest descriptor:", manifestDescriptor) } -func ExampleCopy_remoteToLocal() { - reg, err := remote.NewRegistry(remoteHost) +// ExamplePullImageFromRemoteRepository gives an example of pulling an image +// from a remote repository to an OCI Image layout folder. +func Example_pullImageFromRemoteRepository() { + // 0. Create an OCI layout store + store, err := oci.New("/tmp/oci-layout-root") if err != nil { - panic(err) // Handle error + panic(err) } + // 1. Connect to a remote repository ctx := context.Background() - src, err := reg.Repository(ctx, "source") + reg := "myregistry.example.com" + repo, err := remote.NewRepository(reg + "/myrepo") if err != nil { - panic(err) // Handle error + panic(err) } - dst := memory.New() - - tagName := "latest" - desc, err := oras.Copy(ctx, src, tagName, dst, tagName, oras.DefaultCopyOptions) - if err != nil { - panic(err) // Handle error + // Note: The below code can be omitted if authentication is not required + repo.Client = &auth.Client{ + Client: retry.DefaultClient, + Cache: auth.NewCache(), + Credential: auth.StaticCredential(reg, auth.Credential{ + Username: "username", + Password: "password", + }), } - fmt.Println(desc.Digest) - - // Output: - // sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 -} - -func ExampleCopy_localToLocal() { - src := exampleMemoryStore - dst := memory.New() - tagName := "latest" - ctx := context.Background() - desc, err := oras.Copy(ctx, src, tagName, dst, tagName, oras.DefaultCopyOptions) + // 2. Copy from the remote repository to the OCI layout store + tag := "latest" + manifestDescriptor, err := oras.Copy(ctx, repo, tag, store, tag, oras.DefaultCopyOptions) if err != nil { - panic(err) // Handle error + panic(err) } - fmt.Println(desc.Digest) - - // Output: - // sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 + fmt.Println("manifest descriptor:", manifestDescriptor) } -func ExampleCopy_localToOciFile() { - src := exampleMemoryStore - tempDir, err := os.MkdirTemp("", "oras_oci_example_*") - if err != nil { - panic(err) // Handle error - } - defer os.RemoveAll(tempDir) - dst, err := oci.New(tempDir) +// ExamplePullImageUsingDockerCredentials gives an example of pulling an image +// from a remote repository to an OCI Image layout folder using Docker +// credentials. +func Example_pullImageUsingDockerCredentials() { + // 0. Create an OCI layout store + store, err := oci.New("/tmp/oci-layout-root") if err != nil { - panic(err) // Handle error + panic(err) } - tagName := "latest" + // 1. Connect to a remote repository ctx := context.Background() - desc, err := oras.Copy(ctx, src, tagName, dst, tagName, oras.DefaultCopyOptions) + reg := "docker.io" + repo, err := remote.NewRepository(reg + "/user/my-repo") if err != nil { - panic(err) // Handle error + panic(err) } - fmt.Println(desc.Digest) - // Output: - // sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 -} - -func ExampleCopy_localToRemote() { - src := exampleMemoryStore - reg, err := remote.NewRegistry(remoteHost) + // prepare authentication using Docker credentials + storeOpts := credentials.StoreOptions{} + credStore, err := credentials.NewStoreFromDocker(storeOpts) if err != nil { - panic(err) // Handle error + panic(err) } - ctx := context.Background() - dst, err := reg.Repository(ctx, "target") - if err != nil { - panic(err) // Handle error + repo.Client = &auth.Client{ + Client: retry.DefaultClient, + Cache: auth.NewCache(), + Credential: credentials.Credential(credStore), // Use the credentials store } - tagName := "latest" - desc, err := oras.Copy(ctx, src, tagName, dst, tagName, oras.DefaultCopyOptions) + // 2. Copy from the remote repository to the OCI layout store + tag := "latest" + manifestDescriptor, err := oras.Copy(ctx, repo, tag, store, tag, oras.DefaultCopyOptions) if err != nil { - panic(err) // Handle error + panic(err) } - fmt.Println(desc.Digest) - // Output: - // sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 + fmt.Println("manifest pulled:", manifestDescriptor.Digest, manifestDescriptor.MediaType) } -// Example_copyArtifactManifestRemoteToLocal gives an example of copying -// an artifact manifest from a remote repository to local. -func Example_copyArtifactManifestRemoteToLocal() { - src, err := remote.NewRepository(fmt.Sprintf("%s/source", remoteHost)) +// ExamplePushFilesToRemoteRepository gives an example of pushing local files +// to a remote repository. +func Example_pushFilesToRemoteRepository() { + // 0. Create a file store + fs, err := file.New("/tmp/") if err != nil { panic(err) } - dst := memory.New() + defer fs.Close() ctx := context.Background() - exampleDigest := "sha256:70c29a81e235dda5c2cebb8ec06eafd3cca346cbd91f15ac74cefd98681c5b3d" - descriptor, err := src.Resolve(ctx, exampleDigest) - if err != nil { - panic(err) + // 1. Add files to the file store + mediaType := "application/vnd.test.file" + fileNames := []string{"/tmp/myfile"} + fileDescriptors := make([]v1.Descriptor, 0, len(fileNames)) + for _, name := range fileNames { + fileDescriptor, err := fs.Add(ctx, name, mediaType, "") + if err != nil { + panic(err) + } + fileDescriptors = append(fileDescriptors, fileDescriptor) + fmt.Printf("file descriptor for %s: %v\n", name, fileDescriptor) } - err = oras.CopyGraph(ctx, src, dst, descriptor, oras.DefaultCopyGraphOptions) + + // 2. Pack the files and tag the packed manifest + artifactType := "application/vnd.test.artifact" + opts := oras.PackManifestOptions{ + Layers: fileDescriptors, + } + manifestDescriptor, err := oras.PackManifest(ctx, fs, oras.PackManifestVersion1_1_RC4, artifactType, opts) if err != nil { panic(err) } + fmt.Println("manifest descriptor:", manifestDescriptor) - // verify that the artifact manifest described by the descriptor exists in dst - contentExists, err := dst.Exists(ctx, descriptor) - if err != nil { + tag := "latest" + if err = fs.Tag(ctx, manifestDescriptor, tag); err != nil { panic(err) } - fmt.Println(contentExists) - - // Output: - // true -} -// Example_extendedCopyArtifactAndReferrersRemoteToLocal gives an example of -// copying an artifact along with its referrers from a remote repository to local. -func Example_extendedCopyArtifactAndReferrersRemoteToLocal() { - src, err := remote.NewRepository(fmt.Sprintf("%s/source", remoteHost)) + // 3. Connect to a remote repository + reg := "myregistry.example.com" + repo, err := remote.NewRepository(reg + "/myrepo") if err != nil { panic(err) } - dst := memory.New() - ctx := context.Background() + // Note: The below code can be omitted if authentication is not required + repo.Client = &auth.Client{ + Client: retry.DefaultClient, + Cache: auth.NewCache(), + Credential: auth.StaticCredential(reg, auth.Credential{ + Username: "username", + Password: "password", + }), + } - tagName := "latest" - // ExtendedCopy will copy the artifact tagged by "latest" along with all of its - // referrers from src to dst. - desc, err := oras.ExtendedCopy(ctx, src, tagName, dst, tagName, oras.DefaultExtendedCopyOptions) + // 4. Copy from the file store to the remote repository + _, err = oras.Copy(ctx, fs, tag, repo, tag, oras.DefaultCopyOptions) if err != nil { panic(err) } - - fmt.Println(desc.Digest) - // Output: - // sha256:f396bc4d300934a39ca28ab0d5ac8a3573336d7d63c654d783a68cd1e2057662 } diff --git a/extendedcopy_test.go b/extendedcopy_test.go index 6c9d7f3a..08a0a8c3 100644 --- a/extendedcopy_test.go +++ b/extendedcopy_test.go @@ -344,6 +344,124 @@ func TestExtendedCopyGraph_PartialCopy(t *testing.T) { } } +func TestExtendedCopyGraph_artifactIndex(t *testing.T) { + // generate test content + var blobs [][]byte + var descs []ocispec.Descriptor + appendBlob := func(mediaType string, blob []byte) { + blobs = append(blobs, blob) + descs = append(descs, ocispec.Descriptor{ + MediaType: mediaType, + Digest: digest.FromBytes(blob), + Size: int64(len(blob)), + }) + } + generateManifest := func(subject *ocispec.Descriptor, config ocispec.Descriptor, layers ...ocispec.Descriptor) { + manifest := ocispec.Manifest{ + Subject: subject, + Config: config, + Layers: layers, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatal(err) + } + appendBlob(ocispec.MediaTypeImageManifest, manifestJSON) + } + generateIndex := func(subject *ocispec.Descriptor, manifests ...ocispec.Descriptor) { + index := ocispec.Index{ + Subject: subject, + Manifests: manifests, + } + indexJSON, err := json.Marshal(index) + if err != nil { + t.Fatal(err) + } + appendBlob(ocispec.MediaTypeImageIndex, indexJSON) + } + + appendBlob(ocispec.MediaTypeImageConfig, []byte("config_1")) // Blob 0 + appendBlob(ocispec.MediaTypeImageLayer, []byte("layer_1")) // Blob 1 + generateManifest(nil, descs[0], descs[1]) // Blob 2 + appendBlob(ocispec.MediaTypeImageConfig, []byte("config_2")) // Blob 3 + appendBlob(ocispec.MediaTypeImageLayer, []byte("layer_2")) // Blob 4 + generateManifest(nil, descs[3], descs[4]) // Blob 5 + appendBlob(ocispec.MediaTypeImageLayer, []byte("{}")) // Blob 6 + appendBlob(ocispec.MediaTypeImageLayer, []byte("sbom_1")) // Blob 7 + generateManifest(&descs[2], descs[6], descs[7]) // Blob 8 + appendBlob(ocispec.MediaTypeImageLayer, []byte("sbom_2")) // Blob 9 + generateManifest(&descs[5], descs[6], descs[9]) // Blob 10 + generateIndex(nil, []ocispec.Descriptor{descs[2], descs[5]}...) // Blob 11 (root) + generateIndex(&descs[11], []ocispec.Descriptor{descs[8], descs[10]}...) // Blob 12 (root) + + ctx := context.Background() + verifyCopy := func(dst content.Fetcher, copiedIndice []int, uncopiedIndice []int) { + for _, i := range copiedIndice { + got, err := content.FetchAll(ctx, dst, descs[i]) + if err != nil { + t.Errorf("content[%d] error = %v, wantErr %v", i, err, false) + continue + } + if want := blobs[i]; !bytes.Equal(got, want) { + t.Errorf("content[%d] = %v, want %v", i, got, want) + } + } + for _, i := range uncopiedIndice { + if _, err := content.FetchAll(ctx, dst, descs[i]); !errors.Is(err, errdef.ErrNotFound) { + t.Errorf("content[%d] error = %v, wantErr %v", i, err, errdef.ErrNotFound) + } + } + } + + src := memory.New() + for i := range blobs { + err := src.Push(ctx, descs[i], bytes.NewReader(blobs[i])) + if err != nil { + t.Fatalf("failed to push test content to src: %d: %v", i, err) + } + } + + // test extended copy by descs[0] + dst := memory.New() + if err := oras.ExtendedCopyGraph(ctx, src, dst, descs[0], oras.ExtendedCopyGraphOptions{}); err != nil { + t.Fatalf("ExtendedCopyGraph() error = %v, wantErr %v", err, false) + } + // all blobs should be copied + copiedIndice := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + uncopiedIndice := []int{} + verifyCopy(dst, copiedIndice, uncopiedIndice) + + // test extended copy by descs[2] + dst = memory.New() + if err := oras.ExtendedCopyGraph(ctx, src, dst, descs[2], oras.ExtendedCopyGraphOptions{}); err != nil { + t.Fatalf("ExtendedCopyGraph() error = %v, wantErr %v", err, false) + } + // all blobs should be copied + copiedIndice = []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + uncopiedIndice = []int{} + verifyCopy(dst, copiedIndice, uncopiedIndice) + + // test extended copy by descs[8] + dst = memory.New() + if err := oras.ExtendedCopyGraph(ctx, src, dst, descs[8], oras.ExtendedCopyGraphOptions{}); err != nil { + t.Fatalf("ExtendedCopyGraph() error = %v, wantErr %v", err, false) + } + // all blobs should be copied + copiedIndice = []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + uncopiedIndice = []int{} + verifyCopy(dst, copiedIndice, uncopiedIndice) + + // test extended copy by descs[11] + dst = memory.New() + if err := oras.ExtendedCopyGraph(ctx, src, dst, descs[11], oras.ExtendedCopyGraphOptions{}); err != nil { + t.Fatalf("ExtendedCopyGraph() error = %v, wantErr %v", err, false) + } + // all blobs should be copied + copiedIndice = []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + uncopiedIndice = []int{} + verifyCopy(dst, copiedIndice, uncopiedIndice) +} + func TestExtendedCopyGraph_WithDepthOption(t *testing.T) { // generate test content var blobs [][]byte diff --git a/go.mod b/go.mod index 062b4bed..14a85315 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,9 @@ module oras.land/oras-go/v2 -go 1.19 +go 1.20 require ( github.com/opencontainers/go-digest v1.0.0 - github.com/opencontainers/image-spec v1.1.0-rc.3 - golang.org/x/sync v0.2.0 + github.com/opencontainers/image-spec v1.1.0-rc5 + golang.org/x/sync v0.5.0 ) diff --git a/go.sum b/go.sum index 24fad321..65a6f863 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= -github.com/opencontainers/image-spec v1.1.0-rc.3 h1:GT9Xon8YrLxz6N7sErbN81V8J4lOQKGUZQmI3ioviqU= -github.com/opencontainers/image-spec v1.1.0-rc.3/go.mod h1:X4pATf0uXsnn3g5aiGIsVnJBR4mxhKzfwmvK/B2NTm8= -golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= -golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +github.com/opencontainers/image-spec v1.1.0-rc5 h1:Ygwkfw9bpDvs+c9E34SdgGOj41dX/cbdlwvlWt0pnFI= +github.com/opencontainers/image-spec v1.1.0-rc5/go.mod h1:X4pATf0uXsnn3g5aiGIsVnJBR4mxhKzfwmvK/B2NTm8= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= diff --git a/internal/container/set/set.go b/internal/container/set/set.go index a084e288..07c96d47 100644 --- a/internal/container/set/set.go +++ b/internal/container/set/set.go @@ -33,3 +33,8 @@ func (s Set[T]) Contains(item T) bool { _, ok := s[item] return ok } + +// Delete deletes an item from the set. +func (s Set[T]) Delete(item T) { + delete(s, item) +} diff --git a/internal/container/set/set_test.go b/internal/container/set/set_test.go index 94f87c7a..12dc6cea 100644 --- a/internal/container/set/set_test.go +++ b/internal/container/set/set_test.go @@ -52,4 +52,12 @@ func TestSet(t *testing.T) { if got, want := len(set), 2; got != want { t.Errorf("len(Set) = %v, want %v", got, want) } + // test deleting a key + set.Delete(key1) + if got, want := set.Contains(key1), false; got != want { + t.Errorf("Set.Contains(%s) = %v, want %v", key1, got, want) + } + if got, want := len(set), 1; got != want { + t.Errorf("len(Set) = %v, want %v", got, want) + } } diff --git a/internal/graph/memory.go b/internal/graph/memory.go index 0aa25aee..bbb57556 100644 --- a/internal/graph/memory.go +++ b/internal/graph/memory.go @@ -23,6 +23,7 @@ import ( ocispec "github.com/opencontainers/image-spec/specs-go/v1" "oras.land/oras-go/v2/content" "oras.land/oras-go/v2/errdef" + "oras.land/oras-go/v2/internal/container/set" "oras.land/oras-go/v2/internal/descriptor" "oras.land/oras-go/v2/internal/status" "oras.land/oras-go/v2/internal/syncutil" @@ -30,35 +31,31 @@ import ( // Memory is a memory based PredecessorFinder. type Memory struct { - predecessors sync.Map // map[descriptor.Descriptor]map[descriptor.Descriptor]ocispec.Descriptor - indexed sync.Map // map[descriptor.Descriptor]any + nodes map[descriptor.Descriptor]ocispec.Descriptor // nodes saves the map keys of ocispec.Descriptor + predecessors map[descriptor.Descriptor]set.Set[descriptor.Descriptor] + successors map[descriptor.Descriptor]set.Set[descriptor.Descriptor] + lock sync.RWMutex } // NewMemory creates a new memory PredecessorFinder. func NewMemory() *Memory { - return &Memory{} + return &Memory{ + nodes: make(map[descriptor.Descriptor]ocispec.Descriptor), + predecessors: make(map[descriptor.Descriptor]set.Set[descriptor.Descriptor]), + successors: make(map[descriptor.Descriptor]set.Set[descriptor.Descriptor]), + } } // Index indexes predecessors for each direct successor of the given node. -// There is no data consistency issue as long as deletion is not implemented -// for the underlying storage. func (m *Memory) Index(ctx context.Context, fetcher content.Fetcher, node ocispec.Descriptor) error { - successors, err := content.Successors(ctx, fetcher, node) - if err != nil { - return err - } - - m.index(ctx, node, successors) - return nil + _, err := m.index(ctx, fetcher, node) + return err } // Index indexes predecessors for all the successors of the given node. -// There is no data consistency issue as long as deletion is not implemented -// for the underlying storage. func (m *Memory) IndexAll(ctx context.Context, fetcher content.Fetcher, node ocispec.Descriptor) error { // track content status tracker := status.NewTracker() - var fn syncutil.GoFunc[ocispec.Descriptor] fn = func(ctx context.Context, region *syncutil.LimitedRegion, desc ocispec.Descriptor) error { // skip the node if other go routine is working on it @@ -66,15 +63,7 @@ func (m *Memory) IndexAll(ctx context.Context, fetcher content.Fetcher, node oci if !committed { return nil } - - // skip the node if it has been indexed - key := descriptor.FromOCI(desc) - _, exists := m.indexed.Load(key) - if exists { - return nil - } - - successors, err := content.Successors(ctx, fetcher, desc) + successors, err := m.index(ctx, fetcher, desc) if err != nil { if errors.Is(err, errdef.ErrNotFound) { // skip the node if it does not exist @@ -82,9 +71,6 @@ func (m *Memory) IndexAll(ctx context.Context, fetcher content.Fetcher, node oci } return err } - m.index(ctx, desc, successors) - m.indexed.Store(key, nil) - if len(successors) > 0 { // traverse and index successors return syncutil.Go(ctx, nil, fn, successors...) @@ -96,39 +82,73 @@ func (m *Memory) IndexAll(ctx context.Context, fetcher content.Fetcher, node oci // Predecessors returns the nodes directly pointing to the current node. // Predecessors returns nil without error if the node does not exists in the -// store. -// Like other operations, calling Predecessors() is go-routine safe. However, -// it does not necessarily correspond to any consistent snapshot of the stored -// contents. +// store. Like other operations, calling Predecessors() is go-routine safe. +// However, it does not necessarily correspond to any consistent snapshot of +// the stored contents. func (m *Memory) Predecessors(_ context.Context, node ocispec.Descriptor) ([]ocispec.Descriptor, error) { + m.lock.RLock() + defer m.lock.RUnlock() + key := descriptor.FromOCI(node) - value, exists := m.predecessors.Load(key) + set, exists := m.predecessors[key] if !exists { return nil, nil } - predecessors := value.(*sync.Map) - var res []ocispec.Descriptor - predecessors.Range(func(key, value interface{}) bool { - res = append(res, value.(ocispec.Descriptor)) - return true - }) + for k := range set { + res = append(res, m.nodes[k]) + } return res, nil } +// Remove removes the node from its predecessors and successors. +func (m *Memory) Remove(ctx context.Context, node ocispec.Descriptor) error { + m.lock.Lock() + defer m.lock.Unlock() + + nodeKey := descriptor.FromOCI(node) + // remove the node from its successors' predecessor list + for successorKey := range m.successors[nodeKey] { + predecessorEntry := m.predecessors[successorKey] + predecessorEntry.Delete(nodeKey) + + // if none of the predecessors of the node still exists, we remove the + // predecessors entry. Otherwise, we do not remove the entry. + if len(predecessorEntry) == 0 { + delete(m.predecessors, successorKey) + } + } + delete(m.successors, nodeKey) + delete(m.nodes, nodeKey) + return nil +} + // index indexes predecessors for each direct successor of the given node. -// There is no data consistency issue as long as deletion is not implemented -// for the underlying storage. -func (m *Memory) index(ctx context.Context, node ocispec.Descriptor, successors []ocispec.Descriptor) { - if len(successors) == 0 { - return +func (m *Memory) index(ctx context.Context, fetcher content.Fetcher, node ocispec.Descriptor) ([]ocispec.Descriptor, error) { + successors, err := content.Successors(ctx, fetcher, node) + if err != nil { + return nil, err } + m.lock.Lock() + defer m.lock.Unlock() + + // index the node + nodeKey := descriptor.FromOCI(node) + m.nodes[nodeKey] = node - predecessorKey := descriptor.FromOCI(node) + // for each successor, put it into the node's successors list, and + // put node into the succeesor's predecessors list + successorSet := set.New[descriptor.Descriptor]() + m.successors[nodeKey] = successorSet for _, successor := range successors { successorKey := descriptor.FromOCI(successor) - value, _ := m.predecessors.LoadOrStore(successorKey, &sync.Map{}) - predecessors := value.(*sync.Map) - predecessors.Store(predecessorKey, node) + successorSet.Add(successorKey) + predecessorSet, exists := m.predecessors[successorKey] + if !exists { + predecessorSet = set.New[descriptor.Descriptor]() + m.predecessors[successorKey] = predecessorSet + } + predecessorSet.Add(nodeKey) } + return successors, nil } diff --git a/internal/graph/memory_test.go b/internal/graph/memory_test.go new file mode 100644 index 00000000..f9f9e89d --- /dev/null +++ b/internal/graph/memory_test.go @@ -0,0 +1,612 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package graph + +import ( + "bytes" + "context" + "encoding/json" + "io" + "reflect" + "testing" + + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2/internal/cas" + "oras.land/oras-go/v2/internal/descriptor" +) + +// +------------------------------+ +// | | +// | +-----------+ | +// | |A(manifest)| | +// | +-----+-----+ | +// | | | +// | +------------+ | +// | | | | +// | v v | +// | +-----+-----+ +---+----+ | +// | |B(manifest)| |C(layer)| | +// | +-----+-----+ +--------+ | +// | | | +// | v | +// | +---+----+ | +// | |D(layer)| | +// | +--------+ | +// | | +// |------------------------------+ +func TestMemory_IndexAndRemove(t *testing.T) { + testFetcher := cas.NewMemory() + testMemory := NewMemory() + ctx := context.Background() + + // generate test content + var blobs [][]byte + var descs []ocispec.Descriptor + appendBlob := func(mediaType string, blob []byte) ocispec.Descriptor { + blobs = append(blobs, blob) + descs = append(descs, ocispec.Descriptor{ + MediaType: mediaType, + Digest: digest.FromBytes(blob), + Size: int64(len(blob)), + }) + return descs[len(descs)-1] + } + generateManifest := func(layers ...ocispec.Descriptor) ocispec.Descriptor { + manifest := ocispec.Manifest{ + Config: ocispec.Descriptor{MediaType: "test config"}, + Layers: layers, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatal(err) + } + return appendBlob(ocispec.MediaTypeImageManifest, manifestJSON) + } + descC := appendBlob("layer node C", []byte("Node C is a layer")) // blobs[0], layer "C" + descD := appendBlob("layer node D", []byte("Node D is a layer")) // blobs[1], layer "D" + descB := generateManifest(descs[0:2]...) // blobs[2], manifest "B" + descA := generateManifest(descs[1:3]...) // blobs[3], manifest "A" + + // prepare the content in the fetcher, so that it can be used to test Index + testContents := []ocispec.Descriptor{descC, descD, descB, descA} + for i := 0; i < len(blobs); i++ { + testFetcher.Push(ctx, testContents[i], bytes.NewReader(blobs[i])) + } + + // make sure that testFetcher works + rc, err := testFetcher.Fetch(ctx, descA) + if err != nil { + t.Errorf("testFetcher.Fetch() error = %v", err) + } + got, err := io.ReadAll(rc) + if err != nil { + t.Errorf("testFetcher.Fetch().Read() error = %v", err) + } + err = rc.Close() + if err != nil { + t.Errorf("testFetcher.Fetch().Close() error = %v", err) + } + if !bytes.Equal(got, blobs[3]) { + t.Errorf("testFetcher.Fetch() = %v, want %v", got, blobs[4]) + } + + nodeKeyA := descriptor.FromOCI(descA) + nodeKeyB := descriptor.FromOCI(descB) + nodeKeyC := descriptor.FromOCI(descC) + nodeKeyD := descriptor.FromOCI(descD) + + // index and check the information of node D + testMemory.Index(ctx, testFetcher, descD) + // 1. verify its existence in testMemory.nodes + if _, exists := testMemory.nodes[nodeKeyD]; !exists { + t.Errorf("nodes entry of %s should exist", "D") + } + // 2. verify that the entry of D exists in testMemory.successors and it's empty + successorsD, exists := testMemory.successors[nodeKeyD] + if !exists { + t.Errorf("successor entry of %s should exist", "D") + } + if successorsD == nil { + t.Errorf("successors of %s should be an empty set, not nil", "D") + } + if len(successorsD) != 0 { + t.Errorf("successors of %s should be empty", "D") + } + // 3. there should be no entry of D in testMemory.predecessors yet + _, exists = testMemory.predecessors[nodeKeyD] + if exists { + t.Errorf("predecessor entry of %s should not exist yet", "D") + } + + // index and check the information of node C + testMemory.Index(ctx, testFetcher, descC) + // 1. verify its existence in memory.nodes + if _, exists := testMemory.nodes[nodeKeyC]; !exists { + t.Errorf("nodes entry of %s should exist", "C") + } + // 2. verify that the entry of C exists in testMemory.successors and it's empty + successorsC, exists := testMemory.successors[nodeKeyC] + if !exists { + t.Errorf("successor entry of %s should exist", "C") + } + if successorsC == nil { + t.Errorf("successors of %s should be an empty set, not nil", "C") + } + if len(successorsC) != 0 { + t.Errorf("successors of %s should be empty", "C") + } + // 3. there should be no entry of C in testMemory.predecessors yet + _, exists = testMemory.predecessors[nodeKeyC] + if exists { + t.Errorf("predecessor entry of %s should not exist yet", "C") + } + + // index and check the information of node A + testMemory.Index(ctx, testFetcher, descA) + // 1. verify its existence in testMemory.nodes + if _, exists := testMemory.nodes[nodeKeyA]; !exists { + t.Errorf("nodes entry of %s should exist", "A") + } + // 2. verify that the entry of A exists in testMemory.successors and it contains + // node B and node D + successorsA, exists := testMemory.successors[nodeKeyA] + if !exists { + t.Errorf("successor entry of %s should exist", "A") + } + if successorsA == nil { + t.Errorf("successors of %s should be a set, not nil", "A") + } + if !successorsA.Contains(nodeKeyB) { + t.Errorf("successors of %s should contain %s", "A", "B") + } + if !successorsA.Contains(nodeKeyD) { + t.Errorf("successors of %s should contain %s", "A", "D") + } + // 3. verify that node A exists in the predecessors lists of its successors. + // there should be an entry of D in testMemory.predecessors by now and it + // should contain A but not B + predecessorsD, exists := testMemory.predecessors[nodeKeyD] + if !exists { + t.Errorf("predecessor entry of %s should exist by now", "D") + } + if !predecessorsD.Contains(nodeKeyA) { + t.Errorf("predecessors of %s should contain %s", "D", "A") + } + if predecessorsD.Contains(nodeKeyB) { + t.Errorf("predecessors of %s should not contain %s yet", "D", "B") + } + // there should be an entry of B in testMemory.predecessors now + // and it should contain A + predecessorsB, exists := testMemory.predecessors[nodeKeyB] + if !exists { + t.Errorf("predecessor entry of %s should exist by now", "B") + } + if !predecessorsB.Contains(nodeKeyA) { + t.Errorf("predecessors of %s should contain %s", "B", "A") + } + // 4. there should be no entry of A in testMemory.predecessors + _, exists = testMemory.predecessors[nodeKeyA] + if exists { + t.Errorf("predecessor entry of %s should not exist", "A") + } + + // index and check the information of node B + testMemory.Index(ctx, testFetcher, descB) + // 1. verify its existence in testMemory.nodes + if _, exists := testMemory.nodes[nodeKeyB]; !exists { + t.Errorf("nodes entry of %s should exist", "B") + } + // 2. verify that the entry of B exists in testMemory.successors and it contains + // node C and node D + successorsB, exists := testMemory.successors[nodeKeyB] + if !exists { + t.Errorf("successor entry of %s should exist", "B") + } + if successorsB == nil { + t.Errorf("successors of %s should be a set, not nil", "B") + } + if !successorsB.Contains(nodeKeyC) { + t.Errorf("successors of %s should contain %s", "B", "C") + } + if !successorsB.Contains(nodeKeyD) { + t.Errorf("successors of %s should contain %s", "B", "D") + } + // 3. verify that node B exists in the predecessors lists of its successors. + // there should be an entry of C in testMemory.predecessors by now + // and it should contain B + predecessorsC, exists := testMemory.predecessors[nodeKeyC] + if !exists { + t.Errorf("predecessor entry of %s should exist by now", "C") + } + if !predecessorsC.Contains(nodeKeyB) { + t.Errorf("predecessors of %s should contain %s", "C", "B") + } + // predecessors of D should have been updated now to have node A and B + if !predecessorsD.Contains(nodeKeyB) { + t.Errorf("predecessors of %s should contain %s", "D", "B") + } + if !predecessorsD.Contains(nodeKeyA) { + t.Errorf("predecessors of %s should contain %s", "D", "A") + } + + // remove node B and check the stored information + testMemory.Remove(ctx, descB) + // 1. verify that node B no longer exists in testMemory.nodes + if _, exists := testMemory.nodes[nodeKeyB]; exists { + t.Errorf("nodes entry of %s should no longer exist", "B") + } + // 2. verify B' predecessors info: B's entry in testMemory.predecessors should + // still exist, since its predecessor A still exists + predecessorsB, exists = testMemory.predecessors[nodeKeyB] + if !exists { + t.Errorf("testDeletableMemory.predecessors should still contain the entry of %v", "B") + } + if !predecessorsB.Contains(nodeKeyA) { + t.Errorf("predecessors of %s should still contain %s", "B", "A") + } + // 3. verify B' successors info: B's entry in testMemory.successors should no + // longer exist + if _, exists := testMemory.successors[nodeKeyB]; exists { + t.Errorf("testDeletableMemory.successors should not contain the entry of %v", "B") + } + // 4. verify B' predecessors' successors info: B should still exist in A's + // successors + if !successorsA.Contains(nodeKeyB) { + t.Errorf("successors of %s should still contain %s", "A", "B") + } + // 5. verify B' successors' predecessors info: C's entry in testMemory.predecessors + // should no longer exist, since C's only predecessor B is already deleted + if _, exists = testMemory.predecessors[nodeKeyC]; exists { + t.Errorf("predecessor entry of %s should no longer exist by now, since all its predecessors have been deleted", "C") + } + // B should no longer exist in D's predecessors + if predecessorsD.Contains(nodeKeyB) { + t.Errorf("predecessors of %s should not contain %s", "D", "B") + } + // but A still exists in D's predecessors + if !predecessorsD.Contains(nodeKeyA) { + t.Errorf("predecessors of %s should still contain %s", "D", "A") + } + + // remove node A and check the stored information + testMemory.Remove(ctx, descA) + // 1. verify that node A no longer exists in testMemory.nodes + if _, exists := testMemory.nodes[nodeKeyA]; exists { + t.Errorf("nodes entry of %s should no longer exist", "A") + } + // 2. verify A' successors info: A's entry in testMemory.successors should no + // longer exist + if _, exists := testMemory.successors[nodeKeyA]; exists { + t.Errorf("testDeletableMemory.successors should not contain the entry of %v", "A") + } + // 3. verify A' successors' predecessors info: D's entry in testMemory.predecessors + // should no longer exist, since all predecessors of D are already deleted + if _, exists = testMemory.predecessors[nodeKeyD]; exists { + t.Errorf("predecessor entry of %s should no longer exist by now, since all its predecessors have been deleted", "D") + } + // B's entry in testMemory.predecessors should no longer exist, since B's only + // predecessor A is already deleted + if _, exists = testMemory.predecessors[nodeKeyB]; exists { + t.Errorf("predecessor entry of %s should no longer exist by now, since all its predecessors have been deleted", "B") + } +} + +// +-----------------------------------------------+ +// | | +// | +--------+ | +// | |A(index)| | +// | +---+----+ | +// | | | +// | -+--------------+--------------+- | +// | | | | | +// | +-----v-----+ +-----v-----+ +-----v-----+ | +// | |B(manifest)| |C(manifest)| |D(manifest)| | +// | +--------+--+ ++---------++ +--+--------+ | +// | | | | | | +// | | | | | | +// | v v v v | +// | ++------++ ++------++ | +// | |E(layer)| |F(layer)| | +// | +--------+ +--------+ | +// | | +// +-----------------------------------------------+ +func TestMemory_IndexAllAndPredecessors(t *testing.T) { + testFetcher := cas.NewMemory() + testMemory := NewMemory() + ctx := context.Background() + + // generate test content + var blobs [][]byte + var descriptors []ocispec.Descriptor + appendBlob := func(mediaType string, blob []byte) ocispec.Descriptor { + blobs = append(blobs, blob) + descriptors = append(descriptors, ocispec.Descriptor{ + MediaType: mediaType, + Digest: digest.FromBytes(blob), + Size: int64(len(blob)), + }) + return descriptors[len(descriptors)-1] + } + generateManifest := func(layers ...ocispec.Descriptor) ocispec.Descriptor { + manifest := ocispec.Manifest{ + Config: ocispec.Descriptor{MediaType: "test config"}, + Layers: layers, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatal(err) + } + return appendBlob(ocispec.MediaTypeImageManifest, manifestJSON) + } + generateIndex := func(manifests ...ocispec.Descriptor) ocispec.Descriptor { + index := ocispec.Index{ + Manifests: manifests, + } + indexJSON, err := json.Marshal(index) + if err != nil { + t.Fatal(err) + } + return appendBlob(ocispec.MediaTypeImageIndex, indexJSON) + } + descE := appendBlob("layer node E", []byte("Node E is a layer")) // blobs[0], layer "E" + descF := appendBlob("layer node F", []byte("Node F is a layer")) // blobs[1], layer "F" + descB := generateManifest(descriptors[0:1]...) // blobs[2], manifest "B" + descC := generateManifest(descriptors[0:2]...) // blobs[3], manifest "C" + descD := generateManifest(descriptors[1:2]...) // blobs[4], manifest "D" + descA := generateIndex(descriptors[2:5]...) // blobs[5], index "A" + + // prepare the content in the fetcher, so that it can be used to test IndexAll + testContents := []ocispec.Descriptor{descE, descF, descB, descC, descD, descA} + for i := 0; i < len(blobs); i++ { + testFetcher.Push(ctx, testContents[i], bytes.NewReader(blobs[i])) + } + + // make sure that testFetcher works + rc, err := testFetcher.Fetch(ctx, descA) + if err != nil { + t.Errorf("testFetcher.Fetch() error = %v", err) + } + got, err := io.ReadAll(rc) + if err != nil { + t.Errorf("testFetcher.Fetch().Read() error = %v", err) + } + err = rc.Close() + if err != nil { + t.Errorf("testFetcher.Fetch().Close() error = %v", err) + } + if !bytes.Equal(got, blobs[5]) { + t.Errorf("testFetcher.Fetch() = %v, want %v", got, blobs[4]) + } + + nodeKeyA := descriptor.FromOCI(descA) + nodeKeyB := descriptor.FromOCI(descB) + nodeKeyC := descriptor.FromOCI(descC) + nodeKeyD := descriptor.FromOCI(descD) + nodeKeyE := descriptor.FromOCI(descE) + nodeKeyF := descriptor.FromOCI(descF) + + // index node A into testMemory using IndexAll + testMemory.IndexAll(ctx, testFetcher, descA) + + // check the information of node A + // 1. verify that node A exists in testMemory.nodes + if _, exists := testMemory.nodes[nodeKeyA]; !exists { + t.Errorf("nodes entry of %s should exist", "A") + } + // 2. verify that there is no entry of A in predecessors + if _, exists := testMemory.predecessors[nodeKeyA]; exists { + t.Errorf("there should be no entry of %s in predecessors", "A") + } + // 3. verify that A has successors B, C, D + successorsA, exists := testMemory.successors[nodeKeyA] + if !exists { + t.Errorf("there should be an entry of %s in successors", "A") + } + if !successorsA.Contains(nodeKeyB) { + t.Errorf("successors of %s should contain %s", "A", "B") + } + if !successorsA.Contains(nodeKeyC) { + t.Errorf("successors of %s should contain %s", "A", "C") + } + if !successorsA.Contains(nodeKeyD) { + t.Errorf("successors of %s should contain %s", "A", "D") + } + + // check the information of node B + // 1. verify that node B exists in testMemory.nodes + if _, exists := testMemory.nodes[nodeKeyB]; !exists { + t.Errorf("nodes entry of %s should exist", "B") + } + // 2. verify that B has node A in its predecessors + predecessorsB := testMemory.predecessors[nodeKeyB] + if !predecessorsB.Contains(nodeKeyA) { + t.Errorf("predecessors of %s should contain %s", "B", "A") + } + // 3. verify that B has node E in its successors + successorsB := testMemory.successors[nodeKeyB] + if !successorsB.Contains(nodeKeyE) { + t.Errorf("successors of %s should contain %s", "B", "E") + } + + // check the information of node C + // 1. verify that node C exists in testMemory.nodes + if _, exists := testMemory.nodes[nodeKeyC]; !exists { + t.Errorf("nodes entry of %s should exist", "C") + } + // 2. verify that C has node A in its predecessors + predecessorsC := testMemory.predecessors[nodeKeyC] + if !predecessorsC.Contains(nodeKeyA) { + t.Errorf("predecessors of %s should contain %s", "C", "A") + } + // 3. verify that C has node E and F in its successors + successorsC := testMemory.successors[nodeKeyC] + if !successorsC.Contains(nodeKeyE) { + t.Errorf("successors of %s should contain %s", "C", "E") + } + if !successorsC.Contains(nodeKeyF) { + t.Errorf("successors of %s should contain %s", "C", "F") + } + + // check the information of node D + // 1. verify that node D exists in testMemory.nodes + if _, exists := testMemory.nodes[nodeKeyD]; !exists { + t.Errorf("nodes entry of %s should exist", "D") + } + // 2. verify that D has node A in its predecessors + predecessorsD := testMemory.predecessors[nodeKeyD] + if !predecessorsD.Contains(nodeKeyA) { + t.Errorf("predecessors of %s should contain %s", "D", "A") + } + // 3. verify that D has node F in its successors + successorsD := testMemory.successors[nodeKeyD] + if !successorsD.Contains(nodeKeyF) { + t.Errorf("successors of %s should contain %s", "D", "F") + } + + // check the information of node E + // 1. verify that node E exists in testMemory.nodes + if _, exists := testMemory.nodes[nodeKeyE]; !exists { + t.Errorf("nodes entry of %s should exist", "E") + } + // 2. verify that E has node B and C in its predecessors + predecessorsE := testMemory.predecessors[nodeKeyE] + if !predecessorsE.Contains(nodeKeyB) { + t.Errorf("predecessors of %s should contain %s", "E", "B") + } + if !predecessorsE.Contains(nodeKeyC) { + t.Errorf("predecessors of %s should contain %s", "E", "C") + } + // 3. verify that E has an entry in successors and it's empty + successorsE, exists := testMemory.successors[nodeKeyE] + if !exists { + t.Errorf("entry %s should exist in testMemory.successors", "E") + } + if successorsE == nil { + t.Errorf("successors of %s should be an empty set, not nil", "E") + } + if len(successorsE) != 0 { + t.Errorf("successors of %s should be empty", "E") + } + + // check the information of node F + // 1. verify that node F exists in testMemory.nodes + if _, exists := testMemory.nodes[nodeKeyF]; !exists { + t.Errorf("nodes entry of %s should exist", "F") + } + // 2. verify that F has node C and D in its predecessors + predecessorsF := testMemory.predecessors[nodeKeyF] + if !predecessorsF.Contains(nodeKeyC) { + t.Errorf("predecessors of %s should contain %s", "F", "C") + } + if !predecessorsF.Contains(nodeKeyD) { + t.Errorf("predecessors of %s should contain %s", "F", "D") + } + // 3. verify that F has an entry in successors and it's empty + successorsF, exists := testMemory.successors[nodeKeyF] + if !exists { + t.Errorf("entry %s should exist in testMemory.successors", "F") + } + if successorsF == nil { + t.Errorf("successors of %s should be an empty set, not nil", "F") + } + if len(successorsF) != 0 { + t.Errorf("successors of %s should be empty", "F") + } + + // check that the Predecessors of node C is node A + predsC, err := testMemory.Predecessors(ctx, descC) + if err != nil { + t.Errorf("testFetcher.Predecessors() error = %v", err) + } + expectedLength := 1 + if len(predsC) != expectedLength { + t.Errorf("%s should have length %d", "predsC", expectedLength) + } + if !reflect.DeepEqual(predsC[0], descA) { + t.Errorf("incorrect predecessor result") + } + + // check that the Predecessors of node F are node C and node D + predsF, err := testMemory.Predecessors(ctx, descF) + if err != nil { + t.Errorf("testFetcher.Predecessors() error = %v", err) + } + expectedLength = 2 + if len(predsF) != expectedLength { + t.Errorf("%s should have length %d", "predsF", expectedLength) + } + for _, pred := range predsF { + if !reflect.DeepEqual(pred, descC) && !reflect.DeepEqual(pred, descD) { + t.Errorf("incorrect predecessor results") + } + } + + // remove node C and check the stored information + testMemory.Remove(ctx, descC) + if predecessorsE.Contains(nodeKeyC) { + t.Errorf("predecessors of %s should not contain %s", "E", "C") + } + if predecessorsF.Contains(nodeKeyC) { + t.Errorf("predecessors of %s should not contain %s", "F", "C") + } + if !successorsA.Contains(nodeKeyC) { + t.Errorf("successors of %s should still contain %s", "A", "C") + } + if _, exists := testMemory.successors[nodeKeyC]; exists { + t.Errorf("testMemory.successors should not contain the entry of %v", "C") + } + if _, exists := testMemory.predecessors[nodeKeyC]; !exists { + t.Errorf("entry %s in predecessors should still exists since it still has at least one predecessor node present", "C") + } + + // remove node A and check the stored information + testMemory.Remove(ctx, descA) + if _, exists := testMemory.predecessors[nodeKeyB]; exists { + t.Errorf("entry %s in predecessors should no longer exists", "B") + } + if _, exists := testMemory.predecessors[nodeKeyC]; exists { + t.Errorf("entry %s in predecessors should no longer exists", "C") + } + if _, exists := testMemory.predecessors[nodeKeyD]; exists { + t.Errorf("entry %s in predecessors should no longer exists", "D") + } + if _, exists := testMemory.successors[nodeKeyA]; exists { + t.Errorf("testDeletableMemory.successors should not contain the entry of %v", "A") + } + + // check that the Predecessors of node D is empty + predsD, err := testMemory.Predecessors(ctx, descD) + if err != nil { + t.Errorf("testFetcher.Predecessors() error = %v", err) + } + if predsD != nil { + t.Errorf("%s should be nil", "predsD") + } + + // check that the Predecessors of node E is node B + predsE, err := testMemory.Predecessors(ctx, descE) + if err != nil { + t.Errorf("testFetcher.Predecessors() error = %v", err) + } + expectedLength = 1 + if len(predsE) != expectedLength { + t.Errorf("%s should have length %d", "predsE", expectedLength) + } + if !reflect.DeepEqual(predsE[0], descB) { + t.Errorf("incorrect predecessor result") + } +} diff --git a/internal/manifestutil/parser.go b/internal/manifestutil/parser.go new file mode 100644 index 00000000..c904dc69 --- /dev/null +++ b/internal/manifestutil/parser.go @@ -0,0 +1,63 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package manifestutil + +import ( + "context" + "encoding/json" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2/content" + "oras.land/oras-go/v2/internal/docker" +) + +// Config returns the config of desc, if present. +func Config(ctx context.Context, fetcher content.Fetcher, desc ocispec.Descriptor) (*ocispec.Descriptor, error) { + switch desc.MediaType { + case docker.MediaTypeManifest, ocispec.MediaTypeImageManifest: + content, err := content.FetchAll(ctx, fetcher, desc) + if err != nil { + return nil, err + } + // OCI manifest schema can be used to marshal docker manifest + var manifest ocispec.Manifest + if err := json.Unmarshal(content, &manifest); err != nil { + return nil, err + } + return &manifest.Config, nil + default: + return nil, nil + } +} + +// Manifest returns the manifests of desc, if present. +func Manifests(ctx context.Context, fetcher content.Fetcher, desc ocispec.Descriptor) ([]ocispec.Descriptor, error) { + switch desc.MediaType { + case docker.MediaTypeManifestList, ocispec.MediaTypeImageIndex: + content, err := content.FetchAll(ctx, fetcher, desc) + if err != nil { + return nil, err + } + // OCI manifest index schema can be used to marshal docker manifest list + var index ocispec.Index + if err := json.Unmarshal(content, &index); err != nil { + return nil, err + } + return index.Manifests, nil + default: + return nil, nil + } +} diff --git a/internal/manifestutil/parser_test.go b/internal/manifestutil/parser_test.go new file mode 100644 index 00000000..44c5e43e --- /dev/null +++ b/internal/manifestutil/parser_test.go @@ -0,0 +1,202 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package manifestutil + +import ( + "bytes" + "context" + "encoding/json" + "reflect" + "testing" + + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2/internal/cas" + "oras.land/oras-go/v2/internal/docker" +) + +func TestConfig(t *testing.T) { + storage := cas.NewMemory() + + // generate test content + var blobs [][]byte + var descs []ocispec.Descriptor + appendBlob := func(mediaType string, blob []byte) { + blobs = append(blobs, blob) + descs = append(descs, ocispec.Descriptor{ + MediaType: mediaType, + Digest: digest.FromBytes(blob), + Size: int64(len(blob)), + }) + } + generateManifest := func(mediaType string, config ocispec.Descriptor, layers ...ocispec.Descriptor) { + manifest := ocispec.Manifest{ + Config: config, + Layers: layers, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatal(err) + } + appendBlob(mediaType, manifestJSON) + } + + appendBlob(ocispec.MediaTypeImageConfig, []byte("config")) // Blob 0 + appendBlob(ocispec.MediaTypeImageLayer, []byte("foo")) // Blob 1 + generateManifest(ocispec.MediaTypeImageManifest, descs[0], descs[1]) // Blob 2 + generateManifest(docker.MediaTypeManifest, descs[0], descs[1]) // Blob 3 + generateManifest("whatever", descs[0], descs[1]) // Blob 4 + + ctx := context.Background() + for i := range blobs { + err := storage.Push(ctx, descs[i], bytes.NewReader(blobs[i])) + if err != nil { + t.Fatalf("failed to push test content to src: %d: %v", i, err) + } + } + + tests := []struct { + name string + desc ocispec.Descriptor + want *ocispec.Descriptor + wantErr bool + }{ + { + name: "OCI Image Manifest", + desc: descs[2], + want: &descs[0], + }, + { + name: "Docker Manifest", + desc: descs[3], + want: &descs[0], + wantErr: false, + }, + { + name: "Other media type", + desc: descs[4], + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Config(ctx, storage, tt.desc) + if (err != nil) != tt.wantErr { + t.Errorf("Config() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Config() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestManifests(t *testing.T) { + storage := cas.NewMemory() + + // generate test content + var blobs [][]byte + var descs []ocispec.Descriptor + appendBlob := func(mediaType string, blob []byte) { + blobs = append(blobs, blob) + descs = append(descs, ocispec.Descriptor{ + MediaType: mediaType, + Digest: digest.FromBytes(blob), + Size: int64(len(blob)), + }) + } + generateManifest := func(subject *ocispec.Descriptor, config ocispec.Descriptor, layers ...ocispec.Descriptor) { + manifest := ocispec.Manifest{ + Subject: subject, + Config: config, + Layers: layers, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatal(err) + } + appendBlob(ocispec.MediaTypeImageManifest, manifestJSON) + } + generateIndex := func(mediaType string, subject *ocispec.Descriptor, manifests ...ocispec.Descriptor) { + index := ocispec.Index{ + Subject: subject, + Manifests: manifests, + } + indexJSON, err := json.Marshal(index) + if err != nil { + t.Fatal(err) + } + appendBlob(mediaType, indexJSON) + } + + appendBlob(ocispec.MediaTypeImageConfig, []byte("config")) // Blob 0 + appendBlob(ocispec.MediaTypeImageLayer, []byte("foo")) // Blob 1 + appendBlob(ocispec.MediaTypeImageLayer, []byte("bar")) // Blob 2 + appendBlob(ocispec.MediaTypeImageLayer, []byte("hello")) // Blob 3 + generateManifest(nil, descs[0], descs[1:3]...) // Blob 4 + generateManifest(nil, descs[0], descs[3]) // Blob 5 + appendBlob(ocispec.MediaTypeImageConfig, []byte("{}")) // Blob 6 + appendBlob("test/sig", []byte("sig")) // Blob 7 + generateManifest(&descs[4], descs[5], descs[6]) // Blob 8 + generateIndex(ocispec.MediaTypeImageIndex, &descs[8], descs[4:6]...) // Blob 9 + generateIndex(docker.MediaTypeManifestList, nil, descs[4:6]...) // Blob 10 + generateIndex("whatever", &descs[8], descs[4:6]...) // Blob 11 + + ctx := context.Background() + for i := range blobs { + err := storage.Push(ctx, descs[i], bytes.NewReader(blobs[i])) + if err != nil { + t.Fatalf("failed to push test content to src: %d: %v", i, err) + } + } + + tests := []struct { + name string + desc ocispec.Descriptor + want []ocispec.Descriptor + wantErr bool + }{ + { + name: "OCI Image Index", + desc: descs[9], + want: descs[4:6], + }, + { + name: "Docker Manifest List", + desc: descs[10], + want: descs[4:6], + wantErr: false, + }, + { + name: "Other media type", + desc: descs[11], + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Manifests(ctx, storage, tt.desc) + if (err != nil) != tt.wantErr { + t.Errorf("Manifests() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Manifests() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/platform/platform.go b/internal/platform/platform.go index 38d8d47f..e903fe3d 100644 --- a/internal/platform/platform.go +++ b/internal/platform/platform.go @@ -25,6 +25,7 @@ import ( "oras.land/oras-go/v2/content" "oras.land/oras-go/v2/errdef" "oras.land/oras-go/v2/internal/docker" + "oras.land/oras-go/v2/internal/manifestutil" ) // Match checks whether the current platform matches the target platform. @@ -35,7 +36,7 @@ import ( // array of the current platform. // // Note: Variant, OSVersion and OSFeatures are optional fields, will skip -// the comparison if the target platform does not provide specfic value. +// the comparison if the target platform does not provide specific value. func Match(got *ocispec.Platform, want *ocispec.Platform) bool { if got.Architecture != want.Architecture || got.OS != want.OS { return false @@ -77,7 +78,7 @@ func isSubset(a, b []string) bool { func SelectManifest(ctx context.Context, src content.ReadOnlyStorage, root ocispec.Descriptor, p *ocispec.Platform) (ocispec.Descriptor, error) { switch root.MediaType { case docker.MediaTypeManifestList, ocispec.MediaTypeImageIndex: - manifests, err := content.Successors(ctx, src, root) + manifests, err := manifestutil.Manifests(ctx, src, root) if err != nil { return ocispec.Descriptor{}, err } @@ -90,7 +91,8 @@ func SelectManifest(ctx context.Context, src content.ReadOnlyStorage, root ocisp } return ocispec.Descriptor{}, fmt.Errorf("%s: %w: no matching manifest was found in the manifest list", root.Digest, errdef.ErrNotFound) case docker.MediaTypeManifest, ocispec.MediaTypeImageManifest: - descs, err := content.Successors(ctx, src, root) + // config will be non-nil for docker manifest and OCI image manifest + config, err := manifestutil.Config(ctx, src, root) if err != nil { return ocispec.Descriptor{}, err } @@ -99,8 +101,7 @@ func SelectManifest(ctx context.Context, src content.ReadOnlyStorage, root ocisp if root.MediaType == ocispec.MediaTypeImageManifest { configMediaType = ocispec.MediaTypeImageConfig } - - cfgPlatform, err := getPlatformFromConfig(ctx, src, descs[0], configMediaType) + cfgPlatform, err := getPlatformFromConfig(ctx, src, *config, configMediaType) if err != nil { return ocispec.Descriptor{}, err } diff --git a/internal/platform/platform_test.go b/internal/platform/platform_test.go index a19e0b37..621ce959 100644 --- a/internal/platform/platform_test.go +++ b/internal/platform/platform_test.go @@ -144,10 +144,11 @@ func TestSelectManifest(t *testing.T) { }, }) } - generateManifest := func(arc, os, variant string, config ocispec.Descriptor, layers ...ocispec.Descriptor) { + generateManifest := func(arc, os, variant string, subject *ocispec.Descriptor, config ocispec.Descriptor, layers ...ocispec.Descriptor) { manifest := ocispec.Manifest{ - Config: config, - Layers: layers, + Subject: subject, + Config: config, + Layers: layers, } manifestJSON, err := json.Marshal(manifest) if err != nil { @@ -155,8 +156,9 @@ func TestSelectManifest(t *testing.T) { } appendManifest(arc, os, variant, ocispec.MediaTypeImageManifest, manifestJSON) } - generateIndex := func(manifests ...ocispec.Descriptor) { + generateIndex := func(subject *ocispec.Descriptor, manifests ...ocispec.Descriptor) { index := ocispec.Index{ + Subject: subject, Manifests: manifests, } indexJSON, err := json.Marshal(index) @@ -166,20 +168,21 @@ func TestSelectManifest(t *testing.T) { appendBlob(ocispec.MediaTypeImageIndex, indexJSON) } + appendBlob("test/subject", []byte("dummy subject")) // Blob 0 appendBlob(ocispec.MediaTypeImageConfig, []byte(`{"mediaType":"application/vnd.oci.image.config.v1+json", "created":"2022-07-29T08:13:55Z", "author":"test author", "architecture":"test-arc-1", "os":"test-os-1", -"variant":"v1"}`)) // Blob 0 - appendBlob(ocispec.MediaTypeImageLayer, []byte("foo")) // Blob 1 - appendBlob(ocispec.MediaTypeImageLayer, []byte("bar")) // Blob 2 - generateManifest(arc_1, os_1, variant_1, descs[0], descs[1:3]...) // Blob 3 - appendBlob(ocispec.MediaTypeImageLayer, []byte("hello1")) // Blob 4 - generateManifest(arc_2, os_2, variant_1, descs[0], descs[4]) // Blob 5 - appendBlob(ocispec.MediaTypeImageLayer, []byte("hello2")) // Blob 6 - generateManifest(arc_1, os_1, variant_2, descs[0], descs[6]) // Blob 7 - generateIndex(descs[3], descs[5], descs[7]) // Blob 8 +"variant":"v1"}`)) // Blob 1 + appendBlob(ocispec.MediaTypeImageLayer, []byte("foo")) // Blob 2 + appendBlob(ocispec.MediaTypeImageLayer, []byte("bar")) // Blob 3 + generateManifest(arc_1, os_1, variant_1, &descs[0], descs[1], descs[2:4]...) // Blob 4 + appendBlob(ocispec.MediaTypeImageLayer, []byte("hello1")) // Blob 5 + generateManifest(arc_2, os_2, variant_1, nil, descs[1], descs[5]) // Blob 6 + appendBlob(ocispec.MediaTypeImageLayer, []byte("hello2")) // Blob 7 + generateManifest(arc_1, os_1, variant_2, nil, descs[1], descs[7]) // Blob 8 + generateIndex(&descs[0], descs[4], descs[6], descs[8]) // Blob 9 ctx := context.Background() for i := range blobs { @@ -190,12 +193,12 @@ func TestSelectManifest(t *testing.T) { } // test SelectManifest on image index, only one matching manifest found - root := descs[8] + root := descs[9] targetPlatform := ocispec.Platform{ Architecture: arc_2, OS: os_2, } - wantDesc := descs[5] + wantDesc := descs[6] gotDesc, err := SelectManifest(ctx, storage, root, &targetPlatform) if err != nil { t.Fatalf("SelectManifest() error = %v, wantErr %v", err, false) @@ -211,7 +214,7 @@ func TestSelectManifest(t *testing.T) { Architecture: arc_1, OS: os_1, } - wantDesc = descs[3] + wantDesc = descs[4] gotDesc, err = SelectManifest(ctx, storage, root, &targetPlatform) if err != nil { t.Fatalf("SelectManifest() error = %v, wantErr %v", err, false) @@ -221,12 +224,12 @@ func TestSelectManifest(t *testing.T) { } // test SelectManifest on manifest - root = descs[7] + root = descs[8] targetPlatform = ocispec.Platform{ Architecture: arc_1, OS: os_1, } - wantDesc = descs[7] + wantDesc = descs[8] gotDesc, err = SelectManifest(ctx, storage, root, &targetPlatform) if err != nil { t.Fatalf("SelectManifest() error = %v, wantErr %v", err, false) @@ -237,7 +240,7 @@ func TestSelectManifest(t *testing.T) { // test SelectManifest on manifest, but there is no matching node. // Should return not found error. - root = descs[7] + root = descs[8] targetPlatform = ocispec.Platform{ Architecture: arc_1, OS: os_1, @@ -255,24 +258,26 @@ func TestSelectManifest(t *testing.T) { Architecture: arc_1, OS: os_1, } - root = descs[1] + root = descs[2] _, err = SelectManifest(ctx, storage, root, &targetPlatform) if !errors.Is(err, errdef.ErrUnsupported) { t.Fatalf("SelectManifest() error = %v, wantErr %v", err, errdef.ErrUnsupported) } // generate incorrect test content + storage = cas.NewMemory() blobs = nil descs = nil + appendBlob("test/subject", []byte("dummy subject")) // Blob 0 appendBlob(docker.MediaTypeConfig, []byte(`{"mediaType":"application/vnd.oci.image.config.v1+json", -"created":"2022-07-29T08:13:55Z", -"author":"test author 1", -"architecture":"test-arc-1", -"os":"test-os-1", -"variant":"v1"}`)) // Blob 0 - appendBlob(ocispec.MediaTypeImageLayer, []byte("foo1")) // Blob 1 - generateManifest(arc_1, os_1, variant_1, descs[0], descs[1]) // Blob 2 - generateIndex(descs[2]) // Blob 3 + "created":"2022-07-29T08:13:55Z", + "author":"test author 1", + "architecture":"test-arc-1", + "os":"test-os-1", + "variant":"v1"}`)) // Blob 1 + appendBlob(ocispec.MediaTypeImageLayer, []byte("foo1")) // Blob 2 + generateManifest(arc_1, os_1, variant_1, &descs[0], descs[1], descs[2]) // Blob 3 + generateIndex(&descs[0], descs[3]) // Blob 4 ctx = context.Background() for i := range blobs { @@ -285,7 +290,7 @@ func TestSelectManifest(t *testing.T) { // test SelectManifest on manifest, but the manifest is // invalid by having docker mediaType config in the manifest and oci // mediaType in the image config. Should return error. - root = descs[2] + root = descs[3] targetPlatform = ocispec.Platform{ Architecture: arc_1, OS: os_1, @@ -297,12 +302,14 @@ func TestSelectManifest(t *testing.T) { } // generate test content with null config blob + storage = cas.NewMemory() blobs = nil descs = nil - appendBlob(ocispec.MediaTypeImageConfig, []byte("null")) // Blob 0 - appendBlob(ocispec.MediaTypeImageLayer, []byte("foo2")) // Blob 1 - generateManifest(arc_1, os_1, variant_1, descs[0], descs[1]) // Blob 2 - generateIndex(descs[2]) // Blob 3 + appendBlob("test/subject", []byte("dummy subject")) // Blob 0 + appendBlob(ocispec.MediaTypeImageConfig, []byte("null")) // Blob 1 + appendBlob(ocispec.MediaTypeImageLayer, []byte("foo2")) // Blob 2 + generateManifest(arc_1, os_1, variant_1, &descs[0], descs[1], descs[2]) // Blob 3 + generateIndex(nil, descs[3]) // Blob 4 ctx = context.Background() for i := range blobs { @@ -314,7 +321,7 @@ func TestSelectManifest(t *testing.T) { // test SelectManifest on manifest with null config blob, // should return not found error. - root = descs[2] + root = descs[3] targetPlatform = ocispec.Platform{ Architecture: arc_1, OS: os_1, @@ -326,12 +333,14 @@ func TestSelectManifest(t *testing.T) { } // generate test content with empty config blob + storage = cas.NewMemory() blobs = nil descs = nil - appendBlob(ocispec.MediaTypeImageConfig, []byte("")) // Blob 0 - appendBlob(ocispec.MediaTypeImageLayer, []byte("foo3")) // Blob 1 - generateManifest(arc_1, os_1, variant_1, descs[0], descs[1]) // Blob 2 - generateIndex(descs[2]) // Blob 3 + appendBlob("test/subject", []byte("dummy subject")) // Blob 0 + appendBlob(ocispec.MediaTypeImageConfig, []byte("")) // Blob 1 + appendBlob(ocispec.MediaTypeImageLayer, []byte("foo3")) // Blob 2 + generateManifest(arc_1, os_1, variant_1, nil, descs[1], descs[2]) // Blob 3 + generateIndex(&descs[0], descs[3]) // Blob 4 ctx = context.Background() for i := range blobs { @@ -343,13 +352,16 @@ func TestSelectManifest(t *testing.T) { // test SelectManifest on manifest with empty config blob // should return not found error - root = descs[2] + root = descs[3] + targetPlatform = ocispec.Platform{ Architecture: arc_1, OS: os_1, } + _, err = SelectManifest(ctx, storage, root, &targetPlatform) expected = fmt.Sprintf("%s: %v: platform in manifest does not match target platform", root.Digest, errdef.ErrNotFound) + if err.Error() != expected { t.Fatalf("SelectManifest() error = %v, wantErr %v", err, expected) } diff --git a/internal/resolver/memory.go b/internal/resolver/memory.go index 6fac5e2d..2111d504 100644 --- a/internal/resolver/memory.go +++ b/internal/resolver/memory.go @@ -48,6 +48,11 @@ func (m *Memory) Tag(_ context.Context, desc ocispec.Descriptor, reference strin return nil } +// Untag removes a reference from index map. +func (m *Memory) Untag(reference string) { + m.index.Delete(reference) +} + // Map dumps the memory into a built-in map structure. // Like other operations, calling Map() is go-routine safe. However, it does not // necessarily correspond to any consistent snapshot of the storage contents. diff --git a/internal/resolver/memory_test.go b/internal/resolver/memory_test.go index eb9e1e56..f6f04b06 100644 --- a/internal/resolver/memory_test.go +++ b/internal/resolver/memory_test.go @@ -54,6 +54,15 @@ func TestMemorySuccess(t *testing.T) { if got := len(s.Map()); got != 1 { t.Errorf("Memory.Map() = %v, want %v", got, 1) } + + s.Untag(ref) + _, err = s.Resolve(ctx, ref) + if !errors.Is(err, errdef.ErrNotFound) { + t.Errorf("Memory.Resolve() error = %v, want %v", err, errdef.ErrNotFound) + } + if got := len(s.Map()); got != 0 { + t.Errorf("Memory.Map() = %v, want %v", got, 0) + } } func TestMemoryNotFound(t *testing.T) { diff --git a/internal/spec/artifact.go b/internal/spec/artifact.go index 8aa8e79e..7f801fd9 100644 --- a/internal/spec/artifact.go +++ b/internal/spec/artifact.go @@ -17,8 +17,16 @@ package spec import ocispec "github.com/opencontainers/image-spec/specs-go/v1" -// AnnotationReferrersFiltersApplied is the annotation key for the comma separated list of filters applied by the registry in the referrers listing. -const AnnotationReferrersFiltersApplied = "org.opencontainers.referrers.filtersApplied" +const ( + // AnnotationArtifactCreated is the annotation key for the date and time on which the artifact was built, conforming to RFC 3339. + AnnotationArtifactCreated = "org.opencontainers.artifact.created" + + // AnnotationArtifactDescription is the annotation key for the human readable description for the artifact. + AnnotationArtifactDescription = "org.opencontainers.artifact.description" + + // AnnotationReferrersFiltersApplied is the annotation key for the comma separated list of filters applied by the registry in the referrers listing. + AnnotationReferrersFiltersApplied = "org.opencontainers.referrers.filtersApplied" +) // MediaTypeArtifactManifest specifies the media type for a content descriptor. const MediaTypeArtifactManifest = "application/vnd.oci.artifact.manifest.v1+json" diff --git a/pack.go b/pack.go index 81fc12ca..08e14e19 100644 --- a/pack.go +++ b/pack.go @@ -21,6 +21,7 @@ import ( "encoding/json" "errors" "fmt" + "regexp" "time" specs "github.com/opencontainers/image-spec/specs-go" @@ -31,38 +32,127 @@ import ( ) const ( - // MediaTypeUnknownConfig is the default mediaType used when no - // config media type is specified. + // MediaTypeUnknownConfig is the default config mediaType used + // - for [Pack] when PackOptions.PackImageManifest is true and + // PackOptions.ConfigDescriptor is not specified. + // - for [PackManifest] when packManifestVersion is PackManifestVersion1_0 + // and PackManifestOptions.ConfigDescriptor is not specified. MediaTypeUnknownConfig = "application/vnd.unknown.config.v1+json" - // MediaTypeUnknownArtifact is the default artifactType used when no - // artifact type is specified. + + // MediaTypeUnknownArtifact is the default artifactType used for [Pack] + // when PackOptions.PackImageManifest is false and artifactType is + // not specified. MediaTypeUnknownArtifact = "application/vnd.unknown.artifact.v1" ) -// ErrInvalidDateTimeFormat is returned by Pack() when -// AnnotationArtifactCreated or AnnotationCreated is provided, but its value -// is not in RFC 3339 format. -// Reference: https://www.rfc-editor.org/rfc/rfc3339#section-5.6 -var ErrInvalidDateTimeFormat = errors.New("invalid date and time format") +var ( + // ErrInvalidDateTimeFormat is returned by [Pack] and [PackManifest] when + // AnnotationArtifactCreated or AnnotationCreated is provided, but its value + // is not in RFC 3339 format. + // Reference: https://www.rfc-editor.org/rfc/rfc3339#section-5.6 + ErrInvalidDateTimeFormat = errors.New("invalid date and time format") + + // ErrMissingArtifactType is returned by [PackManifest] when + // packManifestVersion is PackManifestVersion1_1_RC4 and artifactType is + // empty and the config media type is set to + // "application/vnd.oci.empty.v1+json". + ErrMissingArtifactType = errors.New("missing artifact type") +) + +// PackManifestVersion represents the manifest version used for [PackManifest]. +type PackManifestVersion int + +const ( + // PackManifestVersion1_0 represents the OCI Image Manifest defined in + // image-spec v1.0.2. + // Reference: https://github.com/opencontainers/image-spec/blob/v1.0.2/manifest.md + PackManifestVersion1_0 PackManifestVersion = 1 + + // PackManifestVersion1_1_RC4 represents the OCI Image Manifest defined + // in image-spec v1.1.0-rc4. + // Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc4/manifest.md + PackManifestVersion1_1_RC4 PackManifestVersion = 2 +) + +// PackManifestOptions contains optional parameters for [PackManifest]. +type PackManifestOptions struct { + // Subject is the subject of the manifest. + // This option is only valid when PackManifestVersion is + // NOT PackManifestVersion1_0. + Subject *ocispec.Descriptor + + // Layers is the layers of the manifest. + Layers []ocispec.Descriptor + + // ManifestAnnotations is the annotation map of the manifest. + ManifestAnnotations map[string]string + + // ConfigDescriptor is a pointer to the descriptor of the config blob. + // If not nil, ConfigAnnotations will be ignored. + ConfigDescriptor *ocispec.Descriptor + + // ConfigAnnotations is the annotation map of the config descriptor. + // This option is valid only when ConfigDescriptor is nil. + ConfigAnnotations map[string]string +} + +// mediaTypeRegexp checks the format of media types. +// References: +// - https://github.com/opencontainers/image-spec/blob/v1.1.0-rc4/schema/defs-descriptor.json#L7 +// - https://datatracker.ietf.org/doc/html/rfc6838#section-4.2 +var mediaTypeRegexp = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9!#$&-^_.+]{0,126}/[A-Za-z0-9][A-Za-z0-9!#$&-^_.+]{0,126}$`) -// PackOptions contains parameters for [oras.Pack]. +// PackManifest generates an OCI Image Manifest based on the given parameters +// and pushes the packed manifest to a content storage using pusher. The version +// of the manifest to be packed is determined by packManifestVersion +// (Recommended value: PackManifestVersion1_1_RC4). +// +// - If packManifestVersion is [PackManifestVersion1_1_RC4]: +// artifactType MUST NOT be empty unless opts.ConfigDescriptor is specified. +// - If packManifestVersion is [PackManifestVersion1_0]: +// if opts.ConfigDescriptor is nil, artifactType will be used as the +// config media type; if artifactType is empty, +// "application/vnd.unknown.config.v1+json" will be used. +// if opts.ConfigDescriptor is NOT nil, artifactType will be ignored. +// +// artifactType and opts.ConfigDescriptor.MediaType MUST comply with RFC 6838. +// +// If succeeded, returns a descriptor of the packed manifest. +func PackManifest(ctx context.Context, pusher content.Pusher, packManifestVersion PackManifestVersion, artifactType string, opts PackManifestOptions) (ocispec.Descriptor, error) { + switch packManifestVersion { + case PackManifestVersion1_0: + return packManifestV1_0(ctx, pusher, artifactType, opts) + case PackManifestVersion1_1_RC4: + return packManifestV1_1_RC4(ctx, pusher, artifactType, opts) + default: + return ocispec.Descriptor{}, fmt.Errorf("PackManifestVersion(%v): %w", packManifestVersion, errdef.ErrUnsupported) + } +} + +// PackOptions contains optional parameters for [Pack]. +// +// Deprecated: This type is deprecated and not recommended for future use. +// Use [PackManifestOptions] instead. type PackOptions struct { // Subject is the subject of the manifest. Subject *ocispec.Descriptor + // ManifestAnnotations is the annotation map of the manifest. ManifestAnnotations map[string]string - // PackImageManifest controls whether to pack an image manifest or not. - // - If true, pack an image manifest; artifactType will be used as the - // the config descriptor mediaType of the image manifest. - // - If false, pack an artifact manifest. - // Default: false. + // PackImageManifest controls whether to pack an OCI Image Manifest or not. + // - If true, pack an OCI Image Manifest. + // - If false, pack an OCI Artifact Manifest (deprecated). + // + // Default value: false. PackImageManifest bool + // ConfigDescriptor is a pointer to the descriptor of the config blob. // If not nil, artifactType will be implied by the mediaType of the // specified ConfigDescriptor, and ConfigAnnotations will be ignored. // This option is valid only when PackImageManifest is true. ConfigDescriptor *ocispec.Descriptor + // ConfigAnnotations is the annotation map of the config descriptor. // This option is valid only when PackImageManifest is true // and ConfigDescriptor is nil. @@ -74,23 +164,26 @@ type PackOptions struct { // // When opts.PackImageManifest is true, artifactType will be used as the // the config descriptor mediaType of the image manifest. +// // If succeeded, returns a descriptor of the manifest. +// +// Deprecated: This method is deprecated and not recommended for future use. +// Use [PackManifest] instead. func Pack(ctx context.Context, pusher content.Pusher, artifactType string, blobs []ocispec.Descriptor, opts PackOptions) (ocispec.Descriptor, error) { if opts.PackImageManifest { - return packImage(ctx, pusher, artifactType, blobs, opts) + return packManifestV1_1_RC2(ctx, pusher, artifactType, blobs, opts) } return packArtifact(ctx, pusher, artifactType, blobs, opts) } -// packArtifact packs the given blobs, generates an artifact manifest for the -// pack, and pushes it to a content storage. -// If succeeded, returns a descriptor of the manifest. +// packArtifact packs an Artifact manifest as defined in image-spec v1.1.0-rc2. +// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc2/artifact.md func packArtifact(ctx context.Context, pusher content.Pusher, artifactType string, blobs []ocispec.Descriptor, opts PackOptions) (ocispec.Descriptor, error) { if artifactType == "" { artifactType = MediaTypeUnknownArtifact } - annotations, err := ensureAnnotationCreated(opts.ManifestAnnotations, ocispec.AnnotationArtifactCreated) + annotations, err := ensureAnnotationCreated(opts.ManifestAnnotations, spec.AnnotationArtifactCreated) if err != nil { return ocispec.Descriptor{}, err } @@ -101,46 +194,72 @@ func packArtifact(ctx context.Context, pusher content.Pusher, artifactType strin Subject: opts.Subject, Annotations: annotations, } - manifestJSON, err := json.Marshal(manifest) - if err != nil { - return ocispec.Descriptor{}, fmt.Errorf("failed to marshal manifest: %w", err) + return pushManifest(ctx, pusher, manifest, manifest.MediaType, manifest.ArtifactType, manifest.Annotations) +} + +// packManifestV1_0 packs an image manifest defined in image-spec v1.0.2. +// Reference: https://github.com/opencontainers/image-spec/blob/v1.0.2/manifest.md +func packManifestV1_0(ctx context.Context, pusher content.Pusher, artifactType string, opts PackManifestOptions) (ocispec.Descriptor, error) { + if opts.Subject != nil { + return ocispec.Descriptor{}, fmt.Errorf("subject is not supported for manifest version %v: %w", PackManifestVersion1_0, errdef.ErrUnsupported) } - manifestDesc := content.NewDescriptorFromBytes(spec.MediaTypeArtifactManifest, manifestJSON) - // populate ArtifactType and Annotations of the manifest into manifestDesc - manifestDesc.ArtifactType = manifest.ArtifactType - manifestDesc.Annotations = manifest.Annotations - // push manifest - if err := pusher.Push(ctx, manifestDesc, bytes.NewReader(manifestJSON)); err != nil && !errors.Is(err, errdef.ErrAlreadyExists) { - return ocispec.Descriptor{}, fmt.Errorf("failed to push manifest: %w", err) + // prepare config + var configDesc ocispec.Descriptor + if opts.ConfigDescriptor != nil { + if err := validateMediaType(opts.ConfigDescriptor.MediaType); err != nil { + return ocispec.Descriptor{}, fmt.Errorf("invalid config mediaType format: %w", err) + } + configDesc = *opts.ConfigDescriptor + } else { + if artifactType == "" { + artifactType = MediaTypeUnknownConfig + } else if err := validateMediaType(artifactType); err != nil { + return ocispec.Descriptor{}, fmt.Errorf("invalid artifactType format: %w", err) + } + var err error + configDesc, err = pushCustomEmptyConfig(ctx, pusher, artifactType, opts.ConfigAnnotations) + if err != nil { + return ocispec.Descriptor{}, err + } } - return manifestDesc, nil + annotations, err := ensureAnnotationCreated(opts.ManifestAnnotations, ocispec.AnnotationCreated) + if err != nil { + return ocispec.Descriptor{}, err + } + if opts.Layers == nil { + opts.Layers = []ocispec.Descriptor{} // make it an empty array to prevent potential server-side bugs + } + manifest := ocispec.Manifest{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + Config: configDesc, + MediaType: ocispec.MediaTypeImageManifest, + Layers: opts.Layers, + Annotations: annotations, + } + return pushManifest(ctx, pusher, manifest, manifest.MediaType, manifest.Config.MediaType, manifest.Annotations) } -// packImage packs the given blobs, generates an image manifest for the pack, -// and pushes it to a content storage. artifactType will be used as the config -// descriptor mediaType of the image manifest. -// If succeeded, returns a descriptor of the manifest. -func packImage(ctx context.Context, pusher content.Pusher, configMediaType string, layers []ocispec.Descriptor, opts PackOptions) (ocispec.Descriptor, error) { +// packManifestV1_1_RC2 packs an image manifest as defined in image-spec +// v1.1.0-rc2. +// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc2/manifest.md +func packManifestV1_1_RC2(ctx context.Context, pusher content.Pusher, configMediaType string, layers []ocispec.Descriptor, opts PackOptions) (ocispec.Descriptor, error) { if configMediaType == "" { configMediaType = MediaTypeUnknownConfig } + // prepare config var configDesc ocispec.Descriptor if opts.ConfigDescriptor != nil { configDesc = *opts.ConfigDescriptor } else { - // Use an empty JSON object here, because some registries may not accept - // empty config blob. - // As of September 2022, GAR is known to return 400 on empty blob upload. - // See https://github.com/oras-project/oras-go/issues/294 for details. - configBytes := []byte("{}") - configDesc = content.NewDescriptorFromBytes(configMediaType, configBytes) - configDesc.Annotations = opts.ConfigAnnotations - // push config - if err := pusher.Push(ctx, configDesc, bytes.NewReader(configBytes)); err != nil && !errors.Is(err, errdef.ErrAlreadyExists) { - return ocispec.Descriptor{}, fmt.Errorf("failed to push config: %w", err) + var err error + configDesc, err = pushCustomEmptyConfig(ctx, pusher, configMediaType, opts.ConfigAnnotations) + if err != nil { + return ocispec.Descriptor{}, err } } @@ -161,23 +280,124 @@ func packImage(ctx context.Context, pusher content.Pusher, configMediaType strin Subject: opts.Subject, Annotations: annotations, } + return pushManifest(ctx, pusher, manifest, manifest.MediaType, manifest.Config.MediaType, manifest.Annotations) +} + +// packManifestV1_1_RC4 packs an image manifest defined in image-spec v1.1.0-rc4. +// Reference: https://github.com/opencontainers/image-spec/blob/v1.1.0-rc4/manifest.md#guidelines-for-artifact-usage +func packManifestV1_1_RC4(ctx context.Context, pusher content.Pusher, artifactType string, opts PackManifestOptions) (ocispec.Descriptor, error) { + if artifactType == "" && (opts.ConfigDescriptor == nil || opts.ConfigDescriptor.MediaType == ocispec.MediaTypeEmptyJSON) { + // artifactType MUST be set when config.mediaType is set to the empty value + return ocispec.Descriptor{}, ErrMissingArtifactType + } + if artifactType != "" { + if err := validateMediaType(artifactType); err != nil { + return ocispec.Descriptor{}, fmt.Errorf("invalid artifactType format: %w", err) + } + } + + // prepare config + var emptyBlobExists bool + var configDesc ocispec.Descriptor + if opts.ConfigDescriptor != nil { + if err := validateMediaType(opts.ConfigDescriptor.MediaType); err != nil { + return ocispec.Descriptor{}, fmt.Errorf("invalid config mediaType format: %w", err) + } + configDesc = *opts.ConfigDescriptor + } else { + // use the empty descriptor for config + configDesc = ocispec.DescriptorEmptyJSON + configDesc.Annotations = opts.ConfigAnnotations + configBytes := ocispec.DescriptorEmptyJSON.Data + // push config + if err := pushIfNotExist(ctx, pusher, configDesc, configBytes); err != nil { + return ocispec.Descriptor{}, fmt.Errorf("failed to push config: %w", err) + } + emptyBlobExists = true + } + + annotations, err := ensureAnnotationCreated(opts.ManifestAnnotations, ocispec.AnnotationCreated) + if err != nil { + return ocispec.Descriptor{}, err + } + if len(opts.Layers) == 0 { + // use the empty descriptor as the single layer + layerDesc := ocispec.DescriptorEmptyJSON + layerData := ocispec.DescriptorEmptyJSON.Data + if !emptyBlobExists { + if err := pushIfNotExist(ctx, pusher, layerDesc, layerData); err != nil { + return ocispec.Descriptor{}, fmt.Errorf("failed to push layer: %w", err) + } + } + opts.Layers = []ocispec.Descriptor{layerDesc} + } + + manifest := ocispec.Manifest{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + Config: configDesc, + MediaType: ocispec.MediaTypeImageManifest, + Layers: opts.Layers, + Subject: opts.Subject, + ArtifactType: artifactType, + Annotations: annotations, + } + return pushManifest(ctx, pusher, manifest, manifest.MediaType, manifest.ArtifactType, manifest.Annotations) +} + +// pushIfNotExist pushes data described by desc if it does not exist in the +// target. +func pushIfNotExist(ctx context.Context, pusher content.Pusher, desc ocispec.Descriptor, data []byte) error { + if ros, ok := pusher.(content.ReadOnlyStorage); ok { + exists, err := ros.Exists(ctx, desc) + if err != nil { + return fmt.Errorf("failed to check existence: %s: %s: %w", desc.Digest.String(), desc.MediaType, err) + } + if exists { + return nil + } + } + + if err := pusher.Push(ctx, desc, bytes.NewReader(data)); err != nil && !errors.Is(err, errdef.ErrAlreadyExists) { + return fmt.Errorf("failed to push: %s: %s: %w", desc.Digest.String(), desc.MediaType, err) + } + return nil +} + +// pushManifest marshals manifest into JSON bytes and pushes it. +func pushManifest(ctx context.Context, pusher content.Pusher, manifest any, mediaType string, artifactType string, annotations map[string]string) (ocispec.Descriptor, error) { manifestJSON, err := json.Marshal(manifest) if err != nil { return ocispec.Descriptor{}, fmt.Errorf("failed to marshal manifest: %w", err) } - manifestDesc := content.NewDescriptorFromBytes(ocispec.MediaTypeImageManifest, manifestJSON) + manifestDesc := content.NewDescriptorFromBytes(mediaType, manifestJSON) // populate ArtifactType and Annotations of the manifest into manifestDesc - manifestDesc.ArtifactType = manifest.Config.MediaType - manifestDesc.Annotations = manifest.Annotations - + manifestDesc.ArtifactType = artifactType + manifestDesc.Annotations = annotations // push manifest if err := pusher.Push(ctx, manifestDesc, bytes.NewReader(manifestJSON)); err != nil && !errors.Is(err, errdef.ErrAlreadyExists) { return ocispec.Descriptor{}, fmt.Errorf("failed to push manifest: %w", err) } - return manifestDesc, nil } +// pushCustomEmptyConfig generates and pushes an empty config blob. +func pushCustomEmptyConfig(ctx context.Context, pusher content.Pusher, mediaType string, annotations map[string]string) (ocispec.Descriptor, error) { + // Use an empty JSON object here, because some registries may not accept + // empty config blob. + // As of September 2022, GAR is known to return 400 on empty blob upload. + // See https://github.com/oras-project/oras-go/issues/294 for details. + configBytes := []byte("{}") + configDesc := content.NewDescriptorFromBytes(mediaType, configBytes) + configDesc.Annotations = annotations + // push config + if err := pushIfNotExist(ctx, pusher, configDesc, configBytes); err != nil { + return ocispec.Descriptor{}, fmt.Errorf("failed to push config: %w", err) + } + return configDesc, nil +} + // ensureAnnotationCreated ensures that annotationCreatedKey is in annotations, // and that its value conforms to RFC 3339. Otherwise returns a new annotation // map with annotationCreatedKey created. @@ -201,3 +421,11 @@ func ensureAnnotationCreated(annotations map[string]string, annotationCreatedKey copied[annotationCreatedKey] = now.Format(time.RFC3339) return copied, nil } + +// validateMediaType validates the format of mediaType. +func validateMediaType(mediaType string) error { + if !mediaTypeRegexp.MatchString(mediaType) { + return fmt.Errorf("%s: %w", mediaType, errdef.ErrInvalidMediaType) + } + return nil +} diff --git a/pack_test.go b/pack_test.go index 582ef128..0f8bc04b 100644 --- a/pack_test.go +++ b/pack_test.go @@ -30,10 +30,11 @@ import ( ocispec "github.com/opencontainers/image-spec/specs-go/v1" "oras.land/oras-go/v2/content" "oras.land/oras-go/v2/content/memory" + "oras.land/oras-go/v2/errdef" "oras.land/oras-go/v2/internal/spec" ) -func Test_Pack_Default(t *testing.T) { +func Test_Pack_Artifact_NoOption(t *testing.T) { s := memory.New() // prepare test content @@ -50,7 +51,7 @@ func Test_Pack_Default(t *testing.T) { t.Fatal("Oras.Pack() error =", err) } - // test blobs + // verify blobs var manifest spec.Artifact rc, err := s.Fetch(ctx, manifestDesc) if err != nil { @@ -66,28 +67,38 @@ func Test_Pack_Default(t *testing.T) { t.Errorf("Store.Fetch() = %v, want %v", manifest.Blobs, blobs) } - // test media type + // verify media type if got := manifest.MediaType; got != spec.MediaTypeArtifactManifest { t.Fatalf("got media type = %s, want %s", got, spec.MediaTypeArtifactManifest) } - // test artifact type + // verify artifact type if got := manifest.ArtifactType; got != artifactType { t.Fatalf("got artifact type = %s, want %s", got, artifactType) } - // test created time annotation - createdTime, ok := manifest.Annotations[ocispec.AnnotationArtifactCreated] + // verify created time annotation + createdTime, ok := manifest.Annotations[spec.AnnotationArtifactCreated] if !ok { - t.Errorf("Annotation %s = %v, want %v", ocispec.AnnotationArtifactCreated, ok, true) + t.Errorf("Annotation %s = %v, want %v", spec.AnnotationArtifactCreated, ok, true) } _, err = time.Parse(time.RFC3339, createdTime) if err != nil { t.Errorf("error parsing created time: %s, error = %v", createdTime, err) } + + // verify descriptor artifact type + if want := manifest.ArtifactType; !reflect.DeepEqual(manifestDesc.ArtifactType, want) { + t.Errorf("got descriptor artifactType = %v, want %v", manifestDesc.ArtifactType, want) + } + + // verify descriptor annotations + if want := manifest.Annotations; !reflect.DeepEqual(manifestDesc.Annotations, want) { + t.Errorf("got descriptor annotations = %v, want %v", manifestDesc.Annotations, want) + } } -func Test_Pack_WithOptions(t *testing.T) { +func Test_Pack_Artifact_WithOptions(t *testing.T) { s := memory.New() // prepare test content @@ -98,7 +109,8 @@ func Test_Pack_WithOptions(t *testing.T) { artifactType := "application/vnd.test" annotations := map[string]string{ - ocispec.AnnotationArtifactCreated: "2000-01-01T00:00:00Z", + spec.AnnotationArtifactCreated: "2000-01-01T00:00:00Z", + "foo": "bar", } subjectManifest := []byte(`{"layers":[]}`) subjectDesc := ocispec.Descriptor{ @@ -108,12 +120,17 @@ func Test_Pack_WithOptions(t *testing.T) { ArtifactType: artifactType, Annotations: annotations, } + configBytes := []byte("{}") + configDesc := content.NewDescriptorFromBytes("application/vnd.test.config", configBytes) + configAnnotations := map[string]string{"foo": "bar"} // test Pack ctx := context.Background() opts := PackOptions{ Subject: &subjectDesc, ManifestAnnotations: annotations, + ConfigDescriptor: &configDesc, // should not work + ConfigAnnotations: configAnnotations, // should not work } manifestDesc, err := Pack(ctx, s, artifactType, blobs, opts) if err != nil { @@ -132,7 +149,7 @@ func Test_Pack_WithOptions(t *testing.T) { t.Fatal("failed to marshal manifest:", err) } - // test manifest + // verify manifest rc, err := s.Fetch(ctx, manifestDesc) if err != nil { t.Fatal("Store.Fetch() error =", err) @@ -148,9 +165,17 @@ func Test_Pack_WithOptions(t *testing.T) { if !bytes.Equal(got, expectedManifestBytes) { t.Errorf("Store.Fetch() = %v, want %v", got, expectedManifestBytes) } + + // verify descriptor + expectedManifestDesc := content.NewDescriptorFromBytes(expectedManifest.MediaType, expectedManifestBytes) + expectedManifestDesc.ArtifactType = expectedManifest.ArtifactType + expectedManifestDesc.Annotations = expectedManifest.Annotations + if !reflect.DeepEqual(manifestDesc, expectedManifestDesc) { + t.Errorf("Pack() = %v, want %v", manifestDesc, expectedManifestDesc) + } } -func Test_Pack_NoBlob(t *testing.T) { +func Test_Pack_Artifact_NoBlob(t *testing.T) { s := memory.New() // test Pack @@ -173,14 +198,14 @@ func Test_Pack_NoBlob(t *testing.T) { t.Fatal("Store.Fetch().Close() error =", err) } - // test blobs + // verify blobs var expectedBlobs []ocispec.Descriptor if !reflect.DeepEqual(manifest.Blobs, expectedBlobs) { t.Errorf("Store.Fetch() = %v, want %v", manifest.Blobs, expectedBlobs) } } -func Test_Pack_NoArtifactType(t *testing.T) { +func Test_Pack_Artifact_NoArtifactType(t *testing.T) { s := memory.New() ctx := context.Background() @@ -201,7 +226,7 @@ func Test_Pack_NoArtifactType(t *testing.T) { t.Fatal("Store.Fetch().Close() error =", err) } - // test artifact type + // verify artifact type if manifestDesc.ArtifactType != MediaTypeUnknownArtifact { t.Fatalf("got artifact type = %s, want %s", manifestDesc.ArtifactType, MediaTypeUnknownArtifact) } @@ -210,23 +235,23 @@ func Test_Pack_NoArtifactType(t *testing.T) { } } -func Test_Pack_InvalidDateTimeFormat(t *testing.T) { +func Test_Pack_Artifact_InvalidDateTimeFormat(t *testing.T) { s := memory.New() ctx := context.Background() opts := PackOptions{ ManifestAnnotations: map[string]string{ - ocispec.AnnotationArtifactCreated: "2000/01/01 00:00:00", + spec.AnnotationArtifactCreated: "2000/01/01 00:00:00", }, } artifactType := "application/vnd.test" _, err := Pack(ctx, s, artifactType, nil, opts) - if err == nil || !errors.Is(err, ErrInvalidDateTimeFormat) { - t.Errorf("Oras.Pack() error = %v, wantErr = %v", err, ErrInvalidDateTimeFormat) + if wantErr := ErrInvalidDateTimeFormat; !errors.Is(err, wantErr) { + t.Errorf("Oras.Pack() error = %v, wantErr = %v", err, wantErr) } } -func Test_Pack_Image(t *testing.T) { +func Test_Pack_ImageV1_1_RC2(t *testing.T) { s := memory.New() // prepare test content @@ -237,7 +262,7 @@ func Test_Pack_Image(t *testing.T) { // test Pack ctx := context.Background() - artifactType := "testconfig" + artifactType := "application/vnd.test" manifestDesc, err := Pack(ctx, s, artifactType, layers, PackOptions{PackImageManifest: true}) if err != nil { t.Fatal("Oras.Pack() error =", err) @@ -255,13 +280,13 @@ func Test_Pack_Image(t *testing.T) { t.Fatal("Store.Fetch().Close() error =", err) } - // test media type + // verify media type got := manifest.MediaType if got != ocispec.MediaTypeImageManifest { t.Fatalf("got media type = %s, want %s", got, ocispec.MediaTypeImageManifest) } - // test config + // verify config expectedConfigBytes := []byte("{}") expectedConfig := ocispec.Descriptor{ MediaType: artifactType, @@ -272,12 +297,12 @@ func Test_Pack_Image(t *testing.T) { t.Errorf("got config = %v, want %v", manifest.Config, expectedConfig) } - // test layers + // verify layers if !reflect.DeepEqual(manifest.Layers, layers) { t.Errorf("got layers = %v, want %v", manifest.Layers, layers) } - // test created time annotation + // verify created time annotation createdTime, ok := manifest.Annotations[ocispec.AnnotationCreated] if !ok { t.Errorf("Annotation %s = %v, want %v", ocispec.AnnotationCreated, ok, true) @@ -286,9 +311,14 @@ func Test_Pack_Image(t *testing.T) { if err != nil { t.Errorf("error parsing created time: %s, error = %v", createdTime, err) } + + // verify descriptor annotations + if want := manifest.Annotations; !reflect.DeepEqual(manifestDesc.Annotations, want) { + t.Errorf("got descriptor annotations = %v, want %v", manifestDesc.Annotations, want) + } } -func Test_Pack_Image_WithOptions(t *testing.T) { +func Test_Pack_ImageV1_1_RC2_WithOptions(t *testing.T) { s := memory.New() // prepare test content @@ -297,10 +327,11 @@ func Test_Pack_Image_WithOptions(t *testing.T) { content.NewDescriptorFromBytes("test", []byte("goodbye world")), } configBytes := []byte("{}") - configDesc := content.NewDescriptorFromBytes("testconfig", configBytes) + configDesc := content.NewDescriptorFromBytes("application/vnd.test.config", configBytes) configAnnotations := map[string]string{"foo": "bar"} annotations := map[string]string{ ocispec.AnnotationCreated: "2000-01-01T00:00:00Z", + "foo": "bar", } artifactType := "application/vnd.test" subjectManifest := []byte(`{"layers":[]}`) @@ -355,6 +386,14 @@ func Test_Pack_Image_WithOptions(t *testing.T) { t.Errorf("Store.Fetch() = %v, want %v", string(got), string(expectedManifestBytes)) } + // verify descriptor + expectedManifestDesc := content.NewDescriptorFromBytes(expectedManifest.MediaType, expectedManifestBytes) + expectedManifestDesc.ArtifactType = expectedManifest.Config.MediaType + expectedManifestDesc.Annotations = expectedManifest.Annotations + if !reflect.DeepEqual(manifestDesc, expectedManifestDesc) { + t.Errorf("Pack() = %v, want %v", manifestDesc, expectedManifestDesc) + } + // test Pack without ConfigDescriptor opts = PackOptions{ PackImageManifest: true, @@ -399,9 +438,17 @@ func Test_Pack_Image_WithOptions(t *testing.T) { if !bytes.Equal(got, expectedManifestBytes) { t.Errorf("Store.Fetch() = %v, want %v", string(got), string(expectedManifestBytes)) } + + // verify descriptor + expectedManifestDesc = content.NewDescriptorFromBytes(expectedManifest.MediaType, expectedManifestBytes) + expectedManifestDesc.ArtifactType = expectedManifest.Config.MediaType + expectedManifestDesc.Annotations = expectedManifest.Annotations + if !reflect.DeepEqual(manifestDesc, expectedManifestDesc) { + t.Errorf("Pack() = %v, want %v", manifestDesc, expectedManifestDesc) + } } -func Test_Pack_Image_NoArtifactType(t *testing.T) { +func Test_Pack_ImageV1_1_RC2_NoArtifactType(t *testing.T) { s := memory.New() ctx := context.Background() @@ -422,7 +469,7 @@ func Test_Pack_Image_NoArtifactType(t *testing.T) { t.Fatal("Store.Fetch().Close() error =", err) } - // test artifact type and config media type + // verify artifact type and config media type if manifestDesc.ArtifactType != MediaTypeUnknownConfig { t.Fatalf("got artifact type = %s, want %s", manifestDesc.ArtifactType, MediaTypeUnknownConfig) } @@ -431,7 +478,7 @@ func Test_Pack_Image_NoArtifactType(t *testing.T) { } } -func Test_Pack_Image_NoLayer(t *testing.T) { +func Test_Pack_ImageV1_1_RC2_NoLayer(t *testing.T) { s := memory.New() // test Pack @@ -453,14 +500,14 @@ func Test_Pack_Image_NoLayer(t *testing.T) { t.Fatal("Store.Fetch().Close() error =", err) } - // test layers + // verify layers expectedLayers := []ocispec.Descriptor{} if !reflect.DeepEqual(manifest.Layers, expectedLayers) { t.Errorf("got layers = %v, want %v", manifest.Layers, expectedLayers) } } -func Test_Pack_Image_InvalidDateTimeFormat(t *testing.T) { +func Test_Pack_ImageV1_1_RC2_InvalidDateTimeFormat(t *testing.T) { s := memory.New() ctx := context.Background() @@ -471,7 +518,575 @@ func Test_Pack_Image_InvalidDateTimeFormat(t *testing.T) { }, } _, err := Pack(ctx, s, "", nil, opts) - if err == nil || !errors.Is(err, ErrInvalidDateTimeFormat) { - t.Errorf("Oras.Pack() error = %v, wantErr = %v", err, ErrInvalidDateTimeFormat) + if wantErr := ErrInvalidDateTimeFormat; !errors.Is(err, wantErr) { + t.Errorf("Oras.Pack() error = %v, wantErr = %v", err, wantErr) + } +} + +func Test_PackManifest_ImageV1_0(t *testing.T) { + s := memory.New() + + // test Pack + ctx := context.Background() + artifactType := "application/vnd.test" + manifestDesc, err := PackManifest(ctx, s, PackManifestVersion1_0, artifactType, PackManifestOptions{}) + if err != nil { + t.Fatal("Oras.PackManifest() error =", err) + } + + var manifest ocispec.Manifest + rc, err := s.Fetch(ctx, manifestDesc) + if err != nil { + t.Fatal("Store.Fetch() error =", err) + } + if err := json.NewDecoder(rc).Decode(&manifest); err != nil { + t.Fatal("error decoding manifest, error =", err) + } + if err := rc.Close(); err != nil { + t.Fatal("Store.Fetch().Close() error =", err) + } + + // verify media type + got := manifest.MediaType + if got != ocispec.MediaTypeImageManifest { + t.Fatalf("got media type = %s, want %s", got, ocispec.MediaTypeImageManifest) + } + + // verify config + expectedConfigBytes := []byte("{}") + expectedConfig := ocispec.Descriptor{ + MediaType: artifactType, + Digest: digest.FromBytes(expectedConfigBytes), + Size: int64(len(expectedConfigBytes)), + } + if !reflect.DeepEqual(manifest.Config, expectedConfig) { + t.Errorf("got config = %v, want %v", manifest.Config, expectedConfig) + } + + // verify layers + expectedLayers := []ocispec.Descriptor{} + if !reflect.DeepEqual(manifest.Layers, expectedLayers) { + t.Errorf("got layers = %v, want %v", manifest.Layers, expectedLayers) + } + + // verify created time annotation + createdTime, ok := manifest.Annotations[ocispec.AnnotationCreated] + if !ok { + t.Errorf("Annotation %s = %v, want %v", ocispec.AnnotationCreated, ok, true) + } + _, err = time.Parse(time.RFC3339, createdTime) + if err != nil { + t.Errorf("error parsing created time: %s, error = %v", createdTime, err) + } + + // verify descriptor annotations + if want := manifest.Annotations; !reflect.DeepEqual(manifestDesc.Annotations, want) { + t.Errorf("got descriptor annotations = %v, want %v", manifestDesc.Annotations, want) + } +} + +func Test_PackManifest_ImageV1_0_WithOptions(t *testing.T) { + s := memory.New() + + // prepare test content + layers := []ocispec.Descriptor{ + content.NewDescriptorFromBytes("test", []byte("hello world")), + content.NewDescriptorFromBytes("test", []byte("goodbye world")), + } + configBytes := []byte("{}") + configDesc := content.NewDescriptorFromBytes("application/vnd.test.config", configBytes) + configAnnotations := map[string]string{"foo": "bar"} + annotations := map[string]string{ + ocispec.AnnotationCreated: "2000-01-01T00:00:00Z", + "foo": "bar", + } + artifactType := "application/vnd.test" + + // test PackManifest with ConfigDescriptor + ctx := context.Background() + opts := PackManifestOptions{ + Layers: layers, + ConfigDescriptor: &configDesc, + ConfigAnnotations: configAnnotations, + ManifestAnnotations: annotations, + } + manifestDesc, err := PackManifest(ctx, s, PackManifestVersion1_0, artifactType, opts) + if err != nil { + t.Fatal("Oras.PackManifest() error =", err) + } + + expectedManifest := ocispec.Manifest{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageManifest, + Config: configDesc, + Layers: layers, + Annotations: annotations, + } + expectedManifestBytes, err := json.Marshal(expectedManifest) + if err != nil { + t.Fatal("failed to marshal manifest:", err) + } + + rc, err := s.Fetch(ctx, manifestDesc) + if err != nil { + t.Fatal("Store.Fetch() error =", err) + } + got, err := io.ReadAll(rc) + if err != nil { + t.Fatal("Store.Fetch().Read() error =", err) + } + err = rc.Close() + if err != nil { + t.Error("Store.Fetch().Close() error =", err) + } + if !bytes.Equal(got, expectedManifestBytes) { + t.Errorf("Store.Fetch() = %v, want %v", string(got), string(expectedManifestBytes)) + } + + // verify descriptor + expectedManifestDesc := content.NewDescriptorFromBytes(expectedManifest.MediaType, expectedManifestBytes) + expectedManifestDesc.ArtifactType = expectedManifest.Config.MediaType + expectedManifestDesc.Annotations = expectedManifest.Annotations + if !reflect.DeepEqual(manifestDesc, expectedManifestDesc) { + t.Errorf("Pack() = %v, want %v", manifestDesc, expectedManifestDesc) + } + + // test PackManifest without ConfigDescriptor + opts = PackManifestOptions{ + Layers: layers, + ConfigAnnotations: configAnnotations, + ManifestAnnotations: annotations, + } + manifestDesc, err = PackManifest(ctx, s, PackManifestVersion1_0, artifactType, opts) + if err != nil { + t.Fatal("Oras.PackManifest() error =", err) + } + + expectedConfigDesc := content.NewDescriptorFromBytes(artifactType, configBytes) + expectedConfigDesc.Annotations = configAnnotations + expectedManifest = ocispec.Manifest{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageManifest, + Config: expectedConfigDesc, + Layers: layers, + Annotations: annotations, + } + expectedManifestBytes, err = json.Marshal(expectedManifest) + if err != nil { + t.Fatal("failed to marshal manifest:", err) + } + + rc, err = s.Fetch(ctx, manifestDesc) + if err != nil { + t.Fatal("Store.Fetch() error =", err) + } + got, err = io.ReadAll(rc) + if err != nil { + t.Fatal("Store.Fetch().Read() error =", err) + } + err = rc.Close() + if err != nil { + t.Error("Store.Fetch().Close() error =", err) + } + if !bytes.Equal(got, expectedManifestBytes) { + t.Errorf("Store.Fetch() = %v, want %v", string(got), string(expectedManifestBytes)) + } + + // verify descriptor + expectedManifestDesc = content.NewDescriptorFromBytes(expectedManifest.MediaType, expectedManifestBytes) + expectedManifestDesc.ArtifactType = expectedManifest.Config.MediaType + expectedManifestDesc.Annotations = expectedManifest.Annotations + if !reflect.DeepEqual(manifestDesc, expectedManifestDesc) { + t.Errorf("PackManifest() = %v, want %v", manifestDesc, expectedManifestDesc) + } +} + +func Test_PackManifest_ImageV1_0_SubjectUnsupported(t *testing.T) { + s := memory.New() + + // prepare test content + artifactType := "application/vnd.test" + subjectManifest := []byte(`{"layers":[]}`) + subjectDesc := ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageManifest, + Digest: digest.FromBytes(subjectManifest), + Size: int64(len(subjectManifest)), + } + + // test Pack with ConfigDescriptor + ctx := context.Background() + opts := PackManifestOptions{ + Subject: &subjectDesc, + } + _, err := PackManifest(ctx, s, PackManifestVersion1_0, artifactType, opts) + if wantErr := errdef.ErrUnsupported; !errors.Is(err, wantErr) { + t.Errorf("Oras.PackManifest() error = %v, wantErr %v", err, wantErr) + } +} + +func Test_PackManifest_ImageV1_0_NoArtifactType(t *testing.T) { + s := memory.New() + + ctx := context.Background() + manifestDesc, err := PackManifest(ctx, s, PackManifestVersion1_0, "", PackManifestOptions{}) + if err != nil { + t.Fatal("Oras.PackManifest() error =", err) + } + + var manifest ocispec.Manifest + rc, err := s.Fetch(ctx, manifestDesc) + if err != nil { + t.Fatal("Store.Fetch() error =", err) + } + if err := json.NewDecoder(rc).Decode(&manifest); err != nil { + t.Fatal("error decoding manifest, error =", err) + } + if err := rc.Close(); err != nil { + t.Fatal("Store.Fetch().Close() error =", err) + } + + // verify artifact type and config media type + if manifestDesc.ArtifactType != MediaTypeUnknownConfig { + t.Fatalf("got artifact type = %s, want %s", manifestDesc.ArtifactType, MediaTypeUnknownConfig) + } + if manifest.Config.MediaType != MediaTypeUnknownConfig { + t.Fatalf("got artifact type = %s, want %s", manifest.Config.MediaType, MediaTypeUnknownConfig) + } +} + +func Test_PackManifest_ImageV1_0_InvalidMediaType(t *testing.T) { + s := memory.New() + + ctx := context.Background() + // test invalid artifact type + valid config media type + artifactType := "random" + configBytes := []byte("{}") + configDesc := content.NewDescriptorFromBytes("application/vnd.test.config", configBytes) + opts := PackManifestOptions{ + ConfigDescriptor: &configDesc, + } + _, err := PackManifest(ctx, s, PackManifestVersion1_0, artifactType, opts) + if err != nil { + t.Error("Oras.PackManifest() error =", err) + } + + // test invalid config media type + valid artifact type + artifactType = "application/vnd.test" + configDesc = content.NewDescriptorFromBytes("random", configBytes) + opts = PackManifestOptions{ + ConfigDescriptor: &configDesc, + } + _, err = PackManifest(ctx, s, PackManifestVersion1_0, artifactType, opts) + if wantErr := errdef.ErrInvalidMediaType; !errors.Is(err, wantErr) { + t.Errorf("Oras.PackManifest() error = %v, wantErr = %v", err, wantErr) + } +} + +func Test_PackManifest_ImageV1_0_InvalidDateTimeFormat(t *testing.T) { + s := memory.New() + + ctx := context.Background() + opts := PackManifestOptions{ + ManifestAnnotations: map[string]string{ + ocispec.AnnotationCreated: "2000/01/01 00:00:00", + }, + } + _, err := PackManifest(ctx, s, PackManifestVersion1_0, "", opts) + if wantErr := ErrInvalidDateTimeFormat; !errors.Is(err, wantErr) { + t.Errorf("Oras.PackManifest() error = %v, wantErr = %v", err, wantErr) + } +} + +func Test_PackManifest_ImageV1_1_RC4(t *testing.T) { + s := memory.New() + + // test PackManifest + ctx := context.Background() + artifactType := "application/vnd.test" + manifestDesc, err := PackManifest(ctx, s, PackManifestVersion1_1_RC4, artifactType, PackManifestOptions{}) + if err != nil { + t.Fatal("Oras.PackManifest() error =", err) + } + + var manifest ocispec.Manifest + rc, err := s.Fetch(ctx, manifestDesc) + if err != nil { + t.Fatal("Store.Fetch() error =", err) + } + if err := json.NewDecoder(rc).Decode(&manifest); err != nil { + t.Fatal("error decoding manifest, error =", err) + } + if err := rc.Close(); err != nil { + t.Fatal("Store.Fetch().Close() error =", err) + } + + // verify layers + expectedLayers := []ocispec.Descriptor{ocispec.DescriptorEmptyJSON} + if !reflect.DeepEqual(manifest.Layers, expectedLayers) { + t.Errorf("got layers = %v, want %v", manifest.Layers, expectedLayers) + } +} + +func Test_PackManifest_ImageV1_1_RC4_WithOptions(t *testing.T) { + s := memory.New() + + // prepare test content + layers := []ocispec.Descriptor{ + content.NewDescriptorFromBytes("test", []byte("hello world")), + content.NewDescriptorFromBytes("test", []byte("goodbye world")), + } + configBytes := []byte("config") + configDesc := content.NewDescriptorFromBytes("application/vnd.test.config", configBytes) + configAnnotations := map[string]string{"foo": "bar"} + annotations := map[string]string{ + ocispec.AnnotationCreated: "2000-01-01T00:00:00Z", + "foo": "bar", + } + artifactType := "application/vnd.test" + subjectManifest := []byte(`{"layers":[]}`) + subjectDesc := ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageManifest, + Digest: digest.FromBytes(subjectManifest), + Size: int64(len(subjectManifest)), + } + + // test PackManifest with ConfigDescriptor + ctx := context.Background() + opts := PackManifestOptions{ + Subject: &subjectDesc, + Layers: layers, + ConfigDescriptor: &configDesc, + ConfigAnnotations: configAnnotations, + ManifestAnnotations: annotations, + } + manifestDesc, err := PackManifest(ctx, s, PackManifestVersion1_1_RC4, artifactType, opts) + if err != nil { + t.Fatal("Oras.PackManifest() error =", err) + } + + expectedManifest := ocispec.Manifest{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageManifest, + ArtifactType: artifactType, + Subject: &subjectDesc, + Config: configDesc, + Layers: layers, + Annotations: annotations, + } + expectedManifestBytes, err := json.Marshal(expectedManifest) + if err != nil { + t.Fatal("failed to marshal manifest:", err) + } + + rc, err := s.Fetch(ctx, manifestDesc) + if err != nil { + t.Fatal("Store.Fetch() error =", err) + } + got, err := io.ReadAll(rc) + if err != nil { + t.Fatal("Store.Fetch().Read() error =", err) + } + err = rc.Close() + if err != nil { + t.Error("Store.Fetch().Close() error =", err) + } + if !bytes.Equal(got, expectedManifestBytes) { + t.Errorf("Store.Fetch() = %v, want %v", string(got), string(expectedManifestBytes)) + } + + // verify descriptor + expectedManifestDesc := content.NewDescriptorFromBytes(expectedManifest.MediaType, expectedManifestBytes) + expectedManifestDesc.ArtifactType = expectedManifest.ArtifactType + expectedManifestDesc.Annotations = expectedManifest.Annotations + if !reflect.DeepEqual(manifestDesc, expectedManifestDesc) { + t.Errorf("PackManifest() = %v, want %v", manifestDesc, expectedManifestDesc) + } + + // test PackManifest with ConfigDescriptor, but without artifactType + opts = PackManifestOptions{ + Subject: &subjectDesc, + Layers: layers, + ConfigDescriptor: &configDesc, + ConfigAnnotations: configAnnotations, + ManifestAnnotations: annotations, + } + manifestDesc, err = PackManifest(ctx, s, PackManifestVersion1_1_RC4, "", opts) + if err != nil { + t.Fatal("Oras.PackManifest() error =", err) + } + + expectedManifest = ocispec.Manifest{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageManifest, + Subject: &subjectDesc, + Config: configDesc, + Layers: layers, + Annotations: annotations, + } + expectedManifestBytes, err = json.Marshal(expectedManifest) + if err != nil { + t.Fatal("failed to marshal manifest:", err) + } + + rc, err = s.Fetch(ctx, manifestDesc) + if err != nil { + t.Fatal("Store.Fetch() error =", err) + } + got, err = io.ReadAll(rc) + if err != nil { + t.Fatal("Store.Fetch().Read() error =", err) + } + err = rc.Close() + if err != nil { + t.Error("Store.Fetch().Close() error =", err) + } + if !bytes.Equal(got, expectedManifestBytes) { + t.Errorf("Store.Fetch() = %v, want %v", string(got), string(expectedManifestBytes)) + } + + // verify descriptor + expectedManifestDesc = content.NewDescriptorFromBytes(expectedManifest.MediaType, expectedManifestBytes) + expectedManifestDesc.ArtifactType = expectedManifest.ArtifactType + expectedManifestDesc.Annotations = expectedManifest.Annotations + if !reflect.DeepEqual(manifestDesc, expectedManifestDesc) { + t.Errorf("PackManifest() = %v, want %v", manifestDesc, expectedManifestDesc) + } + + // test Pack without ConfigDescriptor + opts = PackManifestOptions{ + Subject: &subjectDesc, + Layers: layers, + ConfigAnnotations: configAnnotations, + ManifestAnnotations: annotations, + } + manifestDesc, err = PackManifest(ctx, s, PackManifestVersion1_1_RC4, artifactType, opts) + if err != nil { + t.Fatal("Oras.PackManifest() error =", err) + } + + expectedConfigDesc := ocispec.DescriptorEmptyJSON + expectedConfigDesc.Annotations = configAnnotations + expectedManifest = ocispec.Manifest{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageManifest, + ArtifactType: artifactType, + Subject: &subjectDesc, + Config: expectedConfigDesc, + Layers: layers, + Annotations: annotations, + } + expectedManifestBytes, err = json.Marshal(expectedManifest) + if err != nil { + t.Fatal("failed to marshal manifest:", err) + } + + rc, err = s.Fetch(ctx, manifestDesc) + if err != nil { + t.Fatal("Store.Fetch() error =", err) + } + got, err = io.ReadAll(rc) + if err != nil { + t.Fatal("Store.Fetch().Read() error =", err) + } + err = rc.Close() + if err != nil { + t.Error("Store.Fetch().Close() error =", err) + } + if !bytes.Equal(got, expectedManifestBytes) { + t.Errorf("Store.Fetch() = %v, want %v", string(got), string(expectedManifestBytes)) + } + + // verify descriptor + expectedManifestDesc = content.NewDescriptorFromBytes(expectedManifest.MediaType, expectedManifestBytes) + expectedManifestDesc.ArtifactType = expectedManifest.ArtifactType + expectedManifestDesc.Annotations = expectedManifest.Annotations + if !reflect.DeepEqual(manifestDesc, expectedManifestDesc) { + t.Errorf("PackManifest() = %v, want %v", manifestDesc, expectedManifestDesc) + } +} + +func Test_PackManifest_ImageV1_1_RC4_NoArtifactType(t *testing.T) { + s := memory.New() + + ctx := context.Background() + // test no artifact type and no config + _, err := PackManifest(ctx, s, PackManifestVersion1_1_RC4, "", PackManifestOptions{}) + if wantErr := ErrMissingArtifactType; !errors.Is(err, wantErr) { + t.Errorf("Oras.PackManifest() error = %v, wantErr = %v", err, wantErr) + } + + // test no artifact type and config with empty media type + opts := PackManifestOptions{ + ConfigDescriptor: &ocispec.Descriptor{ + MediaType: ocispec.DescriptorEmptyJSON.MediaType, + }, + } + _, err = PackManifest(ctx, s, PackManifestVersion1_1_RC4, "", opts) + if wantErr := ErrMissingArtifactType; !errors.Is(err, wantErr) { + t.Errorf("Oras.PackManifest() error = %v, wantErr = %v", err, wantErr) + } +} + +func Test_PackManifest_ImageV1_1_RC4_InvalidMediaType(t *testing.T) { + s := memory.New() + + ctx := context.Background() + // test invalid artifact type + valid config media type + artifactType := "random" + configBytes := []byte("{}") + configDesc := content.NewDescriptorFromBytes("application/vnd.test.config", configBytes) + opts := PackManifestOptions{ + ConfigDescriptor: &configDesc, + } + _, err := PackManifest(ctx, s, PackManifestVersion1_1_RC4, artifactType, opts) + if wantErr := errdef.ErrInvalidMediaType; !errors.Is(err, wantErr) { + t.Errorf("Oras.PackManifest() error = %v, wantErr = %v", err, wantErr) + } + + // test invalid config media type + invalid artifact type + artifactType = "application/vnd.test" + configDesc = content.NewDescriptorFromBytes("random", configBytes) + opts = PackManifestOptions{ + ConfigDescriptor: &configDesc, + } + _, err = PackManifest(ctx, s, PackManifestVersion1_1_RC4, artifactType, opts) + if wantErr := errdef.ErrInvalidMediaType; !errors.Is(err, wantErr) { + t.Errorf("Oras.PackManifest() error = %v, wantErr = %v", err, wantErr) + } +} + +func Test_PackManifest_ImageV1_1_RC4_InvalidDateTimeFormat(t *testing.T) { + s := memory.New() + + ctx := context.Background() + opts := PackManifestOptions{ + ManifestAnnotations: map[string]string{ + ocispec.AnnotationCreated: "2000/01/01 00:00:00", + }, + } + artifactType := "application/vnd.test" + _, err := PackManifest(ctx, s, PackManifestVersion1_1_RC4, artifactType, opts) + if wantErr := ErrInvalidDateTimeFormat; !errors.Is(err, wantErr) { + t.Errorf("Oras.PackManifest() error = %v, wantErr = %v", err, wantErr) + } +} + +func Test_PackManifest_UnsupportedPackManifestVersion(t *testing.T) { + s := memory.New() + + ctx := context.Background() + _, err := PackManifest(ctx, s, -1, "", PackManifestOptions{}) + if wantErr := errdef.ErrUnsupported; !errors.Is(err, wantErr) { + t.Errorf("Oras.PackManifest() error = %v, wantErr = %v", err, wantErr) } } diff --git a/registry/reference.go b/registry/reference.go index cea579a1..a4d2003d 100644 --- a/registry/reference.go +++ b/registry/reference.go @@ -31,14 +31,16 @@ var ( // repository name set under OCI distribution spec is a subset of the docker // spec. For maximum compatability, the docker spec is verified client-side. // Further checks are left to the server-side. + // // References: - // - https://github.com/distribution/distribution/blob/v2.7.1/reference/regexp.go#L53 - // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#pulling-manifests + // - https://github.com/distribution/distribution/blob/v2.7.1/reference/regexp.go#L53 + // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#pulling-manifests repositoryRegexp = regexp.MustCompile(`^[a-z0-9]+(?:(?:[._]|__|[-]*)[a-z0-9]+)*(?:/[a-z0-9]+(?:(?:[._]|__|[-]*)[a-z0-9]+)*)*$`) // tagRegexp checks the tag name. // The docker and OCI spec have the same regular expression. - // Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#pulling-manifests + // + // Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#pulling-manifests tagRegexp = regexp.MustCompile(`^[\w][\w.-]{0,127}$`) ) @@ -115,7 +117,7 @@ func ParseReference(artifact string) (Reference, error) { parts := strings.SplitN(artifact, "/", 2) if len(parts) == 1 { // Invalid Form - return Reference{}, fmt.Errorf("%w: missing repository", errdef.ErrInvalidReference) + return Reference{}, fmt.Errorf("%w: missing registry or repository", errdef.ErrInvalidReference) } registry, path := parts[0], parts[1] @@ -188,7 +190,7 @@ func (r Reference) Validate() error { // ValidateRegistry validates the registry. func (r Reference) ValidateRegistry() error { if uri, err := url.ParseRequestURI("dummy://" + r.Registry); err != nil || uri.Host != r.Registry { - return fmt.Errorf("%w: invalid registry", errdef.ErrInvalidReference) + return fmt.Errorf("%w: invalid registry %q", errdef.ErrInvalidReference, r.Registry) } return nil } @@ -196,7 +198,7 @@ func (r Reference) ValidateRegistry() error { // ValidateRepository validates the repository. func (r Reference) ValidateRepository() error { if !repositoryRegexp.MatchString(r.Repository) { - return fmt.Errorf("%w: invalid repository", errdef.ErrInvalidReference) + return fmt.Errorf("%w: invalid repository %q", errdef.ErrInvalidReference, r.Repository) } return nil } @@ -204,7 +206,7 @@ func (r Reference) ValidateRepository() error { // ValidateReferenceAsTag validates the reference as a tag. func (r Reference) ValidateReferenceAsTag() error { if !tagRegexp.MatchString(r.Reference) { - return fmt.Errorf("%w: invalid tag", errdef.ErrInvalidReference) + return fmt.Errorf("%w: invalid tag %q", errdef.ErrInvalidReference, r.Reference) } return nil } @@ -212,7 +214,7 @@ func (r Reference) ValidateReferenceAsTag() error { // ValidateReferenceAsDigest validates the reference as a digest. func (r Reference) ValidateReferenceAsDigest() error { if _, err := r.Digest(); err != nil { - return fmt.Errorf("%w: invalid digest; %v", errdef.ErrInvalidReference, err) + return fmt.Errorf("%w: invalid digest %q: %v", errdef.ErrInvalidReference, r.Reference, err) } return nil } diff --git a/registry/remote/auth/client.go b/registry/remote/auth/client.go index 37eb65ed..58355161 100644 --- a/registry/remote/auth/client.go +++ b/registry/remote/auth/client.go @@ -54,16 +54,23 @@ var maxResponseBytes int64 = 128 * 1024 // 128 KiB // See also ClientID. var defaultClientID = "oras-go" +// CredentialFunc represents a function that resolves the credential for the +// given registry (i.e. host:port). +// +// [EmptyCredential] is a valid return value and should not be considered as +// an error. +type CredentialFunc func(ctx context.Context, hostport string) (Credential, error) + // StaticCredential specifies static credentials for the given host. -func StaticCredential(registry string, cred Credential) func(context.Context, string) (Credential, error) { +func StaticCredential(registry string, cred Credential) CredentialFunc { if registry == "docker.io" { // it is expected that traffic targeting "docker.io" will be redirected // to "registry-1.docker.io" // reference: https://github.com/moby/moby/blob/v24.0.0-beta.2/registry/config.go#L25-L48 registry = "registry-1.docker.io" } - return func(_ context.Context, target string) (Credential, error) { - if target == registry { + return func(_ context.Context, hostport string) (Credential, error) { + if hostport == registry { return cred, nil } return EmptyCredential, nil @@ -88,10 +95,10 @@ type Client struct { // Credential specifies the function for resolving the credential for the // given registry (i.e. host:port). - // `EmptyCredential` is a valid return value and should not be considered as + // EmptyCredential is a valid return value and should not be considered as // an error. - // If nil, the credential is always resolved to `EmptyCredential`. - Credential func(context.Context, string) (Credential, error) + // If nil, the credential is always resolved to EmptyCredential. + Credential CredentialFunc // Cache caches credentials for direct accessing the remote registry. // If nil, no cache is used. @@ -170,19 +177,19 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { // attempt cached auth token var attemptedKey string cache := c.cache() - registry := originalReq.Host - scheme, err := cache.GetScheme(ctx, registry) + host := originalReq.Host + scheme, err := cache.GetScheme(ctx, host) if err == nil { switch scheme { case SchemeBasic: - token, err := cache.GetToken(ctx, registry, SchemeBasic, "") + token, err := cache.GetToken(ctx, host, SchemeBasic, "") if err == nil { req.Header.Set("Authorization", "Basic "+token) } case SchemeBearer: - scopes := GetScopes(ctx) + scopes := GetAllScopesForHost(ctx, host) attemptedKey = strings.Join(scopes, " ") - token, err := cache.GetToken(ctx, registry, SchemeBearer, attemptedKey) + token, err := cache.GetToken(ctx, host, SchemeBearer, attemptedKey) if err == nil { req.Header.Set("Authorization", "Bearer "+token) } @@ -204,8 +211,8 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { case SchemeBasic: resp.Body.Close() - token, err := cache.Set(ctx, registry, SchemeBasic, "", func(ctx context.Context) (string, error) { - return c.fetchBasicAuth(ctx, registry) + token, err := cache.Set(ctx, host, SchemeBasic, "", func(ctx context.Context) (string, error) { + return c.fetchBasicAuth(ctx, host) }) if err != nil { return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err) @@ -216,17 +223,17 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { case SchemeBearer: resp.Body.Close() - // merge hinted scopes with challenged scopes - scopes := GetScopes(ctx) - if scope := params["scope"]; scope != "" { - scopes = append(scopes, strings.Split(scope, " ")...) + scopes := GetAllScopesForHost(ctx, host) + if paramScope := params["scope"]; paramScope != "" { + // merge hinted scopes with challenged scopes + scopes = append(scopes, strings.Split(paramScope, " ")...) scopes = CleanScopes(scopes) } key := strings.Join(scopes, " ") // attempt the cache again if there is a scope change if key != attemptedKey { - if token, err := cache.GetToken(ctx, registry, SchemeBearer, key); err == nil { + if token, err := cache.GetToken(ctx, host, SchemeBearer, key); err == nil { req = originalReq.Clone(ctx) req.Header.Set("Authorization", "Bearer "+token) if err := rewindRequestBody(req); err != nil { @@ -247,8 +254,8 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { // attempt with credentials realm := params["realm"] service := params["service"] - token, err := cache.Set(ctx, registry, SchemeBearer, key, func(ctx context.Context) (string, error) { - return c.fetchBearerToken(ctx, registry, realm, service, scopes) + token, err := cache.Set(ctx, host, SchemeBearer, key, func(ctx context.Context) (string, error) { + return c.fetchBearerToken(ctx, host, realm, service, scopes) }) if err != nil { return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err) diff --git a/registry/remote/auth/client_test.go b/registry/remote/auth/client_test.go index 9e5ed69d..de879863 100644 --- a/registry/remote/auth/client_test.go +++ b/registry/remote/auth/client_test.go @@ -449,6 +449,205 @@ func TestClient_Do_Bearer_AccessToken_Cached(t *testing.T) { } } +func TestClient_Do_Bearer_AccessToken_Cached_PerHost(t *testing.T) { + as := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + })) + defer as.Close() + // set up server 1 + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var service1 string + scope1 := "repository:test:pull" + accessToken1 := "test/access/token/1" + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service1, scope1) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + AccessToken: accessToken1, + }), + Cache: NewCache(), + } + + // set up server 2 + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var service2 string + scope2 := "repository:test:pull,push" + accessToken2 := "test/access/token/2" + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service2, scope2) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + AccessToken: accessToken2, + }), + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesForHost(ctx, uri1.Host, scope1) + ctx = WithScopesForHost(ctx, uri2.Host, scope2) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount1 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + + // repeated request to server 1 + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1++; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + // repeated request to server 2 + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2++; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + + // credential change for server 1 + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client1.Credential = StaticCredential(uri1.Host, Credential{ + AccessToken: accessToken1, + }) + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + // credential change for server 2 + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client2.Credential = StaticCredential(uri2.Host, Credential{ + AccessToken: accessToken2, + }) + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } +} + func TestClient_Do_Bearer_Auth(t *testing.T) { username := "test_user" password := "test_password" @@ -725,6 +924,297 @@ func TestClient_Do_Bearer_Auth_Cached(t *testing.T) { } } +func TestClient_Do_Bearer_Auth_Cached_PerHost(t *testing.T) { + // set up server 1 + username1 := "test_user1" + password1 := "test_password1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ + "repository:src:pull", + } + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + header := "Basic " + base64.StdEncoding.EncodeToString([]byte(username1+":"+password1)) + if auth := r.Header.Get("Authorization"); auth != header { + t.Errorf("unexpected auth: got %s, want %s", auth, header) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.URL.Query().Get("service"); got != service1 { + t.Errorf("unexpected service: got %s, want %s", got, service1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.URL.Query()["scope"]; !reflect.DeepEqual(got, scopes1) { + t.Errorf("unexpected scope: got %s, want %s", got, scopes1) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, strings.Join(scopes1, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }), + Cache: NewCache(), + } + + // set up server 2 + username2 := "test_user2" + password2 := "test_password2" + accessToken2 := "test/access/token/1" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + } + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + header := "Basic " + base64.StdEncoding.EncodeToString([]byte(username2+":"+password2)) + if auth := r.Header.Get("Authorization"); auth != header { + t.Errorf("unexpected auth: got %s, want %s", auth, header) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.URL.Query().Get("service"); got != service2 { + t.Errorf("unexpected service: got %s, want %s", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.URL.Query()["scope"]; !reflect.DeepEqual(got, scopes2) { + t.Errorf("unexpected scope: got %s, want %s", got, scopes2) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, strings.Join(scopes2, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }), + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // repeated request to server 1 + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1++; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // repeated request to server 2 + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2++; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // credential change for server 1 + username1 = "test_user1_new" + password1 = "test_password1_new" + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client1.Credential = StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }) + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // credential change for server 2 + username2 = "test_user2_new" + password2 = "test_password2_new" + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client2.Credential = StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }) + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } +} + func TestClient_Do_Bearer_OAuth2_Password(t *testing.T) { username := "test_user" password := "test_password" @@ -1043,18 +1533,19 @@ func TestClient_Do_Bearer_OAuth2_Password_Cached(t *testing.T) { } } -func TestClient_Do_Bearer_OAuth2_RefreshToken(t *testing.T) { - refreshToken := "test/refresh/token" - accessToken := "test/access/token" - var requestCount, wantRequestCount int64 - var successCount, wantSuccessCount int64 - var authCount, wantAuthCount int64 - var service string - scopes := []string{ - "repository:dst:pull,push", +func TestClient_Do_Bearer_OAuth2_Password_Cached_PerHost(t *testing.T) { + // set up server 1 + username1 := "test_user1" + password1 := "test_password1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ "repository:src:pull", } - as := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost || r.URL.Path != "/" { t.Error("unexecuted attempt of authorization service") w.WriteHeader(http.StatusUnauthorized) @@ -1065,13 +1556,13 @@ func TestClient_Do_Bearer_OAuth2_RefreshToken(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) return } - if got := r.PostForm.Get("grant_type"); got != "refresh_token" { - t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") w.WriteHeader(http.StatusUnauthorized) return } - if got := r.PostForm.Get("service"); got != service { - t.Errorf("unexpected service: %v, want %v", got, service) + if got := r.PostForm.Get("service"); got != service1 { + t.Errorf("unexpected service: %v, want %v", got, service1) w.WriteHeader(http.StatusUnauthorized) return } @@ -1080,108 +1571,298 @@ func TestClient_Do_Bearer_OAuth2_RefreshToken(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) return } - scope := strings.Join(scopes, " ") + scope := strings.Join(scopes1, " ") if got := r.PostForm.Get("scope"); got != scope { t.Errorf("unexpected scope: %v, want %v", got, scope) w.WriteHeader(http.StatusUnauthorized) return } - if got := r.PostForm.Get("refresh_token"); got != refreshToken { - t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken) + if got := r.PostForm.Get("username"); got != username1 { + t.Errorf("unexpected username: %v, want %v", got, username1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password1 { + t.Errorf("unexpected password: %v, want %v", got, password1) w.WriteHeader(http.StatusUnauthorized) return } - atomic.AddInt64(&authCount, 1) - if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken); err != nil { + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { t.Errorf("failed to write %q: %v", r.URL, err) } })) - defer as.Close() - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt64(&requestCount, 1) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) if r.Method != http.MethodGet || r.URL.Path != "/" { t.Errorf("unexpected access: %s %s", r.Method, r.URL) w.WriteHeader(http.StatusNotFound) return } - header := "Bearer " + accessToken + header := "Bearer " + accessToken1 if auth := r.Header.Get("Authorization"); auth != header { - challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service, strings.Join(scopes, " ")) + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, strings.Join(scopes1, " ")) w.Header().Set("Www-Authenticate", challenge) w.WriteHeader(http.StatusUnauthorized) return } - atomic.AddInt64(&successCount, 1) + atomic.AddInt64(&successCount1, 1) })) - defer ts.Close() - uri, err := url.Parse(ts.URL) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) if err != nil { t.Fatalf("invalid test http server: %v", err) } - service = uri.Host + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }), + ForceAttemptOAuth2: true, + Cache: NewCache(), + } + // set up server 2 + username2 := "test_user2" + password2 := "test_password2" + accessToken2 := "test/access/token/2" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + } + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service2 { + t.Errorf("unexpected service: %v, want %v", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes2, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("username"); got != username2 { + t.Errorf("unexpected username: %v, want %v", got, username2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password2 { + t.Errorf("unexpected password: %v, want %v", got, password2) + w.WriteHeader(http.StatusUnauthorized) + return + } - client := &Client{ - Credential: func(ctx context.Context, reg string) (Credential, error) { - if reg != uri.Host { - err := fmt.Errorf("registry mismatch: got %v, want %v", reg, uri.Host) - t.Error(err) - return EmptyCredential, err - } - return Credential{ - RefreshToken: refreshToken, - }, nil - }, + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as2.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, strings.Join(scopes2, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }), + ForceAttemptOAuth2: true, + Cache: NewCache(), } - // first request - req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + ctx := context.Background() + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) if err != nil { t.Fatalf("failed to create test request: %v", err) } - resp, err := client.Do(req) + resp1, err := client1.Do(req1) if err != nil { t.Fatalf("Client.Do() error = %v", err) } - if resp.StatusCode != http.StatusOK { - t.Errorf("Client.Do() = %v, want %v", resp.StatusCode, http.StatusOK) + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) } - if wantRequestCount += 2; requestCount != wantRequestCount { - t.Errorf("unexpected number of requests: %d, want %d", requestCount, wantRequestCount) + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) } - if wantSuccessCount++; successCount != wantSuccessCount { - t.Errorf("unexpected number of successful requests: %d, want %d", successCount, wantSuccessCount) + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) } - if wantAuthCount++; authCount != wantAuthCount { - t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) } - // credential change - refreshToken = "test/refresh/token/2" - accessToken = "test/access/token/2" - req, err = http.NewRequest(http.MethodGet, ts.URL, nil) + // repeated request to server 1 + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) if err != nil { t.Fatalf("failed to create test request: %v", err) } - resp, err = client.Do(req) + resp1, err = client1.Do(req1) if err != nil { t.Fatalf("Client.Do() error = %v", err) } - if resp.StatusCode != http.StatusOK { - t.Errorf("Client.Do() = %v, want %v", resp.StatusCode, http.StatusOK) + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) } - if wantRequestCount += 2; requestCount != wantRequestCount { - t.Errorf("unexpected number of requests: %d, want %d", requestCount, wantRequestCount) + if wantRequestCount1++; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) } - if wantSuccessCount++; successCount != wantSuccessCount { - t.Errorf("unexpected number of successful requests: %d, want %d", successCount, wantSuccessCount) + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) } - if wantAuthCount++; authCount != wantAuthCount { - t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) + if authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // repeated request to server 2 + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2++; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // credential change for server 1 + username1 = "test_user1_new" + password1 = "test_password1_new" + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client1.Credential = StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }) + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // credential change for server 2 + username2 = "test_user2_new" + password2 = "test_password2_new" + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client2.Credential = StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }) + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) } } -func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached(t *testing.T) { +func TestClient_Do_Bearer_OAuth2_RefreshToken(t *testing.T) { refreshToken := "test/refresh/token" accessToken := "test/access/token" var requestCount, wantRequestCount int64 @@ -1270,12 +1951,10 @@ func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached(t *testing.T) { RefreshToken: refreshToken, }, nil }, - Cache: NewCache(), } // first request - ctx := WithScopes(context.Background(), scopes...) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) if err != nil { t.Fatalf("failed to create test request: %v", err) } @@ -1296,32 +1975,10 @@ func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached(t *testing.T) { t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) } - // repeated request - req, err = http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) - if err != nil { - t.Fatalf("failed to create test request: %v", err) - } - resp, err = client.Do(req) - if err != nil { - t.Fatalf("Client.Do() error = %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("Client.Do() = %v, want %v", resp.StatusCode, http.StatusOK) - } - if wantRequestCount++; requestCount != wantRequestCount { - t.Errorf("unexpected number of requests: %d, want %d", requestCount, wantRequestCount) - } - if wantSuccessCount++; successCount != wantSuccessCount { - t.Errorf("unexpected number of successful requests: %d, want %d", successCount, wantSuccessCount) - } - if authCount != wantAuthCount { - t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) - } - // credential change refreshToken = "test/refresh/token/2" accessToken = "test/access/token/2" - req, err = http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) + req, err = http.NewRequest(http.MethodGet, ts.URL, nil) if err != nil { t.Fatalf("failed to create test request: %v", err) } @@ -1343,7 +2000,7 @@ func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached(t *testing.T) { } } -func TestClient_Do_Token_Expire(t *testing.T) { +func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached(t *testing.T) { refreshToken := "test/refresh/token" accessToken := "test/access/token" var requestCount, wantRequestCount int64 @@ -1458,7 +2115,30 @@ func TestClient_Do_Token_Expire(t *testing.T) { t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) } - // invalidate the access token and request again + // repeated request + req, err = http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp, err = client.Do(req) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp.StatusCode, http.StatusOK) + } + if wantRequestCount++; requestCount != wantRequestCount { + t.Errorf("unexpected number of requests: %d, want %d", requestCount, wantRequestCount) + } + if wantSuccessCount++; successCount != wantSuccessCount { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount, wantSuccessCount) + } + if authCount != wantAuthCount { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) + } + + // credential change + refreshToken = "test/refresh/token/2" accessToken = "test/access/token/2" req, err = http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) if err != nil { @@ -1482,20 +2162,18 @@ func TestClient_Do_Token_Expire(t *testing.T) { } } -func TestClient_Do_Scope_Hint_Mismatch(t *testing.T) { - username := "test_user" - password := "test_password" - accessToken := "test/access/token" - var requestCount, wantRequestCount int64 - var successCount, wantSuccessCount int64 - var authCount, wantAuthCount int64 - var service string - scopes := []string{ - "repository:dst:pull,push", +func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached_PerHost(t *testing.T) { + // set up server 1 + refreshToken1 := "test/refresh/token/1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ "repository:src:pull", } - scope := "repository:test:delete" - as := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost || r.URL.Path != "/" { t.Error("unexecuted attempt of authorization service") w.WriteHeader(http.StatusUnauthorized) @@ -1506,13 +2184,13 @@ func TestClient_Do_Scope_Hint_Mismatch(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) return } - if got := r.PostForm.Get("grant_type"); got != "password" { - t.Errorf("unexpected grant type: %v, want %v", got, "password") + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") w.WriteHeader(http.StatusUnauthorized) return } - if got := r.PostForm.Get("service"); got != service { - t.Errorf("unexpected service: %v, want %v", got, service) + if got := r.PostForm.Get("service"); got != service1 { + t.Errorf("unexpected service: %v, want %v", got, service1) w.WriteHeader(http.StatusUnauthorized) return } @@ -1521,54 +2199,765 @@ func TestClient_Do_Scope_Hint_Mismatch(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) return } - scopes := CleanScopes(append([]string{scope}, scopes...)) - scope := strings.Join(scopes, " ") + scope := strings.Join(scopes1, " ") if got := r.PostForm.Get("scope"); got != scope { t.Errorf("unexpected scope: %v, want %v", got, scope) w.WriteHeader(http.StatusUnauthorized) return } - if got := r.PostForm.Get("username"); got != username { - t.Errorf("unexpected username: %v, want %v", got, username) - w.WriteHeader(http.StatusUnauthorized) - return - } - if got := r.PostForm.Get("password"); got != password { - t.Errorf("unexpected password: %v, want %v", got, password) + if got := r.PostForm.Get("refresh_token"); got != refreshToken1 { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken1) w.WriteHeader(http.StatusUnauthorized) return } - atomic.AddInt64(&authCount, 1) - if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken); err != nil { + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { t.Errorf("failed to write %q: %v", r.URL, err) } })) - defer as.Close() - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt64(&requestCount, 1) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) if r.Method != http.MethodGet || r.URL.Path != "/" { t.Errorf("unexpected access: %s %s", r.Method, r.URL) w.WriteHeader(http.StatusNotFound) return } - header := "Bearer " + accessToken + header := "Bearer " + accessToken1 if auth := r.Header.Get("Authorization"); auth != header { - challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service, scope) + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, strings.Join(scopes1, " ")) w.Header().Set("Www-Authenticate", challenge) w.WriteHeader(http.StatusUnauthorized) return } - atomic.AddInt64(&successCount, 1) + atomic.AddInt64(&successCount1, 1) })) - defer ts.Close() - uri, err := url.Parse(ts.URL) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) if err != nil { t.Fatalf("invalid test http server: %v", err) } - service = uri.Host - - client := &Client{ + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + RefreshToken: refreshToken1, + }), + Cache: NewCache(), + } + + // set up server 2 + refreshToken2 := "test/refresh/token/1" + accessToken2 := "test/access/token/1" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + } + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service2 { + t.Errorf("unexpected service: %v, want %v", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes2, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("refresh_token"); got != refreshToken2 { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken2) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as2.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, strings.Join(scopes2, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + RefreshToken: refreshToken2, + }), + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // repeated request to server 1 + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1++; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // repeated request to server 2 + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2++; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // credential change to server 1 + refreshToken1 = "test/refresh/token/1/new" + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client1.Credential = StaticCredential(uri1.Host, Credential{ + RefreshToken: refreshToken1, + }) + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // credential change to server 2 + refreshToken2 = "test/refresh/token/2/new" + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client2.Credential = StaticCredential(uri2.Host, Credential{ + RefreshToken: refreshToken2, + }) + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } +} + +func TestClient_Do_Token_Expire(t *testing.T) { + refreshToken := "test/refresh/token" + accessToken := "test/access/token" + var requestCount, wantRequestCount int64 + var successCount, wantSuccessCount int64 + var authCount, wantAuthCount int64 + var service string + scopes := []string{ + "repository:dst:pull,push", + "repository:src:pull", + } + as := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service { + t.Errorf("unexpected service: %v, want %v", got, service) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("refresh_token"); got != refreshToken { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as.Close() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service, strings.Join(scopes, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount, 1) + })) + defer ts.Close() + uri, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service = uri.Host + + client := &Client{ + Credential: func(ctx context.Context, reg string) (Credential, error) { + if reg != uri.Host { + err := fmt.Errorf("registry mismatch: got %v, want %v", reg, uri.Host) + t.Error(err) + return EmptyCredential, err + } + return Credential{ + RefreshToken: refreshToken, + }, nil + }, + Cache: NewCache(), + } + + // first request + ctx := WithScopes(context.Background(), scopes...) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp.StatusCode, http.StatusOK) + } + if wantRequestCount += 2; requestCount != wantRequestCount { + t.Errorf("unexpected number of requests: %d, want %d", requestCount, wantRequestCount) + } + if wantSuccessCount++; successCount != wantSuccessCount { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount, wantSuccessCount) + } + if wantAuthCount++; authCount != wantAuthCount { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) + } + + // invalidate the access token and request again + accessToken = "test/access/token/2" + req, err = http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp, err = client.Do(req) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp.StatusCode, http.StatusOK) + } + if wantRequestCount += 2; requestCount != wantRequestCount { + t.Errorf("unexpected number of requests: %d, want %d", requestCount, wantRequestCount) + } + if wantSuccessCount++; successCount != wantSuccessCount { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount, wantSuccessCount) + } + if wantAuthCount++; authCount != wantAuthCount { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) + } +} + +func TestClient_Do_Token_Expire_PerHost(t *testing.T) { + // set up server 1 + refreshToken1 := "test/refresh/token/1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ + "repository:src:pull", + } + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service1 { + t.Errorf("unexpected service: %v, want %v", got, service1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes1, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("refresh_token"); got != refreshToken1 { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken1) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, strings.Join(scopes1, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + RefreshToken: refreshToken1, + }), + Cache: NewCache(), + } + // set up server 2 + refreshToken2 := "test/refresh/token/2" + accessToken2 := "test/access/token/2" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + } + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service2 { + t.Errorf("unexpected service: %v, want %v", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes2, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("refresh_token"); got != refreshToken2 { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken2) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as2.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, strings.Join(scopes2, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + RefreshToken: refreshToken2, + }), + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // invalidate the access token and request again to server 1 + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // invalidate the access token and request again to server 2 + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } +} + +func TestClient_Do_Scope_Hint_Mismatch(t *testing.T) { + username := "test_user" + password := "test_password" + accessToken := "test/access/token" + var requestCount, wantRequestCount int64 + var successCount, wantSuccessCount int64 + var authCount, wantAuthCount int64 + var service string + scopes := []string{ + "repository:dst:pull,push", + "repository:src:pull", + } + scope := "repository:test:delete" + as := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service { + t.Errorf("unexpected service: %v, want %v", got, service) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scopes := CleanScopes(append([]string{scope}, scopes...)) + scope := strings.Join(scopes, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("username"); got != username { + t.Errorf("unexpected username: %v, want %v", got, username) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password { + t.Errorf("unexpected password: %v, want %v", got, password) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as.Close() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service, scope) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount, 1) + })) + defer ts.Close() + uri, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service = uri.Host + + client := &Client{ Credential: func(ctx context.Context, reg string) (Credential, error) { if reg != uri.Host { err := fmt.Errorf("registry mismatch: got %v, want %v", reg, uri.Host) @@ -1633,6 +3022,293 @@ func TestClient_Do_Scope_Hint_Mismatch(t *testing.T) { } } +func TestClient_Do_Scope_Hint_Mismatch_PerHost(t *testing.T) { + // set up server 1 + username1 := "test_user1" + password1 := "test_password1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ + "repository:dst:pull,push", + "repository:src:pull", + } + scope1 := "repository:test1:delete" + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service1 { + t.Errorf("unexpected service: %v, want %v", got, service1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scopes := CleanScopes(append([]string{scope1}, scopes1...)) + scope := strings.Join(scopes, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("username"); got != username1 { + t.Errorf("unexpected username: %v, want %v", got, username1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password1 { + t.Errorf("unexpected password: %v, want %v", got, password1) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, scope1) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }), + ForceAttemptOAuth2: true, + Cache: NewCache(), + } + + // set up server 1 + username2 := "test_user2" + password2 := "test_password2" + accessToken2 := "test/access/token/2" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + "repository:src:pull", + } + scope2 := "repository:test2:delete" + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service2 { + t.Errorf("unexpected service: %v, want %v", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scopes := CleanScopes(append([]string{scope2}, scopes2...)) + scope := strings.Join(scopes, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("username"); got != username2 { + t.Errorf("unexpected username: %v, want %v", got, username2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password2 { + t.Errorf("unexpected password: %v, want %v", got, password2) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as2.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, scope2) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts1.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }), + ForceAttemptOAuth2: true, + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // first request to server 1 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // repeated request to server 1 + // although the actual scope does not match the hinted scopes, the client + // with cache cannot avoid a request to obtain a challenge but can prevent + // a repeated call to the authorization server. + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // repeated request to server 2 + // although the actual scope does not match the hinted scopes, the client + // with cache cannot avoid a request to obtain a challenge but can prevent + // a repeated call to the authorization server. + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } +} + func TestClient_Do_Invalid_Credential_Basic(t *testing.T) { username := "test_user" password := "test_password" diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 24a0f898..fabc2af2 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -19,6 +19,9 @@ import ( "context" "sort" "strings" + + "oras.land/oras-go/v2/internal/slices" + "oras.land/oras-go/v2/registry" ) // Actions used in scopes. @@ -54,6 +57,28 @@ func ScopeRepository(repository string, actions ...string) string { }, ":") } +// AppendRepositoryScope returns a new context containing scope hints for the +// auth client to fetch bearer tokens with the given actions on the repository. +// If called multiple times, the new scopes will be appended to the existing +// scopes. The resulted scopes are de-duplicated. +// +// For example, uploading blob to the repository "hello-world" does HEAD request +// first then POST and PUT. The HEAD request will return a challenge for scope +// `repository:hello-world:pull`, and the auth client will fetch a token for +// that challenge. Later, the POST request will return a challenge for scope +// `repository:hello-world:push`, and the auth client will fetch a token for +// that challenge again. By invoking AppendRepositoryScope with the actions +// [ActionPull] and [ActionPush] for the repository `hello-world`, +// the auth client with cache is hinted to fetch a token via a single token +// fetch request for all the HEAD, POST, PUT requests. +func AppendRepositoryScope(ctx context.Context, ref registry.Reference, actions ...string) context.Context { + if len(actions) == 0 { + return ctx + } + scope := ScopeRepository(ref.Repository, actions...) + return AppendScopesForHost(ctx, ref.Host(), scope) +} + // scopesContextKey is the context key for scopes. type scopesContextKey struct{} @@ -66,7 +91,7 @@ type scopesContextKey struct{} // `repository:hello-world:pull`, and the auth client will fetch a token for // that challenge. Later, the POST request will return a challenge for scope // `repository:hello-world:push`, and the auth client will fetch a token for -// that challenge again. By invoking `WithScopes()` with the scope +// that challenge again. By invoking WithScopes with the scope // `repository:hello-world:pull,push`, the auth client with cache is hinted to // fetch a token via a single token fetch request for all the HEAD, POST, PUT // requests. @@ -93,11 +118,76 @@ func AppendScopes(ctx context.Context, scopes ...string) context.Context { // GetScopes returns the scopes in the context. func GetScopes(ctx context.Context) []string { if scopes, ok := ctx.Value(scopesContextKey{}).([]string); ok { - return append([]string(nil), scopes...) + return slices.Clone(scopes) + } + return nil +} + +// scopesForHostContextKey is the context key for per-host scopes. +type scopesForHostContextKey string + +// WithScopesForHost returns a context with per-host scopes added. +// Scopes are de-duplicated. +// Scopes are used as hints for the auth client to fetch bearer tokens with +// larger scopes. +// +// For example, uploading blob to the repository "hello-world" does HEAD request +// first then POST and PUT. The HEAD request will return a challenge for scope +// `repository:hello-world:pull`, and the auth client will fetch a token for +// that challenge. Later, the POST request will return a challenge for scope +// `repository:hello-world:push`, and the auth client will fetch a token for +// that challenge again. By invoking WithScopesForHost with the scope +// `repository:hello-world:pull,push`, the auth client with cache is hinted to +// fetch a token via a single token fetch request for all the HEAD, POST, PUT +// requests. +// +// Passing an empty list of scopes will virtually remove the scope hints in the +// context for the given host. +// +// Reference: https://docs.docker.com/registry/spec/auth/scope/ +func WithScopesForHost(ctx context.Context, host string, scopes ...string) context.Context { + scopes = CleanScopes(scopes) + return context.WithValue(ctx, scopesForHostContextKey(host), scopes) +} + +// AppendScopesForHost appends additional scopes to the existing scopes +// in the context for the given host and returns a new context. +// The resulted scopes are de-duplicated. +// The append operation does modify the existing scope in the context passed in. +func AppendScopesForHost(ctx context.Context, host string, scopes ...string) context.Context { + if len(scopes) == 0 { + return ctx + } + oldScopes := GetScopesForHost(ctx, host) + return WithScopesForHost(ctx, host, append(oldScopes, scopes...)...) +} + +// GetScopesForHost returns the scopes in the context for the given host, +// excluding global scopes added by [WithScopes] and [AppendScopes]. +func GetScopesForHost(ctx context.Context, host string) []string { + if scopes, ok := ctx.Value(scopesForHostContextKey(host)).([]string); ok { + return slices.Clone(scopes) } return nil } +// GetAllScopesForHost returns the scopes in the context for the given host, +// including global scopes added by [WithScopes] and [AppendScopes]. +func GetAllScopesForHost(ctx context.Context, host string) []string { + scopes := GetScopesForHost(ctx, host) + globalScopes := GetScopes(ctx) + + if len(scopes) == 0 { + return globalScopes + } + if len(globalScopes) == 0 { + return scopes + } + // re-clean the scopes + allScopes := append(scopes, globalScopes...) + return CleanScopes(allScopes) +} + // CleanScopes merges and sort the actions in ascending order if the scopes have // the same resource type and name. The final scopes are sorted in ascending // order. In other words, the scopes passed in are de-duplicated and sorted. diff --git a/registry/remote/auth/scope_test.go b/registry/remote/auth/scope_test.go index ac41ad7b..ca9fe339 100644 --- a/registry/remote/auth/scope_test.go +++ b/registry/remote/auth/scope_test.go @@ -19,6 +19,8 @@ import ( "context" "reflect" "testing" + + "oras.land/oras-go/v2/registry" ) func TestScopeRepository(t *testing.T) { @@ -103,6 +105,70 @@ func TestScopeRepository(t *testing.T) { } } +func TestWithScopeHints(t *testing.T) { + ctx := context.Background() + ref1, err := registry.ParseReference("registry.example.com/foo") + if err != nil { + t.Fatal("registry.ParseReference() error =", err) + } + ref2, err := registry.ParseReference("docker.io/foo") + if err != nil { + t.Fatal("registry.ParseReference() error =", err) + } + + // with single scope + want1 := []string{ + "repository:foo:pull", + } + want2 := []string{ + "repository:foo:push", + } + ctx = AppendRepositoryScope(ctx, ref1, ActionPull) + ctx = AppendRepositoryScope(ctx, ref2, ActionPush) + if got := GetScopesForHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want2) + } + + // with duplicated scopes + scopes1 := []string{ + ActionDelete, + ActionDelete, + ActionPull, + } + want1 = []string{ + "repository:foo:delete,pull", + } + scopes2 := []string{ + ActionPush, + ActionPush, + ActionDelete, + } + want2 = []string{ + "repository:foo:delete,push", + } + ctx = AppendRepositoryScope(ctx, ref1, scopes1...) + ctx = AppendRepositoryScope(ctx, ref2, scopes2...) + if got := GetScopesForHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want2) + } + + // append empty scopes + ctx = AppendRepositoryScope(ctx, ref1) + ctx = AppendRepositoryScope(ctx, ref2) + if got := GetScopesForHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want2) + } +} + func TestWithScopes(t *testing.T) { ctx := context.Background() @@ -184,6 +250,149 @@ func TestAppendScopes(t *testing.T) { } } +func TestWithScopesPerHost(t *testing.T) { + ctx := context.Background() + reg1 := "registry1.example.com" + reg2 := "registry2.example.com" + + // with single scope + want1 := []string{ + "repository:foo:pull", + } + want2 := []string{ + "repository:foo:push", + } + ctx = WithScopesForHost(ctx, reg1, want1...) + ctx = WithScopesForHost(ctx, reg2, want2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) + } + + // overwrite scopes + want1 = []string{ + "repository:bar:push", + } + want2 = []string{ + "repository:bar:pull", + } + ctx = WithScopesForHost(ctx, reg1, want1...) + ctx = WithScopesForHost(ctx, reg2, want2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) + } + + // overwrite scopes with de-duplication + scopes1 := []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + } + want1 = []string{ + "repository:alpine:delete", + "repository:hello-world:pull,push", + } + scopes2 := []string{ + "repository:goodbye-world:push", + "repository:nginx:delete", + "repository:goodbye-world:pull", + "repository:nginx:delete", + } + want2 = []string{ + "repository:goodbye-world:pull,push", + "repository:nginx:delete", + } + ctx = WithScopesForHost(ctx, reg1, scopes1...) + ctx = WithScopesForHost(ctx, reg2, scopes2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) + } + + // clean scopes + var want []string + ctx = WithScopesForHost(ctx, reg1, want...) + ctx = WithScopesForHost(ctx, reg2, want...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want) + } +} + +func TestAppendScopesPerHost(t *testing.T) { + ctx := context.Background() + reg1 := "registry1.example.com" + reg2 := "registry2.example.com" + + // with single scope + want1 := []string{ + "repository:foo:pull", + } + want2 := []string{ + "repository:foo:push", + } + ctx = AppendScopesForHost(ctx, reg1, want1...) + ctx = AppendScopesForHost(ctx, reg2, want2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) + } + + // append scopes with de-duplication + scopes1 := []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + } + want1 = []string{ + "repository:alpine:delete", + "repository:foo:pull", + "repository:hello-world:pull,push", + } + scopes2 := []string{ + "repository:goodbye-world:push", + "repository:nginx:delete", + "repository:goodbye-world:pull", + "repository:nginx:delete", + } + want2 = []string{ + "repository:foo:push", + "repository:goodbye-world:pull,push", + "repository:nginx:delete", + } + ctx = AppendScopesForHost(ctx, reg1, scopes1...) + ctx = AppendScopesForHost(ctx, reg2, scopes2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) + } + + // append empty scopes + ctx = AppendScopesForHost(ctx, reg1) + ctx = AppendScopesForHost(ctx, reg2) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) + } +} + func TestCleanScopes(t *testing.T) { tests := []struct { name string @@ -449,3 +658,71 @@ func Test_cleanActions(t *testing.T) { }) } } + +func Test_getAllScopesForHost(t *testing.T) { + host := "registry.example.com" + tests := []struct { + name string + scopes []string + globalScopes []string + want []string + }{ + { + name: "Empty per-host scopes", + scopes: []string{}, + globalScopes: []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + }, + want: []string{ + "repository:alpine:delete", + "repository:hello-world:pull,push", + }, + }, + { + name: "Empty global scopes", + scopes: []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + }, + globalScopes: []string{}, + want: []string{ + "repository:alpine:delete", + "repository:hello-world:pull,push", + }, + }, + { + name: "Per-host scopes + global scopes", + scopes: []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + }, + globalScopes: []string{ + "repository:foo:pull", + "repository:hello-world:pull", + "repository:alpine:pull", + }, + want: []string{ + "repository:alpine:delete,pull", + "repository:foo:pull", + "repository:hello-world:pull,push", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + ctx = WithScopesForHost(ctx, host, tt.scopes...) + ctx = WithScopes(ctx, tt.globalScopes...) + if got := GetAllScopesForHost(ctx, host); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getAllScopesForHost() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/registry/remote/credentials/example_test.go b/registry/remote/credentials/example_test.go new file mode 100644 index 00000000..be8eece0 --- /dev/null +++ b/registry/remote/credentials/example_test.go @@ -0,0 +1,239 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentials_test + +import ( + "context" + "fmt" + "net/http" + + "oras.land/oras-go/v2/registry/remote" + "oras.land/oras-go/v2/registry/remote/auth" + credentials "oras.land/oras-go/v2/registry/remote/credentials" +) + +func ExampleNewNativeStore() { + ns := credentials.NewNativeStore("pass") + + ctx := context.Background() + // save credentials into the store + err := ns.Put(ctx, "localhost:5000", auth.Credential{ + Username: "username-example", + Password: "password-example", + }) + if err != nil { + panic(err) + } + + // get credentials from the store + cred, err := ns.Get(ctx, "localhost:5000") + if err != nil { + panic(err) + } + fmt.Println(cred) + + // delete the credentials from the store + err = ns.Delete(ctx, "localhost:5000") + if err != nil { + panic(err) + } +} + +func ExampleNewFileStore() { + fs, err := credentials.NewFileStore("example/path/config.json") + if err != nil { + panic(err) + } + + ctx := context.Background() + // save credentials into the store + err = fs.Put(ctx, "localhost:5000", auth.Credential{ + Username: "username-example", + Password: "password-example", + }) + if err != nil { + panic(err) + } + + // get credentials from the store + cred, err := fs.Get(ctx, "localhost:5000") + if err != nil { + panic(err) + } + fmt.Println(cred) + + // delete the credentials from the store + err = fs.Delete(ctx, "localhost:5000") + if err != nil { + panic(err) + } +} + +func ExampleNewStore() { + // NewStore returns a Store based on the given configuration file. It will + // automatically determine which Store (file store or native store) to use. + // If the native store is not available, you can save your credentials in + // the configuration file by specifying AllowPlaintextPut: true, but keep + // in mind that this is an unsafe workaround. + // See the documentation for details. + store, err := credentials.NewStore("example/path/config.json", credentials.StoreOptions{ + AllowPlaintextPut: true, + }) + if err != nil { + panic(err) + } + + ctx := context.Background() + // save credentials into the store + err = store.Put(ctx, "localhost:5000", auth.Credential{ + Username: "username-example", + Password: "password-example", + }) + if err != nil { + panic(err) + } + + // get credentials from the store + cred, err := store.Get(ctx, "localhost:5000") + if err != nil { + panic(err) + } + fmt.Println(cred) + + // delete the credentials from the store + err = store.Delete(ctx, "localhost:5000") + if err != nil { + panic(err) + } +} + +func ExampleNewStoreFromDocker() { + ds, err := credentials.NewStoreFromDocker(credentials.StoreOptions{ + AllowPlaintextPut: true, + }) + if err != nil { + panic(err) + } + + ctx := context.Background() + // save credentials into the store + err = ds.Put(ctx, "localhost:5000", auth.Credential{ + Username: "username-example", + Password: "password-example", + }) + if err != nil { + panic(err) + } + + // get credentials from the store + cred, err := ds.Get(ctx, "localhost:5000") + if err != nil { + panic(err) + } + fmt.Println(cred) + + // delete the credentials from the store + err = ds.Delete(ctx, "localhost:5000") + if err != nil { + panic(err) + } +} + +func ExampleNewStoreWithFallbacks_configAsPrimaryStoreDockerAsFallback() { + primaryStore, err := credentials.NewStore("example/path/config.json", credentials.StoreOptions{ + AllowPlaintextPut: true, + }) + if err != nil { + panic(err) + } + fallbackStore, err := credentials.NewStoreFromDocker(credentials.StoreOptions{}) + sf := credentials.NewStoreWithFallbacks(primaryStore, fallbackStore) + + ctx := context.Background() + // save credentials into the store + err = sf.Put(ctx, "localhost:5000", auth.Credential{ + Username: "username-example", + Password: "password-example", + }) + if err != nil { + panic(err) + } + + // get credentials from the store + cred, err := sf.Get(ctx, "localhost:5000") + if err != nil { + panic(err) + } + fmt.Println(cred) + + // delete the credentials from the store + err = sf.Delete(ctx, "localhost:5000") + if err != nil { + panic(err) + } +} + +func ExampleLogin() { + store, err := credentials.NewStore("example/path/config.json", credentials.StoreOptions{ + AllowPlaintextPut: true, + }) + if err != nil { + panic(err) + } + registry, err := remote.NewRegistry("localhost:5000") + if err != nil { + panic(err) + } + cred := auth.Credential{ + Username: "username-example", + Password: "password-example", + } + err = credentials.Login(context.Background(), store, registry, cred) + if err != nil { + panic(err) + } + fmt.Println("Login succeeded") +} + +func ExampleLogout() { + store, err := credentials.NewStore("example/path/config.json", credentials.StoreOptions{}) + if err != nil { + panic(err) + } + err = credentials.Logout(context.Background(), store, "localhost:5000") + if err != nil { + panic(err) + } + fmt.Println("Logout succeeded") +} + +func ExampleCredential() { + store, err := credentials.NewStore("example/path/config.json", credentials.StoreOptions{}) + if err != nil { + panic(err) + } + + client := auth.DefaultClient + client.Credential = credentials.Credential(store) + + request, err := http.NewRequest(http.MethodGet, "localhost:5000", nil) + if err != nil { + panic(err) + } + + _, err = client.Do(request) + if err != nil { + panic(err) + } +} diff --git a/registry/remote/credentials/file_store.go b/registry/remote/credentials/file_store.go new file mode 100644 index 00000000..7664cc2a --- /dev/null +++ b/registry/remote/credentials/file_store.go @@ -0,0 +1,97 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentials + +import ( + "context" + "errors" + "fmt" + "strings" + + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/credentials/internal/config" +) + +// FileStore implements a credentials store using the docker configuration file +// to keep the credentials in plain-text. +// +// Reference: https://docs.docker.com/engine/reference/commandline/cli/#docker-cli-configuration-file-configjson-properties +type FileStore struct { + // DisablePut disables putting credentials in plaintext. + // If DisablePut is set to true, Put() will return ErrPlaintextPutDisabled. + DisablePut bool + + config *config.Config +} + +var ( + // ErrPlaintextPutDisabled is returned by Put() when DisablePut is set + // to true. + ErrPlaintextPutDisabled = errors.New("putting plaintext credentials is disabled") + // ErrBadCredentialFormat is returned by Put() when the credential format + // is bad. + ErrBadCredentialFormat = errors.New("bad credential format") +) + +// NewFileStore creates a new file credentials store. +// +// Reference: https://docs.docker.com/engine/reference/commandline/cli/#docker-cli-configuration-file-configjson-properties +func NewFileStore(configPath string) (*FileStore, error) { + cfg, err := config.Load(configPath) + if err != nil { + return nil, err + } + return newFileStore(cfg), nil +} + +// newFileStore creates a file credentials store based on the given config instance. +func newFileStore(cfg *config.Config) *FileStore { + return &FileStore{config: cfg} +} + +// Get retrieves credentials from the store for the given server address. +func (fs *FileStore) Get(_ context.Context, serverAddress string) (auth.Credential, error) { + return fs.config.GetCredential(serverAddress) +} + +// Put saves credentials into the store for the given server address. +// Returns ErrPlaintextPutDisabled if fs.DisablePut is set to true. +func (fs *FileStore) Put(_ context.Context, serverAddress string, cred auth.Credential) error { + if fs.DisablePut { + return ErrPlaintextPutDisabled + } + if err := validateCredentialFormat(cred); err != nil { + return err + } + + return fs.config.PutCredential(serverAddress, cred) +} + +// Delete removes credentials from the store for the given server address. +func (fs *FileStore) Delete(_ context.Context, serverAddress string) error { + return fs.config.DeleteCredential(serverAddress) +} + +// validateCredentialFormat validates the format of cred. +func validateCredentialFormat(cred auth.Credential) error { + if strings.ContainsRune(cred.Username, ':') { + // Username and password will be encoded in the base64(username:password) + // format in the file. The decoded result will be wrong if username + // contains colon(s). + return fmt.Errorf("%w: colons(:) are not allowed in username", ErrBadCredentialFormat) + } + return nil +} diff --git a/registry/remote/credentials/file_store_test.go b/registry/remote/credentials/file_store_test.go new file mode 100644 index 00000000..dccb7d05 --- /dev/null +++ b/registry/remote/credentials/file_store_test.go @@ -0,0 +1,910 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentials + +import ( + "context" + "encoding/json" + "errors" + "os" + "path/filepath" + "reflect" + "testing" + + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/credentials/internal/config/configtest" +) + +func TestNewFileStore_badPath(t *testing.T) { + tempDir := t.TempDir() + + tests := []struct { + name string + configPath string + wantErr bool + }{ + { + name: "Path is a directory", + configPath: tempDir, + wantErr: true, + }, + { + name: "Empty file name", + configPath: filepath.Join(tempDir, ""), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewFileStore(tt.configPath) + if (err != nil) != tt.wantErr { + t.Errorf("NewFileStore() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestNewFileStore_badFormat(t *testing.T) { + tests := []struct { + name string + configPath string + wantErr bool + }{ + { + name: "Bad JSON format", + configPath: "testdata/bad_config", + wantErr: true, + }, + { + name: "Invalid auths format", + configPath: "testdata/invalid_auths_config.json", + wantErr: true, + }, + { + name: "No auths field", + configPath: "testdata/no_auths_config.json", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewFileStore(tt.configPath) + if (err != nil) != tt.wantErr { + t.Errorf("NewFileStore() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestFileStore_Get_validConfig(t *testing.T) { + ctx := context.Background() + fs, err := NewFileStore("testdata/valid_auths_config.json") + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + + tests := []struct { + name string + serverAddress string + want auth.Credential + wantErr bool + }{ + { + name: "Username and password", + serverAddress: "registry1.example.com", + want: auth.Credential{ + Username: "username", + Password: "password", + }, + }, + { + name: "Identity token", + serverAddress: "registry2.example.com", + want: auth.Credential{ + RefreshToken: "identity_token", + }, + }, + { + name: "Registry token", + serverAddress: "registry3.example.com", + want: auth.Credential{ + AccessToken: "registry_token", + }, + }, + { + name: "Username and password, identity token and registry token", + serverAddress: "registry4.example.com", + want: auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "identity_token", + AccessToken: "registry_token", + }, + }, + { + name: "Empty credential", + serverAddress: "registry5.example.com", + want: auth.EmptyCredential, + }, + { + name: "Username and password, no auth", + serverAddress: "registry6.example.com", + want: auth.Credential{ + Username: "username", + Password: "password", + }, + }, + { + name: "Auth overriding Username and password", + serverAddress: "registry7.example.com", + want: auth.Credential{ + Username: "username", + Password: "password", + }, + }, + { + name: "Not in auths", + serverAddress: "foo.example.com", + want: auth.EmptyCredential, + }, + { + name: "No record", + serverAddress: "registry999.example.com", + want: auth.EmptyCredential, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := fs.Get(ctx, tt.serverAddress) + if (err != nil) != tt.wantErr { + t.Errorf("FileStore.Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("FileStore.Get() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFileStore_Get_invalidConfig(t *testing.T) { + ctx := context.Background() + fs, err := NewFileStore("testdata/invalid_auths_entry_config.json") + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + + tests := []struct { + name string + serverAddress string + want auth.Credential + wantErr bool + }{ + { + name: "Invalid auth encode", + serverAddress: "registry1.example.com", + want: auth.EmptyCredential, + wantErr: true, + }, + { + name: "Invalid auths format", + serverAddress: "registry2.example.com", + want: auth.EmptyCredential, + wantErr: true, + }, + { + name: "Invalid type", + serverAddress: "registry3.example.com", + want: auth.EmptyCredential, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := fs.Get(ctx, tt.serverAddress) + if (err != nil) != tt.wantErr { + t.Errorf("FileStore.Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("FileStore.Get() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFileStore_Get_emptyConfig(t *testing.T) { + ctx := context.Background() + fs, err := NewFileStore("testdata/empty_config.json") + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + + tests := []struct { + name string + serverAddress string + want auth.Credential + wantErr error + }{ + { + name: "Not found", + serverAddress: "registry.example.com", + want: auth.EmptyCredential, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := fs.Get(ctx, tt.serverAddress) + if !errors.Is(err, tt.wantErr) { + t.Errorf("FileStore.Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("FileStore.Get() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFileStore_Get_notExistConfig(t *testing.T) { + ctx := context.Background() + fs, err := NewFileStore("whatever") + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + + tests := []struct { + name string + serverAddress string + want auth.Credential + wantErr error + }{ + { + name: "Not found", + serverAddress: "registry.example.com", + want: auth.EmptyCredential, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := fs.Get(ctx, tt.serverAddress) + if !errors.Is(err, tt.wantErr) { + t.Errorf("FileStore.Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("FileStore.Get() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFileStore_Put_notExistConfig(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + ctx := context.Background() + + fs, err := NewFileStore(configPath) + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + + server := "test.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "refresh_token", + AccessToken: "access_token", + } + + // test put + if err := fs.Put(ctx, server, cred); err != nil { + t.Fatalf("FileStore.Put() error = %v", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + + var cfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&cfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + want := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server: { + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + IdentityToken: "refresh_token", + RegistryToken: "access_token", + }, + }, + } + if !reflect.DeepEqual(cfg, want) { + t.Errorf("Decoded config = %v, want %v", cfg, want) + } + + // verify get + got, err := fs.Get(ctx, server) + if err != nil { + t.Fatalf("FileStore.Get() error = %v", err) + } + if want := cred; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.Get() = %v, want %v", got, want) + } +} + +func TestFileStore_Put_addNew(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + ctx := context.Background() + + // prepare test content + server1 := "registry1.example.com" + cred1 := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "refresh_token", + AccessToken: "access_token", + } + + cfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server1: { + SomeAuthField: "whatever", + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + IdentityToken: cred1.RefreshToken, + RegistryToken: cred1.AccessToken, + }, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + // test put + fs, err := NewFileStore(configPath) + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + server2 := "registry2.example.com" + cred2 := auth.Credential{ + Username: "username_2", + Password: "password_2", + RefreshToken: "refresh_token_2", + AccessToken: "access_token_2", + } + if err := fs.Put(ctx, server2, cred2); err != nil { + t.Fatalf("FileStore.Put() error = %v", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + var gotCfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotCfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + wantCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server1: { + SomeAuthField: "whatever", + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + IdentityToken: cred1.RefreshToken, + RegistryToken: cred1.AccessToken, + }, + server2: { + Auth: "dXNlcm5hbWVfMjpwYXNzd29yZF8y", + IdentityToken: "refresh_token_2", + RegistryToken: "access_token_2", + }, + }, + SomeConfigField: cfg.SomeConfigField, + } + if !reflect.DeepEqual(gotCfg, wantCfg) { + t.Errorf("Decoded config = %v, want %v", gotCfg, wantCfg) + } + + // verify get + got, err := fs.Get(ctx, server1) + if err != nil { + t.Fatalf("FileStore.Get() error = %v", err) + } + if want := cred1; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.Get(%s) = %v, want %v", server1, got, want) + } + + got, err = fs.Get(ctx, server2) + if err != nil { + t.Fatalf("FileStore.Get() error = %v", err) + } + if want := cred2; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.Get(%s) = %v, want %v", server2, got, want) + } +} + +func TestFileStore_Put_updateOld(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + ctx := context.Background() + + // prepare test content + server := "registry.example.com" + cfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server: { + SomeAuthField: "whatever", + Username: "foo", + Password: "bar", + IdentityToken: "refresh_token", + }, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + // test put + fs, err := NewFileStore(configPath) + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + cred := auth.Credential{ + Username: "username", + Password: "password", + AccessToken: "access_token", + } + if err := fs.Put(ctx, server, cred); err != nil { + t.Fatalf("FileStore.Put() error = %v", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + var gotCfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotCfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + wantCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server: { + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + RegistryToken: "access_token", + }, + }, + SomeConfigField: cfg.SomeConfigField, + } + if !reflect.DeepEqual(gotCfg, wantCfg) { + t.Errorf("Decoded config = %v, want %v", gotCfg, wantCfg) + } + + // verify get + got, err := fs.Get(ctx, server) + if err != nil { + t.Fatalf("FileStore.Get() error = %v", err) + } + if want := cred; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.Get(%s) = %v, want %v", server, got, want) + } +} + +func TestFileStore_Put_disablePut(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + ctx := context.Background() + + fs, err := NewFileStore(configPath) + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + fs.DisablePut = true + + server := "test.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "refresh_token", + AccessToken: "access_token", + } + err = fs.Put(ctx, server, cred) + if wantErr := ErrPlaintextPutDisabled; !errors.Is(err, wantErr) { + t.Errorf("FileStore.Put() error = %v, wantErr %v", err, wantErr) + } +} + +func TestFileStore_Put_usernameContainsColon(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + ctx := context.Background() + + fs, err := NewFileStore(configPath) + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + serverAddr := "test.example.com" + cred := auth.Credential{ + Username: "x:y", + Password: "z", + } + if err := fs.Put(ctx, serverAddr, cred); err == nil { + t.Fatal("FileStore.Put() error is nil, want", ErrBadCredentialFormat) + } +} + +func TestFileStore_Put_passwordContainsColon(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + ctx := context.Background() + + fs, err := NewFileStore(configPath) + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + serverAddr := "test.example.com" + cred := auth.Credential{ + Username: "y", + Password: "y:z", + } + if err := fs.Put(ctx, serverAddr, cred); err != nil { + t.Fatal("FileStore.Put() error =", err) + } + got, err := fs.Get(ctx, serverAddr) + if err != nil { + t.Fatal("FileStore.Get() error =", err) + } + if !reflect.DeepEqual(got, cred) { + t.Errorf("FileStore.Get() = %v, want %v", got, cred) + } +} + +func TestFileStore_Delete(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + ctx := context.Background() + + // prepare test content + server1 := "registry1.example.com" + cred1 := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "refresh_token", + AccessToken: "access_token", + } + server2 := "registry2.example.com" + cred2 := auth.Credential{ + Username: "username_2", + Password: "password_2", + RefreshToken: "refresh_token_2", + AccessToken: "access_token_2", + } + + cfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server1: { + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + IdentityToken: cred1.RefreshToken, + RegistryToken: cred1.AccessToken, + }, + server2: { + Auth: "dXNlcm5hbWVfMjpwYXNzd29yZF8y", + IdentityToken: "refresh_token_2", + RegistryToken: "access_token_2", + }, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + fs, err := NewFileStore(configPath) + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + // test get + got, err := fs.Get(ctx, server1) + if err != nil { + t.Fatalf("FileStore.Get() error = %v", err) + } + if want := cred1; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.Get(%s) = %v, want %v", server1, got, want) + } + got, err = fs.Get(ctx, server2) + if err != nil { + t.Fatalf("FileStore.Get() error = %v", err) + } + if want := cred2; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.Get(%s) = %v, want %v", server2, got, want) + } + + // test delete + if err := fs.Delete(ctx, server1); err != nil { + t.Fatalf("FileStore.Delete() error = %v", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + var gotCfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotCfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + wantCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server2: cfg.AuthConfigs[server2], + }, + SomeConfigField: cfg.SomeConfigField, + } + if !reflect.DeepEqual(gotCfg, wantCfg) { + t.Errorf("Decoded config = %v, want %v", gotCfg, wantCfg) + } + + // test get again + got, err = fs.Get(ctx, server1) + if err != nil { + t.Fatalf("FileStore.Get() error = %v", err) + } + if want := auth.EmptyCredential; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.Get(%s) = %v, want %v", server1, got, want) + } + got, err = fs.Get(ctx, server2) + if err != nil { + t.Fatalf("FileStore.Get() error = %v", err) + } + if want := cred2; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.Get(%s) = %v, want %v", server2, got, want) + } +} + +func TestFileStore_Delete_lastConfig(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + ctx := context.Background() + + // prepare test content + server := "registry1.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "refresh_token", + AccessToken: "access_token", + } + + cfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server: { + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + IdentityToken: cred.RefreshToken, + RegistryToken: cred.AccessToken, + }, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + fs, err := NewFileStore(configPath) + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + // test get + got, err := fs.Get(ctx, server) + if err != nil { + t.Fatalf("FileStore.Get() error = %v", err) + } + if want := cred; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.Get(%s) = %v, want %v", server, got, want) + } + + // test delete + if err := fs.Delete(ctx, server); err != nil { + t.Fatalf("FileStore.Delete() error = %v", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + var gotCfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotCfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + wantCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{}, + SomeConfigField: cfg.SomeConfigField, + } + if !reflect.DeepEqual(gotCfg, wantCfg) { + t.Errorf("Decoded config = %v, want %v", gotCfg, wantCfg) + } + + // test get again + got, err = fs.Get(ctx, server) + if err != nil { + t.Fatalf("FileStore.Get() error = %v", err) + } + if want := auth.EmptyCredential; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.Get(%s) = %v, want %v", server, got, want) + } +} + +func TestFileStore_Delete_notExistRecord(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + ctx := context.Background() + + // prepare test content + server := "registry1.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "refresh_token", + AccessToken: "access_token", + } + cfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server: { + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + IdentityToken: cred.RefreshToken, + RegistryToken: cred.AccessToken, + }, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + fs, err := NewFileStore(configPath) + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + // test get + got, err := fs.Get(ctx, server) + if err != nil { + t.Fatalf("FileStore.Get() error = %v", err) + } + if want := cred; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.Get(%s) = %v, want %v", server, got, want) + } + + // test delete + if err := fs.Delete(ctx, "test.example.com"); err != nil { + t.Fatalf("FileStore.Delete() error = %v", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + var gotCfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotCfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + wantCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server: cfg.AuthConfigs[server], + }, + SomeConfigField: cfg.SomeConfigField, + } + if !reflect.DeepEqual(gotCfg, wantCfg) { + t.Errorf("Decoded config = %v, want %v", gotCfg, wantCfg) + } + + // test get again + got, err = fs.Get(ctx, server) + if err != nil { + t.Fatalf("FileStore.Get() error = %v", err) + } + if want := cred; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.Get(%s) = %v, want %v", server, got, want) + } +} + +func TestFileStore_Delete_notExistConfig(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + ctx := context.Background() + + fs, err := NewFileStore(configPath) + if err != nil { + t.Fatal("NewFileStore() error =", err) + } + + server := "test.example.com" + // test delete + if err := fs.Delete(ctx, server); err != nil { + t.Fatalf("FileStore.Delete() error = %v", err) + } + + // verify config file is not created + _, err = os.Stat(configPath) + if wantErr := os.ErrNotExist; !errors.Is(err, wantErr) { + t.Errorf("Stat(%s) error = %v, wantErr %v", configPath, err, wantErr) + } +} + +func Test_validateCredentialFormat(t *testing.T) { + tests := []struct { + name string + cred auth.Credential + wantErr error + }{ + { + name: "Username contains colon", + cred: auth.Credential{ + Username: "x:y", + Password: "z", + }, + wantErr: ErrBadCredentialFormat, + }, + { + name: "Password contains colon", + cred: auth.Credential{ + Username: "x", + Password: "y:z", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := validateCredentialFormat(tt.cred); !errors.Is(err, tt.wantErr) { + t.Errorf("validateCredentialFormat() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/registry/remote/credentials/internal/config/config.go b/registry/remote/credentials/internal/config/config.go new file mode 100644 index 00000000..5bb66a0e --- /dev/null +++ b/registry/remote/credentials/internal/config/config.go @@ -0,0 +1,327 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package config + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/credentials/internal/ioutil" +) + +const ( + // configFieldAuths is the "auths" field in the config file. + // Reference: https://github.com/docker/cli/blob/v24.0.0-beta.2/cli/config/configfile/file.go#L19 + configFieldAuths = "auths" + // configFieldCredentialsStore is the "credsStore" field in the config file. + configFieldCredentialsStore = "credsStore" + // configFieldCredentialHelpers is the "credHelpers" field in the config file. + configFieldCredentialHelpers = "credHelpers" +) + +// ErrInvalidConfigFormat is returned when the config format is invalid. +var ErrInvalidConfigFormat = errors.New("invalid config format") + +// AuthConfig contains authorization information for connecting to a Registry. +// References: +// - https://github.com/docker/cli/blob/v24.0.0-beta.2/cli/config/configfile/file.go#L17-L45 +// - https://github.com/docker/cli/blob/v24.0.0-beta.2/cli/config/types/authconfig.go#L3-L22 +type AuthConfig struct { + // Auth is a base64-encoded string of "{username}:{password}". + Auth string `json:"auth,omitempty"` + // IdentityToken is used to authenticate the user and get an access token + // for the registry. + IdentityToken string `json:"identitytoken,omitempty"` + // RegistryToken is a bearer token to be sent to a registry. + RegistryToken string `json:"registrytoken,omitempty"` + + Username string `json:"username,omitempty"` // legacy field for compatibility + Password string `json:"password,omitempty"` // legacy field for compatibility +} + +// NewAuthConfig creates an authConfig based on cred. +func NewAuthConfig(cred auth.Credential) AuthConfig { + return AuthConfig{ + Auth: encodeAuth(cred.Username, cred.Password), + IdentityToken: cred.RefreshToken, + RegistryToken: cred.AccessToken, + } +} + +// Credential returns an auth.Credential based on ac. +func (ac AuthConfig) Credential() (auth.Credential, error) { + cred := auth.Credential{ + Username: ac.Username, + Password: ac.Password, + RefreshToken: ac.IdentityToken, + AccessToken: ac.RegistryToken, + } + if ac.Auth != "" { + var err error + // override username and password + cred.Username, cred.Password, err = decodeAuth(ac.Auth) + if err != nil { + return auth.EmptyCredential, fmt.Errorf("failed to decode auth field: %w: %v", ErrInvalidConfigFormat, err) + } + } + return cred, nil +} + +// Config represents a docker configuration file. +// References: +// - https://docs.docker.com/engine/reference/commandline/cli/#docker-cli-configuration-file-configjson-properties +// - https://github.com/docker/cli/blob/v24.0.0-beta.2/cli/config/configfile/file.go#L17-L44 +type Config struct { + // path is the path to the config file. + path string + // rwLock is a read-write-lock for the file store. + rwLock sync.RWMutex + // content is the content of the config file. + // Reference: https://github.com/docker/cli/blob/v24.0.0-beta.2/cli/config/configfile/file.go#L17-L44 + content map[string]json.RawMessage + // authsCache is a cache of the auths field of the config. + // Reference: https://github.com/docker/cli/blob/v24.0.0-beta.2/cli/config/configfile/file.go#L19 + authsCache map[string]json.RawMessage + // credentialsStore is the credsStore field of the config. + // Reference: https://github.com/docker/cli/blob/v24.0.0-beta.2/cli/config/configfile/file.go#L28 + credentialsStore string + // credentialHelpers is the credHelpers field of the config. + // Reference: https://github.com/docker/cli/blob/v24.0.0-beta.2/cli/config/configfile/file.go#L29 + credentialHelpers map[string]string +} + +// Load loads Config from the given config path. +func Load(configPath string) (*Config, error) { + cfg := &Config{path: configPath} + configFile, err := os.Open(configPath) + if err != nil { + if os.IsNotExist(err) { + // init content and caches if the content file does not exist + cfg.content = make(map[string]json.RawMessage) + cfg.authsCache = make(map[string]json.RawMessage) + return cfg, nil + } + return nil, fmt.Errorf("failed to open config file at %s: %w", configPath, err) + } + defer configFile.Close() + + // decode config content if the config file exists + if err := json.NewDecoder(configFile).Decode(&cfg.content); err != nil { + return nil, fmt.Errorf("failed to decode config file at %s: %w: %v", configPath, ErrInvalidConfigFormat, err) + } + + if credsStoreBytes, ok := cfg.content[configFieldCredentialsStore]; ok { + if err := json.Unmarshal(credsStoreBytes, &cfg.credentialsStore); err != nil { + return nil, fmt.Errorf("failed to unmarshal creds store field: %w: %v", ErrInvalidConfigFormat, err) + } + } + + if credHelpersBytes, ok := cfg.content[configFieldCredentialHelpers]; ok { + if err := json.Unmarshal(credHelpersBytes, &cfg.credentialHelpers); err != nil { + return nil, fmt.Errorf("failed to unmarshal cred helpers field: %w: %v", ErrInvalidConfigFormat, err) + } + } + + if authsBytes, ok := cfg.content[configFieldAuths]; ok { + if err := json.Unmarshal(authsBytes, &cfg.authsCache); err != nil { + return nil, fmt.Errorf("failed to unmarshal auths field: %w: %v", ErrInvalidConfigFormat, err) + } + } + if cfg.authsCache == nil { + cfg.authsCache = make(map[string]json.RawMessage) + } + + return cfg, nil +} + +// GetAuthConfig returns an auth.Credential for serverAddress. +func (cfg *Config) GetCredential(serverAddress string) (auth.Credential, error) { + cfg.rwLock.RLock() + defer cfg.rwLock.RUnlock() + + authCfgBytes, ok := cfg.authsCache[serverAddress] + if !ok { + // NOTE: the auth key for the server address may have been stored with + // a http/https prefix in legacy config files, e.g. "registry.example.com" + // can be stored as "https://registry.example.com/". + var matched bool + for addr, auth := range cfg.authsCache { + if toHostname(addr) == serverAddress { + matched = true + authCfgBytes = auth + break + } + } + if !matched { + return auth.EmptyCredential, nil + } + } + var authCfg AuthConfig + if err := json.Unmarshal(authCfgBytes, &authCfg); err != nil { + return auth.EmptyCredential, fmt.Errorf("failed to unmarshal auth field: %w: %v", ErrInvalidConfigFormat, err) + } + return authCfg.Credential() +} + +// PutAuthConfig puts cred for serverAddress. +func (cfg *Config) PutCredential(serverAddress string, cred auth.Credential) error { + cfg.rwLock.Lock() + defer cfg.rwLock.Unlock() + + authCfg := NewAuthConfig(cred) + authCfgBytes, err := json.Marshal(authCfg) + if err != nil { + return fmt.Errorf("failed to marshal auth field: %w", err) + } + cfg.authsCache[serverAddress] = authCfgBytes + return cfg.saveFile() +} + +// DeleteAuthConfig deletes the corresponding credential for serverAddress. +func (cfg *Config) DeleteCredential(serverAddress string) error { + cfg.rwLock.Lock() + defer cfg.rwLock.Unlock() + + if _, ok := cfg.authsCache[serverAddress]; !ok { + // no ops + return nil + } + delete(cfg.authsCache, serverAddress) + return cfg.saveFile() +} + +// GetCredentialHelper returns the credential helpers for serverAddress. +func (cfg *Config) GetCredentialHelper(serverAddress string) string { + return cfg.credentialHelpers[serverAddress] +} + +// CredentialsStore returns the configured credentials store. +func (cfg *Config) CredentialsStore() string { + cfg.rwLock.RLock() + defer cfg.rwLock.RUnlock() + + return cfg.credentialsStore +} + +// SetCredentialsStore puts the configured credentials store. +func (cfg *Config) SetCredentialsStore(credsStore string) error { + cfg.rwLock.Lock() + defer cfg.rwLock.Unlock() + + cfg.credentialsStore = credsStore + return cfg.saveFile() +} + +// IsAuthConfigured returns whether there is authentication configured in this +// config file or not. +func (cfg *Config) IsAuthConfigured() bool { + return cfg.credentialsStore != "" || + len(cfg.credentialHelpers) > 0 || + len(cfg.authsCache) > 0 +} + +// saveFile saves Config into the file. +func (cfg *Config) saveFile() (returnErr error) { + // marshal content + // credentialHelpers is skipped as it's never set + if cfg.credentialsStore != "" { + credsStoreBytes, err := json.Marshal(cfg.credentialsStore) + if err != nil { + return fmt.Errorf("failed to marshal creds store: %w", err) + } + cfg.content[configFieldCredentialsStore] = credsStoreBytes + } else { + // omit empty + delete(cfg.content, configFieldCredentialsStore) + } + authsBytes, err := json.Marshal(cfg.authsCache) + if err != nil { + return fmt.Errorf("failed to marshal credentials: %w", err) + } + cfg.content[configFieldAuths] = authsBytes + jsonBytes, err := json.MarshalIndent(cfg.content, "", "\t") + if err != nil { + return fmt.Errorf("failed to marshal config: %w", err) + } + + // write the content to a ingest file for atomicity + configDir := filepath.Dir(cfg.path) + if err := os.MkdirAll(configDir, 0700); err != nil { + return fmt.Errorf("failed to make directory %s: %w", configDir, err) + } + ingest, err := ioutil.Ingest(configDir, bytes.NewReader(jsonBytes)) + if err != nil { + return fmt.Errorf("failed to save config file: %w", err) + } + defer func() { + if returnErr != nil { + // clean up the ingest file in case of error + os.Remove(ingest) + } + }() + + // overwrite the config file + if err := os.Rename(ingest, cfg.path); err != nil { + return fmt.Errorf("failed to save config file: %w", err) + } + return nil +} + +// encodeAuth base64-encodes username and password into base64(username:password). +func encodeAuth(username, password string) string { + if username == "" && password == "" { + return "" + } + return base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) +} + +// decodeAuth decodes a base64 encoded string and returns username and password. +func decodeAuth(authStr string) (username string, password string, err error) { + if authStr == "" { + return "", "", nil + } + + decoded, err := base64.StdEncoding.DecodeString(authStr) + if err != nil { + return "", "", err + } + decodedStr := string(decoded) + username, password, ok := strings.Cut(decodedStr, ":") + if !ok { + return "", "", fmt.Errorf("auth '%s' does not conform the base64(username:password) format", decodedStr) + } + return username, password, nil +} + +// toHostname normalizes a server address to just its hostname, removing +// the scheme and the path parts. +// It is used to match keys in the auths map, which may be either stored as +// hostname or as hostname including scheme (in legacy docker config files). +// Reference: https://github.com/docker/cli/blob/v24.0.6/cli/config/credentials/file_store.go#L71 +func toHostname(addr string) string { + addr = strings.TrimPrefix(addr, "http://") + addr = strings.TrimPrefix(addr, "https://") + addr, _, _ = strings.Cut(addr, "/") + return addr +} diff --git a/registry/remote/credentials/internal/config/config_test.go b/registry/remote/credentials/internal/config/config_test.go new file mode 100644 index 00000000..6622108b --- /dev/null +++ b/registry/remote/credentials/internal/config/config_test.go @@ -0,0 +1,1452 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package config + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "reflect" + "testing" + + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/credentials/internal/config/configtest" +) + +func TestLoad_badPath(t *testing.T) { + tempDir := t.TempDir() + + tests := []struct { + name string + configPath string + wantErr bool + }{ + { + name: "Path is a directory", + configPath: tempDir, + wantErr: true, + }, + { + name: "Empty file name", + configPath: filepath.Join(tempDir, ""), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Load(tt.configPath) + if (err != nil) != tt.wantErr { + t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestLoad_badFormat(t *testing.T) { + tests := []struct { + name string + configPath string + wantErr bool + }{ + { + name: "Bad JSON format", + configPath: "../../testdata/bad_config", + wantErr: true, + }, + { + name: "Invalid auths format", + configPath: "../../testdata/invalid_auths_config.json", + wantErr: true, + }, + { + name: "No auths field", + configPath: "../../testdata/no_auths_config.json", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Load(tt.configPath) + if (err != nil) != tt.wantErr { + t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestConfig_GetCredential_validConfig(t *testing.T) { + cfg, err := Load("../../testdata/valid_auths_config.json") + if err != nil { + t.Fatal("Load() error =", err) + } + + tests := []struct { + name string + serverAddress string + want auth.Credential + wantErr bool + }{ + { + name: "Username and password", + serverAddress: "registry1.example.com", + want: auth.Credential{ + Username: "username", + Password: "password", + }, + }, + { + name: "Identity token", + serverAddress: "registry2.example.com", + want: auth.Credential{ + RefreshToken: "identity_token", + }, + }, + { + name: "Registry token", + serverAddress: "registry3.example.com", + want: auth.Credential{ + AccessToken: "registry_token", + }, + }, + { + name: "Username and password, identity token and registry token", + serverAddress: "registry4.example.com", + want: auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "identity_token", + AccessToken: "registry_token", + }, + }, + { + name: "Empty credential", + serverAddress: "registry5.example.com", + want: auth.EmptyCredential, + }, + { + name: "Username and password, no auth", + serverAddress: "registry6.example.com", + want: auth.Credential{ + Username: "username", + Password: "password", + }, + }, + { + name: "Auth overriding Username and password", + serverAddress: "registry7.example.com", + want: auth.Credential{ + Username: "username", + Password: "password", + }, + }, + { + name: "Not in auths", + serverAddress: "foo.example.com", + want: auth.EmptyCredential, + }, + { + name: "No record", + serverAddress: "registry999.example.com", + want: auth.EmptyCredential, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := cfg.GetCredential(tt.serverAddress) + if (err != nil) != tt.wantErr { + t.Errorf("Config.GetCredential() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Config.GetCredential() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConfig_GetCredential_legacyConfig(t *testing.T) { + cfg, err := Load("../../testdata/legacy_auths_config.json") + if err != nil { + t.Fatal("Load() error =", err) + } + + tests := []struct { + name string + serverAddress string + want auth.Credential + wantErr bool + }{ + { + name: "Regular address matched", + serverAddress: "registry1.example.com", + want: auth.Credential{ + Username: "username1", + Password: "password1", + }, + }, + { + name: "Another entry for the same address matched", + serverAddress: "https://registry1.example.com/", + want: auth.Credential{ + Username: "foo", + Password: "bar", + }, + }, + { + name: "Address with different scheme unmached", + serverAddress: "http://registry1.example.com/", + want: auth.EmptyCredential, + }, + { + name: "Address with http prefix matched", + serverAddress: "registry2.example.com", + want: auth.Credential{ + Username: "username2", + Password: "password2", + }, + }, + { + name: "Address with https prefix matched", + serverAddress: "registry3.example.com", + want: auth.Credential{ + Username: "username3", + Password: "password3", + }, + }, + { + name: "Address with http prefix and / suffix matched", + serverAddress: "registry4.example.com", + want: auth.Credential{ + Username: "username4", + Password: "password4", + }, + }, + { + name: "Address with https prefix and / suffix matched", + serverAddress: "registry5.example.com", + want: auth.Credential{ + Username: "username5", + Password: "password5", + }, + }, + { + name: "Address with https prefix and path suffix matched", + serverAddress: "registry6.example.com", + want: auth.Credential{ + Username: "username6", + Password: "password6", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := cfg.GetCredential(tt.serverAddress) + if (err != nil) != tt.wantErr { + t.Errorf("Config.GetCredential() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Config.GetCredential() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConfig_GetCredential_invalidConfig(t *testing.T) { + cfg, err := Load("../../testdata/invalid_auths_entry_config.json") + if err != nil { + t.Fatal("Load() error =", err) + } + + tests := []struct { + name string + serverAddress string + want auth.Credential + wantErr bool + }{ + { + name: "Invalid auth encode", + serverAddress: "registry1.example.com", + want: auth.EmptyCredential, + wantErr: true, + }, + { + name: "Invalid auths format", + serverAddress: "registry2.example.com", + want: auth.EmptyCredential, + wantErr: true, + }, + { + name: "Invalid type", + serverAddress: "registry3.example.com", + want: auth.EmptyCredential, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := cfg.GetCredential(tt.serverAddress) + if (err != nil) != tt.wantErr { + t.Errorf("Config.GetCredential() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Config.GetCredential() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConfig_GetCredential_emptyConfig(t *testing.T) { + cfg, err := Load("../../testdata/empty_config.json") + if err != nil { + t.Fatal("Load() error =", err) + } + + tests := []struct { + name string + serverAddress string + want auth.Credential + wantErr error + }{ + { + name: "Not found", + serverAddress: "registry.example.com", + want: auth.EmptyCredential, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := cfg.GetCredential(tt.serverAddress) + if !errors.Is(err, tt.wantErr) { + t.Errorf("Config.GetCredential() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Config.GetCredential() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConfig_GetCredential_notExistConfig(t *testing.T) { + cfg, err := Load("whatever") + if err != nil { + t.Fatal("Load() error =", err) + } + + tests := []struct { + name string + serverAddress string + want auth.Credential + wantErr error + }{ + { + name: "Not found", + serverAddress: "registry.example.com", + want: auth.EmptyCredential, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := cfg.GetCredential(tt.serverAddress) + if !errors.Is(err, tt.wantErr) { + t.Errorf("Config.GetCredential() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Config.GetCredential() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConfig_PutCredential_notExistConfig(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + cfg, err := Load(configPath) + if err != nil { + t.Fatal("Load() error =", err) + } + + server := "test.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "refresh_token", + AccessToken: "access_token", + } + + // test put + if err := cfg.PutCredential(server, cred); err != nil { + t.Fatalf("Config.PutCredential() error = %v", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + + var testCfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&testCfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + want := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server: { + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + IdentityToken: "refresh_token", + RegistryToken: "access_token", + }, + }, + } + if !reflect.DeepEqual(testCfg, want) { + t.Errorf("Decoded config = %v, want %v", testCfg, want) + } + + // verify get + got, err := cfg.GetCredential(server) + if err != nil { + t.Fatalf("Config.GetCredential() error = %v", err) + } + if want := cred; !reflect.DeepEqual(got, want) { + t.Errorf("Config.GetCredential() = %v, want %v", got, want) + } +} + +func TestConfig_PutCredential_addNew(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + // prepare test content + server1 := "registry1.example.com" + cred1 := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "refresh_token", + AccessToken: "access_token", + } + + testCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server1: { + SomeAuthField: "whatever", + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + IdentityToken: cred1.RefreshToken, + RegistryToken: cred1.AccessToken, + }, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(testCfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + // test put + cfg, err := Load(configPath) + if err != nil { + t.Fatal("Load() error =", err) + } + server2 := "registry2.example.com" + cred2 := auth.Credential{ + Username: "username_2", + Password: "password_2", + RefreshToken: "refresh_token_2", + AccessToken: "access_token_2", + } + if err := cfg.PutCredential(server2, cred2); err != nil { + t.Fatalf("Config.PutCredential() error = %v", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + var gotCfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotCfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + wantTestCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server1: { + SomeAuthField: "whatever", + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + IdentityToken: cred1.RefreshToken, + RegistryToken: cred1.AccessToken, + }, + server2: { + Auth: "dXNlcm5hbWVfMjpwYXNzd29yZF8y", + IdentityToken: "refresh_token_2", + RegistryToken: "access_token_2", + }, + }, + SomeConfigField: testCfg.SomeConfigField, + } + if !reflect.DeepEqual(gotCfg, wantTestCfg) { + t.Errorf("Decoded config = %v, want %v", gotCfg, wantTestCfg) + } + + // verify get + got, err := cfg.GetCredential(server1) + if err != nil { + t.Fatalf("Config.GetCredential() error = %v", err) + } + if want := cred1; !reflect.DeepEqual(got, want) { + t.Errorf("Config.GetCredential(%s) = %v, want %v", server1, got, want) + } + + got, err = cfg.GetCredential(server2) + if err != nil { + t.Fatalf("Config.GetCredential() error = %v", err) + } + if want := cred2; !reflect.DeepEqual(got, want) { + t.Errorf("Config.GetCredential(%s) = %v, want %v", server2, got, want) + } +} + +func TestConfig_PutCredential_updateOld(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + // prepare test content + server := "registry.example.com" + testCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server: { + SomeAuthField: "whatever", + Username: "foo", + Password: "bar", + IdentityToken: "refresh_token", + }, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(testCfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + // test put + cfg, err := Load(configPath) + if err != nil { + t.Fatal("Load() error =", err) + } + cred := auth.Credential{ + Username: "username", + Password: "password", + AccessToken: "access_token", + } + if err := cfg.PutCredential(server, cred); err != nil { + t.Fatalf("Config.PutCredential() error = %v", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + var gotCfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotCfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + wantCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server: { + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + RegistryToken: "access_token", + }, + }, + SomeConfigField: testCfg.SomeConfigField, + } + if !reflect.DeepEqual(gotCfg, wantCfg) { + t.Errorf("Decoded config = %v, want %v", gotCfg, wantCfg) + } + + // verify get + got, err := cfg.GetCredential(server) + if err != nil { + t.Fatalf("Config.GetCredential() error = %v", err) + } + if want := cred; !reflect.DeepEqual(got, want) { + t.Errorf("Config.GetCredential(%s) = %v, want %v", server, got, want) + } +} + +func TestConfig_DeleteCredential(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + // prepare test content + server1 := "registry1.example.com" + cred1 := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "refresh_token", + AccessToken: "access_token", + } + server2 := "registry2.example.com" + cred2 := auth.Credential{ + Username: "username_2", + Password: "password_2", + RefreshToken: "refresh_token_2", + AccessToken: "access_token_2", + } + + testCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server1: { + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + IdentityToken: cred1.RefreshToken, + RegistryToken: cred1.AccessToken, + }, + server2: { + Auth: "dXNlcm5hbWVfMjpwYXNzd29yZF8y", + IdentityToken: "refresh_token_2", + RegistryToken: "access_token_2", + }, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(testCfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + cfg, err := Load(configPath) + if err != nil { + t.Fatal("Load() error =", err) + } + // test get + got, err := cfg.GetCredential(server1) + if err != nil { + t.Fatalf("FileStore.GetCredential() error = %v", err) + } + if want := cred1; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.GetCredential(%s) = %v, want %v", server1, got, want) + } + got, err = cfg.GetCredential(server2) + if err != nil { + t.Fatalf("FileStore.GetCredential() error = %v", err) + } + if want := cred2; !reflect.DeepEqual(got, want) { + t.Errorf("FileStore.Get(%s) = %v, want %v", server2, got, want) + } + + // test delete + if err := cfg.DeleteCredential(server1); err != nil { + t.Fatalf("Config.DeleteCredential() error = %v", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + var gotTestCfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotTestCfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + wantTestCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server2: testCfg.AuthConfigs[server2], + }, + SomeConfigField: testCfg.SomeConfigField, + } + if !reflect.DeepEqual(gotTestCfg, wantTestCfg) { + t.Errorf("Decoded config = %v, want %v", gotTestCfg, wantTestCfg) + } + + // test get again + got, err = cfg.GetCredential(server1) + if err != nil { + t.Fatalf("Config.GetCredential() error = %v", err) + } + if want := auth.EmptyCredential; !reflect.DeepEqual(got, want) { + t.Errorf("Config.GetCredential(%s) = %v, want %v", server1, got, want) + } + got, err = cfg.GetCredential(server2) + if err != nil { + t.Fatalf("Config.GetCredential() error = %v", err) + } + if want := cred2; !reflect.DeepEqual(got, want) { + t.Errorf("Config.GetCredential(%s) = %v, want %v", server2, got, want) + } +} + +func TestConfig_DeleteCredential_lastConfig(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + // prepare test content + server := "registry1.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "refresh_token", + AccessToken: "access_token", + } + + testCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server: { + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + IdentityToken: cred.RefreshToken, + RegistryToken: cred.AccessToken, + }, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(testCfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + cfg, err := Load(configPath) + if err != nil { + t.Fatal("Load() error =", err) + } + // test get + got, err := cfg.GetCredential(server) + if err != nil { + t.Fatalf("Config.GetCredential() error = %v", err) + } + if want := cred; !reflect.DeepEqual(got, want) { + t.Errorf("Config.GetCredential(%s) = %v, want %v", server, got, want) + } + + // test delete + if err := cfg.DeleteCredential(server); err != nil { + t.Fatalf("Config.DeleteCredential() error = %v", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + var gotTestCfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotTestCfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + wantTestCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{}, + SomeConfigField: testCfg.SomeConfigField, + } + if !reflect.DeepEqual(gotTestCfg, wantTestCfg) { + t.Errorf("Decoded config = %v, want %v", gotTestCfg, wantTestCfg) + } + + // test get again + got, err = cfg.GetCredential(server) + if err != nil { + t.Fatalf("Config.GetCredential() error = %v", err) + } + if want := auth.EmptyCredential; !reflect.DeepEqual(got, want) { + t.Errorf("Config.GetCredential(%s) = %v, want %v", server, got, want) + } +} + +func TestConfig_DeleteCredential_notExistRecord(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + // prepare test content + server := "registry1.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "refresh_token", + AccessToken: "access_token", + } + testCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server: { + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + IdentityToken: cred.RefreshToken, + RegistryToken: cred.AccessToken, + }, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(testCfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + cfg, err := Load(configPath) + if err != nil { + t.Fatal("Load() error =", err) + } + // test get + got, err := cfg.GetCredential(server) + if err != nil { + t.Fatalf("Config.GetCredential() error = %v", err) + } + if want := cred; !reflect.DeepEqual(got, want) { + t.Errorf("Config.GetCredential(%s) = %v, want %v", server, got, want) + } + + // test delete + if err := cfg.DeleteCredential("test.example.com"); err != nil { + t.Fatalf("Config.DeleteCredential() error = %v", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + var gotTestCfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotTestCfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + wantTestCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + server: testCfg.AuthConfigs[server], + }, + SomeConfigField: testCfg.SomeConfigField, + } + if !reflect.DeepEqual(gotTestCfg, wantTestCfg) { + t.Errorf("Decoded config = %v, want %v", gotTestCfg, wantTestCfg) + } + + // test get again + got, err = cfg.GetCredential(server) + if err != nil { + t.Fatalf("Config.GetCredential() error = %v", err) + } + if want := cred; !reflect.DeepEqual(got, want) { + t.Errorf("Config.GetCredential(%s) = %v, want %v", server, got, want) + } +} + +func TestConfig_DeleteCredential_notExistConfig(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + cfg, err := Load(configPath) + if err != nil { + t.Fatal("Load() error =", err) + } + + server := "test.example.com" + // test delete + if err := cfg.DeleteCredential(server); err != nil { + t.Fatalf("Config.DeleteCredential() error = %v", err) + } + + // verify config file is not created + _, err = os.Stat(configPath) + if wantErr := os.ErrNotExist; !errors.Is(err, wantErr) { + t.Errorf("Stat(%s) error = %v, wantErr %v", configPath, err, wantErr) + } +} + +func TestConfig_GetCredentialHelper(t *testing.T) { + cfg, err := Load("../../testdata/credHelpers_config.json") + if err != nil { + t.Fatal("Load() error =", err) + } + + tests := []struct { + name string + serverAddress string + want string + }{ + { + name: "Get cred helper: registry_helper1", + serverAddress: "registry1.example.com", + want: "registry1-helper", + }, + { + name: "Get cred helper: registry_helper2", + serverAddress: "registry2.example.com", + want: "registry2-helper", + }, + { + name: "Empty cred helper configured", + serverAddress: "registry3.example.com", + want: "", + }, + { + name: "No cred helper configured", + serverAddress: "whatever.example.com", + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := cfg.GetCredentialHelper(tt.serverAddress); got != tt.want { + t.Errorf("Config.GetCredentialHelper() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConfig_CredentialsStore(t *testing.T) { + tests := []struct { + name string + configPath string + want string + }{ + { + name: "creds store configured", + configPath: "../../testdata/credsStore_config.json", + want: "teststore", + }, + { + name: "No creds store configured", + configPath: "../../testdata/credsHelpers_config.json", + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg, err := Load(tt.configPath) + if err != nil { + t.Fatal("Load() error =", err) + } + if got := cfg.CredentialsStore(); got != tt.want { + t.Errorf("Config.CredentialsStore() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConfig_SetCredentialsStore(t *testing.T) { + // prepare test content + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + testCfg := configtest.Config{ + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(testCfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + // test SetCredentialsStore + cfg, err := Load(configPath) + if err != nil { + t.Fatal("Load() error =", err) + } + credsStore := "testStore" + if err := cfg.SetCredentialsStore(credsStore); err != nil { + t.Fatal("Config.SetCredentialsStore() error =", err) + } + + // verify + if got := cfg.credentialsStore; got != credsStore { + t.Errorf("Config.credentialsStore = %v, want %v", got, credsStore) + } + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + var gotTestCfg1 configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotTestCfg1); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + if err := configFile.Close(); err != nil { + t.Fatal("failed to close config file:", err) + } + + wantTestCfg1 := configtest.Config{ + AuthConfigs: make(map[string]configtest.AuthConfig), + CredentialsStore: credsStore, + SomeConfigField: testCfg.SomeConfigField, + } + if !reflect.DeepEqual(gotTestCfg1, wantTestCfg1) { + t.Errorf("Decoded config = %v, want %v", gotTestCfg1, wantTestCfg1) + } + + // test SetCredentialsStore: set as empty + if err := cfg.SetCredentialsStore(""); err != nil { + t.Fatal("Config.SetCredentialsStore() error =", err) + } + // verify + if got := cfg.credentialsStore; got != "" { + t.Errorf("Config.credentialsStore = %v, want empty", got) + } + // verify config file + configFile, err = os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + var gotTestCfg2 configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotTestCfg2); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + if err := configFile.Close(); err != nil { + t.Fatal("failed to close config file:", err) + } + + wantTestCfg2 := configtest.Config{ + AuthConfigs: make(map[string]configtest.AuthConfig), + SomeConfigField: testCfg.SomeConfigField, + } + if !reflect.DeepEqual(gotTestCfg2, wantTestCfg2) { + t.Errorf("Decoded config = %v, want %v", gotTestCfg2, wantTestCfg2) + } +} + +func TestConfig_IsAuthConfigured(t *testing.T) { + tempDir := t.TempDir() + + tests := []struct { + name string + fileName string + shouldCreateFile bool + cfg configtest.Config + want bool + }{ + { + name: "not existing file", + fileName: "config.json", + shouldCreateFile: false, + cfg: configtest.Config{}, + want: false, + }, + { + name: "no auth", + fileName: "config.json", + shouldCreateFile: true, + cfg: configtest.Config{ + SomeConfigField: 123, + }, + want: false, + }, + { + name: "empty auths exist", + fileName: "empty_auths.json", + shouldCreateFile: true, + cfg: configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{}, + }, + want: false, + }, + { + name: "auths exist, but no credential", + fileName: "no_cred_auths.json", + shouldCreateFile: true, + cfg: configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + "test.example.com": {}, + }, + }, + want: true, + }, + { + name: "auths exist", + fileName: "auths.json", + shouldCreateFile: true, + cfg: configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + "test.example.com": { + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + }, + }, + }, + want: true, + }, + { + name: "credsStore exists", + fileName: "credsStore.json", + shouldCreateFile: true, + cfg: configtest.Config{ + CredentialsStore: "teststore", + }, + want: true, + }, + { + name: "empty credHelpers exist", + fileName: "empty_credsStore.json", + shouldCreateFile: true, + cfg: configtest.Config{ + CredentialHelpers: map[string]string{}, + }, + want: false, + }, + { + name: "credHelpers exist", + fileName: "credsStore.json", + shouldCreateFile: true, + cfg: configtest.Config{ + CredentialHelpers: map[string]string{ + "test.example.com": "testhelper", + }, + }, + want: true, + }, + { + name: "all exist", + fileName: "credsStore.json", + shouldCreateFile: true, + cfg: configtest.Config{ + SomeConfigField: 123, + AuthConfigs: map[string]configtest.AuthConfig{ + "test.example.com": {}, + }, + CredentialsStore: "teststore", + CredentialHelpers: map[string]string{ + "test.example.com": "testhelper", + }, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // prepare test content + configPath := filepath.Join(tempDir, tt.fileName) + if tt.shouldCreateFile { + jsonStr, err := json.Marshal(tt.cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + } + + cfg, err := Load(configPath) + if err != nil { + t.Fatal("Load() error =", err) + } + if got := cfg.IsAuthConfigured(); got != tt.want { + t.Errorf("IsAuthConfigured() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConfig_saveFile(t *testing.T) { + tempDir := t.TempDir() + tests := []struct { + name string + fileName string + shouldCreateFile bool + oldCfg configtest.Config + newCfg configtest.Config + wantCfg configtest.Config + }{ + { + name: "set credsStore in a non-existing file", + fileName: "config.json", + oldCfg: configtest.Config{}, + newCfg: configtest.Config{ + CredentialsStore: "teststore", + }, + wantCfg: configtest.Config{ + AuthConfigs: make(map[string]configtest.AuthConfig), + CredentialsStore: "teststore", + }, + shouldCreateFile: false, + }, + { + name: "set credsStore in empty file", + fileName: "empty.json", + oldCfg: configtest.Config{}, + newCfg: configtest.Config{ + CredentialsStore: "teststore", + }, + wantCfg: configtest.Config{ + AuthConfigs: make(map[string]configtest.AuthConfig), + CredentialsStore: "teststore", + }, + shouldCreateFile: true, + }, + { + name: "set credsStore in a no-auth-configured file", + fileName: "empty.json", + oldCfg: configtest.Config{ + SomeConfigField: 123, + }, + newCfg: configtest.Config{ + CredentialsStore: "teststore", + }, + wantCfg: configtest.Config{ + SomeConfigField: 123, + AuthConfigs: make(map[string]configtest.AuthConfig), + CredentialsStore: "teststore", + }, + shouldCreateFile: true, + }, + { + name: "Set credsStore and credHelpers in an auth-configured file", + fileName: "auth_configured.json", + oldCfg: configtest.Config{ + SomeConfigField: 123, + AuthConfigs: map[string]configtest.AuthConfig{ + "registry1.example.com": { + SomeAuthField: "something", + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + }, + }, + CredentialsStore: "oldstore", + CredentialHelpers: map[string]string{ + "registry2.example.com": "testhelper", + }, + }, + newCfg: configtest.Config{ + AuthConfigs: make(map[string]configtest.AuthConfig), + SomeConfigField: 123, + CredentialsStore: "newstore", + CredentialHelpers: map[string]string{ + "xxx": "yyy", + }, + }, + wantCfg: configtest.Config{ + SomeConfigField: 123, + AuthConfigs: map[string]configtest.AuthConfig{ + "registry1.example.com": { + SomeAuthField: "something", + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + }, + }, + CredentialsStore: "newstore", + CredentialHelpers: map[string]string{ + "registry2.example.com": "testhelper", // cred helpers will not be updated + }, + }, + shouldCreateFile: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // prepare test content + configPath := filepath.Join(tempDir, tt.fileName) + if tt.shouldCreateFile { + jsonStr, err := json.Marshal(tt.oldCfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + } + + cfg, err := Load(configPath) + if err != nil { + t.Fatal("Load() error =", err) + } + cfg.credentialsStore = tt.newCfg.CredentialsStore + cfg.credentialHelpers = tt.newCfg.CredentialHelpers + if err := cfg.saveFile(); err != nil { + t.Fatal("saveFile() error =", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + var gotCfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotCfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + if !reflect.DeepEqual(gotCfg, tt.wantCfg) { + t.Errorf("Decoded config = %v, want %v", gotCfg, tt.wantCfg) + } + }) + } +} + +func Test_encodeAuth(t *testing.T) { + tests := []struct { + name string + username string + password string + want string + }{ + { + name: "Username and password", + username: "username", + password: "password", + want: "dXNlcm5hbWU6cGFzc3dvcmQ=", + }, + { + name: "Username only", + username: "username", + password: "", + want: "dXNlcm5hbWU6", + }, + { + name: "Password only", + username: "", + password: "password", + want: "OnBhc3N3b3Jk", + }, + { + name: "Empty username and empty password", + username: "", + password: "", + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := encodeAuth(tt.username, tt.password); got != tt.want { + t.Errorf("encodeAuth() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_decodeAuth(t *testing.T) { + tests := []struct { + name string + authStr string + username string + password string + wantErr bool + }{ + { + name: "Valid base64", + authStr: "dXNlcm5hbWU6cGFzc3dvcmQ=", // username:password + username: "username", + password: "password", + }, + { + name: "Valid base64, username only", + authStr: "dXNlcm5hbWU6", // username: + username: "username", + }, + { + name: "Valid base64, password only", + authStr: "OnBhc3N3b3Jk", // :password + password: "password", + }, + { + name: "Valid base64, bad format", + authStr: "d2hhdGV2ZXI=", // whatever + username: "", + password: "", + wantErr: true, + }, + { + name: "Invalid base64", + authStr: "whatever", + username: "", + password: "", + wantErr: true, + }, + { + name: "Empty string", + authStr: "", + username: "", + password: "", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotUsername, gotPassword, err := decodeAuth(tt.authStr) + if (err != nil) != tt.wantErr { + t.Errorf("decodeAuth() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotUsername != tt.username { + t.Errorf("decodeAuth() got = %v, want %v", gotUsername, tt.username) + } + if gotPassword != tt.password { + t.Errorf("decodeAuth() got1 = %v, want %v", gotPassword, tt.password) + } + }) + } +} + +func Test_toHostname(t *testing.T) { + tests := []struct { + name string + addr string + want string + }{ + { + addr: "http://test.example.com", + want: "test.example.com", + }, + { + addr: "http://test.example.com/", + want: "test.example.com", + }, + { + addr: "http://test.example.com/foo/bar", + want: "test.example.com", + }, + { + addr: "https://test.example.com", + want: "test.example.com", + }, + { + addr: "https://test.example.com/", + want: "test.example.com", + }, + { + addr: "http://test.example.com/foo/bar", + want: "test.example.com", + }, + { + addr: "test.example.com", + want: "test.example.com", + }, + { + addr: "test.example.com/foo/bar/", + want: "test.example.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := toHostname(tt.addr); got != tt.want { + t.Errorf("toHostname() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/registry/remote/credentials/internal/config/configtest/config.go b/registry/remote/credentials/internal/config/configtest/config.go new file mode 100644 index 00000000..5945e12e --- /dev/null +++ b/registry/remote/credentials/internal/config/configtest/config.go @@ -0,0 +1,39 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package configtest + +// Config represents the structure of a config file for testing purpose. +type Config struct { + AuthConfigs map[string]AuthConfig `json:"auths"` + CredentialsStore string `json:"credsStore,omitempty"` + CredentialHelpers map[string]string `json:"credHelpers,omitempty"` + SomeConfigField int `json:"some_config_field"` +} + +// AuthConfig represents the structure of the "auths" field of a config file +// for testing purpose. +type AuthConfig struct { + SomeAuthField string `json:"some_auth_field,omitempty"` + Username string `json:"username,omitempty"` + Password string `json:"password,omitempty"` + Auth string `json:"auth,omitempty"` + + // IdentityToken is used to authenticate the user and get + // an access token for the registry. + IdentityToken string `json:"identitytoken,omitempty"` + // RegistryToken is a bearer token to be sent to a registry + RegistryToken string `json:"registrytoken,omitempty"` +} diff --git a/registry/remote/credentials/internal/executer/executer.go b/registry/remote/credentials/internal/executer/executer.go new file mode 100644 index 00000000..a074c684 --- /dev/null +++ b/registry/remote/credentials/internal/executer/executer.go @@ -0,0 +1,80 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package executer is an abstraction for the docker credential helper protocol +// binaries. It is used by nativeStore to interact with installed binaries. +package executer + +import ( + "bytes" + "context" + "errors" + "io" + "os" + "os/exec" + + "oras.land/oras-go/v2/registry/remote/credentials/trace" +) + +// dockerDesktopHelperName is the name of the docker credentials helper +// execuatable. +const dockerDesktopHelperName = "docker-credential-desktop.exe" + +// Executer is an interface that simulates an executable binary. +type Executer interface { + Execute(ctx context.Context, input io.Reader, action string) ([]byte, error) +} + +// executable implements the Executer interface. +type executable struct { + name string +} + +// New returns a new Executer instance. +func New(name string) Executer { + return &executable{ + name: name, + } +} + +// Execute operates on an executable binary and supports context. +func (c *executable) Execute(ctx context.Context, input io.Reader, action string) ([]byte, error) { + cmd := exec.CommandContext(ctx, c.name, action) + cmd.Stdin = input + cmd.Stderr = os.Stderr + trace := trace.ContextExecutableTrace(ctx) + if trace != nil && trace.ExecuteStart != nil { + trace.ExecuteStart(c.name, action) + } + output, err := cmd.Output() + if trace != nil && trace.ExecuteDone != nil { + trace.ExecuteDone(c.name, action, err) + } + if err != nil { + switch execErr := err.(type) { + case *exec.ExitError: + if errMessage := string(bytes.TrimSpace(output)); errMessage != "" { + return nil, errors.New(errMessage) + } + case *exec.Error: + // check if the error is caused by Docker Desktop not running + if execErr.Err == exec.ErrNotFound && c.name == dockerDesktopHelperName { + return nil, errors.New("credentials store is configured to `desktop.exe` but Docker Desktop seems not running") + } + } + return nil, err + } + return output, nil +} diff --git a/registry/remote/credentials/internal/ioutil/ioutil.go b/registry/remote/credentials/internal/ioutil/ioutil.go new file mode 100644 index 00000000..b2e3179d --- /dev/null +++ b/registry/remote/credentials/internal/ioutil/ioutil.go @@ -0,0 +1,49 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ioutil + +import ( + "fmt" + "io" + "os" +) + +// Ingest writes content into a temporary ingest file with the file name format +// "oras_credstore_temp_{randomString}". +func Ingest(dir string, content io.Reader) (path string, ingestErr error) { + tempFile, err := os.CreateTemp(dir, "oras_credstore_temp_*") + if err != nil { + return "", fmt.Errorf("failed to create ingest file: %w", err) + } + path = tempFile.Name() + defer func() { + if err := tempFile.Close(); err != nil && ingestErr == nil { + ingestErr = fmt.Errorf("failed to close ingest file: %w", err) + } + // remove the temp file in case of error. + if ingestErr != nil { + os.Remove(path) + } + }() + + if err := tempFile.Chmod(0600); err != nil { + return "", fmt.Errorf("failed to ensure permission: %w", err) + } + if _, err := io.Copy(tempFile, content); err != nil { + return "", fmt.Errorf("failed to ingest: %w", err) + } + return +} diff --git a/registry/remote/credentials/memory_store.go b/registry/remote/credentials/memory_store.go new file mode 100644 index 00000000..6eb7749b --- /dev/null +++ b/registry/remote/credentials/memory_store.go @@ -0,0 +1,54 @@ +/* + Copyright The ORAS Authors. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package credentials + +import ( + "context" + "sync" + + "oras.land/oras-go/v2/registry/remote/auth" +) + +// memoryStore is a store that keeps credentials in memory. +type memoryStore struct { + store sync.Map +} + +// NewMemoryStore creates a new in-memory credentials store. +func NewMemoryStore() Store { + return &memoryStore{} +} + +// Get retrieves credentials from the store for the given server address. +func (ms *memoryStore) Get(_ context.Context, serverAddress string) (auth.Credential, error) { + cred, found := ms.store.Load(serverAddress) + if !found { + return auth.EmptyCredential, nil + } + return cred.(auth.Credential), nil +} + +// Put saves credentials into the store for the given server address. +func (ms *memoryStore) Put(_ context.Context, serverAddress string, cred auth.Credential) error { + ms.store.Store(serverAddress, cred) + return nil +} + +// Delete removes credentials from the store for the given server address. +func (ms *memoryStore) Delete(_ context.Context, serverAddress string) error { + ms.store.Delete(serverAddress) + return nil +} diff --git a/registry/remote/credentials/memory_store_test.go b/registry/remote/credentials/memory_store_test.go new file mode 100644 index 00000000..6f08c8fc --- /dev/null +++ b/registry/remote/credentials/memory_store_test.go @@ -0,0 +1,229 @@ +/* + Copyright The ORAS Authors. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package credentials + +import ( + "context" + "reflect" + "testing" + + "oras.land/oras-go/v2/registry/remote/auth" +) + +func TestMemoryStore_Get_notExistRecord(t *testing.T) { + ctx := context.Background() + ms := NewMemoryStore() + + serverAddress := "registry.example.com" + got, err := ms.Get(ctx, serverAddress) + if err != nil { + t.Errorf("MemoryStore.Get() error = %v", err) + return + } + if !reflect.DeepEqual(got, auth.EmptyCredential) { + t.Errorf("MemoryStore.Get() = %v, want %v", got, auth.EmptyCredential) + } +} + +func TestMemoryStore_Get_validRecord(t *testing.T) { + ctx := context.Background() + ms := NewMemoryStore().(*memoryStore) + + serverAddress := "registry.example.com" + want := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "identity_token", + AccessToken: "registry_token", + } + ms.store.Store(serverAddress, want) + + got, err := ms.Get(ctx, serverAddress) + if err != nil { + t.Errorf("MemoryStore.Get() error = %v", err) + return + } + if !reflect.DeepEqual(got, want) { + t.Errorf("MemoryStore.Get() = %v, want %v", got, want) + } +} + +func TestMemoryStore_Put_addNew(t *testing.T) { + ctx := context.Background() + ms := NewMemoryStore() + + // Test Put + server1 := "registry.example.com" + cred1 := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "identity_token", + AccessToken: "registry_token", + } + if err := ms.Put(ctx, server1, cred1); err != nil { + t.Errorf("MemoryStore.Put() error = %v", err) + return + } + + server2 := "registry2.example.com" + cred2 := auth.Credential{ + Username: "username2", + Password: "password2", + RefreshToken: "identity_token2", + AccessToken: "registry_token2", + } + if err := ms.Put(ctx, server2, cred2); err != nil { + t.Errorf("MemoryStore.Put() error = %v", err) + return + } + + // Verify Content + got1, err := ms.Get(ctx, server1) + if err != nil { + t.Errorf("MemoryStore.Get() error = %v", err) + return + } + if !reflect.DeepEqual(got1, cred1) { + t.Errorf("MemoryStore.Get() = %v, want %v", got1, cred1) + return + } + + got2, err := ms.Get(ctx, server2) + if err != nil { + t.Errorf("MemoryStore.Get() error = %v", err) + return + } + if !reflect.DeepEqual(got2, cred2) { + t.Errorf("MemoryStore.Get() = %v, want %v", got2, cred2) + return + } +} + +func TestMemoryStore_Put_update(t *testing.T) { + ctx := context.Background() + ms := NewMemoryStore() + + // Test Put + serverAddress := "registry.example.com" + cred1 := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "identity_token", + AccessToken: "registry_token", + } + if err := ms.Put(ctx, serverAddress, cred1); err != nil { + t.Errorf("MemoryStore.Put() error = %v", err) + return + } + + cred2 := auth.Credential{ + Username: "username2", + Password: "password2", + RefreshToken: "identity_token2", + AccessToken: "registry_token2", + } + if err := ms.Put(ctx, serverAddress, cred2); err != nil { + t.Errorf("MemoryStore.Put() error = %v", err) + return + } + + got, err := ms.Get(ctx, serverAddress) + if err != nil { + t.Errorf("MemoryStore.Get() error = %v", err) + return + } + if !reflect.DeepEqual(got, cred2) { + t.Errorf("MemoryStore.Get() = %v, want %v", got, cred2) + return + } +} + +func TestMemoryStore_Delete_existRecord(t *testing.T) { + ctx := context.Background() + ms := NewMemoryStore() + + // Test Put + serverAddress := "registry.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "identity_token", + AccessToken: "registry_token", + } + if err := ms.Put(ctx, serverAddress, cred); err != nil { + t.Errorf("MemoryStore.Put() error = %v", err) + return + } + + // Test Get + got, err := ms.Get(ctx, serverAddress) + if err != nil { + t.Errorf("MemoryStore.Get() error = %v", err) + return + } + if !reflect.DeepEqual(got, cred) { + t.Errorf("MemoryStore.Get(%s) = %v, want %v", serverAddress, got, cred) + return + } + + // Test Delete + if err := ms.Delete(ctx, serverAddress); err != nil { + t.Errorf("MemoryStore.Delete() error = %v", err) + return + } + + // Test Get again + got, err = ms.Get(ctx, serverAddress) + if err != nil { + t.Errorf("MemoryStore.Get() error = %v", err) + return + } + if !reflect.DeepEqual(got, auth.EmptyCredential) { + t.Errorf("MemoryStore.Get() = %v, want %v", got, auth.EmptyCredential) + return + } +} + +func TestMemoryStore_Delete_notExistRecord(t *testing.T) { + ctx := context.Background() + ms := NewMemoryStore() + + // Test Put + serverAddress := "registry.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + RefreshToken: "identity_token", + AccessToken: "registry_token", + } + if err := ms.Put(ctx, serverAddress, cred); err != nil { + t.Errorf("MemoryStore.Put() error = %v", err) + return + } + + // Test Delete + if err := ms.Delete(ctx, serverAddress); err != nil { + t.Errorf("MemoryStore.Delete() error = %v", err) + return + } + + // Test Delete again + // Expect no error if target record does not exist + if err := ms.Delete(ctx, serverAddress); err != nil { + t.Errorf("MemoryStore.Delete() error = %v", err) + return + } +} diff --git a/registry/remote/credentials/native_store.go b/registry/remote/credentials/native_store.go new file mode 100644 index 00000000..9f4c7f74 --- /dev/null +++ b/registry/remote/credentials/native_store.go @@ -0,0 +1,139 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentials + +import ( + "bytes" + "context" + "encoding/json" + "os/exec" + "strings" + + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/credentials/internal/executer" +) + +const ( + remoteCredentialsPrefix = "docker-credential-" + emptyUsername = "" + errCredentialsNotFoundMessage = "credentials not found in native keychain" +) + +// dockerCredentials mimics how docker credential helper binaries store +// credential information. +// Reference: +// - https://docs.docker.com/engine/reference/commandline/login/#credential-helper-protocol +type dockerCredentials struct { + ServerURL string `json:"ServerURL"` + Username string `json:"Username"` + Secret string `json:"Secret"` +} + +// nativeStore implements a credentials store using native keychain to keep +// credentials secure. +type nativeStore struct { + exec executer.Executer +} + +// NewNativeStore creates a new native store that uses a remote helper program to +// manage credentials. +// +// The argument of NewNativeStore can be the native keychains +// ("wincred" for Windows, "pass" for linux and "osxkeychain" for macOS), +// or any program that follows the docker-credentials-helper protocol. +// +// Reference: +// - https://docs.docker.com/engine/reference/commandline/login#credentials-store +func NewNativeStore(helperSuffix string) Store { + return &nativeStore{ + exec: executer.New(remoteCredentialsPrefix + helperSuffix), + } +} + +// NewDefaultNativeStore returns a native store based on the platform-default +// docker credentials helper and a bool indicating if the native store is +// available. +// - Windows: "wincred" +// - Linux: "pass" or "secretservice" +// - macOS: "osxkeychain" +// +// Reference: +// - https://docs.docker.com/engine/reference/commandline/login/#credentials-store +func NewDefaultNativeStore() (Store, bool) { + if helper := getDefaultHelperSuffix(); helper != "" { + return NewNativeStore(helper), true + } + return nil, false +} + +// Get retrieves credentials from the store for the given server. +func (ns *nativeStore) Get(ctx context.Context, serverAddress string) (auth.Credential, error) { + var cred auth.Credential + out, err := ns.exec.Execute(ctx, strings.NewReader(serverAddress), "get") + if err != nil { + if err.Error() == errCredentialsNotFoundMessage { + // do not return an error if the credentials are not in the keychain. + return auth.EmptyCredential, nil + } + return auth.EmptyCredential, err + } + var dockerCred dockerCredentials + if err := json.Unmarshal(out, &dockerCred); err != nil { + return auth.EmptyCredential, err + } + // bearer auth is used if the username is "" + if dockerCred.Username == emptyUsername { + cred.RefreshToken = dockerCred.Secret + } else { + cred.Username = dockerCred.Username + cred.Password = dockerCred.Secret + } + return cred, nil +} + +// Put saves credentials into the store. +func (ns *nativeStore) Put(ctx context.Context, serverAddress string, cred auth.Credential) error { + dockerCred := &dockerCredentials{ + ServerURL: serverAddress, + Username: cred.Username, + Secret: cred.Password, + } + if cred.RefreshToken != "" { + dockerCred.Username = emptyUsername + dockerCred.Secret = cred.RefreshToken + } + credJSON, err := json.Marshal(dockerCred) + if err != nil { + return err + } + _, err = ns.exec.Execute(ctx, bytes.NewReader(credJSON), "store") + return err +} + +// Delete removes credentials from the store for the given server. +func (ns *nativeStore) Delete(ctx context.Context, serverAddress string) error { + _, err := ns.exec.Execute(ctx, strings.NewReader(serverAddress), "erase") + return err +} + +// getDefaultHelperSuffix returns the default credential helper suffix. +func getDefaultHelperSuffix() string { + platformDefault := getPlatformDefaultHelperSuffix() + if _, err := exec.LookPath(remoteCredentialsPrefix + platformDefault); err == nil { + return platformDefault + } + return "" +} diff --git a/registry/remote/credentials/native_store_darwin.go b/registry/remote/credentials/native_store_darwin.go new file mode 100644 index 00000000..1a9aca6f --- /dev/null +++ b/registry/remote/credentials/native_store_darwin.go @@ -0,0 +1,23 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentials + +// getPlatformDefaultHelperSuffix returns the platform default credential +// helper suffix. +// Reference: https://docs.docker.com/engine/reference/commandline/login/#default-behavior +func getPlatformDefaultHelperSuffix() string { + return "osxkeychain" +} diff --git a/registry/remote/credentials/native_store_generic.go b/registry/remote/credentials/native_store_generic.go new file mode 100644 index 00000000..5c7d4a3b --- /dev/null +++ b/registry/remote/credentials/native_store_generic.go @@ -0,0 +1,25 @@ +//go:build !windows && !darwin && !linux + +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentials + +// getPlatformDefaultHelperSuffix returns the platform default credential +// helper suffix. +// Reference: https://docs.docker.com/engine/reference/commandline/login/#default-behavior +func getPlatformDefaultHelperSuffix() string { + return "" +} diff --git a/internal/registryutil/auth.go b/registry/remote/credentials/native_store_linux.go similarity index 60% rename from internal/registryutil/auth.go rename to registry/remote/credentials/native_store_linux.go index 4a601f0c..f182923b 100644 --- a/internal/registryutil/auth.go +++ b/registry/remote/credentials/native_store_linux.go @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. */ -package registryutil +package credentials -import ( - "context" +import "os/exec" - "oras.land/oras-go/v2/registry" - "oras.land/oras-go/v2/registry/remote/auth" -) +// getPlatformDefaultHelperSuffix returns the platform default credential +// helper suffix. +// Reference: https://docs.docker.com/engine/reference/commandline/login/#default-behavior +func getPlatformDefaultHelperSuffix() string { + if _, err := exec.LookPath("pass"); err == nil { + return "pass" + } -// WithScopeHint adds a hinted scope to the context. -func WithScopeHint(ctx context.Context, ref registry.Reference, actions ...string) context.Context { - scope := auth.ScopeRepository(ref.Repository, actions...) - return auth.AppendScopes(ctx, scope) + return "secretservice" } diff --git a/registry/remote/credentials/native_store_test.go b/registry/remote/credentials/native_store_test.go new file mode 100644 index 00000000..df465ff8 --- /dev/null +++ b/registry/remote/credentials/native_store_test.go @@ -0,0 +1,385 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentials + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "strings" + "testing" + + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/credentials/trace" +) + +const ( + basicAuthHost = "localhost:2333" + bearerAuthHost = "localhost:666" + exeErrorHost = "localhost:500/exeError" + jsonErrorHost = "localhost:500/jsonError" + noCredentialsHost = "localhost:404" + traceHost = "localhost:808" + testUsername = "test_username" + testPassword = "test_password" + testRefreshToken = "test_token" +) + +var ( + errCommandExited = fmt.Errorf("exited with error") + errExecute = fmt.Errorf("Execute failed") + errCredentialsNotFound = fmt.Errorf(errCredentialsNotFoundMessage) +) + +// testExecuter implements the Executer interface for testing purpose. +// It simulates interactions between the docker client and a remote +// credentials helper. +type testExecuter struct{} + +// Execute mocks the behavior of a credential helper binary. It returns responses +// and errors based on the input. +func (e *testExecuter) Execute(ctx context.Context, input io.Reader, action string) ([]byte, error) { + in, err := io.ReadAll(input) + if err != nil { + return nil, err + } + inS := string(in) + switch action { + case "get": + switch inS { + case basicAuthHost: + return []byte(`{"Username": "test_username", "Secret": "test_password"}`), nil + case bearerAuthHost: + return []byte(`{"Username": "", "Secret": "test_token"}`), nil + case exeErrorHost: + return []byte("Execute failed"), errExecute + case jsonErrorHost: + return []byte("json.Unmarshal failed"), nil + case noCredentialsHost: + return []byte("credentials not found"), errCredentialsNotFound + case traceHost: + traceHook := trace.ContextExecutableTrace(ctx) + if traceHook != nil { + if traceHook.ExecuteStart != nil { + traceHook.ExecuteStart("testExecuter", "get") + } + if traceHook.ExecuteDone != nil { + traceHook.ExecuteDone("testExecuter", "get", nil) + } + } + return []byte(`{"Username": "test_username", "Secret": "test_password"}`), nil + default: + return []byte("program failed"), errCommandExited + } + case "store": + var c dockerCredentials + err := json.NewDecoder(strings.NewReader(inS)).Decode(&c) + if err != nil { + return []byte("program failed"), errCommandExited + } + switch c.ServerURL { + case basicAuthHost, bearerAuthHost, exeErrorHost: + return nil, nil + case traceHost: + traceHook := trace.ContextExecutableTrace(ctx) + if traceHook != nil { + if traceHook.ExecuteStart != nil { + traceHook.ExecuteStart("testExecuter", "store") + } + if traceHook.ExecuteDone != nil { + traceHook.ExecuteDone("testExecuter", "store", nil) + } + } + return nil, nil + default: + return []byte("program failed"), errCommandExited + } + case "erase": + switch inS { + case basicAuthHost, bearerAuthHost: + return nil, nil + case traceHost: + traceHook := trace.ContextExecutableTrace(ctx) + if traceHook != nil { + if traceHook.ExecuteStart != nil { + traceHook.ExecuteStart("testExecuter", "erase") + } + if traceHook.ExecuteDone != nil { + traceHook.ExecuteDone("testExecuter", "erase", nil) + } + } + return nil, nil + default: + return []byte("program failed"), errCommandExited + } + } + return []byte(fmt.Sprintf("unknown argument %q with %q", action, inS)), errCommandExited +} + +func TestNativeStore_interface(t *testing.T) { + var ns interface{} = &nativeStore{} + if _, ok := ns.(Store); !ok { + t.Error("&NativeStore{} does not conform Store") + } +} + +func TestNativeStore_basicAuth(t *testing.T) { + ns := &nativeStore{ + &testExecuter{}, + } + // Put + err := ns.Put(context.Background(), basicAuthHost, auth.Credential{Username: testUsername, Password: testPassword}) + if err != nil { + t.Fatalf("basic auth test ns.Put fails: %v", err) + } + // Get + cred, err := ns.Get(context.Background(), basicAuthHost) + if err != nil { + t.Fatalf("basic auth test ns.Get fails: %v", err) + } + if cred.Username != testUsername { + t.Fatal("incorrect username") + } + if cred.Password != testPassword { + t.Fatal("incorrect password") + } + // Delete + err = ns.Delete(context.Background(), basicAuthHost) + if err != nil { + t.Fatalf("basic auth test ns.Delete fails: %v", err) + } +} + +func TestNativeStore_refreshToken(t *testing.T) { + ns := &nativeStore{ + &testExecuter{}, + } + // Put + err := ns.Put(context.Background(), bearerAuthHost, auth.Credential{RefreshToken: testRefreshToken}) + if err != nil { + t.Fatalf("refresh token test ns.Put fails: %v", err) + } + // Get + cred, err := ns.Get(context.Background(), bearerAuthHost) + if err != nil { + t.Fatalf("refresh token test ns.Get fails: %v", err) + } + if cred.Username != "" { + t.Fatalf("expect username to be empty, got %s", cred.Username) + } + if cred.RefreshToken != testRefreshToken { + t.Fatal("incorrect refresh token") + } + // Delete + err = ns.Delete(context.Background(), basicAuthHost) + if err != nil { + t.Fatalf("refresh token test ns.Delete fails: %v", err) + } +} + +func TestNativeStore_errorHandling(t *testing.T) { + ns := &nativeStore{ + &testExecuter{}, + } + // Get Error: Execute error + _, err := ns.Get(context.Background(), exeErrorHost) + if err != errExecute { + t.Fatalf("got error: %v, should get exeErr", err) + } + // Get Error: json.Unmarshal + _, err = ns.Get(context.Background(), jsonErrorHost) + if err == nil { + t.Fatalf("should get error from json.Unmarshal") + } + // Get: Should not return error when credentials are not found + _, err = ns.Get(context.Background(), noCredentialsHost) + if err != nil { + t.Fatalf("should not get error when no credentials are found") + } +} + +func TestNewDefaultNativeStore(t *testing.T) { + defaultHelper := getDefaultHelperSuffix() + wantOK := (defaultHelper != "") + + if _, ok := NewDefaultNativeStore(); ok != wantOK { + t.Errorf("NewDefaultNativeStore() = %v, want %v", ok, wantOK) + } +} + +func TestNativeStore_trace(t *testing.T) { + ns := &nativeStore{ + &testExecuter{}, + } + // create trace hooks that write to buffer + buffer := bytes.Buffer{} + traceHook := &trace.ExecutableTrace{ + ExecuteStart: func(executableName string, action string) { + buffer.WriteString(fmt.Sprintf("test trace, start the execution of executable %s with action %s ", executableName, action)) + }, + ExecuteDone: func(executableName string, action string, err error) { + buffer.WriteString(fmt.Sprintf("test trace, completed the execution of executable %s with action %s and got err %v", executableName, action, err)) + }, + } + ctx := trace.WithExecutableTrace(context.Background(), traceHook) + // Test ns.Put trace + err := ns.Put(ctx, traceHost, auth.Credential{Username: testUsername, Password: testPassword}) + if err != nil { + t.Fatalf("trace test ns.Put fails: %v", err) + } + bufferContent := buffer.String() + if bufferContent != "test trace, start the execution of executable testExecuter with action store test trace, completed the execution of executable testExecuter with action store and got err " { + t.Fatalf("incorrect buffer content: %s", bufferContent) + } + buffer.Reset() + // Test ns.Get trace + _, err = ns.Get(ctx, traceHost) + if err != nil { + t.Fatalf("trace test ns.Get fails: %v", err) + } + bufferContent = buffer.String() + if bufferContent != "test trace, start the execution of executable testExecuter with action get test trace, completed the execution of executable testExecuter with action get and got err " { + t.Fatalf("incorrect buffer content: %s", bufferContent) + } + buffer.Reset() + // Test ns.Delete trace + err = ns.Delete(ctx, traceHost) + if err != nil { + t.Fatalf("trace test ns.Delete fails: %v", err) + } + bufferContent = buffer.String() + if bufferContent != "test trace, start the execution of executable testExecuter with action erase test trace, completed the execution of executable testExecuter with action erase and got err " { + t.Fatalf("incorrect buffer content: %s", bufferContent) + } +} + +// This test ensures that a nil trace will not cause an error. +func TestNativeStore_noTrace(t *testing.T) { + ns := &nativeStore{ + &testExecuter{}, + } + // Put + err := ns.Put(context.Background(), traceHost, auth.Credential{Username: testUsername, Password: testPassword}) + if err != nil { + t.Fatalf("basic auth test ns.Put fails: %v", err) + } + // Get + cred, err := ns.Get(context.Background(), traceHost) + if err != nil { + t.Fatalf("basic auth test ns.Get fails: %v", err) + } + if cred.Username != testUsername { + t.Fatal("incorrect username") + } + if cred.Password != testPassword { + t.Fatal("incorrect password") + } + // Delete + err = ns.Delete(context.Background(), traceHost) + if err != nil { + t.Fatalf("basic auth test ns.Delete fails: %v", err) + } +} + +// This test ensures that an empty trace will not cause an error. +func TestNativeStore_emptyTrace(t *testing.T) { + ns := &nativeStore{ + &testExecuter{}, + } + traceHook := &trace.ExecutableTrace{} + ctx := trace.WithExecutableTrace(context.Background(), traceHook) + // Put + err := ns.Put(ctx, traceHost, auth.Credential{Username: testUsername, Password: testPassword}) + if err != nil { + t.Fatalf("basic auth test ns.Put fails: %v", err) + } + // Get + cred, err := ns.Get(ctx, traceHost) + if err != nil { + t.Fatalf("basic auth test ns.Get fails: %v", err) + } + if cred.Username != testUsername { + t.Fatal("incorrect username") + } + if cred.Password != testPassword { + t.Fatal("incorrect password") + } + // Delete + err = ns.Delete(ctx, traceHost) + if err != nil { + t.Fatalf("basic auth test ns.Delete fails: %v", err) + } +} + +func TestNativeStore_multipleTrace(t *testing.T) { + ns := &nativeStore{ + &testExecuter{}, + } + // create trace hooks that write to buffer + buffer := bytes.Buffer{} + trace1 := &trace.ExecutableTrace{ + ExecuteStart: func(executableName string, action string) { + buffer.WriteString(fmt.Sprintf("trace 1 start %s, %s ", executableName, action)) + }, + ExecuteDone: func(executableName string, action string, err error) { + buffer.WriteString(fmt.Sprintf("trace 1 done %s, %s, %v ", executableName, action, err)) + }, + } + ctx := context.Background() + ctx = trace.WithExecutableTrace(ctx, trace1) + trace2 := &trace.ExecutableTrace{ + ExecuteStart: func(executableName string, action string) { + buffer.WriteString(fmt.Sprintf("trace 2 start %s, %s ", executableName, action)) + }, + ExecuteDone: func(executableName string, action string, err error) { + buffer.WriteString(fmt.Sprintf("trace 2 done %s, %s, %v ", executableName, action, err)) + }, + } + ctx = trace.WithExecutableTrace(ctx, trace2) + trace3 := &trace.ExecutableTrace{} + ctx = trace.WithExecutableTrace(ctx, trace3) + // Test ns.Put trace + err := ns.Put(ctx, traceHost, auth.Credential{Username: testUsername, Password: testPassword}) + if err != nil { + t.Fatalf("trace test ns.Put fails: %v", err) + } + bufferContent := buffer.String() + if bufferContent != "trace 2 start testExecuter, store trace 1 start testExecuter, store trace 2 done testExecuter, store, trace 1 done testExecuter, store, " { + t.Fatalf("incorrect buffer content: %s", bufferContent) + } + buffer.Reset() + // Test ns.Get trace + _, err = ns.Get(ctx, traceHost) + if err != nil { + t.Fatalf("trace test ns.Get fails: %v", err) + } + bufferContent = buffer.String() + if bufferContent != "trace 2 start testExecuter, get trace 1 start testExecuter, get trace 2 done testExecuter, get, trace 1 done testExecuter, get, " { + t.Fatalf("incorrect buffer content: %s", bufferContent) + } + buffer.Reset() + // Test ns.Delete trace + err = ns.Delete(ctx, traceHost) + if err != nil { + t.Fatalf("trace test ns.Delete fails: %v", err) + } + bufferContent = buffer.String() + if bufferContent != "trace 2 start testExecuter, erase trace 1 start testExecuter, erase trace 2 done testExecuter, erase, trace 1 done testExecuter, erase, " { + t.Fatalf("incorrect buffer content: %s", bufferContent) + } +} diff --git a/registry/remote/credentials/native_store_windows.go b/registry/remote/credentials/native_store_windows.go new file mode 100644 index 00000000..e334cc79 --- /dev/null +++ b/registry/remote/credentials/native_store_windows.go @@ -0,0 +1,23 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentials + +// getPlatformDefaultHelperSuffix returns the platform default credential +// helper suffix. +// Reference: https://docs.docker.com/engine/reference/commandline/login/#default-behavior +func getPlatformDefaultHelperSuffix() string { + return "wincred" +} diff --git a/registry/remote/credentials/registry.go b/registry/remote/credentials/registry.go new file mode 100644 index 00000000..39735b77 --- /dev/null +++ b/registry/remote/credentials/registry.go @@ -0,0 +1,102 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentials + +import ( + "context" + "errors" + "fmt" + + "oras.land/oras-go/v2/registry/remote" + "oras.land/oras-go/v2/registry/remote/auth" +) + +// ErrClientTypeUnsupported is thrown by Login() when the registry's client type +// is not supported. +var ErrClientTypeUnsupported = errors.New("client type not supported") + +// Login provides the login functionality with the given credentials. The target +// registry's client should be nil or of type *auth.Client. Login uses +// a client local to the function and will not modify the original client of +// the registry. +func Login(ctx context.Context, store Store, reg *remote.Registry, cred auth.Credential) error { + // create a clone of the original registry for login purpose + regClone := *reg + // we use the original client if applicable, otherwise use a default client + var authClient auth.Client + if reg.Client == nil { + authClient = *auth.DefaultClient + authClient.Cache = nil // no cache + } else if client, ok := reg.Client.(*auth.Client); ok { + authClient = *client + } else { + return ErrClientTypeUnsupported + } + regClone.Client = &authClient + // update credentials with the client + authClient.Credential = auth.StaticCredential(reg.Reference.Registry, cred) + // validate and store the credential + if err := regClone.Ping(ctx); err != nil { + return fmt.Errorf("failed to validate the credentials for %s: %w", regClone.Reference.Registry, err) + } + hostname := ServerAddressFromRegistry(regClone.Reference.Registry) + if err := store.Put(ctx, hostname, cred); err != nil { + return fmt.Errorf("failed to store the credentials for %s: %w", hostname, err) + } + return nil +} + +// Logout provides the logout functionality given the registry name. +func Logout(ctx context.Context, store Store, registryName string) error { + registryName = ServerAddressFromRegistry(registryName) + if err := store.Delete(ctx, registryName); err != nil { + return fmt.Errorf("failed to delete the credential for %s: %w", registryName, err) + } + return nil +} + +// Credential returns a Credential() function that can be used by auth.Client. +func Credential(store Store) auth.CredentialFunc { + return func(ctx context.Context, hostport string) (auth.Credential, error) { + hostport = ServerAddressFromHostname(hostport) + if hostport == "" { + return auth.EmptyCredential, nil + } + return store.Get(ctx, hostport) + } +} + +// ServerAddressFromRegistry maps a registry to a server address, which is used as +// a key for credentials store. The Docker CLI expects that the credentials of +// the registry 'docker.io' will be added under the key "https://index.docker.io/v1/". +// See: https://github.com/moby/moby/blob/v24.0.2/registry/config.go#L25-L48 +func ServerAddressFromRegistry(registry string) string { + if registry == "docker.io" { + return "https://index.docker.io/v1/" + } + return registry +} + +// ServerAddressFromHostname maps a hostname to a server address, which is used as +// a key for credentials store. It is expected that the traffic targetting the +// host "registry-1.docker.io" will be redirected to "https://index.docker.io/v1/". +// See: https://github.com/moby/moby/blob/v24.0.2/registry/config.go#L25-L48 +func ServerAddressFromHostname(hostname string) string { + if hostname == "registry-1.docker.io" { + return "https://index.docker.io/v1/" + } + return hostname +} diff --git a/registry/remote/credentials/registry_test.go b/registry/remote/credentials/registry_test.go new file mode 100644 index 00000000..1b5dce31 --- /dev/null +++ b/registry/remote/credentials/registry_test.go @@ -0,0 +1,247 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentials + +import ( + "context" + "encoding/base64" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "testing" + + "oras.land/oras-go/v2/registry/remote" + "oras.land/oras-go/v2/registry/remote/auth" +) + +// testStore implements the Store interface, used for testing purpose. +type testStore struct { + storage map[string]auth.Credential +} + +func (t *testStore) Get(ctx context.Context, serverAddress string) (auth.Credential, error) { + return t.storage[serverAddress], nil +} + +func (t *testStore) Put(ctx context.Context, serverAddress string, cred auth.Credential) error { + if len(t.storage) == 0 { + t.storage = make(map[string]auth.Credential) + } + t.storage[serverAddress] = cred + return nil +} + +func (t *testStore) Delete(ctx context.Context, serverAddress string) error { + delete(t.storage, serverAddress) + return nil +} + +func TestLogin(t *testing.T) { + // create a test registry + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wantedAuthHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(testUsername+":"+testPassword)) + authHeader := r.Header.Get("Authorization") + if authHeader != wantedAuthHeader { + w.Header().Set("Www-Authenticate", `Basic realm="Test Server"`) + w.WriteHeader(http.StatusUnauthorized) + } + })) + defer ts.Close() + uri, _ := url.Parse(ts.URL) + reg, err := remote.NewRegistry(uri.Host) + if err != nil { + t.Fatalf("cannot create test registry: %v", err) + } + reg.PlainHTTP = true + // create a test store + s := &testStore{} + tests := []struct { + name string + ctx context.Context + registry *remote.Registry + cred auth.Credential + wantErr bool + }{ + { + name: "login succeeds", + ctx: context.Background(), + cred: auth.Credential{Username: testUsername, Password: testPassword}, + wantErr: false, + }, + { + name: "login fails (incorrect password)", + ctx: context.Background(), + cred: auth.Credential{Username: testUsername, Password: "whatever"}, + wantErr: true, + }, + { + name: "login fails (nil context makes remote.Ping fails)", + ctx: nil, + cred: auth.Credential{Username: testUsername, Password: testPassword}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // login to test registry + err := Login(tt.ctx, s, reg, tt.cred) + if (err != nil) != tt.wantErr { + t.Fatalf("Login() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + return + } + if got := s.storage[reg.Reference.Registry]; !reflect.DeepEqual(got, tt.cred) { + t.Fatalf("Stored credential = %v, want %v", got, tt.cred) + } + s.Delete(tt.ctx, reg.Reference.Registry) + }) + } +} + +func TestLogin_unsupportedClient(t *testing.T) { + var testClient http.Client + reg, err := remote.NewRegistry("whatever") + if err != nil { + t.Fatalf("cannot create test registry: %v", err) + } + reg.PlainHTTP = true + reg.Client = &testClient + ctx := context.Background() + + s := &testStore{} + cred := auth.EmptyCredential + err = Login(ctx, s, reg, cred) + if wantErr := ErrClientTypeUnsupported; !errors.Is(err, wantErr) { + t.Errorf("Login() error = %v, wantErr %v", err, wantErr) + } +} + +func TestLogout(t *testing.T) { + // create a test store + s := &testStore{} + s.storage = map[string]auth.Credential{ + "localhost:2333": {Username: "test_user", Password: "test_word"}, + "https://index.docker.io/v1/": {Username: "user", Password: "word"}, + } + tests := []struct { + name string + ctx context.Context + store Store + registryName string + wantErr bool + }{ + { + name: "logout of regular registry", + ctx: context.Background(), + registryName: "localhost:2333", + wantErr: false, + }, + { + name: "logout of docker.io", + ctx: context.Background(), + registryName: "docker.io", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := Logout(tt.ctx, s, tt.registryName); (err != nil) != tt.wantErr { + t.Fatalf("Logout() error = %v, wantErr %v", err, tt.wantErr) + } + if s.storage[tt.registryName] != auth.EmptyCredential { + t.Error("Credentials are not deleted") + } + }) + } +} + +func Test_mapHostname(t *testing.T) { + tests := []struct { + name string + host string + want string + }{ + { + name: "map docker.io to https://index.docker.io/v1/", + host: "docker.io", + want: "https://index.docker.io/v1/", + }, + { + name: "do not map other host names", + host: "localhost:2333", + want: "localhost:2333", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ServerAddressFromRegistry(tt.host); got != tt.want { + t.Errorf("mapHostname() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCredential(t *testing.T) { + // create a test store + s := &testStore{} + s.storage = map[string]auth.Credential{ + "localhost:2333": {Username: "test_user", Password: "test_word"}, + "https://index.docker.io/v1/": {Username: "user", Password: "word"}, + } + // create a test client using Credential + testClient := &auth.Client{} + testClient.Credential = Credential(s) + tests := []struct { + name string + registry string + wantCredential auth.Credential + }{ + { + name: "get credentials for localhost:2333", + registry: "localhost:2333", + wantCredential: auth.Credential{Username: "test_user", Password: "test_word"}, + }, + { + name: "get credentials for registry-1.docker.io", + registry: "registry-1.docker.io", + wantCredential: auth.Credential{Username: "user", Password: "word"}, + }, + { + name: "get credentials for a registry not stored", + registry: "localhost:6666", + wantCredential: auth.EmptyCredential, + }, + { + name: "get credentials for an empty string", + registry: "", + wantCredential: auth.EmptyCredential, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := testClient.Credential(context.Background(), tt.registry) + if err != nil { + t.Errorf("could not get credential: %v", err) + } + if !reflect.DeepEqual(got, tt.wantCredential) { + t.Errorf("Credential() = %v, want %v", got, tt.wantCredential) + } + }) + } +} diff --git a/registry/remote/credentials/store.go b/registry/remote/credentials/store.go new file mode 100644 index 00000000..973e0e67 --- /dev/null +++ b/registry/remote/credentials/store.go @@ -0,0 +1,257 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package credentials supports reading, saving, and removing credentials from +// Docker configuration files and external credential stores that follow +// the Docker credential helper protocol. +// +// Reference: https://docs.docker.com/engine/reference/commandline/login/#credential-stores +package credentials + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/credentials/internal/config" +) + +const ( + dockerConfigDirEnv = "DOCKER_CONFIG" + dockerConfigFileDir = ".docker" + dockerConfigFileName = "config.json" +) + +// Store is the interface that any credentials store must implement. +type Store interface { + // Get retrieves credentials from the store for the given server address. + Get(ctx context.Context, serverAddress string) (auth.Credential, error) + // Put saves credentials into the store for the given server address. + Put(ctx context.Context, serverAddress string, cred auth.Credential) error + // Delete removes credentials from the store for the given server address. + Delete(ctx context.Context, serverAddress string) error +} + +// DynamicStore dynamically determines which store to use based on the settings +// in the config file. +type DynamicStore struct { + config *config.Config + options StoreOptions + detectedCredsStore string + setCredsStoreOnce sync.Once +} + +// StoreOptions provides options for NewStore. +type StoreOptions struct { + // AllowPlaintextPut allows saving credentials in plaintext in the config + // file. + // - If AllowPlaintextPut is set to false (default value), Put() will + // return an error when native store is not available. + // - If AllowPlaintextPut is set to true, Put() will save credentials in + // plaintext in the config file when native store is not available. + AllowPlaintextPut bool + + // DetectDefaultNativeStore enables detecting the platform-default native + // credentials store when the config file has no authentication information. + // + // If DetectDefaultNativeStore is set to true, the store will detect and set + // the default native credentials store in the "credsStore" field of the + // config file. + // - Windows: "wincred" + // - Linux: "pass" or "secretservice" + // - macOS: "osxkeychain" + // + // References: + // - https://docs.docker.com/engine/reference/commandline/login/#credentials-store + // - https://docs.docker.com/engine/reference/commandline/cli/#docker-cli-configuration-file-configjson-properties + DetectDefaultNativeStore bool +} + +// NewStore returns a Store based on the given configuration file. +// +// For Get(), Put() and Delete(), the returned Store will dynamically determine +// which underlying credentials store to use for the given server address. +// The underlying credentials store is determined in the following order: +// 1. Native server-specific credential helper +// 2. Native credentials store +// 3. The plain-text config file itself +// +// References: +// - https://docs.docker.com/engine/reference/commandline/login/#credentials-store +// - https://docs.docker.com/engine/reference/commandline/cli/#docker-cli-configuration-file-configjson-properties +func NewStore(configPath string, opts StoreOptions) (*DynamicStore, error) { + cfg, err := config.Load(configPath) + if err != nil { + return nil, err + } + ds := &DynamicStore{ + config: cfg, + options: opts, + } + if opts.DetectDefaultNativeStore && !cfg.IsAuthConfigured() { + // no authentication configured, detect the default credentials store + ds.detectedCredsStore = getDefaultHelperSuffix() + } + return ds, nil +} + +// NewStoreFromDocker returns a Store based on the default docker config file. +// - If the $DOCKER_CONFIG environment variable is set, +// $DOCKER_CONFIG/config.json will be used. +// - Otherwise, the default location $HOME/.docker/config.json will be used. +// +// NewStoreFromDocker internally calls [NewStore]. +// +// References: +// - https://docs.docker.com/engine/reference/commandline/cli/#configuration-files +// - https://docs.docker.com/engine/reference/commandline/cli/#change-the-docker-directory +func NewStoreFromDocker(opt StoreOptions) (*DynamicStore, error) { + configPath, err := getDockerConfigPath() + if err != nil { + return nil, err + } + return NewStore(configPath, opt) +} + +// Get retrieves credentials from the store for the given server address. +func (ds *DynamicStore) Get(ctx context.Context, serverAddress string) (auth.Credential, error) { + return ds.getStore(serverAddress).Get(ctx, serverAddress) +} + +// Put saves credentials into the store for the given server address. +// Put returns ErrPlaintextPutDisabled if native store is not available and +// [StoreOptions].AllowPlaintextPut is set to false. +func (ds *DynamicStore) Put(ctx context.Context, serverAddress string, cred auth.Credential) (returnErr error) { + if err := ds.getStore(serverAddress).Put(ctx, serverAddress, cred); err != nil { + return err + } + // save the detected creds store back to the config file on first put + ds.setCredsStoreOnce.Do(func() { + if ds.detectedCredsStore != "" { + if err := ds.config.SetCredentialsStore(ds.detectedCredsStore); err != nil { + returnErr = fmt.Errorf("failed to set credsStore: %w", err) + } + } + }) + return returnErr +} + +// Delete removes credentials from the store for the given server address. +func (ds *DynamicStore) Delete(ctx context.Context, serverAddress string) error { + return ds.getStore(serverAddress).Delete(ctx, serverAddress) +} + +// IsAuthConfigured returns whether there is authentication configured in the +// config file or not. +// +// IsAuthConfigured returns true when: +// - The "credsStore" field is not empty +// - Or the "credHelpers" field is not empty +// - Or there is any entry in the "auths" field +func (ds *DynamicStore) IsAuthConfigured() bool { + return ds.config.IsAuthConfigured() +} + +// getHelperSuffix returns the credential helper suffix for the given server +// address. +func (ds *DynamicStore) getHelperSuffix(serverAddress string) string { + // 1. Look for a server-specific credential helper first + if helper := ds.config.GetCredentialHelper(serverAddress); helper != "" { + return helper + } + // 2. Then look for the configured native store + if credsStore := ds.config.CredentialsStore(); credsStore != "" { + return credsStore + } + // 3. Use the detected default store + return ds.detectedCredsStore +} + +// getStore returns a store for the given server address. +func (ds *DynamicStore) getStore(serverAddress string) Store { + if helper := ds.getHelperSuffix(serverAddress); helper != "" { + return NewNativeStore(helper) + } + + fs := newFileStore(ds.config) + fs.DisablePut = !ds.options.AllowPlaintextPut + return fs +} + +// getDockerConfigPath returns the path to the default docker config file. +func getDockerConfigPath() (string, error) { + // first try the environment variable + configDir := os.Getenv(dockerConfigDirEnv) + if configDir == "" { + // then try home directory + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %w", err) + } + configDir = filepath.Join(homeDir, dockerConfigFileDir) + } + return filepath.Join(configDir, dockerConfigFileName), nil +} + +// storeWithFallbacks is a store that has multiple fallback stores. +type storeWithFallbacks struct { + stores []Store +} + +// NewStoreWithFallbacks returns a new store based on the given stores. +// - Get() searches the primary and the fallback stores +// for the credentials and returns when it finds the +// credentials in any of the stores. +// - Put() saves the credentials into the primary store. +// - Delete() deletes the credentials from the primary store. +func NewStoreWithFallbacks(primary Store, fallbacks ...Store) Store { + if len(fallbacks) == 0 { + return primary + } + return &storeWithFallbacks{ + stores: append([]Store{primary}, fallbacks...), + } +} + +// Get retrieves credentials from the StoreWithFallbacks for the given server. +// It searches the primary and the fallback stores for the credentials of serverAddress +// and returns when it finds the credentials in any of the stores. +func (sf *storeWithFallbacks) Get(ctx context.Context, serverAddress string) (auth.Credential, error) { + for _, s := range sf.stores { + cred, err := s.Get(ctx, serverAddress) + if err != nil { + return auth.EmptyCredential, err + } + if cred != auth.EmptyCredential { + return cred, nil + } + } + return auth.EmptyCredential, nil +} + +// Put saves credentials into the StoreWithFallbacks. It puts +// the credentials into the primary store. +func (sf *storeWithFallbacks) Put(ctx context.Context, serverAddress string, cred auth.Credential) error { + return sf.stores[0].Put(ctx, serverAddress, cred) +} + +// Delete removes credentials from the StoreWithFallbacks for the given server. +// It deletes the credentials from the primary store. +func (sf *storeWithFallbacks) Delete(ctx context.Context, serverAddress string) error { + return sf.stores[0].Delete(ctx, serverAddress) +} diff --git a/registry/remote/credentials/store_test.go b/registry/remote/credentials/store_test.go new file mode 100644 index 00000000..285e2796 --- /dev/null +++ b/registry/remote/credentials/store_test.go @@ -0,0 +1,982 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentials + +import ( + "context" + "encoding/json" + "errors" + "os" + "path/filepath" + "reflect" + "testing" + + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/credentials/internal/config/configtest" +) + +type badStore struct{} + +var errBadStore = errors.New("bad store!") + +// Get retrieves credentials from the store for the given server address. +func (s *badStore) Get(ctx context.Context, serverAddress string) (auth.Credential, error) { + return auth.EmptyCredential, errBadStore +} + +// Put saves credentials into the store for the given server address. +func (s *badStore) Put(ctx context.Context, serverAddress string, cred auth.Credential) error { + return errBadStore +} + +// Delete removes credentials from the store for the given server address. +func (s *badStore) Delete(ctx context.Context, serverAddress string) error { + return errBadStore +} + +func Test_DynamicStore_IsAuthConfigured(t *testing.T) { + tempDir := t.TempDir() + + tests := []struct { + name string + fileName string + shouldCreateFile bool + cfg configtest.Config + want bool + }{ + { + name: "not existing file", + fileName: "config.json", + shouldCreateFile: false, + cfg: configtest.Config{}, + want: false, + }, + { + name: "no auth", + fileName: "config.json", + shouldCreateFile: true, + cfg: configtest.Config{ + SomeConfigField: 123, + }, + want: false, + }, + { + name: "empty auths exist", + fileName: "empty_auths.json", + shouldCreateFile: true, + cfg: configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{}, + }, + want: false, + }, + { + name: "auths exist, but no credential", + fileName: "no_cred_auths.json", + shouldCreateFile: true, + cfg: configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + "test.example.com": {}, + }, + }, + want: true, + }, + { + name: "auths exist", + fileName: "auths.json", + shouldCreateFile: true, + cfg: configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + "test.example.com": { + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + }, + }, + }, + want: true, + }, + { + name: "credsStore exists", + fileName: "credsStore.json", + shouldCreateFile: true, + cfg: configtest.Config{ + CredentialsStore: "teststore", + }, + want: true, + }, + { + name: "empty credHelpers exist", + fileName: "empty_credsStore.json", + shouldCreateFile: true, + cfg: configtest.Config{ + CredentialHelpers: map[string]string{}, + }, + want: false, + }, + { + name: "credHelpers exist", + fileName: "credsStore.json", + shouldCreateFile: true, + cfg: configtest.Config{ + CredentialHelpers: map[string]string{ + "test.example.com": "testhelper", + }, + }, + want: true, + }, + { + name: "all exist", + fileName: "credsStore.json", + shouldCreateFile: true, + cfg: configtest.Config{ + SomeConfigField: 123, + AuthConfigs: map[string]configtest.AuthConfig{ + "test.example.com": {}, + }, + CredentialsStore: "teststore", + CredentialHelpers: map[string]string{ + "test.example.com": "testhelper", + }, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // prepare test content + configPath := filepath.Join(tempDir, tt.fileName) + if tt.shouldCreateFile { + jsonStr, err := json.Marshal(tt.cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + } + + ds, err := NewStore(configPath, StoreOptions{}) + if err != nil { + t.Fatal("newStore() error =", err) + } + if got := ds.IsAuthConfigured(); got != tt.want { + t.Errorf("DynamicStore.IsAuthConfigured() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_DynamicStore_authConfigured(t *testing.T) { + // prepare test content + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "auth_configured.json") + config := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + "xxx": {}, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(config) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + ds, err := NewStore(configPath, StoreOptions{AllowPlaintextPut: true}) + if err != nil { + t.Fatal("NewStore() error =", err) + } + + // test IsAuthConfigured + authConfigured := ds.IsAuthConfigured() + if want := true; authConfigured != want { + t.Errorf("DynamicStore.IsAuthConfigured() = %v, want %v", authConfigured, want) + } + + serverAddr := "test.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + } + ctx := context.Background() + + // test put + if err := ds.Put(ctx, serverAddr, cred); err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + // Put() should not set detected store back to config + if got := ds.detectedCredsStore; got != "" { + t.Errorf("ds.detectedCredsStore = %v, want empty", got) + } + if got := ds.config.CredentialsStore(); got != "" { + t.Errorf("ds.config.CredentialsStore() = %v, want empty", got) + } + + // test get + got, err := ds.Get(ctx, serverAddr) + if err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + if want := cred; got != want { + t.Errorf("DynamicStore.Get() = %v, want %v", got, want) + } + + // test delete + err = ds.Delete(ctx, serverAddr) + if err != nil { + t.Fatal("DynamicStore.Delete() error =", err) + } + + // verify delete + got, err = ds.Get(ctx, serverAddr) + if err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + if want := auth.EmptyCredential; got != want { + t.Errorf("DynamicStore.Get() = %v, want %v", got, want) + } +} + +func Test_DynamicStore_authConfigured_DetectDefaultNativeStore(t *testing.T) { + // prepare test content + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "auth_configured.json") + config := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + "xxx": {}, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(config) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + opts := StoreOptions{ + AllowPlaintextPut: true, + DetectDefaultNativeStore: true, + } + ds, err := NewStore(configPath, opts) + if err != nil { + t.Fatal("NewStore() error =", err) + } + + // test IsAuthConfigured + authConfigured := ds.IsAuthConfigured() + if want := true; authConfigured != want { + t.Errorf("DynamicStore.IsAuthConfigured() = %v, want %v", authConfigured, want) + } + + serverAddr := "test.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + } + ctx := context.Background() + + // test put + if err := ds.Put(ctx, serverAddr, cred); err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + // Put() should not set detected store back to config + if got := ds.detectedCredsStore; got != "" { + t.Errorf("ds.detectedCredsStore = %v, want empty", got) + } + if got := ds.config.CredentialsStore(); got != "" { + t.Errorf("ds.config.CredentialsStore() = %v, want empty", got) + } + + // test get + got, err := ds.Get(ctx, serverAddr) + if err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + if want := cred; got != want { + t.Errorf("DynamicStore.Get() = %v, want %v", got, want) + } + + // test delete + err = ds.Delete(ctx, serverAddr) + if err != nil { + t.Fatal("DynamicStore.Delete() error =", err) + } + + // verify delete + got, err = ds.Get(ctx, serverAddr) + if err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + if want := auth.EmptyCredential; got != want { + t.Errorf("DynamicStore.Get() = %v, want %v", got, want) + } +} + +func Test_DynamicStore_noAuthConfigured(t *testing.T) { + // prepare test content + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "no_auth_configured.json") + cfg := configtest.Config{ + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + ds, err := NewStore(configPath, StoreOptions{AllowPlaintextPut: true}) + if err != nil { + t.Fatal("NewStore() error =", err) + } + + // test IsAuthConfigured + authConfigured := ds.IsAuthConfigured() + if want := false; authConfigured != want { + t.Errorf("DynamicStore.IsAuthConfigured() = %v, want %v", authConfigured, want) + } + + serverAddr := "test.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + } + ctx := context.Background() + + // Get() should not set detected store back to config + if _, err := ds.Get(ctx, serverAddr); err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + + // test put + if err := ds.Put(ctx, serverAddr, cred); err != nil { + t.Fatal("DynamicStore.Put() error =", err) + } + // Put() should not set detected store back to config + if got := ds.detectedCredsStore; got != "" { + t.Errorf("ds.detectedCredsStore = %v, want empty", got) + } + if got := ds.config.CredentialsStore(); got != "" { + t.Errorf("ds.config.CredentialsStore() = %v, want empty", got) + } + + // test get + got, err := ds.Get(ctx, serverAddr) + if err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + if want := cred; got != want { + t.Errorf("DynamicStore.Get() = %v, want %v", got, want) + } + + // test delete + err = ds.Delete(ctx, serverAddr) + if err != nil { + t.Fatal("DynamicStore.Delete() error =", err) + } + + // verify delete + got, err = ds.Get(ctx, serverAddr) + if err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + if want := auth.EmptyCredential; got != want { + t.Errorf("DynamicStore.Get() = %v, want %v", got, want) + } +} + +func Test_DynamicStore_noAuthConfigured_DetectDefaultNativeStore(t *testing.T) { + // prepare test content + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "no_auth_configured.json") + cfg := configtest.Config{ + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + opts := StoreOptions{ + AllowPlaintextPut: true, + DetectDefaultNativeStore: true, + } + ds, err := NewStore(configPath, opts) + if err != nil { + t.Fatal("NewStore() error =", err) + } + + // test IsAuthConfigured + authConfigured := ds.IsAuthConfigured() + if want := false; authConfigured != want { + t.Errorf("DynamicStore.IsAuthConfigured() = %v, want %v", authConfigured, want) + } + + serverAddr := "test.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + } + ctx := context.Background() + + // Get() should set detectedCredsStore only, but should not save it back to config + if _, err := ds.Get(ctx, serverAddr); err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + if defaultStore := getDefaultHelperSuffix(); defaultStore != "" { + if got := ds.detectedCredsStore; got != defaultStore { + t.Errorf("ds.detectedCredsStore = %v, want %v", got, defaultStore) + } + } + if got := ds.config.CredentialsStore(); got != "" { + t.Errorf("ds.config.CredentialsStore() = %v, want empty", got) + } + + // test put + if err := ds.Put(ctx, serverAddr, cred); err != nil { + t.Fatal("DynamicStore.Put() error =", err) + } + + // Put() should set the detected store back to config + if got := ds.config.CredentialsStore(); got != ds.detectedCredsStore { + t.Errorf("ds.config.CredentialsStore() = %v, want %v", got, ds.detectedCredsStore) + } + + // test get + got, err := ds.Get(ctx, serverAddr) + if err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + if want := cred; got != want { + t.Errorf("DynamicStore.Get() = %v, want %v", got, want) + } + + // test delete + err = ds.Delete(ctx, serverAddr) + if err != nil { + t.Fatal("DynamicStore.Delete() error =", err) + } + + // verify delete + got, err = ds.Get(ctx, serverAddr) + if err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + if want := auth.EmptyCredential; got != want { + t.Errorf("DynamicStore.Get() = %v, want %v", got, want) + } +} + +func Test_DynamicStore_fileStore_AllowPlainTextPut(t *testing.T) { + // prepare test content + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + serverAddr := "newtest.example.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + } + ctx := context.Background() + + cfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + "test.example.com": {}, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + // test default option + ds, err := NewStore(configPath, StoreOptions{}) + if err != nil { + t.Fatal("NewStore() error =", err) + } + err = ds.Put(ctx, serverAddr, cred) + if wantErr := ErrPlaintextPutDisabled; !errors.Is(err, wantErr) { + t.Errorf("DynamicStore.Put() error = %v, wantErr %v", err, wantErr) + } + + // test AllowPlainTextPut = true + ds, err = NewStore(configPath, StoreOptions{AllowPlaintextPut: true}) + if err != nil { + t.Fatal("NewStore() error =", err) + } + if err := ds.Put(ctx, serverAddr, cred); err != nil { + t.Error("DynamicStore.Put() error =", err) + } + + // verify config file + configFile, err := os.Open(configPath) + if err != nil { + t.Fatalf("failed to open config file: %v", err) + } + defer configFile.Close() + var gotCfg configtest.Config + if err := json.NewDecoder(configFile).Decode(&gotCfg); err != nil { + t.Fatalf("failed to decode config file: %v", err) + } + wantCfg := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + "test.example.com": {}, + serverAddr: { + Auth: "dXNlcm5hbWU6cGFzc3dvcmQ=", + }, + }, + SomeConfigField: cfg.SomeConfigField, + } + if !reflect.DeepEqual(gotCfg, wantCfg) { + t.Errorf("Decoded config = %v, want %v", gotCfg, wantCfg) + } +} + +func Test_DynamicStore_getHelperSuffix(t *testing.T) { + tests := []struct { + name string + configPath string + serverAddress string + want string + }{ + { + name: "Get cred helper: registry_helper1", + configPath: "testdata/credHelpers_config.json", + serverAddress: "registry1.example.com", + want: "registry1-helper", + }, + { + name: "Get cred helper: registry_helper2", + configPath: "testdata/credHelpers_config.json", + serverAddress: "registry2.example.com", + want: "registry2-helper", + }, + { + name: "Empty cred helper configured", + configPath: "testdata/credHelpers_config.json", + serverAddress: "registry3.example.com", + want: "", + }, + { + name: "No cred helper and creds store configured", + configPath: "testdata/credHelpers_config.json", + serverAddress: "whatever.example.com", + want: "", + }, + { + name: "Choose cred helper over creds store", + configPath: "testdata/credsStore_config.json", + serverAddress: "test.example.com", + want: "test-helper", + }, + { + name: "No cred helper configured, choose cred store", + configPath: "testdata/credsStore_config.json", + serverAddress: "whatever.example.com", + want: "teststore", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ds, err := NewStore(tt.configPath, StoreOptions{}) + if err != nil { + t.Fatal("NewStore() error =", err) + } + if got := ds.getHelperSuffix(tt.serverAddress); got != tt.want { + t.Errorf("DynamicStore.getHelperSuffix() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_DynamicStore_getStore_nativeStore(t *testing.T) { + tests := []struct { + name string + configPath string + serverAddress string + }{ + { + name: "Cred helper configured for registry1.example.com", + configPath: "testdata/credHelpers_config.json", + serverAddress: "registry1.example.com", + }, + { + name: "Cred helper configured for registry2.example.com", + configPath: "testdata/credHelpers_config.json", + serverAddress: "registry2.example.com", + }, + { + name: "Cred helper configured for test.example.com", + configPath: "testdata/credsStore_config.json", + serverAddress: "test.example.com", + }, + { + name: "No cred helper configured, use creds store", + configPath: "testdata/credsStore_config.json", + serverAddress: "whaterver.example.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ds, err := NewStore(tt.configPath, StoreOptions{}) + if err != nil { + t.Fatal("NewStore() error =", err) + } + gotStore := ds.getStore(tt.serverAddress) + if _, ok := gotStore.(*nativeStore); !ok { + t.Errorf("gotStore is not a native store") + } + }) + } +} + +func Test_DynamicStore_getStore_fileStore(t *testing.T) { + tests := []struct { + name string + configPath string + serverAddress string + }{ + { + name: "Empty cred helper configured for registry3.example.com", + configPath: "testdata/credHelpers_config.json", + serverAddress: "registry3.example.com", + }, + { + name: "No cred helper configured", + configPath: "testdata/credHelpers_config.json", + serverAddress: "whatever.example.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ds, err := NewStore(tt.configPath, StoreOptions{}) + if err != nil { + t.Fatal("NewStore() error =", err) + } + gotStore := ds.getStore(tt.serverAddress) + gotFS1, ok := gotStore.(*FileStore) + if !ok { + t.Errorf("gotStore is not a file store") + } + + // get again, the two file stores should be based on the same config instance + gotStore = ds.getStore(tt.serverAddress) + gotFS2, ok := gotStore.(*FileStore) + if !ok { + t.Errorf("gotStore is not a file store") + } + if gotFS1.config != gotFS2.config { + t.Errorf("gotFS1 and gotFS2 are not based on the same config") + } + }) + } +} + +func Test_storeWithFallbacks_Get(t *testing.T) { + // prepare test content + server1 := "foo.registry.com" + cred1 := auth.Credential{ + Username: "username", + Password: "password", + } + server2 := "bar.registry.com" + cred2 := auth.Credential{ + RefreshToken: "identity_token", + } + + primaryStore := &testStore{} + fallbackStore1 := &testStore{ + storage: map[string]auth.Credential{ + server1: cred1, + }, + } + fallbackStore2 := &testStore{ + storage: map[string]auth.Credential{ + server2: cred2, + }, + } + sf := NewStoreWithFallbacks(primaryStore, fallbackStore1, fallbackStore2) + ctx := context.Background() + + // test Get() + got1, err := sf.Get(ctx, server1) + if err != nil { + t.Fatalf("storeWithFallbacks.Get(%s) error = %v", server1, err) + } + if want := cred1; got1 != cred1 { + t.Errorf("storeWithFallbacks.Get(%s) = %v, want %v", server1, got1, want) + } + got2, err := sf.Get(ctx, server2) + if err != nil { + t.Fatalf("storeWithFallbacks.Get(%s) error = %v", server2, err) + } + if want := cred2; got2 != cred2 { + t.Errorf("storeWithFallbacks.Get(%s) = %v, want %v", server2, got2, want) + } + + // test Get(): no credential found + got, err := sf.Get(ctx, "whaterver") + if err != nil { + t.Fatal("storeWithFallbacks.Get() error =", err) + } + if want := auth.EmptyCredential; got != want { + t.Errorf("storeWithFallbacks.Get() = %v, want %v", got, want) + } +} + +func Test_storeWithFallbacks_Get_throwError(t *testing.T) { + badStore := &badStore{} + goodStore := &testStore{} + sf := NewStoreWithFallbacks(badStore, goodStore) + ctx := context.Background() + + // test Get(): should throw error + _, err := sf.Get(ctx, "whatever") + if wantErr := errBadStore; !errors.Is(err, wantErr) { + t.Errorf("storeWithFallback.Get() error = %v, wantErr %v", err, wantErr) + } +} + +func Test_storeWithFallbacks_Put(t *testing.T) { + // prepare test content + cfg := configtest.Config{ + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "no_auth_configured.json") + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + opts := StoreOptions{ + AllowPlaintextPut: true, + } + primaryStore, err := NewStore(configPath, opts) // plaintext enabled + if err != nil { + t.Fatalf("NewStore(%s) error = %v", configPath, err) + } + badStore := &badStore{} // bad store + sf := NewStoreWithFallbacks(primaryStore, badStore) + ctx := context.Background() + + server := "example.registry.com" + cred := auth.Credential{ + Username: "username", + Password: "password", + } + // test Put() + if err := sf.Put(ctx, server, cred); err != nil { + t.Fatal("storeWithFallbacks.Put() error =", err) + } + // verify Get() + got, err := sf.Get(ctx, server) + if err != nil { + t.Fatal("storeWithFallbacks.Get() error =", err) + } + if want := cred; got != want { + t.Errorf("storeWithFallbacks.Get() = %v, want %v", got, want) + } +} + +func Test_storeWithFallbacks_Put_throwError(t *testing.T) { + badStore := &badStore{} + goodStore := &testStore{} + sf := NewStoreWithFallbacks(badStore, goodStore) + ctx := context.Background() + + // test Put(): should thrown error + err := sf.Put(ctx, "whatever", auth.Credential{}) + if wantErr := errBadStore; !errors.Is(err, wantErr) { + t.Errorf("storeWithFallback.Put() error = %v, wantErr %v", err, wantErr) + } +} + +func Test_storeWithFallbacks_Delete(t *testing.T) { + // prepare test content + server1 := "foo.registry.com" + cred1 := auth.Credential{ + Username: "username", + Password: "password", + } + server2 := "bar.registry.com" + cred2 := auth.Credential{ + RefreshToken: "identity_token", + } + + primaryStore := &testStore{ + storage: map[string]auth.Credential{ + server1: cred1, + server2: cred2, + }, + } + badStore := &badStore{} + sf := NewStoreWithFallbacks(primaryStore, badStore) + ctx := context.Background() + + // test Delete(): server1 + if err := sf.Delete(ctx, server1); err != nil { + t.Fatal("storeWithFallback.Delete()") + } + // verify primary store + if want := map[string]auth.Credential{server2: cred2}; !reflect.DeepEqual(primaryStore.storage, want) { + t.Errorf("primaryStore.storage = %v, want %v", primaryStore.storage, want) + } + + // test Delete(): server2 + if err := sf.Delete(ctx, server2); err != nil { + t.Fatal("storeWithFallback.Delete()") + } + // verify primary store + if want := map[string]auth.Credential{}; !reflect.DeepEqual(primaryStore.storage, want) { + t.Errorf("primaryStore.storage = %v, want %v", primaryStore.storage, want) + } +} + +func Test_storeWithFallbacks_Delete_throwError(t *testing.T) { + badStore := &badStore{} + goodStore := &testStore{} + sf := NewStoreWithFallbacks(badStore, goodStore) + ctx := context.Background() + + // test Delete(): should throw error + err := sf.Delete(ctx, "whatever") + if wantErr := errBadStore; !errors.Is(err, wantErr) { + t.Errorf("storeWithFallback.Delete() error = %v, wantErr %v", err, wantErr) + } +} + +func Test_getDockerConfigPath_env(t *testing.T) { + dir, err := os.Getwd() + if err != nil { + t.Fatal("os.Getwd() error =", err) + } + t.Setenv("DOCKER_CONFIG", dir) + + got, err := getDockerConfigPath() + if err != nil { + t.Fatal("getDockerConfigPath() error =", err) + } + if want := filepath.Join(dir, "config.json"); got != want { + t.Errorf("getDockerConfigPath() = %v, want %v", got, want) + } +} + +func Test_getDockerConfigPath_homeDir(t *testing.T) { + t.Setenv("DOCKER_CONFIG", "") + + got, err := getDockerConfigPath() + if err != nil { + t.Fatal("getDockerConfigPath() error =", err) + } + homeDir, err := os.UserHomeDir() + if err != nil { + t.Fatal("os.UserHomeDir()") + } + if want := filepath.Join(homeDir, ".docker", "config.json"); got != want { + t.Errorf("getDockerConfigPath() = %v, want %v", got, want) + } +} + +func TestNewStoreFromDocker(t *testing.T) { + // prepare test content + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + t.Setenv("DOCKER_CONFIG", tempDir) + + serverAddr1 := "test.example.com" + cred1 := auth.Credential{ + Username: "foo", + Password: "bar", + } + config := configtest.Config{ + AuthConfigs: map[string]configtest.AuthConfig{ + serverAddr1: { + Auth: "Zm9vOmJhcg==", + }, + }, + SomeConfigField: 123, + } + jsonStr, err := json.Marshal(config) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, jsonStr, 0666); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + ctx := context.Background() + + ds, err := NewStoreFromDocker(StoreOptions{AllowPlaintextPut: true}) + if err != nil { + t.Fatal("NewStoreFromDocker() error =", err) + } + + // test getting an existing credential + got, err := ds.Get(ctx, serverAddr1) + if err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + if want := cred1; got != want { + t.Errorf("DynamicStore.Get() = %v, want %v", got, want) + } + + // test putting a new credential + serverAddr2 := "newtest.example.com" + cred2 := auth.Credential{ + Username: "username", + Password: "password", + } + if err := ds.Put(ctx, serverAddr2, cred2); err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + + // test getting the new credential + got, err = ds.Get(ctx, serverAddr2) + if err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + if want := cred2; got != want { + t.Errorf("DynamicStore.Get() = %v, want %v", got, want) + } + + // test deleting the old credential + err = ds.Delete(ctx, serverAddr1) + if err != nil { + t.Fatal("DynamicStore.Delete() error =", err) + } + + // verify delete + got, err = ds.Get(ctx, serverAddr1) + if err != nil { + t.Fatal("DynamicStore.Get() error =", err) + } + if want := auth.EmptyCredential; got != want { + t.Errorf("DynamicStore.Get() = %v, want %v", got, want) + } +} diff --git a/registry/remote/credentials/testdata/bad_config b/registry/remote/credentials/testdata/bad_config new file mode 100644 index 00000000..22a4f5d5 --- /dev/null +++ b/registry/remote/credentials/testdata/bad_config @@ -0,0 +1 @@ +bad diff --git a/registry/remote/credentials/testdata/credHelpers_config.json b/registry/remote/credentials/testdata/credHelpers_config.json new file mode 100644 index 00000000..f33a98e2 --- /dev/null +++ b/registry/remote/credentials/testdata/credHelpers_config.json @@ -0,0 +1,15 @@ +{ + "auths": { + "registry1.example.com": { + "auth": "dXNlcm5hbWU6cGFzc3dvcmQ=" + }, + "registry3.example.com": { + "auth": "Zm9vOmJhcg==" + } + }, + "credHelpers": { + "registry1.example.com": "registry1-helper", + "registry2.example.com": "registry2-helper", + "registry3.example.com": "" + } +} diff --git a/registry/remote/credentials/testdata/credsStore_config.json b/registry/remote/credentials/testdata/credsStore_config.json new file mode 100644 index 00000000..40eb3843 --- /dev/null +++ b/registry/remote/credentials/testdata/credsStore_config.json @@ -0,0 +1,6 @@ +{ + "credHelpers": { + "test.example.com": "test-helper" + }, + "credsStore": "teststore" +} diff --git a/registry/remote/credentials/testdata/empty.json b/registry/remote/credentials/testdata/empty.json new file mode 100644 index 00000000..e69de29b diff --git a/registry/remote/credentials/testdata/invalid_auths_config.json b/registry/remote/credentials/testdata/invalid_auths_config.json new file mode 100644 index 00000000..8010010a --- /dev/null +++ b/registry/remote/credentials/testdata/invalid_auths_config.json @@ -0,0 +1,3 @@ +{ + "auths": "whaterver" +} diff --git a/registry/remote/credentials/testdata/invalid_auths_entry_config.json b/registry/remote/credentials/testdata/invalid_auths_entry_config.json new file mode 100644 index 00000000..60a45445 --- /dev/null +++ b/registry/remote/credentials/testdata/invalid_auths_entry_config.json @@ -0,0 +1,11 @@ +{ + "auths": { + "registry1.example.com": { + "auth": "username:password" + }, + "registry2.example.com": "whatever", + "registry3.example.com": { + "identitytoken": 123 + } + } +} diff --git a/registry/remote/credentials/testdata/legacy_auths_config.json b/registry/remote/credentials/testdata/legacy_auths_config.json new file mode 100644 index 00000000..20d6a8d1 --- /dev/null +++ b/registry/remote/credentials/testdata/legacy_auths_config.json @@ -0,0 +1,25 @@ +{ + "auths": { + "registry1.example.com": { + "auth": "dXNlcm5hbWUxOnBhc3N3b3JkMQ==" + }, + "http://registry2.example.com": { + "auth": "dXNlcm5hbWUyOnBhc3N3b3JkMg==" + }, + "https://registry3.example.com": { + "auth": "dXNlcm5hbWUzOnBhc3N3b3JkMw==" + }, + "http://registry4.example.com/": { + "auth": "dXNlcm5hbWU0OnBhc3N3b3JkNA==" + }, + "https://registry5.example.com/": { + "auth": "dXNlcm5hbWU1OnBhc3N3b3JkNQ==" + }, + "https://registry6.example.com/path/": { + "auth": "dXNlcm5hbWU2OnBhc3N3b3JkNg==" + }, + "https://registry1.example.com/": { + "auth": "Zm9vOmJhcg==" + } + } +} diff --git a/registry/remote/credentials/testdata/no_auths_config.json b/registry/remote/credentials/testdata/no_auths_config.json new file mode 100644 index 00000000..e07eb621 --- /dev/null +++ b/registry/remote/credentials/testdata/no_auths_config.json @@ -0,0 +1,3 @@ +{ + "key": "val" +} diff --git a/registry/remote/credentials/testdata/valid_auths_config.json b/registry/remote/credentials/testdata/valid_auths_config.json new file mode 100644 index 00000000..e643c082 --- /dev/null +++ b/registry/remote/credentials/testdata/valid_auths_config.json @@ -0,0 +1,28 @@ +{ + "auths": { + "registry1.example.com": { + "auth": "dXNlcm5hbWU6cGFzc3dvcmQ=" + }, + "registry2.example.com": { + "identitytoken": "identity_token" + }, + "registry3.example.com": { + "registrytoken": "registry_token" + }, + "registry4.example.com": { + "auth": "dXNlcm5hbWU6cGFzc3dvcmQ=", + "identitytoken": "identity_token", + "registrytoken": "registry_token" + }, + "registry5.example.com": {}, + "registry6.example.com": { + "username": "username", + "password": "password" + }, + "registry7.example.com": { + "auth": "dXNlcm5hbWU6cGFzc3dvcmQ=", + "username": "foo", + "password": "bar" + } + } +} diff --git a/registry/remote/credentials/trace/example_test.go b/registry/remote/credentials/trace/example_test.go new file mode 100644 index 00000000..65f6af78 --- /dev/null +++ b/registry/remote/credentials/trace/example_test.go @@ -0,0 +1,65 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package trace_test + +import ( + "context" + "fmt" + + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/credentials" + "oras.land/oras-go/v2/registry/remote/credentials/trace" +) + +// An example on how to use ExecutableTrace with Stores. +func Example() { + // ExecutableTrace works with all Stores that may invoke executables, for + // example the Store returned from NewStore and NewNativeStore. + store, err := credentials.NewStore("example/path/config.json", credentials.StoreOptions{}) + if err != nil { + panic(err) + } + + // Define ExecutableTrace and add it to the context. The 'action' argument + // refers to one of 'store', 'get' and 'erase' defined by the docker + // credential helper protocol. + // Reference: https://docs.docker.com/engine/reference/commandline/login/#credential-helper-protocol + traceHooks := &trace.ExecutableTrace{ + ExecuteStart: func(executableName string, action string) { + fmt.Printf("executable %s, action %s started", executableName, action) + }, + ExecuteDone: func(executableName string, action string, err error) { + fmt.Printf("executable %s, action %s finished", executableName, action) + }, + } + ctx := trace.WithExecutableTrace(context.Background(), traceHooks) + + // Get, Put and Delete credentials from store. If any credential helper + // executable is run, traceHooks is executed. + err = store.Put(ctx, "localhost:5000", auth.Credential{Username: "testUsername", Password: "testPassword"}) + if err != nil { + panic(err) + } + + cred, err := store.Get(ctx, "localhost:5000") + if err != nil { + panic(err) + } + fmt.Println(cred) + + err = store.Delete(ctx, "localhost:5000") + if err != nil { + panic(err) + } +} diff --git a/registry/remote/credentials/trace/trace.go b/registry/remote/credentials/trace/trace.go new file mode 100644 index 00000000..b7cd8683 --- /dev/null +++ b/registry/remote/credentials/trace/trace.go @@ -0,0 +1,94 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package trace + +import "context" + +// executableTraceContextKey is a value key used to retrieve the ExecutableTrace +// from Context. +type executableTraceContextKey struct{} + +// ExecutableTrace is a set of hooks used to trace the execution of binary +// executables. Any particular hook may be nil. +type ExecutableTrace struct { + // ExecuteStart is called before the execution of the executable. The + // executableName parameter is the name of the credential helper executable + // used with NativeStore. The action parameter is one of "store", "get" and + // "erase". + // + // Reference: + // - https://docs.docker.com/engine/reference/commandline/login#credentials-store + ExecuteStart func(executableName string, action string) + + // ExecuteDone is called after the execution of an executable completes. + // The executableName parameter is the name of the credential helper + // executable used with NativeStore. The action parameter is one of "store", + // "get" and "erase". The err parameter is the error (if any) returned from + // the execution. + // + // Reference: + // - https://docs.docker.com/engine/reference/commandline/login#credentials-store + ExecuteDone func(executableName string, action string, err error) +} + +// ContextExecutableTrace returns the ExecutableTrace associated with the +// context. If none, it returns nil. +func ContextExecutableTrace(ctx context.Context) *ExecutableTrace { + trace, _ := ctx.Value(executableTraceContextKey{}).(*ExecutableTrace) + return trace +} + +// WithExecutableTrace takes a Context and an ExecutableTrace, and returns a +// Context with the ExecutableTrace added as a Value. If the Context has a +// previously added trace, the hooks defined in the new trace will be added +// in addition to the previous ones. The recent hooks will be called first. +func WithExecutableTrace(ctx context.Context, trace *ExecutableTrace) context.Context { + if trace == nil { + return ctx + } + if oldTrace := ContextExecutableTrace(ctx); oldTrace != nil { + trace.compose(oldTrace) + } + return context.WithValue(ctx, executableTraceContextKey{}, trace) +} + +// compose takes an oldTrace and modifies the existing trace to include +// the hooks defined in the oldTrace. The hooks in the existing trace will +// be called first. +func (trace *ExecutableTrace) compose(oldTrace *ExecutableTrace) { + if oldStart := oldTrace.ExecuteStart; oldStart != nil { + start := trace.ExecuteStart + if start != nil { + trace.ExecuteStart = func(executableName, action string) { + start(executableName, action) + oldStart(executableName, action) + } + } else { + trace.ExecuteStart = oldStart + } + } + if oldDone := oldTrace.ExecuteDone; oldDone != nil { + done := trace.ExecuteDone + if done != nil { + trace.ExecuteDone = func(executableName, action string, err error) { + done(executableName, action, err) + oldDone(executableName, action, err) + } + } else { + trace.ExecuteDone = oldDone + } + } +} diff --git a/registry/remote/errcode/errors.go b/registry/remote/errcode/errors.go index cf0018a0..fb192aa8 100644 --- a/registry/remote/errcode/errors.go +++ b/registry/remote/errcode/errors.go @@ -24,7 +24,7 @@ import ( ) // References: -// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#error-codes +// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#error-codes // - https://docs.docker.com/registry/spec/api/#errors-2 const ( ErrorCodeBlobUnknown = "BLOB_UNKNOWN" @@ -45,7 +45,7 @@ const ( // Error represents a response inner error returned by the remote // registry. // References: -// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#error-codes +// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#error-codes // - https://docs.docker.com/registry/spec/api/#errors-2 type Error struct { Code string `json:"code"` @@ -73,7 +73,7 @@ func (e Error) Error() string { // Errors represents a list of response inner errors returned by the remote // server. // References: -// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#error-codes +// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#error-codes // - https://docs.docker.com/registry/spec/api/#errors-2 type Errors []Error diff --git a/registry/remote/example_test.go b/registry/remote/example_test.go index 5d251189..d2fb529d 100644 --- a/registry/remote/example_test.go +++ b/registry/remote/example_test.go @@ -40,18 +40,20 @@ import ( ) const ( - exampleRepositoryName = "example" - exampleTag = "latest" - exampleConfig = "Example config content" - exampleLayer = "Example layer content" - exampleUploadUUid = "0bc84d80-837c-41d9-824e-1907463c53b3" - ManifestDigest = "sha256:0b696106ecd0654e031f19e0a8cbd1aee4ad457d7c9cea881f07b12a930cd307" - ReferenceManifestDigest = "sha256:6983f495f7ee70d43e571657ae8b39ca3d3ca1b0e77270fd4fbddfb19832a1cf" + _ = ExampleUnplayable + exampleRepositoryName = "example" + exampleTag = "latest" + exampleConfig = "Example config content" + exampleLayer = "Example layer content" + exampleUploadUUid = "0bc84d80-837c-41d9-824e-1907463c53b3" + // For ExampleRepository_Push_artifactReferenceManifest: + ManifestDigest = "sha256:a3f9d449466b9b7194c3a76ca4890d792e11eb4e62e59aa8b4c3cce0a56f129d" + ReferenceManifestDigest = "sha256:2d30397701742b04550891851529abe6b071e4fae920a91897d34612662a3bf6" + // For Example_pushAndIgnoreReferrersIndexError: referrersAPIUnavailableRepositoryName = "no-referrers-api" - referrerDigest = "sha256:21c623eb8ccd273f2702efd74a0abb455dd06a99987f413c2114fb00961ebfe7" + referrerDigest = "sha256:4caba1e18385eb152bd92e9fee1dc01e47c436e594123b3c2833acfcad9883e2" referrersTag = "sha256-c824a9aa7d2e3471306648c6d4baa1abbcb97ff0276181ab4722ca27127cdba0" referrerIndexDigest = "sha256:7baac5147dd58d56fdbaad5a888fa919235a3a90cb71aaa8b56ee5d19f4cd838" - _ = ExampleUnplayable ) var ( @@ -107,8 +109,10 @@ var ( Size: int64(len(exampleManifestWithBlobs))} subjectDescriptor = content.NewDescriptorFromBytes(ocispec.MediaTypeImageManifest, []byte(`{"layers":[]}`)) referrerManifestContent, _ = json.Marshal(ocispec.Manifest{ + Versioned: specs.Versioned{SchemaVersion: 2}, MediaType: ocispec.MediaTypeImageManifest, Subject: &subjectDescriptor, + Config: ocispec.DescriptorEmptyJSON, }) referrerDescriptor = content.NewDescriptorFromBytes(ocispec.MediaTypeImageManifest, referrerManifestContent) referrerIndex, _ = json.Marshal(ocispec.Index{ @@ -194,6 +198,7 @@ func TestMain(m *testing.M) { w.Header().Set("Content-Type", ocispec.MediaTypeImageManifest) w.Header().Set("Docker-Content-Digest", exampleManifestDigest) w.Header().Set("Content-Length", strconv.Itoa(len([]byte(exampleManifest)))) + w.Header().Set("Warning", `299 - "This image is deprecated and will be removed soon."`) if m == "GET" { w.Write([]byte(exampleManifest)) } @@ -303,7 +308,11 @@ func ExampleRepository_Push_artifactReferenceManifest() { // 1. assemble the referenced artifact manifest manifest := ocispec.Manifest{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, MediaType: ocispec.MediaTypeImageManifest, + Config: content.NewDescriptorFromBytes(ocispec.MediaTypeImageConfig, []byte("config bytes")), } manifestContent, err := json.Marshal(manifest) if err != nil { @@ -730,6 +739,41 @@ func Example_pullByDigest() { // {"schemaVersion":2,"config":{"mediaType":"application/vnd.oci.image.config.v1+json","digest":"sha256:569224ae188c06e97b9fcadaeb2358fb0fb7c4eb105d49aee2620b2719abea43","size":22},"layers":[{"mediaType":"application/vnd.oci.image.layer.v1.tar","digest":"sha256:ef79e47691ad1bc702d7a256da6323ec369a8fc3159b4f1798a47136f3b38c10","size":21}]} } +func Example_handleWarning() { + repo, err := remote.NewRepository(fmt.Sprintf("%s/%s", host, exampleRepositoryName)) + if err != nil { + panic(err) + } + // 1. specify HandleWarning + repo.HandleWarning = func(warning remote.Warning) { + fmt.Printf("Warning from %s: %s\n", repo.Reference.Repository, warning.Text) + } + + ctx := context.Background() + exampleDigest := "sha256:b53dc03a49f383ba230d8ac2b78a9c4aec132e4a9f36cc96524df98163202cc7" + // 2. resolve the descriptor + descriptor, err := repo.Resolve(ctx, exampleDigest) + if err != nil { + panic(err) + } + fmt.Println(descriptor.Digest) + fmt.Println(descriptor.Size) + + // 3. fetch the content byte[] from the repository + pulledBlob, err := content.FetchAll(ctx, repo, descriptor) + if err != nil { + panic(err) + } + fmt.Println(string(pulledBlob)) + + // Output: + // Warning from example: This image is deprecated and will be removed soon. + // sha256:b53dc03a49f383ba230d8ac2b78a9c4aec132e4a9f36cc96524df98163202cc7 + // 337 + // Warning from example: This image is deprecated and will be removed soon. + // {"schemaVersion":2,"config":{"mediaType":"application/vnd.oci.image.config.v1+json","digest":"sha256:569224ae188c06e97b9fcadaeb2358fb0fb7c4eb105d49aee2620b2719abea43","size":22},"layers":[{"mediaType":"application/vnd.oci.image.layer.v1.tar","digest":"sha256:ef79e47691ad1bc702d7a256da6323ec369a8fc3159b4f1798a47136f3b38c10","size":21}]} +} + // Example_pushAndTag gives example snippet of pushing an OCI image with a tag. func Example_pushAndTag() { repo, err := remote.NewRepository(fmt.Sprintf("%s/%s", host, exampleRepositoryName)) diff --git a/registry/remote/referrers.go b/registry/remote/referrers.go index a3ed08ca..191db9d1 100644 --- a/registry/remote/referrers.go +++ b/registry/remote/referrers.go @@ -22,7 +22,6 @@ import ( ocispec "github.com/opencontainers/image-spec/specs-go/v1" "oras.land/oras-go/v2/content" "oras.land/oras-go/v2/internal/descriptor" - "oras.land/oras-go/v2/internal/spec" ) // zeroDigest represents a digest that consists of zeros. zeroDigest is used @@ -103,17 +102,15 @@ func (e *ReferrersError) IsReferrersIndexDelete() bool { // buildReferrersTag builds the referrers tag for the given manifest descriptor. // Format: - -// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#unavailable-referrers-api +// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#unavailable-referrers-api func buildReferrersTag(desc ocispec.Descriptor) string { alg := desc.Digest.Algorithm().String() encoded := desc.Digest.Encoded() return alg + "-" + encoded } -// isReferrersFilterApplied checks annotations to see if requested is in the -// applied filter list. -func isReferrersFilterApplied(annotations map[string]string, requested string) bool { - applied := annotations[spec.AnnotationReferrersFiltersApplied] +// isReferrersFilterApplied checks if requsted is in the applied filter list. +func isReferrersFilterApplied(applied, requested string) bool { if applied == "" || requested == "" { return false } diff --git a/registry/remote/referrers_test.go b/registry/remote/referrers_test.go index 1b8b98f8..8a1bdda3 100644 --- a/registry/remote/referrers_test.go +++ b/registry/remote/referrers_test.go @@ -63,63 +63,57 @@ func Test_buildReferrersTag(t *testing.T) { func Test_isReferrersFilterApplied(t *testing.T) { tests := []struct { - name string - annotations map[string]string - requested string - want bool + name string + applied string + requested string + want bool }{ { - name: "single filter applied, specified filter matches", - annotations: map[string]string{spec.AnnotationReferrersFiltersApplied: "artifactType"}, - requested: "artifactType", - want: true, + name: "single filter applied, specified filter matches", + applied: "artifactType", + requested: "artifactType", + want: true, }, { - name: "single filter applied, specified filter does not match", - annotations: map[string]string{spec.AnnotationReferrersFiltersApplied: "foo"}, - requested: "artifactType", - want: false, + name: "single filter applied, specified filter does not match", + applied: "foo", + requested: "artifactType", + want: false, }, { - name: "multiple filters applied, specified filter matches", - annotations: map[string]string{spec.AnnotationReferrersFiltersApplied: "foo,artifactType"}, - requested: "artifactType", - want: true, + name: "multiple filters applied, specified filter matches", + applied: "foo,artifactType", + requested: "artifactType", + want: true, }, { - name: "multiple filters applied, specified filter does not match", - annotations: map[string]string{spec.AnnotationReferrersFiltersApplied: "foo,bar"}, - requested: "artifactType", - want: false, + name: "multiple filters applied, specified filter does not match", + applied: "foo,bar", + requested: "artifactType", + want: false, }, { - name: "single filter applied, specified filter empty", - annotations: map[string]string{spec.AnnotationReferrersFiltersApplied: "foo"}, - requested: "", - want: false, + name: "single filter applied, no specified filter", + applied: "foo", + requested: "", + want: false, }, { - name: "no filter applied", - annotations: map[string]string{}, - requested: "artifactType", - want: false, + name: "no filter applied, specified filter does not match", + applied: "", + requested: "artifactType", + want: false, }, { - name: "empty filter applied", - annotations: map[string]string{spec.AnnotationReferrersFiltersApplied: ""}, - requested: "artifactType", - want: false, - }, - { - name: "no filter applied, specified filter empty", - annotations: map[string]string{}, - requested: "", - want: false, + name: "no filter applied, no specified filter", + applied: "", + requested: "", + want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := isReferrersFilterApplied(tt.annotations, tt.requested); got != tt.want { + if got := isReferrersFilterApplied(tt.applied, tt.requested); got != tt.want { t.Errorf("isReferrersFilterApplied() = %v, want %v", got, tt.want) } }) diff --git a/registry/remote/registry.go b/registry/remote/registry.go index c8c414f1..d1334042 100644 --- a/registry/remote/registry.go +++ b/registry/remote/registry.go @@ -73,13 +73,28 @@ func (r *Registry) client() Client { return r.Client } +// do sends an HTTP request and returns an HTTP response using the HTTP client +// returned by r.client(). +func (r *Registry) do(req *http.Request) (*http.Response, error) { + if r.HandleWarning == nil { + return r.client().Do(req) + } + + resp, err := r.client().Do(req) + if err != nil { + return nil, err + } + handleWarningHeaders(resp.Header.Values(headerWarning), r.HandleWarning) + return resp, nil +} + // Ping checks whether or not the registry implement Docker Registry API V2 or // OCI Distribution Specification. // Ping can be used to check authentication when an auth client is configured. // // References: // - https://docs.docker.com/registry/spec/api/#base -// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#api +// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#api func (r *Registry) Ping(ctx context.Context) error { url := buildRegistryBaseURL(r.PlainHTTP, r.Reference) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) @@ -87,7 +102,7 @@ func (r *Registry) Ping(ctx context.Context) error { return err } - resp, err := r.client().Do(req) + resp, err := r.do(req) if err != nil { return err } @@ -112,7 +127,7 @@ func (r *Registry) Ping(ctx context.Context) error { // // Reference: https://docs.docker.com/registry/spec/api/#catalog func (r *Registry) Repositories(ctx context.Context, last string, fn func(repos []string) error) error { - ctx = auth.AppendScopes(ctx, auth.ScopeRegistryCatalog) + ctx = auth.AppendScopesForHost(ctx, r.Reference.Host(), auth.ScopeRegistryCatalog) url := buildRegistryCatalogURL(r.PlainHTTP, r.Reference) var err error for err == nil { @@ -142,7 +157,7 @@ func (r *Registry) repositories(ctx context.Context, last string, fn func(repos } req.URL.RawQuery = q.Encode() } - resp, err := r.client().Do(req) + resp, err := r.do(req) if err != nil { return "", err } diff --git a/registry/remote/registry_test.go b/registry/remote/registry_test.go index 0cbfbacb..8f91c4e1 100644 --- a/registry/remote/registry_test.go +++ b/registry/remote/registry_test.go @@ -16,9 +16,11 @@ limitations under the License. package remote import ( + "bytes" "context" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "net/url" @@ -178,6 +180,7 @@ func TestRegistry_Repository(t *testing.T) { t.Fatalf("NewRegistry() error = %v", err) } reg.PlainHTTP = true + reg.SkipReferrersGC = true reg.RepositoryListPageSize = 50 reg.TagListPageSize = 100 reg.ReferrerListPageSize = 10 @@ -265,7 +268,119 @@ func TestRegistry_Repositories_WithLastParam(t *testing.T) { } } -//indexOf returns the index of an element within a slice +func TestRegistry_do(t *testing.T) { + data := []byte(`hello world!`) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/test" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + w.Header().Add("Warning", `299 - "Test 1: Good warning."`) + w.Header().Add("Warning", `199 - "Test 2: Warning with a non-299 code."`) + w.Header().Add("Warning", `299 - "Test 3: Good warning."`) + w.Header().Add("Warning", `299 myregistry.example.com "Test 4: Warning with a non-unknown agent"`) + w.Header().Add("Warning", `299 - "Test 5: Warning with a date." "Sat, 25 Aug 2012 23:34:45 GMT"`) + w.Header().Add("wArnIng", `299 - "Test 6: Good warning."`) + w.Write(data) + })) + defer ts.Close() + uri, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + testURL := ts.URL + "/test" + + // test do() without HandleWarning + reg, err := NewRegistry(uri.Host) + if err != nil { + t.Fatal("NewRegistry() error =", err) + } + req, err := http.NewRequest(http.MethodGet, testURL, nil) + if err != nil { + t.Fatal("failed to create test request:", err) + } + resp, err := reg.do(req) + if err != nil { + t.Fatal("Registry.do() error =", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Registry.do() status code = %v, want %v", resp.StatusCode, http.StatusOK) + } + if got := len(resp.Header["Warning"]); got != 6 { + t.Errorf("Registry.do() warning header len = %v, want %v", got, 6) + } + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal("io.ReadAll() error =", err) + } + resp.Body.Close() + if !bytes.Equal(got, data) { + t.Errorf("Registry.do() = %v, want %v", got, data) + } + + // test do() with HandleWarning + reg, err = NewRegistry(uri.Host) + if err != nil { + t.Fatal("NewRegistry() error =", err) + } + var gotWarnings []Warning + reg.HandleWarning = func(warning Warning) { + gotWarnings = append(gotWarnings, warning) + } + + req, err = http.NewRequest(http.MethodGet, testURL, nil) + if err != nil { + t.Fatal("failed to create test request:", err) + } + resp, err = reg.do(req) + if err != nil { + t.Fatal("Registry.do() error =", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Registry.do() status code = %v, want %v", resp.StatusCode, http.StatusOK) + } + if got := len(resp.Header["Warning"]); got != 6 { + t.Errorf("Registry.do() warning header len = %v, want %v", got, 6) + } + got, err = io.ReadAll(resp.Body) + if err != nil { + t.Errorf("Registry.do() = %v, want %v", got, data) + } + resp.Body.Close() + if !bytes.Equal(got, data) { + t.Errorf("Registry.do() = %v, want %v", got, data) + } + + wantWarnings := []Warning{ + { + WarningValue: WarningValue{ + Code: 299, + Agent: "-", + Text: "Test 1: Good warning.", + }, + }, + { + WarningValue: WarningValue{ + Code: 299, + Agent: "-", + Text: "Test 3: Good warning.", + }, + }, + { + WarningValue: WarningValue{ + Code: 299, + Agent: "-", + Text: "Test 6: Good warning.", + }, + }, + } + if !reflect.DeepEqual(gotWarnings, wantWarnings) { + t.Errorf("Registry.do() = %v, want %v", gotWarnings, wantWarnings) + } +} + +// indexOf returns the index of an element within a slice func indexOf(element string, data []string) int { for ind, val := range data { if element == val { diff --git a/registry/remote/repository.go b/registry/remote/repository.go index 32ac347d..b91054fc 100644 --- a/registry/remote/repository.go +++ b/registry/remote/repository.go @@ -37,7 +37,6 @@ import ( "oras.land/oras-go/v2/internal/cas" "oras.land/oras-go/v2/internal/httputil" "oras.land/oras-go/v2/internal/ioutil" - "oras.land/oras-go/v2/internal/registryutil" "oras.land/oras-go/v2/internal/slices" "oras.land/oras-go/v2/internal/spec" "oras.land/oras-go/v2/internal/syncutil" @@ -47,11 +46,37 @@ import ( "oras.land/oras-go/v2/registry/remote/internal/errutil" ) -// dockerContentDigestHeader - The Docker-Content-Digest header, if present -// on the response, returns the canonical digest of the uploaded blob. -// See https://docs.docker.com/registry/spec/api/#digest-header -// See https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#pull -const dockerContentDigestHeader = "Docker-Content-Digest" +const ( + // headerDockerContentDigest is the "Docker-Content-Digest" header. + // If present on the response, it contains the canonical digest of the + // uploaded blob. + // + // References: + // - https://docs.docker.com/registry/spec/api/#digest-header + // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#pull + headerDockerContentDigest = "Docker-Content-Digest" + + // headerOCIFiltersApplied is the "OCI-Filters-Applied" header. + // If present on the response, it contains a comma-separated list of the + // applied filters. + // + // Reference: + // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#listing-referrers + headerOCIFiltersApplied = "OCI-Filters-Applied" + + // headerOCISubject is the "OCI-Subject" header. + // If present on the response, it contains the digest of the subject, + // indicating that Referrers API is supported by the registry. + headerOCISubject = "OCI-Subject" +) + +// filterTypeArtifactType is the "artifactType" filter applied on the list of +// referrers. +// +// References: +// - Latest spec: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#listing-referrers +// - Compatible spec: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#listing-referrers +const filterTypeArtifactType = "artifactType" // Client is an interface for a HTTP client. type Client interface { @@ -93,7 +118,7 @@ type Repository struct { // ReferrerListPageSize specifies the page size when invoking the Referrers // API. // If zero, the page size is determined by the remote registry. - // Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#listing-referrers + // Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#listing-referrers ReferrerListPageSize int // MaxMetadataBytes specifies a limit on how many response bytes are allowed @@ -102,7 +127,26 @@ type Repository struct { // If less than or equal to zero, a default (currently 4MiB) is used. MaxMetadataBytes int64 - // NOTE: Must keep fields in sync with newRepositoryWithOptions function. + // SkipReferrersGC specifies whether to delete the dangling referrers + // index when referrers tag schema is utilized. + // - If false, the old referrers index will be deleted after the new one + // is successfully uploaded. + // - If true, the old referrers index is kept. + // By default, it is disabled (set to false). See also: + // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#referrers-tag-schema + // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#pushing-manifests-with-subject + // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#deleting-manifests + SkipReferrersGC bool + + // HandleWarning handles the warning returned by the remote server. + // Callers SHOULD deduplicate warnings from multiple associated responses. + // + // References: + // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#warnings + // - https://www.rfc-editor.org/rfc/rfc7234#section-5.5 + HandleWarning func(warning Warning) + + // NOTE: Must keep fields in sync with clone(). // referrersState represents that if the repository supports Referrers API. // default: referrersStateUnknown @@ -141,15 +185,24 @@ func newRepositoryWithOptions(ref registry.Reference, opts *RepositoryOptions) ( if err := ref.ValidateRepository(); err != nil { return nil, err } + repo := (*Repository)(opts).clone() + repo.Reference = ref + return repo, nil +} + +// clone makes a copy of the Repository being careful not to copy non-copyable fields (sync.Mutex and syncutil.Pool types) +func (r *Repository) clone() *Repository { return &Repository{ - Client: opts.Client, - Reference: ref, - PlainHTTP: opts.PlainHTTP, - ManifestMediaTypes: slices.Clone(opts.ManifestMediaTypes), - TagListPageSize: opts.TagListPageSize, - ReferrerListPageSize: opts.ReferrerListPageSize, - MaxMetadataBytes: opts.MaxMetadataBytes, - }, nil + Client: r.Client, + Reference: r.Reference, + PlainHTTP: r.PlainHTTP, + ManifestMediaTypes: slices.Clone(r.ManifestMediaTypes), + TagListPageSize: r.TagListPageSize, + ReferrerListPageSize: r.ReferrerListPageSize, + MaxMetadataBytes: r.MaxMetadataBytes, + SkipReferrersGC: r.SkipReferrersGC, + HandleWarning: r.HandleWarning, + } } // SetReferrersCapability indicates the Referrers API capability of the remote @@ -159,9 +212,9 @@ func newRepositoryWithOptions(ref registry.Reference, opts *RepositoryOptions) ( // SetReferrersCapability returns ErrReferrersCapabilityAlreadySet if the // Referrers API capability has been already set. // - When the capability is set to true, the Referrers() function will always -// request the Referrers API. Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#listing-referrers +// request the Referrers API. Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#listing-referrers // - When the capability is set to false, the Referrers() function will always -// request the Referrers Tag. Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#referrers-tag-schema +// request the Referrers Tag. Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#referrers-tag-schema // - When the capability is not set, the Referrers() function will automatically // determine which API to use. func (r *Repository) SetReferrersCapability(capable bool) error { @@ -196,6 +249,21 @@ func (r *Repository) client() Client { return r.Client } +// do sends an HTTP request and returns an HTTP response using the HTTP client +// returned by r.client(). +func (r *Repository) do(req *http.Request) (*http.Response, error) { + if r.HandleWarning == nil { + return r.client().Do(req) + } + + resp, err := r.client().Do(req) + if err != nil { + return nil, err + } + handleWarningHeaders(resp.Header.Values(headerWarning), r.HandleWarning) + return resp, nil +} + // blobStore detects the blob store for the given descriptor. func (r *Repository) blobStore(desc ocispec.Descriptor) registry.BlobStore { if isManifest(r.ManifestMediaTypes, desc) { @@ -320,10 +388,10 @@ func (r *Repository) ParseReference(reference string) (registry.Reference, error // of the Tags list. // // References: -// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#content-discovery +// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#content-discovery // - https://docs.docker.com/registry/spec/api/#tags func (r *Repository) Tags(ctx context.Context, last string, fn func(tags []string) error) error { - ctx = registryutil.WithScopeHint(ctx, r.Reference, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, r.Reference, auth.ActionPull) url := buildRepositoryTagListURL(r.PlainHTTP, r.Reference) var err error for err == nil { @@ -353,7 +421,7 @@ func (r *Repository) tags(ctx context.Context, last string, fn func(tags []strin } req.URL.RawQuery = q.Encode() } - resp, err := r.client().Do(req) + resp, err := r.do(req) if err != nil { return "", err } @@ -379,7 +447,7 @@ func (r *Repository) tags(ctx context.Context, last string, fn func(tags []strin // Predecessors returns the descriptors of image or artifact manifests directly // referencing the given manifest descriptor. // Predecessors internally leverages Referrers. -// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#listing-referrers +// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#listing-referrers func (r *Repository) Predecessors(ctx context.Context, desc ocispec.Descriptor) ([]ocispec.Descriptor, error) { var res []ocispec.Descriptor if err := r.Referrers(ctx, desc, "", func(referrers []ocispec.Descriptor) error { @@ -398,7 +466,7 @@ func (r *Repository) Predecessors(ctx context.Context, desc ocispec.Descriptor) // If artifactType is not empty, only referrers of the same artifact type are // fed to fn. // -// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#listing-referrers +// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#listing-referrers func (r *Repository) Referrers(ctx context.Context, desc ocispec.Descriptor, artifactType string, fn func(referrers []ocispec.Descriptor) error) error { state := r.loadReferrersState() if state == referrersStateUnsupported { @@ -440,7 +508,7 @@ func (r *Repository) Referrers(ctx context.Context, desc ocispec.Descriptor, art func (r *Repository) referrersByAPI(ctx context.Context, desc ocispec.Descriptor, artifactType string, fn func(referrers []ocispec.Descriptor) error) error { ref := r.Reference ref.Reference = desc.Digest.String() - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildReferrersURL(r.PlainHTTP, ref, artifactType) var err error @@ -470,7 +538,7 @@ func (r *Repository) referrersPageByAPI(ctx context.Context, artifactType string req.URL.RawQuery = q.Encode() } - resp, err := r.client().Do(req) + resp, err := r.do(req) if err != nil { return "", err } @@ -485,10 +553,19 @@ func (r *Repository) referrersPageByAPI(ctx context.Context, artifactType string if err := json.NewDecoder(lr).Decode(&index); err != nil { return "", fmt.Errorf("%s %q: failed to decode response: %w", resp.Request.Method, resp.Request.URL, err) } + referrers := index.Manifests - if artifactType != "" && !isReferrersFilterApplied(index.Annotations, "artifactType") { - // perform client side filtering if the filter is not applied on the server side - referrers = filterReferrers(referrers, artifactType) + if artifactType != "" { + // check both filters header and filters annotations for compatibility + // latest spec for filters header: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#listing-referrers + // older spec for filters annotations: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#listing-referrers + filtersHeader := resp.Header.Get(headerOCIFiltersApplied) + filtersAnnotation := index.Annotations[spec.AnnotationReferrersFiltersApplied] + if !isReferrersFilterApplied(filtersHeader, filterTypeArtifactType) && + !isReferrersFilterApplied(filtersAnnotation, filterTypeArtifactType) { + // perform client side filtering if the filter is not applied on the server side + referrers = filterReferrers(referrers, artifactType) + } } if len(referrers) > 0 { if err := fn(referrers); err != nil { @@ -502,7 +579,7 @@ func (r *Repository) referrersPageByAPI(ctx context.Context, artifactType string // referencing the given manifest descriptor by requesting referrers tag. // fn is called for the referrers result. If artifactType is not empty, // only referrers of the same artifact type are fed to fn. -// reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#backwards-compatibility +// reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#backwards-compatibility func (r *Repository) referrersByTagSchema(ctx context.Context, desc ocispec.Descriptor, artifactType string, fn func(referrers []ocispec.Descriptor) error) error { referrersTag := buildReferrersTag(desc) _, referrers, err := r.referrersFromIndex(ctx, referrersTag) @@ -565,14 +642,14 @@ func (r *Repository) pingReferrers(ctx context.Context) (bool, error) { ref := r.Reference ref.Reference = zeroDigest - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildReferrersURL(r.PlainHTTP, ref, "") req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return false, err } - resp, err := r.client().Do(req) + resp, err := r.do(req) if err != nil { return false, err } @@ -599,7 +676,7 @@ func (r *Repository) pingReferrers(ctx context.Context) (bool, error) { func (r *Repository) delete(ctx context.Context, target ocispec.Descriptor, isManifest bool) error { ref := r.Reference ref.Reference = target.Digest.String() - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionDelete) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionDelete) buildURL := buildRepositoryBlobURL if isManifest { buildURL = buildRepositoryManifestURL @@ -610,7 +687,7 @@ func (r *Repository) delete(ctx context.Context, target ocispec.Descriptor, isMa return err } - resp, err := r.client().Do(req) + resp, err := r.do(req) if err != nil { return err } @@ -635,14 +712,14 @@ type blobStore struct { func (s *blobStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io.ReadCloser, err error) { ref := s.repo.Reference ref.Reference = target.Digest.String() - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } - resp, err := s.repo.client().Do(req) + resp, err := s.repo.do(req) if err != nil { return nil, err } @@ -677,19 +754,19 @@ func (s *blobStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io func (s *blobStore) Mount(ctx context.Context, desc ocispec.Descriptor, fromRepo string, getContent func() (io.ReadCloser, error)) error { // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = registryutil.WithScopeHint(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) // We also need pull access to the source repo. fromRef := s.repo.Reference fromRef.Repository = fromRepo - ctx = registryutil.WithScopeHint(ctx, fromRef, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, fromRef, auth.ActionPull) url := buildRepositoryBlobMountURL(s.repo.PlainHTTP, s.repo.Reference, desc.Digest, fromRepo) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) if err != nil { return err } - resp, err := s.repo.client().Do(req) + resp, err := s.repo.do(req) if err != nil { return err } @@ -715,7 +792,7 @@ func (s *blobStore) Mount(ctx context.Context, desc ocispec.Descriptor, fromRepo // push it. If the caller has provided a getContent function, we // can use that, otherwise pull the content from the source repository. // - // [spec]: https://github.com/opencontainers/distribution-spec/blob/main/spec.md#mounting-a-blob-from-another-repository + // [spec]: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#mounting-a-blob-from-another-repository var r io.ReadCloser if getContent != nil { @@ -733,10 +810,10 @@ func (s *blobStore) Mount(ctx context.Context, desc ocispec.Descriptor, fromRepo // sibling returns a blob store for another repository in the same // registry. func (s *blobStore) sibling(otherRepoName string) *blobStore { - otherRepo := *s.repo + otherRepo := s.repo.clone() otherRepo.Reference.Repository = otherRepoName return &blobStore{ - repo: &otherRepo, + repo: otherRepo, } } @@ -746,22 +823,23 @@ func (s *blobStore) sibling(otherRepoName string) *blobStore { // Push is done by conventional 2-step monolithic upload instead of a single // `POST` request for better overall performance. It also allows early fail on // authentication errors. +// // References: -// - https://docs.docker.com/registry/spec/api/#pushing-an-image -// - https://docs.docker.com/registry/spec/api/#initiate-blob-upload -// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#pushing-a-blob-monolithically +// - https://docs.docker.com/registry/spec/api/#pushing-an-image +// - https://docs.docker.com/registry/spec/api/#initiate-blob-upload +// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#pushing-a-blob-monolithically func (s *blobStore) Push(ctx context.Context, expected ocispec.Descriptor, content io.Reader) error { // start an upload // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = registryutil.WithScopeHint(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) url := buildRepositoryBlobUploadURL(s.repo.PlainHTTP, s.repo.Reference) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) if err != nil { return err } - resp, err := s.repo.client().Do(req) + resp, err := s.repo.do(req) if err != nil { return err } @@ -816,7 +894,7 @@ func (s *blobStore) completePushAfterInitialPost(ctx context.Context, req *http. if auth := resp.Request.Header.Get("Authorization"); auth != "" { req.Header.Set("Authorization", auth) } - resp, err = s.repo.client().Do(req) + resp, err = s.repo.do(req) if err != nil { return err } @@ -855,14 +933,14 @@ func (s *blobStore) Resolve(ctx context.Context, reference string) (ocispec.Desc if err != nil { return ocispec.Descriptor{}, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { return ocispec.Descriptor{}, err } - resp, err := s.repo.client().Do(req) + resp, err := s.repo.do(req) if err != nil { return ocispec.Descriptor{}, err } @@ -890,14 +968,14 @@ func (s *blobStore) FetchReference(ctx context.Context, reference string) (desc return ocispec.Descriptor{}, nil, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return ocispec.Descriptor{}, nil, err } - resp, err := s.repo.client().Do(req) + resp, err := s.repo.do(req) if err != nil { return ocispec.Descriptor{}, nil, err } @@ -965,7 +1043,7 @@ type manifestStore struct { func (s *manifestStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io.ReadCloser, err error) { ref := s.repo.Reference ref.Reference = target.Digest.String() - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -973,7 +1051,7 @@ func (s *manifestStore) Fetch(ctx context.Context, target ocispec.Descriptor) (r } req.Header.Set("Accept", target.MediaType) - resp, err := s.repo.client().Do(req) + resp, err := s.repo.do(req) if err != nil { return nil, err } @@ -1032,7 +1110,8 @@ func (s *manifestStore) Delete(ctx context.Context, target ocispec.Descriptor) e // deleteWithIndexing removes the manifest content identified by the descriptor, // and indexes referrers for the manifest when needed. func (s *manifestStore) deleteWithIndexing(ctx context.Context, target ocispec.Descriptor) error { - if target.MediaType == spec.MediaTypeArtifactManifest || target.MediaType == ocispec.MediaTypeImageManifest { + switch target.MediaType { + case spec.MediaTypeArtifactManifest, ocispec.MediaTypeImageManifest, ocispec.MediaTypeImageIndex: if state := s.repo.loadReferrersState(); state == referrersStateSupported { // referrers API is available, no client-side indexing needed return s.repo.delete(ctx, target, true) @@ -1041,6 +1120,7 @@ func (s *manifestStore) deleteWithIndexing(ctx context.Context, target ocispec.D if err := limitSize(target, s.repo.MaxMetadataBytes); err != nil { return err } + ctx = auth.AppendRepositoryScope(ctx, s.repo.Reference, auth.ActionPull, auth.ActionDelete) manifestJSON, err := content.FetchAll(ctx, s, target) if err != nil { return err @@ -1053,9 +1133,12 @@ func (s *manifestStore) deleteWithIndexing(ctx context.Context, target ocispec.D return s.repo.delete(ctx, target, true) } -// indexReferrersForDelete indexes referrers for image or artifact manifest with -// the subject field on manifest delete. -// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#deleting-manifests +// indexReferrersForDelete indexes referrers for manifests with a subject field +// on manifest delete. +// +// References: +// - Latest spec: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#deleting-manifests +// - Compatible spec: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#deleting-manifests func (s *manifestStore) indexReferrersForDelete(ctx context.Context, desc ocispec.Descriptor, manifestJSON []byte) error { var manifest struct { Subject *ocispec.Descriptor `json:"subject"` @@ -1087,7 +1170,7 @@ func (s *manifestStore) Resolve(ctx context.Context, reference string) (ocispec. if err != nil { return ocispec.Descriptor{}, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { @@ -1095,7 +1178,7 @@ func (s *manifestStore) Resolve(ctx context.Context, reference string) (ocispec. } req.Header.Set("Accept", manifestAcceptHeader(s.repo.ManifestMediaTypes)) - resp, err := s.repo.client().Do(req) + resp, err := s.repo.do(req) if err != nil { return ocispec.Descriptor{}, err } @@ -1119,7 +1202,7 @@ func (s *manifestStore) FetchReference(ctx context.Context, reference string) (d return ocispec.Descriptor{}, nil, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1127,7 +1210,7 @@ func (s *manifestStore) FetchReference(ctx context.Context, reference string) (d } req.Header.Set("Accept", manifestAcceptHeader(s.repo.ManifestMediaTypes)) - resp, err := s.repo.client().Do(req) + resp, err := s.repo.do(req) if err != nil { return ocispec.Descriptor{}, nil, err } @@ -1162,7 +1245,7 @@ func (s *manifestStore) Tag(ctx context.Context, desc ocispec.Descriptor, refere return err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull, auth.ActionPush) rc, err := s.Fetch(ctx, desc) if err != nil { return err @@ -1187,7 +1270,7 @@ func (s *manifestStore) push(ctx context.Context, expected ocispec.Descriptor, c ref.Reference = reference // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull, auth.ActionPush) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) // unwrap the content for optimizations of built-in types. body := ioutil.UnwrapNopCloser(content) @@ -1225,7 +1308,7 @@ func (s *manifestStore) push(ctx context.Context, expected ocispec.Descriptor, c return err } } - resp, err := client.Do(req) + resp, err := s.repo.do(req) if err != nil { return err } @@ -1234,14 +1317,26 @@ func (s *manifestStore) push(ctx context.Context, expected ocispec.Descriptor, c if resp.StatusCode != http.StatusCreated { return errutil.ParseErrorResponse(resp) } + s.checkOCISubjectHeader(resp) return verifyContentDigest(resp, expected.Digest) } +// checkOCISubjectHeader checks the "OCI-Subject" header in the response and +// sets referrers capability accordingly. +// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#pushing-manifests-with-subject +func (s *manifestStore) checkOCISubjectHeader(resp *http.Response) { + // Referrers capability is not set to false when the subject header is not + // present, as the server may still conform to an older version of the spec + if subjectHeader := resp.Header.Get(headerOCISubject); subjectHeader != "" { + s.repo.SetReferrersCapability(true) + } +} + // pushWithIndexing pushes the manifest content matching the expected descriptor, // and indexes referrers for the manifest when needed. func (s *manifestStore) pushWithIndexing(ctx context.Context, expected ocispec.Descriptor, r io.Reader, reference string) error { switch expected.MediaType { - case spec.MediaTypeArtifactManifest, ocispec.MediaTypeImageManifest: + case spec.MediaTypeArtifactManifest, ocispec.MediaTypeImageManifest, ocispec.MediaTypeImageIndex: if state := s.repo.loadReferrersState(); state == referrersStateSupported { // referrers API is available, no client-side indexing needed return s.push(ctx, expected, r, reference) @@ -1257,15 +1352,22 @@ func (s *manifestStore) pushWithIndexing(ctx context.Context, expected ocispec.D if err := s.push(ctx, expected, bytes.NewReader(manifestJSON), reference); err != nil { return err } + // check referrers API availability again after push + if state := s.repo.loadReferrersState(); state == referrersStateSupported { + return nil + } return s.indexReferrersForPush(ctx, expected, manifestJSON) default: return s.push(ctx, expected, r, reference) } } -// indexReferrersForPush indexes referrers for image or artifact manifest with -// the subject field on manifest push. -// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#pushing-manifests-with-subject +// indexReferrersForPush indexes referrers for manifests with a subject field +// on manifest push. +// +// References: +// - Latest spec: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#pushing-manifests-with-subject +// - Compatible spec: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#pushing-manifests-with-subject func (s *manifestStore) indexReferrersForPush(ctx context.Context, desc ocispec.Descriptor, manifestJSON []byte) error { var subject ocispec.Descriptor switch desc.MediaType { @@ -1291,7 +1393,22 @@ func (s *manifestStore) indexReferrersForPush(ctx context.Context, desc ocispec. return nil } subject = *manifest.Subject - desc.ArtifactType = manifest.Config.MediaType + desc.ArtifactType = manifest.ArtifactType + if desc.ArtifactType == "" { + desc.ArtifactType = manifest.Config.MediaType + } + desc.Annotations = manifest.Annotations + case ocispec.MediaTypeImageIndex: + var manifest ocispec.Index + if err := json.Unmarshal(manifestJSON, &manifest); err != nil { + return fmt.Errorf("failed to decode manifest: %s: %s: %w", desc.Digest, desc.MediaType, err) + } + if manifest.Subject == nil { + // no subject, no indexing needed + return nil + } + subject = *manifest.Subject + desc.ArtifactType = manifest.ArtifactType desc.Annotations = manifest.Annotations default: return nil @@ -1311,31 +1428,30 @@ func (s *manifestStore) indexReferrersForPush(ctx context.Context, desc ocispec. // updateReferrersIndex updates the referrers index for desc referencing subject // on manifest push and manifest delete. // References: -// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#pushing-manifests-with-subject -// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#deleting-manifests +// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#pushing-manifests-with-subject +// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#deleting-manifests func (s *manifestStore) updateReferrersIndex(ctx context.Context, subject ocispec.Descriptor, change referrerChange) (err error) { referrersTag := buildReferrersTag(subject) - var skipDelete bool - var oldIndexDesc ocispec.Descriptor - var referrers []ocispec.Descriptor + var oldIndexDesc *ocispec.Descriptor + var oldReferrers []ocispec.Descriptor prepare := func() error { // 1. pull the original referrers list using the referrers tag schema - var err error - oldIndexDesc, referrers, err = s.repo.referrersFromIndex(ctx, referrersTag) + indexDesc, referrers, err := s.repo.referrersFromIndex(ctx, referrersTag) if err != nil { if errors.Is(err, errdef.ErrNotFound) { - // no old index found, skip delete - skipDelete = true + // valid case: no old referrers index return nil } return err } + oldIndexDesc = &indexDesc + oldReferrers = referrers return nil } update := func(referrerChanges []referrerChange) error { // 2. apply the referrer changes on the referrers list - updatedReferrers, err := applyReferrerChanges(referrers, referrerChanges) + updatedReferrers, err := applyReferrerChanges(oldReferrers, referrerChanges) if err != nil { if err == errNoReferrerUpdate { return nil @@ -1344,7 +1460,12 @@ func (s *manifestStore) updateReferrersIndex(ctx context.Context, subject ocispe } // 3. push the updated referrers list using referrers tag schema - if len(updatedReferrers) > 0 { + if len(updatedReferrers) > 0 || s.repo.SkipReferrersGC { + // push a new index in either case: + // 1. the referrers list has been updated with a non-zero size + // 2. OR the updated referrers list is empty but referrers GC + // is skipped, in this case an empty index should still be pushed + // as the old index won't get deleted newIndexDesc, newIndex, err := generateIndex(updatedReferrers) if err != nil { return fmt.Errorf("failed to generate referrers index for referrers tag %s: %w", referrersTag, err) @@ -1354,14 +1475,15 @@ func (s *manifestStore) updateReferrersIndex(ctx context.Context, subject ocispe } } - // 4. delete the dangling original referrers index - if !skipDelete { - if err := s.repo.delete(ctx, oldIndexDesc, true); err != nil { - return &ReferrersError{ - Op: opDeleteReferrersIndex, - Err: fmt.Errorf("failed to delete dangling referrers index %s for referrers tag %s: %w", oldIndexDesc.Digest.String(), referrersTag, err), - Subject: subject, - } + // 4. delete the dangling original referrers index, if applicable + if s.repo.SkipReferrersGC || oldIndexDesc == nil { + return nil + } + if err := s.repo.delete(ctx, *oldIndexDesc, true); err != nil { + return &ReferrersError{ + Op: opDeleteReferrersIndex, + Err: fmt.Errorf("failed to delete dangling referrers index %s for referrers tag %s: %w", oldIndexDesc.Digest.String(), referrersTag, err), + Subject: subject, } } return nil @@ -1408,13 +1530,13 @@ func (s *manifestStore) generateDescriptor(resp *http.Response, ref registry.Ref // 4. Validate Server Digest (if present) var serverHeaderDigest digest.Digest - if serverHeaderDigestStr := resp.Header.Get(dockerContentDigestHeader); serverHeaderDigestStr != "" { + if serverHeaderDigestStr := resp.Header.Get(headerDockerContentDigest); serverHeaderDigestStr != "" { if serverHeaderDigest, err = digest.Parse(serverHeaderDigestStr); err != nil { return ocispec.Descriptor{}, fmt.Errorf( "%s %q: invalid response header value: `%s: %s`; %w", resp.Request.Method, resp.Request.URL, - dockerContentDigestHeader, + headerDockerContentDigest, serverHeaderDigestStr, err, ) @@ -1431,7 +1553,7 @@ func (s *manifestStore) generateDescriptor(resp *http.Response, ref registry.Ref // immediate fail return ocispec.Descriptor{}, fmt.Errorf( "HTTP %s request missing required header %q", - httpMethod, dockerContentDigestHeader, + httpMethod, headerDockerContentDigest, ) } // Otherwise, just trust the client-supplied digest @@ -1453,7 +1575,7 @@ func (s *manifestStore) generateDescriptor(resp *http.Response, ref registry.Ref return ocispec.Descriptor{}, fmt.Errorf( "%s %q: invalid response; digest mismatch in %s: received %q when expecting %q", resp.Request.Method, resp.Request.URL, - dockerContentDigestHeader, contentDigest, + headerDockerContentDigest, contentDigest, refDigest, ) } @@ -1485,7 +1607,7 @@ func calculateDigestFromResponse(resp *http.Response, maxMetadataBytes int64) (d // OCI distribution-spec states the Docker-Content-Digest header is optional. // Reference: https://github.com/opencontainers/distribution-spec/blob/v1.0.1/spec.md#legacy-docker-support-http-headers func verifyContentDigest(resp *http.Response, expected digest.Digest) error { - digestStr := resp.Header.Get(dockerContentDigestHeader) + digestStr := resp.Header.Get(headerDockerContentDigest) if len(digestStr) == 0 { return nil @@ -1496,7 +1618,7 @@ func verifyContentDigest(resp *http.Response, expected digest.Digest) error { return fmt.Errorf( "%s %q: invalid response header: `%s: %s`", resp.Request.Method, resp.Request.URL, - dockerContentDigestHeader, digestStr, + headerDockerContentDigest, digestStr, ) } @@ -1504,7 +1626,7 @@ func verifyContentDigest(resp *http.Response, expected digest.Digest) error { return fmt.Errorf( "%s %q: invalid response; digest mismatch in %s: received %q when expecting %q", resp.Request.Method, resp.Request.URL, - dockerContentDigestHeader, contentDigest, + headerDockerContentDigest, contentDigest, expected, ) } diff --git a/registry/remote/repository_test.go b/registry/remote/repository_test.go index f062509d..b6772cbd 100644 --- a/registry/remote/repository_test.go +++ b/registry/remote/repository_test.go @@ -321,7 +321,7 @@ func TestRepository_Mount(t *testing.T) { t.Errorf("unexpected value for 'from' parameter; got %q want %q", got, want) } gotMount++ - w.Header().Set(dockerContentDigestHeader, blobDesc.Digest.String()) + w.Header().Set(headerDockerContentDigest, blobDesc.Digest.String()) w.WriteHeader(201) return default: @@ -693,29 +693,36 @@ func TestRepository_Delete(t *testing.T) { Digest: digest.FromBytes(blob), Size: int64(len(blob)), } - blobDeleted := false index := []byte(`{"manifests":[]}`) indexDesc := ocispec.Descriptor{ MediaType: ocispec.MediaTypeImageIndex, Digest: digest.FromBytes(index), Size: int64(len(index)), } - indexDeleted := false + + var blobDeleted bool + var indexDeleted bool ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete { - t.Errorf("unexpected access: %s %s", r.Method, r.URL) - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - switch r.URL.Path { - case "/v2/test/blobs/" + blobDesc.Digest.String(): + switch { + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/blobs/"+blobDesc.Digest.String(): blobDeleted = true w.Header().Set("Docker-Content-Digest", blobDesc.Digest.String()) w.WriteHeader(http.StatusAccepted) - case "/v2/test/manifests/" + indexDesc.Digest.String(): + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+indexDesc.Digest.String(): indexDeleted = true // no "Docker-Content-Digest" header for manifest deletion w.WriteHeader(http.StatusAccepted) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+indexDesc.Digest.String(): + if accept := r.Header.Get("Accept"); !strings.Contains(accept, indexDesc.MediaType) { + t.Errorf("manifest not convertable: %s", accept) + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", indexDesc.MediaType) + w.Header().Set("Docker-Content-Digest", indexDesc.Digest.String()) + if _, err := w.Write(index); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } default: t.Errorf("unexpected access: %s %s", r.Method, r.URL) w.WriteHeader(http.StatusNotFound) @@ -1892,6 +1899,8 @@ func TestRepository_Referrers_ServerFiltering(t *testing.T) { }, }, } + + // Test with filter annotations only var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { path := "/v2/test/referrers/" + manifestDesc.Digest.String() @@ -1930,6 +1939,7 @@ func TestRepository_Referrers_ServerFiltering(t *testing.T) { }, MediaType: ocispec.MediaTypeImageIndex, Manifests: referrers, + // set filter annotations Annotations: map[string]string{ spec.AnnotationReferrersFiltersApplied: "artifactType", }, @@ -1969,6 +1979,164 @@ func TestRepository_Referrers_ServerFiltering(t *testing.T) { if index != len(referrerSet) { t.Errorf("fn invoked %d time(s), want %d", index, len(referrerSet)) } + + // Test with filter header only + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := "/v2/test/referrers/" + manifestDesc.Digest.String() + queryParams, err := url.ParseQuery(r.URL.RawQuery) + if err != nil { + t.Fatal("failed to parse url query") + } + if r.Method != http.MethodGet || + r.URL.Path != path || + reflect.DeepEqual(queryParams["artifactType"], []string{"application%2Fvnd.test"}) { + t.Errorf("unexpected access: %s %q", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + q := r.URL.Query() + n, err := strconv.Atoi(q.Get("n")) + if err != nil || n != 2 { + t.Errorf("bad page size: %s", q.Get("n")) + w.WriteHeader(http.StatusBadRequest) + return + } + var referrers []ocispec.Descriptor + switch q.Get("test") { + case "foo": + referrers = referrerSet[1] + w.Header().Set("Link", fmt.Sprintf(`<%s%s?n=2&test=bar>; rel="next"`, ts.URL, path)) + case "bar": + referrers = referrerSet[2] + default: + referrers = referrerSet[0] + w.Header().Set("Link", fmt.Sprintf(`<%s?n=2&test=foo>; rel="next"`, path)) + } + result := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: referrers, + } + // set filter header + w.Header().Set("OCI-Filters-Applied", "artifactType") + if err := json.NewEncoder(w).Encode(result); err != nil { + t.Errorf("failed to write response: %v", err) + } + })) + defer ts.Close() + uri, err = url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + + repo, err = NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + repo.ReferrerListPageSize = 2 + + ctx = context.Background() + index = 0 + if err := repo.Referrers(ctx, manifestDesc, "application/vnd.test", func(got []ocispec.Descriptor) error { + if index >= len(referrerSet) { + t.Fatalf("out of index bound: %d", index) + } + referrers := referrerSet[index] + index++ + if !reflect.DeepEqual(got, referrers) { + t.Errorf("Repository.Referrers() = %v, want %v", got, referrers) + } + return nil + }); err != nil { + t.Errorf("Repository.Referrers() error = %v", err) + } + if index != len(referrerSet) { + t.Errorf("fn invoked %d time(s), want %d", index, len(referrerSet)) + } + + // Test with both filter annotation and filter header + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := "/v2/test/referrers/" + manifestDesc.Digest.String() + queryParams, err := url.ParseQuery(r.URL.RawQuery) + if err != nil { + t.Fatal("failed to parse url query") + } + if r.Method != http.MethodGet || + r.URL.Path != path || + reflect.DeepEqual(queryParams["artifactType"], []string{"application%2Fvnd.test"}) { + t.Errorf("unexpected access: %s %q", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + q := r.URL.Query() + n, err := strconv.Atoi(q.Get("n")) + if err != nil || n != 2 { + t.Errorf("bad page size: %s", q.Get("n")) + w.WriteHeader(http.StatusBadRequest) + return + } + var referrers []ocispec.Descriptor + switch q.Get("test") { + case "foo": + referrers = referrerSet[1] + w.Header().Set("Link", fmt.Sprintf(`<%s%s?n=2&test=bar>; rel="next"`, ts.URL, path)) + case "bar": + referrers = referrerSet[2] + default: + referrers = referrerSet[0] + w.Header().Set("Link", fmt.Sprintf(`<%s?n=2&test=foo>; rel="next"`, path)) + } + result := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: referrers, + // set filter annotation + Annotations: map[string]string{ + spec.AnnotationReferrersFiltersApplied: "artifactType", + }, + } + // set filter header + w.Header().Set("OCI-Filters-Applied", "artifactType") + if err := json.NewEncoder(w).Encode(result); err != nil { + t.Errorf("failed to write response: %v", err) + } + })) + defer ts.Close() + uri, err = url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + + repo, err = NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + repo.ReferrerListPageSize = 2 + + ctx = context.Background() + index = 0 + if err := repo.Referrers(ctx, manifestDesc, "application/vnd.test", func(got []ocispec.Descriptor) error { + if index >= len(referrerSet) { + t.Fatalf("out of index bound: %d", index) + } + referrers := referrerSet[index] + index++ + if !reflect.DeepEqual(got, referrers) { + t.Errorf("Repository.Referrers() = %v, want %v", got, referrers) + } + return nil + }); err != nil { + t.Errorf("Repository.Referrers() error = %v", err) + } + if index != len(referrerSet) { + t.Errorf("fn invoked %d time(s), want %d", index, len(referrerSet)) + } } func TestRepository_Referrers_ClientFiltering(t *testing.T) { @@ -2940,7 +3108,7 @@ func Test_generateBlobDescriptorWithVariousDockerContentDigestHeaders(t *testing resp := http.Response{ Header: http.Header{ "Content-Type": []string{"application/vnd.docker.distribution.manifest.v2+json"}, - dockerContentDigestHeader: []string{dcdIOStruct.serverCalculatedDigest.String()}, + headerDockerContentDigest: []string{dcdIOStruct.serverCalculatedDigest.String()}, }, } if method == http.MethodGet { @@ -3126,7 +3294,7 @@ func Test_ManifestStore_Push_ReferrersAPIAvailable(t *testing.T) { } artifactJSON, err := json.Marshal(artifact) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } artifactDesc := content.NewDescriptorFromBytes(artifact.MediaType, artifactJSON) manifest := ocispec.Manifest{ @@ -3135,9 +3303,18 @@ func Test_ManifestStore_Push_ReferrersAPIAvailable(t *testing.T) { } manifestJSON, err := json.Marshal(manifest) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } manifestDesc := content.NewDescriptorFromBytes(manifest.MediaType, manifestJSON) + index := ocispec.Index{ + MediaType: ocispec.MediaTypeImageIndex, + Subject: &subjectDesc, + } + indexJSON, err := json.Marshal(index) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexDesc := content.NewDescriptorFromBytes(manifest.MediaType, indexJSON) var gotManifest []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -3153,6 +3330,7 @@ func Test_ManifestStore_Push_ReferrersAPIAvailable(t *testing.T) { } gotManifest = buf.Bytes() w.Header().Set("Docker-Content-Digest", artifactDesc.Digest.String()) + w.Header().Set("OCI-Subject", subjectDesc.Digest.String()) w.WriteHeader(http.StatusCreated) case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): if contentType := r.Header.Get("Content-Type"); contentType != manifestDesc.MediaType { @@ -3165,18 +3343,21 @@ func Test_ManifestStore_Push_ReferrersAPIAvailable(t *testing.T) { } gotManifest = buf.Bytes() w.Header().Set("Docker-Content-Digest", manifestDesc.Digest.String()) + w.Header().Set("OCI-Subject", subjectDesc.Digest.String()) w.WriteHeader(http.StatusCreated) - case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: - result := ocispec.Index{ - Versioned: specs.Versioned{ - SchemaVersion: 2, // historical value. does not pertain to OCI or docker version - }, - MediaType: ocispec.MediaTypeImageIndex, - Manifests: []ocispec.Descriptor{}, + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+indexDesc.Digest.String(): + if contentType := r.Header.Get("Content-Type"); contentType != indexDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break } - if err := json.NewEncoder(w).Encode(result); err != nil { - t.Errorf("failed to write response: %v", err) + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) } + gotManifest = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexDesc.Digest.String()) + w.Header().Set("OCI-Subject", subjectDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) default: t.Errorf("unexpected access: %s %s", r.Method, r.URL) w.WriteHeader(http.StatusNotFound) @@ -3187,15 +3368,14 @@ func Test_ManifestStore_Push_ReferrersAPIAvailable(t *testing.T) { if err != nil { t.Fatalf("invalid test http server: %v", err) } - ctx := context.Background() + + // test pushing artifact with subject repo, err := NewRepository(uri.Host + "/test") if err != nil { t.Fatalf("NewRepository() error = %v", err) } repo.PlainHTTP = true - - // test push artifact with subject if state := repo.loadReferrersState(); state != referrersStateUnknown { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) } @@ -3206,11 +3386,19 @@ func Test_ManifestStore_Push_ReferrersAPIAvailable(t *testing.T) { if !bytes.Equal(gotManifest, artifactJSON) { t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(artifactJSON)) } - - // test push image manifest with subject if state := repo.loadReferrersState(); state != referrersStateSupported { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) } + + // test pushing image manifest with subject + repo, err = NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } err = repo.Push(ctx, manifestDesc, bytes.NewReader(manifestJSON)) if err != nil { t.Fatalf("Manifests.Push() error = %v", err) @@ -3218,44 +3406,56 @@ func Test_ManifestStore_Push_ReferrersAPIAvailable(t *testing.T) { if !bytes.Equal(gotManifest, manifestJSON) { t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(manifestJSON)) } + if state := repo.loadReferrersState(); state != referrersStateSupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) + } + + // test pushing image index with subject + err = repo.Push(ctx, indexDesc, bytes.NewReader(indexJSON)) + if err != nil { + t.Fatalf("Manifests.Push() error = %v", err) + } + if !bytes.Equal(gotManifest, indexJSON) { + t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(indexJSON)) + } + if state := repo.loadReferrersState(); state != referrersStateSupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) + } } -func Test_ManifestStore_Push_ReferrersAPIUnavailable(t *testing.T) { +func Test_ManifestStore_Push_ReferrersAPIAvailable_NoSubjectHeader(t *testing.T) { // generate test content subject := []byte(`{"layers":[]}`) subjectDesc := content.NewDescriptorFromBytes(spec.MediaTypeArtifactManifest, subject) - referrersTag := strings.Replace(subjectDesc.Digest.String(), ":", "-", 1) artifact := spec.Artifact{ - MediaType: spec.MediaTypeArtifactManifest, - Subject: &subjectDesc, - ArtifactType: "application/vnd.test", - Annotations: map[string]string{"foo": "bar"}, + MediaType: spec.MediaTypeArtifactManifest, + Subject: &subjectDesc, } artifactJSON, err := json.Marshal(artifact) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } artifactDesc := content.NewDescriptorFromBytes(artifact.MediaType, artifactJSON) - artifactDesc.ArtifactType = artifact.ArtifactType - artifactDesc.Annotations = artifact.Annotations - - // test push artifact with subject - index_1 := ocispec.Index{ - Versioned: specs.Versioned{ - SchemaVersion: 2, // historical value. does not pertain to OCI or docker version - }, + manifest := ocispec.Manifest{ + MediaType: ocispec.MediaTypeImageManifest, + Subject: &subjectDesc, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + manifestDesc := content.NewDescriptorFromBytes(manifest.MediaType, manifestJSON) + index := ocispec.Index{ MediaType: ocispec.MediaTypeImageIndex, - Manifests: []ocispec.Descriptor{ - artifactDesc, - }, + Subject: &subjectDesc, } - indexJSON_1, err := json.Marshal(index_1) + indexJSON, err := json.Marshal(index) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } - indexDesc_1 := content.NewDescriptorFromBytes(index_1.MediaType, indexJSON_1) + indexDesc := content.NewDescriptorFromBytes(manifest.MediaType, indexJSON) + var gotManifest []byte - var gotReferrerIndex []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+artifactDesc.Digest.String(): @@ -3270,8 +3470,158 @@ func Test_ManifestStore_Push_ReferrersAPIUnavailable(t *testing.T) { gotManifest = buf.Bytes() w.Header().Set("Docker-Content-Digest", artifactDesc.Digest.String()) w.WriteHeader(http.StatusCreated) - case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: - w.WriteHeader(http.StatusNotFound) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): + if contentType := r.Header.Get("Content-Type"); contentType != manifestDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotManifest = buf.Bytes() + w.Header().Set("Docker-Content-Digest", manifestDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+indexDesc.Digest.String(): + if contentType := r.Header.Get("Content-Type"); contentType != indexDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotManifest = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: + result := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{}, + } + if err := json.NewEncoder(w).Encode(result); err != nil { + t.Errorf("failed to write response: %v", err) + } + default: + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + uri, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + ctx := context.Background() + + // test pushing artifact with subject + repo, err := NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } + err = repo.Push(ctx, artifactDesc, bytes.NewReader(artifactJSON)) + if err != nil { + t.Fatalf("Manifests.Push() error = %v", err) + } + if !bytes.Equal(gotManifest, artifactJSON) { + t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(artifactJSON)) + } + if state := repo.loadReferrersState(); state != referrersStateSupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) + } + + // test pushing image manifest with subject + repo, err = NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } + err = repo.Push(ctx, manifestDesc, bytes.NewReader(manifestJSON)) + if err != nil { + t.Fatalf("Manifests.Push() error = %v", err) + } + if !bytes.Equal(gotManifest, manifestJSON) { + t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(manifestJSON)) + } + if state := repo.loadReferrersState(); state != referrersStateSupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) + } + + // test pushing image index with subject + err = repo.Push(ctx, indexDesc, bytes.NewReader(indexJSON)) + if err != nil { + t.Fatalf("Manifests.Push() error = %v", err) + } + if !bytes.Equal(gotManifest, indexJSON) { + t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(indexJSON)) + } + if state := repo.loadReferrersState(); state != referrersStateSupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) + } +} + +func Test_ManifestStore_Push_ReferrersAPIUnavailable(t *testing.T) { + // generate test content + subject := []byte(`{"layers":[]}`) + subjectDesc := content.NewDescriptorFromBytes(spec.MediaTypeArtifactManifest, subject) + referrersTag := strings.Replace(subjectDesc.Digest.String(), ":", "-", 1) + artifact := spec.Artifact{ + MediaType: spec.MediaTypeArtifactManifest, + Subject: &subjectDesc, + ArtifactType: "application/vnd.test", + Annotations: map[string]string{"foo": "bar"}, + } + artifactJSON, err := json.Marshal(artifact) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + artifactDesc := content.NewDescriptorFromBytes(artifact.MediaType, artifactJSON) + artifactDesc.ArtifactType = artifact.ArtifactType + artifactDesc.Annotations = artifact.Annotations + + // test pushing artifact with subject, a referrer list should be created + index_1 := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{ + artifactDesc, + }, + } + indexJSON_1, err := json.Marshal(index_1) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexDesc_1 := content.NewDescriptorFromBytes(index_1.MediaType, indexJSON_1) + var gotManifest []byte + var gotReferrerIndex []byte + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+artifactDesc.Digest.String(): + if contentType := r.Header.Get("Content-Type"); contentType != artifactDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotManifest = buf.Bytes() + w.Header().Set("Docker-Content-Digest", artifactDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: + w.WriteHeader(http.StatusNotFound) case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+referrersTag: w.WriteHeader(http.StatusNotFound) case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+referrersTag: @@ -3321,7 +3671,93 @@ func Test_ManifestStore_Push_ReferrersAPIUnavailable(t *testing.T) { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) } - // test push image manifest with subject, referrer list should be updated + // test pushing artifact with subject when an old empty referrer list exists, + // the referrer list should be updated + emptyIndex := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + } + emptyIndexJSON, err := json.Marshal(emptyIndex) + if err != nil { + t.Error("failed to marshal index", err) + } + emptyIndexDesc := content.NewDescriptorFromBytes(emptyIndex.MediaType, emptyIndexJSON) + var indexDeleted bool + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+artifactDesc.Digest.String(): + if contentType := r.Header.Get("Content-Type"); contentType != artifactDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotManifest = buf.Bytes() + w.Header().Set("Docker-Content-Digest", artifactDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: + w.WriteHeader(http.StatusNotFound) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+referrersTag: + w.Write(emptyIndexJSON) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+referrersTag: + if contentType := r.Header.Get("Content-Type"); contentType != ocispec.MediaTypeImageIndex { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotReferrerIndex = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexDesc_1.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+emptyIndexDesc.Digest.String(): + indexDeleted = true + // no "Docker-Content-Digest" header for manifest deletion + w.WriteHeader(http.StatusAccepted) + default: + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + uri, err = url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + + ctx = context.Background() + repo, err = NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } + err = repo.Push(ctx, artifactDesc, bytes.NewReader(artifactJSON)) + if err != nil { + t.Fatalf("Manifests.Push() error = %v", err) + } + if !bytes.Equal(gotManifest, artifactJSON) { + t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(artifactJSON)) + } + if !bytes.Equal(gotReferrerIndex, indexJSON_1) { + t.Errorf("got referrers index = %v, want %v", string(gotReferrerIndex), string(indexJSON_1)) + } + if !indexDeleted { + t.Errorf("indexDeleted = %v, want %v", indexDeleted, true) + } + if state := repo.loadReferrersState(); state != referrersStateUnsupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) + } + + // test pushing image manifest with subject, referrer list should be updated manifest := ocispec.Manifest{ MediaType: ocispec.MediaTypeImageManifest, Config: ocispec.Descriptor{ @@ -3332,7 +3768,7 @@ func Test_ManifestStore_Push_ReferrersAPIUnavailable(t *testing.T) { } manifestJSON, err := json.Marshal(manifest) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } manifestDesc := content.NewDescriptorFromBytes(manifest.MediaType, manifestJSON) manifestDesc.ArtifactType = manifest.Config.MediaType @@ -3349,10 +3785,10 @@ func Test_ManifestStore_Push_ReferrersAPIUnavailable(t *testing.T) { } indexJSON_2, err := json.Marshal(index_2) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } indexDesc_2 := content.NewDescriptorFromBytes(index_2.MediaType, indexJSON_2) - var manifestDeleted bool + indexDeleted = false ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): @@ -3384,7 +3820,7 @@ func Test_ManifestStore_Push_ReferrersAPIUnavailable(t *testing.T) { w.Header().Set("Docker-Content-Digest", indexDesc_2.Digest.String()) w.WriteHeader(http.StatusCreated) case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+indexDesc_1.Digest.String(): - manifestDeleted = true + indexDeleted = true // no "Docker-Content-Digest" header for manifest deletion w.WriteHeader(http.StatusAccepted) default: @@ -3417,14 +3853,14 @@ func Test_ManifestStore_Push_ReferrersAPIUnavailable(t *testing.T) { if !bytes.Equal(gotReferrerIndex, indexJSON_2) { t.Errorf("got referrers index = %v, want %v", string(gotReferrerIndex), string(indexJSON_2)) } - if !manifestDeleted { - t.Errorf("manifestDeleted = %v, want %v", manifestDeleted, true) + if !indexDeleted { + t.Errorf("indexDeleted = %v, want %v", indexDeleted, true) } if state := repo.loadReferrersState(); state != referrersStateUnsupported { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) } - // test push image manifest with subject again, referrers list should not be changed + // test pushing image manifest with subject again, referrers list should not be changed ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): @@ -3477,41 +3913,608 @@ func Test_ManifestStore_Push_ReferrersAPIUnavailable(t *testing.T) { if state := repo.loadReferrersState(); state != referrersStateUnsupported { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) } -} -func Test_ManifestStore_Exists(t *testing.T) { - manifest := []byte(`{"layers":[]}`) - manifestDesc := ocispec.Descriptor{ - MediaType: ocispec.MediaTypeImageManifest, - Digest: digest.FromBytes(manifest), - Size: int64(len(manifest)), + // push image index with subject, referrer list should be updated + indexManifest := ocispec.Index{ + MediaType: ocispec.MediaTypeImageIndex, + Subject: &subjectDesc, + ArtifactType: "test/index", + Annotations: map[string]string{"foo": "bar"}, } - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodHead { - t.Errorf("unexpected access: %s %s", r.Method, r.URL) - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - switch r.URL.Path { - case "/v2/test/manifests/" + manifestDesc.Digest.String(): - if accept := r.Header.Get("Accept"); !strings.Contains(accept, manifestDesc.MediaType) { - t.Errorf("manifest not convertable: %s", accept) - w.WriteHeader(http.StatusBadRequest) - return - } - w.Header().Set("Content-Type", manifestDesc.MediaType) - w.Header().Set("Docker-Content-Digest", manifestDesc.Digest.String()) - w.Header().Set("Content-Length", strconv.Itoa(int(manifestDesc.Size))) - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer ts.Close() + indexManifestJSON, err := json.Marshal(indexManifest) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexManifestDesc := content.NewDescriptorFromBytes(indexManifest.MediaType, indexManifestJSON) + indexManifestDesc.ArtifactType = indexManifest.ArtifactType + indexManifestDesc.Annotations = indexManifest.Annotations + index_3 := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{ + artifactDesc, + manifestDesc, + indexManifestDesc, + }, + } + indexJSON_3, err := json.Marshal(index_3) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexDesc_3 := content.NewDescriptorFromBytes(index_3.MediaType, indexJSON_3) + indexDeleted = false + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+indexManifestDesc.Digest.String(): + if contentType := r.Header.Get("Content-Type"); contentType != indexManifestDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotManifest = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexManifestDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: + w.WriteHeader(http.StatusNotFound) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+referrersTag: + w.Write(indexJSON_2) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+referrersTag: + if contentType := r.Header.Get("Content-Type"); contentType != ocispec.MediaTypeImageIndex { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotReferrerIndex = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexDesc_3.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+indexDesc_2.Digest.String(): + indexDeleted = true + // no "Docker-Content-Digest" header for manifest deletion + w.WriteHeader(http.StatusAccepted) + default: + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + uri, err = url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + + ctx = context.Background() + repo, err = NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } + err = repo.Push(ctx, indexManifestDesc, bytes.NewReader(indexManifestJSON)) + if err != nil { + t.Fatalf("Manifests.Push() error = %v", err) + } + if !bytes.Equal(gotManifest, indexManifestJSON) { + t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(indexManifestJSON)) + } + if !bytes.Equal(gotReferrerIndex, indexJSON_3) { + t.Errorf("got referrers index = %v, want %v", string(gotReferrerIndex), string(indexJSON_3)) + } + if !indexDeleted { + t.Errorf("indexDeleted = %v, want %v", indexDeleted, true) + } + if state := repo.loadReferrersState(); state != referrersStateUnsupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) + } +} + +func Test_ManifestStore_Push_ReferrersAPIUnavailable_SkipReferrersGC(t *testing.T) { + // generate test content + subject := []byte(`{"layers":[]}`) + subjectDesc := content.NewDescriptorFromBytes(spec.MediaTypeArtifactManifest, subject) + referrersTag := strings.Replace(subjectDesc.Digest.String(), ":", "-", 1) + manifest := ocispec.Manifest{ + MediaType: ocispec.MediaTypeImageManifest, + Config: ocispec.Descriptor{ + MediaType: "testconfig", + }, + Subject: &subjectDesc, + Annotations: map[string]string{"foo": "bar"}, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + manifestDesc := content.NewDescriptorFromBytes(manifest.MediaType, manifestJSON) + manifestDesc.ArtifactType = manifest.Config.MediaType + manifestDesc.Annotations = manifest.Annotations + index_1 := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{ + manifestDesc, + }, + } + + // test pushing image manifest with subject, a referrers list should be created + indexJSON_1, err := json.Marshal(index_1) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexDesc_1 := content.NewDescriptorFromBytes(index_1.MediaType, indexJSON_1) + var gotManifest []byte + var gotReferrerIndex []byte + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): + if contentType := r.Header.Get("Content-Type"); contentType != manifestDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotManifest = buf.Bytes() + w.Header().Set("Docker-Content-Digest", manifestDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: + w.WriteHeader(http.StatusNotFound) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+referrersTag: + w.WriteHeader(http.StatusNotFound) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+referrersTag: + if contentType := r.Header.Get("Content-Type"); contentType != ocispec.MediaTypeImageIndex { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotReferrerIndex = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexDesc_1.Digest.String()) + w.WriteHeader(http.StatusCreated) + default: + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + uri, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + + ctx := context.Background() + repo, err := NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + repo.SkipReferrersGC = true + + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } + err = repo.Push(ctx, manifestDesc, bytes.NewReader(manifestJSON)) + if err != nil { + t.Fatalf("Manifests.Push() error = %v", err) + } + if !bytes.Equal(gotManifest, manifestJSON) { + t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(manifestJSON)) + } + if !bytes.Equal(gotReferrerIndex, indexJSON_1) { + t.Errorf("got referrers index = %v, want %v", string(gotReferrerIndex), string(indexJSON_1)) + } + if state := repo.loadReferrersState(); state != referrersStateUnsupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) + } + + // test pushing image manifest with subject when an old empty referrer list exists, + // the referrer list should be updated + emptyIndex := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + } + emptyIndexJSON, err := json.Marshal(emptyIndex) + if err != nil { + t.Error("failed to marshal index", err) + } + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): + if contentType := r.Header.Get("Content-Type"); contentType != manifestDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotManifest = buf.Bytes() + w.Header().Set("Docker-Content-Digest", manifestDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: + w.WriteHeader(http.StatusNotFound) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+referrersTag: + w.Write(emptyIndexJSON) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+referrersTag: + if contentType := r.Header.Get("Content-Type"); contentType != ocispec.MediaTypeImageIndex { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotReferrerIndex = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexDesc_1.Digest.String()) + w.WriteHeader(http.StatusCreated) + default: + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + uri, err = url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + + ctx = context.Background() + repo, err = NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + repo.SkipReferrersGC = true + + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } + err = repo.Push(ctx, manifestDesc, bytes.NewReader(manifestJSON)) + if err != nil { + t.Fatalf("Manifests.Push() error = %v", err) + } + if !bytes.Equal(gotManifest, manifestJSON) { + t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(manifestJSON)) + } + if !bytes.Equal(gotReferrerIndex, indexJSON_1) { + t.Errorf("got referrers index = %v, want %v", string(gotReferrerIndex), string(indexJSON_1)) + } + if state := repo.loadReferrersState(); state != referrersStateUnsupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) + } + + // push image index with subject, referrer list should be updated, the old + // one should not be deleted + indexManifest := ocispec.Index{ + MediaType: ocispec.MediaTypeImageIndex, + Subject: &subjectDesc, + ArtifactType: "test/index", + Annotations: map[string]string{"foo": "bar"}, + } + indexManifestJSON, err := json.Marshal(indexManifest) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexManifestDesc := content.NewDescriptorFromBytes(indexManifest.MediaType, indexManifestJSON) + indexManifestDesc.ArtifactType = indexManifest.ArtifactType + indexManifestDesc.Annotations = indexManifest.Annotations + index_2 := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{ + manifestDesc, + indexManifestDesc, + }, + } + indexJSON_2, err := json.Marshal(index_2) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexDesc_2 := content.NewDescriptorFromBytes(index_2.MediaType, indexJSON_2) + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+indexManifestDesc.Digest.String(): + if contentType := r.Header.Get("Content-Type"); contentType != indexManifestDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotManifest = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexManifestDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: + w.WriteHeader(http.StatusNotFound) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+referrersTag: + w.Write(indexJSON_1) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+referrersTag: + if contentType := r.Header.Get("Content-Type"); contentType != ocispec.MediaTypeImageIndex { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotReferrerIndex = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexDesc_2.Digest.String()) + w.WriteHeader(http.StatusCreated) + default: + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + uri, err = url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + + ctx = context.Background() + repo, err = NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + repo.SkipReferrersGC = true + + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } + err = repo.Push(ctx, indexManifestDesc, bytes.NewReader(indexManifestJSON)) + if err != nil { + t.Fatalf("Manifests.Push() error = %v", err) + } + if !bytes.Equal(gotManifest, indexManifestJSON) { + t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(indexManifestJSON)) + } + if !bytes.Equal(gotReferrerIndex, indexJSON_2) { + t.Errorf("got referrers index = %v, want %v", string(gotReferrerIndex), string(indexJSON_2)) + } + if state := repo.loadReferrersState(); state != referrersStateUnsupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) + } +} + +func Test_ManifestStore_Exists(t *testing.T) { + manifest := []byte(`{"layers":[]}`) + manifestDesc := ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageManifest, + Digest: digest.FromBytes(manifest), + Size: int64(len(manifest)), + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodHead { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + switch r.URL.Path { + case "/v2/test/manifests/" + manifestDesc.Digest.String(): + if accept := r.Header.Get("Accept"); !strings.Contains(accept, manifestDesc.MediaType) { + t.Errorf("manifest not convertable: %s", accept) + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", manifestDesc.MediaType) + w.Header().Set("Docker-Content-Digest", manifestDesc.Digest.String()) + w.Header().Set("Content-Length", strconv.Itoa(int(manifestDesc.Size))) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + uri, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + + repo, err := NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + store := repo.Manifests() + ctx := context.Background() + + exists, err := store.Exists(ctx, manifestDesc) + if err != nil { + t.Fatalf("Manifests.Exists() error = %v", err) + } + if !exists { + t.Errorf("Manifests.Exists() = %v, want %v", exists, true) + } + + content := []byte(`{"manifests":[]}`) + contentDesc := ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageIndex, + Digest: digest.FromBytes(content), + Size: int64(len(content)), + } + exists, err = store.Exists(ctx, contentDesc) + if err != nil { + t.Fatalf("Manifests.Exists() error = %v", err) + } + if exists { + t.Errorf("Manifests.Exists() = %v, want %v", exists, false) + } +} + +func Test_ManifestStore_Delete(t *testing.T) { + manifest := []byte(`{"layers":[]}`) + manifestDesc := ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageManifest, + Digest: digest.FromBytes(manifest), + Size: int64(len(manifest)), + } + manifestDeleted := false + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete && r.Method != http.MethodGet { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusMethodNotAllowed) + } + switch { + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): + manifestDeleted = true + // no "Docker-Content-Digest" header for manifest deletion + w.WriteHeader(http.StatusAccepted) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): + if accept := r.Header.Get("Accept"); !strings.Contains(accept, manifestDesc.MediaType) { + t.Errorf("manifest not convertable: %s", accept) + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", manifestDesc.MediaType) + w.Header().Set("Docker-Content-Digest", manifestDesc.Digest.String()) + if _, err := w.Write(manifest); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + uri, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + + repo, err := NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + store := repo.Manifests() + ctx := context.Background() + + // test deleting manifest without subject + err = store.Delete(ctx, manifestDesc) + if err != nil { + t.Fatalf("Manifests.Delete() error = %v", err) + } + if !manifestDeleted { + t.Errorf("Manifests.Delete() = %v, want %v", manifestDeleted, true) + } + + // test deleting content that does not exist + content := []byte(`{"manifests":[]}`) + contentDesc := ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageIndex, + Digest: digest.FromBytes(content), + Size: int64(len(content)), + } + err = store.Delete(ctx, contentDesc) + if !errors.Is(err, errdef.ErrNotFound) { + t.Errorf("Manifests.Delete() error = %v, wantErr %v", err, errdef.ErrNotFound) + } +} + +func Test_ManifestStore_Delete_ReferrersAPIAvailable(t *testing.T) { + // generate test content + subject := []byte(`{"layers":[]}`) + subjectDesc := content.NewDescriptorFromBytes(spec.MediaTypeArtifactManifest, subject) + artifact := spec.Artifact{ + MediaType: spec.MediaTypeArtifactManifest, + Subject: &subjectDesc, + } + artifactJSON, err := json.Marshal(artifact) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + artifactDesc := content.NewDescriptorFromBytes(artifact.MediaType, artifactJSON) + + manifest := ocispec.Manifest{ + MediaType: ocispec.MediaTypeImageManifest, + Subject: &subjectDesc, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + manifestDesc := content.NewDescriptorFromBytes(manifest.MediaType, manifestJSON) + + index := ocispec.Index{ + MediaType: ocispec.MediaTypeImageIndex, + Subject: &subjectDesc, + } + indexJSON, err := json.Marshal(index) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexDesc := content.NewDescriptorFromBytes(index.MediaType, indexJSON) + + var manifestDeleted bool + var artifactDeleted bool + var indexDeleted bool + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete && r.Method != http.MethodGet { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusMethodNotAllowed) + } + switch { + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+artifactDesc.Digest.String(): + artifactDeleted = true + // no "Docker-Content-Digest" header for manifest deletion + w.WriteHeader(http.StatusAccepted) + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): + manifestDeleted = true + // no "Docker-Content-Digest" header for manifest deletion + w.WriteHeader(http.StatusAccepted) + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+indexDesc.Digest.String(): + indexDeleted = true + // no "Docker-Content-Digest" header for manifest deletion + w.WriteHeader(http.StatusAccepted) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+artifactDesc.Digest.String(): + if accept := r.Header.Get("Accept"); !strings.Contains(accept, artifactDesc.MediaType) { + t.Errorf("manifest not convertable: %s", accept) + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", artifactDesc.MediaType) + w.Header().Set("Docker-Content-Digest", artifactDesc.Digest.String()) + if _, err := w.Write(artifactJSON); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: + result := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{}, + } + if err := json.NewEncoder(w).Encode(result); err != nil { + t.Errorf("failed to write response: %v", err) + } + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() uri, err := url.Parse(ts.URL) if err != nil { t.Fatalf("invalid test http server: %v", err) } - repo, err := NewRepository(uri.Host + "/test") if err != nil { t.Fatalf("NewRepository() error = %v", err) @@ -3519,60 +4522,181 @@ func Test_ManifestStore_Exists(t *testing.T) { repo.PlainHTTP = true store := repo.Manifests() ctx := context.Background() + // test deleting artifact with subject + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } + err = store.Delete(ctx, artifactDesc) + if err != nil { + t.Fatalf("Manifests.Delete() error = %v", err) + } + if !artifactDeleted { + t.Errorf("Manifests.Delete() = %v, want %v", artifactDeleted, true) + } - exists, err := store.Exists(ctx, manifestDesc) + // test deleting manifest with subject + if state := repo.loadReferrersState(); state != referrersStateSupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) + } + err = store.Delete(ctx, manifestDesc) if err != nil { - t.Fatalf("Manifests.Exists() error = %v", err) + t.Fatalf("Manifests.Delete() error = %v", err) } - if !exists { - t.Errorf("Manifests.Exists() = %v, want %v", exists, true) + if !manifestDeleted { + t.Errorf("Manifests.Delete() = %v, want %v", manifestDeleted, true) } - content := []byte(`{"manifests":[]}`) + // test deleting index with subject + if state := repo.loadReferrersState(); state != referrersStateSupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) + } + err = store.Delete(ctx, indexDesc) + if err != nil { + t.Fatalf("Manifests.Delete() error = %v", err) + } + if !indexDeleted { + t.Errorf("Manifests.Delete() = %v, want %v", indexDeleted, true) + } + + // test deleting content that does not exist + content := []byte("whatever") contentDesc := ocispec.Descriptor{ - MediaType: ocispec.MediaTypeImageIndex, + MediaType: ocispec.MediaTypeImageManifest, Digest: digest.FromBytes(content), Size: int64(len(content)), } - exists, err = store.Exists(ctx, contentDesc) - if err != nil { - t.Fatalf("Manifests.Exists() error = %v", err) - } - if exists { - t.Errorf("Manifests.Exists() = %v, want %v", exists, false) + ctx = context.Background() + err = store.Delete(ctx, contentDesc) + if !errors.Is(err, errdef.ErrNotFound) { + t.Errorf("Manifests.Delete() error = %v, wantErr %v", err, errdef.ErrNotFound) } } -func Test_ManifestStore_Delete(t *testing.T) { - manifest := []byte(`{"layers":[]}`) - manifestDesc := ocispec.Descriptor{ +func Test_ManifestStore_Delete_ReferrersAPIUnavailable(t *testing.T) { + // generate test content + subject := []byte(`{"layers":[]}`) + subjectDesc := content.NewDescriptorFromBytes(spec.MediaTypeArtifactManifest, subject) + referrersTag := strings.Replace(subjectDesc.Digest.String(), ":", "-", 1) + + artifact := spec.Artifact{ + MediaType: spec.MediaTypeArtifactManifest, + Subject: &subjectDesc, + } + artifactJSON, err := json.Marshal(artifact) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + artifactDesc := content.NewDescriptorFromBytes(artifact.MediaType, artifactJSON) + + manifest := ocispec.Manifest{ MediaType: ocispec.MediaTypeImageManifest, - Digest: digest.FromBytes(manifest), - Size: int64(len(manifest)), + Subject: &subjectDesc, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + manifestDesc := content.NewDescriptorFromBytes(manifest.MediaType, manifestJSON) + + indexManifest := ocispec.Index{ + MediaType: ocispec.MediaTypeImageIndex, + Subject: &subjectDesc, + } + indexManifestJSON, err := json.Marshal(indexManifest) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexManifestDesc := content.NewDescriptorFromBytes(indexManifest.MediaType, indexManifestJSON) + + index_1 := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{ + artifactDesc, + manifestDesc, + indexManifestDesc, + }, + } + indexJSON_1, err := json.Marshal(index_1) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexDesc_1 := content.NewDescriptorFromBytes(index_1.MediaType, indexJSON_1) + index_2 := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{ + manifestDesc, + indexManifestDesc, + }, + } + indexJSON_2, err := json.Marshal(index_2) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexDesc_2 := content.NewDescriptorFromBytes(index_2.MediaType, indexJSON_2) + index_3 := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{ + indexManifestDesc, + }, } + indexJSON_3, err := json.Marshal(index_3) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexDesc_3 := content.NewDescriptorFromBytes(index_3.MediaType, indexJSON_3) + + // test deleting artifact with subject, referrers list should be updated manifestDeleted := false + indexDeleted := false + var gotReferrerIndex []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete && r.Method != http.MethodGet { - t.Errorf("unexpected access: %s %s", r.Method, r.URL) - w.WriteHeader(http.StatusMethodNotAllowed) - } switch { - case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+artifactDesc.Digest.String(): manifestDeleted = true // no "Docker-Content-Digest" header for manifest deletion w.WriteHeader(http.StatusAccepted) - case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): - if accept := r.Header.Get("Accept"); !strings.Contains(accept, manifestDesc.MediaType) { + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+artifactDesc.Digest.String(): + if accept := r.Header.Get("Accept"); !strings.Contains(accept, artifactDesc.MediaType) { t.Errorf("manifest not convertable: %s", accept) w.WriteHeader(http.StatusBadRequest) return } - w.Header().Set("Content-Type", manifestDesc.MediaType) - w.Header().Set("Docker-Content-Digest", manifestDesc.Digest.String()) - if _, err := w.Write(manifest); err != nil { + w.Header().Set("Content-Type", artifactDesc.MediaType) + w.Header().Set("Docker-Content-Digest", artifactDesc.Digest.String()) + if _, err := w.Write(artifactJSON); err != nil { t.Errorf("failed to write %q: %v", r.URL, err) } + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: + w.WriteHeader(http.StatusNotFound) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+referrersTag: + w.Write(indexJSON_1) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+referrersTag: + if contentType := r.Header.Get("Content-Type"); contentType != ocispec.MediaTypeImageIndex { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotReferrerIndex = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexDesc_2.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+indexDesc_1.Digest.String(): + indexDeleted = true + // no "Docker-Content-Digest" header for manifest deletion + w.WriteHeader(http.StatusAccepted) default: + t.Errorf("unexpected access: %s %s", r.Method, r.URL) w.WriteHeader(http.StatusNotFound) } })) @@ -3581,7 +4705,6 @@ func Test_ManifestStore_Delete(t *testing.T) { if err != nil { t.Fatalf("invalid test http server: %v", err) } - repo, err := NewRepository(uri.Host + "/test") if err != nil { t.Fatalf("NewRepository() error = %v", err) @@ -3590,215 +4713,249 @@ func Test_ManifestStore_Delete(t *testing.T) { store := repo.Manifests() ctx := context.Background() - // test delete manifest without subject - err = store.Delete(ctx, manifestDesc) + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } + err = store.Delete(ctx, artifactDesc) if err != nil { t.Fatalf("Manifests.Delete() error = %v", err) } if !manifestDeleted { t.Errorf("Manifests.Delete() = %v, want %v", manifestDeleted, true) } - - // test delete content that does not exist - content := []byte(`{"manifests":[]}`) - contentDesc := ocispec.Descriptor{ - MediaType: ocispec.MediaTypeImageIndex, - Digest: digest.FromBytes(content), - Size: int64(len(content)), + if !bytes.Equal(gotReferrerIndex, indexJSON_2) { + t.Errorf("got referrers index = %v, want %v", string(gotReferrerIndex), string(indexJSON_2)) } - err = store.Delete(ctx, contentDesc) - if !errors.Is(err, errdef.ErrNotFound) { - t.Errorf("Manifests.Delete() error = %v, wantErr %v", err, errdef.ErrNotFound) + if !indexDeleted { + t.Errorf("Manifests.Delete() = %v, want %v", manifestDeleted, true) + } + if state := repo.loadReferrersState(); state != referrersStateUnsupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) } -} -func Test_ManifestStore_Delete_ReferrersAPIAvailable(t *testing.T) { - // generate test content - subject := []byte(`{"layers":[]}`) - subjectDesc := content.NewDescriptorFromBytes(spec.MediaTypeArtifactManifest, subject) - artifact := spec.Artifact{ - MediaType: spec.MediaTypeArtifactManifest, - Subject: &subjectDesc, + // test deleting manifest with subject, referrers list should be updated + manifestDeleted = false + indexDeleted = false + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): + manifestDeleted = true + // no "Docker-Content-Digest" header for manifest deletion + w.WriteHeader(http.StatusAccepted) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): + if accept := r.Header.Get("Accept"); !strings.Contains(accept, manifestDesc.MediaType) { + t.Errorf("manifest not convertable: %s", accept) + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", manifestDesc.MediaType) + w.Header().Set("Docker-Content-Digest", manifestDesc.Digest.String()) + if _, err := w.Write(manifestJSON); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: + w.WriteHeader(http.StatusNotFound) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+referrersTag: + w.Write(indexJSON_2) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+referrersTag: + if contentType := r.Header.Get("Content-Type"); contentType != ocispec.MediaTypeImageIndex { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotReferrerIndex = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexDesc_3.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+indexDesc_2.Digest.String(): + indexDeleted = true + // no "Docker-Content-Digest" header for manifest deletion + w.WriteHeader(http.StatusAccepted) + default: + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + uri, err = url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + repo, err = NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + store = repo.Manifests() + ctx = context.Background() + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) } - artifactJSON, err := json.Marshal(artifact) + err = store.Delete(ctx, manifestDesc) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("Manifests.Delete() error = %v", err) } - artifactDesc := content.NewDescriptorFromBytes(artifact.MediaType, artifactJSON) - manifest := ocispec.Manifest{ - MediaType: ocispec.MediaTypeImageManifest, - Subject: &subjectDesc, + if !manifestDeleted { + t.Errorf("Manifests.Delete() = %v, want %v", manifestDeleted, true) } - manifestJSON, err := json.Marshal(manifest) - if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + if !indexDeleted { + t.Errorf("Manifests.Delete() = %v, want %v", manifestDeleted, true) } - manifestDesc := content.NewDescriptorFromBytes(manifest.MediaType, manifestJSON) - manifestDeleted := false - artifactDeleted := false - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete && r.Method != http.MethodGet { - t.Errorf("unexpected access: %s %s", r.Method, r.URL) - w.WriteHeader(http.StatusMethodNotAllowed) - } + if state := repo.loadReferrersState(); state != referrersStateUnsupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) + } + + // test deleting index with a subject, referrers list should be updated + manifestDeleted = false + indexDeleted = false + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+artifactDesc.Digest.String(): - artifactDeleted = true - // no "Docker-Content-Digest" header for manifest deletion - w.WriteHeader(http.StatusAccepted) - case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+indexManifestDesc.Digest.String(): manifestDeleted = true // no "Docker-Content-Digest" header for manifest deletion w.WriteHeader(http.StatusAccepted) - case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+artifactDesc.Digest.String(): - if accept := r.Header.Get("Accept"); !strings.Contains(accept, artifactDesc.MediaType) { + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+indexManifestDesc.Digest.String(): + if accept := r.Header.Get("Accept"); !strings.Contains(accept, indexManifestDesc.MediaType) { t.Errorf("manifest not convertable: %s", accept) w.WriteHeader(http.StatusBadRequest) return } - w.Header().Set("Content-Type", artifactDesc.MediaType) - w.Header().Set("Docker-Content-Digest", artifactDesc.Digest.String()) - if _, err := w.Write(artifactJSON); err != nil { + w.Header().Set("Content-Type", indexManifestDesc.MediaType) + w.Header().Set("Docker-Content-Digest", indexManifestDesc.Digest.String()) + if _, err := w.Write(indexManifestJSON); err != nil { t.Errorf("failed to write %q: %v", r.URL, err) } case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: - result := ocispec.Index{ - Versioned: specs.Versioned{ - SchemaVersion: 2, // historical value. does not pertain to OCI or docker version - }, - MediaType: ocispec.MediaTypeImageIndex, - Manifests: []ocispec.Descriptor{}, - } - if err := json.NewEncoder(w).Encode(result); err != nil { - t.Errorf("failed to write response: %v", err) - } + w.WriteHeader(http.StatusNotFound) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+referrersTag: + w.Write(indexJSON_3) + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+indexDesc_3.Digest.String(): + indexDeleted = true + // no "Docker-Content-Digest" header for manifest deletion + w.WriteHeader(http.StatusAccepted) default: + t.Errorf("unexpected access: %s %s", r.Method, r.URL) w.WriteHeader(http.StatusNotFound) } })) defer ts.Close() - uri, err := url.Parse(ts.URL) + uri, err = url.Parse(ts.URL) if err != nil { t.Fatalf("invalid test http server: %v", err) } - repo, err := NewRepository(uri.Host + "/test") + repo, err = NewRepository(uri.Host + "/test") if err != nil { t.Fatalf("NewRepository() error = %v", err) } repo.PlainHTTP = true - store := repo.Manifests() - ctx := context.Background() - // test delete artifact with subject + store = repo.Manifests() + ctx = context.Background() if state := repo.loadReferrersState(); state != referrersStateUnknown { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) } - err = store.Delete(ctx, artifactDesc) - if err != nil { - t.Fatalf("Manifests.Delete() error = %v", err) - } - if !artifactDeleted { - t.Errorf("Manifests.Delete() = %v, want %v", artifactDeleted, true) - } - - // test delete manifest with subject - if state := repo.loadReferrersState(); state != referrersStateSupported { - t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) - } - err = store.Delete(ctx, manifestDesc) + err = store.Delete(ctx, indexManifestDesc) if err != nil { t.Fatalf("Manifests.Delete() error = %v", err) } if !manifestDeleted { t.Errorf("Manifests.Delete() = %v, want %v", manifestDeleted, true) } - - // test delete content that does not exist - content := []byte("whatever") - contentDesc := ocispec.Descriptor{ - MediaType: ocispec.MediaTypeImageManifest, - Digest: digest.FromBytes(content), - Size: int64(len(content)), + if !indexDeleted { + t.Errorf("Manifests.Delete() = %v, want %v", manifestDeleted, true) } - ctx = context.Background() - err = store.Delete(ctx, contentDesc) - if !errors.Is(err, errdef.ErrNotFound) { - t.Errorf("Manifests.Delete() error = %v, wantErr %v", err, errdef.ErrNotFound) + if state := repo.loadReferrersState(); state != referrersStateUnsupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) } } -func Test_ManifestStore_Delete_ReferrersAPIUnavailable(t *testing.T) { +func Test_ManifestStore_Delete_ReferrersAPIUnavailable_SkipReferrersGC(t *testing.T) { // generate test content subject := []byte(`{"layers":[]}`) subjectDesc := content.NewDescriptorFromBytes(spec.MediaTypeArtifactManifest, subject) referrersTag := strings.Replace(subjectDesc.Digest.String(), ":", "-", 1) - artifact := spec.Artifact{ - MediaType: spec.MediaTypeArtifactManifest, - Subject: &subjectDesc, - } - artifactJSON, err := json.Marshal(artifact) - if err != nil { - t.Errorf("failed to marshal manifest: %v", err) - } - artifactDesc := content.NewDescriptorFromBytes(artifact.MediaType, artifactJSON) + manifest := ocispec.Manifest{ MediaType: ocispec.MediaTypeImageManifest, Subject: &subjectDesc, } manifestJSON, err := json.Marshal(manifest) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } manifestDesc := content.NewDescriptorFromBytes(manifest.MediaType, manifestJSON) - // test delete artifact with subject + indexManifest := ocispec.Index{ + MediaType: ocispec.MediaTypeImageIndex, + Subject: &subjectDesc, + } + indexManifestJSON, err := json.Marshal(indexManifest) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexManifestDesc := content.NewDescriptorFromBytes(indexManifest.MediaType, indexManifestJSON) + index_1 := ocispec.Index{ Versioned: specs.Versioned{ SchemaVersion: 2, // historical value. does not pertain to OCI or docker version }, MediaType: ocispec.MediaTypeImageIndex, Manifests: []ocispec.Descriptor{ - artifactDesc, manifestDesc, + indexManifestDesc, }, } indexJSON_1, err := json.Marshal(index_1) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } - indexDesc_1 := content.NewDescriptorFromBytes(index_1.MediaType, indexJSON_1) index_2 := ocispec.Index{ Versioned: specs.Versioned{ SchemaVersion: 2, // historical value. does not pertain to OCI or docker version }, MediaType: ocispec.MediaTypeImageIndex, Manifests: []ocispec.Descriptor{ - manifestDesc, + indexManifestDesc, }, } indexJSON_2, err := json.Marshal(index_2) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } indexDesc_2 := content.NewDescriptorFromBytes(index_2.MediaType, indexJSON_2) + index_3 := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{}, + } + indexJSON_3, err := json.Marshal(index_3) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexDesc_3 := content.NewDescriptorFromBytes(index_3.MediaType, indexJSON_3) + // test deleting image manifest with subject, referrers list should be updated, + // the old one should not be deleted manifestDeleted := false - indexDeleted := false var gotReferrerIndex []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+artifactDesc.Digest.String(): + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): manifestDeleted = true // no "Docker-Content-Digest" header for manifest deletion w.WriteHeader(http.StatusAccepted) - case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+artifactDesc.Digest.String(): - if accept := r.Header.Get("Accept"); !strings.Contains(accept, artifactDesc.MediaType) { + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): + if accept := r.Header.Get("Accept"); !strings.Contains(accept, manifestDesc.MediaType) { t.Errorf("manifest not convertable: %s", accept) w.WriteHeader(http.StatusBadRequest) return } - w.Header().Set("Content-Type", artifactDesc.MediaType) - w.Header().Set("Docker-Content-Digest", artifactDesc.Digest.String()) - if _, err := w.Write(artifactJSON); err != nil { + w.Header().Set("Content-Type", manifestDesc.MediaType) + w.Header().Set("Docker-Content-Digest", manifestDesc.Digest.String()) + if _, err := w.Write(manifestJSON); err != nil { t.Errorf("failed to write %q: %v", r.URL, err) } case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: @@ -3817,10 +4974,6 @@ func Test_ManifestStore_Delete_ReferrersAPIUnavailable(t *testing.T) { gotReferrerIndex = buf.Bytes() w.Header().Set("Docker-Content-Digest", indexDesc_2.Digest.String()) w.WriteHeader(http.StatusCreated) - case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+indexDesc_1.Digest.String(): - indexDeleted = true - // no "Docker-Content-Digest" header for manifest deletion - w.WriteHeader(http.StatusAccepted) default: t.Errorf("unexpected access: %s %s", r.Method, r.URL) w.WriteHeader(http.StatusNotFound) @@ -3836,14 +4989,14 @@ func Test_ManifestStore_Delete_ReferrersAPIUnavailable(t *testing.T) { t.Fatalf("NewRepository() error = %v", err) } repo.PlainHTTP = true + repo.SkipReferrersGC = true store := repo.Manifests() ctx := context.Background() - // test delete artifact with subject if state := repo.loadReferrersState(); state != referrersStateUnknown { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) } - err = store.Delete(ctx, artifactDesc) + err = store.Delete(ctx, manifestDesc) if err != nil { t.Fatalf("Manifests.Delete() error = %v", err) } @@ -3853,41 +5006,46 @@ func Test_ManifestStore_Delete_ReferrersAPIUnavailable(t *testing.T) { if !bytes.Equal(gotReferrerIndex, indexJSON_2) { t.Errorf("got referrers index = %v, want %v", string(gotReferrerIndex), string(indexJSON_2)) } - if !indexDeleted { - t.Errorf("Manifests.Delete() = %v, want %v", manifestDeleted, true) - } if state := repo.loadReferrersState(); state != referrersStateUnsupported { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) } - // test delete manifest with subject + // test deleting index with a subject, referrers list should be updated, + // the old one should not be deleted, an empty one should be pushed manifestDeleted = false - indexDeleted = false ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+indexManifestDesc.Digest.String(): manifestDeleted = true // no "Docker-Content-Digest" header for manifest deletion w.WriteHeader(http.StatusAccepted) - case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): - if accept := r.Header.Get("Accept"); !strings.Contains(accept, manifestDesc.MediaType) { + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+indexManifestDesc.Digest.String(): + if accept := r.Header.Get("Accept"); !strings.Contains(accept, indexManifestDesc.MediaType) { t.Errorf("manifest not convertable: %s", accept) w.WriteHeader(http.StatusBadRequest) return } - w.Header().Set("Content-Type", manifestDesc.MediaType) - w.Header().Set("Docker-Content-Digest", manifestDesc.Digest.String()) - if _, err := w.Write(manifestJSON); err != nil { + w.Header().Set("Content-Type", indexManifestDesc.MediaType) + w.Header().Set("Docker-Content-Digest", indexManifestDesc.Digest.String()) + if _, err := w.Write(indexManifestJSON); err != nil { t.Errorf("failed to write %q: %v", r.URL, err) } case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: w.WriteHeader(http.StatusNotFound) case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+referrersTag: w.Write(indexJSON_2) - case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+indexDesc_2.Digest.String(): - indexDeleted = true - // no "Docker-Content-Digest" header for manifest deletion - w.WriteHeader(http.StatusAccepted) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+referrersTag: + if contentType := r.Header.Get("Content-Type"); contentType != ocispec.MediaTypeImageIndex { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotReferrerIndex = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexDesc_3.Digest.String()) + w.WriteHeader(http.StatusCreated) default: t.Errorf("unexpected access: %s %s", r.Method, r.URL) w.WriteHeader(http.StatusNotFound) @@ -3903,20 +5061,22 @@ func Test_ManifestStore_Delete_ReferrersAPIUnavailable(t *testing.T) { t.Fatalf("NewRepository() error = %v", err) } repo.PlainHTTP = true + repo.SkipReferrersGC = true store = repo.Manifests() ctx = context.Background() + if state := repo.loadReferrersState(); state != referrersStateUnknown { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) } - err = store.Delete(ctx, manifestDesc) + err = store.Delete(ctx, indexManifestDesc) if err != nil { t.Fatalf("Manifests.Delete() error = %v", err) } if !manifestDeleted { t.Errorf("Manifests.Delete() = %v, want %v", manifestDeleted, true) } - if !indexDeleted { - t.Errorf("Manifests.Delete() = %v, want %v", manifestDeleted, true) + if !bytes.Equal(gotReferrerIndex, indexJSON_3) { + t.Errorf("got referrers index = %v, want %v", string(gotReferrerIndex), string(indexJSON_3)) } if state := repo.loadReferrersState(); state != referrersStateUnsupported { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) @@ -3934,7 +5094,7 @@ func Test_ManifestStore_Delete_ReferrersAPIUnavailable_InconsistentIndex(t *test } artifactJSON, err := json.Marshal(artifact) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } artifactDesc := content.NewDescriptorFromBytes(artifact.MediaType, artifactJSON) @@ -4457,13 +5617,130 @@ func Test_ManifestStore_PushReference(t *testing.T) { if _, err := buf.ReadFrom(r.Body); err != nil { t.Errorf("fail to read: %v", err) } - gotIndex = buf.Bytes() + gotIndex = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) + return + default: + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusForbidden) + } + })) + defer ts.Close() + uri, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + + repo, err := NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + store := repo.Manifests() + repo.PlainHTTP = true + ctx := context.Background() + err = store.PushReference(ctx, indexDesc, bytes.NewReader(index), ref) + if err != nil { + t.Fatalf("Repository.PushReference() error = %v", err) + } + if !bytes.Equal(gotIndex, index) { + t.Errorf("Repository.PushReference() = %v, want %v", gotIndex, index) + } +} + +func Test_ManifestStore_PushReference_ReferrersAPIAvailable(t *testing.T) { + // generate test content + subject := []byte(`{"layers":[]}`) + subjectDesc := content.NewDescriptorFromBytes(spec.MediaTypeArtifactManifest, subject) + artifact := spec.Artifact{ + MediaType: spec.MediaTypeArtifactManifest, + Subject: &subjectDesc, + } + artifactJSON, err := json.Marshal(artifact) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + artifactDesc := content.NewDescriptorFromBytes(artifact.MediaType, artifactJSON) + artifactRef := "foo" + + manifest := ocispec.Manifest{ + MediaType: ocispec.MediaTypeImageManifest, + Subject: &subjectDesc, + } + manifestJSON, err := json.Marshal(manifest) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + manifestDesc := content.NewDescriptorFromBytes(manifest.MediaType, manifestJSON) + manifestRef := "bar" + + index := ocispec.Index{ + MediaType: ocispec.MediaTypeImageIndex, + Subject: &subjectDesc, + } + indexJSON, err := json.Marshal(index) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexDesc := content.NewDescriptorFromBytes(index.MediaType, indexJSON) + indexRef := "baz" + + var gotManifest []byte + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+artifactRef: + if contentType := r.Header.Get("Content-Type"); contentType != artifactDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotManifest = buf.Bytes() + w.Header().Set("Docker-Content-Digest", artifactDesc.Digest.String()) + w.Header().Set("OCI-Subject", subjectDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+manifestRef: + if contentType := r.Header.Get("Content-Type"); contentType != manifestDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotManifest = buf.Bytes() + w.Header().Set("Docker-Content-Digest", manifestDesc.Digest.String()) + w.Header().Set("OCI-Subject", subjectDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+indexRef: + if contentType := r.Header.Get("Content-Type"); contentType != indexDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotManifest = buf.Bytes() w.Header().Set("Docker-Content-Digest", indexDesc.Digest.String()) + w.Header().Set("OCI-Subject", subjectDesc.Digest.String()) w.WriteHeader(http.StatusCreated) - return + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: + result := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{}, + } + if err := json.NewEncoder(w).Encode(result); err != nil { + t.Errorf("failed to write response: %v", err) + } default: t.Errorf("unexpected access: %s %s", r.Method, r.URL) - w.WriteHeader(http.StatusForbidden) + w.WriteHeader(http.StatusNotFound) } })) defer ts.Close() @@ -4471,24 +5748,62 @@ func Test_ManifestStore_PushReference(t *testing.T) { if err != nil { t.Fatalf("invalid test http server: %v", err) } + ctx := context.Background() + // test pushing artifact with subject repo, err := NewRepository(uri.Host + "/test") if err != nil { t.Fatalf("NewRepository() error = %v", err) } - store := repo.Manifests() repo.PlainHTTP = true - ctx := context.Background() - err = store.PushReference(ctx, indexDesc, bytes.NewReader(index), ref) + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } + err = repo.PushReference(ctx, artifactDesc, bytes.NewReader(artifactJSON), artifactRef) if err != nil { - t.Fatalf("Repository.PushReference() error = %v", err) + t.Fatalf("Manifests.Push() error = %v", err) } - if !bytes.Equal(gotIndex, index) { - t.Errorf("Repository.PushReference() = %v, want %v", gotIndex, index) + if !bytes.Equal(gotManifest, artifactJSON) { + t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(artifactJSON)) + } + if state := repo.loadReferrersState(); state != referrersStateSupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) + } + + // test pushing image manifest with subject + repo, err = NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } + err = repo.PushReference(ctx, manifestDesc, bytes.NewReader(manifestJSON), manifestRef) + if err != nil { + t.Fatalf("Manifests.Push() error = %v", err) + } + if !bytes.Equal(gotManifest, manifestJSON) { + t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(manifestJSON)) + } + if state := repo.loadReferrersState(); state != referrersStateSupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) + } + + // test pushing image index with subject + err = repo.PushReference(ctx, indexDesc, bytes.NewReader(indexJSON), indexRef) + if err != nil { + t.Fatalf("Manifests.Push() error = %v", err) + } + if !bytes.Equal(gotManifest, indexJSON) { + t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(indexJSON)) + } + if state := repo.loadReferrersState(); state != referrersStateSupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) } } -func Test_ManifestStore_PushReference_ReferrersAPIAvailable(t *testing.T) { +func Test_ManifestStore_PushReference_ReferrersAPIAvailable_NoSubjectHeader(t *testing.T) { // generate test content subject := []byte(`{"layers":[]}`) subjectDesc := content.NewDescriptorFromBytes(spec.MediaTypeArtifactManifest, subject) @@ -4498,7 +5813,7 @@ func Test_ManifestStore_PushReference_ReferrersAPIAvailable(t *testing.T) { } artifactJSON, err := json.Marshal(artifact) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } artifactDesc := content.NewDescriptorFromBytes(artifact.MediaType, artifactJSON) artifactRef := "foo" @@ -4509,11 +5824,22 @@ func Test_ManifestStore_PushReference_ReferrersAPIAvailable(t *testing.T) { } manifestJSON, err := json.Marshal(manifest) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } manifestDesc := content.NewDescriptorFromBytes(manifest.MediaType, manifestJSON) manifestRef := "bar" + index := ocispec.Index{ + MediaType: ocispec.MediaTypeImageIndex, + Subject: &subjectDesc, + } + indexJSON, err := json.Marshal(index) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexDesc := content.NewDescriptorFromBytes(index.MediaType, indexJSON) + indexRef := "baz" + var gotManifest []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { @@ -4541,6 +5867,18 @@ func Test_ManifestStore_PushReference_ReferrersAPIAvailable(t *testing.T) { gotManifest = buf.Bytes() w.Header().Set("Docker-Content-Digest", manifestDesc.Digest.String()) w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+indexRef: + if contentType := r.Header.Get("Content-Type"); contentType != indexDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotManifest = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: result := ocispec.Index{ Versioned: specs.Versioned{ @@ -4562,15 +5900,14 @@ func Test_ManifestStore_PushReference_ReferrersAPIAvailable(t *testing.T) { if err != nil { t.Fatalf("invalid test http server: %v", err) } - ctx := context.Background() + + // test pushing artifact with subject repo, err := NewRepository(uri.Host + "/test") if err != nil { t.Fatalf("NewRepository() error = %v", err) } repo.PlainHTTP = true - - // test push artifact with subject if state := repo.loadReferrersState(); state != referrersStateUnknown { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) } @@ -4581,11 +5918,19 @@ func Test_ManifestStore_PushReference_ReferrersAPIAvailable(t *testing.T) { if !bytes.Equal(gotManifest, artifactJSON) { t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(artifactJSON)) } - - // test push image manifest with subject if state := repo.loadReferrersState(); state != referrersStateSupported { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) } + + // test pushing image manifest with subject + repo, err = NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } err = repo.PushReference(ctx, manifestDesc, bytes.NewReader(manifestJSON), manifestRef) if err != nil { t.Fatalf("Manifests.Push() error = %v", err) @@ -4593,6 +5938,21 @@ func Test_ManifestStore_PushReference_ReferrersAPIAvailable(t *testing.T) { if !bytes.Equal(gotManifest, manifestJSON) { t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(manifestJSON)) } + if state := repo.loadReferrersState(); state != referrersStateSupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) + } + + // test pushing image index with subject + err = repo.PushReference(ctx, indexDesc, bytes.NewReader(indexJSON), indexRef) + if err != nil { + t.Fatalf("Manifests.Push() error = %v", err) + } + if !bytes.Equal(gotManifest, indexJSON) { + t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(indexJSON)) + } + if state := repo.loadReferrersState(); state != referrersStateSupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) + } } func Test_ManifestStore_PushReference_ReferrersAPIUnavailable(t *testing.T) { @@ -4608,14 +5968,14 @@ func Test_ManifestStore_PushReference_ReferrersAPIUnavailable(t *testing.T) { } artifactJSON, err := json.Marshal(artifact) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } artifactDesc := content.NewDescriptorFromBytes(artifact.MediaType, artifactJSON) artifactDesc.ArtifactType = artifact.ArtifactType artifactDesc.Annotations = artifact.Annotations artifactRef := "foo" - // test push artifact with subject + // test pushing artifact with subject index_1 := ocispec.Index{ Versioned: specs.Versioned{ SchemaVersion: 2, // historical value. does not pertain to OCI or docker version @@ -4627,7 +5987,7 @@ func Test_ManifestStore_PushReference_ReferrersAPIUnavailable(t *testing.T) { } indexJSON_1, err := json.Marshal(index_1) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } indexDesc_1 := content.NewDescriptorFromBytes(index_1.MediaType, indexJSON_1) var gotManifest []byte @@ -4697,7 +6057,7 @@ func Test_ManifestStore_PushReference_ReferrersAPIUnavailable(t *testing.T) { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) } - // test push image manifest with subject, referrers list should be updated + // test pushing image manifest with subject, referrers list should be updated manifest := ocispec.Manifest{ MediaType: ocispec.MediaTypeImageManifest, Config: ocispec.Descriptor{ @@ -4708,7 +6068,7 @@ func Test_ManifestStore_PushReference_ReferrersAPIUnavailable(t *testing.T) { } manifestJSON, err := json.Marshal(manifest) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } manifestDesc := content.NewDescriptorFromBytes(manifest.MediaType, manifestJSON) manifestDesc.ArtifactType = manifest.Config.MediaType @@ -4727,7 +6087,7 @@ func Test_ManifestStore_PushReference_ReferrersAPIUnavailable(t *testing.T) { } indexJSON_2, err := json.Marshal(index_2) if err != nil { - t.Errorf("failed to marshal manifest: %v", err) + t.Fatalf("failed to marshal manifest: %v", err) } indexDesc_2 := content.NewDescriptorFromBytes(index_2.MediaType, indexJSON_2) var manifestDeleted bool @@ -4802,10 +6162,10 @@ func Test_ManifestStore_PushReference_ReferrersAPIUnavailable(t *testing.T) { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) } - // test push image manifest with subject again, referrers list should not be changed + // test pushing image manifest with subject again, referrers list should not be changed ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+manifestDesc.Digest.String(): + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+manifestRef: if contentType := r.Header.Get("Content-Type"); contentType != manifestDesc.MediaType { w.WriteHeader(http.StatusBadRequest) break @@ -4841,12 +6201,12 @@ func Test_ManifestStore_PushReference_ReferrersAPIUnavailable(t *testing.T) { if state := repo.loadReferrersState(); state != referrersStateUnknown { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) } - err = repo.Push(ctx, manifestDesc, bytes.NewReader(manifestJSON)) + err = repo.PushReference(ctx, manifestDesc, bytes.NewReader(manifestJSON), manifestRef) if err != nil { - t.Fatalf("Manifests.Push() error = %v", err) + t.Fatalf("Manifests.PushReference() error = %v", err) } if !bytes.Equal(gotManifest, manifestJSON) { - t.Errorf("Manifests.Push() = %v, want %v", string(gotManifest), string(manifestJSON)) + t.Errorf("Manifests.PushReference() = %v, want %v", string(gotManifest), string(manifestJSON)) } // referrers list should not be changed if !bytes.Equal(gotReferrerIndex, indexJSON_2) { @@ -4855,6 +6215,109 @@ func Test_ManifestStore_PushReference_ReferrersAPIUnavailable(t *testing.T) { if state := repo.loadReferrersState(); state != referrersStateUnsupported { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) } + + // push image index with subject, referrer list should be updated + indexManifest := ocispec.Index{ + MediaType: ocispec.MediaTypeImageIndex, + Subject: &subjectDesc, + ArtifactType: "test/index", + Annotations: map[string]string{"foo": "bar"}, + } + indexManifestJSON, err := json.Marshal(indexManifest) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexManifestDesc := content.NewDescriptorFromBytes(indexManifest.MediaType, indexManifestJSON) + indexManifestDesc.ArtifactType = indexManifest.ArtifactType + indexManifestDesc.Annotations = indexManifest.Annotations + indexManifestRef := "baz" + index_3 := ocispec.Index{ + Versioned: specs.Versioned{ + SchemaVersion: 2, // historical value. does not pertain to OCI or docker version + }, + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{ + artifactDesc, + manifestDesc, + indexManifestDesc, + }, + } + indexJSON_3, err := json.Marshal(index_3) + if err != nil { + t.Fatalf("failed to marshal manifest: %v", err) + } + indexDesc_3 := content.NewDescriptorFromBytes(index_3.MediaType, indexJSON_3) + manifestDeleted = false + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+indexManifestRef: + if contentType := r.Header.Get("Content-Type"); contentType != indexManifestDesc.MediaType { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotManifest = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexManifestDesc.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/referrers/"+zeroDigest: + w.WriteHeader(http.StatusNotFound) + case r.Method == http.MethodGet && r.URL.Path == "/v2/test/manifests/"+referrersTag: + w.Write(indexJSON_2) + case r.Method == http.MethodPut && r.URL.Path == "/v2/test/manifests/"+referrersTag: + if contentType := r.Header.Get("Content-Type"); contentType != ocispec.MediaTypeImageIndex { + w.WriteHeader(http.StatusBadRequest) + break + } + buf := bytes.NewBuffer(nil) + if _, err := buf.ReadFrom(r.Body); err != nil { + t.Errorf("fail to read: %v", err) + } + gotReferrerIndex = buf.Bytes() + w.Header().Set("Docker-Content-Digest", indexDesc_3.Digest.String()) + w.WriteHeader(http.StatusCreated) + case r.Method == http.MethodDelete && r.URL.Path == "/v2/test/manifests/"+indexDesc_2.Digest.String(): + manifestDeleted = true + // no "Docker-Content-Digest" header for manifest deletion + w.WriteHeader(http.StatusAccepted) + default: + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + uri, err = url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + + ctx = context.Background() + repo, err = NewRepository(uri.Host + "/test") + if err != nil { + t.Fatalf("NewRepository() error = %v", err) + } + repo.PlainHTTP = true + if state := repo.loadReferrersState(); state != referrersStateUnknown { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnknown) + } + err = repo.PushReference(ctx, indexManifestDesc, bytes.NewReader(indexManifestJSON), indexManifestRef) + if err != nil { + t.Fatalf("Manifests.PushReference() error = %v", err) + } + if !bytes.Equal(gotManifest, indexManifestJSON) { + t.Errorf("Manifests.PushReference() = %v, want %v", string(gotManifest), string(indexManifestJSON)) + } + if !bytes.Equal(gotReferrerIndex, indexJSON_3) { + t.Errorf("got referrers index = %v, want %v", string(gotReferrerIndex), string(indexJSON_3)) + } + if !manifestDeleted { + t.Errorf("manifestDeleted = %v, want %v", manifestDeleted, true) + } + if state := repo.loadReferrersState(); state != referrersStateUnsupported { + t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateUnsupported) + } } func Test_ManifestStore_generateDescriptorWithVariousDockerContentDigestHeaders(t *testing.T) { @@ -4879,7 +6342,7 @@ func Test_ManifestStore_generateDescriptorWithVariousDockerContentDigestHeaders( resp := http.Response{ Header: http.Header{ "Content-Type": []string{"application/vnd.docker.distribution.manifest.v2+json"}, - dockerContentDigestHeader: []string{dcdIOStruct.serverCalculatedDigest.String()}, + headerDockerContentDigest: []string{dcdIOStruct.serverCalculatedDigest.String()}, }, } if method == http.MethodGet { @@ -5872,3 +7335,136 @@ func TestRepository_pingReferrers_Concurrent(t *testing.T) { t.Errorf("Repository.loadReferrersState() = %v, want %v", state, referrersStateSupported) } } + +func TestRepository_do(t *testing.T) { + data := []byte(`hello world!`) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/test" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + w.Header().Add("Warning", `299 - "Test 1: Good warning."`) + w.Header().Add("Warning", `199 - "Test 2: Warning with a non-299 code."`) + w.Header().Add("Warning", `299 - "Test 3: Good warning."`) + w.Header().Add("Warning", `299 myregistry.example.com "Test 4: Warning with a non-unknown agent"`) + w.Header().Add("Warning", `299 - "Test 5: Warning with a date." "Sat, 25 Aug 2012 23:34:45 GMT"`) + w.Header().Add("wArnIng", `299 - "Test 6: Good warning."`) + w.Write(data) + })) + defer ts.Close() + uri, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + testURL := ts.URL + "/test" + + // test do() without HandleWarning + repo, err := NewRepository(uri.Host + "/test") + if err != nil { + t.Fatal("NewRepository() error =", err) + } + req, err := http.NewRequest(http.MethodGet, testURL, nil) + if err != nil { + t.Fatal("failed to create test request:", err) + } + resp, err := repo.do(req) + if err != nil { + t.Fatal("Repository.do() error =", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Repository.do() status code = %v, want %v", resp.StatusCode, http.StatusOK) + } + if got := len(resp.Header["Warning"]); got != 6 { + t.Errorf("Repository.do() warning header len = %v, want %v", got, 6) + } + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal("io.ReadAll() error =", err) + } + resp.Body.Close() + if !bytes.Equal(got, data) { + t.Errorf("Repository.do() = %v, want %v", got, data) + } + + // test do() with HandleWarning + repo, err = NewRepository(uri.Host + "/test") + if err != nil { + t.Fatal("NewRepository() error =", err) + } + var gotWarnings []Warning + repo.HandleWarning = func(warning Warning) { + gotWarnings = append(gotWarnings, warning) + } + + req, err = http.NewRequest(http.MethodGet, testURL, nil) + if err != nil { + t.Fatal("failed to create test request:", err) + } + resp, err = repo.do(req) + if err != nil { + t.Fatal("Repository.do() error =", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Repository.do() status code = %v, want %v", resp.StatusCode, http.StatusOK) + } + if got := len(resp.Header["Warning"]); got != 6 { + t.Errorf("Repository.do() warning header len = %v, want %v", got, 6) + } + got, err = io.ReadAll(resp.Body) + if err != nil { + t.Errorf("Repository.do() = %v, want %v", got, data) + } + resp.Body.Close() + if !bytes.Equal(got, data) { + t.Errorf("Repository.do() = %v, want %v", got, data) + } + + wantWarnings := []Warning{ + { + WarningValue: WarningValue{ + Code: 299, + Agent: "-", + Text: "Test 1: Good warning.", + }, + }, + { + WarningValue: WarningValue{ + Code: 299, + Agent: "-", + Text: "Test 3: Good warning.", + }, + }, + { + WarningValue: WarningValue{ + Code: 299, + Agent: "-", + Text: "Test 6: Good warning.", + }, + }, + } + if !reflect.DeepEqual(gotWarnings, wantWarnings) { + t.Errorf("Repository.do() = %v, want %v", gotWarnings, wantWarnings) + } +} + +func TestRepository_clone(t *testing.T) { + repo, err := NewRepository("localhost:1234/repo/image") + if err != nil { + t.Fatalf("invalid repository: %v", err) + } + + crepo := repo.clone() + + if repo.Reference != crepo.Reference { + t.Fatal("references should be the same") + } + + if !reflect.DeepEqual(&repo.referrersPingLock, &crepo.referrersPingLock) { + t.Fatal("referrersPingLock should be different") + } + + if !reflect.DeepEqual(&repo.referrersMergePool, &crepo.referrersMergePool) { + t.Fatal("referrersMergePool should be different") + } +} diff --git a/registry/remote/url.go b/registry/remote/url.go index d3eee3ee..74258de7 100644 --- a/registry/remote/url.go +++ b/registry/remote/url.go @@ -101,7 +101,7 @@ func buildRepositoryBlobMountURL(plainHTTP bool, ref registry.Reference, d diges // buildReferrersURL builds the URL for querying the Referrers API. // Format: :///v2//referrers/?artifactType= -// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#listing-referrers +// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#listing-referrers func buildReferrersURL(plainHTTP bool, ref registry.Reference, artifactType string) string { var query string if artifactType != "" { diff --git a/registry/remote/warning.go b/registry/remote/warning.go new file mode 100644 index 00000000..ff8f9c02 --- /dev/null +++ b/registry/remote/warning.go @@ -0,0 +1,100 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remote + +import ( + "errors" + "fmt" + "strconv" + "strings" +) + +const ( + // headerWarning is the "Warning" header. + // Reference: https://www.rfc-editor.org/rfc/rfc7234#section-5.5 + headerWarning = "Warning" + + // warnCode299 is the 299 warn-code. + // Reference: https://www.rfc-editor.org/rfc/rfc7234#section-5.5 + warnCode299 = 299 + + // warnAgentUnknown represents an unknown warn-agent. + // Reference: https://www.rfc-editor.org/rfc/rfc7234#section-5.5 + warnAgentUnknown = "-" +) + +// errUnexpectedWarningFormat is returned by parseWarningHeader when +// an unexpected warning format is encountered. +var errUnexpectedWarningFormat = errors.New("unexpected warning format") + +// WarningValue represents the value of the Warning header. +// +// References: +// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#warnings +// - https://www.rfc-editor.org/rfc/rfc7234#section-5.5 +type WarningValue struct { + // Code is the warn-code. + Code int + // Agent is the warn-agent. + Agent string + // Text is the warn-text. + Text string +} + +// Warning contains the value of the warning header and may contain +// other information related to the warning. +// +// References: +// - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#warnings +// - https://www.rfc-editor.org/rfc/rfc7234#section-5.5 +type Warning struct { + // WarningValue is the value of the warning header. + WarningValue +} + +// parseWarningHeader parses the warning header into WarningValue. +func parseWarningHeader(header string) (WarningValue, error) { + if len(header) < 9 || !strings.HasPrefix(header, `299 - "`) || !strings.HasSuffix(header, `"`) { + // minimum header value: `299 - "x"` + return WarningValue{}, fmt.Errorf("%s: %w", header, errUnexpectedWarningFormat) + } + + // validate text only as code and agent are fixed + quotedText := header[6:] // behind `299 - `, quoted by " + text, err := strconv.Unquote(quotedText) + if err != nil { + return WarningValue{}, fmt.Errorf("%s: unexpected text: %w: %v", header, errUnexpectedWarningFormat, err) + } + + return WarningValue{ + Code: warnCode299, + Agent: warnAgentUnknown, + Text: text, + }, nil +} + +// handleWarningHeaders parses the warning headers and handles the parsed +// warnings using handleWarning. +func handleWarningHeaders(headers []string, handleWarning func(Warning)) { + for _, h := range headers { + if value, err := parseWarningHeader(h); err == nil { + // ignore warnings in unexpected formats + handleWarning(Warning{ + WarningValue: value, + }) + } + } +} diff --git a/registry/remote/warning_test.go b/registry/remote/warning_test.go new file mode 100644 index 00000000..d8c22b66 --- /dev/null +++ b/registry/remote/warning_test.go @@ -0,0 +1,158 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remote + +import ( + "errors" + "reflect" + "testing" +) + +func Test_parseWarningHeader(t *testing.T) { + tests := []struct { + name string + header string + want WarningValue + wantErr error + }{ + { + name: "Valid warning", + header: `299 - "This is a warning."`, + want: WarningValue{ + Code: 299, + Agent: "-", + Text: "This is a warning.", + }, + }, + { + name: "Valid meaningless warning", + header: `299 - " "`, + want: WarningValue{ + Code: 299, + Agent: "-", + Text: " ", + }, + }, + { + name: "Multiple spaces in warning", + header: `299 - "This is a warning."`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "Leading space in warning", + header: ` 299 - "This is a warning."`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "Trailing space in warning", + header: `299 - "This is a warning." `, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "Warning with a non-299 code", + header: `199 - "This is a warning."`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "Warning with a non-unknown agent", + header: `299 localhost:5000 "This is a warning."`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "Warning with a date", + header: `299 - "This is a warning." "Sat, 25 Aug 2012 23:34:45 GMT"`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "Invalid format", + header: `299 - "This is a warning." something strange`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "Not a warning", + header: `foo bar baz`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "No code", + header: `- "This is a warning."`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "No agent", + header: `299 "This is a warning."`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "No text", + header: `299 -`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "Empty text", + header: `299 - ""`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "Unquoted text", + header: `299 - This is a warning.`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "Single-quoted text", + header: `299 - 'This is a warning.'`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "Back-quoted text", + header: "299 - `This is a warning.`", + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + { + name: "Invalid quotes", + header: `299 - 'This is a warning."`, + want: WarningValue{}, + wantErr: errUnexpectedWarningFormat, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseWarningHeader(tt.header) + if !errors.Is(err, tt.wantErr) { + t.Errorf("parseWarningHeader() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseWarningHeader() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/registry/repository.go b/registry/repository.go index 2dd7ff99..b75b7b8e 100644 --- a/registry/repository.go +++ b/registry/repository.go @@ -82,7 +82,7 @@ type ReferenceFetcher interface { } // ReferrerLister provides the Referrers API. -// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#listing-referrers +// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#listing-referrers type ReferrerLister interface { Referrers(ctx context.Context, desc ocispec.Descriptor, artifactType string, fn func(referrers []ocispec.Descriptor) error) error } @@ -93,16 +93,19 @@ type TagLister interface { // Since the returned tag list may be paginated by the underlying // implementation, a function should be passed in to process the paginated // tag list. + // // `last` argument is the `last` parameter when invoking the tags API. // If `last` is NOT empty, the entries in the response start after the // tag specified by `last`. Otherwise, the response starts from the top // of the Tags list. + // // Note: When implemented by a remote registry, the tags API is called. // However, not all registries supports pagination or conforms the // specification. + // // References: - // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc1/spec.md#content-discovery - // - https://docs.docker.com/registry/spec/api/#tags + // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#content-discovery + // - https://docs.docker.com/registry/spec/api/#tags // See also `Tags()` in this package. Tags(ctx context.Context, last string, fn func(tags []string) error) error }