net/netip: add AddrPort.Compare and Prefix.Compare

Fixes #61642

Change-Id: I2262855dbe75135f70008e5df4634d2cfff76550
GitHub-Last-Rev: 949685a9e4
GitHub-Pull-Request: golang/go#62387
Reviewed-on: https://go-review.googlesource.com/c/go/+/524616
Reviewed-by: Joseph Tsai <joetsai@digital-static.net>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Heschi Kreinick <heschi@google.com>
This commit is contained in:
David Anderson 2023-08-31 22:46:45 +00:00 committed by Brad Fitzpatrick
parent dfb2e4265b
commit 94f24fd054
3 changed files with 133 additions and 1 deletions

2
api/next/61642.txt Normal file
View File

@ -0,0 +1,2 @@
pkg net/netip, method (AddrPort) Compare(AddrPort) int #61642
pkg net/netip, method (Prefix) Compare(Prefix) int #61642

View File

@ -12,6 +12,7 @@
package netip
import (
"cmp"
"errors"
"math"
"strconv"
@ -1102,6 +1103,16 @@ func MustParseAddrPort(s string) AddrPort {
// All ports are valid, including zero.
func (p AddrPort) IsValid() bool { return p.ip.IsValid() }
// Compare returns an integer comparing two AddrPorts.
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
// AddrPorts sort first by IP address, then port.
func (p AddrPort) Compare(p2 AddrPort) int {
if c := p.Addr().Compare(p2.Addr()); c != 0 {
return c
}
return cmp.Compare(p.Port(), p2.Port())
}
func (p AddrPort) String() string {
switch p.ip.z {
case z0:
@ -1261,6 +1272,21 @@ func (p Prefix) isZero() bool { return p == Prefix{} }
// IsSingleIP reports whether p contains exactly one IP.
func (p Prefix) IsSingleIP() bool { return p.IsValid() && p.Bits() == p.ip.BitLen() }
// Compare returns an integer comparing two prefixes.
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
// Prefixes sort first by validity (invalid before valid), then
// address family (IPv4 before IPv6), then prefix length, then
// address.
func (p Prefix) Compare(p2 Prefix) int {
if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 {
return c
}
if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 {
return c
}
return p.Addr().Compare(p2.Addr())
}
// ParsePrefix parses s as an IP address prefix.
// The string can be in the form "192.168.1.0/24" or "2001:db8::/32",
// the CIDR notation defined in RFC 4632 and RFC 4291.

View File

@ -14,6 +14,7 @@ import (
"net"
. "net/netip"
"reflect"
"slices"
"sort"
"strings"
"testing"
@ -812,7 +813,7 @@ func TestAddrWellKnown(t *testing.T) {
}
}
func TestLessCompare(t *testing.T) {
func TestAddrLessCompare(t *testing.T) {
tests := []struct {
a, b Addr
want bool
@ -882,6 +883,109 @@ func TestLessCompare(t *testing.T) {
}
}
func TestAddrPortCompare(t *testing.T) {
tests := []struct {
a, b AddrPort
want int
}{
{AddrPort{}, AddrPort{}, 0},
{AddrPort{}, mustIPPort("1.2.3.4:80"), -1},
{mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:80"), 0},
{mustIPPort("[::1]:80"), mustIPPort("[::1]:80"), 0},
{mustIPPort("1.2.3.4:80"), mustIPPort("2.3.4.5:22"), -1},
{mustIPPort("[::1]:80"), mustIPPort("[::2]:22"), -1},
{mustIPPort("1.2.3.4:80"), mustIPPort("1.2.3.4:443"), -1},
{mustIPPort("[::1]:80"), mustIPPort("[::1]:443"), -1},
{mustIPPort("1.2.3.4:80"), mustIPPort("[0102:0304::0]:80"), -1},
}
for _, tt := range tests {
got := tt.a.Compare(tt.b)
if got != tt.want {
t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want)
}
// Also check inverse.
if got == tt.want {
got2 := tt.b.Compare(tt.a)
if want2 := -1 * tt.want; got2 != want2 {
t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2)
}
}
}
// And just sort.
values := []AddrPort{
mustIPPort("[::1]:80"),
mustIPPort("[::2]:80"),
AddrPort{},
mustIPPort("1.2.3.4:443"),
mustIPPort("8.8.8.8:8080"),
mustIPPort("[::1%foo]:1024"),
}
slices.SortFunc(values, func(a, b AddrPort) int { return a.Compare(b) })
got := fmt.Sprintf("%s", values)
want := `[invalid AddrPort 1.2.3.4:443 8.8.8.8:8080 [::1]:80 [::1%foo]:1024 [::2]:80]`
if got != want {
t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want)
}
}
func TestPrefixCompare(t *testing.T) {
tests := []struct {
a, b Prefix
want int
}{
{Prefix{}, Prefix{}, 0},
{Prefix{}, mustPrefix("1.2.3.0/24"), -1},
{mustPrefix("1.2.3.0/24"), mustPrefix("1.2.3.0/24"), 0},
{mustPrefix("fe80::/64"), mustPrefix("fe80::/64"), 0},
{mustPrefix("1.2.3.0/24"), mustPrefix("1.2.4.0/24"), -1},
{mustPrefix("fe80::/64"), mustPrefix("fe90::/64"), -1},
{mustPrefix("1.2.0.0/16"), mustPrefix("1.2.0.0/24"), -1},
{mustPrefix("fe80::/48"), mustPrefix("fe80::/64"), -1},
{mustPrefix("1.2.3.0/24"), mustPrefix("fe80::/8"), -1},
}
for _, tt := range tests {
got := tt.a.Compare(tt.b)
if got != tt.want {
t.Errorf("Compare(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want)
}
// Also check inverse.
if got == tt.want {
got2 := tt.b.Compare(tt.a)
if want2 := -1 * tt.want; got2 != want2 {
t.Errorf("Compare(%q, %q) was correctly %v, but Compare(%q, %q) was %v", tt.a, tt.b, got, tt.b, tt.a, got2)
}
}
}
// And just sort.
values := []Prefix{
mustPrefix("1.2.3.0/24"),
mustPrefix("fe90::/64"),
mustPrefix("fe80::/64"),
mustPrefix("1.2.0.0/16"),
Prefix{},
mustPrefix("fe80::/48"),
mustPrefix("1.2.0.0/24"),
}
slices.SortFunc(values, func(a, b Prefix) int { return a.Compare(b) })
got := fmt.Sprintf("%s", values)
want := `[invalid Prefix 1.2.0.0/16 1.2.0.0/24 1.2.3.0/24 fe80::/48 fe80::/64 fe90::/64]`
if got != want {
t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want)
}
}
func TestIPStringExpanded(t *testing.T) {
tests := []struct {
ip Addr