mirror of https://github.com/golang/go.git
net,os: arrange zero-copy of os.File and net.TCPConn to net.UnixConn
Fixes #58808 goos: linux goarch: amd64 pkg: net cpu: DO-Premium-Intel │ old │ new │ │ sec/op │ sec/op vs base │ Splice/tcp-to-unix/1024-4 3.783µ ± 10% 3.201µ ± 7% -15.40% (p=0.001 n=10) Splice/tcp-to-unix/2048-4 3.967µ ± 13% 3.818µ ± 16% ~ (p=0.971 n=10) Splice/tcp-to-unix/4096-4 4.988µ ± 16% 4.590µ ± 11% ~ (p=0.089 n=10) Splice/tcp-to-unix/8192-4 6.981µ ± 13% 5.236µ ± 9% -25.00% (p=0.000 n=10) Splice/tcp-to-unix/16384-4 10.192µ ± 9% 7.350µ ± 7% -27.89% (p=0.000 n=10) Splice/tcp-to-unix/32768-4 19.65µ ± 13% 10.28µ ± 16% -47.69% (p=0.000 n=10) Splice/tcp-to-unix/65536-4 41.89µ ± 18% 15.70µ ± 13% -62.52% (p=0.000 n=10) Splice/tcp-to-unix/131072-4 90.05µ ± 11% 29.55µ ± 10% -67.18% (p=0.000 n=10) Splice/tcp-to-unix/262144-4 170.24µ ± 15% 52.66µ ± 4% -69.06% (p=0.000 n=10) Splice/tcp-to-unix/524288-4 326.4µ ± 13% 109.3µ ± 11% -66.52% (p=0.000 n=10) Splice/tcp-to-unix/1048576-4 651.4µ ± 9% 228.3µ ± 14% -64.95% (p=0.000 n=10) geomean 29.42µ 15.62µ -46.90% │ old │ new │ │ B/s │ B/s vs base │ Splice/tcp-to-unix/1024-4 258.2Mi ± 11% 305.2Mi ± 8% +18.21% (p=0.001 n=10) Splice/tcp-to-unix/2048-4 492.5Mi ± 15% 511.7Mi ± 13% ~ (p=0.971 n=10) Splice/tcp-to-unix/4096-4 783.5Mi ± 14% 851.2Mi ± 12% ~ (p=0.089 n=10) Splice/tcp-to-unix/8192-4 1.093Gi ± 11% 1.458Gi ± 8% +33.36% (p=0.000 n=10) Splice/tcp-to-unix/16384-4 1.497Gi ± 9% 2.076Gi ± 7% +38.67% (p=0.000 n=10) Splice/tcp-to-unix/32768-4 1.553Gi ± 11% 2.969Gi ± 14% +91.17% (p=0.000 n=10) Splice/tcp-to-unix/65536-4 1.458Gi ± 23% 3.888Gi ± 11% +166.69% (p=0.000 n=10) Splice/tcp-to-unix/131072-4 1.356Gi ± 10% 4.131Gi ± 9% +204.72% (p=0.000 n=10) Splice/tcp-to-unix/262144-4 1.434Gi ± 13% 4.637Gi ± 4% +223.32% (p=0.000 n=10) Splice/tcp-to-unix/524288-4 1.497Gi ± 15% 4.468Gi ± 10% +198.47% (p=0.000 n=10) Splice/tcp-to-unix/1048576-4 1.501Gi ± 10% 4.277Gi ± 16% +184.88% (p=0.000 n=10) geomean 1.038Gi 1.954Gi +88.28% │ old │ new │ │ B/op │ B/op vs base │ Splice/tcp-to-unix/1024-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/2048-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/4096-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/8192-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/16384-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/32768-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/65536-4 1.000 ± ? 0.000 ± 0% -100.00% (p=0.001 n=10) Splice/tcp-to-unix/131072-4 2.000 ± 0% 0.000 ± 0% -100.00% (p=0.000 n=10) Splice/tcp-to-unix/262144-4 4.000 ± 25% 0.000 ± 0% -100.00% (p=0.000 n=10) Splice/tcp-to-unix/524288-4 7.500 ± 33% 0.000 ± 0% -100.00% (p=0.000 n=10) Splice/tcp-to-unix/1048576-4 17.00 ± 12% 0.00 ± 0% -100.00% (p=0.000 n=10) geomean ² ? ² ³ ¹ all samples are equal ² summaries must be >0 to compute geomean ³ ratios must be >0 to compute geomean │ old │ new │ │ allocs/op │ allocs/op vs base │ Splice/tcp-to-unix/1024-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/2048-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/4096-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/8192-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/16384-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/32768-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/65536-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/131072-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/262144-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/524288-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/1048576-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ geomean ² +0.00% ² ¹ all samples are equal ² summaries must be >0 to compute geomean Change-Id: I829061b009a0929a8ef1a15c183793c0b9104dde Reviewed-on: https://go-review.googlesource.com/c/go/+/472475 Reviewed-by: Damien Neil <dneil@google.com> Reviewed-by: Bryan Mills <bcmills@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
parent
f67b2d8f0b
commit
f664031bc1
|
|
@ -0,0 +1,2 @@
|
|||
pkg net, method (*TCPConn) WriteTo(io.Writer) (int64, error) #58808
|
||||
pkg os, method (*File) WriteTo(io.Writer) (int64, error) #58808
|
||||
|
|
@ -81,3 +81,14 @@ func consume(v *[][]byte, n int64) {
|
|||
|
||||
// TestHookDidWritev is a hook for testing writev.
|
||||
var TestHookDidWritev = func(wrote int) {}
|
||||
|
||||
// String is an internal string definition for methods/functions
|
||||
// that is not intended for use outside the standard libraries.
|
||||
//
|
||||
// Other packages in std that import internal/poll and have some
|
||||
// exported APIs (now we've got some in net.rawConn) which are only used
|
||||
// internally and are not intended to be used outside the standard libraries,
|
||||
// Therefore, we make those APIs use internal types like poll.FD or poll.String
|
||||
// in their function signatures to disable the usability of these APIs from
|
||||
// external codebase.
|
||||
type String string
|
||||
|
|
|
|||
|
|
@ -264,6 +264,12 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) {
|
|||
actualReader = reflect.TypeOf(lr.R)
|
||||
} else {
|
||||
actualReader = reflect.TypeOf(mw.CalledReader)
|
||||
// We have to handle this special case for genericWriteTo in os,
|
||||
// this struct is introduced to support a zero-copy optimization,
|
||||
// check out https://go.dev/issue/58808 for details.
|
||||
if actualReader.Kind() == reflect.Struct && actualReader.PkgPath() == "os" && actualReader.Name() == "fileWithoutWriteTo" {
|
||||
actualReader = actualReader.Field(1).Type
|
||||
}
|
||||
}
|
||||
|
||||
if tc.expectedReader != actualReader {
|
||||
|
|
|
|||
|
|
@ -664,15 +664,53 @@ var errClosed = poll.ErrNetClosing
|
|||
// errors.Is(err, net.ErrClosed).
|
||||
var ErrClosed error = errClosed
|
||||
|
||||
type writerOnly struct {
|
||||
io.Writer
|
||||
// noReadFrom can be embedded alongside another type to
|
||||
// hide the ReadFrom method of that other type.
|
||||
type noReadFrom struct{}
|
||||
|
||||
// ReadFrom hides another ReadFrom method.
|
||||
// It should never be called.
|
||||
func (noReadFrom) ReadFrom(io.Reader) (int64, error) {
|
||||
panic("can't happen")
|
||||
}
|
||||
|
||||
// tcpConnWithoutReadFrom implements all the methods of *TCPConn other
|
||||
// than ReadFrom. This is used to permit ReadFrom to call io.Copy
|
||||
// without leading to a recursive call to ReadFrom.
|
||||
type tcpConnWithoutReadFrom struct {
|
||||
noReadFrom
|
||||
*TCPConn
|
||||
}
|
||||
|
||||
// Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't
|
||||
// applicable.
|
||||
func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
|
||||
func genericReadFrom(c *TCPConn, r io.Reader) (n int64, err error) {
|
||||
// Use wrapper to hide existing r.ReadFrom from io.Copy.
|
||||
return io.Copy(writerOnly{w}, r)
|
||||
return io.Copy(tcpConnWithoutReadFrom{TCPConn: c}, r)
|
||||
}
|
||||
|
||||
// noWriteTo can be embedded alongside another type to
|
||||
// hide the WriteTo method of that other type.
|
||||
type noWriteTo struct{}
|
||||
|
||||
// WriteTo hides another WriteTo method.
|
||||
// It should never be called.
|
||||
func (noWriteTo) WriteTo(io.Writer) (int64, error) {
|
||||
panic("can't happen")
|
||||
}
|
||||
|
||||
// tcpConnWithoutWriteTo implements all the methods of *TCPConn other
|
||||
// than WriteTo. This is used to permit WriteTo to call io.Copy
|
||||
// without leading to a recursive call to WriteTo.
|
||||
type tcpConnWithoutWriteTo struct {
|
||||
noWriteTo
|
||||
*TCPConn
|
||||
}
|
||||
|
||||
// Fallback implementation of io.WriterTo's WriteTo, when zero-copy isn't applicable.
|
||||
func genericWriteTo(c *TCPConn, w io.Writer) (n int64, err error) {
|
||||
// Use wrapper to hide existing w.WriteTo from io.Copy.
|
||||
return io.Copy(w, tcpConnWithoutWriteTo{TCPConn: c})
|
||||
}
|
||||
|
||||
// Limit the number of concurrent cgo-using goroutines, because
|
||||
|
|
|
|||
|
|
@ -79,6 +79,17 @@ func newRawConn(fd *netFD) *rawConn {
|
|||
return &rawConn{fd: fd}
|
||||
}
|
||||
|
||||
// Network returns the network type of the underlying connection.
|
||||
//
|
||||
// Other packages in std that import internal/poll and are unable to
|
||||
// import net (such as os) can use a type assertion to access this
|
||||
// extension method so that they can distinguish different socket types.
|
||||
//
|
||||
// Network is not intended for use outside the standard library.
|
||||
func (c *rawConn) Network() poll.String {
|
||||
return poll.String(c.fd.net)
|
||||
}
|
||||
|
||||
type rawListener struct {
|
||||
rawConn
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,29 +14,36 @@ import (
|
|||
)
|
||||
|
||||
func BenchmarkSendFile(b *testing.B) {
|
||||
b.Run("file-to-tcp", func(b *testing.B) { benchmarkSendFile(b, "tcp") })
|
||||
b.Run("file-to-unix", func(b *testing.B) { benchmarkSendFile(b, "unix") })
|
||||
}
|
||||
|
||||
func benchmarkSendFile(b *testing.B, proto string) {
|
||||
for i := 0; i <= 10; i++ {
|
||||
size := 1 << (i + 10)
|
||||
bench := sendFileBench{chunkSize: size}
|
||||
bench := sendFileBench{
|
||||
proto: proto,
|
||||
chunkSize: size,
|
||||
}
|
||||
b.Run(strconv.Itoa(size), bench.benchSendFile)
|
||||
}
|
||||
}
|
||||
|
||||
type sendFileBench struct {
|
||||
proto string
|
||||
chunkSize int
|
||||
}
|
||||
|
||||
func (bench sendFileBench) benchSendFile(b *testing.B) {
|
||||
fileSize := b.N * bench.chunkSize
|
||||
f := createTempFile(b, fileSize)
|
||||
fileName := f.Name()
|
||||
defer os.Remove(fileName)
|
||||
defer f.Close()
|
||||
|
||||
client, server := spliceTestSocketPair(b, "tcp")
|
||||
client, server := spliceTestSocketPair(b, bench.proto)
|
||||
defer server.Close()
|
||||
|
||||
cleanUp, err := startSpliceClient(client, "r", bench.chunkSize, fileSize)
|
||||
if err != nil {
|
||||
client.Close()
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
|
@ -51,15 +58,18 @@ func (bench sendFileBench) benchSendFile(b *testing.B) {
|
|||
b.Fatalf("failed to copy data with sendfile, error: %v", err)
|
||||
}
|
||||
if sent != int64(fileSize) {
|
||||
b.Fatalf("bytes sent mismatch\n\texpect: %d\n\tgot: %d", fileSize, sent)
|
||||
b.Fatalf("bytes sent mismatch, got: %d, want: %d", sent, fileSize)
|
||||
}
|
||||
}
|
||||
|
||||
func createTempFile(b *testing.B, size int) *os.File {
|
||||
f, err := os.CreateTemp("", "linux-sendfile-test")
|
||||
f, err := os.CreateTemp(b.TempDir(), "linux-sendfile-bench")
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create temporary file: %v", err)
|
||||
}
|
||||
b.Cleanup(func() {
|
||||
f.Close()
|
||||
})
|
||||
|
||||
data := make([]byte, size)
|
||||
if _, err := f.Write(data); err != nil {
|
||||
|
|
|
|||
|
|
@ -9,12 +9,12 @@ import (
|
|||
"io"
|
||||
)
|
||||
|
||||
// splice transfers data from r to c using the splice system call to minimize
|
||||
// copies from and to userspace. c must be a TCP connection. Currently, splice
|
||||
// is only enabled if r is a TCP or a stream-oriented Unix connection.
|
||||
// spliceFrom transfers data from r to c using the splice system call to minimize
|
||||
// copies from and to userspace. c must be a TCP connection.
|
||||
// Currently, spliceFrom is only enabled if r is a TCP or a stream-oriented Unix connection.
|
||||
//
|
||||
// If splice returns handled == false, it has performed no work.
|
||||
func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
|
||||
// If spliceFrom returns handled == false, it has performed no work.
|
||||
func spliceFrom(c *netFD, r io.Reader) (written int64, err error, handled bool) {
|
||||
var remain int64 = 1<<63 - 1 // by default, copy until EOF
|
||||
lr, ok := r.(*io.LimitedReader)
|
||||
if ok {
|
||||
|
|
@ -25,14 +25,17 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
|
|||
}
|
||||
|
||||
var s *netFD
|
||||
if tc, ok := r.(*TCPConn); ok {
|
||||
s = tc.fd
|
||||
} else if uc, ok := r.(*UnixConn); ok {
|
||||
if uc.fd.net != "unix" {
|
||||
switch v := r.(type) {
|
||||
case *TCPConn:
|
||||
s = v.fd
|
||||
case tcpConnWithoutWriteTo:
|
||||
s = v.fd
|
||||
case *UnixConn:
|
||||
if v.fd.net != "unix" {
|
||||
return 0, nil, false
|
||||
}
|
||||
s = uc.fd
|
||||
} else {
|
||||
s = v.fd
|
||||
default:
|
||||
return 0, nil, false
|
||||
}
|
||||
|
||||
|
|
@ -42,3 +45,18 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
|
|||
}
|
||||
return written, wrapSyscallError(sc, err), handled
|
||||
}
|
||||
|
||||
// spliceTo transfers data from c to w using the splice system call to minimize
|
||||
// copies from and to userspace. c must be a TCP connection.
|
||||
// Currently, spliceTo is only enabled if w is a stream-oriented Unix connection.
|
||||
//
|
||||
// If spliceTo returns handled == false, it has performed no work.
|
||||
func spliceTo(w io.Writer, c *netFD) (written int64, err error, handled bool) {
|
||||
uc, ok := w.(*UnixConn)
|
||||
if !ok || uc.fd.net != "unix" {
|
||||
return
|
||||
}
|
||||
|
||||
written, handled, sc, err := poll.Splice(&uc.fd.pfd, &c.pfd, 1<<63-1)
|
||||
return written, wrapSyscallError(sc, err), handled
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@ package net
|
|||
|
||||
import "io"
|
||||
|
||||
func splice(c *netFD, r io.Reader) (int64, error, bool) {
|
||||
func spliceFrom(_ *netFD, _ io.Reader) (int64, error, bool) {
|
||||
return 0, nil, false
|
||||
}
|
||||
|
||||
func spliceTo(_ io.Writer, _ *netFD) (int64, error, bool) {
|
||||
return 0, nil, false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ func TestSplice(t *testing.T) {
|
|||
t.Skip("skipping unix-to-tcp tests")
|
||||
}
|
||||
t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
|
||||
t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") })
|
||||
t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
|
||||
t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
|
||||
t.Run("no-unixpacket", testSpliceNoUnixpacket)
|
||||
|
|
@ -159,6 +160,13 @@ func (tc spliceTestCase) testFile(t *testing.T) {
|
|||
}
|
||||
|
||||
func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
|
||||
// UnixConn doesn't implement io.ReaderFrom, which will fail
|
||||
// the following test in asserting a UnixConn to be an io.ReaderFrom,
|
||||
// so skip this test.
|
||||
if upNet == "unix" || downNet == "unix" {
|
||||
t.Skip("skipping test on unix socket")
|
||||
}
|
||||
|
||||
clientUp, serverUp := spliceTestSocketPair(t, upNet)
|
||||
defer clientUp.Close()
|
||||
clientDown, serverDown := spliceTestSocketPair(t, downNet)
|
||||
|
|
@ -166,16 +174,16 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
|
|||
|
||||
serverUp.Close()
|
||||
|
||||
// We'd like to call net.splice here and check the handled return
|
||||
// We'd like to call net.spliceFrom here and check the handled return
|
||||
// value, but we disable splice on old Linux kernels.
|
||||
//
|
||||
// In that case, poll.Splice and net.splice return a non-nil error
|
||||
// In that case, poll.Splice and net.spliceFrom return a non-nil error
|
||||
// and handled == false. We'd ideally like to see handled == true
|
||||
// because the source reader is at EOF, but if we're running on an old
|
||||
// kernel, and splice is disabled, we won't see EOF from net.splice,
|
||||
// kernel, and splice is disabled, we won't see EOF from net.spliceFrom,
|
||||
// because we won't touch the reader at all.
|
||||
//
|
||||
// Trying to untangle the errors from net.splice and match them
|
||||
// Trying to untangle the errors from net.spliceFrom and match them
|
||||
// against the errors created by the poll package would be brittle,
|
||||
// so this is a higher level test.
|
||||
//
|
||||
|
|
@ -268,7 +276,7 @@ func testSpliceNoUnixpacket(t *testing.T) {
|
|||
//
|
||||
// What we want is err == nil and handled == false, i.e. we never
|
||||
// called poll.Splice, because we know the unix socket's network.
|
||||
_, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
|
||||
_, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp)
|
||||
if err != nil || handled != false {
|
||||
t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
|
||||
}
|
||||
|
|
@ -289,7 +297,7 @@ func testSpliceNoUnixgram(t *testing.T) {
|
|||
defer clientDown.Close()
|
||||
defer serverDown.Close()
|
||||
// Analogous to testSpliceNoUnixpacket.
|
||||
_, err, handled := splice(serverDown.(*TCPConn).fd, up)
|
||||
_, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up)
|
||||
if err != nil || handled != false {
|
||||
t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
|
||||
}
|
||||
|
|
@ -300,6 +308,7 @@ func BenchmarkSplice(b *testing.B) {
|
|||
|
||||
b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
|
||||
b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
|
||||
b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") })
|
||||
}
|
||||
|
||||
func benchSplice(b *testing.B, upNet, downNet string) {
|
||||
|
|
|
|||
|
|
@ -134,6 +134,18 @@ func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
|
|||
return n, err
|
||||
}
|
||||
|
||||
// WriteTo implements the io.WriterTo WriteTo method.
|
||||
func (c *TCPConn) WriteTo(w io.Writer) (int64, error) {
|
||||
if !c.ok() {
|
||||
return 0, syscall.EINVAL
|
||||
}
|
||||
n, err := c.writeTo(w)
|
||||
if err != nil && err != io.EOF {
|
||||
err = &OpError{Op: "writeto", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// CloseRead shuts down the reading side of the TCP connection.
|
||||
// Most callers should just use Close.
|
||||
func (c *TCPConn) CloseRead() error {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,10 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
|
|||
return genericReadFrom(c, r)
|
||||
}
|
||||
|
||||
func (c *TCPConn) writeTo(w io.Writer) (int64, error) {
|
||||
return genericWriteTo(c, w)
|
||||
}
|
||||
|
||||
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
|
||||
if h := sd.testHookDialTCP; h != nil {
|
||||
return h(ctx, sd.network, laddr, raddr)
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ func (a *TCPAddr) toLocal(net string) sockaddr {
|
|||
}
|
||||
|
||||
func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
|
||||
if n, err, handled := splice(c.fd, r); handled {
|
||||
if n, err, handled := spliceFrom(c.fd, r); handled {
|
||||
return n, err
|
||||
}
|
||||
if n, err, handled := sendFile(c.fd, r); handled {
|
||||
|
|
@ -54,6 +54,13 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
|
|||
return genericReadFrom(c, r)
|
||||
}
|
||||
|
||||
func (c *TCPConn) writeTo(w io.Writer) (int64, error) {
|
||||
if n, err, handled := spliceTo(w, c.fd); handled {
|
||||
return n, err
|
||||
}
|
||||
return genericWriteTo(c, w)
|
||||
}
|
||||
|
||||
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
|
||||
if h := sd.testHookDialTCP; h != nil {
|
||||
return h(ctx, sd.network, laddr, raddr)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
package os
|
||||
|
||||
var (
|
||||
PollCopyFileRangeP = &pollCopyFileRange
|
||||
PollSpliceFile = &pollSplice
|
||||
GetPollFDForTest = getPollFD
|
||||
PollCopyFileRangeP = &pollCopyFileRange
|
||||
PollSpliceFile = &pollSplice
|
||||
PollSendFile = &pollSendFile
|
||||
GetPollFDAndNetwork = getPollFDAndNetwork
|
||||
)
|
||||
|
|
|
|||
|
|
@ -157,20 +157,26 @@ func (f *File) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
return n, f.wrapErr("write", e)
|
||||
}
|
||||
|
||||
func genericReadFrom(f *File, r io.Reader) (int64, error) {
|
||||
return io.Copy(fileWithoutReadFrom{f}, r)
|
||||
// noReadFrom can be embedded alongside another type to
|
||||
// hide the ReadFrom method of that other type.
|
||||
type noReadFrom struct{}
|
||||
|
||||
// ReadFrom hides another ReadFrom method.
|
||||
// It should never be called.
|
||||
func (noReadFrom) ReadFrom(io.Reader) (int64, error) {
|
||||
panic("can't happen")
|
||||
}
|
||||
|
||||
// fileWithoutReadFrom implements all the methods of *File other
|
||||
// than ReadFrom. This is used to permit ReadFrom to call io.Copy
|
||||
// without leading to a recursive call to ReadFrom.
|
||||
type fileWithoutReadFrom struct {
|
||||
noReadFrom
|
||||
*File
|
||||
}
|
||||
|
||||
// This ReadFrom method hides the *File ReadFrom method.
|
||||
func (fileWithoutReadFrom) ReadFrom(fileWithoutReadFrom) {
|
||||
panic("unreachable")
|
||||
func genericReadFrom(f *File, r io.Reader) (int64, error) {
|
||||
return io.Copy(fileWithoutReadFrom{File: f}, r)
|
||||
}
|
||||
|
||||
// Write writes len(b) bytes from b to the File.
|
||||
|
|
@ -229,6 +235,40 @@ func (f *File) WriteAt(b []byte, off int64) (n int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// WriteTo implements io.WriterTo.
|
||||
func (f *File) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if err := f.checkValid("read"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, handled, e := f.writeTo(w)
|
||||
if handled {
|
||||
return n, f.wrapErr("read", e)
|
||||
}
|
||||
return genericWriteTo(f, w) // without wrapping
|
||||
}
|
||||
|
||||
// noWriteTo can be embedded alongside another type to
|
||||
// hide the WriteTo method of that other type.
|
||||
type noWriteTo struct{}
|
||||
|
||||
// WriteTo hides another WriteTo method.
|
||||
// It should never be called.
|
||||
func (noWriteTo) WriteTo(io.Writer) (int64, error) {
|
||||
panic("can't happen")
|
||||
}
|
||||
|
||||
// fileWithoutWriteTo implements all the methods of *File other
|
||||
// than WriteTo. This is used to permit WriteTo to call io.Copy
|
||||
// without leading to a recursive call to WriteTo.
|
||||
type fileWithoutWriteTo struct {
|
||||
noWriteTo
|
||||
*File
|
||||
}
|
||||
|
||||
func genericWriteTo(f *File, w io.Writer) (int64, error) {
|
||||
return io.Copy(w, fileWithoutWriteTo{File: f})
|
||||
}
|
||||
|
||||
// Seek sets the offset for the next Read or Write on file to offset, interpreted
|
||||
// according to whence: 0 means relative to the origin of the file, 1 means
|
||||
// relative to the current offset, and 2 means relative to the end.
|
||||
|
|
|
|||
|
|
@ -749,12 +749,12 @@ func TestProcCopy(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGetPollFDFromReader(t *testing.T) {
|
||||
t.Run("tcp", func(t *testing.T) { testGetPollFromReader(t, "tcp") })
|
||||
t.Run("unix", func(t *testing.T) { testGetPollFromReader(t, "unix") })
|
||||
func TestGetPollFDAndNetwork(t *testing.T) {
|
||||
t.Run("tcp4", func(t *testing.T) { testGetPollFDAndNetwork(t, "tcp4") })
|
||||
t.Run("unix", func(t *testing.T) { testGetPollFDAndNetwork(t, "unix") })
|
||||
}
|
||||
|
||||
func testGetPollFromReader(t *testing.T, proto string) {
|
||||
func testGetPollFDAndNetwork(t *testing.T, proto string) {
|
||||
_, server := createSocketPair(t, proto)
|
||||
sc, ok := server.(syscall.Conn)
|
||||
if !ok {
|
||||
|
|
@ -765,12 +765,15 @@ func testGetPollFromReader(t *testing.T, proto string) {
|
|||
t.Fatalf("server SyscallConn error: %v", err)
|
||||
}
|
||||
if err = rc.Control(func(fd uintptr) {
|
||||
pfd := GetPollFDForTest(server)
|
||||
pfd, network := GetPollFDAndNetwork(server)
|
||||
if pfd == nil {
|
||||
t.Fatalf("GetPollFDForTest didn't return poll.FD")
|
||||
t.Fatalf("GetPollFDAndNetwork didn't return poll.FD")
|
||||
}
|
||||
if string(network) != proto {
|
||||
t.Fatalf("GetPollFDAndNetwork returned wrong network, got: %s, want: %s", network, proto)
|
||||
}
|
||||
if pfd.Sysfd != int(fd) {
|
||||
t.Fatalf("GetPollFDForTest returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
|
||||
t.Fatalf("GetPollFDAndNetwork returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
|
||||
}
|
||||
if !pfd.IsStream {
|
||||
t.Fatalf("expected IsStream to be true")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,171 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package os_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"internal/poll"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
. "os"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSendFile(t *testing.T) {
|
||||
sizes := []int{
|
||||
1,
|
||||
42,
|
||||
1025,
|
||||
syscall.Getpagesize() + 1,
|
||||
32769,
|
||||
}
|
||||
t.Run("sendfile-to-unix", func(t *testing.T) {
|
||||
for _, size := range sizes {
|
||||
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
||||
testSendFile(t, "unix", int64(size))
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("sendfile-to-tcp", func(t *testing.T) {
|
||||
for _, size := range sizes {
|
||||
t.Run(strconv.Itoa(size), func(t *testing.T) {
|
||||
testSendFile(t, "tcp", int64(size))
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func testSendFile(t *testing.T, proto string, size int64) {
|
||||
dst, src, recv, data, hook := newSendFileTest(t, proto, size)
|
||||
|
||||
// Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile
|
||||
n, err := io.Copy(dst, src)
|
||||
if err != nil {
|
||||
t.Fatalf("io.Copy error: %v", err)
|
||||
}
|
||||
|
||||
// We should have called poll.Splice with the right file descriptor arguments.
|
||||
if n > 0 && !hook.called {
|
||||
t.Fatal("expected to called poll.SendFile")
|
||||
}
|
||||
if hook.called && hook.srcfd != int(src.Fd()) {
|
||||
t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
|
||||
}
|
||||
sc, ok := dst.(syscall.Conn)
|
||||
if !ok {
|
||||
t.Fatalf("destination is not a syscall.Conn")
|
||||
}
|
||||
rc, err := sc.SyscallConn()
|
||||
if err != nil {
|
||||
t.Fatalf("destination SyscallConn error: %v", err)
|
||||
}
|
||||
if err = rc.Control(func(fd uintptr) {
|
||||
if hook.called && hook.dstfd != int(fd) {
|
||||
t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd))
|
||||
}
|
||||
}); err != nil {
|
||||
t.Fatalf("destination Conn Control error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the data size and content.
|
||||
dataSize := len(data)
|
||||
dstData := make([]byte, dataSize)
|
||||
m, err := io.ReadFull(recv, dstData)
|
||||
if err != nil {
|
||||
t.Fatalf("server Conn Read error: %v", err)
|
||||
}
|
||||
if n != int64(dataSize) {
|
||||
t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize)
|
||||
}
|
||||
if m != dataSize {
|
||||
t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize)
|
||||
}
|
||||
if !bytes.Equal(dstData, data) {
|
||||
t.Errorf("data mismatch, got %s, want %s", dstData, data)
|
||||
}
|
||||
}
|
||||
|
||||
// newSendFileTest initializes a new test for sendfile.
|
||||
//
|
||||
// It creates source file and destination sockets, and populates the source file
|
||||
// with random data of the specified size. It also hooks package os' call
|
||||
// to poll.Sendfile and returns the hook so it can be inspected.
|
||||
func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) {
|
||||
t.Helper()
|
||||
|
||||
hook := hookSendFile(t)
|
||||
|
||||
client, server := createSocketPair(t, proto)
|
||||
tempFile, data := createTempFile(t, size)
|
||||
|
||||
return client, tempFile, server, data, hook
|
||||
}
|
||||
|
||||
func hookSendFile(t *testing.T) *sendFileHook {
|
||||
h := new(sendFileHook)
|
||||
h.install()
|
||||
t.Cleanup(h.uninstall)
|
||||
return h
|
||||
}
|
||||
|
||||
type sendFileHook struct {
|
||||
called bool
|
||||
dstfd int
|
||||
srcfd int
|
||||
remain int64
|
||||
|
||||
written int64
|
||||
handled bool
|
||||
err error
|
||||
|
||||
original func(dst *poll.FD, src int, remain int64) (int64, error, bool)
|
||||
}
|
||||
|
||||
func (h *sendFileHook) install() {
|
||||
h.original = *PollSendFile
|
||||
*PollSendFile = func(dst *poll.FD, src int, remain int64) (int64, error, bool) {
|
||||
h.called = true
|
||||
h.dstfd = dst.Sysfd
|
||||
h.srcfd = src
|
||||
h.remain = remain
|
||||
h.written, h.err, h.handled = h.original(dst, src, remain)
|
||||
return h.written, h.err, h.handled
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sendFileHook) uninstall() {
|
||||
*PollSendFile = h.original
|
||||
}
|
||||
|
||||
func createTempFile(t *testing.T, size int64) (*File, []byte) {
|
||||
f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temporary file: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
f.Close()
|
||||
})
|
||||
|
||||
randSeed := time.Now().Unix()
|
||||
t.Logf("random data seed: %d\n", randSeed)
|
||||
prng := rand.New(rand.NewSource(randSeed))
|
||||
data := make([]byte, size)
|
||||
prng.Read(data)
|
||||
if _, err := f.Write(data); err != nil {
|
||||
t.Fatalf("failed to create and feed the file: %v", err)
|
||||
}
|
||||
if err := f.Sync(); err != nil {
|
||||
t.Fatalf("failed to save the file: %v", err)
|
||||
}
|
||||
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||
t.Fatalf("failed to rewind the file: %v", err)
|
||||
}
|
||||
|
||||
return f, data
|
||||
}
|
||||
|
|
@ -13,8 +13,33 @@ import (
|
|||
var (
|
||||
pollCopyFileRange = poll.CopyFileRange
|
||||
pollSplice = poll.Splice
|
||||
pollSendFile = poll.SendFile
|
||||
)
|
||||
|
||||
func (f *File) writeTo(w io.Writer) (written int64, handled bool, err error) {
|
||||
pfd, network := getPollFDAndNetwork(w)
|
||||
// TODO(panjf2000): same as File.spliceToFile.
|
||||
if pfd == nil || !pfd.IsStream || !isUnixOrTCP(string(network)) {
|
||||
return
|
||||
}
|
||||
|
||||
sc, err := f.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rerr := sc.Read(func(fd uintptr) (done bool) {
|
||||
written, err, handled = pollSendFile(pfd, int(fd), 1<<63-1)
|
||||
return true
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
err = rerr
|
||||
}
|
||||
|
||||
return written, handled, wrapSyscallError("sendfile", err)
|
||||
}
|
||||
|
||||
func (f *File) readFrom(r io.Reader) (written int64, handled bool, err error) {
|
||||
// Neither copy_file_range(2) nor splice(2) supports destinations opened with
|
||||
// O_APPEND, so don't bother to try zero-copy with these system calls.
|
||||
|
|
@ -41,7 +66,7 @@ func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error
|
|||
return 0, true, nil
|
||||
}
|
||||
|
||||
pfd := getPollFD(r)
|
||||
pfd, _ := getPollFDAndNetwork(r)
|
||||
// TODO(panjf2000): run some tests to see if we should unlock the non-streams for splice.
|
||||
// Streams benefit the most from the splice(2), non-streams are not even supported in old kernels
|
||||
// where splice(2) will just return EINVAL; newer kernels support non-streams like UDP, but I really
|
||||
|
|
@ -63,25 +88,6 @@ func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error
|
|||
return written, handled, wrapSyscallError(syscallName, err)
|
||||
}
|
||||
|
||||
// getPollFD tries to get the poll.FD from the given io.Reader by expecting
|
||||
// the underlying type of r to be the implementation of syscall.Conn that contains
|
||||
// a *net.rawConn.
|
||||
func getPollFD(r io.Reader) *poll.FD {
|
||||
sc, ok := r.(syscall.Conn)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
rc, err := sc.SyscallConn()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
ipfd, ok := rc.(interface{ PollFD() *poll.FD })
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return ipfd.PollFD()
|
||||
}
|
||||
|
||||
func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err error) {
|
||||
var (
|
||||
remain int64
|
||||
|
|
@ -91,10 +97,16 @@ func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err erro
|
|||
return 0, true, nil
|
||||
}
|
||||
|
||||
src, ok := r.(*File)
|
||||
if !ok {
|
||||
var src *File
|
||||
switch v := r.(type) {
|
||||
case *File:
|
||||
src = v
|
||||
case fileWithoutWriteTo:
|
||||
src = v.File
|
||||
default:
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
if src.checkValid("ReadFrom") != nil {
|
||||
// Avoid returning the error as we report handled as false,
|
||||
// leave further error handling as the responsibility of the caller.
|
||||
|
|
@ -108,6 +120,28 @@ func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err erro
|
|||
return written, handled, wrapSyscallError("copy_file_range", err)
|
||||
}
|
||||
|
||||
// getPollFDAndNetwork tries to get the poll.FD and network type from the given interface
|
||||
// by expecting the underlying type of i to be the implementation of syscall.Conn
|
||||
// that contains a *net.rawConn.
|
||||
func getPollFDAndNetwork(i any) (*poll.FD, poll.String) {
|
||||
sc, ok := i.(syscall.Conn)
|
||||
if !ok {
|
||||
return nil, ""
|
||||
}
|
||||
rc, err := sc.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, ""
|
||||
}
|
||||
irc, ok := rc.(interface {
|
||||
PollFD() *poll.FD
|
||||
Network() poll.String
|
||||
})
|
||||
if !ok {
|
||||
return nil, ""
|
||||
}
|
||||
return irc.PollFD(), irc.Network()
|
||||
}
|
||||
|
||||
// tryLimitedReader tries to assert the io.Reader to io.LimitedReader, it returns the io.LimitedReader,
|
||||
// the underlying io.Reader and the remaining amount of bytes if the assertion succeeds,
|
||||
// otherwise it just returns the original io.Reader and the theoretical unlimited remaining amount of bytes.
|
||||
|
|
@ -122,3 +156,12 @@ func tryLimitedReader(r io.Reader) (*io.LimitedReader, io.Reader, int64) {
|
|||
remain = lr.N
|
||||
return lr, lr.R, remain
|
||||
}
|
||||
|
||||
func isUnixOrTCP(network string) bool {
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6", "unix":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
@ -8,6 +8,10 @@ package os
|
|||
|
||||
import "io"
|
||||
|
||||
func (f *File) writeTo(w io.Writer) (written int64, handled bool, err error) {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
func (f *File) readFrom(r io.Reader) (n int64, handled bool, err error) {
|
||||
return 0, false, nil
|
||||
}
|
||||
Loading…
Reference in New Issue