diff --git a/lib/options.go b/lib/options.go index a0d1be95ff0..f138af4a2bf 100644 --- a/lib/options.go +++ b/lib/options.go @@ -27,11 +27,11 @@ import ( "net" "reflect" "strconv" - - "gopkg.in/guregu/null.v3" + "strings" "go.k6.io/k6/lib/types" "go.k6.io/k6/stats" + "gopkg.in/guregu/null.v3" ) // DefaultScenarioName is used as the default key/ID of the scenario config entries @@ -178,10 +178,20 @@ func (ipnet *IPNet) UnmarshalText(b []byte) error { } *ipnet = *newIPNet - return nil } +// MarshalJSON returns the JSON representation of IPNet using CIDR notation. +func (ipnet *IPNet) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf("%q", ipnet.String())), nil +} + +// UnmarshalJSON decodes a JSON representation of IPNet using CIDR notation. +func (ipnet *IPNet) UnmarshalJSON(b []byte) error { + // Remove quotes and fallback to unmarshal text + return ipnet.UnmarshalText([]byte(strings.Trim(string(b), `"`))) +} + // HostAddress stores information about IP and port // for a host. type HostAddress net.TCPAddr @@ -572,6 +582,11 @@ func (o Options) Validate() []error { o.ExecutionSegment, o.ExecutionSegmentSequence)) } } + for _, ip := range o.BlacklistIPs { + if ip.String() == "" { + errors = append(errors, fmt.Errorf("blacklisted IP %q has an invalid CIDR", ip.IP.String())) + } + } return append(errors, o.Scenarios.Validate()...) } diff --git a/lib/options_test.go b/lib/options_test.go index 5b74862d29a..3be105faa8c 100644 --- a/lib/options_test.go +++ b/lib/options_test.go @@ -307,15 +307,28 @@ func TestOptions(t *testing.T) { opts := Options{}.Apply(Options{ BlacklistIPs: []*IPNet{{ IPNet: net.IPNet{ - IP: net.IPv4zero, - Mask: net.CIDRMask(1, 1), + IP: net.IPv4bcast, + Mask: net.CIDRMask(31, 32), }, }}, }) assert.NotNil(t, opts.BlacklistIPs) assert.NotEmpty(t, opts.BlacklistIPs) - assert.Equal(t, net.IPv4zero, opts.BlacklistIPs[0].IP) - assert.Equal(t, net.CIDRMask(1, 1), opts.BlacklistIPs[0].Mask) + assert.Equal(t, net.IPv4bcast, opts.BlacklistIPs[0].IP) + assert.Equal(t, net.CIDRMask(31, 32), opts.BlacklistIPs[0].Mask) + + t.Run("JSON", func(t *testing.T) { + t.Parallel() + + b, err := json.Marshal(opts) + require.NoError(t, err) + + var uopts Options + err = json.Unmarshal(b, &uopts) + require.NoError(t, err) + require.Len(t, uopts.BlacklistIPs, 1) + require.Equal(t, "255.255.255.254/31", uopts.BlacklistIPs[0].String()) + }) }) t.Run("BlockedHostnames", func(t *testing.T) { blockedHostnames, err := types.NewNullHostnameTrie([]string{"test.k6.io", "*valid.pattern"}) @@ -519,6 +532,23 @@ func TestOptionsEnv(t *testing.T) { } } +func TestOptionsValidate(t *testing.T) { + t.Parallel() + t.Run("BlacklistIPs", func(t *testing.T) { + t.Parallel() + + opts := Options{}.Apply(Options{ + BlacklistIPs: []*IPNet{ + {IPNet: net.IPNet{IP: []byte{192, 0, 2, 1}, Mask: []byte{255, 255, 255, 0}}}, + {IPNet: net.IPNet{IP: net.IPv4bcast, Mask: net.CIDRMask(1, 1)}}, + {IPNet: net.IPNet{IP: []byte(""), Mask: net.CIDRMask(64, 128)}}, + }, + }) + errs := opts.Validate() + require.Len(t, errs, 2) + }) +} + func TestCIDRUnmarshal(t *testing.T) { testData := []struct { input string