Skip to content

Commit

Permalink
feat: support customized DNS resolving for remote registry (#696)
Browse files Browse the repository at this point in the history
Resolves #688

Signed-off-by: Billy Zha <[email protected]>
  • Loading branch information
qweeah authored Dec 23, 2022
1 parent 42f62e0 commit 3ad6eee
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 6 deletions.
57 changes: 55 additions & 2 deletions cmd/oras/internal/option/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net"
"net/http"
"os"
"strconv"
"strings"
"time"

Expand All @@ -31,6 +32,7 @@ import (
"oras.land/oras-go/v2/registry/remote/auth"
"oras.land/oras/internal/credential"
"oras.land/oras/internal/crypto"
onet "oras.land/oras/internal/net"
"oras.land/oras/internal/trace"
"oras.land/oras/internal/version"
)
Expand All @@ -44,6 +46,9 @@ type Remote struct {
Username string
PasswordFromStdin bool
Password string

resolveFlag []string
resolveDialContext func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error)
}

// ApplyFlags applies flags to a command flag set.
Expand Down Expand Up @@ -76,6 +81,10 @@ func (opts *Remote) ApplyFlagsWithPrefix(fs *pflag.FlagSet, prefix, description
if fs.Lookup("registry-config") == nil {
fs.StringArrayVarP(&opts.Configs, "registry-config", "", nil, "`path` of the authentication file")
}

if fs.Lookup("resolve") == nil {
fs.StringArrayVarP(&opts.resolveFlag, "resolve", "", nil, "customized DNS formatted in `host:port:address`")
}
}

// ReadPassword tries to read password with optional cmd prompt.
Expand All @@ -94,6 +103,41 @@ func (opts *Remote) ReadPassword() (err error) {
return nil
}

// parseResolve parses resolve flag.
func (opts *Remote) parseResolve() error {
if len(opts.resolveFlag) == 0 {
return nil
}

formatError := func(param, message string) error {
return fmt.Errorf("failed to parse resolve flag %q: %s", param, message)
}
var dialer onet.Dialer
for _, r := range opts.resolveFlag {
parts := strings.SplitN(r, ":", 3)
if len(parts) < 3 {
return formatError(r, "expecting host:port:address")
}

port, err := strconv.Atoi(parts[1])
if err != nil {
return formatError(r, "expecting uint64 port")
}

// ipv6 zone is not parsed
to := net.ParseIP(parts[2])
if to == nil {
return formatError(r, "invalid IP address")
}
dialer.Add(parts[0], port, to)
}
opts.resolveDialContext = func(base *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
dialer.Dialer = base
return dialer.DialContext
}
return nil
}

// tlsConfig assembles the tls config.
func (opts *Remote) tlsConfig() (*tls.Config, error) {
config := &tls.Config{
Expand All @@ -115,15 +159,24 @@ func (opts *Remote) authClient(registry string, debug bool) (client *auth.Client
if err != nil {
return nil, err
}
if err := opts.parseResolve(); err != nil {
return nil, err
}
resolveDialContext := opts.resolveDialContext
if resolveDialContext == nil {
resolveDialContext = func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
return dialer.DialContext
}
}
client = &auth.Client{
Client: &http.Client{
// default value are derived from http.DefaultTransport
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
DialContext: resolveDialContext(&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
}),
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
Expand Down
73 changes: 69 additions & 4 deletions cmd/oras/internal/option/remote_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ import (
"encoding/json"
"encoding/pem"
"fmt"
nhttp "net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
"testing"

nhttp "net/http"
"net/http/httptest"
"net/url"

"github.com/spf13/pflag"
"oras.land/oras-go/v2/registry/remote/auth"
)
Expand Down Expand Up @@ -139,6 +138,31 @@ func TestRemote_authClient_CARoots(t *testing.T) {
}
}

func TestRemote_authClient_resolve(t *testing.T) {
URL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("invalid url in test server: %s", ts.URL)
}

testHost := "test.unit.oras"
opts := Remote{
resolveFlag: []string{fmt.Sprintf("%s:%s:%s", testHost, URL.Port(), URL.Hostname())},
Insecure: true,
}
client, err := opts.authClient(testHost, false)
if err != nil {
t.Fatalf("unexpected error when creating auth client: %v", err)
}
req, err := nhttp.NewRequestWithContext(context.Background(), nhttp.MethodGet, fmt.Sprintf("https://%s:%s", testHost, URL.Port()), nil)
if err != nil {
t.Fatalf("unexpected error when generating request: %v", err)
}
_, err = client.Do(req)
if err != nil {
t.Fatalf("unexpected error when sending request: %v", err)
}
}

func TestRemote_NewRegistry(t *testing.T) {
caPath := filepath.Join(t.TempDir(), "oras-test.pem")
if err := os.WriteFile(caPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ts.Certificate().Raw}), 0644); err != nil {
Expand Down Expand Up @@ -220,3 +244,44 @@ func TestRemote_isPlainHttp_localhost(t *testing.T) {

}
}

