net,os: arrange zero-copy of os.File and net.TCPConn to net.UnixConn

Fixes #58808

goos: linux
goarch: amd64
pkg: net
cpu: DO-Premium-Intel
                             │      old      │                 new                  │
                             │    sec/op     │    sec/op     vs base                │
Splice/tcp-to-unix/1024-4       3.783µ ± 10%   3.201µ ±  7%  -15.40% (p=0.001 n=10)
Splice/tcp-to-unix/2048-4       3.967µ ± 13%   3.818µ ± 16%        ~ (p=0.971 n=10)
Splice/tcp-to-unix/4096-4       4.988µ ± 16%   4.590µ ± 11%        ~ (p=0.089 n=10)
Splice/tcp-to-unix/8192-4       6.981µ ± 13%   5.236µ ±  9%  -25.00% (p=0.000 n=10)
Splice/tcp-to-unix/16384-4     10.192µ ±  9%   7.350µ ±  7%  -27.89% (p=0.000 n=10)
Splice/tcp-to-unix/32768-4      19.65µ ± 13%   10.28µ ± 16%  -47.69% (p=0.000 n=10)
Splice/tcp-to-unix/65536-4      41.89µ ± 18%   15.70µ ± 13%  -62.52% (p=0.000 n=10)
Splice/tcp-to-unix/131072-4     90.05µ ± 11%   29.55µ ± 10%  -67.18% (p=0.000 n=10)
Splice/tcp-to-unix/262144-4    170.24µ ± 15%   52.66µ ±  4%  -69.06% (p=0.000 n=10)
Splice/tcp-to-unix/524288-4     326.4µ ± 13%   109.3µ ± 11%  -66.52% (p=0.000 n=10)
Splice/tcp-to-unix/1048576-4    651.4µ ±  9%   228.3µ ± 14%  -64.95% (p=0.000 n=10)
geomean                         29.42µ         15.62µ        -46.90%

                             │      old      │                  new                   │
                             │      B/s      │      B/s       vs base                 │
Splice/tcp-to-unix/1024-4      258.2Mi ± 11%   305.2Mi ±  8%   +18.21% (p=0.001 n=10)
Splice/tcp-to-unix/2048-4      492.5Mi ± 15%   511.7Mi ± 13%         ~ (p=0.971 n=10)
Splice/tcp-to-unix/4096-4      783.5Mi ± 14%   851.2Mi ± 12%         ~ (p=0.089 n=10)
Splice/tcp-to-unix/8192-4      1.093Gi ± 11%   1.458Gi ±  8%   +33.36% (p=0.000 n=10)
Splice/tcp-to-unix/16384-4     1.497Gi ±  9%   2.076Gi ±  7%   +38.67% (p=0.000 n=10)
Splice/tcp-to-unix/32768-4     1.553Gi ± 11%   2.969Gi ± 14%   +91.17% (p=0.000 n=10)
Splice/tcp-to-unix/65536-4     1.458Gi ± 23%   3.888Gi ± 11%  +166.69% (p=0.000 n=10)
Splice/tcp-to-unix/131072-4    1.356Gi ± 10%   4.131Gi ±  9%  +204.72% (p=0.000 n=10)
Splice/tcp-to-unix/262144-4    1.434Gi ± 13%   4.637Gi ±  4%  +223.32% (p=0.000 n=10)
Splice/tcp-to-unix/524288-4    1.497Gi ± 15%   4.468Gi ± 10%  +198.47% (p=0.000 n=10)
Splice/tcp-to-unix/1048576-4   1.501Gi ± 10%   4.277Gi ± 16%  +184.88% (p=0.000 n=10)
geomean                        1.038Gi         1.954Gi         +88.28%

                             │      old      │                   new                   │
                             │     B/op      │    B/op     vs base                     │
