diff --git a/cmd/options.go b/cmd/options.go index 40164b8efb8..91db3805f8a 100644 --- a/cmd/options.go +++ b/cmd/options.go @@ -22,7 +22,6 @@ package cmd import ( "fmt" - "net" "strings" "time" @@ -122,9 +121,9 @@ func getOptions(flags *pflag.FlagSet) (lib.Options, error) { return opts, err } for _, s := range blacklistIPStrings { - _, net, err := net.ParseCIDR(s) - if err != nil { - return opts, errors.Wrap(err, "blacklist-ip") + net, parseErr := lib.ParseCIDR(s) + if parseErr != nil { + return opts, errors.Wrap(parseErr, "blacklist-ip") } opts.BlacklistIPs = append(opts.BlacklistIPs, net) } diff --git a/js/runner_test.go b/js/runner_test.go index bd15201293d..a19bab7f325 100644 --- a/js/runner_test.go +++ b/js/runner_test.go @@ -747,7 +747,7 @@ func TestVUIntegrationInsecureRequests(t *testing.T) { } } -func TestVUIntegrationBlacklist(t *testing.T) { +func TestVUIntegrationBlacklistOption(t *testing.T) { r1, err := New(&lib.SourceData{ Filename: "/script.js", Data: []byte(` @@ -759,13 +759,13 @@ func TestVUIntegrationBlacklist(t *testing.T) { return } - _, cidr, err := net.ParseCIDR("10.0.0.0/8") + cidr, err := lib.ParseCIDR("10.0.0.0/8") if !assert.NoError(t, err) { return } r1.SetOptions(lib.Options{ Throw: null.BoolFrom(true), - BlacklistIPs: []*net.IPNet{cidr}, + BlacklistIPs: []*lib.IPNet{cidr}, }) r2, err := NewFromArchive(r1.MakeArchive(), lib.RuntimeOptions{}) @@ -786,6 +786,44 @@ func TestVUIntegrationBlacklist(t *testing.T) { } } +func TestVUIntegrationBlacklistScript(t *testing.T) { + r1, err := New(&lib.SourceData{ + Filename: "/script.js", + Data: []byte(` + import http from "k6/http"; + + export let options = { + throw: true, + blacklistIPs: ["10.0.0.0/8"], + }; + + export default function() { http.get("http://10.1.2.3/"); } + `), + }, afero.NewMemMapFs(), lib.RuntimeOptions{}) + if !assert.NoError(t, err) { + return + } + + r2, err := NewFromArchive(r1.MakeArchive(), lib.RuntimeOptions{}) + if !assert.NoError(t, err) { + return + } + + runners := map[string]*Runner{"Source": r1, "Archive": r2} + + for name, r := range runners { + r := r + t.Run(name, func(t *testing.T) { + vu, err := r.NewVU(make(chan stats.SampleContainer, 100)) + if !assert.NoError(t, err) { + return + } + err = vu.RunOnce(context.Background()) + assert.EqualError(t, err, "GoError: Get http://10.1.2.3/: IP (10.1.2.3) is in a blacklisted range (10.0.0.0/8)") + }) + } +} + func TestVUIntegrationHosts(t *testing.T) { tb := testutils.NewHTTPMultiBin(t) defer tb.Cleanup() diff --git a/lib/netext/dialer.go b/lib/netext/dialer.go index 7026f859961..f95d91c099b 100644 --- a/lib/netext/dialer.go +++ b/lib/netext/dialer.go @@ -28,6 +28,7 @@ import ( "sync/atomic" "time" + "github.com/loadimpact/k6/lib" "github.com/loadimpact/k6/lib/metrics" "github.com/loadimpact/k6/stats" @@ -40,7 +41,7 @@ type Dialer struct { net.Dialer Resolver *dnscache.Resolver - Blacklist []*net.IPNet + Blacklist []*lib.IPNet Hosts map[string]net.IP BytesRead int64 @@ -58,7 +59,7 @@ func NewDialer(dialer net.Dialer) *Dialer { // BlackListedIPError is an error that is returned when a given IP is blacklisted type BlackListedIPError struct { ip net.IP - net *net.IPNet + net *lib.IPNet } func (b BlackListedIPError) Error() string { @@ -80,9 +81,9 @@ func (d *Dialer) DialContext(ctx context.Context, proto, addr string) (net.Conn, } } - for _, net := range d.Blacklist { - if net.Contains(ip) { - return nil, BlackListedIPError{ip: ip, net: net} + for _, ipnet := range d.Blacklist { + if (*net.IPNet)(ipnet).Contains(ip) { + return nil, BlackListedIPError{ip: ip, net: ipnet} } } ipStr := ip.String() diff --git a/lib/options.go b/lib/options.go index f4a4657e401..4c743e2d4b9 100644 --- a/lib/options.go +++ b/lib/options.go @@ -205,6 +205,37 @@ func (c *TLSAuth) Certificate() (*tls.Certificate, error) { return c.certificate, nil } +// IPNet is a wrapper around net.IPNet for JSON unmarshalling +type IPNet net.IPNet + +func (ipnet *IPNet) String() string { + return (*net.IPNet)(ipnet).String() +} + +// UnmarshalText populates the IPNet from the given CIDR +func (ipnet *IPNet) UnmarshalText(b []byte) error { + newIPNet, err := ParseCIDR(string(b)) + if err != nil { + return errors.Wrap(err, "Failed to parse CIDR") + } + + *ipnet = *newIPNet + + return nil +} + +// ParseCIDR creates an IPNet out of a CIDR string +func ParseCIDR(s string) (*IPNet, error) { + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + return nil, err + } + + parsedIPNet := IPNet(*ipnet) + + return &parsedIPNet, nil +} + type Options struct { // Should the test start in a paused state? Paused null.Bool `json:"paused" envconfig:"paused"` @@ -258,7 +289,7 @@ type Options struct { Thresholds map[string]stats.Thresholds `json:"thresholds" envconfig:"thresholds"` // Blacklist IP ranges that tests may not contact. Mainly useful in hosted setups. - BlacklistIPs []*net.IPNet `json:"blacklistIPs" envconfig:"blacklist_ips"` + BlacklistIPs []*IPNet `json:"blacklistIPs" envconfig:"blacklist_ips"` // Hosts overrides dns entries for given hosts Hosts map[string]net.IP `json:"hosts" envconfig:"hosts"` diff --git a/lib/options_test.go b/lib/options_test.go index 63cf4ae891e..0f601ad9ef4 100644 --- a/lib/options_test.go +++ b/lib/options_test.go @@ -308,7 +308,7 @@ func TestOptions(t *testing.T) { }) t.Run("BlacklistIPs", func(t *testing.T) { opts := Options{}.Apply(Options{ - BlacklistIPs: []*net.IPNet{{ + BlacklistIPs: []*IPNet{{ IP: net.IPv4zero, Mask: net.CIDRMask(1, 1), }}, @@ -513,3 +513,47 @@ func TestTagSetTextUnmarshal(t *testing.T) { require.Equal(t, (map[string]bool)(*set), expected) } } + +func TestCIDRUnmarshal(t *testing.T) { + + var testData = []struct { + input string + expectedOutput *IPNet + expactFailure bool + }{ + { + "10.0.0.0/8", + &IPNet{ + IP: net.IP{10, 0, 0, 0}, + Mask: net.IPv4Mask(255, 0, 0, 0), + }, + false, + }, + { + "fc00:1234:5678::/48", + &IPNet{ + IP: net.ParseIP("fc00:1234:5678::"), + Mask: net.CIDRMask(48, 128), + }, + false, + }, + {"10.0.0.0", nil, true}, + {"fc00:1234:5678::", nil, true}, + {"fc00::1234::/48", nil, true}, + } + + for _, data := range testData { + data := data + t.Run(data.input, func(t *testing.T) { + actualIPNet := &IPNet{} + err := actualIPNet.UnmarshalText([]byte(data.input)) + + if data.expactFailure { + require.EqualError(t, err, "Failed to parse CIDR: invalid CIDR address: "+data.input) + } else { + require.NoError(t, err) + assert.Equal(t, data.expectedOutput, actualIPNet) + } + }) + } +}