mirror of https://github.com/golang/go.git
net: add shutdown: TCPConn.CloseWrite and CloseRead
R=golang-dev, rsc, iant CC=golang-dev https://golang.org/cl/5136052
This commit is contained in:
parent
260991ad5f
commit
394842e2a5
|
|
@ -358,6 +358,22 @@ func (fd *netFD) Close() os.Error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (fd *netFD) CloseRead() os.Error {
|
||||
if fd == nil || fd.sysfile == nil {
|
||||
return os.EINVAL
|
||||
}
|
||||
syscall.Shutdown(fd.sysfd, syscall.SHUT_RD)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fd *netFD) CloseWrite() os.Error {
|
||||
if fd == nil || fd.sysfile == nil {
|
||||
return os.EINVAL
|
||||
}
|
||||
syscall.Shutdown(fd.sysfd, syscall.SHUT_WR)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fd *netFD) Read(p []byte) (n int, err os.Error) {
|
||||
if fd == nil {
|
||||
return 0, os.EINVAL
|
||||
|
|
|
|||
|
|
@ -312,6 +312,22 @@ func (fd *netFD) Close() os.Error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (fd *netFD) CloseRead() os.Error {
|
||||
if fd == nil || fd.sysfd == syscall.InvalidHandle {
|
||||
return os.EINVAL
|
||||
}
|
||||
syscall.Shutdown(fd.sysfd, syscall.SHUT_RD)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fd *netFD) CloseWrite() os.Error {
|
||||
if fd == nil || fd.sysfd == syscall.InvalidHandle {
|
||||
return os.EINVAL
|
||||
}
|
||||
syscall.Shutdown(fd.sysfd, syscall.SHUT_WR)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read from network.
|
||||
|
||||
type readOp struct {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ package net
|
|||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
|
@ -119,3 +120,46 @@ func TestReverseAddress(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
l, err := Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
if l, err = Listen("tcp6", "[::1]:0"); err != nil {
|
||||
t.Fatalf("ListenTCP on :0: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("Accept: %v", err)
|
||||
}
|
||||
var buf [10]byte
|
||||
n, err := c.Read(buf[:])
|
||||
if n != 0 || err != os.EOF {
|
||||
t.Fatalf("server Read = %d, %v; want 0, os.EOF", n, err)
|
||||
}
|
||||
c.Write([]byte("response"))
|
||||
c.Close()
|
||||
}()
|
||||
|
||||
c, err := Dial("tcp", l.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
err = c.(*TCPConn).CloseWrite()
|
||||
if err != nil {
|
||||
t.Fatalf("CloseWrite: %v", err)
|
||||
}
|
||||
var buf [10]byte
|
||||
n, err := c.Read(buf[:])
|
||||
if err != nil {
|
||||
t.Fatalf("client Read: %d, %v", n, err)
|
||||
}
|
||||
got := string(buf[:n])
|
||||
if got != "response" {
|
||||
t.Errorf("read = %q, want \"response\"", got)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,6 +100,24 @@ func (c *TCPConn) Close() os.Error {
|
|||
return err
|
||||
}
|
||||
|
||||
// CloseRead shuts down the reading side of the TCP connection.
|
||||
// Most callers should just use Close.
|
||||
func (c *TCPConn) CloseRead() os.Error {
|
||||
if !c.ok() {
|
||||
return os.EINVAL
|
||||
}
|
||||
return c.fd.CloseRead()
|
||||
}
|
||||
|
||||
// CloseWrite shuts down the writing side of the TCP connection.
|
||||
// Most callers should just use Close.
|
||||
func (c *TCPConn) CloseWrite() os.Error {
|
||||
if !c.ok() {
|
||||
return os.EINVAL
|
||||
}
|
||||
return c.fd.CloseWrite()
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address, a *TCPAddr.
|
||||
func (c *TCPConn) LocalAddr() Addr {
|
||||
if !c.ok() {
|
||||
|
|
|
|||
Loading…
Reference in New Issue