mirror of https://github.com/golang/go.git
net/http/httputil: use response controller in reverse proxy
Previously, the reverse proxy is unable to detect the support for hijack or flush if those things are residing in the response writer in a wrapped manner. The reverse proxy now makes use of the new http response controller as the means to discover the underlying flusher and hijacker associated with the response writer, allowing wrapped flusher and hijacker become discoverable. Change-Id: I53acbb12315c3897be068e8c00598ef42fc74649 Reviewed-on: https://go-review.googlesource.com/c/go/+/468755 Run-TryBot: Damien Neil <dneil@google.com> Auto-Submit: Damien Neil <dneil@google.com> TryBot-Result: Gopher Robot <gobot@golang.org> Reviewed-by: Damien Neil <dneil@google.com> Reviewed-by: Cherry Mui <cherryyz@google.com>
This commit is contained in:
parent
602e6aa979
commit
2449bbb5e6
|
|
@ -524,9 +524,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
// Force chunking if we saw a response trailer.
|
||||
// This prevents net/http from calculating the length for short
|
||||
// bodies and adding a Content-Length.
|
||||
if fl, ok := rw.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
http.NewResponseController(rw).Flush()
|
||||
}
|
||||
|
||||
if len(res.Trailer) == announcedTrailers {
|
||||
|
|
@ -601,21 +599,22 @@ func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
|
|||
return p.FlushInterval
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
|
||||
func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error {
|
||||
var w io.Writer = dst
|
||||
|
||||
if flushInterval != 0 {
|
||||
if wf, ok := dst.(writeFlusher); ok {
|
||||
mlw := &maxLatencyWriter{
|
||||
dst: wf,
|
||||
latency: flushInterval,
|
||||
}
|
||||
defer mlw.stop()
|
||||
|
||||
// set up initial timer so headers get flushed even if body writes are delayed
|
||||
mlw.flushPending = true
|
||||
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
|
||||
|
||||
dst = mlw
|
||||
mlw := &maxLatencyWriter{
|
||||
dst: dst,
|
||||
flush: http.NewResponseController(dst).Flush,
|
||||
latency: flushInterval,
|
||||
}
|
||||
defer mlw.stop()
|
||||
|
||||
// set up initial timer so headers get flushed even if body writes are delayed
|
||||
mlw.flushPending = true
|
||||
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
|
||||
|
||||
w = mlw
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
|
|
@ -623,7 +622,7 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval
|
|||
buf = p.BufferPool.Get()
|
||||
defer p.BufferPool.Put(buf)
|
||||
}
|
||||
_, err := p.copyBuffer(dst, src, buf)
|
||||
_, err := p.copyBuffer(w, src, buf)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -668,13 +667,9 @@ func (p *ReverseProxy) logf(format string, args ...any) {
|
|||
}
|
||||
}
|
||||
|
||||
type writeFlusher interface {
|
||||
io.Writer
|
||||
http.Flusher
|
||||
}
|
||||
|
||||
type maxLatencyWriter struct {
|
||||
dst writeFlusher
|
||||
dst io.Writer
|
||||
flush func() error
|
||||
latency time.Duration // non-zero; negative means to flush immediately
|
||||
|
||||
mu sync.Mutex // protects t, flushPending, and dst.Flush
|
||||
|
|
@ -687,7 +682,7 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
|
|||
defer m.mu.Unlock()
|
||||
n, err = m.dst.Write(p)
|
||||
if m.latency < 0 {
|
||||
m.dst.Flush()
|
||||
m.flush()
|
||||
return
|
||||
}
|
||||
if m.flushPending {
|
||||
|
|
@ -708,7 +703,7 @@ func (m *maxLatencyWriter) delayedFlush() {
|
|||
if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
|
||||
return
|
||||
}
|
||||
m.dst.Flush()
|
||||
m.flush()
|
||||
m.flushPending = false
|
||||
}
|
||||
|
||||
|
|
@ -739,17 +734,19 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
|
|||
return
|
||||
}
|
||||
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
|
||||
return
|
||||
}
|
||||
backConn, ok := res.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
|
||||
return
|
||||
}
|
||||
|
||||
rc := http.NewResponseController(rw)
|
||||
conn, brw, hijackErr := rc.Hijack()
|
||||
if errors.Is(hijackErr, http.ErrNotSupported) {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
|
||||
return
|
||||
}
|
||||
|
||||
backConnCloseCh := make(chan bool)
|
||||
go func() {
|
||||
// Ensure that the cancellation of a request closes the backend.
|
||||
|
|
@ -760,12 +757,10 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
|
|||
}
|
||||
backConn.Close()
|
||||
}()
|
||||
|
||||
defer close(backConnCloseCh)
|
||||
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
|
||||
if hijackErr != nil {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
|
|
|||
|
|
@ -478,6 +478,62 @@ func TestReverseProxyFlushInterval(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type mockFlusher struct {
|
||||
http.ResponseWriter
|
||||
flushed bool
|
||||
}
|
||||
|
||||
func (m *mockFlusher) Flush() {
|
||||
m.flushed = true
|
||||
}
|
||||
|
||||
type wrappedRW struct {
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (w *wrappedRW) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
|
||||
func TestReverseProxyResponseControllerFlushInterval(t *testing.T) {
|
||||
const expected = "hi"
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(expected))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
backendURL, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mf := &mockFlusher{}
|
||||
proxyHandler := NewSingleHostReverseProxy(backendURL)
|
||||
proxyHandler.FlushInterval = -1 // flush immediately
|
||||
proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mf.ResponseWriter = w
|
||||
w = &wrappedRW{mf}
|
||||
proxyHandler.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
frontend := httptest.NewServer(proxyWithMiddleware)
|
||||
defer frontend.Close()
|
||||
|
||||
req, _ := http.NewRequest("GET", frontend.URL, nil)
|
||||
req.Close = true
|
||||
res, err := frontend.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
|
||||
t.Errorf("got body %q; expected %q", bodyBytes, expected)
|
||||
}
|
||||
if !mf.flushed {
|
||||
t.Errorf("response writer was not flushed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
|
||||
const expected = "hi"
|
||||
stopCh := make(chan struct{})
|
||||
|
|
|
|||
Loading…
Reference in New Issue