diff --git a/src/pkg/net/http/httputil/reverseproxy.go b/src/pkg/net/http/httputil/reverseproxy.go index 9c4bd6e09a..2f08a8c0c9 100644 --- a/src/pkg/net/http/httputil/reverseproxy.go +++ b/src/pkg/net/http/httputil/reverseproxy.go @@ -17,6 +17,10 @@ import ( "time" ) +// beforeCopyResponse is a callback set by tests to intercept the state of the +// output io.Writer before the data is copied to it. +var beforeCopyResponse func(dst io.Writer) + // ReverseProxy is an HTTP Handler that takes an incoming request and // sends it to another server, proxying the response back to the // client. @@ -112,20 +116,32 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusInternalServerError) return } + defer res.Body.Close() copyHeader(rw.Header(), res.Header) rw.WriteHeader(res.StatusCode) + p.copyResponse(rw, res.Body) +} - if res.Body != nil { - var dst io.Writer = rw - if p.FlushInterval != 0 { - if wf, ok := rw.(writeFlusher); ok { - dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval} +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { + if p.FlushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: p.FlushInterval, + done: make(chan bool), } + go mlw.flushLoop() + defer mlw.stop() + dst = mlw } - io.Copy(dst, res.Body) } + + if beforeCopyResponse != nil { + beforeCopyResponse(dst) + } + io.Copy(dst, src) } type writeFlusher interface { @@ -137,22 +153,14 @@ type maxLatencyWriter struct { dst writeFlusher latency time.Duration - lk sync.Mutex // protects init of done, as well Write + Flush + lk sync.Mutex // protects Write + Flush done chan bool } -func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { +func (m *maxLatencyWriter) Write(p []byte) (int, error) { m.lk.Lock() defer m.lk.Unlock() - if m.done == nil { - m.done = make(chan bool) - go m.flushLoop() - } - n, err = m.dst.Write(p) - if err != nil { - m.done <- true - } - return + return m.dst.Write(p) } func (m *maxLatencyWriter) flushLoop() { @@ -160,13 +168,15 @@ func (m *maxLatencyWriter) flushLoop() { defer t.Stop() for { select { + case <-m.done: + return case <-t.C: m.lk.Lock() m.dst.Flush() m.lk.Unlock() - case <-m.done: - return } } panic("unreached") } + +func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/src/pkg/net/http/httputil/reverseproxy_test.go b/src/pkg/net/http/httputil/reverseproxy_test.go index 28e9c90ad3..3bcb23c077 100644 --- a/src/pkg/net/http/httputil/reverseproxy_test.go +++ b/src/pkg/net/http/httputil/reverseproxy_test.go @@ -7,11 +7,14 @@ package httputil import ( + "io" "io/ioutil" "net/http" "net/http/httptest" "net/url" + "runtime" "testing" + "time" ) func TestReverseProxy(t *testing.T) { @@ -107,3 +110,58 @@ func TestReverseProxyQuery(t *testing.T) { frontend.Close() } } + +func TestReverseProxyFlushInterval(t *testing.T) { + if testing.Short() { + return + } + + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + dstChan := make(chan io.Writer, 1) + beforeCopyResponse = func(dst io.Writer) { dstChan <- dst } + defer func() { beforeCopyResponse = nil }() + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + initGoroutines := runtime.NumGoroutine() + for i := 0; i < 100; i++ { + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } + + select { + case dst := <-dstChan: + if _, ok := dst.(*maxLatencyWriter); !ok { + t.Errorf("got writer %T; expected %T", dst, &maxLatencyWriter{}) + } + default: + t.Error("maxLatencyWriter Write() was never called") + } + + res.Body.Close() + } + // Allow up to 50 additional goroutines over 100 requests. + if delta := runtime.NumGoroutine() - initGoroutines; delta > 50 { + t.Errorf("grew %d goroutines; leak?", delta) + } +}