net: rewrite nsswitch.conf parsing to work like other parsers

Seems simpler than having two different parsing mechanisms.

Change-Id: I4f8468bc025f8e03f59ec9c79b17721581b64eed
Reviewed-on: https://go-review.googlesource.com/c/go/+/448855
Reviewed-by: Ian Lance Taylor <iant@google.com>
Reviewed-by: Joedian Reid <joedian@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>
This commit is contained in:
Ian Lance Taylor 2022-11-08 15:08:48 -08:00 committed by Gopher Robot
parent 678cd71d11
commit 37ca171ce7
4 changed files with 81 additions and 89 deletions

View File

@ -8,7 +8,7 @@ package net
import (
"io/fs"
"strings"
"os"
"testing"
"time"
)
@ -19,7 +19,19 @@ type nssHostTest struct {
want hostLookupOrder
}
func nssStr(s string) *nssConf { return parseNSSConf(strings.NewReader(s)) }
func nssStr(t *testing.T, s string) *nssConf {
f, err := os.CreateTemp(t.TempDir(), "nss")
if err != nil {
t.Fatal(err)
}
if _, err := f.WriteString(s); err != nil {
t.Fatal(err)
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
return parseNSSConfFile(f.Name())
}
// represents a dnsConfig returned by parsing a nonexistent resolv.conf
var defaultResolvConf = &dnsConfig{
@ -45,7 +57,7 @@ func TestConfHostLookupOrder(t *testing.T) {
forceCgoLookupHost: true,
},
resolv: defaultResolvConf,
nss: nssStr("foo: bar"),
nss: nssStr(t, "foo: bar"),
hostTests: []nssHostTest{
{"foo.local", "myhostname", hostLookupCgo},
{"google.com", "myhostname", hostLookupCgo},
@ -57,7 +69,7 @@ func TestConfHostLookupOrder(t *testing.T) {
netGo: true,
},
resolv: defaultResolvConf,
nss: nssStr("hosts: dns files"),
nss: nssStr(t, "hosts: dns files"),
hostTests: []nssHostTest{
{"x.com", "myhostname", hostLookupDNSFiles},
},
@ -68,7 +80,7 @@ func TestConfHostLookupOrder(t *testing.T) {
netGo: true,
},
resolv: defaultResolvConf,
nss: nssStr("hosts: dns files something_custom"),
nss: nssStr(t, "hosts: dns files something_custom"),
hostTests: []nssHostTest{
{"x.com", "myhostname", hostLookupFilesDNS},
},
@ -77,7 +89,7 @@ func TestConfHostLookupOrder(t *testing.T) {
name: "ubuntu_trusty_avahi",
c: &conf{},
resolv: defaultResolvConf,
nss: nssStr("hosts: files mdns4_minimal [NOTFOUND=return] dns mdns4"),
nss: nssStr(t, "hosts: files mdns4_minimal [NOTFOUND=return] dns mdns4"),
hostTests: []nssHostTest{
{"foo.local", "myhostname", hostLookupCgo},
{"foo.local.", "myhostname", hostLookupCgo},
@ -92,7 +104,7 @@ func TestConfHostLookupOrder(t *testing.T) {
goos: "freebsd",
},
resolv: defaultResolvConf,
nss: nssStr("foo: bar"),
nss: nssStr(t, "foo: bar"),
hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFilesDNS}},
},
// On OpenBSD, no resolv.conf means no DNS.
@ -187,14 +199,14 @@ func TestConfHostLookupOrder(t *testing.T) {
goos: "linux",
},
resolv: defaultResolvConf,
nss: nssStr(""),
nss: nssStr(t, ""),
hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFilesDNS}},
},
{
name: "files_mdns_dns",
c: &conf{},
resolv: defaultResolvConf,
nss: nssStr("hosts: files mdns dns"),
nss: nssStr(t, "hosts: files mdns dns"),
hostTests: []nssHostTest{
{"x.com", "myhostname", hostLookupFilesDNS},
{"x.local", "myhostname", hostLookupCgo},
@ -204,7 +216,7 @@ func TestConfHostLookupOrder(t *testing.T) {
name: "dns_special_hostnames",
c: &conf{},
resolv: defaultResolvConf,
nss: nssStr("hosts: dns"),
nss: nssStr(t, "hosts: dns"),
hostTests: []nssHostTest{
{"x.com", "myhostname", hostLookupDNS},
{"x\\.com", "myhostname", hostLookupCgo}, // punt on weird glibc escape
@ -217,7 +229,7 @@ func TestConfHostLookupOrder(t *testing.T) {
hasMDNSAllow: true,
},
resolv: defaultResolvConf,
nss: nssStr("hosts: files mdns dns"),
nss: nssStr(t, "hosts: files mdns dns"),
hostTests: []nssHostTest{
{"x.com", "myhostname", hostLookupCgo},
{"x.local", "myhostname", hostLookupCgo},
@ -227,7 +239,7 @@ func TestConfHostLookupOrder(t *testing.T) {
name: "files_dns",
c: &conf{},
resolv: defaultResolvConf,
nss: nssStr("hosts: files dns"),
nss: nssStr(t, "hosts: files dns"),
hostTests: []nssHostTest{
{"x.com", "myhostname", hostLookupFilesDNS},
{"x", "myhostname", hostLookupFilesDNS},
@ -238,7 +250,7 @@ func TestConfHostLookupOrder(t *testing.T) {
name: "dns_files",
c: &conf{},
resolv: defaultResolvConf,
nss: nssStr("hosts: dns files"),
nss: nssStr(t, "hosts: dns files"),
hostTests: []nssHostTest{
{"x.com", "myhostname", hostLookupDNSFiles},
{"x", "myhostname", hostLookupDNSFiles},
@ -249,7 +261,7 @@ func TestConfHostLookupOrder(t *testing.T) {
name: "something_custom",
c: &conf{},
resolv: defaultResolvConf,
nss: nssStr("hosts: dns files something_custom"),
nss: nssStr(t, "hosts: dns files something_custom"),
hostTests: []nssHostTest{
{"x.com", "myhostname", hostLookupCgo},
},
@ -258,7 +270,7 @@ func TestConfHostLookupOrder(t *testing.T) {
name: "myhostname",
c: &conf{},
resolv: defaultResolvConf,
nss: nssStr("hosts: files dns myhostname"),
nss: nssStr(t, "hosts: files dns myhostname"),
hostTests: []nssHostTest{
{"x.com", "myhostname", hostLookupFilesDNS},
{"myhostname", "myhostname", hostLookupCgo},
@ -285,7 +297,7 @@ func TestConfHostLookupOrder(t *testing.T) {
name: "ubuntu14.04.02",
c: &conf{},
resolv: defaultResolvConf,
nss: nssStr("hosts: files myhostname mdns4_minimal [NOTFOUND=return] dns mdns4"),
nss: nssStr(t, "hosts: files myhostname mdns4_minimal [NOTFOUND=return] dns mdns4"),
hostTests: []nssHostTest{
{"x.com", "myhostname", hostLookupFilesDNS},
{"somehostname", "myhostname", hostLookupFilesDNS},
@ -300,7 +312,7 @@ func TestConfHostLookupOrder(t *testing.T) {
name: "debian_squeeze",
c: &conf{},
resolv: defaultResolvConf,
nss: nssStr("hosts: dns [success=return notfound=continue unavail=continue tryagain=continue] files [notfound=return]"),
nss: nssStr(t, "hosts: dns [success=return notfound=continue unavail=continue tryagain=continue] files [notfound=return]"),
hostTests: []nssHostTest{
{"x.com", "myhostname", hostLookupDNSFiles},
{"somehostname", "myhostname", hostLookupDNSFiles},
@ -310,7 +322,7 @@ func TestConfHostLookupOrder(t *testing.T) {
name: "resolv.conf-unknown",
c: &conf{},
resolv: &dnsConfig{servers: defaultNS, ndots: 1, timeout: 5, attempts: 2, unknownOpt: true},
nss: nssStr("foo: bar"),
nss: nssStr(t, "foo: bar"),
hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}},
},
// Android should always use cgo.
@ -320,7 +332,7 @@ func TestConfHostLookupOrder(t *testing.T) {
goos: "android",
},
resolv: defaultResolvConf,
nss: nssStr(""),
nss: nssStr(t, ""),
hostTests: []nssHostTest{
{"x.com", "myhostname", hostLookupCgo},
},
@ -335,7 +347,7 @@ func TestConfHostLookupOrder(t *testing.T) {
netCgo: true,
},
resolv: defaultResolvConf,
nss: nssStr(""),
nss: nssStr(t, ""),
hostTests: []nssHostTest{
{"localhost", "myhostname", hostLookupFilesDNS},
},

View File

@ -7,7 +7,6 @@ package net
import (
"errors"
"internal/bytealg"
"io"
"os"
"sync"
"time"
@ -148,63 +147,62 @@ func (c nssCriterion) standardStatusAction(last bool) bool {
}
func parseNSSConfFile(file string) *nssConf {
f, err := os.Open(file)
f, err := open(file)
if err != nil {
return &nssConf{err: err}
}
defer f.Close()
stat, err := f.Stat()
defer f.close()
mtime, _, err := f.stat()
if err != nil {
return &nssConf{err: err}
}
conf := parseNSSConf(f)
conf.mtime = stat.ModTime()
conf.mtime = mtime
return conf
}
func parseNSSConf(r io.Reader) *nssConf {
slurp, err := readFull(r)
if err != nil {
return &nssConf{err: err}
}
func parseNSSConf(f *file) *nssConf {
conf := new(nssConf)
conf.err = foreachLine(slurp, func(line []byte) error {
for line, ok := f.readLine(); ok; line, ok = f.readLine() {
line = trimSpace(removeComment(line))
if len(line) == 0 {
return nil
continue
}
colon := bytealg.IndexByte(line, ':')
colon := bytealg.IndexByteString(line, ':')
if colon == -1 {
return errors.New("no colon on line")
conf.err = errors.New("no colon on line")
return conf
}
db := string(trimSpace(line[:colon]))
db := trimSpace(line[:colon])
srcs := line[colon+1:]
for {
srcs = trimSpace(srcs)
if len(srcs) == 0 {
break
}
sp := bytealg.IndexByte(srcs, ' ')
sp := bytealg.IndexByteString(srcs, ' ')
var src string
if sp == -1 {
src = string(srcs)
srcs = nil // done
src = srcs
srcs = "" // done
} else {
src = string(srcs[:sp])
src = srcs[:sp]
srcs = trimSpace(srcs[sp+1:])
}
var criteria []nssCriterion
// See if there's a criteria block in brackets.
if len(srcs) > 0 && srcs[0] == '[' {
bclose := bytealg.IndexByte(srcs, ']')
bclose := bytealg.IndexByteString(srcs, ']')
if bclose == -1 {
return errors.New("unclosed criterion bracket")
conf.err = errors.New("unclosed criterion bracket")
return conf
}
var err error
criteria, err = parseCriteria(srcs[1:bclose])
if err != nil {
return errors.New("invalid criteria: " + string(srcs[1:bclose]))
conf.err = errors.New("invalid criteria: " + srcs[1:bclose])
return conf
}
srcs = srcs[bclose+1:]
}
@ -216,14 +214,13 @@ func parseNSSConf(r io.Reader) *nssConf {
criteria: criteria,
})
}
return nil
})
}
return conf
}
// parses "foo=bar !foo=bar"
func parseCriteria(x []byte) (c []nssCriterion, err error) {
err = foreachField(x, func(f []byte) error {
func parseCriteria(x string) (c []nssCriterion, err error) {
err = foreachField(x, func(f string) error {
not := false
if len(f) > 0 && f[0] == '!' {
not = true
@ -232,15 +229,19 @@ func parseCriteria(x []byte) (c []nssCriterion, err error) {
if len(f) < 3 {
return errors.New("criterion too short")
}
eq := bytealg.IndexByte(f, '=')
eq := bytealg.IndexByteString(f, '=')
if eq == -1 {
return errors.New("criterion lacks equal sign")
}
lowerASCIIBytes(f)
if hasUpperCase(f) {
lower := []byte(f)
lowerASCIIBytes(lower)
f = string(lower)
}
c = append(c, nssCriterion{
negate: not,
status: string(f[:eq]),
action: string(f[eq+1:]),
status: f[:eq],
action: f[eq+1:],
})
return nil
})

View File

@ -8,8 +8,8 @@ package net
import (
"reflect"
"strings"
"testing"
"time"
)
const ubuntuTrustyAvahi = `# /etc/nsswitch.conf
@ -34,6 +34,8 @@ netgroup: nis
`
func TestParseNSSConf(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
@ -161,7 +163,8 @@ func TestParseNSSConf(t *testing.T) {
}
for _, tt := range tests {
gotConf := parseNSSConf(strings.NewReader(tt.in))
gotConf := nssStr(t, tt.in)
gotConf.mtime = time.Time{} // ignore mtime in comparison
if !reflect.DeepEqual(gotConf, tt.want) {
t.Errorf("%s: mismatch\n got %#v\nwant %#v", tt.name, gotConf, tt.want)
}

View File

@ -64,6 +64,14 @@ func (f *file) readLine() (s string, ok bool) {
return
}
func (f *file) stat() (mtime time.Time, size int64, err error) {
st, err := f.file.Stat()
if err != nil {
return time.Time{}, 0, err
}
return st.ModTime(), st.Size(), nil
}
func open(name string) (*file, error) {
fd, err := os.Open(name)
if err != nil {
@ -236,7 +244,7 @@ func lowerASCII(b byte) byte {
}
// trimSpace returns x without any leading or trailing ASCII whitespace.
func trimSpace(x []byte) []byte {
func trimSpace(x string) string {
for len(x) > 0 && isSpace(x[0]) {
x = x[1:]
}
@ -253,37 +261,19 @@ func isSpace(b byte) bool {
// removeComment returns line, removing any '#' byte and any following
// bytes.
func removeComment(line []byte) []byte {
if i := bytealg.IndexByte(line, '#'); i != -1 {
func removeComment(line string) string {
if i := bytealg.IndexByteString(line, '#'); i != -1 {
return line[:i]
}
return line
}
// foreachLine runs fn on each line of x.
// Each line (except for possibly the last) ends in '\n'.
// It returns the first non-nil error returned by fn.
func foreachLine(x []byte, fn func(line []byte) error) error {
for len(x) > 0 {
nl := bytealg.IndexByte(x, '\n')
if nl == -1 {
return fn(x)
}
line := x[:nl+1]
x = x[nl+1:]
if err := fn(line); err != nil {
return err
}
}
return nil
}
// foreachField runs fn on each non-empty run of non-space bytes in x.
// It returns the first non-nil error returned by fn.
func foreachField(x []byte, fn func(field []byte) error) error {
func foreachField(x string, fn func(field string) error) error {
x = trimSpace(x)
for len(x) > 0 {
sp := bytealg.IndexByte(x, ' ')
sp := bytealg.IndexByteString(x, ' ')
if sp == -1 {
return fn(x)
}
@ -327,17 +317,3 @@ func stringsEqualFold(s, t string) bool {
}
return true
}
func readFull(r io.Reader) (all []byte, err error) {
buf := make([]byte, 1024)
for {
n, err := r.Read(buf)
all = append(all, buf[:n]...)
if err == io.EOF {
return all, nil
}
if err != nil {
return nil, err
}
}
}