Skip to content

Commit

Permalink
Fix blacklistIPs JS configuration (#1004)
Browse files Browse the repository at this point in the history
Net.IPNet does not support JSON unmarshaling, causing the blacklistIPs configuration to fail when set through JS. Fixed by wrapping defining a custom type for net.IPNet and implementing UnmarshalText for it.

Fixes #973
THoelzel authored and na-- committed Apr 25, 2019
1 parent 3fc300c commit 1621aad
Showing 5 changed files with 127 additions and 14 deletions.
7 changes: 3 additions & 4 deletions cmd/options.go
Original file line number Diff line number Diff line change
@@ -22,7 +22,6 @@ package cmd

import (
"fmt"
"net"
"strings"
"time"

@@ -122,9 +121,9 @@ func getOptions(flags *pflag.FlagSet) (lib.Options, error) {
return opts, err
}
for _, s := range blacklistIPStrings {
_, net, err := net.ParseCIDR(s)
if err != nil {
return opts, errors.Wrap(err, "blacklist-ip")
net, parseErr := lib.ParseCIDR(s)
if parseErr != nil {
return opts, errors.Wrap(parseErr, "blacklist-ip")
}
opts.BlacklistIPs = append(opts.BlacklistIPs, net)
}
44 changes: 41 additions & 3 deletions js/runner_test.go
Original file line number Diff line number Diff line change
@@ -747,7 +747,7 @@ func TestVUIntegrationInsecureRequests(t *testing.T) {
}
}

func TestVUIntegrationBlacklist(t *testing.T) {
func TestVUIntegrationBlacklistOption(t *testing.T) {
r1, err := New(&lib.SourceData{
Filename: "/script.js",
Data: []byte(`
@@ -759,13 +759,13 @@ func TestVUIntegrationBlacklist(t *testing.T) {
return
}

_, cidr, err := net.ParseCIDR("10.0.0.0/8")
cidr, err := lib.ParseCIDR("10.0.0.0/8")
if !assert.NoError(t, err) {
return
}
r1.SetOptions(lib.Options{
Throw: null.BoolFrom(true),
BlacklistIPs: []*net.IPNet{cidr},
BlacklistIPs: []*lib.IPNet{cidr},
})

r2, err := NewFromArchive(r1.MakeArchive(), lib.RuntimeOptions{})
@@ -786,6 +786,44 @@ func TestVUIntegrationBlacklist(t *testing.T) {
}
}

func TestVUIntegrationBlacklistScript(t *testing.T) {
r1, err := New(&lib.SourceData{
Filename: "/script.js",
Data: []byte(`
import http from "k6/http";
export let options = {
throw: true,
blacklistIPs: ["10.0.0.0/8"],
};
export default function() { http.get("http://10.1.2.3/"); }
`),
}, afero.NewMemMapFs(), lib.RuntimeOptions{})
if !assert.NoError(t, err) {
return
}

r2, err := NewFromArchive(r1.MakeArchive(), lib.RuntimeOptions{})
if !assert.NoError(t, err) {
return
}

runners := map[string]*Runner{"Source": r1, "Archive": r2}

for name, r := range runners {
r := r
t.Run(name, func(t *testing.T) {
vu, err := r.NewVU(make(chan stats.SampleContainer, 100))
if !assert.NoError(t, err) {
return
}
err = vu.RunOnce(context.Background())
assert.EqualError(t, err, "GoError: Get http://10.1.2.3/: IP (10.1.2.3) is in a blacklisted range (10.0.0.0/8)")
})
}
}

func TestVUIntegrationHosts(t *testing.T) {
tb := testutils.NewHTTPMultiBin(t)
defer tb.Cleanup()
11 changes: 6 additions & 5 deletions lib/netext/dialer.go
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@ import (
"sync/atomic"
"time"

"github.com/loadimpact/k6/lib"
"github.com/loadimpact/k6/lib/metrics"
"github.com/loadimpact/k6/stats"

@@ -40,7 +41,7 @@ type Dialer struct {
net.Dialer

Resolver *dnscache.Resolver
Blacklist []*net.IPNet
Blacklist []*lib.IPNet
Hosts map[string]net.IP

BytesRead int64
@@ -58,7 +59,7 @@ func NewDialer(dialer net.Dialer) *Dialer {
// BlackListedIPError is an error that is returned when a given IP is blacklisted
type BlackListedIPError struct {
ip net.IP
net *net.IPNet
net *lib.IPNet
}

func (b BlackListedIPError) Error() string {
@@ -80,9 +81,9 @@ func (d *Dialer) DialContext(ctx context.Context, proto, addr string) (net.Conn,
}
}

for _, net := range d.Blacklist {
if net.Contains(ip) {
return nil, BlackListedIPError{ip: ip, net: net}
for _, ipnet := range d.Blacklist {
if (*net.IPNet)(ipnet).Contains(ip) {
return nil, BlackListedIPError{ip: ip, net: ipnet}
}
}
ipStr := ip.String()
33 changes: 32 additions & 1 deletion lib/options.go
Original file line number Diff line number Diff line change
@@ -205,6 +205,37 @@ func (c *TLSAuth) Certificate() (*tls.Certificate, error) {
return c.certificate, nil
}

// IPNet is a wrapper around net.IPNet for JSON unmarshalling
type IPNet net.IPNet

func (ipnet *IPNet) String() string {
return (*net.IPNet)(ipnet).String()
}

// UnmarshalText populates the IPNet from the given CIDR
func (ipnet *IPNet) UnmarshalText(b []byte) error {
newIPNet, err := ParseCIDR(string(b))
if err != nil {
return errors.Wrap(err, "Failed to parse CIDR")
}

*ipnet = *newIPNet

return nil
}

// ParseCIDR creates an IPNet out of a CIDR string
func ParseCIDR(s string) (*IPNet, error) {
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
return nil, err
}

parsedIPNet := IPNet(*ipnet)

return &parsedIPNet, nil
}

type Options struct {
// Should the test start in a paused state?
Paused null.Bool `json:"paused" envconfig:"paused"`
@@ -258,7 +289,7 @@ type Options struct {
Thresholds map[string]stats.Thresholds `json:"thresholds" envconfig:"thresholds"`

// Blacklist IP ranges that tests may not contact. Mainly useful in hosted setups.
BlacklistIPs []*net.IPNet `json:"blacklistIPs" envconfig:"blacklist_ips"`
BlacklistIPs []*IPNet `json:"blacklistIPs" envconfig:"blacklist_ips"`

// Hosts overrides dns entries for given hosts
Hosts map[string]net.IP `json:"hosts" envconfig:"hosts"`
46 changes: 45 additions & 1 deletion lib/options_test.go
Original file line number Diff line number Diff line change
@@ -308,7 +308,7 @@ func TestOptions(t *testing.T) {
})
t.Run("BlacklistIPs", func(t *testing.T) {
opts := Options{}.Apply(Options{
BlacklistIPs: []*net.IPNet{{
BlacklistIPs: []*IPNet{{
IP: net.IPv4zero,
Mask: net.CIDRMask(1, 1),
}},
@@ -513,3 +513,47 @@ func TestTagSetTextUnmarshal(t *testing.T) {
require.Equal(t, (map[string]bool)(*set), expected)
}
}

func TestCIDRUnmarshal(t *testing.T) {

var testData = []struct {
input string
expectedOutput *IPNet
expactFailure bool
}{
{
"10.0.0.0/8",
&IPNet{
IP: net.IP{10, 0, 0, 0},
Mask: net.IPv4Mask(255, 0, 0, 0),
},
false,
},
{
"fc00:1234:5678::/48",
&IPNet{
IP: net.ParseIP("fc00:1234:5678::"),
Mask: net.CIDRMask(48, 128),
},
false,
},
{"10.0.0.0", nil, true},
{"fc00:1234:5678::", nil, true},
{"fc00::1234::/48", nil, true},
}

for _, data := range testData {
data := data
t.Run(data.input, func(t *testing.T) {
actualIPNet := &IPNet{}
err := actualIPNet.UnmarshalText([]byte(data.input))

if data.expactFailure {
require.EqualError(t, err, "Failed to parse CIDR: invalid CIDR address: "+data.input)
} else {
require.NoError(t, err)
assert.Equal(t, data.expectedOutput, actualIPNet)
}
})
}
}

0 comments on commit 1621aad

Please sign in to comment.