Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow user to specify destination port via resolve flag #801

Merged
merged 14 commits into from
Feb 22, 2023
29 changes: 18 additions & 11 deletions cmd/oras/internal/option/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (opts *Remote) ApplyFlagsWithPrefix(fs *pflag.FlagSet, prefix, description
}

if fs.Lookup("resolve") == nil {
fs.StringArrayVarP(&opts.resolveFlag, "resolve", "", nil, "customized DNS formatted in `host:port:address`")
fs.StringArrayVarP(&opts.resolveFlag, "resolve", "", nil, "customized DNS formatted in `host:port:address[:address_port]`")
}
}

Expand Down Expand Up @@ -144,22 +144,29 @@ func (opts *Remote) parseResolve() error {
}
var dialer onet.Dialer
for _, r := range opts.resolveFlag {
parts := strings.SplitN(r, ":", 3)
if len(parts) < 3 {
return formatError(r, "expecting host:port:address")
parts := strings.SplitN(r, ":", 4)
length := len(parts)
if length < 3 {
return formatError(r, "expecting host:port:address[:address_port]")
}

port, err := strconv.Atoi(parts[1])
host := parts[0]
hostPort, err := strconv.Atoi(parts[1])
if err != nil {
return formatError(r, "expecting uint64 port")
return formatError(r, "expecting uint64 host port")
}

// ipv6 zone is not parsed
to := net.ParseIP(parts[2])
if to == nil {
address := net.ParseIP(parts[2])
if address == nil {
return formatError(r, "invalid IP address")
}
dialer.Add(parts[0], port, to)
addressPort := hostPort
if length > 3 {
addressPort, err = strconv.Atoi(parts[3])
if err != nil {
return formatError(r, "expecting uint64 address port")
}
}
dialer.Add(host, hostPort, address, addressPort)
}
opts.resolveDialContext = func(base *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
dialer.Dialer = base
Expand Down
69 changes: 49 additions & 20 deletions cmd/oras/internal/option/remote_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,40 +247,69 @@ func TestRemote_isPlainHttp_localhost(t *testing.T) {

func TestRemote_parseResolve_err(t *testing.T) {
tests := []struct {
name string
opts *Remote
wantErr bool
name string
opts *Remote
}{
{
name: "invalid flag",
opts: &Remote{resolveFlag: []string{"this-shouldn't_work"}},
wantErr: true,
name: "invalid flag",
opts: &Remote{resolveFlag: []string{"this-shouldn't_work"}},
},
{
name: "no host",
opts: &Remote{resolveFlag: []string{":port:address"}},
wantErr: true,
name: "no host",
opts: &Remote{resolveFlag: []string{":port:address"}},
},
{
name: "no address",
opts: &Remote{resolveFlag: []string{"host:port:"}},
wantErr: true,
name: "no address",
opts: &Remote{resolveFlag: []string{"host:port:"}},
},
{
name: "invalid address",
opts: &Remote{resolveFlag: []string{"host:port:invalid-ip"}},
wantErr: true,
name: "invalid address",
opts: &Remote{resolveFlag: []string{"host:port:invalid-ip"}},
},
{
name: "no port",
opts: &Remote{resolveFlag: []string{"host::address"}},
wantErr: true,
name: "no port",
opts: &Remote{resolveFlag: []string{"host::address"}},
},
{
name: "invalid source port",
opts: &Remote{resolveFlag: []string{"host:port:address"}},
},
{
name: "invalid destination port",
opts: &Remote{resolveFlag: []string{"host:443:address:port"}},
},
{
name: "no source port",
opts: &Remote{resolveFlag: []string{"host::address"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.opts.parseResolve(); err == nil {
t.Errorf("Expecting error in Remote.parseResolve()")
}
})
}
}

func TestRemote_parseResolve(t *testing.T) {
tests := []struct {
name string
opts *Remote
}{
{
name: "fromHost:fromPort:toIp",
opts: &Remote{resolveFlag: []string{"host:443:0.0.0.0"}},
},
{
name: "fromHost:fromPort:toIp:toPort",
opts: &Remote{resolveFlag: []string{"host:443:0.0.0.0:5000"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.opts.parseResolve(); (err != nil) != tt.wantErr {
t.Errorf("Remote.parseResolve() error = %v, wantErr %v", err, tt.wantErr)
if err := tt.opts.parseResolve(); err != nil {
t.Errorf("Remote.parseResolve() error = %v", err)
}
})
}
Expand Down
8 changes: 4 additions & 4 deletions internal/net/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,18 @@ type Dialer struct {
}

// Add adds an entry for DNS resolve.
func (d *Dialer) Add(from string, port int, to net.IP) {
func (d *Dialer) Add(from string, fromPort int, to net.IP, toPort int) {
if d.resolve == nil {
d.resolve = make(map[string]string)
}
d.resolve[fmt.Sprintf("%s:%d", from, port)] = fmt.Sprintf("%s:%d", to, port)
d.resolve[fmt.Sprintf("%s:%d", from, fromPort)] = fmt.Sprintf("%s:%d", to, toPort)
}

// DialContext connects to the addr on the named network using the provided
// context.
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if resolve, ok := d.resolve[addr]; ok {
addr = resolve
if resolved, ok := d.resolve[addr]; ok {
addr = resolved
}
return d.Dialer.DialContext(ctx, network, addr)
}
38 changes: 5 additions & 33 deletions internal/net/net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,53 +16,25 @@ limitations under the License.
package net

import (
"context"
"fmt"
"net"
"reflect"
"testing"
)

func TestDialer_DialContext(t *testing.T) {
type args struct {
ctx context.Context
network string
addr string
}
tests := []struct {
name string
d *Dialer
args args
want net.Conn
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.d.DialContext(tt.args.ctx, tt.args.network, tt.args.addr)
if (err != nil) != tt.wantErr {
t.Errorf("Dialer.DialContext() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Dialer.DialContext() = %v, want %v", got, tt.want)
}
})
}
}

func TestRemote_parseResolve_ipv4(t *testing.T) {
host := "mockedHost"
port := "12345"
hostPort := 443
address := "192.168.1.1"
addressPort := 12345
var d Dialer
d.Add(host, 12345, net.ParseIP(address))
d.Add(host, hostPort, net.ParseIP(address), addressPort)

if len(d.resolve) != 1 {
t.Fatalf("expect 1 resolve entries but got %v", d.resolve)
}
want := make(map[string]string)
want[host+":"+port] = address + ":" + port
want[host+":"+fmt.Sprint(hostPort)] = address + ":" + fmt.Sprint(addressPort)
if !reflect.DeepEqual(want, d.resolve) {
t.Fatalf("expecting %v but got %v", want, d.resolve)
}
Expand Down