diff --git a/api/next/61642.txt b/api/next/61642.txt new file mode 100644 index 0000000000000..4c8bf252df3bd --- /dev/null +++ b/api/next/61642.txt @@ -0,0 +1,2 @@ +pkg net/netip, method (AddrPort) Compare(AddrPort) int #61642 +pkg net/netip, method (Prefix) Compare(Prefix) int #61642 diff --git a/src/net/netip/netip.go b/src/net/netip/netip.go index 0c9dc3246ccfd..99cb754fae877 100644 --- a/src/net/netip/netip.go +++ b/src/net/netip/netip.go @@ -12,6 +12,7 @@ package netip import ( + "cmp" "errors" "math" "strconv" @@ -1102,6 +1103,16 @@ func MustParseAddrPort(s string) AddrPort { // All ports are valid, including zero. func (p AddrPort) IsValid() bool { return p.ip.IsValid() } +// Compare returns an integer comparing two AddrPorts. +// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2. +// AddrPorts sort first by IP address, then port. +func (p AddrPort) Compare(p2 AddrPort) int { + if c := p.Addr().Compare(p2.Addr()); c != 0 { + return c + } + return cmp.Compare(p.Port(), p2.Port()) +} + func (p AddrPort) String() string { switch p.ip.z { case z0: @@ -1261,6 +1272,21 @@ func (p Prefix) isZero() bool { return p == Prefix{} } // IsSingleIP reports whether p contains exactly one IP. func (p Prefix) IsSingleIP() bool { return p.IsValid() && p.Bits() == p.ip.BitLen() } +// Compare returns an integer comparing two prefixes. +// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2. +// Prefixes sort first by validity (invalid before valid), then +// address family (IPv4 before IPv6), then prefix length, then +// address. +func (p Prefix) Compare(p2 Prefix) int { + if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 { + return c + } + if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 { + return c + } + return p.Addr().Compare(p2.Addr()) +} + // ParsePrefix parses s as an IP address prefix. // The string can be in the form "192.168.1.0/24" or "2001:db8::/32", // the CIDR notation defined in RFC 4632 and RFC 4291. diff --git a/src/net/netip/netip_test.go b/src/net/netip/netip_test.go index 0f80bb0ab0e56..39893e0f6df55 100644 --- a/src/net/netip/netip_test.go +++ b/src/net/netip/netip_test.go @@ -14,6 +14,7 @@ import ( "net" . "net/netip" "reflect" + "slices" "sort" "strings" "testing" @@ -812,7 +813,7 @@ func TestAddrWellKnown(t *testing.T) { } } -func TestLessCompare(t *testing.T) { +func TestAddrLessCompare(t *testing.T) { tests := []struct { a, b Addr want bool @@ -882,6 +883,109 @@ func TestLessCompare(t *testing.T) { } } +func TestAddrPortCompare(t *testing.T) { + tests := []struct { + a, b AddrPort + want int + }{ + {AddrPort{}, AddrPort{}, 0}, + {AddrPort{}, mustIPPort("1.2.3.4:80"), -1}, + + {mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:80"), 0}, + {mustIPPort("[::1]:80"), mustIPPort("[::1]:80"), 0}, + + {mustIPPort("1.2.3.4:80"), mustIPPort("2.3.4.5:22"), -1}, + {mustIPPort("[::1]:80"), mustIPPort("[::2]:22"), -1}, + + {mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:443"), -1}, + {mustIPPort("[::1]:80"), mustIPPort("[::1]:443"), -1}, + + {mustIPPort("1.2.3.4:80"), mustIPPort("[0102:0304::0]:80"), -1}, + } + for _, tt := range tests { + got := tt.a.Compare(tt.b) + if got != tt.want { + t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want) + } + + // Also check inverse. + if got == tt.want { + got2 := tt.b.Compare(tt.a) + if want2 := -1 * tt.want; got2 != want2 { + t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2) + } + } + } + + // And just sort. + values := []AddrPort{ + mustIPPort("[::1]:80"), + mustIPPort("[::2]:80"), + AddrPort{}, + mustIPPort("1.2.3.4:443"), + mustIPPort("8.8.8.8:8080"), + mustIPPort("[::1%foo]:1024"), + } + slices.SortFunc(values, func(a, b AddrPort) int { return a.Compare(b) }) + got := fmt.Sprintf("%s", values) + want := `[invalid AddrPort 1.2.3.4:443 8.8.8.8:8080 [::1]:80 [::1%foo]:1024 [::2]:80]` + if got != want { + t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want) + } +} + +func TestPrefixCompare(t *testing.T) { + tests := []struct { + a, b Prefix + want int + }{ + {Prefix{}, Prefix{}, 0}, + {Prefix{}, mustPrefix("1.2.3.0/24"), -1}, + + {mustPrefix("1.2.3.0/24"), mustPrefix("1.2.3.0/24"), 0}, + {mustPrefix("fe80::/64"), mustPrefix("fe80::/64"), 0}, + + {mustPrefix("1.2.3.0/24"), mustPrefix("1.2.4.0/24"), -1}, + {mustPrefix("fe80::/64"), mustPrefix("fe90::/64"), -1}, + + {mustPrefix("1.2.0.0/16"), mustPrefix("1.2.0.0/24"), -1}, + {mustPrefix("fe80::/48"), mustPrefix("fe80::/64"), -1}, + + {mustPrefix("1.2.3.0/24"), mustPrefix("fe80::/8"), -1}, + } + for _, tt := range tests { + got := tt.a.Compare(tt.b) + if got != tt.want { + t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want) + } + + // Also check inverse. + if got == tt.want { + got2 := tt.b.Compare(tt.a) + if want2 := -1 * tt.want; got2 != want2 { + t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2) + } + } + } + + // And just sort. + values := []Prefix{ + mustPrefix("1.2.3.0/24"), + mustPrefix("fe90::/64"), + mustPrefix("fe80::/64"), + mustPrefix("1.2.0.0/16"), + Prefix{}, + mustPrefix("fe80::/48"), + mustPrefix("1.2.0.0/24"), + } + slices.SortFunc(values, func(a, b Prefix) int { return a.Compare(b) }) + got := fmt.Sprintf("%s", values) + want := `[invalid Prefix 1.2.0.0/16 1.2.0.0/24 1.2.3.0/24 fe80::/48 fe80::/64 fe90::/64]` + if got != want { + t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want) + } +} + func TestIPStringExpanded(t *testing.T) { tests := []struct { ip Addr