net: avoid memory copy calling absDomainName

Change-Id: I8ea9bec8bc33e29b8c265fbca40871bc23667144
Reviewed-on: https://go-review.googlesource.com/c/go/+/330470
Reviewed-by: Damien Neil <dneil@google.com>
Trust: Damien Neil <dneil@google.com>
Trust: Michael Knyszek <mknyszek@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
This commit is contained in:
Andy Pan 2021-06-24 12:50:14 +08:00 committed by Damien Neil
parent 6406227d71
commit c04a32e59a
8 changed files with 43 additions and 38 deletions

View File

@ -323,7 +323,7 @@ func cgoLookupAddrPTR(addr string, sa *C.struct_sockaddr, salen C.socklen_t) (na
break break
} }
} }
return []string{absDomainName(b)}, nil return []string{absDomainName(string(b))}, nil
} }
func cgoReverseLookup(result chan<- reverseLookupResult, addr string, sa *C.struct_sockaddr, salen C.socklen_t) { func cgoReverseLookup(result chan<- reverseLookupResult, addr string, sa *C.struct_sockaddr, salen C.socklen_t) {

View File

@ -5,6 +5,7 @@
package net package net
import ( import (
"internal/bytealg"
"internal/itoa" "internal/itoa"
"sort" "sort"
@ -136,18 +137,11 @@ func isDomainName(s string) bool {
// It's hard to tell so we settle on the heuristic that names without dots // It's hard to tell so we settle on the heuristic that names without dots
// (like "localhost" or "myhost") do not get trailing dots, but any other // (like "localhost" or "myhost") do not get trailing dots, but any other
// names do. // names do.
func absDomainName(b []byte) string { func absDomainName(s string) string {
hasDots := false if bytealg.IndexByteString(s, '.') != -1 && s[len(s)-1] != '.' {
for _, x := range b { s += "."
if x == '.' {
hasDots = true
break
} }
} return s
if hasDots && b[len(b)-1] != '.' {
b = append(b, '.')
}
return string(b)
} }
// An SRV represents a single DNS SRV record. // An SRV represents a single DNS SRV record.

View File

@ -82,10 +82,10 @@ func readHosts() {
continue continue
} }
for i := 1; i < len(f); i++ { for i := 1; i < len(f); i++ {
name := absDomainName([]byte(f[i])) name := absDomainName(f[i])
h := []byte(f[i]) h := []byte(f[i])
lowerASCIIBytes(h) lowerASCIIBytes(h)
key := absDomainName(h) key := absDomainName(string(h))
hs[key] = append(hs[key], addr) hs[key] = append(hs[key], addr)
is[addr] = append(is[addr], name) is[addr] = append(is[addr], name)
} }
@ -106,11 +106,12 @@ func lookupStaticHost(host string) []string {
defer hosts.Unlock() defer hosts.Unlock()
readHosts() readHosts()
if len(hosts.byName) != 0 { if len(hosts.byName) != 0 {
// TODO(jbd,bradfitz): avoid this alloc if host is already all lowercase? if hasUpperCase(host) {
// or linear scan the byName map if it's small enough?
lowerHost := []byte(host) lowerHost := []byte(host)
lowerASCIIBytes(lowerHost) lowerASCIIBytes(lowerHost)
if ips, ok := hosts.byName[absDomainName(lowerHost)]; ok { host = string(lowerHost)
}
if ips, ok := hosts.byName[absDomainName(host)]; ok {
ipsCp := make([]string, len(ips)) ipsCp := make([]string, len(ips))
copy(ipsCp, ips) copy(ipsCp, ips)
return ipsCp return ipsCp

View File

@ -70,7 +70,7 @@ func TestLookupStaticHost(t *testing.T) {
} }
func testStaticHost(t *testing.T, hostsPath string, ent staticHostEntry) { func testStaticHost(t *testing.T, hostsPath string, ent staticHostEntry) {
ins := []string{ent.in, absDomainName([]byte(ent.in)), strings.ToLower(ent.in), strings.ToUpper(ent.in)} ins := []string{ent.in, absDomainName(ent.in), strings.ToLower(ent.in), strings.ToUpper(ent.in)}
for _, in := range ins { for _, in := range ins {
addrs := lookupStaticHost(in) addrs := lookupStaticHost(in)
if !reflect.DeepEqual(addrs, ent.out) { if !reflect.DeepEqual(addrs, ent.out) {
@ -141,7 +141,7 @@ func TestLookupStaticAddr(t *testing.T) {
func testStaticAddr(t *testing.T, hostsPath string, ent staticHostEntry) { func testStaticAddr(t *testing.T, hostsPath string, ent staticHostEntry) {
hosts := lookupStaticAddr(ent.in) hosts := lookupStaticAddr(ent.in)
for i := range ent.out { for i := range ent.out {
ent.out[i] = absDomainName([]byte(ent.out[i])) ent.out[i] = absDomainName(ent.out[i])
} }
if !reflect.DeepEqual(hosts, ent.out) { if !reflect.DeepEqual(hosts, ent.out) {
t.Errorf("%s, lookupStaticAddr(%s) = %v; want %v", hostsPath, ent.in, hosts, ent.out) t.Errorf("%s, lookupStaticAddr(%s) = %v; want %v", hostsPath, ent.in, hosts, ent.out)

View File

@ -262,8 +262,8 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cn
if !(portOk && priorityOk && weightOk) { if !(portOk && priorityOk && weightOk) {
continue continue
} }
addrs = append(addrs, &SRV{absDomainName([]byte(f[5])), uint16(port), uint16(priority), uint16(weight)}) addrs = append(addrs, &SRV{absDomainName(f[5]), uint16(port), uint16(priority), uint16(weight)})
cname = absDomainName([]byte(f[0])) cname = absDomainName(f[0])
} }
byPriorityWeight(addrs).sort() byPriorityWeight(addrs).sort()
return return
@ -280,7 +280,7 @@ func (*Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error
continue continue
} }
if pref, _, ok := dtoi(f[2]); ok { if pref, _, ok := dtoi(f[2]); ok {
mx = append(mx, &MX{absDomainName([]byte(f[3])), uint16(pref)}) mx = append(mx, &MX{absDomainName(f[3]), uint16(pref)})
} }
} }
byPref(mx).sort() byPref(mx).sort()
@ -297,7 +297,7 @@ func (*Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error
if len(f) < 3 { if len(f) < 3 {
continue continue
} }
ns = append(ns, &NS{absDomainName([]byte(f[2]))}) ns = append(ns, &NS{absDomainName(f[2])})
} }
return return
} }
@ -329,7 +329,7 @@ func (*Resolver) lookupAddr(ctx context.Context, addr string) (name []string, er
if len(f) < 3 { if len(f) < 3 {
continue continue
} }
name = append(name, absDomainName([]byte(f[2]))) name = append(name, absDomainName(f[2]))
} }
return return
} }

View File

@ -226,7 +226,7 @@ func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s // windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS { if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
// if there are no aliases, the canonical name is the input name // if there are no aliases, the canonical name is the input name
return absDomainName([]byte(name)), nil return absDomainName(name), nil
} }
if e != nil { if e != nil {
return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name} return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
@ -235,7 +235,7 @@ func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r) resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
cname := windows.UTF16PtrToString(resolved) cname := windows.UTF16PtrToString(resolved)
return absDomainName([]byte(cname)), nil return absDomainName(cname), nil
} }
func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
@ -258,10 +258,10 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (st
srvs := make([]*SRV, 0, 10) srvs := make([]*SRV, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) { for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0])) v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
srvs = append(srvs, &SRV{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]))), v.Port, v.Priority, v.Weight}) srvs = append(srvs, &SRV{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:])), v.Port, v.Priority, v.Weight})
} }
byPriorityWeight(srvs).sort() byPriorityWeight(srvs).sort()
return absDomainName([]byte(target)), srvs, nil return absDomainName(target), srvs, nil
} }
func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
@ -278,7 +278,7 @@ func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
mxs := make([]*MX, 0, 10) mxs := make([]*MX, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) { for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0])) v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
mxs = append(mxs, &MX{absDomainName([]byte(windows.UTF16PtrToString(v.NameExchange))), v.Preference}) mxs = append(mxs, &MX{absDomainName(windows.UTF16PtrToString(v.NameExchange)), v.Preference})
} }
byPref(mxs).sort() byPref(mxs).sort()
return mxs, nil return mxs, nil
@ -298,7 +298,7 @@ func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
nss := make([]*NS, 0, 10) nss := make([]*NS, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) { for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
nss = append(nss, &NS{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:])))}) nss = append(nss, &NS{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))})
} }
return nss, nil return nss, nil
} }
@ -344,7 +344,7 @@ func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error)
ptrs := make([]string, 0, 10) ptrs := make([]string, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) { for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
ptrs = append(ptrs, absDomainName([]byte(windows.UTF16PtrToString(v.Host)))) ptrs = append(ptrs, absDomainName(windows.UTF16PtrToString(v.Host)))
} }
return ptrs, nil return ptrs, nil
} }