Splice/tcp-to-unix/1024-4      0.000 ±  0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/2048-4      0.000 ±  0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/4096-4      0.000 ±  0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/8192-4      0.000 ±  0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/16384-4     0.000 ±  0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/32768-4     0.000 ±  0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/65536-4     1.000 ±   ?     0.000 ± 0%  -100.00% (p=0.001 n=10)
Splice/tcp-to-unix/131072-4    2.000 ±  0%     0.000 ± 0%  -100.00% (p=0.000 n=10)
Splice/tcp-to-unix/262144-4    4.000 ± 25%     0.000 ± 0%  -100.00% (p=0.000 n=10)
Splice/tcp-to-unix/524288-4    7.500 ± 33%     0.000 ± 0%  -100.00% (p=0.000 n=10)
Splice/tcp-to-unix/1048576-4   17.00 ± 12%      0.00 ± 0%  -100.00% (p=0.000 n=10)
geomean                                    ²               ?                       ² ³
¹ all samples are equal
² summaries must be >0 to compute geomean
³ ratios must be >0 to compute geomean

                             │     old      │                 new                 │
                             │  allocs/op   │ allocs/op   vs base                 │
Splice/tcp-to-unix/1024-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/2048-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/4096-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/8192-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/16384-4     0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/32768-4     0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/65536-4     0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/131072-4    0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/262144-4    0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/524288-4    0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/1048576-4   0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
geomean                                   ²               +0.00%                ²
¹ all samples are equal
² summaries must be >0 to compute geomean

Change-Id: I829061b009a0929a8ef1a15c183793c0b9104dde
Reviewed-on: https://go-review.googlesource.com/c/go/+/472475
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Andy Pan 2023-02-28 16:39:15 +08:00 committed by Damien Neil
parent f67b2d8f0b
commit f664031bc1
18 changed files with 461 additions and 67 deletions

2
api/next/58808.txt Normal file
View File

@ -0,0 +1,2 @@
pkg net, method (*TCPConn) WriteTo(io.Writer) (int64, error) #58808
pkg os, method (*File) WriteTo(io.Writer) (int64, error) #58808

View File

