diff --git a/src/internal/poll/fd_windows.go b/src/internal/poll/fd_windows.go index 3c11ce5bb4..14b1febbf4 100644 --- a/src/internal/poll/fd_windows.go +++ b/src/internal/poll/fd_windows.go @@ -88,6 +88,17 @@ type operation struct { bufs []syscall.WSABuf } +func (o *operation) overlapped() *syscall.Overlapped { + if o.runtimeCtx == 0 { + // Don't return the overlapped object if the file handle + // doesn't use overlapped I/O. It could be used, but + // that would then use the file pointer stored in the + // overlapped object rather than the real file pointer. + return nil + } + return &o.o +} + func (o *operation) InitBuf(buf []byte) { o.buf.Len = uint32(len(buf)) o.buf.Buf = nil @@ -142,15 +153,9 @@ func (o *operation) InitMsg(p []byte, oob []byte) { } } -// execIO executes a single IO operation o. It submits and cancels -// IO in the current thread for systems where Windows CancelIoEx API -// is available. Alternatively, it passes the request onto -// runtime netpoll and waits for completion or cancels request. +// execIO executes a single IO operation o. +// It supports both synchronous and asynchronous IO. func execIO(o *operation, submit func(o *operation) error) (int, error) { - if o.fd.pd.runtimeCtx == 0 { - return 0, errors.New("internal error: polling on unsupported descriptor type") - } - fd := o.fd // Notify runtime netpoll about starting IO. err := fd.pd.prepare(int(o.mode), fd.isFile) @@ -159,6 +164,12 @@ func execIO(o *operation, submit func(o *operation) error) (int, error) { } // Start IO. err = submit(o) + if !fd.pd.pollable() { + if err != nil { + return 0, err + } + return int(o.qty), nil + } switch err { case nil: // IO completed immediately @@ -176,7 +187,11 @@ func execIO(o *operation, submit func(o *operation) error) (int, error) { // Wait for our request to complete. err = fd.pd.wait(int(o.mode), fd.isFile) if err == nil { - err = windows.WSAGetOverlappedResult(fd.Sysfd, &o.o, &o.qty, false, &o.flags) + if fd.isFile { + err = windows.GetOverlappedResult(fd.Sysfd, &o.o, &o.qty, false) + } else { + err = windows.WSAGetOverlappedResult(fd.Sysfd, &o.o, &o.qty, false, &o.flags) + } // All is good. Extract our IO results and return. if err != nil { // More data available. Return back the size of received data. @@ -204,7 +219,11 @@ func execIO(o *operation, submit func(o *operation) error) (int, error) { } // Wait for cancellation to complete. fd.pd.waitCanceled(int(o.mode)) - err = windows.WSAGetOverlappedResult(fd.Sysfd, &o.o, &o.qty, false, &o.flags) + if fd.isFile { + err = windows.GetOverlappedResult(fd.Sysfd, &o.o, &o.qty, true) + } else { + err = windows.WSAGetOverlappedResult(fd.Sysfd, &o.o, &o.qty, false, &o.flags) + } if err != nil { if err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled err = netpollErr @@ -237,6 +256,11 @@ type FD struct { // Used to implement pread/pwrite. l sync.Mutex + // The file offset for the next read or write. + // Overlapped IO operations don't use the real file pointer, + // so we need to keep track of the offset ourselves. + offset int64 + // For console I/O. lastbits []byte // first few bytes of the last incomplete rune in last write readuint16 []uint16 // buffer to hold uint16s obtained with ReadConsole @@ -263,6 +287,30 @@ type FD struct { kind fileKind } +// setOffset sets the offset fields of the overlapped object +// to the given offset. The fd.l lock must be held. +// +// Overlapped IO operations don't update the offset fields +// of the overlapped object nor the file pointer automatically, +// so we do that manually here. +// Note that this is a best effort that only works if the file +// pointer is completely owned by this operation. We could +// call seek to allow other processes or other operations on the +// same file to see the updated offset. That would be inefficient +// and won't work for concurrent operations anyway. If concurrent +// operations are needed, then the caller should serialize them +// using an external mechanism. +func (fd *FD) setOffset(off int64) { + fd.offset = off + fd.rop.o.OffsetHigh, fd.rop.o.Offset = uint32(off>>32), uint32(off) + fd.wop.o.OffsetHigh, fd.wop.o.Offset = uint32(off>>32), uint32(off) +} + +// addOffset adds the given offset to the current offset. +func (fd *FD) addOffset(off int) { + fd.setOffset(fd.offset + int64(off)) +} + // fileKind describes the kind of file. type fileKind byte @@ -302,6 +350,10 @@ func (fd *FD) Init(net string, pollable bool) error { return errors.New("internal error: unknown network type " + net) } fd.isFile = fd.kind != kindNet + fd.rop.mode = 'r' + fd.wop.mode = 'w' + fd.rop.fd = fd + fd.wop.fd = fd var err error if pollable { @@ -328,10 +380,6 @@ func (fd *FD) Init(net string, pollable bool) error { fd.skipSyncNotif = true } } - fd.rop.mode = 'r' - fd.wop.mode = 'w' - fd.rop.fd = fd - fd.wop.fd = fd fd.rop.runtimeCtx = fd.pd.runtimeCtx fd.wop.runtimeCtx = fd.pd.runtimeCtx return nil @@ -400,12 +448,23 @@ func (fd *FD) Read(buf []byte) (int, error) { case kindConsole: n, err = fd.readConsole(buf) default: - n, err = syscall.Read(fd.Sysfd, buf) - if fd.kind == kindPipe && err == syscall.ERROR_OPERATION_ABORTED { - // Close uses CancelIoEx to interrupt concurrent I/O for pipes. - // If the fd is a pipe and the Read was interrupted by CancelIoEx, - // we assume it is interrupted by Close. - err = ErrFileClosing + o := &fd.rop + o.InitBuf(buf) + n, err = execIO(o, func(o *operation) error { + return syscall.ReadFile(o.fd.Sysfd, unsafe.Slice(o.buf.Buf, o.buf.Len), &o.qty, o.overlapped()) + }) + fd.addOffset(n) + if fd.kind == kindPipe && err != nil { + switch err { + case syscall.ERROR_BROKEN_PIPE: + // Returned by pipes when the other end is closed. + err = nil + case syscall.ERROR_OPERATION_ABORTED: + // Close uses CancelIoEx to interrupt concurrent I/O for pipes. + // If the fd is a pipe and the Read was interrupted by CancelIoEx, + // we assume it is interrupted by Close. + err = ErrFileClosing + } } } if err != nil { @@ -520,27 +579,28 @@ func (fd *FD) Pread(b []byte, off int64) (int, error) { fd.l.Lock() defer fd.l.Unlock() - curoffset, e := syscall.Seek(fd.Sysfd, 0, io.SeekCurrent) - if e != nil { - return 0, e + curoffset, err := syscall.Seek(fd.Sysfd, 0, io.SeekCurrent) + if err != nil { + return 0, err } defer syscall.Seek(fd.Sysfd, curoffset, io.SeekStart) - o := syscall.Overlapped{ - OffsetHigh: uint32(off >> 32), - Offset: uint32(off), - } - var done uint32 - e = syscall.ReadFile(fd.Sysfd, b, &done, &o) - if e != nil { - done = 0 - if e == syscall.ERROR_HANDLE_EOF { - e = io.EOF + defer fd.setOffset(curoffset) + o := &fd.rop + o.InitBuf(b) + fd.setOffset(off) + n, err := execIO(o, func(o *operation) error { + return syscall.ReadFile(o.fd.Sysfd, unsafe.Slice(o.buf.Buf, o.buf.Len), &o.qty, &o.o) + }) + if err != nil { + n = 0 + if err == syscall.ERROR_HANDLE_EOF { + err = io.EOF } } if len(b) != 0 { - e = fd.eofError(int(done), e) + err = fd.eofError(n, err) } - return int(done), e + return n, err } // ReadFrom wraps the recvfrom network call. @@ -654,7 +714,12 @@ func (fd *FD) Write(buf []byte) (int, error) { case kindConsole: n, err = fd.writeConsole(b) default: - n, err = syscall.Write(fd.Sysfd, b) + o := &fd.wop + o.InitBuf(b) + n, err = execIO(o, func(o *operation) error { + return syscall.WriteFile(o.fd.Sysfd, unsafe.Slice(o.buf.Buf, o.buf.Len), &o.qty, o.overlapped()) + }) + fd.addOffset(n) if fd.kind == kindPipe && err == syscall.ERROR_OPERATION_ABORTED { // Close uses CancelIoEx to interrupt concurrent I/O for pipes. // If the fd is a pipe and the Write was interrupted by CancelIoEx, @@ -742,11 +807,12 @@ func (fd *FD) Pwrite(buf []byte, off int64) (int, error) { fd.l.Lock() defer fd.l.Unlock() - curoffset, e := syscall.Seek(fd.Sysfd, 0, io.SeekCurrent) - if e != nil { - return 0, e + curoffset, err := syscall.Seek(fd.Sysfd, 0, io.SeekCurrent) + if err != nil { + return 0, err } defer syscall.Seek(fd.Sysfd, curoffset, io.SeekStart) + defer fd.setOffset(curoffset) ntotal := 0 for len(buf) > 0 { @@ -754,15 +820,15 @@ func (fd *FD) Pwrite(buf []byte, off int64) (int, error) { if len(b) > maxRW { b = b[:maxRW] } - var n uint32 - o := syscall.Overlapped{ - OffsetHigh: uint32(off >> 32), - Offset: uint32(off), - } - e = syscall.WriteFile(fd.Sysfd, b, &n, &o) + o := &fd.wop + o.InitBuf(b) + fd.setOffset(off) + n, err := execIO(o, func(o *operation) error { + return syscall.WriteFile(o.fd.Sysfd, unsafe.Slice(o.buf.Buf, o.buf.Len), &o.qty, &o.o) + }) ntotal += int(n) - if e != nil { - return ntotal, e + if err != nil { + return ntotal, err } buf = buf[n:] off += int64(n) @@ -992,7 +1058,9 @@ func (fd *FD) Seek(offset int64, whence int) (int64, error) { fd.l.Lock() defer fd.l.Unlock() - return syscall.Seek(fd.Sysfd, offset, whence) + n, err := syscall.Seek(fd.Sysfd, offset, whence) + fd.setOffset(n) + return n, err } // Fchmod updates syscall.ByHandleFileInformation.Fileattributes when needed. diff --git a/src/internal/poll/fd_windows_test.go b/src/internal/poll/fd_windows_test.go index 87273c08ac..f2bc9b2a21 100644 --- a/src/internal/poll/fd_windows_test.go +++ b/src/internal/poll/fd_windows_test.go @@ -5,12 +5,17 @@ package poll_test import ( + "bytes" "errors" "fmt" "internal/poll" "internal/syscall/windows" + "io" "os" + "path/filepath" + "strconv" "sync" + "sync/atomic" "syscall" "testing" "unsafe" @@ -184,3 +189,238 @@ type _TCP_INFO_v0 struct { TimeoutEpisodes uint32 SynRetrans uint8 } + +func newFD(t testing.TB, h syscall.Handle, kind string, overlapped bool) *poll.FD { + fd := poll.FD{ + Sysfd: h, + IsStream: true, + ZeroReadIsEOF: true, + } + err := fd.Init(kind, true) + if overlapped && err != nil { + // Overlapped file handles should not error. + t.Fatal(err) + } else if !overlapped && err == nil { + // Non-overlapped file handles should return an error but still + // be usable as sync handles. + t.Fatal("expected error for non-overlapped file handle") + } + return &fd +} + +func newFile(t testing.TB, name string, overlapped bool) *poll.FD { + namep, err := syscall.UTF16PtrFromString(name) + if err != nil { + t.Fatal(err) + } + flags := syscall.FILE_ATTRIBUTE_NORMAL + if overlapped { + flags |= syscall.FILE_FLAG_OVERLAPPED + } + h, err := syscall.CreateFile(namep, + syscall.GENERIC_READ|syscall.GENERIC_WRITE, + syscall.FILE_SHARE_WRITE|syscall.FILE_SHARE_READ, + nil, syscall.OPEN_ALWAYS, uint32(flags), 0) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := syscall.CloseHandle(h); err != nil { + t.Fatal(err) + } + }) + return newFD(t, h, "file", overlapped) +} + +var currentProces = sync.OnceValue(func() string { + // Convert the process ID to a string. + return strconv.FormatUint(uint64(os.Getpid()), 10) +}) + +var pipeCounter atomic.Uint64 + +func newPipe(t testing.TB, overlapped bool) (string, *poll.FD) { + name := `\\.\pipe\go-internal-poll-test-` + currentProces() + `-` + strconv.FormatUint(pipeCounter.Add(1), 10) + wname, err := syscall.UTF16PtrFromString(name) + if err != nil { + t.Fatal(err) + } + // Create the read handle. + flags := windows.PIPE_ACCESS_DUPLEX + if overlapped { + flags |= syscall.FILE_FLAG_OVERLAPPED + } + h, err := windows.CreateNamedPipe(wname, uint32(flags), windows.PIPE_TYPE_BYTE, 1, 4096, 4096, 0, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := syscall.CloseHandle(h); err != nil { + t.Fatal(err) + } + }) + return name, newFD(t, h, "pipe", overlapped) +} + +func testReadWrite(t *testing.T, fdr, fdw *poll.FD) { + write := make(chan string, 1) + read := make(chan struct{}, 1) + go func() { + for s := range write { + n, err := fdw.Write([]byte(s)) + read <- struct{}{} + if err != nil { + t.Error(err) + } + if n != len(s) { + t.Errorf("expected to write %d bytes, got %d", len(s), n) + } + } + }() + for i := range 10 { + s := strconv.Itoa(i) + write <- s + <-read + buf := make([]byte, len(s)) + _, err := io.ReadFull(fdr, buf) + if err != nil { + t.Fatalf("read failed: %v", err) + } + if !bytes.Equal(buf, []byte(s)) { + t.Fatalf("expected %q, got %q", s, buf) + } + } + close(read) + close(write) +} + +func testPreadPwrite(t *testing.T, fdr, fdw *poll.FD) { + type op struct { + s string + off int64 + } + write := make(chan op, 1) + read := make(chan struct{}, 1) + go func() { + for o := range write { + n, err := fdw.Pwrite([]byte(o.s), o.off) + read <- struct{}{} + if err != nil { + t.Error(err) + } + if n != len(o.s) { + t.Errorf("expected to write %d bytes, got %d", len(o.s), n) + } + } + }() + for i := range 10 { + off := int64(i % 3) // exercise some back and forth + s := strconv.Itoa(i) + write <- op{s, off} + <-read + buf := make([]byte, len(s)) + n, err := fdr.Pread(buf, off) + if err != nil { + t.Fatal(err) + } + if n != len(s) { + t.Fatalf("expected to read %d bytes, got %d", len(s), n) + } + if !bytes.Equal(buf, []byte(s)) { + t.Fatalf("expected %q, got %q", s, buf) + } + } + close(read) + close(write) +} + +func TestFile(t *testing.T) { + test := func(t *testing.T, r, w bool) { + name := filepath.Join(t.TempDir(), "foo") + rh := newFile(t, name, r) + wh := newFile(t, name, w) + testReadWrite(t, rh, wh) + testPreadPwrite(t, rh, wh) + } + t.Run("overlapped", func(t *testing.T) { + test(t, true, true) + }) + t.Run("overlapped-read", func(t *testing.T) { + test(t, true, false) + }) + t.Run("overlapped-write", func(t *testing.T) { + test(t, false, true) + }) + t.Run("sync", func(t *testing.T) { + test(t, false, false) + }) +} + +func TestPipe(t *testing.T) { + t.Run("overlapped", func(t *testing.T) { + name, pipe := newPipe(t, true) + file := newFile(t, name, true) + testReadWrite(t, pipe, file) + }) + t.Run("overlapped-write", func(t *testing.T) { + name, pipe := newPipe(t, true) + file := newFile(t, name, false) + testReadWrite(t, file, pipe) + }) + t.Run("overlapped-read", func(t *testing.T) { + name, pipe := newPipe(t, false) + file := newFile(t, name, true) + testReadWrite(t, file, pipe) + }) + t.Run("sync", func(t *testing.T) { + name, pipe := newPipe(t, false) + file := newFile(t, name, false) + testReadWrite(t, file, pipe) + }) + t.Run("anonymous", func(t *testing.T) { + var r, w syscall.Handle + if err := syscall.CreatePipe(&r, &w, nil, 0); err != nil { + t.Fatal(err) + } + defer func() { + if err := syscall.CloseHandle(r); err != nil { + t.Fatal(err) + } + if err := syscall.CloseHandle(w); err != nil { + t.Fatal(err) + } + }() + // CreatePipe always returns sync handles. + fdr := newFD(t, r, "pipe", false) + fdw := newFD(t, w, "file", false) + testReadWrite(t, fdr, fdw) + }) +} + +func BenchmarkReadOverlapped(b *testing.B) { + benchmarkRead(b, true) +} + +func BenchmarkReadSync(b *testing.B) { + benchmarkRead(b, false) +} + +func benchmarkRead(b *testing.B, overlapped bool) { + name := filepath.Join(b.TempDir(), "foo") + const content = "hello world" + err := os.WriteFile(name, []byte(content), 0644) + if err != nil { + b.Fatal(err) + } + file := newFile(b, name, overlapped) + var buf [len(content)]byte + for b.Loop() { + _, err := io.ReadFull(file, buf[:]) + if err != nil { + b.Fatal(err) + } + if _, err := file.Seek(0, io.SeekStart); err != nil { + b.Fatal(err) + } + } +} diff --git a/src/internal/syscall/windows/syscall_windows.go b/src/internal/syscall/windows/syscall_windows.go index e4d42f3dae..af542c8003 100644 --- a/src/internal/syscall/windows/syscall_windows.go +++ b/src/internal/syscall/windows/syscall_windows.go @@ -504,6 +504,17 @@ func QueryPerformanceFrequency() int64 // Implemented in runtime package. //sys GetModuleHandle(modulename *uint16) (handle syscall.Handle, err error) = kernel32.GetModuleHandleW +const ( + PIPE_ACCESS_INBOUND = 0x00000001 + PIPE_ACCESS_OUTBOUND = 0x00000002 + PIPE_ACCESS_DUPLEX = 0x00000003 + + PIPE_TYPE_BYTE = 0x00000000 +) + +//sys GetOverlappedResult(handle syscall.Handle, overlapped *syscall.Overlapped, done *uint32, wait bool) (err error) +//sys CreateNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateNamedPipeW + // NTStatus corresponds with NTSTATUS, error values returned by ntdll.dll and // other native functions. type NTStatus uint32 diff --git a/src/internal/syscall/windows/zsyscall_windows.go b/src/internal/syscall/windows/zsyscall_windows.go index f7b89e9ca3..d2d40440e3 100644 --- a/src/internal/syscall/windows/zsyscall_windows.go +++ b/src/internal/syscall/windows/zsyscall_windows.go @@ -66,6 +66,7 @@ var ( procProcessPrng = modbcryptprimitives.NewProc("ProcessPrng") procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses") procCreateEventW = modkernel32.NewProc("CreateEventW") + procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW") procGetACP = modkernel32.NewProc("GetACP") procGetComputerNameExW = modkernel32.NewProc("GetComputerNameExW") procGetConsoleCP = modkernel32.NewProc("GetConsoleCP") @@ -74,6 +75,7 @@ var ( procGetFinalPathNameByHandleW = modkernel32.NewProc("GetFinalPathNameByHandleW") procGetModuleFileNameW = modkernel32.NewProc("GetModuleFileNameW") procGetModuleHandleW = modkernel32.NewProc("GetModuleHandleW") + procGetOverlappedResult = modkernel32.NewProc("GetOverlappedResult") procGetTempPath2W = modkernel32.NewProc("GetTempPath2W") procGetVolumeInformationByHandleW = modkernel32.NewProc("GetVolumeInformationByHandleW") procGetVolumeNameForVolumeMountPointW = modkernel32.NewProc("GetVolumeNameForVolumeMountPointW") @@ -266,6 +268,15 @@ func CreateEvent(eventAttrs *SecurityAttributes, manualReset uint32, initialStat return } +func CreateNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *SecurityAttributes) (handle syscall.Handle, err error) { + r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0) + handle = syscall.Handle(r0) + if handle == syscall.InvalidHandle { + err = errnoErr(e1) + } + return +} + func GetACP() (acp uint32) { r0, _, _ := syscall.Syscall(procGetACP.Addr(), 0, 0, 0, 0) acp = uint32(r0) @@ -330,6 +341,18 @@ func GetModuleHandle(modulename *uint16) (handle syscall.Handle, err error) { return } +func GetOverlappedResult(handle syscall.Handle, overlapped *syscall.Overlapped, done *uint32, wait bool) (err error) { + var _p0 uint32 + if wait { + _p0 = 1 + } + r1, _, e1 := syscall.Syscall6(procGetOverlappedResult.Addr(), 4, uintptr(handle), uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(done)), uintptr(_p0), 0, 0) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + func GetTempPath2(buflen uint32, buf *uint16) (n uint32, err error) { r0, _, e1 := syscall.Syscall(procGetTempPath2W.Addr(), 2, uintptr(buflen), uintptr(unsafe.Pointer(buf)), 0) n = uint32(r0)