View File

@ -220,14 +220,14 @@ func nslookupMX(name string) (mx []*MX, err error) {
rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+mail exchanger\s*=\s*([0-9]+)\s*([a-z0-9.\-]+)$`) rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+mail exchanger\s*=\s*([0-9]+)\s*([a-z0-9.\-]+)$`)
for _, ans := range rx.FindAllStringSubmatch(r, -1) { for _, ans := range rx.FindAllStringSubmatch(r, -1) {
pref, _, _ := dtoi(ans[2]) pref, _, _ := dtoi(ans[2])
mx = append(mx, &MX{absDomainName([]byte(ans[3])), uint16(pref)}) mx = append(mx, &MX{absDomainName(ans[3]), uint16(pref)})
} }
// windows nslookup syntax // windows nslookup syntax
// gmail.com MX preference = 30, mail exchanger = alt3.gmail-smtp-in.l.google.com // gmail.com MX preference = 30, mail exchanger = alt3.gmail-smtp-in.l.google.com
rx = regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+MX preference\s*=\s*([0-9]+)\s*,\s*mail exchanger\s*=\s*([a-z0-9.\-]+)$`) rx = regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+MX preference\s*=\s*([0-9]+)\s*,\s*mail exchanger\s*=\s*([a-z0-9.\-]+)$`)
for _, ans := range rx.FindAllStringSubmatch(r, -1) { for _, ans := range rx.FindAllStringSubmatch(r, -1) {
pref, _, _ := dtoi(ans[2]) pref, _, _ := dtoi(ans[2])
mx = append(mx, &MX{absDomainName([]byte(ans[3])), uint16(pref)}) mx = append(mx, &MX{absDomainName(ans[3]), uint16(pref)})
} }
return return
} }
@ -241,7 +241,7 @@ func nslookupNS(name string) (ns []*NS, err error) {
// golang.org nameserver = ns1.google.com. // golang.org nameserver = ns1.google.com.
rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+nameserver\s*=\s*([a-z0-9.\-]+)$`) rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+nameserver\s*=\s*([a-z0-9.\-]+)$`)
for _, ans := range rx.FindAllStringSubmatch(r, -1) { for _, ans := range rx.FindAllStringSubmatch(r, -1) {
ns = append(ns, &NS{absDomainName([]byte(ans[2]))}) ns = append(ns, &NS{absDomainName(ans[2])})
} }
return return
} }
@ -258,7 +258,7 @@ func nslookupCNAME(name string) (cname string, err error) {
for _, ans := range rx.FindAllStringSubmatch(r, -1) { for _, ans := range rx.FindAllStringSubmatch(r, -1) {
last = ans[2] last = ans[2]
} }
return absDomainName([]byte(last)), nil return absDomainName(last), nil
} }
func nslookupTXT(name string) (txt []string, err error) { func nslookupTXT(name string) (txt []string, err error) {
@ -299,7 +299,7 @@ func lookupPTR(name string) (ptr []string, err error) {
ptr = make([]string, 0, 10) ptr = make([]string, 0, 10)
rx := regexp.MustCompile(`(?m)^Pinging\s+([a-zA-Z0-9.\-]+)\s+\[.*$`) rx := regexp.MustCompile(`(?m)^Pinging\s+([a-zA-Z0-9.\-]+)\s+\[.*$`)
for _, ans := range rx.FindAllStringSubmatch(r, -1) { for _, ans := range rx.FindAllStringSubmatch(r, -1) {
ptr = append(ptr, absDomainName([]byte(ans[1]))) ptr = append(ptr, absDomainName(ans[1]))
} }
return return
} }

View File

@ -208,6 +208,16 @@ func last(s string, b byte) int {
return i return i
} }
// hasUpperCase tells whether the given string contains at least one upper-case.
func hasUpperCase(s string) bool {
for i := range s {
if 'A' <= s[i] && s[i] <= 'Z' {
return true
}
}
return false
}
// lowerASCIIBytes makes x ASCII lowercase in-place. // lowerASCIIBytes makes x ASCII lowercase in-place.
func lowerASCIIBytes(x []byte) { func lowerASCIIBytes(x []byte) {
for i, b := range x { for i, b := range x {