Skip to content

Commit

Permalink
Use existing auth plugins with OCIDownloader
Browse files Browse the repository at this point in the history
This change addresses solutions 2) and 3) of the related issue #5553.
It mainly starts using the (now exposed) `Config.AuthPlugin()` function
of the `rest` package in the `download.OCIDownloader`. This allows it
to use any `HTTPAuthPlugin` that is defined in the `Config.Credentials`
section and makes it much more consistent with behavior of the
`download.Downloader` and potential other uses of the rest package.

Fixes #5553

Signed-off-by: DerGut <[email protected]>
  • Loading branch information
DerGut authored and ashutosh-narkar committed Apr 26, 2023
1 parent 85fbad5 commit b626a2c
Show file tree
Hide file tree
Showing 16 changed files with 1,031 additions and 296 deletions.
40 changes: 27 additions & 13 deletions docs/content/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -739,24 +739,38 @@ func init() {

### Using private image from OCI repositories

When using a private image from an OCI registry the credentials are mandatory as the OCI downloader needs the credentials for the pull operation.
When using a private image from an OCI registry you need to specify an authentication method. Supported authentication methods are listed in the [Services](#services) section. The Azure managed identity plugin
is not supported at this point in time.

Examples of setting credetials for pulling private images:
*AWS ECR* private image usually requires at least basic authentication. The credentials to authenticate can be obtained using the AWS CLI command `aws ecr get-login` and those can be passed to the service configuration as basic bearer credentials as follows:
Examples of setting credentials for pulling private images:
*AWS ECR* private images usually require at least basic authentication. The credentials to authenticate can be obtained using the AWS CLI command `aws ecr get-login` and those can be passed to the service configuration as basic bearer credentials as follows:
```yaml
credentials:
bearer:
scheme: "Basic"
token: "<username>:<password>"
```
credentials:
bearer:
scheme: "Basic"
token: "<username>:<password>"
Other AWS authentication methods also work:
```yaml
credentials:
s3_signing:
service: "ecr"
metadata_credentials:
aws_region: us-east-1
```
The OCI downloader includes a base64 encoder for these credentials so they can be supplied as shown above.
Note, that the authentication method `s3_signing` does work for
signing requests to other AWS services.

A special case is that bearer authentication works differently to normal service authentication. The OCI downloader base64-encodes the credentials for you so that they need to be supplied in plain text.

For *GHCR* (Github Container Registry) you can use a developer PAT (personal access token) when downloading a private image. These can be supplied as:
```
credentials:
bearer:
scheme: "Bearer"
token: "<PAT>"
```yaml
credentials:
bearer:
scheme: "Bearer"
token: "<PAT>"
```

### Miscellaneous
Expand Down
118 changes: 56 additions & 62 deletions download/oci_download.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package download

import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -294,44 +292,47 @@ func (d *OCIDownloader) download(ctx context.Context, m metrics.Metrics) (*downl
}

func (d *OCIDownloader) pull(ctx context.Context, ref string) (*ocispec.Descriptor, error) {
authHeader := make(http.Header)
client, err := d.getHTTPClient(&authHeader)
lookup := d.client.AuthPluginLookup()

plugin, err := d.client.Config().AuthPlugin(lookup)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to look up auth plugin: %w", err)
}
urlInfo, err := url.Parse(d.client.Config().URL)

d.logger.Debug("OCIDownloader: using auth plugin: %T", plugin)

resolver, err := dockerResolver(plugin, d.client.Config(), d.logger)
if err != nil {
return nil, fmt.Errorf("invalid host url %s: %w", d.client.Config().URL, err)
}

target := newRemoteManager(d.getResolverHost(client, urlInfo), authHeader, ref)
target := remoteManager{
resolver: resolver,
srcRef: ref,
}

manifestDescriptor, err := oraslib.Copy(ctx, target, ref, d.store, "", oraslib.DefaultCopyOptions)
manifestDescriptor, err := oraslib.Copy(ctx, &target, ref, d.store, "", oraslib.DefaultCopyOptions)
if err != nil {
return nil, fmt.Errorf("download for '%s' failed: %w", ref, err)
}

return &manifestDescriptor, nil
}

func (d *OCIDownloader) getResolverHost(client *http.Client, urlInfo *url.URL) docker.RegistryHosts {
creds := d.client.Config().Credentials
var auth docker.Authorizer
if creds.Bearer == nil {
auth = docker.NewDockerAuthorizer(
docker.WithAuthClient(client),
)
} else {
auth = docker.NewDockerAuthorizer(
docker.WithAuthClient(client),
docker.WithAuthCreds(func(string) (string, string, error) {
creds := d.client.Config().Credentials
if creds.Bearer == nil {
return " ", " ", nil
}

return creds.Bearer.Scheme, creds.Bearer.Token, nil
}))
func dockerResolver(plugin rest.HTTPAuthPlugin, config *rest.Config, logger logging.Logger) (remotes.Resolver, error) {
client, err := plugin.NewClient(*config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %w", err)
}

urlInfo, err := url.Parse(config.URL)
if err != nil {
return nil, fmt.Errorf("failed to parse url: %w", err)
}

authorizer := pluginAuthorizer{
plugin: plugin,
logger: logger,
}

registryHost := docker.RegistryHost{
Expand All @@ -340,40 +341,42 @@ func (d *OCIDownloader) getResolverHost(client *http.Client, urlInfo *url.URL) d
Capabilities: docker.HostCapabilityPull | docker.HostCapabilityResolve | docker.HostCapabilityPush,
Client: client,
Path: "/v2",
Authorizer: auth,
Authorizer: &authorizer,
}

return func(string) ([]docker.RegistryHost, error) {
return []docker.RegistryHost{registryHost}, nil
opts := docker.ResolverOptions{
Hosts: func(string) ([]docker.RegistryHost, error) {
return []docker.RegistryHost{registryHost}, nil
},
}

return docker.NewResolver(opts), nil
}

func (d *OCIDownloader) getHTTPClient(authHeader *http.Header) (*http.Client, error) {
var client *http.Client
var err error
clientConfig := d.client.Config()
if clientConfig != nil && clientConfig.Credentials.ClientTLS != nil {
client, err = clientConfig.Credentials.ClientTLS.NewClient(*clientConfig)
if err != nil {
return nil, fmt.Errorf("can not create a new client: %w", err)
}
} else {
if clientConfig != nil && clientConfig.Credentials.Bearer != nil {
client, err = clientConfig.Credentials.Bearer.NewClient(*clientConfig)
if err != nil {
return nil, fmt.Errorf("can not create a new bearer client: %w", err)
}
type pluginAuthorizer struct {
plugin rest.HTTPAuthPlugin

authHeader.Add("Authorization",
fmt.Sprintf("%s %s",
clientConfig.Credentials.Bearer.Scheme,
base64.StdEncoding.EncodeToString([]byte(clientConfig.Credentials.Bearer.Token))),
)
} else {
client = &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: clientConfig.AllowInsecureTLS}}}
}
logger logging.Logger
}

var _ docker.Authorizer = &pluginAuthorizer{}

func (a *pluginAuthorizer) AddResponses(context.Context, []*http.Response) error {
return fmt.Errorf("using custom authorizer: %w", errdefs.ErrNotImplemented)
}

// Authorize uses a rest.HTTPAuthPlugin to Prepare a request.
func (a *pluginAuthorizer) Authorize(_ context.Context, req *http.Request) error {
if err := a.plugin.Prepare(req); err != nil {
err = fmt.Errorf("failed to prepare docker request: %w", err)

// Make sure to log this before passing the error back to docker
a.logger.Error(err.Error())

return err
}
return client, nil

return nil
}

func manifestFromDesc(ctx context.Context, target oras.Target, desc *ocispec.Descriptor) (*ocispec.Manifest, error) {
Expand Down Expand Up @@ -406,15 +409,6 @@ type remoteManager struct {
srcRef string
}

func newRemoteManager(hosts docker.RegistryHosts, headers http.Header, srcRef string) *remoteManager {
resolver := docker.NewResolver(docker.ResolverOptions{
Hosts: hosts,
Headers: headers,
})

return &remoteManager{resolver: resolver, srcRef: srcRef}
}

func (r *remoteManager) Resolve(ctx context.Context, ref string) (ocispec.Descriptor, error) {
_, desc, err := r.resolver.Resolve(ctx, ref)
if err != nil {
Expand Down
122 changes: 110 additions & 12 deletions download/oci_download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"encoding/base64"
"fmt"
"net/http"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -50,23 +51,40 @@ func TestOCIStartStop(t *testing.T) {
d.Stop(ctx)
}

func TestOCIAuth(t *testing.T) {
func TestOCIBearerAuthPlugin(t *testing.T) {
ctx := context.Background()
fixture := newTestFixture(t)
token := base64.StdEncoding.EncodeToString([]byte("secret")) // token should be base64 encoded
fixture.server.expAuth = fmt.Sprintf("Bearer %s", token) // test on private repository
plainToken := "secret"
token := base64.StdEncoding.EncodeToString([]byte(plainToken)) // token should be base64 encoded
fixture.server.expAuth = fmt.Sprintf("Bearer %s", token) // test on private repository
fixture.server.expEtag = "sha256:c5834dbce332cabe6ae68a364de171a50bf5b08024c27d7c08cc72878b4df7ff"

restConf := fmt.Sprintf(`{
"url": %q,
"type": "oci",
"credentials": {
"bearer": {
"token": %q
}
}
}`, fixture.server.server.URL, plainToken)

client, err := rest.New([]byte(restConf), map[string]*keys.Config{})
if err != nil {
t.Fatal(err)
}

fixture.setClient(client)

config := Config{}
if err := config.ValidateAndInjectDefaults(); err != nil {
t.Fatal(err)
}

d := NewOCI(config, fixture.client, "ghcr.io/org/repo:latest", "/tmp/oci")

err := d.oneShot(ctx)
if err != nil {
t.Fatal("unexpected error")
if err := d.oneShot(ctx); err != nil {
t.Fatal(err)
}
}

Expand All @@ -89,15 +107,33 @@ func TestOCIFailureAuthn(t *testing.T) {
}

func TestOCIEtag(t *testing.T) {
ctx := context.Background()
fixture := newTestFixture(t)
token := base64.StdEncoding.EncodeToString([]byte("secret")) // token should be base64 encoded
fixture.server.expAuth = fmt.Sprintf("Bearer %s", token) // test on private repository
fixture.server.expEtag = "sha256:c5834dbce332cabe6ae68a364de171a50bf5b08024c27d7c08cc72878b4df7ff"

restConfig := []byte(fmt.Sprintf(`{
"url": %q,
"type": "oci",
"credentials": {
"bearer": {
"token": "secret"
}
}
}`, fixture.server.server.URL))

client, err := rest.New(restConfig, map[string]*keys.Config{})
if err != nil {
t.Fatal(err)
}

fixture.setClient(client)

config := Config{}
if err := config.ValidateAndInjectDefaults(); err != nil {
t.Fatal(err)
}

firstResponse := Update{ETag: ""}
d := NewOCI(config, fixture.client, "ghcr.io/org/repo:latest", "/tmp/oci").WithCallback(func(_ context.Context, u Update) {
if firstResponse.ETag == "" {
Expand All @@ -111,17 +147,16 @@ func TestOCIEtag(t *testing.T) {
})

// fill firstResponse
err := d.oneShot(ctx)
if err != nil {
t.Fatal("unexpected error")
if err := d.oneShot(context.Background()); err != nil {
t.Fatal(err)
}
// Give time for some download events to occur
time.Sleep(1 * time.Second)

// second call to verify if nil bundle is returned and same etag
err = d.oneShot(ctx)
err = d.oneShot(context.Background())
if err != nil {
t.Fatal("unexpected error")
t.Fatal(err)
}
}

Expand Down Expand Up @@ -152,3 +187,66 @@ func TestOCIPublicRegistry(t *testing.T) {
t.Fatal("unexpected error")
}
}

func TestOCICustomAuthPlugin(t *testing.T) {
fixture := newTestFixture(t)
defer fixture.server.stop()

restConfig := []byte(fmt.Sprintf(`{
"url": %q,
"credentials": {
"plugin": "my_plugin"
}
}`, fixture.server.server.URL))

client, err := rest.New(
restConfig,
map[string]*keys.Config{},
rest.AuthPluginLookup(mockAuthPluginLookup),
)
if err != nil {
t.Fatal(err)
}

fixture.setClient(client)

config := Config{}
if err := config.ValidateAndInjectDefaults(); err != nil {
t.Fatal(err)
}

tmpDir := t.TempDir()

d := NewOCI(config, fixture.client, "ghcr.io/org/repo:latest", tmpDir)

if err := d.oneShot(context.Background()); err != nil {
t.Fatal(err)
}
}

func mockAuthPluginLookup(string) rest.HTTPAuthPlugin {
return &mockAuthPlugin{}
}

type mockAuthPlugin struct{}

func (p *mockAuthPlugin) NewClient(c rest.Config) (*http.Client, error) {
tlsConfig, err := rest.DefaultTLSConfig(c)
if err != nil {
return nil, err
}

timeoutSec := 10

client := rest.DefaultRoundTripperClient(
tlsConfig,
int64(timeoutSec),
)

return client, nil
}

func (*mockAuthPlugin) Prepare(r *http.Request) error {
r.Header.Set("Authorization", "Bearer secret")
return nil
}
Loading

0 comments on commit b626a2c

Please sign in to comment.