diff --git a/cmd/oras/internal/option/remote.go b/cmd/oras/internal/option/remote.go index b2e87ac7a..cf918a430 100644 --- a/cmd/oras/internal/option/remote.go +++ b/cmd/oras/internal/option/remote.go @@ -30,6 +30,7 @@ import ( "github.com/spf13/pflag" "oras.land/oras-go/v2/registry/remote" "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/retry" "oras.land/oras/internal/credential" "oras.land/oras/internal/crypto" onet "oras.land/oras/internal/net" @@ -205,22 +206,25 @@ func (opts *Remote) authClient(registry string, debug bool) (client *auth.Client return dialer.DialContext } } + // default value are derived from http.DefaultTransport + baseTransport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: resolveDialContext(&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }), + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: config, + } client = &auth.Client{ Client: &http.Client{ - // default value are derived from http.DefaultTransport - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: resolveDialContext(&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }), - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - TLSClientConfig: config, - }, + // http.RoundTripper with a retry using the DefaultPolicy + // see: https://pkg.go.dev/oras.land/oras-go/v2/registry/remote/retry#Policy + Transport: retry.NewTransport(baseTransport), }, Cache: auth.NewCache(), Header: opts.headers, diff --git a/cmd/oras/internal/option/remote_test.go b/cmd/oras/internal/option/remote_test.go index 7428a4844..099c70c04 100644 --- a/cmd/oras/internal/option/remote_test.go +++ b/cmd/oras/internal/option/remote_test.go @@ -23,6 +23,7 @@ import ( "encoding/json" "encoding/pem" "fmt" + "net/http" nhttp "net/http" "net/http/httptest" "net/url" @@ -415,3 +416,53 @@ func TestRemote_parseCustomHeaders(t *testing.T) { }) } } + +func TestRemote_NewRepository_Retry(t *testing.T) { + caPath := filepath.Join(t.TempDir(), "oras-test.pem") + if err := os.WriteFile(caPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ts.Certificate().Raw}), 0644); err != nil { + t.Fatalf("unexpected error: %v", err) + } + retries, count := 3, 0 + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count++ + if count < retries { + http.Error(w, "error", http.StatusTooManyRequests) + return + } + json.NewEncoder(w).Encode(testTagList) + })) + defer ts.Close() + + opts := struct { + Remote + Common + }{ + Remote{ + CACertFilePath: caPath, + }, + Common{}, + } + + uri, err := url.ParseRequestURI(ts.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + repo, err := opts.NewRepository(uri.Host+"/"+testRepo, opts.Common) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if err = repo.Tags(context.Background(), "", func(got []string) error { + want := []string{"tag"} + if len(got) != len(testTagList.Tags) || !reflect.DeepEqual(got, want) { + return fmt.Errorf("expect: %v, got: %v", testTagList.Tags, got) + } + return nil + }); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if count != retries { + t.Errorf("expected %d retries, got %d", retries, count) + } +}