@ -81,3 +81,14 @@ func consume(v *[][]byte, n int64) {
// TestHookDidWritev is a hook for testing writev.
var TestHookDidWritev = func(wrote int) {}
// String is an internal string definition for methods/functions
// that is not intended for use outside the standard libraries.
//
// Other packages in std that import internal/poll and have some
// exported APIs (now we've got some in net.rawConn) which are only used
// internally and are not intended to be used outside the standard libraries,
// Therefore, we make those APIs use internal types like poll.FD or poll.String
// in their function signatures to disable the usability of these APIs from
// external codebase.
type String string

View File

@ -264,6 +264,12 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) {
actualReader = reflect.TypeOf(lr.R)
} else {
actualReader = reflect.TypeOf(mw.CalledReader)
// We have to handle this special case for genericWriteTo in os,
// this struct is introduced to support a zero-copy optimization,
// check out https://go.dev/issue/58808 for details.
if actualReader.Kind() == reflect.Struct && actualReader.PkgPath() == "os" && actualReader.Name() == "fileWithoutWriteTo" {
actualReader = actualReader.Field(1).Type
}
}
if tc.expectedReader != actualReader {

View File

@ -664,15 +664,53 @@ var errClosed = poll.ErrNetClosing
// errors.Is(err, net.ErrClosed).
var ErrClosed error = errClosed
type writerOnly struct {
io.Writer
// noReadFrom can be embedded alongside another type to
// hide the ReadFrom method of that other type.
type noReadFrom struct{}
// ReadFrom hides another ReadFrom method.
// It should never be called.
func (noReadFrom) ReadFrom(io.Reader) (int64, error) {
panic("can't happen")
}
// tcpConnWithoutReadFrom implements all the methods of *TCPConn other
// than ReadFrom. This is used to permit ReadFrom to call io.Copy
// without leading to a recursive call to ReadFrom.
type tcpConnWithoutReadFrom struct {
noReadFrom
*TCPConn
}
// Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't
// applicable.
func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
func genericReadFrom(c *TCPConn, r io.Reader) (n int64, err error) {
// Use wrapper to hide existing r.ReadFrom from io.Copy.
return io.Copy(writerOnly{w}, r)
return io.Copy(tcpConnWithoutReadFrom{TCPConn: c}, r)
}
// noWriteTo can be embedded alongside another type to
// hide the WriteTo method of that other type.
type noWriteTo struct{}
// WriteTo hides another WriteTo method.
// It should never be called.
func (noWriteTo) WriteTo(io.Writer) (int64, error) {
panic("can't happen")
}
// tcpConnWithoutWriteTo implements all the methods of *TCPConn other
// than WriteTo. This is used to permit WriteTo to call io.Copy
// without leading to a recursive call to WriteTo.
type tcpConnWithoutWriteTo struct {
noWriteTo
*TCPConn
}
// Fallback implementation of io.WriterTo's WriteTo, when zero-copy isn't applicable.
func genericWriteTo(c *TCPConn, w io.Writer) (n int64, err error) {
// Use wrapper to hide existing w.WriteTo from io.Copy.
return io.Copy(w, tcpConnWithoutWriteTo{TCPConn: c})
}
// Limit the number of concurrent cgo-using goroutines, because

View File

@ -79,6 +79,17 @@ func newRawConn(fd *netFD) *rawConn {
return &rawConn{fd: fd}
}
// Network returns the network type of the underlying connection.
//
// Other packages in std that import internal/poll and are unable to
// import net (such as os) can use a type assertion to access this
// extension method so that they can distinguish different socket types.
//
// Network is not intended for use outside the standard library.
func (c *rawConn) Network() poll.String {
return poll.String(c.fd.net)
}
type rawListener struct {
rawConn
}

View File

@ -14,29 +14,36 @@ import (
)
func BenchmarkSendFile(b *testing.B) {
b.Run("file-to-tcp", func(b *testing.B) { benchmarkSendFile(b, "tcp") })
b.Run("file-to-unix", func(b *testing.B) { benchmarkSendFile(b, "unix") })
}
func benchmarkSendFile(b *testing.B, proto string) {
for i := 0; i <= 10; i++ {
size := 1 << (i + 10)
bench := sendFileBench{chunkSize: size}
bench := sendFileBench{
proto: proto,
chunkSize: size,
}
b.Run(strconv.Itoa(size), bench.benchSendFile)
}
}
type sendFileBench struct {
proto string
chunkSize int
}
func (bench sendFileBench) benchSendFile(b *testing.B) {
fileSize := b.N * bench.chunkSize
f := createTempFile(b, fileSize)
fileName := f.Name()
defer os.Remove(fileName)
defer f.Close()
client, server := spliceTestSocketPair(b, "tcp")
client, server := spliceTestSocketPair(b, bench.proto)
defer server.Close()
cleanUp, err := startSpliceClient(client, "r", bench.chunkSize, fileSize)
if err != nil {
client.Close()
b.Fatal(err)
}
defer cleanUp()
@ -51,15 +58,18 @@ func (bench sendFileBench) benchSendFile(b *testing.B) {
b.Fatalf("failed to copy data with sendfile, error: %v", err)
}
if sent != int64(fileSize) {
b.Fatalf("bytes sent mismatch\n\texpect: %d\n\tgot: %d", fileSize, sent)
b.Fatalf("bytes sent mismatch, got: %d, want: %d", sent, fileSize)
}
}
func createTempFile(b *testing.B, size int) *os.File {
f, err := os.CreateTemp("", "linux-sendfile-test")
f, err := os.CreateTemp(b.TempDir(), "linux-sendfile-bench")
if err != nil {
b.Fatalf("failed to create temporary file: %v", err)
}
b.Cleanup(func() {
f.Close()
})
data := make([]byte, size)
if _, err := f.Write(data); err != nil {

View File

@ -9,12 +9,12 @@ import (
"io"
)
// splice transfers data from r to c using the splice system call to minimize
// copies from and to userspace. c must be a TCP connection. Currently, splice
// is only enabled if r is a TCP or a stream-oriented Unix connection.
// spliceFrom transfers data from r to c using the splice system call to minimize
// copies from and to userspace. c must be a TCP connection.
// Currently, spliceFrom is only enabled if r is a TCP or a stream-oriented Unix connection.
//
// If splice returns handled == false, it has performed no work.
func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
// If spliceFrom returns handled == false, it has performed no work.
func spliceFrom(c *netFD, r io.Reader) (written int64, err error, handled bool) {
var remain int64 = 1<<63 - 1 // by default, copy until EOF
lr, ok := r.(*io.LimitedReader)
if ok {
@ -25,14 +25,17 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
}
var s *netFD
if tc, ok := r.(*TCPConn); ok {
s = tc.fd
} else if uc, ok := r.(*UnixConn); ok {
if uc.fd.net != "unix" {
switch v := r.(type) {
case *TCPConn:
s = v.fd
case tcpConnWithoutWriteTo:
s = v.fd
case *UnixConn:
if v.fd.net != "unix" {
return 0, nil, false
}
s = uc.fd
} else {
s = v.fd
default:
return 0, nil, false
}
@ -42,3 +45,18 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
}
return written, wrapSyscallError(sc, err), handled
}
// spliceTo transfers data from c to w using the splice system call to minimize
// copies from and to userspace. c must be a TCP connection.
// Currently, spliceTo is only enabled if w is a stream-oriented Unix connection.
//
// If spliceTo returns handled == false, it has performed no work.
func spliceTo(w io.Writer, c *netFD) (written int64, err error, handled bool) {
uc, ok := w.(*UnixConn)
if !ok || uc.fd.net != "unix" {
return
}
written, handled, sc, err := poll.Splice(&uc.fd.pfd, &c.pfd, 1<<63-1)
return written, wrapSyscallError(sc, err), handled
}

View File

@ -8,6 +8,10 @@ package net
import "io"
func splice(c *netFD, r io.Reader) (int64, error, bool) {
func spliceFrom(_ *netFD, _ io.Reader) (int64, error, bool) {
return 0, nil, false
}
func spliceTo(_ io.Writer, _ *netFD) (int64, error, bool) {
return 0, nil, false
}

View File

@ -23,6 +23,7 @@ func TestSplice(t *testing.T) {
t.Skip("skipping unix-to-tcp tests")
}
t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") })
t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
t.Run("no-unixpacket", testSpliceNoUnixpacket)
@ -159,6 +160,13 @@ func (tc spliceTestCase) testFile(t *testing.T) {
}
func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
// UnixConn doesn't implement io.ReaderFrom, which will fail
// the following test in asserting a UnixConn to be an io.ReaderFrom,
// so skip this test.
if upNet == "unix" || downNet == "unix" {
t.Skip("skipping test on unix socket")
}
clientUp, serverUp := spliceTestSocketPair(t, upNet)
defer clientUp.Close()
clientDown, serverDown := spliceTestSocketPair(t, downNet)
@ -166,16 +174,16 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
serverUp.Close()
// We'd like to call net.splice here and check the handled return
// We'd like to call net.spliceFrom here and check the handled return
// value, but we disable splice on old Linux kernels.
//
// In that case, poll.Splice and net.splice return a non-nil error
// In that case, poll.Splice and net.spliceFrom return a non-nil error
// and handled == false. We'd ideally like to see handled == true
// because the source reader is at EOF, but if we're running on an old
// kernel, and splice is disabled, we won't see EOF from net.splice,
// kernel, and splice is disabled, we won't see EOF from net.spliceFrom,
// because we won't touch the reader at all.
//
// Trying to untangle the errors from net.splice and match them
// Trying to untangle the errors from net.spliceFrom and match them
// against the errors created by the poll package would be brittle,
// so this is a higher level test.
//
@ -268,7 +276,7 @@ func testSpliceNoUnixpacket(t *testing.T) {
//
// What we want is err == nil and handled == false, i.e. we never
// called poll.Splice, because we know the unix socket's network.
_, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
_, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp)
if err != nil || handled != false {
t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
}
@ -289,7 +297,7 @@ func testSpliceNoUnixgram(t *testing.T) {
defer clientDown.Close()
defer serverDown.Close()
// Analogous to testSpliceNoUnixpacket.
_, err, handled := splice(serverDown.(*TCPConn).fd, up)
_, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up)
if err != nil || handled != false {
t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
}
@ -300,6 +308,7 @@ func BenchmarkSplice(b *testing.B) {
b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") })
}
func benchSplice(b *testing.B, upNet, downNet string) {

View File

@ -134,6 +134,18 @@ func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
return n, err
}
// WriteTo implements the io.WriterTo WriteTo method.
func (c *TCPConn) WriteTo(w io.Writer) (int64, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
n, err := c.writeTo(w)
if err != nil && err != io.EOF {
err = &OpError{Op: "writeto", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return n, err
}
// CloseRead shuts down the reading side of the TCP connection.
// Most callers should just use Close.
func (c *TCPConn) CloseRead() error {

View File

@ -14,6 +14,10 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
return genericReadFrom(c, r)
}
func (c *TCPConn) writeTo(w io.Writer) (int64, error) {
return genericWriteTo(c, w)
}
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
if h := sd.testHookDialTCP; h != nil {
return h(ctx, sd.network, laddr, raddr)

View File

@ -45,7 +45,7 @@ func (a *TCPAddr) toLocal(net string) sockaddr {
}
func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
if n, err, handled := splice(c.fd, r); handled {
if n, err, handled := spliceFrom(c.fd, r); handled {
return n, err
}
if n, err, handled := sendFile(c.fd, r); handled {
@ -54,6 +54,13 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
return genericReadFrom(c, r)
}
func (c *TCPConn) writeTo(w io.Writer) (int64, error) {
if n, err, handled := spliceTo(w, c.fd); handled {
return n, err
}
return genericWriteTo(c, w)
}
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
if h := sd.testHookDialTCP; h != nil {
return h(ctx, sd.network, laddr, raddr)

View File

@ -5,7 +5,8 @@
package os
var (
PollCopyFileRangeP = &pollCopyFileRange
PollSpliceFile = &pollSplice
GetPollFDForTest = getPollFD
PollCopyFileRangeP = &pollCopyFileRange
PollSpliceFile = &pollSplice
PollSendFile = &pollSendFile
GetPollFDAndNetwork = getPollFDAndNetwork
)

View File

@ -157,20 +157,26 @@ func (f *File) ReadFrom(r io.Reader) (n int64, err error) {
return n, f.wrapErr("write", e)
}
func genericReadFrom(f *File, r io.Reader) (int64, error) {
return io.Copy(fileWithoutReadFrom{f}, r)
// noReadFrom can be embedded alongside another type to
// hide the ReadFrom method of that other type.
type noReadFrom struct{}
// ReadFrom hides another ReadFrom method.
// It should never be called.
func (noReadFrom) ReadFrom(io.Reader) (int64, error) {
panic("can't happen")
}
// fileWithoutReadFrom implements all the methods of *File other
// than ReadFrom. This is used to permit ReadFrom to call io.Copy
// without leading to a recursive call to ReadFrom.
type fileWithoutReadFrom struct {
noReadFrom
*File
}
// This ReadFrom method hides the *File ReadFrom method.
func (fileWithoutReadFrom) ReadFrom(fileWithoutReadFrom) {
panic("unreachable")
func genericReadFrom(f *File, r io.Reader) (int64, error) {
return io.Copy(fileWithoutReadFrom{File: f}, r)
}
// Write writes len(b) bytes from b to the File.
@ -229,6 +235,40 @@ func (f *File) WriteAt(b []byte, off int64) (n int, err error) {
return
}
// WriteTo implements io.WriterTo.
func (f *File) WriteTo(w io.Writer) (n int64, err error) {
if err := f.checkValid("read"); err != nil {
return 0, err
}
n, handled, e := f.writeTo(w)
if handled {
return n, f.wrapErr("read", e)
}
return genericWriteTo(f, w) // without wrapping
}
// noWriteTo can be embedded alongside another type to
// hide the WriteTo method of that other type.
type noWriteTo struct{}
// WriteTo hides another WriteTo method.
// It should never be called.
func (noWriteTo) WriteTo(io.Writer) (int64, error) {
panic("can't happen")
}
// fileWithoutWriteTo implements all the methods of *File other
// than WriteTo. This is used to permit WriteTo to call io.Copy
// without leading to a recursive call to WriteTo.
type fileWithoutWriteTo struct {
noWriteTo
*File
}
func genericWriteTo(f *File, w io.Writer) (int64, error) {
return io.Copy(w, fileWithoutWriteTo{File: f})
}
// Seek sets the offset for the next Read or Write on file to offset, interpreted
// according to whence: 0 means relative to the origin of the file, 1 means
// relative to the current offset, and 2 means relative to the end.

View File

@ -749,12 +749,12 @@ func TestProcCopy(t *testing.T) {
}
}
func TestGetPollFDFromReader(t *testing.T) {
t.Run("tcp", func(t *testing.T) { testGetPollFromReader(t, "tcp") })
t.Run("unix", func(t *testing.T) { testGetPollFromReader(t, "unix") })
func TestGetPollFDAndNetwork(t *testing.T) {
t.Run("tcp4", func(t *testing.T) { testGetPollFDAndNetwork(t, "tcp4") })
t.Run("unix", func(t *testing.T) { testGetPollFDAndNetwork(t, "unix") })
}
func testGetPollFromReader(t *testing.T, proto string) {
func testGetPollFDAndNetwork(t *testing.T, proto string) {
_, server := createSocketPair(t, proto)
sc, ok := server.(syscall.Conn)
if !ok {
@ -765,12 +765,15 @@ func testGetPollFromReader(t *testing.T, proto string) {
t.Fatalf("server SyscallConn error: %v", err)
}
if err = rc.Control(func(fd uintptr) {
pfd := GetPollFDForTest(server)
pfd, network := GetPollFDAndNetwork(server)
if pfd == nil {
t.Fatalf("GetPollFDForTest didn't return poll.FD")
t.Fatalf("GetPollFDAndNetwork didn't return poll.FD")
}
if string(network) != proto {
t.Fatalf("GetPollFDAndNetwork returned wrong network, got: %s, want: %s", network, proto)
}
if pfd.Sysfd != int(fd) {
t.Fatalf("GetPollFDForTest returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
t.Fatalf("GetPollFDAndNetwork returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
}
if !pfd.IsStream {
t.Fatalf("expected IsStream to be true")

View File

@ -0,0 +1,171 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os_test
import (
"bytes"
"internal/poll"
"io"
"math/rand"
"net"
. "os"
"strconv"
"syscall"
"testing"
"time"
)
func TestSendFile(t *testing.T) {
sizes := []int{
1,
42,
1025,
syscall.Getpagesize() + 1,
32769,
}
t.Run("sendfile-to-unix", func(t *testing.T) {
for _, size := range sizes {
t.Run(strconv.Itoa(size), func(t *testing.T) {
testSendFile(t, "unix", int64(size))
})
}
})
t.Run("sendfile-to-tcp", func(t *testing.T) {
for _, size := range sizes {
t.Run(strconv.Itoa(size), func(t *testing.T) {
testSendFile(t, "tcp", int64(size))
})
}
})
}
func testSendFile(t *testing.T, proto string, size int64) {
dst, src, recv, data, hook := newSendFileTest(t, proto, size)
// Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile
n, err := io.Copy(dst, src)
if err != nil {
t.Fatalf("io.Copy error: %v", err)
}
// We should have called poll.Splice with the right file descriptor arguments.
if n > 0 && !hook.called {
t.Fatal("expected to called poll.SendFile")
}
if hook.called && hook.srcfd != int(src.Fd()) {
t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
}
sc, ok := dst.(syscall.Conn)
if !ok {
t.Fatalf("destination is not a syscall.Conn")
}
rc, err := sc.SyscallConn()
if err != nil {
t.Fatalf("destination SyscallConn error: %v", err)
}
if err = rc.Control(func(fd uintptr) {
if hook.called && hook.dstfd != int(fd) {
t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd))
}
}); err != nil {
t.Fatalf("destination Conn Control error: %v", err)
}
// Verify the data size and content.
dataSize := len(data)
dstData := make([]byte, dataSize)
m, err := io.ReadFull(recv, dstData)
if err != nil {
t.Fatalf("server Conn Read error: %v", err)
}
if n != int64(dataSize) {
t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize)
}
if m != dataSize {
t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize)
}
if !bytes.Equal(dstData, data) {
t.Errorf("data mismatch, got %s, want %s", dstData, data)
}
}
// newSendFileTest initializes a new test for sendfile.
//
// It creates source file and destination sockets, and populates the source file
// with random data of the specified size. It also hooks package os' call
// to poll.Sendfile and returns the hook so it can be inspected.
func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) {
t.Helper()
hook := hookSendFile(t)
client, server := createSocketPair(t, proto)
tempFile, data := createTempFile(t, size)
return client, tempFile, server, data, hook
}
func hookSendFile(t *testing.T) *sendFileHook {
h := new(sendFileHook)
h.install()
t.Cleanup(h.uninstall)
return h
}
type sendFileHook struct {
called bool
dstfd int
srcfd int
remain int64
written int64
handled bool
err error
original func(dst *poll.FD, src int, remain int64) (int64, error, bool)
}
func (h *sendFileHook) install() {
h.original = *PollSendFile
*PollSendFile = func(dst *poll.FD, src int, remain int64) (int64, error, bool) {
h.called = true
h.dstfd = dst.Sysfd
h.srcfd = src
h.remain = remain
h.written, h.err, h.handled = h.original(dst, src, remain)
return h.written, h.err, h.handled
}
}
func (h *sendFileHook) uninstall() {
*PollSendFile = h.original
}
func createTempFile(t *testing.T, size int64) (*File, []byte) {
f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket")
if err != nil {
t.Fatalf("failed to create temporary file: %v", err)
}
t.Cleanup(func() {
f.Close()
})
randSeed := time.Now().Unix()
t.Logf("random data seed: %d\n", randSeed)
prng := rand.New(rand.NewSource(randSeed))
data := make([]byte, size)
prng.Read(data)
if _, err := f.Write(data); err != nil {
t.Fatalf("failed to create and feed the file: %v", err)
}
if err := f.Sync(); err != nil {
t.Fatalf("failed to save the file: %v", err)
}
if _, err := f.Seek(0, io.SeekStart); err != nil {
t.Fatalf("failed to rewind the file: %v", err)
}
return f, data
}

View File

@ -13,8 +13,33 @@ import (
var (
pollCopyFileRange = poll.CopyFileRange
pollSplice = poll.Splice
pollSendFile = poll.SendFile
)
func (f *File) writeTo(w io.Writer) (written int64, handled bool, err error) {
pfd, network := getPollFDAndNetwork(w)
// TODO(panjf2000): same as File.spliceToFile.
if pfd == nil || !pfd.IsStream || !isUnixOrTCP(string(network)) {
return
}
sc, err := f.SyscallConn()
if err != nil {
return
}
rerr := sc.Read(func(fd uintptr) (done bool) {
written, err, handled = pollSendFile(pfd, int(fd), 1<<63-1)
return true
})
if err == nil {
err = rerr
}
return written, handled, wrapSyscallError("sendfile", err)
}
func (f *File) readFrom(r io.Reader) (written int64, handled bool, err error) {
// Neither copy_file_range(2) nor splice(2) supports destinations opened with
// O_APPEND, so don't bother to try zero-copy with these system calls.
@ -41,7 +66,7 @@ func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error
return 0, true, nil
}
pfd := getPollFD(r)
pfd, _ := getPollFDAndNetwork(r)
// TODO(panjf2000): run some tests to see if we should unlock the non-streams for splice.
// Streams benefit the most from the splice(2), non-streams are not even supported in old kernels
// where splice(2) will just return EINVAL; newer kernels support non-streams like UDP, but I really
@ -63,25 +88,6 @@ func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error
return written, handled, wrapSyscallError(syscallName, err)
}
// getPollFD tries to get the poll.FD from the given io.Reader by expecting
// the underlying type of r to be the implementation of syscall.Conn that contains
// a *net.rawConn.
func getPollFD(r io.Reader) *poll.FD {
sc, ok := r.(syscall.Conn)
if !ok {
return nil
}
rc, err := sc.SyscallConn()
if err != nil {
return nil
}
ipfd, ok := rc.(interface{ PollFD() *poll.FD })
if !ok {
return nil
}
return ipfd.PollFD()
}
func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err error) {
var (
remain int64
@ -91,10 +97,16 @@ func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err erro
return 0, true, nil
}
src, ok := r.(*File)
if !ok {
var src *File
switch v := r.(type) {
case *File:
src = v
case fileWithoutWriteTo:
src = v.File
default:
return 0, false, nil
}
if src.checkValid("ReadFrom") != nil {
// Avoid returning the error as we report handled as false,
// leave further error handling as the responsibility of the caller.
@ -108,6 +120,28 @@ func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err erro
return written, handled, wrapSyscallError("copy_file_range", err)
}
// getPollFDAndNetwork tries to get the poll.FD and network type from the given interface
// by expecting the underlying type of i to be the implementation of syscall.Conn
// that contains a *net.rawConn.
func getPollFDAndNetwork(i any) (*poll.FD, poll.String) {
sc, ok := i.(syscall.Conn)
if !ok {
return nil, ""
}
rc, err := sc.SyscallConn()
if err != nil {
return nil, ""
}
irc, ok := rc.(interface {
PollFD() *poll.FD
Network() poll.String
})
if !ok {
return nil, ""
}
return irc.PollFD(), irc.Network()
}
// tryLimitedReader tries to assert the io.Reader to io.LimitedReader, it returns the io.LimitedReader,
// the underlying io.Reader and the remaining amount of bytes if the assertion succeeds,
// otherwise it just returns the original io.Reader and the theoretical unlimited remaining amount of bytes.
@ -122,3 +156,12 @@ func tryLimitedReader(r io.Reader) (*io.LimitedReader, io.Reader, int64) {
remain = lr.N
return lr, lr.R, remain
}
func isUnixOrTCP(network string) bool {
switch network {
case "tcp", "tcp4", "tcp6", "unix":
return true
default:
return false
}
}

View File

@ -8,6 +8,10 @@ package os
import "io"
func (f *File) writeTo(w io.Writer) (written int64, handled bool, err error) {
return 0, false, nil
}
func (f *File) readFrom(r io.Reader) (n int64, handled bool, err error) {
return 0, false, nil
}