Skip to content

Commit

Permalink
lib/options: BlacklistIP un/marshaling
Browse files Browse the repository at this point in the history
The BlacklistIP configuration uses the CIDR notation.
Implemented json (Un)Marshaler interfaces to support correct
encoding and decoding based on the notation.
  • Loading branch information
codebien committed Jul 6, 2021
1 parent aa1fd6a commit 2cfa6ad
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 7 deletions.
21 changes: 18 additions & 3 deletions lib/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ import (
"net"
"reflect"
"strconv"

"gopkg.in/guregu/null.v3"
"strings"

"go.k6.io/k6/lib/types"
"go.k6.io/k6/stats"
"gopkg.in/guregu/null.v3"
)

// DefaultScenarioName is used as the default key/ID of the scenario config entries
Expand Down Expand Up @@ -178,10 +178,20 @@ func (ipnet *IPNet) UnmarshalText(b []byte) error {
}

*ipnet = *newIPNet

return nil
}

// MarshalJSON returns the JSON representation of IPNet using CIDR notation.
func (ipnet *IPNet) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("%q", ipnet.String())), nil
}

// UnmarshalJSON decodes a JSON representation of IPNet using CIDR notation.
func (ipnet *IPNet) UnmarshalJSON(b []byte) error {
// Remove quotes and fallback to unmarshal text
return ipnet.UnmarshalText([]byte(strings.Trim(string(b), `"`)))
}

// HostAddress stores information about IP and port
// for a host.
type HostAddress net.TCPAddr
Expand Down Expand Up @@ -572,6 +582,11 @@ func (o Options) Validate() []error {
o.ExecutionSegment, o.ExecutionSegmentSequence))
}
}
for _, ip := range o.BlacklistIPs {
if ip.String() == "<nil>" {
errors = append(errors, fmt.Errorf("blacklisted IP %q has an invalid CIDR", ip.IP.String()))
}
}
return append(errors, o.Scenarios.Validate()...)
}

Expand Down
38 changes: 34 additions & 4 deletions lib/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,15 +307,28 @@ func TestOptions(t *testing.T) {
opts := Options{}.Apply(Options{
BlacklistIPs: []*IPNet{{
IPNet: net.IPNet{
IP: net.IPv4zero,
Mask: net.CIDRMask(1, 1),
IP: net.IPv4bcast,
Mask: net.CIDRMask(31, 32),
},
}},
})
assert.NotNil(t, opts.BlacklistIPs)
assert.NotEmpty(t, opts.BlacklistIPs)
assert.Equal(t, net.IPv4zero, opts.BlacklistIPs[0].IP)
assert.Equal(t, net.CIDRMask(1, 1), opts.BlacklistIPs[0].Mask)
assert.Equal(t, net.IPv4bcast, opts.BlacklistIPs[0].IP)
assert.Equal(t, net.CIDRMask(31, 32), opts.BlacklistIPs[0].Mask)

t.Run("JSON", func(t *testing.T) {
t.Parallel()

b, err := json.Marshal(opts)
require.NoError(t, err)

var uopts Options
err = json.Unmarshal(b, &uopts)
require.NoError(t, err)
require.Len(t, uopts.BlacklistIPs, 1)
require.Equal(t, "255.255.255.254/31", uopts.BlacklistIPs[0].String())
})
})
t.Run("BlockedHostnames", func(t *testing.T) {
blockedHostnames, err := types.NewNullHostnameTrie([]string{"test.k6.io", "*valid.pattern"})
Expand Down Expand Up @@ -519,6 +532,23 @@ func TestOptionsEnv(t *testing.T) {
}
}

func TestOptionsValidate(t *testing.T) {
t.Parallel()
t.Run("BlacklistIPs", func(t *testing.T) {
t.Parallel()

opts := Options{}.Apply(Options{
BlacklistIPs: []*IPNet{
{IPNet: net.IPNet{IP: []byte{192, 0, 2, 1}, Mask: []byte{255, 255, 255, 0}}},
{IPNet: net.IPNet{IP: net.IPv4bcast, Mask: net.CIDRMask(1, 1)}},
{IPNet: net.IPNet{IP: []byte(""), Mask: net.CIDRMask(64, 128)}},
},
})
errs := opts.Validate()
require.Len(t, errs, 2)
})
}

func TestCIDRUnmarshal(t *testing.T) {
testData := []struct {
input string
Expand Down

0 comments on commit 2cfa6ad

Please sign in to comment.