func TestRemote_parseResolve_err(t *testing.T) {
tests := []struct {
name string
opts *Remote
wantErr bool
}{
{
name: "invalid flag",
opts: &Remote{resolveFlag: []string{"this-shouldn't_work"}},
wantErr: true,
},
{
name: "no host",
opts: &Remote{resolveFlag: []string{":port:address"}},
wantErr: true,
},
{
name: "no address",
opts: &Remote{resolveFlag: []string{"host:port:"}},
wantErr: true,
},
{
name: "invalid address",
opts: &Remote{resolveFlag: []string{"host:port:invalid-ip"}},
wantErr: true,
},
{
name: "no port",
opts: &Remote{resolveFlag: []string{"host::address"}},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.opts.parseResolve(); (err != nil) != tt.wantErr {
t.Errorf("Remote.parseResolve() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
45 changes: 45 additions & 0 deletions internal/net/net.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
Copyright The ORAS Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package net

import (
"context"
"fmt"
"net"
)

// Dialer struct provides dialing function with predefined DNS resolves.
type Dialer struct {
*net.Dialer
resolve map[string]string
}

// Add adds an entry for DNS resolve.
func (d *Dialer) Add(from string, port int, to net.IP) {
if d.resolve == nil {
d.resolve = make(map[string]string)
}
d.resolve[fmt.Sprintf("%s:%d", from, port)] = fmt.Sprintf("%s:%d", to, port)
}

// DialContext connects to the addr on the named network using the provided
// context.
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if resolve, ok := d.resolve[addr]; ok {
addr = resolve
}
return d.Dialer.DialContext(ctx, network, addr)
}
69 changes: 69 additions & 0 deletions internal/net/net_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
Copyright The ORAS Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package net

import (
"context"
"net"
"reflect"
"testing"
)

func TestDialer_DialContext(t *testing.T) {
type args struct {
ctx context.Context
network string
addr string
}
tests := []struct {
name string
d *Dialer
args args
want net.Conn
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.d.DialContext(tt.args.ctx, tt.args.network, tt.args.addr)
if (err != nil) != tt.wantErr {
t.Errorf("Dialer.DialContext() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Dialer.DialContext() = %v, want %v", got, tt.want)
}
})
}
}

func TestRemote_parseResolve_ipv4(t *testing.T) {
host := "mockedHost"
port := "12345"
address := "192.168.1.1"
var d Dialer
d.Add(host, 12345, net.ParseIP(address))

if len(d.resolve) != 1 {
t.Fatalf("expect 1 resolve entries but got %v", d.resolve)
}
want := make(map[string]string)
want[host+":"+port] = address + ":" + port
if !reflect.DeepEqual(want, d.resolve) {
t.Fatalf("expecting %v but got %v", want, d.resolve)
}
}

0 comments on commit 3ad6eee

Please sign in to comment.