diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index eece455ac6..2a76b0b8dc 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -524,9 +524,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // Force chunking if we saw a response trailer. // This prevents net/http from calculating the length for short // bodies and adding a Content-Length. - if fl, ok := rw.(http.Flusher); ok { - fl.Flush() - } + http.NewResponseController(rw).Flush() } if len(res.Trailer) == announcedTrailers { @@ -601,21 +599,22 @@ func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { return p.FlushInterval } -func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { +func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error { + var w io.Writer = dst + if flushInterval != 0 { - if wf, ok := dst.(writeFlusher); ok { - mlw := &maxLatencyWriter{ - dst: wf, - latency: flushInterval, - } - defer mlw.stop() - - // set up initial timer so headers get flushed even if body writes are delayed - mlw.flushPending = true - mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) - - dst = mlw + mlw := &maxLatencyWriter{ + dst: dst, + flush: http.NewResponseController(dst).Flush, + latency: flushInterval, } + defer mlw.stop() + + // set up initial timer so headers get flushed even if body writes are delayed + mlw.flushPending = true + mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) + + w = mlw } var buf []byte @@ -623,7 +622,7 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval buf = p.BufferPool.Get() defer p.BufferPool.Put(buf) } - _, err := p.copyBuffer(dst, src, buf) + _, err := p.copyBuffer(w, src, buf) return err } @@ -668,13 +667,9 @@ func (p *ReverseProxy) logf(format string, args ...any) { } } -type writeFlusher interface { - io.Writer - http.Flusher -} - type maxLatencyWriter struct { - dst writeFlusher + dst io.Writer + flush func() error latency time.Duration // non-zero; negative means to flush immediately mu sync.Mutex // protects t, flushPending, and dst.Flush @@ -687,7 +682,7 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { defer m.mu.Unlock() n, err = m.dst.Write(p) if m.latency < 0 { - m.dst.Flush() + m.flush() return } if m.flushPending { @@ -708,7 +703,7 @@ func (m *maxLatencyWriter) delayedFlush() { if !m.flushPending { // if stop was called but AfterFunc already started this goroutine return } - m.dst.Flush() + m.flush() m.flushPending = false } @@ -739,17 +734,19 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R return } - hj, ok := rw.(http.Hijacker) - if !ok { - p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) - return - } backConn, ok := res.Body.(io.ReadWriteCloser) if !ok { p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) return } + rc := http.NewResponseController(rw) + conn, brw, hijackErr := rc.Hijack() + if errors.Is(hijackErr, http.ErrNotSupported) { + p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + backConnCloseCh := make(chan bool) go func() { // Ensure that the cancellation of a request closes the backend. @@ -760,12 +757,10 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R } backConn.Close() }() - defer close(backConnCloseCh) - conn, brw, err := hj.Hijack() - if err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) + if hijackErr != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr)) return } defer conn.Close() diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index d5b0fb4244..dd3330b615 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -478,6 +478,62 @@ func TestReverseProxyFlushInterval(t *testing.T) { } } +type mockFlusher struct { + http.ResponseWriter + flushed bool +} + +func (m *mockFlusher) Flush() { + m.flushed = true +} + +type wrappedRW struct { + http.ResponseWriter +} + +func (w *wrappedRW) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + +func TestReverseProxyResponseControllerFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + mf := &mockFlusher{} + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = -1 // flush immediately + proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mf.ResponseWriter = w + w = &wrappedRW{mf} + proxyHandler.ServeHTTP(w, r) + }) + + frontend := httptest.NewServer(proxyWithMiddleware) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } + if !mf.flushed { + t.Errorf("response writer was not flushed") + } +} + func TestReverseProxyFlushIntervalHeaders(t *testing.T) { const expected = "hi" stopCh := make(chan struct{})