mirror of https://github.com/golang/go.git
net/http: don't reuse a server connection after any Write errors
Fixes #8534 LGTM=adg R=adg CC=golang-codereviews https://golang.org/cl/149340044
This commit is contained in:
parent
a681749ab5
commit
9d51cd0fee
|
|
@ -2659,6 +2659,103 @@ func TestCloseWrite(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// This verifies that a handler can Flush and then Hijack.
|
||||
//
|
||||
// An similar test crashed once during development, but it was only
|
||||
// testing this tangentially and temporarily until another TODO was
|
||||
// fixed.
|
||||
//
|
||||
// So add an explicit test for this.
|
||||
func TestServerFlushAndHijack(t *testing.T) {
|
||||
defer afterTest(t)
|
||||
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
io.WriteString(w, "Hello, ")
|
||||
w.(Flusher).Flush()
|
||||
conn, buf, _ := w.(Hijacker).Hijack()
|
||||
buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
|
||||
if err := buf.Flush(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
res, err := Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
all, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := "Hello, world!"; string(all) != want {
|
||||
t.Errorf("Got %q; want %q", all, want)
|
||||
}
|
||||
}
|
||||
|
||||
// golang.org/issue/8534 -- the Server shouldn't reuse a connection
|
||||
// for keep-alive after it's seen any Write error (e.g. a timeout) on
|
||||
// that net.Conn.
|
||||
//
|
||||
// To test, verify we don't timeout or see fewer unique client
|
||||
// addresses (== unique connections) than requests.
|
||||
func TestServerKeepAliveAfterWriteError(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in -short mode")
|
||||
}
|
||||
defer afterTest(t)
|
||||
const numReq = 3
|
||||
addrc := make(chan string, numReq)
|
||||
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
addrc <- r.RemoteAddr
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
w.(Flusher).Flush()
|
||||
}))
|
||||
ts.Config.WriteTimeout = 250 * time.Millisecond
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
errc := make(chan error, numReq)
|
||||
go func() {
|
||||
defer close(errc)
|
||||
for i := 0; i < numReq; i++ {
|
||||
res, err := Get(ts.URL)
|
||||
if res != nil {
|
||||
res.Body.Close()
|
||||
}
|
||||
errc <- err
|
||||
}
|
||||
}()
|
||||
|
||||
timeout := time.NewTimer(numReq * 2 * time.Second) // 4x overkill
|
||||
defer timeout.Stop()
|
||||
addrSeen := map[string]bool{}
|
||||
numOkay := 0
|
||||
for {
|
||||
select {
|
||||
case v := <-addrc:
|
||||
addrSeen[v] = true
|
||||
case err, ok := <-errc:
|
||||
if !ok {
|
||||
if len(addrSeen) != numReq {
|
||||
t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
|
||||
}
|
||||
if numOkay != 0 {
|
||||
t.Errorf("got %d successful client requests; want 0", numOkay)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
numOkay++
|
||||
}
|
||||
case <-timeout.C:
|
||||
t.Fatal("timeout waiting for requests to complete")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkClientServer(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.StopTimer()
|
||||
|
|
|
|||
|
|
@ -114,6 +114,8 @@ type conn struct {
|
|||
remoteAddr string // network address of remote side
|
||||
server *Server // the Server on which the connection arrived
|
||||
rwc net.Conn // i/o connection
|
||||
w io.Writer // checkConnErrorWriter's copy of wrc, not zeroed on Hijack
|
||||
werr error // any errors writing to w
|
||||
sr liveSwitchReader // where the LimitReader reads from; usually the rwc
|
||||
lr *io.LimitedReader // io.LimitReader(sr)
|
||||
buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc
|
||||
|
|
@ -432,13 +434,14 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
|
|||
c.remoteAddr = rwc.RemoteAddr().String()
|
||||
c.server = srv
|
||||
c.rwc = rwc
|
||||
c.w = rwc
|
||||
if debugServerConnections {
|
||||
c.rwc = newLoggingConn("server", c.rwc)
|
||||
}
|
||||
c.sr = liveSwitchReader{r: c.rwc}
|
||||
c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader)
|
||||
br := newBufioReader(c.lr)
|
||||
bw := newBufioWriterSize(c.rwc, 4<<10)
|
||||
bw := newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
|
||||
c.buf = bufio.NewReadWriter(br, bw)
|
||||
return c, nil
|
||||
}
|
||||
|
|
@ -956,8 +959,10 @@ func (w *response) bodyAllowed() bool {
|
|||
// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes
|
||||
// 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type)
|
||||
// and which writes the chunk headers, if needed.
|
||||
// 4. conn.buf, a bufio.Writer of default (4kB) bytes
|
||||
// 5. the rwc, the net.Conn.
|
||||
// 4. conn.buf, a bufio.Writer of default (4kB) bytes, writing to ->
|
||||
// 5. checkConnErrorWriter{c}, which notes any non-nil error on Write
|
||||
// and populates c.werr with it if so. but otherwise writes to:
|
||||
// 6. the rwc, the net.Conn.
|
||||
//
|
||||
// TODO(bradfitz): short-circuit some of the buffering when the
|
||||
// initial header contains both a Content-Type and Content-Length.
|
||||
|
|
@ -1027,6 +1032,12 @@ func (w *response) finishRequest() {
|
|||
// Did not write enough. Avoid getting out of sync.
|
||||
w.closeAfterReply = true
|
||||
}
|
||||
|
||||
// There was some error writing to the underlying connection
|
||||
// during the request, so don't re-use this conn.
|
||||
if w.conn.werr != nil {
|
||||
w.closeAfterReply = true
|
||||
}
|
||||
}
|
||||
|
||||
func (w *response) Flush() {
|
||||
|
|
@ -2068,3 +2079,18 @@ func (c *loggingConn) Close() (err error) {
|
|||
log.Printf("%s.Close() = %v", c.name, err)
|
||||
return
|
||||
}
|
||||
|
||||
// checkConnErrorWriter writes to c.rwc and records any write errors to c.werr.
|
||||
// It only contains one field (and a pointer field at that), so it
|
||||
// fits in an interface value without an extra allocation.
|
||||
type checkConnErrorWriter struct {
|
||||
c *conn
|
||||
}
|
||||
|
||||
func (w checkConnErrorWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = w.c.w.Write(p) // c.w == c.rwc, except after a hijack, when rwc is nil.
|
||||
if err != nil && w.c.werr == nil {
|
||||
w.c.werr = err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue