diff --git a/cmd/oras/internal/option/remote.go b/cmd/oras/internal/option/remote.go index 523c087f3..118d75559 100644 --- a/cmd/oras/internal/option/remote.go +++ b/cmd/oras/internal/option/remote.go @@ -23,6 +23,7 @@ import ( "net" "net/http" "os" + "strconv" "strings" "time" @@ -31,6 +32,7 @@ import ( "oras.land/oras-go/v2/registry/remote/auth" "oras.land/oras/internal/credential" "oras.land/oras/internal/crypto" + onet "oras.land/oras/internal/net" "oras.land/oras/internal/trace" "oras.land/oras/internal/version" ) @@ -44,6 +46,9 @@ type Remote struct { Username string PasswordFromStdin bool Password string + + resolveFlag []string + resolveDialContext func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) } // ApplyFlags applies flags to a command flag set. @@ -76,6 +81,10 @@ func (opts *Remote) ApplyFlagsWithPrefix(fs *pflag.FlagSet, prefix, description if fs.Lookup("registry-config") == nil { fs.StringArrayVarP(&opts.Configs, "registry-config", "", nil, "`path` of the authentication file") } + + if fs.Lookup("resolve") == nil { + fs.StringArrayVarP(&opts.resolveFlag, "resolve", "", nil, "customized DNS formatted in `host:port:address`") + } } // ReadPassword tries to read password with optional cmd prompt. @@ -94,6 +103,41 @@ func (opts *Remote) ReadPassword() (err error) { return nil } +// parseResolve parses resolve flag. +func (opts *Remote) parseResolve() error { + if len(opts.resolveFlag) == 0 { + return nil + } + + formatError := func(param, message string) error { + return fmt.Errorf("failed to parse resolve flag %q: %s", param, message) + } + var dialer onet.Dialer + for _, r := range opts.resolveFlag { + parts := strings.SplitN(r, ":", 3) + if len(parts) < 3 { + return formatError(r, "expecting host:port:address") + } + + port, err := strconv.Atoi(parts[1]) + if err != nil { + return formatError(r, "expecting uint64 port") + } + + // ipv6 zone is not parsed + to := net.ParseIP(parts[2]) + if to == nil { + return formatError(r, "invalid IP address") + } + dialer.Add(parts[0], port, to) + } + opts.resolveDialContext = func(base *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + dialer.Dialer = base + return dialer.DialContext + } + return nil +} + // tlsConfig assembles the tls config. func (opts *Remote) tlsConfig() (*tls.Config, error) { config := &tls.Config{ @@ -115,15 +159,24 @@ func (opts *Remote) authClient(registry string, debug bool) (client *auth.Client if err != nil { return nil, err } + if err := opts.parseResolve(); err != nil { + return nil, err + } + resolveDialContext := opts.resolveDialContext + if resolveDialContext == nil { + resolveDialContext = func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return dialer.DialContext + } + } client = &auth.Client{ Client: &http.Client{ // default value are derived from http.DefaultTransport Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ + DialContext: resolveDialContext(&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, - }).DialContext, + }), ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, diff --git a/cmd/oras/internal/option/remote_test.go b/cmd/oras/internal/option/remote_test.go index ebac9fdef..b57ea7422 100644 --- a/cmd/oras/internal/option/remote_test.go +++ b/cmd/oras/internal/option/remote_test.go @@ -23,15 +23,14 @@ import ( "encoding/json" "encoding/pem" "fmt" + nhttp "net/http" + "net/http/httptest" + "net/url" "os" "path/filepath" "reflect" "testing" - nhttp "net/http" - "net/http/httptest" - "net/url" - "github.com/spf13/pflag" "oras.land/oras-go/v2/registry/remote/auth" ) @@ -139,6 +138,31 @@ func TestRemote_authClient_CARoots(t *testing.T) { } } +func TestRemote_authClient_resolve(t *testing.T) { + URL, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid url in test server: %s", ts.URL) + } + + testHost := "test.unit.oras" + opts := Remote{ + resolveFlag: []string{fmt.Sprintf("%s:%s:%s", testHost, URL.Port(), URL.Hostname())}, + Insecure: true, + } + client, err := opts.authClient(testHost, false) + if err != nil { + t.Fatalf("unexpected error when creating auth client: %v", err) + } + req, err := nhttp.NewRequestWithContext(context.Background(), nhttp.MethodGet, fmt.Sprintf("https://%s:%s", testHost, URL.Port()), nil) + if err != nil { + t.Fatalf("unexpected error when generating request: %v", err) + } + _, err = client.Do(req) + if err != nil { + t.Fatalf("unexpected error when sending request: %v", err) + } +} + func TestRemote_NewRegistry(t *testing.T) { caPath := filepath.Join(t.TempDir(), "oras-test.pem") if err := os.WriteFile(caPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ts.Certificate().Raw}), 0644); err != nil { @@ -220,3 +244,44 @@ func TestRemote_isPlainHttp_localhost(t *testing.T) { } } + +func TestRemote_parseResolve_err(t *testing.T) { + tests := []struct { + name string + opts *Remote + wantErr bool + }{ + { + name: "invalid flag", + opts: &Remote{resolveFlag: []string{"this-shouldn't_work"}}, + wantErr: true, + }, + { + name: "no host", + opts: &Remote{resolveFlag: []string{":port:address"}}, + wantErr: true, + }, + { + name: "no address", + opts: &Remote{resolveFlag: []string{"host:port:"}}, + wantErr: true, + }, + { + name: "invalid address", + opts: &Remote{resolveFlag: []string{"host:port:invalid-ip"}}, + wantErr: true, + }, + { + name: "no port", + opts: &Remote{resolveFlag: []string{"host::address"}}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.opts.parseResolve(); (err != nil) != tt.wantErr { + t.Errorf("Remote.parseResolve() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/net/net.go b/internal/net/net.go new file mode 100644 index 000000000..e2d571c8c --- /dev/null +++ b/internal/net/net.go @@ -0,0 +1,45 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "context" + "fmt" + "net" +) + +// Dialer struct provides dialing function with predefined DNS resolves. +type Dialer struct { + *net.Dialer + resolve map[string]string +} + +// Add adds an entry for DNS resolve. +func (d *Dialer) Add(from string, port int, to net.IP) { + if d.resolve == nil { + d.resolve = make(map[string]string) + } + d.resolve[fmt.Sprintf("%s:%d", from, port)] = fmt.Sprintf("%s:%d", to, port) +} + +// DialContext connects to the addr on the named network using the provided +// context. +func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + if resolve, ok := d.resolve[addr]; ok { + addr = resolve + } + return d.Dialer.DialContext(ctx, network, addr) +} diff --git a/internal/net/net_test.go b/internal/net/net_test.go new file mode 100644 index 000000000..ff24da87b --- /dev/null +++ b/internal/net/net_test.go @@ -0,0 +1,69 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "context" + "net" + "reflect" + "testing" +) + +func TestDialer_DialContext(t *testing.T) { + type args struct { + ctx context.Context + network string + addr string + } + tests := []struct { + name string + d *Dialer + args args + want net.Conn + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.d.DialContext(tt.args.ctx, tt.args.network, tt.args.addr) + if (err != nil) != tt.wantErr { + t.Errorf("Dialer.DialContext() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Dialer.DialContext() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRemote_parseResolve_ipv4(t *testing.T) { + host := "mockedHost" + port := "12345" + address := "192.168.1.1" + var d Dialer + d.Add(host, 12345, net.ParseIP(address)) + + if len(d.resolve) != 1 { + t.Fatalf("expect 1 resolve entries but got %v", d.resolve) + } + want := make(map[string]string) + want[host+":"+port] = address + ":" + port + if !reflect.DeepEqual(want, d.resolve) { + t.Fatalf("expecting %v but got %v", want, d.resolve) + } +}