diff --git a/pkg/ipnet/ipnet.go b/pkg/ipnet/ipnet.go index 48ca01e6706..fc66cd9360b 100644 --- a/pkg/ipnet/ipnet.go +++ b/pkg/ipnet/ipnet.go @@ -16,6 +16,35 @@ type IPNet struct { net.IPNet } +// String returns a CIDR serialization of the subnet, or an empty +// string if the subnet is nil. +func (ipnet *IPNet) String() string { + if ipnet == nil { + return "" + } + return ipnet.IPNet.String() +} + +// DeepCopyInto copies the receiver into out. out must be non-nil. +func (ipnet *IPNet) DeepCopyInto(out *IPNet) { + if ipnet == nil { + *out = *new(IPNet) + } else { + *out = *ipnet + } + return +} + +// DeepCopy copies the receiver, creating a new IPNet. +func (ipnet *IPNet) DeepCopy() *IPNet { + if ipnet == nil { + return nil + } + out := new(IPNet) + ipnet.DeepCopyInto(out) + return out +} + // MarshalJSON interface for an IPNet func (ipnet IPNet) MarshalJSON() (data []byte, err error) { if reflect.DeepEqual(ipnet.IPNet, emptyIPNet) { diff --git a/pkg/ipnet/ipnet_test.go b/pkg/ipnet/ipnet_test.go index 94b927d6f47..e12f3abaaef 100644 --- a/pkg/ipnet/ipnet_test.go +++ b/pkg/ipnet/ipnet_test.go @@ -38,25 +38,60 @@ func TestUnmarshal(t *testing.T) { Mask: net.IPv4Mask(255, 255, 255, 0), }}, } { - data, err := json.Marshal(ipNetIn) - if err != nil { - t.Fatal(err) - } + t.Run(ipNetIn.String(), func(t *testing.T) { + data, err := json.Marshal(ipNetIn) + if err != nil { + t.Fatal(err) + } - t.Run(string(data), func(t *testing.T) { var ipNetOut *IPNet - err := json.Unmarshal(data, &ipNetOut) + err = json.Unmarshal(data, &ipNetOut) if err != nil { t.Fatal(err) } - if ipNetIn == nil { - if ipNetOut != nil { - t.Fatalf("%v != %v", ipNetOut, ipNetIn) - } - } else if ipNetOut.String() != ipNetIn.String() { + if ipNetOut.String() != ipNetIn.String() { t.Fatalf("%v != %v", ipNetOut, ipNetIn) } }) } } + +func TestDeepCopy(t *testing.T) { + for _, ipNetIn := range []*IPNet{ + {}, + {IPNet: net.IPNet{ + IP: net.IP{192, 168, 0, 10}, + Mask: net.IPv4Mask(255, 255, 255, 0), + }}, + } { + t.Run(ipNetIn.String(), func(t *testing.T) { + t.Run("DeepCopyInto", func(t *testing.T) { + ipNetOut := &IPNet{IPNet: net.IPNet{ + IP: net.IP{10, 0, 0, 0}, + Mask: net.IPv4Mask(255, 0, 0, 0), + }} + + ipNetIn.DeepCopyInto(ipNetOut) + if ipNetOut.String() != ipNetIn.String() { + t.Fatalf("%v != %v", ipNetOut, ipNetIn) + } + }) + + t.Run("DeepCopy", func(t *testing.T) { + ipNetOut := ipNetIn.DeepCopy() + if ipNetOut.String() != ipNetIn.String() { + t.Fatalf("%v != %v", ipNetOut, ipNetIn) + } + + ipNetIn.IPNet = net.IPNet{ + IP: net.IP{192, 168, 10, 10}, + Mask: net.IPv4Mask(255, 255, 255, 255), + } + if ipNetOut.String() == ipNetIn.String() { + t.Fatalf("%v (%q) == %v (%q)", ipNetOut, ipNetOut.String(), ipNetIn, ipNetIn.String()) + } + }) + }) + } +}