Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support customized DNS resolving for remote registry #696

Merged
merged 17 commits into from
Dec 23, 2022
Merged
57 changes: 55 additions & 2 deletions cmd/oras/internal/option/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net"
"net/http"
"os"
"strconv"
"strings"
"time"

Expand All @@ -31,6 +32,7 @@ import (
"oras.land/oras-go/v2/registry/remote/auth"
"oras.land/oras/internal/credential"
"oras.land/oras/internal/crypto"
onet "oras.land/oras/internal/net"
"oras.land/oras/internal/trace"
"oras.land/oras/internal/version"
)
Expand All @@ -44,6 +46,9 @@ type Remote struct {
Username string
PasswordFromStdin bool
Password string

resolveFlag []string
resolveDialContext func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error)
}

// ApplyFlags applies flags to a command flag set.
Expand Down Expand Up @@ -76,6 +81,10 @@ func (opts *Remote) ApplyFlagsWithPrefix(fs *pflag.FlagSet, prefix, description
if fs.Lookup("registry-config") == nil {
fs.StringArrayVarP(&opts.Configs, "registry-config", "", nil, "`path` of the authentication file")
}

if fs.Lookup("resolve") == nil {
fs.StringArrayVarP(&opts.resolveFlag, "resolve", "", nil, "customized DNS formatted in `host:port:address`")
}
}

// ReadPassword tries to read password with optional cmd prompt.
Expand All @@ -94,6 +103,41 @@ func (opts *Remote) ReadPassword() (err error) {
return nil
}

// parseResolve parses resolve flag.
func (opts *Remote) parseResolve() error {
if len(opts.resolveFlag) == 0 {
return nil
}

formatError := func(param, message string) error {
return fmt.Errorf("failed to parse resolve flag %q: %s", param, message)
}
var dialer onet.Dialer
for _, r := range opts.resolveFlag {
parts := strings.SplitN(r, ":", 3)
if len(parts) < 3 {
return formatError(r, "expecting host:port:address")
}

port, err := strconv.Atoi(parts[1])
if err != nil {
return formatError(r, "expecting uint64 port")
}

// ipv6 zone is not parsed
to := net.ParseIP(parts[2])
if to == nil {
return formatError(r, "invalid IP address")
}
dialer.Add(parts[0], port, to)
}
opts.resolveDialContext = func(base *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
dialer.Dialer = base
return dialer.DialContext
}
return nil
}

// tlsConfig assembles the tls config.
func (opts *Remote) tlsConfig() (*tls.Config, error) {
config := &tls.Config{
Expand All @@ -115,15 +159,24 @@ func (opts *Remote) authClient(registry string, debug bool) (client *auth.Client
if err != nil {
return nil, err
}
if err := opts.parseResolve(); err != nil {
return nil, err
}
resolveDialContext := opts.resolveDialContext
if resolveDialContext == nil {
resolveDialContext = func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
return dialer.DialContext
}
}
client = &auth.Client{
Client: &http.Client{
// default value are derived from http.DefaultTransport
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
DialContext: resolveDialContext(&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
}),
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
Expand Down
73 changes: 69 additions & 4 deletions cmd/oras/internal/option/remote_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ import (
"encoding/json"
"encoding/pem"
"fmt"
nhttp "net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
"testing"

nhttp "net/http"
"net/http/httptest"
"net/url"

"github.com/spf13/pflag"
"oras.land/oras-go/v2/registry/remote/auth"
)
Expand Down Expand Up @@ -139,6 +138,31 @@ func TestRemote_authClient_CARoots(t *testing.T) {
}
}

func TestRemote_authClient_resolve(t *testing.T) {
URL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("invalid url in test server: %s", ts.URL)
}

testHost := "test.unit.oras"
opts := Remote{
resolveFlag: []string{fmt.Sprintf("%s:%s:%s", testHost, URL.Port(), URL.Hostname())},
Insecure: true,
}
client, err := opts.authClient(testHost, false)
if err != nil {
t.Fatalf("unexpected error when creating auth client: %v", err)
}
req, err := nhttp.NewRequestWithContext(context.Background(), nhttp.MethodGet, fmt.Sprintf("https://%s:%s", testHost, URL.Port()), nil)
if err != nil {
t.Fatalf("unexpected error when generating request: %v", err)
}
_, err = client.Do(req)
if err != nil {
t.Fatalf("unexpected error when sending request: %v", err)
}
}

func TestRemote_NewRegistry(t *testing.T) {
caPath := filepath.Join(t.TempDir(), "oras-test.pem")
if err := os.WriteFile(caPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ts.Certificate().Raw}), 0644); err != nil {
Expand Down Expand Up @@ -220,3 +244,44 @@ func TestRemote_isPlainHttp_localhost(t *testing.T) {

}
}

func TestRemote_parseResolve_err(t *testing.T) {
tests := []struct {
name string
opts *Remote
wantErr bool
}{
{
name: "invalid flag",
opts: &Remote{resolveFlag: []string{"this-shouldn't_work"}},
wantErr: true,
},
{
name: "no host",
opts: &Remote{resolveFlag: []string{":port:address"}},
wantErr: true,
},
{
name: "no address",
opts: &Remote{resolveFlag: []string{"host:port:"}},
wantErr: true,
},
{
name: "invalid address",
opts: &Remote{resolveFlag: []string{"host:port:invalid-ip"}},
wantErr: true,
},
{
name: "no port",
opts: &Remote{resolveFlag: []string{"host::address"}},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.opts.parseResolve(); (err != nil) != tt.wantErr {
t.Errorf("Remote.parseResolve() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
45 changes: 45 additions & 0 deletions internal/net/net.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
Copyright The ORAS Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package net

import (
"context"
"fmt"
"net"
)

// Dialer struct provides dialing function with predefined DNS resolves.
type Dialer struct {
*net.Dialer
resolve map[string]string
}

// Add adds an entry for DNS resolve.
func (d *Dialer) Add(from string, port int, to net.IP) {
if d.resolve == nil {
d.resolve = make(map[string]string)
}
d.resolve[fmt.Sprintf("%s:%d", from, port)] = fmt.Sprintf("%s:%d", to, port)
}

// DialContext connects to the addr on the named network using
// the provided context.
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if resolve, ok := d.resolve[addr]; ok {
addr = resolve
}
return d.Dialer.DialContext(ctx, network, addr)
}
69 changes: 69 additions & 0 deletions internal/net/net_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
Copyright The ORAS Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package net

import (
"context"
"net"
"reflect"
"testing"
)

func TestDialer_DialContext(t *testing.T) {
type args struct {
ctx context.Context
network string
addr string
}
tests := []struct {
name string
d *Dialer
args args
want net.Conn
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.d.DialContext(tt.args.ctx, tt.args.network, tt.args.addr)
if (err != nil) != tt.wantErr {
t.Errorf("Dialer.DialContext() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Dialer.DialContext() = %v, want %v", got, tt.want)
}
})
}
}

func TestRemote_parseResolve_ipv4(t *testing.T) {
host := "mockedHost"
port := "12345"
address := "192.168.1.1"
var d Dialer
d.Add(host, 12345, net.ParseIP(address))

if len(d.resolve) != 1 {
t.Fatalf("expect 1 resolve entries but got %v", d.resolve)
}
want := make(map[string]string)
want[host+":"+port] = address + ":" + port
if !reflect.DeepEqual(want, d.resolve) {
t.Fatalf("expecting %v but got %v", want, d.resolve)
}
}