diff --git a/docs/overview.md b/docs/overview.md index 1b1085771..27a0e618b 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -176,6 +176,17 @@ host = "exampleregistry.io" insecure = true ``` +`header` field allows to set headers to send to the server. + +```toml +[[resolver.host."registry2:5000".mirrors]] + host = "registry2:5000" + [resolver.host."registry2:5000".mirrors.header] + x-custom-2 = ["value3", "value4"] +``` + +> NOTE: Headers aren't passed to the redirected location. + The config file can be passed to stargz snapshotter using `containerd-stargz-grpc`'s `--config` option. ## Make your remote snapshotter diff --git a/fs/remote/resolver.go b/fs/remote/resolver.go index c15fc3173..8efb4eb47 100644 --- a/fs/remote/resolver.go +++ b/fs/remote/resolver.go @@ -238,7 +238,7 @@ func newHTTPFetcher(ctx context.Context, fc *fetcherConfig) (*httpFetcher, int64 path.Join(host.Host, host.Path), strings.TrimPrefix(fc.refspec.Locator, fc.refspec.Hostname()+"/"), digest) - url, err := redirect(ctx, blobURL, tr, timeout) + url, header, err := redirect(ctx, blobURL, tr, timeout, host.Header) if err != nil { rErr = fmt.Errorf("failed to redirect (host %q, ref:%q, digest:%q): %v: %w", host.Host, fc.refspec, digest, err, rErr) continue // Try another @@ -247,7 +247,7 @@ func newHTTPFetcher(ctx context.Context, fc *fetcherConfig) (*httpFetcher, int64 // Get size information // TODO: we should try to use the Size field in the descriptor here. start := time.Now() // start time before getting layer header - size, err := getSize(ctx, url, tr, timeout) + size, err := getSize(ctx, url, tr, timeout, header) commonmetrics.MeasureLatencyInMilliseconds(commonmetrics.StargzHeaderGet, digest, start) // time to get layer header if err != nil { rErr = fmt.Errorf("failed to get size (host %q, ref:%q, digest:%q): %v: %w", host.Host, fc.refspec, digest, err, rErr) @@ -256,11 +256,13 @@ func newHTTPFetcher(ctx context.Context, fc *fetcherConfig) (*httpFetcher, int64 // Hit one destination return &httpFetcher{ - url: url, - tr: tr, - blobURL: blobURL, - digest: digest, - timeout: timeout, + url: url, + tr: tr, + blobURL: blobURL, + digest: digest, + timeout: timeout, + header: header, + orgHeader: host.Header, }, size, nil } @@ -309,7 +311,7 @@ func (tr *transport) RoundTrip(req *http.Request) (*http.Response, error) { return resp, nil } -func redirect(ctx context.Context, blobURL string, tr http.RoundTripper, timeout time.Duration) (url string, err error) { +func redirect(ctx context.Context, blobURL string, tr http.RoundTripper, timeout time.Duration, header http.Header) (url string, withHeader http.Header, err error) { if timeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, timeout) @@ -320,13 +322,17 @@ func redirect(ctx context.Context, blobURL string, tr http.RoundTripper, timeout // ghcr.io returns 200 on HEAD without Location header (2020). req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil) if err != nil { - return "", fmt.Errorf("failed to make request to the registry: %w", err) + return "", nil, fmt.Errorf("failed to make request to the registry: %w", err) + } + req.Header = http.Header{} + for k, v := range header { + req.Header[k] = v } req.Close = false req.Header.Set("Range", "bytes=0-1") res, err := tr.RoundTrip(req) if err != nil { - return "", fmt.Errorf("failed to request: %w", err) + return "", nil, fmt.Errorf("failed to request: %w", err) } defer func() { io.Copy(io.Discard, res.Body) @@ -335,17 +341,19 @@ func redirect(ctx context.Context, blobURL string, tr http.RoundTripper, timeout if res.StatusCode/100 == 2 { url = blobURL + withHeader = header } else if redir := res.Header.Get("Location"); redir != "" && res.StatusCode/100 == 3 { // TODO: Support nested redirection url = redir + // Do not pass headers to the redirected location. } else { - return "", fmt.Errorf("failed to access to the registry with code %v", res.StatusCode) + return "", nil, fmt.Errorf("failed to access to the registry with code %v", res.StatusCode) } return } -func getSize(ctx context.Context, url string, tr http.RoundTripper, timeout time.Duration) (int64, error) { +func getSize(ctx context.Context, url string, tr http.RoundTripper, timeout time.Duration, header http.Header) (int64, error) { if timeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, timeout) @@ -355,6 +363,10 @@ func getSize(ctx context.Context, url string, tr http.RoundTripper, timeout time if err != nil { return 0, err } + req.Header = http.Header{} + for k, v := range header { + req.Header[k] = v + } req.Close = false res, err := tr.RoundTrip(req) if err != nil { @@ -373,6 +385,10 @@ func getSize(ctx context.Context, url string, tr http.RoundTripper, timeout time if err != nil { return 0, fmt.Errorf("failed to make request to the registry: %w", err) } + req.Header = http.Header{} + for k, v := range header { + req.Header[k] = v + } req.Close = false req.Header.Set("Range", "bytes=0-1") res, err = tr.RoundTrip(req) @@ -404,6 +420,8 @@ type httpFetcher struct { singleRange bool singleRangeMu sync.Mutex timeout time.Duration + header http.Header + orgHeader http.Header } type multipartReadCloser interface { @@ -443,6 +461,10 @@ func (f *httpFetcher) fetch(ctx context.Context, rs []region, retry bool) (multi if err != nil { return nil, err } + req.Header = http.Header{} + for k, v := range f.header { + req.Header[k] = v + } var ranges string for _, reg := range requests { ranges += fmt.Sprintf("%d-%d,", reg.b, reg.e) @@ -514,6 +536,10 @@ func (f *httpFetcher) check() error { if err != nil { return fmt.Errorf("check failed: failed to make request: %w", err) } + req.Header = http.Header{} + for k, v := range f.header { + req.Header[k] = v + } req.Close = false req.Header.Set("Range", "bytes=0-1") res, err := f.tr.RoundTrip(req) @@ -544,12 +570,13 @@ func (f *httpFetcher) check() error { } func (f *httpFetcher) refreshURL(ctx context.Context) error { - newURL, err := redirect(ctx, f.blobURL, f.tr, f.timeout) + newURL, headers, err := redirect(ctx, f.blobURL, f.tr, f.timeout, f.orgHeader) if err != nil { return err } f.urlMu.Lock() f.url = newURL + f.header = headers f.urlMu.Unlock() return nil } diff --git a/fs/remote/resolver_test.go b/fs/remote/resolver_test.go index 421544cc6..490786e05 100644 --- a/fs/remote/resolver_test.go +++ b/fs/remote/resolver_test.go @@ -35,6 +35,7 @@ import ( "github.com/containerd/containerd/reference" "github.com/containerd/containerd/remotes/docker" + "github.com/containerd/stargz-snapshotter/fs/source" rhttp "github.com/hashicorp/go-retryablehttp" digest "github.com/opencontainers/go-digest" ocispec "github.com/opencontainers/image-spec/specs-go/v1" @@ -55,117 +56,219 @@ func TestMirror(t *testing.T) { tests := []struct { name string - tr http.RoundTripper - mirrors []string + hosts func(t *testing.T) source.RegistryHosts wantHost string error bool }{ { - name: "no-mirror", - tr: &sampleRoundTripper{okURLs: []string{refHost}}, - mirrors: nil, + name: "no-mirror", + hosts: hostsConfig( + &sampleRoundTripper{okURLs: []string{refHost}}, + ), wantHost: refHost, }, { - name: "valid-mirror", - tr: &sampleRoundTripper{okURLs: []string{"mirrorexample.com"}}, - mirrors: []string{"mirrorexample.com"}, + name: "valid-mirror", + hosts: hostsConfig( + &sampleRoundTripper{okURLs: []string{"mirrorexample.com"}}, + hostSimple("mirrorexample.com"), + ), wantHost: "mirrorexample.com", }, { name: "invalid-mirror", - tr: &sampleRoundTripper{ - withCode: map[string]int{ - "mirrorexample1.com": http.StatusInternalServerError, - "mirrorexample2.com": http.StatusUnauthorized, - "mirrorexample3.com": http.StatusNotFound, + hosts: hostsConfig( + &sampleRoundTripper{ + withCode: map[string]int{ + "mirrorexample1.com": http.StatusInternalServerError, + "mirrorexample2.com": http.StatusUnauthorized, + "mirrorexample3.com": http.StatusNotFound, + }, + okURLs: []string{"mirrorexample4.com", refHost}, }, - okURLs: []string{"mirrorexample4.com", refHost}, - }, - mirrors: []string{ - "mirrorexample1.com", - "mirrorexample2.com", - "mirrorexample3.com", - "mirrorexample4.com", - }, + hostSimple("mirrorexample1.com"), + hostSimple("mirrorexample2.com"), + hostSimple("mirrorexample3.com"), + hostSimple("mirrorexample4.com"), + ), wantHost: "mirrorexample4.com", }, { name: "invalid-all-mirror", - tr: &sampleRoundTripper{ - withCode: map[string]int{ - "mirrorexample1.com": http.StatusInternalServerError, - "mirrorexample2.com": http.StatusUnauthorized, - "mirrorexample3.com": http.StatusNotFound, + hosts: hostsConfig( + &sampleRoundTripper{ + withCode: map[string]int{ + "mirrorexample1.com": http.StatusInternalServerError, + "mirrorexample2.com": http.StatusUnauthorized, + "mirrorexample3.com": http.StatusNotFound, + }, + okURLs: []string{refHost}, }, - okURLs: []string{refHost}, - }, - mirrors: []string{ - "mirrorexample1.com", - "mirrorexample2.com", - "mirrorexample3.com", - }, + hostSimple("mirrorexample1.com"), + hostSimple("mirrorexample2.com"), + hostSimple("mirrorexample3.com"), + ), wantHost: refHost, }, { name: "invalid-hostname-of-mirror", - tr: &sampleRoundTripper{ - okURLs: []string{`.*`}, - }, - mirrors: []string{"mirrorexample.com/somepath/"}, + hosts: hostsConfig( + &sampleRoundTripper{ + okURLs: []string{`.*`}, + }, + hostSimple("mirrorexample.com/somepath/"), + ), wantHost: refHost, }, { name: "redirected-mirror", - tr: &sampleRoundTripper{ - redirectURL: map[string]string{ - regexp.QuoteMeta(fmt.Sprintf("mirrorexample.com%s", blobPath)): "https://backendexample.com/blobs/" + blobDigest.String(), + hosts: hostsConfig( + &sampleRoundTripper{ + redirectURL: map[string]string{ + regexp.QuoteMeta(fmt.Sprintf("mirrorexample.com%s", blobPath)): "https://backendexample.com/blobs/" + blobDigest.String(), + }, + okURLs: []string{`.*`}, }, - okURLs: []string{`.*`}, - }, - mirrors: []string{"mirrorexample.com"}, + hostSimple("mirrorexample.com"), + ), wantHost: "backendexample.com", }, { name: "invalid-redirected-mirror", - tr: &sampleRoundTripper{ - withCode: map[string]int{ - "backendexample.com": http.StatusInternalServerError, - }, - redirectURL: map[string]string{ - regexp.QuoteMeta(fmt.Sprintf("mirrorexample.com%s", blobPath)): "https://backendexample.com/blobs/" + blobDigest.String(), + hosts: hostsConfig( + &sampleRoundTripper{ + withCode: map[string]int{ + "backendexample.com": http.StatusInternalServerError, + }, + redirectURL: map[string]string{ + regexp.QuoteMeta(fmt.Sprintf("mirrorexample.com%s", blobPath)): "https://backendexample.com/blobs/" + blobDigest.String(), + }, + okURLs: []string{`.*`}, }, - okURLs: []string{`.*`}, - }, - mirrors: []string{"mirrorexample.com"}, + hostSimple("mirrorexample.com"), + ), wantHost: refHost, }, { - name: "fail-all", - tr: &sampleRoundTripper{}, - mirrors: []string{"mirrorexample.com"}, + name: "fail-all", + hosts: hostsConfig( + &sampleRoundTripper{}, + hostSimple("mirrorexample.com"), + ), wantHost: "", error: true, }, + { + name: "headers", + hosts: hostsConfig( + &sampleRoundTripper{ + okURLs: []string{`.*`}, + wantHeaders: map[string]http.Header{ + "mirrorexample.com": http.Header(map[string][]string{ + "test-a-key": {"a-value-1", "a-value-2"}, + "test-b-key": {"b-value-1"}, + }), + }, + }, + hostWithHeaders("mirrorexample.com", map[string][]string{ + "test-a-key": {"a-value-1", "a-value-2"}, + "test-b-key": {"b-value-1"}, + }), + ), + wantHost: "mirrorexample.com", + }, + { + name: "headers-with-mirrors", + hosts: hostsConfig( + &sampleRoundTripper{ + withCode: map[string]int{ + "mirrorexample1.com": http.StatusInternalServerError, + "mirrorexample2.com": http.StatusInternalServerError, + }, + okURLs: []string{"mirrorexample3.com", refHost}, + wantHeaders: map[string]http.Header{ + "mirrorexample1.com": http.Header(map[string][]string{ + "test-a-key": {"a-value"}, + }), + "mirrorexample2.com": http.Header(map[string][]string{ + "test-b-key": {"b-value"}, + "test-b-key-2": {"b-value-2", "b-value-3"}, + }), + "mirrorexample3.com": http.Header(map[string][]string{ + "test-c-key": {"c-value"}, + }), + }, + }, + hostWithHeaders("mirrorexample1.com", map[string][]string{ + "test-a-key": {"a-value"}, + }), + hostWithHeaders("mirrorexample2.com", map[string][]string{ + "test-b-key": {"b-value"}, + "test-b-key-2": {"b-value-2", "b-value-3"}, + }), + hostWithHeaders("mirrorexample3.com", map[string][]string{ + "test-c-key": {"c-value"}, + }), + ), + wantHost: "mirrorexample3.com", + }, + { + name: "headers-with-mirrors-invalid-all", + hosts: hostsConfig( + &sampleRoundTripper{ + withCode: map[string]int{ + "mirrorexample1.com": http.StatusInternalServerError, + "mirrorexample2.com": http.StatusInternalServerError, + }, + okURLs: []string{"mirrorexample3.com", refHost}, + wantHeaders: map[string]http.Header{ + "mirrorexample1.com": http.Header(map[string][]string{ + "test-a-key": {"a-value"}, + }), + "mirrorexample2.com": http.Header(map[string][]string{ + "test-b-key": {"b-value"}, + "test-b-key-2": {"b-value-2", "b-value-3"}, + }), + }, + }, + hostWithHeaders("mirrorexample1.com", map[string][]string{ + "test-a-key": {"a-value"}, + }), + hostWithHeaders("mirrorexample2.com", map[string][]string{ + "test-b-key": {"b-value"}, + "test-b-key-2": {"b-value-2", "b-value-3"}, + }), + ), + wantHost: refHost, + }, + { + name: "headers-with-redirected-mirror", + hosts: hostsConfig( + &sampleRoundTripper{ + redirectURL: map[string]string{ + regexp.QuoteMeta(fmt.Sprintf("mirrorexample.com%s", blobPath)): "https://backendexample.com/blobs/" + blobDigest.String(), + }, + okURLs: []string{`.*`}, + wantHeaders: map[string]http.Header{ + "mirrorexample.com": http.Header(map[string][]string{ + "test-a-key": {"a-value"}, + "test-b-key-2": {"b-value-2", "b-value-3"}, + }), + }, + }, + hostWithHeaders("mirrorexample.com", map[string][]string{ + "test-a-key": {"a-value"}, + "test-b-key-2": {"b-value-2", "b-value-3"}, + }), + ), + wantHost: "backendexample.com", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hosts := func(refspec reference.Spec) (reghosts []docker.RegistryHost, _ error) { - host := refspec.Hostname() - for _, m := range append(tt.mirrors, host) { - reghosts = append(reghosts, docker.RegistryHost{ - Client: &http.Client{Transport: tt.tr}, - Host: m, - Scheme: "https", - Path: "/v2", - Capabilities: docker.HostCapabilityPull, - }) - } - return - } fetcher, _, err := newHTTPFetcher(context.Background(), &fetcherConfig{ - hosts: hosts, + hosts: tt.hosts(t), refspec: refspec, desc: ocispec.Descriptor{Digest: blobDigest}, }) @@ -175,25 +278,84 @@ func TestMirror(t *testing.T) { } t.Fatalf("failed to resolve reference: %v", err) } - nurl, err := url.Parse(fetcher.url) - if err != nil { - t.Fatalf("failed to parse url %q: %v", fetcher.url, err) + checkFetcherURL(t, fetcher, tt.wantHost) + + // Test check() + if err := fetcher.check(); err != nil { + t.Fatalf("failed to check fetcher: %v", err) } - if nurl.Hostname() != tt.wantHost { - t.Errorf("invalid hostname %q(%q); want %q", - nurl.Hostname(), nurl.String(), tt.wantHost) + + // Test refreshURL() + if err := fetcher.refreshURL(context.TODO()); err != nil { + t.Fatalf("failed to refresh URL: %v", err) } + checkFetcherURL(t, fetcher, tt.wantHost) }) } } +func checkFetcherURL(t *testing.T, f *httpFetcher, wantHost string) { + nurl, err := url.Parse(f.url) + if err != nil { + t.Fatalf("failed to parse url %q: %v", f.url, err) + } + if nurl.Hostname() != wantHost { + t.Errorf("invalid hostname %q(%q); want %q", nurl.Hostname(), nurl.String(), wantHost) + } +} + type sampleRoundTripper struct { + t *testing.T withCode map[string]int redirectURL map[string]string okURLs []string + wantHeaders map[string]http.Header +} + +func getTestHeaders(headers map[string][]string) map[string][]string { + res := make(map[string][]string) + for k, v := range headers { + if strings.HasPrefix(k, "test-") { + res[k] = v + } + } + return res } func (tr *sampleRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + reqHeader := getTestHeaders(req.Header) + for host, wHeaders := range tr.wantHeaders { + wantHeader := getTestHeaders(wHeaders) + if ok, _ := regexp.Match(host, []byte(req.URL.String())); ok { + if len(wantHeader) != len(reqHeader) { + tr.t.Fatalf("unexpected num of headers; got %d, wanted %d", len(wantHeader), len(reqHeader)) + } + for k, v := range wantHeader { + gotV, ok := reqHeader[k] + if !ok { + tr.t.Fatalf("required header %q not found; got %+v", k, reqHeader) + } + wantVM := make(map[string]struct{}) + for _, e := range v { + wantVM[e] = struct{}{} + } + if len(gotV) != len(v) { + tr.t.Fatalf("unexpected num of header values of %q; got %d, wanted %d", k, len(gotV), len(v)) + } + for _, gotE := range gotV { + delete(wantVM, gotE) + } + if len(wantVM) != 0 { + tr.t.Fatalf("header %q must have elements %+v", k, wantVM) + } + delete(reqHeader, k) + } + } + } + if len(reqHeader) != 0 { + tr.t.Fatalf("unexpected headers %+v", reqHeader) + } + for host, code := range tr.withCode { if ok, _ := regexp.Match(host, []byte(req.URL.String())); ok { return &http.Response{ @@ -332,3 +494,44 @@ func (r *retryRoundTripper) RoundTrip(req *http.Request) (res *http.Response, er } return } + +type hostFactory func(tr http.RoundTripper) docker.RegistryHost + +func hostSimple(host string) hostFactory { + return func(tr http.RoundTripper) docker.RegistryHost { + return docker.RegistryHost{ + Client: &http.Client{Transport: tr}, + Host: host, + Scheme: "https", + Path: "/v2", + Capabilities: docker.HostCapabilityPull, + } + } +} + +func hostWithHeaders(host string, headers http.Header) hostFactory { + return func(tr http.RoundTripper) docker.RegistryHost { + return docker.RegistryHost{ + Client: &http.Client{Transport: tr}, + Host: host, + Scheme: "https", + Path: "/v2", + Capabilities: docker.HostCapabilityPull, + Header: headers, + } + } +} + +func hostsConfig(tr *sampleRoundTripper, mirrors ...hostFactory) func(t *testing.T) source.RegistryHosts { + return func(t *testing.T) source.RegistryHosts { + tr.t = t + return func(refspec reference.Spec) (reghosts []docker.RegistryHost, _ error) { + host := refspec.Hostname() + for _, m := range mirrors { + reghosts = append(reghosts, m(tr)) + } + reghosts = append(reghosts, hostSimple(host)(tr)) + return + } + } +} diff --git a/service/resolver/registry.go b/service/resolver/registry.go index 544ec34d5..ec3d7ad8d 100644 --- a/service/resolver/registry.go +++ b/service/resolver/registry.go @@ -17,6 +17,8 @@ package resolver import ( + "fmt" + "net/http" "time" "github.com/containerd/containerd/reference" @@ -48,6 +50,9 @@ type MirrorConfig struct { // RequestTimeoutSec == 0 indicates the default timeout (defaultRequestTimeoutSec). // RequestTimeoutSec < 0 indicates no timeout. RequestTimeoutSec int `toml:"request_timeout_sec"` + + // Header are additional headers to send to the server + Header map[string]interface{} `toml:"header"` } type Credential func(string, reference.Spec) (string, string, error) @@ -69,6 +74,24 @@ func RegistryHostsFromConfig(cfg Config, credsFuncs ...Credential) source.Regist tr.Timeout = time.Duration(h.RequestTimeoutSec) * time.Second } } // h.RequestTimeoutSec < 0 means "no timeout" + var header http.Header + var err error + if h.Header != nil { + header = http.Header{} + for key, ty := range h.Header { + switch value := ty.(type) { + case string: + header[key] = []string{value} + case []interface{}: + header[key], err = makeStringSlice(value, nil) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("invalid type %v for header %q", ty, key) + } + } + } config := docker.RegistryHost{ Client: tr, Host: h.Host, @@ -78,6 +101,7 @@ func RegistryHostsFromConfig(cfg Config, credsFuncs ...Credential) source.Regist Authorizer: docker.NewDockerAuthorizer( docker.WithAuthClient(tr), docker.WithAuthCreds(multiCredsFuncs(ref, credsFuncs...))), + Header: header, } if localhost, _ := docker.MatchLocalhost(config.Host); localhost || h.Insecure { config.Scheme = "http" @@ -103,3 +127,23 @@ func multiCredsFuncs(ref reference.Spec, credsFuncs ...Credential) func(string) return "", "", nil } } + +// makeStringSlice is a helper func to convert from []interface{} to []string. +// Additionally an optional cb func may be passed to perform string mapping. +// NOTE: Ported from https://github.com/containerd/containerd/blob/v1.6.9/remotes/docker/config/hosts.go#L516-L533 +func makeStringSlice(slice []interface{}, cb func(string) string) ([]string, error) { + out := make([]string, len(slice)) + for i, value := range slice { + str, ok := value.(string) + if !ok { + return nil, fmt.Errorf("unable to cast %v to string", value) + } + + if cb != nil { + out[i] = cb(str) + } else { + out[i] = str + } + } + return out, nil +}