From 79bb493dcbb9b76d9d2ff9cd0854b29d634f8b73 Mon Sep 17 00:00:00 2001 From: Julien Salleyron Date: Mon, 24 Dec 2018 09:07:07 +0100 Subject: [PATCH] Add rw headers to upgrade response --- src/net/http/httputil/reverseproxy.go | 3 +++ src/net/http/httputil/reverseproxy_test.go | 18 +++++++++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index 5d07ba3d36..c13b99ff72 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -497,6 +497,9 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) return } + + copyHeader(res.Header, rw.Header()) + hj, ok := rw.(http.Hijacker) if !ok { p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index 5caa206066..bda569acc7 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -1013,7 +1013,12 @@ func TestReverseProxyWebSocket(t *testing.T) { rproxy := NewSingleHostReverseProxy(backURL) rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests - frontendProxy := httptest.NewServer(rproxy) + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("X-Header", "X-Value") + rproxy.ServeHTTP(rw, req) + }) + + frontendProxy := httptest.NewServer(handler) defer frontendProxy.Close() req, _ := http.NewRequest("GET", frontendProxy.URL, nil) @@ -1028,6 +1033,13 @@ func TestReverseProxyWebSocket(t *testing.T) { if res.StatusCode != 101 { t.Fatalf("status = %v; want 101", res.Status) } + + got := res.Header.Get("X-Header") + want := "X-Value" + if got != want { + t.Errorf("Header(XHeader) = %q; want %q", got, want) + } + if upgradeType(res.Header) != "websocket" { t.Fatalf("not websocket upgrade; got %#v", res.Header) } @@ -1042,8 +1054,8 @@ func TestReverseProxyWebSocket(t *testing.T) { if !bs.Scan() { t.Fatalf("Scan: %v", bs.Err()) } - got := bs.Text() - want := `backend got "Hello"` + got = bs.Text() + want = `backend got "Hello"` if got != want { t.Errorf("got %#q, want %#q", got, want) }