Skip to content

Commit

Permalink
lib/options: BlacklistIP un/marshaling
Browse files Browse the repository at this point in the history
The BlacklistIP configuration uses the CIDR notation.
Implemented json (Un)Marshaler interfaces to support correct
encoding and decoding based on the notation.
  • Loading branch information
codebien committed Jul 5, 2021
1 parent aa1fd6a commit 7351da9
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 7 deletions.
20 changes: 17 additions & 3 deletions lib/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ import (
"net"
"reflect"
"strconv"

"gopkg.in/guregu/null.v3"
"strings"

"go.k6.io/k6/lib/types"
"go.k6.io/k6/stats"
"gopkg.in/guregu/null.v3"
)

// DefaultScenarioName is used as the default key/ID of the scenario config entries
Expand Down Expand Up @@ -178,10 +178,18 @@ func (ipnet *IPNet) UnmarshalText(b []byte) error {
}

*ipnet = *newIPNet

return nil
}

func (ipnet *IPNet) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("%q", ipnet.String())), nil
}

func (ipnet *IPNet) UnmarshalJSON(b []byte) error {
// Remove quotes and fallback to unmarshal text
return ipnet.UnmarshalText([]byte(strings.Trim(string(b), `"`)))
}

// HostAddress stores information about IP and port
// for a host.
type HostAddress net.TCPAddr
Expand Down Expand Up @@ -572,6 +580,12 @@ func (o Options) Validate() []error {
o.ExecutionSegment, o.ExecutionSegmentSequence))
}
}
for _, ip := range o.BlacklistIPs {
if ip.String() == "<nil>" {
errors = append(errors, fmt.Errorf("blacklisted IP %q has an invalid CIDR", ip.IP.String()))
}

}
return append(errors, o.Scenarios.Validate()...)
}

Expand Down
33 changes: 29 additions & 4 deletions lib/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,15 +307,26 @@ func TestOptions(t *testing.T) {
opts := Options{}.Apply(Options{
BlacklistIPs: []*IPNet{{
IPNet: net.IPNet{
IP: net.IPv4zero,
Mask: net.CIDRMask(1, 1),
IP: net.IPv4bcast,
Mask: net.CIDRMask(31, 32),
},
}},
})
assert.NotNil(t, opts.BlacklistIPs)
assert.NotEmpty(t, opts.BlacklistIPs)
assert.Equal(t, net.IPv4zero, opts.BlacklistIPs[0].IP)
assert.Equal(t, net.CIDRMask(1, 1), opts.BlacklistIPs[0].Mask)
assert.Equal(t, net.IPv4bcast, opts.BlacklistIPs[0].IP)
assert.Equal(t, net.CIDRMask(31, 32), opts.BlacklistIPs[0].Mask)

t.Run("JSON", func(t *testing.T) {
b, err := json.Marshal(opts)
require.NoError(t, err)

var uopts Options
err = json.Unmarshal(b, &uopts)
require.NoError(t, err)
require.Len(t, uopts.BlacklistIPs, 1)
require.Equal(t, "255.255.255.254/31", uopts.BlacklistIPs[0].String())
})
})
t.Run("BlockedHostnames", func(t *testing.T) {
blockedHostnames, err := types.NewNullHostnameTrie([]string{"test.k6.io", "*valid.pattern"})
Expand Down Expand Up @@ -519,6 +530,20 @@ func TestOptionsEnv(t *testing.T) {
}
}

func TestOptionsValidate(t *testing.T) {
t.Run("BlacklistIPs", func(t *testing.T) {
opts := Options{}.Apply(Options{
BlacklistIPs: []*IPNet{
{IPNet: net.IPNet{IP: []byte{192, 0, 2, 1}, Mask: []byte{255, 255, 255, 0}}},
{IPNet: net.IPNet{IP: net.IPv4bcast, Mask: net.CIDRMask(1, 1)}},
{IPNet: net.IPNet{IP: []byte(""), Mask: net.CIDRMask(64, 128)}},
},
})
errs := opts.Validate()
require.Len(t, errs, 2)
})
}

func TestCIDRUnmarshal(t *testing.T) {
testData := []struct {
input string
Expand Down

0 comments on commit 7351da9

Please sign in to comment.