From 94f24fd054a9263174774b91268aab8da23d5d48 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Thu, 31 Aug 2023 22:46:45 +0000 Subject: [PATCH] net/netip: add AddrPort.Compare and Prefix.Compare Fixes #61642 Change-Id: I2262855dbe75135f70008e5df4634d2cfff76550 GitHub-Last-Rev: 949685a9e426f50f37753045d7527ebcccb082e7 GitHub-Pull-Request: golang/go#62387 Reviewed-on: https://go-review.googlesource.com/c/go/+/524616 Reviewed-by: Joseph Tsai Reviewed-by: Brad Fitzpatrick LUCI-TryBot-Result: Go LUCI Reviewed-by: Matthew Dempsky Reviewed-by: Heschi Kreinick --- api/next/61642.txt | 2 + src/net/netip/netip.go | 26 +++++++++ src/net/netip/netip_test.go | 106 +++++++++++++++++++++++++++++++++++- 3 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 api/next/61642.txt diff --git a/api/next/61642.txt b/api/next/61642.txt new file mode 100644 index 00000000000000..4c8bf252df3bd6 --- /dev/null +++ b/api/next/61642.txt @@ -0,0 +1,2 @@ +pkg net/netip, method (AddrPort) Compare(AddrPort) int #61642 +pkg net/netip, method (Prefix) Compare(Prefix) int #61642 diff --git a/src/net/netip/netip.go b/src/net/netip/netip.go index 0c9dc3246ccfd4..99cb754fae8776 100644 --- a/src/net/netip/netip.go +++ b/src/net/netip/netip.go @@ -12,6 +12,7 @@ package netip import ( + "cmp" "errors" "math" "strconv" @@ -1102,6 +1103,16 @@ func MustParseAddrPort(s string) AddrPort { // All ports are valid, including zero. func (p AddrPort) IsValid() bool { return p.ip.IsValid() } +// Compare returns an integer comparing two AddrPorts. +// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2. +// AddrPorts sort first by IP address, then port. +func (p AddrPort) Compare(p2 AddrPort) int { + if c := p.Addr().Compare(p2.Addr()); c != 0 { + return c + } + return cmp.Compare(p.Port(), p2.Port()) +} + func (p AddrPort) String() string { switch p.ip.z { case z0: @@ -1261,6 +1272,21 @@ func (p Prefix) isZero() bool { return p == Prefix{} } // IsSingleIP reports whether p contains exactly one IP. func (p Prefix) IsSingleIP() bool { return p.IsValid() && p.Bits() == p.ip.BitLen() } +// Compare returns an integer comparing two prefixes. +// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2. +// Prefixes sort first by validity (invalid before valid), then +// address family (IPv4 before IPv6), then prefix length, then +// address. +func (p Prefix) Compare(p2 Prefix) int { + if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 { + return c + } + if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 { + return c + } + return p.Addr().Compare(p2.Addr()) +} + // ParsePrefix parses s as an IP address prefix. // The string can be in the form "192.168.1.0/24" or "2001:db8::/32", // the CIDR notation defined in RFC 4632 and RFC 4291. diff --git a/src/net/netip/netip_test.go b/src/net/netip/netip_test.go index 0f80bb0ab0e56f..39893e0f6df556 100644 --- a/src/net/netip/netip_test.go +++ b/src/net/netip/netip_test.go @@ -14,6 +14,7 @@ import ( "net" . "net/netip" "reflect" + "slices" "sort" "strings" "testing" @@ -812,7 +813,7 @@ func TestAddrWellKnown(t *testing.T) { } } -func TestLessCompare(t *testing.T) { +func TestAddrLessCompare(t *testing.T) { tests := []struct { a, b Addr want bool @@ -882,6 +883,109 @@ func TestLessCompare(t *testing.T) { } } +func TestAddrPortCompare(t *testing.T) { + tests := []struct { + a, b AddrPort + want int + }{ + {AddrPort{}, AddrPort{}, 0}, + {AddrPort{}, mustIPPort("1.2.3.4:80"), -1}, + + {mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:80"), 0}, + {mustIPPort("[::1]:80"), mustIPPort("[::1]:80"), 0}, + + {mustIPPort("1.2.3.4:80"), mustIPPort("2.3.4.5:22"), -1}, + {mustIPPort("[::1]:80"), mustIPPort("[::2]:22"), -1}, + + {mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:443"), -1}, + {mustIPPort("[::1]:80"), mustIPPort("[::1]:443"), -1}, + + {mustIPPort("1.2.3.4:80"), mustIPPort("[0102:0304::0]:80"), -1}, + } + for _, tt := range tests { + got := tt.a.Compare(tt.b) + if got != tt.want { + t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want) + } + + // Also check inverse. + if got == tt.want { + got2 := tt.b.Compare(tt.a) + if want2 := -1 * tt.want; got2 != want2 { + t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2) + } + } + } + + // And just sort. + values := []AddrPort{ + mustIPPort("[::1]:80"), + mustIPPort("[::2]:80"), + AddrPort{}, + mustIPPort("1.2.3.4:443"), + mustIPPort("8.8.8.8:8080"), + mustIPPort("[::1%foo]:1024"), + } + slices.SortFunc(values, func(a, b AddrPort) int { return a.Compare(b) }) + got := fmt.Sprintf("%s", values) + want := `[invalid AddrPort 1.2.3.4:443 8.8.8.8:8080 [::1]:80 [::1%foo]:1024 [::2]:80]` + if got != want { + t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want) + } +} + +func TestPrefixCompare(t *testing.T) { + tests := []struct { + a, b Prefix + want int + }{ + {Prefix{}, Prefix{}, 0}, + {Prefix{}, mustPrefix("1.2.3.0/24"), -1}, + + {mustPrefix("1.2.3.0/24"), mustPrefix("1.2.3.0/24"), 0}, + {mustPrefix("fe80::/64"), mustPrefix("fe80::/64"), 0}, + + {mustPrefix("1.2.3.0/24"), mustPrefix("1.2.4.0/24"), -1}, + {mustPrefix("fe80::/64"), mustPrefix("fe90::/64"), -1}, + + {mustPrefix("1.2.0.0/16"), mustPrefix("1.2.0.0/24"), -1}, + {mustPrefix("fe80::/48"), mustPrefix("fe80::/64"), -1}, + + {mustPrefix("1.2.3.0/24"), mustPrefix("fe80::/8"), -1}, + } + for _, tt := range tests { + got := tt.a.Compare(tt.b) + if got != tt.want { + t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want) + } + + // Also check inverse. + if got == tt.want { + got2 := tt.b.Compare(tt.a) + if want2 := -1 * tt.want; got2 != want2 { + t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2) + } + } + } + + // And just sort. + values := []Prefix{ + mustPrefix("1.2.3.0/24"), + mustPrefix("fe90::/64"), + mustPrefix("fe80::/64"), + mustPrefix("1.2.0.0/16"), + Prefix{}, + mustPrefix("fe80::/48"), + mustPrefix("1.2.0.0/24"), + } + slices.SortFunc(values, func(a, b Prefix) int { return a.Compare(b) }) + got := fmt.Sprintf("%s", values) + want := `[invalid Prefix 1.2.0.0/16 1.2.0.0/24 1.2.3.0/24 fe80::/48 fe80::/64 fe90::/64]` + if got != want { + t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want) + } +} + func TestIPStringExpanded(t *testing.T) { tests := []struct { ip Addr