diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go index 44b532ae1f..8b53c41687 100644 --- a/src/net/http/client_test.go +++ b/src/net/http/client_test.go @@ -67,11 +67,9 @@ func (w chanWriter) Write(p []byte) (n int, err error) { return len(p), nil } -func TestClient(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(robotsTxtHandler) - defer ts.Close() +func TestClient(t *testing.T) { run(t, testClient) } +func testClient(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, robotsTxtHandler).ts c := ts.Client() r, err := c.Get(ts.URL) @@ -87,14 +85,9 @@ func TestClient(t *testing.T) { } } -func TestClientHead_h1(t *testing.T) { testClientHead(t, h1Mode) } -func TestClientHead_h2(t *testing.T) { testClientHead(t, h2Mode) } - -func testClientHead(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, robotsTxtHandler) - defer cst.close() - +func TestClientHead(t *testing.T) { run(t, testClientHead) } +func testClientHead(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, robotsTxtHandler) r, err := cst.c.Head(cst.ts.URL) if err != nil { t.Fatal(err) @@ -200,11 +193,10 @@ func TestPostFormRequestFormat(t *testing.T) { } } -func TestClientRedirects(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestClientRedirects(t *testing.T) { run(t, testClientRedirects) } +func testClientRedirects(t *testing.T, mode testMode) { var ts *httptest.Server - ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { n, _ := strconv.Atoi(r.FormValue("n")) // Test Referer header. (7 is arbitrary position to test at) if n == 7 { @@ -217,8 +209,7 @@ func TestClientRedirects(t *testing.T) { return } fmt.Fprintf(w, "n=%d", n) - })) - defer ts.Close() + })).ts c := ts.Client() _, err := c.Get(ts.URL) @@ -299,13 +290,11 @@ func TestClientRedirects(t *testing.T) { } // Tests that Client redirects' contexts are derived from the original request's context. -func TestClientRedirectContext(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientRedirectsContext(t *testing.T) { run(t, testClientRedirectsContext) } +func testClientRedirectsContext(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Redirect(w, r, "/", StatusTemporaryRedirect) - })) - defer ts.Close() + })).ts ctx, cancel := context.WithCancel(context.Background()) c := ts.Client() @@ -373,7 +362,9 @@ func TestPostRedirects(t *testing.T) { `POST /?code=404 "c404"`, } want := strings.Join(wantSegments, "\n") - testRedirectsByMethod(t, "POST", postRedirectTests, want) + run(t, func(t *testing.T, mode testMode) { + testRedirectsByMethod(t, mode, "POST", postRedirectTests, want) + }) } func TestDeleteRedirects(t *testing.T) { @@ -410,17 +401,18 @@ func TestDeleteRedirects(t *testing.T) { `DELETE /?code=404 "c404"`, } want := strings.Join(wantSegments, "\n") - testRedirectsByMethod(t, "DELETE", deleteRedirectTests, want) + run(t, func(t *testing.T, mode testMode) { + testRedirectsByMethod(t, mode, "DELETE", deleteRedirectTests, want) + }) } -func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, want string) { - defer afterTest(t) +func testRedirectsByMethod(t *testing.T, mode testMode, method string, table []redirectTest, want string) { var log struct { sync.Mutex bytes.Buffer } var ts *httptest.Server - ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { log.Lock() slurp, _ := io.ReadAll(r.Body) fmt.Fprintf(&log.Buffer, "%s %s %q", r.Method, r.RequestURI, slurp) @@ -445,8 +437,7 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa } w.WriteHeader(code) } - })) - defer ts.Close() + })).ts c := ts.Client() for _, tt := range table { @@ -491,12 +482,11 @@ func removeCommonLines(a, b string) (asuffix, bsuffix string, commonLines int) { } } -func TestClientRedirectUseResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestClientRedirectUseResponse(t *testing.T) { run(t, testClientRedirectUseResponse) } +func testClientRedirectUseResponse(t *testing.T, mode testMode) { const body = "Hello, world." var ts *httptest.Server - ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if strings.Contains(r.URL.Path, "/other") { io.WriteString(w, "wrong body") } else { @@ -504,8 +494,7 @@ func TestClientRedirectUseResponse(t *testing.T) { w.WriteHeader(StatusFound) io.WriteString(w, body) } - })) - defer ts.Close() + })).ts c := ts.Client() c.CheckRedirect = func(req *Request, via []*Request) error { @@ -533,18 +522,16 @@ func TestClientRedirectUseResponse(t *testing.T) { // Issues 17773 and 49281: don't follow a 3xx if the response doesn't // have a Location header. -func TestClientRedirectNoLocation(t *testing.T) { +func TestClientRedirectNoLocation(t *testing.T) { run(t, testClientRedirectNoLocation) } +func testClientRedirectNoLocation(t *testing.T, mode testMode) { for _, code := range []int{301, 308} { t.Run(fmt.Sprint(code), func(t *testing.T) { setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Foo", "Bar") w.WriteHeader(code) })) - defer ts.Close() - c := ts.Client() - res, err := c.Get(ts.URL) + res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } @@ -560,15 +547,13 @@ func TestClientRedirectNoLocation(t *testing.T) { } // Don't follow a 307/308 if we can't resent the request body. -func TestClientRedirect308NoGetBody(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestClientRedirect308NoGetBody(t *testing.T) { run(t, testClientRedirect308NoGetBody) } +func testClientRedirect308NoGetBody(t *testing.T, mode testMode) { const fakeURL = "https://localhost:1234/" // won't be hit - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Location", fakeURL) w.WriteHeader(308) - })) - defer ts.Close() + })).ts req, err := NewRequest("POST", ts.URL, strings.NewReader("some body")) if err != nil { t.Fatal(err) @@ -659,12 +644,10 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie { return j.perURL[u.Host] } -func TestRedirectCookiesJar(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestRedirectCookiesJar(t *testing.T) { run(t, testRedirectCookiesJar) } +func testRedirectCookiesJar(t *testing.T, mode testMode) { var ts *httptest.Server - ts = httptest.NewServer(echoCookiesRedirectHandler) - defer ts.Close() + ts = newClientServerTest(t, mode, echoCookiesRedirectHandler).ts c := ts.Client() c.Jar = new(TestJar) u, _ := url.Parse(ts.URL) @@ -696,9 +679,9 @@ func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { } } -func TestJarCalls(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestJarCalls(t *testing.T) { run(t, testJarCalls, []testMode{http1Mode}) } +func testJarCalls(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { pathSuffix := r.RequestURI[1:] if r.RequestURI == "/nosetcookie" { return // don't set cookies for this path @@ -707,8 +690,7 @@ func TestJarCalls(t *testing.T) { if r.RequestURI == "/" { Redirect(w, r, "http://secondhost.fake/secondpath", 302) } - })) - defer ts.Close() + })).ts jar := new(RecordingJar) c := ts.Client() c.Jar = jar @@ -757,20 +739,16 @@ func (j *RecordingJar) logf(format string, args ...any) { fmt.Fprintf(&j.log, format, args...) } -func TestStreamingGet_h1(t *testing.T) { testStreamingGet(t, h1Mode) } -func TestStreamingGet_h2(t *testing.T) { testStreamingGet(t, h2Mode) } - -func testStreamingGet(t *testing.T, h2 bool) { - defer afterTest(t) +func TestStreamingGet(t *testing.T) { run(t, testStreamingGet) } +func testStreamingGet(t *testing.T, mode testMode) { say := make(chan string) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() for str := range say { w.Write([]byte(str)) w.(Flusher).Flush() } })) - defer cst.close() c := cst.c res, err := c.Get(cst.ts.URL) @@ -811,11 +789,10 @@ func (c *writeCountingConn) Write(p []byte) (int, error) { // TestClientWrites verifies that client requests are buffered and we // don't send a TCP packet per line of the http request + body. -func TestClientWrites(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - })) - defer ts.Close() +func TestClientWrites(t *testing.T) { run(t, testClientWrites, []testMode{http1Mode}) } +func testClientWrites(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + })).ts writes := 0 dialer := func(netz string, addr string) (net.Conn, error) { @@ -847,11 +824,12 @@ func TestClientWrites(t *testing.T) { } func TestClientInsecureTransport(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testClientInsecureTransport, []testMode{https1Mode, http2Mode}) +} +func testClientInsecureTransport(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) - })) + })).ts errc := make(chanWriter, 10) // but only expecting 1 ts.Config.ErrorLog = log.New(errc, "", 0) defer ts.Close() @@ -898,15 +876,15 @@ func TestClientErrorWithRequestURI(t *testing.T) { } func TestClientWithCorrectTLSServerName(t *testing.T) { - defer afterTest(t) - + run(t, testClientWithCorrectTLSServerName, []testMode{https1Mode, http2Mode}) +} +func testClientWithCorrectTLSServerName(t *testing.T, mode testMode) { const serverName = "example.com" - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS.ServerName != serverName { t.Errorf("expected client to set ServerName %q, got: %q", serverName, r.TLS.ServerName) } - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).TLSClientConfig.ServerName = serverName @@ -916,9 +894,10 @@ func TestClientWithCorrectTLSServerName(t *testing.T) { } func TestClientWithIncorrectTLSServerName(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + run(t, testClientWithIncorrectTLSServerName, []testMode{https1Mode, http2Mode}) +} +func testClientWithIncorrectTLSServerName(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts errc := make(chanWriter, 10) // but only expecting 1 ts.Config.ErrorLog = log.New(errc, "", 0) @@ -951,11 +930,12 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { // // The httptest.Server has a cert with "example.com" as its name. func TestTransportUsesTLSConfigServerName(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportUsesTLSConfigServerName, []testMode{https1Mode, http2Mode}) +} +func testTransportUsesTLSConfigServerName(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -971,11 +951,12 @@ func TestTransportUsesTLSConfigServerName(t *testing.T) { } func TestResponseSetsTLSConnectionState(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testResponseSetsTLSConnectionState, []testMode{https1Mode}) +} +func testResponseSetsTLSConnectionState(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -1001,10 +982,11 @@ func TestResponseSetsTLSConnectionState(t *testing.T) { // to determine that the server is speaking HTTP. // See golang.org/issue/11111. func TestHTTPSClientDetectsHTTPServer(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + run(t, testHTTPSClientDetectsHTTPServer, []testMode{http1Mode}) +} +func testHTTPSClientDetectsHTTPServer(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts ts.Config.ErrorLog = quietLog - defer ts.Close() _, err := Get(strings.Replace(ts.URL, "http", "https", 1)) if got := err.Error(); !strings.Contains(got, "HTTP response to HTTPS client") { @@ -1013,22 +995,13 @@ func TestHTTPSClientDetectsHTTPServer(t *testing.T) { } // Verify Response.ContentLength is populated. https://golang.org/issue/4126 -func TestClientHeadContentLength_h1(t *testing.T) { - testClientHeadContentLength(t, h1Mode) -} - -func TestClientHeadContentLength_h2(t *testing.T) { - testClientHeadContentLength(t, h2Mode) -} - -func testClientHeadContentLength(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientHeadContentLength(t *testing.T) { run(t, testClientHeadContentLength) } +func testClientHeadContentLength(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if v := r.FormValue("cl"); v != "" { w.Header().Set("Content-Length", v) } })) - defer cst.close() tests := []struct { suffix string want int64 @@ -1056,11 +1029,10 @@ func testClientHeadContentLength(t *testing.T, h2 bool) { } } -func TestEmptyPasswordAuth(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestEmptyPasswordAuth(t *testing.T) { run(t, testEmptyPasswordAuth) } +func testEmptyPasswordAuth(t *testing.T, mode testMode) { gopher := "gopher" - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { auth := r.Header.Get("Authorization") if strings.HasPrefix(auth, "Basic ") { encoded := auth[6:] @@ -1076,7 +1048,7 @@ func TestEmptyPasswordAuth(t *testing.T) { } else { t.Errorf("Invalid auth %q", auth) } - })) + })).ts defer ts.Close() req, err := NewRequest("GET", ts.URL, nil) if err != nil { @@ -1205,19 +1177,14 @@ func TestStripPasswordFromError(t *testing.T) { } } -func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) } -func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) } - -func testClientTimeout(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func TestClientTimeout(t *testing.T) { run(t, testClientTimeout) } +func testClientTimeout(t *testing.T, mode testMode) { var ( mu sync.Mutex nonce string // a unique per-request string sawSlowNonce bool // true if the handler saw /slow?nonce= ) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _ = r.ParseForm() if r.URL.Path == "/" { Redirect(w, r, "/slow?nonce="+r.Form.Get("nonce"), StatusFound) @@ -1238,7 +1205,6 @@ func testClientTimeout(t *testing.T, h2 bool) { return } })) - defer cst.close() // Try to trigger a timeout after reading part of the response body. // The initial timeout is emprically usually long enough on a decently fast @@ -1308,18 +1274,13 @@ func testClientTimeout(t *testing.T, h2 bool) { } } -func TestClientTimeout_Headers_h1(t *testing.T) { testClientTimeout_Headers(t, h1Mode) } -func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h2Mode) } - // Client.Timeout firing before getting to the body -func testClientTimeout_Headers(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestClientTimeout_Headers(t *testing.T) { run(t, testClientTimeout_Headers) } +func testClientTimeout_Headers(t *testing.T, mode testMode) { donec := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-donec }), optQuietLog) - defer cst.close() // Note that we use a channel send here and not a close. // The race detector doesn't know that we're waiting for a timeout // and thinks that the waitgroup inside httptest.Server is added to concurrently @@ -1355,18 +1316,15 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) { // Issue 16094: if Client.Timeout is set but not hit, a Timeout error shouldn't be // returned. -func TestClientTimeoutCancel(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestClientTimeoutCancel(t *testing.T) { run(t, testClientTimeoutCancel) } +func testClientTimeoutCancel(t *testing.T, mode testMode) { testDone := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background()) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() <-testDone })) - defer cst.close() defer close(testDone) cst.c.Timeout = 1 * time.Hour @@ -1383,18 +1341,12 @@ func TestClientTimeoutCancel(t *testing.T) { } } -func TestClientTimeoutDoesNotExpire_h1(t *testing.T) { testClientTimeoutDoesNotExpire(t, h1Mode) } -func TestClientTimeoutDoesNotExpire_h2(t *testing.T) { testClientTimeoutDoesNotExpire(t, h2Mode) } - // Issue 49366: if Client.Timeout is set but not hit, no error should be returned. -func testClientTimeoutDoesNotExpire(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientTimeoutDoesNotExpire(t *testing.T) { run(t, testClientTimeoutDoesNotExpire) } +func testClientTimeoutDoesNotExpire(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("body")) })) - defer cst.close() cst.c.Timeout = 1 * time.Hour req, _ := NewRequest("GET", cst.ts.URL, nil) @@ -1410,19 +1362,15 @@ func testClientTimeoutDoesNotExpire(t *testing.T, h2 bool) { } } -func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) } -func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) } -func testClientRedirectEatsBody(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestClientRedirectEatsBody_h1(t *testing.T) { run(t, testClientRedirectEatsBody) } +func testClientRedirectEatsBody(t *testing.T, mode testMode) { saw := make(chan string, 2) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { saw <- r.RemoteAddr if r.URL.Path == "/" { Redirect(w, r, "/foo", StatusFound) // which includes a body } })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -1522,13 +1470,14 @@ func TestClientRedirectResponseWithoutRequest(t *testing.T) { } // Issue 4800: copy (some) headers when Client follows a redirect. -func TestClientCopyHeadersOnRedirect(t *testing.T) { +func TestClientCopyHeadersOnRedirect(t *testing.T) { run(t, testClientCopyHeadersOnRedirect) } +func testClientCopyHeadersOnRedirect(t *testing.T, mode testMode) { const ( ua = "some-agent/1.2" xfoo = "foo-val" ) var ts2URL string - ts1 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts1 := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { want := Header{ "User-Agent": []string{ua}, "X-Foo": []string{xfoo}, @@ -1543,12 +1492,10 @@ func TestClientCopyHeadersOnRedirect(t *testing.T) { } else { w.Header().Set("Result", "ok") } - })) - defer ts1.Close() - ts2 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + })).ts + ts2 := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Redirect(w, r, ts1.URL, StatusFound) - })) - defer ts2.Close() + })).ts ts2URL = ts2.URL c := ts1.Client() @@ -1583,22 +1530,24 @@ func TestClientCopyHeadersOnRedirect(t *testing.T) { } // Issue 22233: copy host when Client follows a relative redirect. -func TestClientCopyHostOnRedirect(t *testing.T) { +func TestClientCopyHostOnRedirect(t *testing.T) { run(t, testClientCopyHostOnRedirect) } +func testClientCopyHostOnRedirect(t *testing.T, mode testMode) { // Virtual hostname: should not receive any request. - virtual := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + virtual := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Errorf("Virtual host received request %v", r.URL) w.WriteHeader(403) io.WriteString(w, "should not see this response") - })) + })).ts defer virtual.Close() virtualHost := strings.TrimPrefix(virtual.URL, "http://") + virtualHost = strings.TrimPrefix(virtualHost, "https://") t.Logf("Virtual host is %v", virtualHost) // Actual hostname: should not receive any request. const wantBody = "response body" var tsURL string var tsHost string - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { switch r.URL.Path { case "/": // Relative redirect. @@ -1630,10 +1579,10 @@ func TestClientCopyHostOnRedirect(t *testing.T) { t.Errorf("Serving unexpected path %q", r.URL.Path) w.WriteHeader(404) } - })) - defer ts.Close() + })).ts tsURL = ts.URL tsHost = strings.TrimPrefix(ts.URL, "http://") + tsHost = strings.TrimPrefix(tsHost, "https://") t.Logf("Server host is %v", tsHost) c := ts.Client() @@ -1653,7 +1602,8 @@ func TestClientCopyHostOnRedirect(t *testing.T) { } // Issue 17494: cookies should be altered when Client follows redirects. -func TestClientAltersCookiesOnRedirect(t *testing.T) { +func TestClientAltersCookiesOnRedirect(t *testing.T) { run(t, testClientAltersCookiesOnRedirect) } +func testClientAltersCookiesOnRedirect(t *testing.T, mode testMode) { cookieMap := func(cs []*Cookie) map[string][]string { m := make(map[string][]string) for _, c := range cs { @@ -1662,7 +1612,7 @@ func TestClientAltersCookiesOnRedirect(t *testing.T) { return m } - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { var want map[string][]string got := cookieMap(r.Cookies()) @@ -1717,8 +1667,7 @@ func TestClientAltersCookiesOnRedirect(t *testing.T) { if !reflect.DeepEqual(got, want) { t.Errorf("redirect %s, Cookie = %v, want %v", c.Value, got, want) } - })) - defer ts.Close() + })).ts jar, _ := cookiejar.New(nil) c := ts.Client() @@ -1790,10 +1739,8 @@ func TestShouldCopyHeaderOnRedirect(t *testing.T) { } } -func TestClientRedirectTypes(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestClientRedirectTypes(t *testing.T) { run(t, testClientRedirectTypes) } +func testClientRedirectTypes(t *testing.T, mode testMode) { tests := [...]struct { method string serverStatus int @@ -1838,11 +1785,10 @@ func TestClientRedirectTypes(t *testing.T) { handlerc := make(chan HandlerFunc, 1) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { h := <-handlerc h(rw, req) - })) - defer ts.Close() + })).ts c := ts.Client() for i, tt := range tests { @@ -1898,18 +1844,16 @@ func (b issue18239Body) Close() error { // Issue 18239: make sure the Transport doesn't retry requests with bodies // if Request.GetBody is not defined. -func TestTransportBodyReadError(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportBodyReadError(t *testing.T) { run(t, testTransportBodyReadError) } +func testTransportBodyReadError(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.URL.Path == "/ping" { return } buf := make([]byte, 1) n, err := r.Body.Read(buf) w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err)) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -1993,22 +1937,13 @@ func TestClientPropagatesTimeoutToContext(t *testing.T) { c.Get("https://example.tld/") } -func TestClientDoCanceledVsTimeout_h1(t *testing.T) { - testClientDoCanceledVsTimeout(t, h1Mode) -} - -func TestClientDoCanceledVsTimeout_h2(t *testing.T) { - testClientDoCanceledVsTimeout(t, h2Mode) -} - // Issue 33545: lock-in the behavior promised by Client.Do's // docs about request cancellation vs timing out. -func testClientDoCanceledVsTimeout(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientDoCanceledVsTimeout(t *testing.T) { run(t, testClientDoCanceledVsTimeout) } +func testClientDoCanceledVsTimeout(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello, World!")) })) - defer cst.close() cases := []string{"timeout", "canceled"} @@ -2084,13 +2019,11 @@ func TestClientPopulatesNilResponseBody(t *testing.T) { } // Issue 40382: Client calls Close multiple times on Request.Body. -func TestClientCallsCloseOnlyOnce(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientCallsCloseOnlyOnce(t *testing.T) { run(t, testClientCallsCloseOnlyOnce) } +func testClientCallsCloseOnlyOnce(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNoContent) })) - defer cst.close() // Issue occurred non-deterministically: needed to occur after a successful // write (into TCP buffer) but before end of body. @@ -2140,17 +2073,15 @@ func (b *issue40382Body) Close() error { return nil } -func TestProbeZeroLengthBody(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestProbeZeroLengthBody(t *testing.T) { run(t, testProbeZeroLengthBody) } +func testProbeZeroLengthBody(t *testing.T, mode testMode) { reqc := make(chan struct{}) - cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { close(reqc) if _, err := io.Copy(w, r.Body); err != nil { t.Errorf("error copying request body: %v", err) } })) - defer cst.close() bodyr, bodyw := io.Pipe() var gotBody string diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index b472ca4b78..87e34cef85 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -35,8 +35,65 @@ import ( "time" ) +type testMode string + +const ( + http1Mode = testMode("h1") // HTTP/1.1 + https1Mode = testMode("https1") // HTTPS/1.1 + http2Mode = testMode("h2") // HTTP/2 +) + +type testNotParallelOpt struct{} + +var ( + testNotParallel = testNotParallelOpt{} +) + +type TBRun[T any] interface { + testing.TB + Run(string, func(T)) bool +} + +// run runs a client/server test in a variety of test configurations. +// +// Tests execute in HTTP/1.1 and HTTP/2 modes by default. +// To run in a different set of configurations, pass a []testMode option. +// +// Tests call t.Parallel() by default. +// To disable parallel execution, pass the testNotParallel option. +func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) { + t.Helper() + modes := []testMode{http1Mode, http2Mode} + parallel := true + for _, opt := range opts { + switch opt := opt.(type) { + case []testMode: + modes = opt + case testNotParallelOpt: + parallel = false + default: + t.Fatalf("unknown option type %T", opt) + } + } + if t, ok := any(t).(*testing.T); ok && parallel { + setParallel(t) + } + for _, mode := range modes { + t.Run(string(mode), func(t T) { + t.Helper() + if t, ok := any(t).(*testing.T); ok && parallel { + setParallel(t) + } + t.Cleanup(func() { + afterTest(t) + }) + f(t, mode) + }) + } +} + type clientServerTest struct { - t *testing.T + t testing.TB h2 bool h Handler ts *httptest.Server @@ -69,11 +126,6 @@ func (t *clientServerTest) scheme() string { return "http" } -const ( - h1Mode = false - h2Mode = true -) - var optQuietLog = func(ts *httptest.Server) { ts.Config.ErrorLog = quietLog } @@ -84,23 +136,33 @@ func optWithServerLog(lg *log.Logger) func(*httptest.Server) { } } -func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientServerTest { - if h2 { +// newClientServerTest creates and starts an httptest.Server. +// +// The mode parameter selects the implementation to test: +// HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use +// the 'run' function, which will start a subtests for each tested mode. +// +// The vararg opts parameter can include functions to configure the +// test server or transport. +// +// func(*httptest.Server) // run before starting the server +// func(*http.Transport) +func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest { + if mode == http2Mode { CondSkipHTTP2(t) } cst := &clientServerTest{ t: t, - h2: h2, + h2: mode == http2Mode, h: h, - tr: &Transport{}, } - cst.c = &Client{Transport: cst.tr} cst.ts = httptest.NewUnstartedServer(h) + var transportFuncs []func(*Transport) for _, opt := range opts { switch opt := opt.(type) { case func(*Transport): - opt(cst.tr) + transportFuncs = append(transportFuncs, opt) case func(*httptest.Server): opt(cst.ts) default: @@ -108,60 +170,84 @@ func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientS } } - if !h2 { + switch mode { + case http1Mode: cst.ts.Start() - return cst + case https1Mode: + cst.ts.StartTLS() + case http2Mode: + ExportHttp2ConfigureServer(cst.ts.Config, nil) + cst.ts.TLS = cst.ts.Config.TLSConfig + cst.ts.StartTLS() + default: + t.Fatalf("unknown test mode %v", mode) } - ExportHttp2ConfigureServer(cst.ts.Config, nil) - cst.ts.TLS = cst.ts.Config.TLSConfig - cst.ts.StartTLS() - - cst.tr.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, + cst.c = cst.ts.Client() + cst.tr = cst.c.Transport.(*Transport) + if mode == http2Mode { + if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { + t.Fatal(err) + } } - if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { - t.Fatal(err) + for _, f := range transportFuncs { + f(cst.tr) } + t.Cleanup(func() { + cst.close() + }) return cst } // Testing the newClientServerTest helper itself. func TestNewClientServerTest(t *testing.T) { + run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode}) +} +func testNewClientServerTest(t *testing.T, mode testMode) { var got struct { sync.Mutex - log []string + proto string + hasTLS bool } h := HandlerFunc(func(w ResponseWriter, r *Request) { got.Lock() defer got.Unlock() - got.log = append(got.log, r.Proto) + got.proto = r.Proto + got.hasTLS = r.TLS != nil }) - for _, v := range [2]bool{false, true} { - cst := newClientServerTest(t, v, h) - if _, err := cst.c.Head(cst.ts.URL); err != nil { - t.Fatal(err) - } - cst.close() + cst := newClientServerTest(t, mode, h) + if _, err := cst.c.Head(cst.ts.URL); err != nil { + t.Fatal(err) } - got.Lock() // no need to unlock - if want := []string{"HTTP/1.1", "HTTP/2.0"}; !reflect.DeepEqual(got.log, want) { - t.Errorf("got %q; want %q", got.log, want) + var wantProto string + var wantTLS bool + switch mode { + case http1Mode: + wantProto = "HTTP/1.1" + wantTLS = false + case https1Mode: + wantProto = "HTTP/1.1" + wantTLS = true + case http2Mode: + wantProto = "HTTP/2.0" + wantTLS = true + } + if got.proto != wantProto { + t.Errorf("req.Proto = %q, want %q", got.proto, wantProto) + } + if got.hasTLS != wantTLS { + t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS) } } -func TestChunkedResponseHeaders_h1(t *testing.T) { testChunkedResponseHeaders(t, h1Mode) } -func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, h2Mode) } - -func testChunkedResponseHeaders(t *testing.T, h2 bool) { - defer afterTest(t) +func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) } +func testChunkedResponseHeaders(t *testing.T, mode testMode) { log.SetOutput(io.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted w.(Flusher).Flush() fmt.Fprintf(w, "I am a chunked response.") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -172,7 +258,7 @@ func testChunkedResponseHeaders(t *testing.T, h2 bool) { t.Errorf("expected ContentLength of %d; got %d", e, g) } wantTE := []string{"chunked"} - if h2 { + if mode == http2Mode { wantTE = nil } if !reflect.DeepEqual(res.TransferEncoding, wantTE) { @@ -204,9 +290,9 @@ func (tt h12Compare) reqFunc() reqFunc { func (tt h12Compare) run(t *testing.T) { setParallel(t) - cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...) + cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...) defer cst1.close() - cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...) + cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...) defer cst2.close() res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) @@ -459,12 +545,9 @@ func TestH12_AutoGzip_Disabled(t *testing.T) { // Test304Responses verifies that 304s don't declare that they're // chunking in their response headers and aren't allowed to produce // output. -func Test304Responses_h1(t *testing.T) { test304Responses(t, h1Mode) } -func Test304Responses_h2(t *testing.T) { test304Responses(t, h2Mode) } - -func test304Responses(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func Test304Responses(t *testing.T) { run(t, test304Responses) } +func test304Responses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNotModified) _, err := w.Write([]byte("illegal body")) if err != ErrBodyNotAllowed { @@ -528,20 +611,17 @@ func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int6 // Tests that closing the Request.Cancel channel also while still // reading the response body. Issue 13159. -func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) } -func TestCancelRequestMidBody_h2(t *testing.T) { testCancelRequestMidBody(t, h2Mode) } -func testCancelRequestMidBody(t *testing.T, h2 bool) { - defer afterTest(t) +func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) } +func testCancelRequestMidBody(t *testing.T, mode testMode) { unblock := make(chan bool) didFlush := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, "Hello") w.(Flusher).Flush() didFlush <- true <-unblock io.WriteString(w, ", world.") })) - defer cst.close() defer close(unblock) req, _ := NewRequest("GET", cst.ts.URL, nil) @@ -577,12 +657,9 @@ func testCancelRequestMidBody(t *testing.T, h2 bool) { } // Tests that clients can send trailers to a server and that the server can read them. -func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) } -func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) } - -func testTrailersClientToServer(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) } +func testTrailersClientToServer(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { var decl []string for k := range r.Trailer { decl = append(decl, k) @@ -605,7 +682,6 @@ func testTrailersClientToServer(t *testing.T, h2 bool) { r.Trailer.Get("Client-Trailer-B")) } })) - defer cst.close() var req *Request req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader( @@ -632,15 +708,20 @@ func testTrailersClientToServer(t *testing.T, h2 bool) { } // Tests that servers send trailers to a client and that the client can read them. -func TestTrailersServerToClient_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, false) } -func TestTrailersServerToClient_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, false) } -func TestTrailersServerToClient_Flush_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, true) } -func TestTrailersServerToClient_Flush_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, true) } +func TestTrailersServerToClient(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTrailersServerToClient(t, mode, false) + }) +} +func TestTrailersServerToClientFlush(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTrailersServerToClient(t, mode, true) + }) +} -func testTrailersServerToClient(t *testing.T, h2, flush bool) { - defer afterTest(t) +func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) { const body = "Some body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") w.Header().Add("Trailer", "Server-Trailer-C") @@ -657,7 +738,6 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { w.Header().Set("Server-Trailer-C", "valuec") // skipping B w.Header().Set("Server-Trailer-NotDeclared", "should be omitted") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -668,7 +748,7 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { "Content-Type": {"text/plain; charset=utf-8"}, } wantLen := -1 - if h2 && !flush { + if mode == http2Mode && !flush { // In HTTP/1.1, any use of trailers forces HTTP/1.1 // chunking and a flush at the first write. That's // unnecessary with HTTP/2's framing, so the server @@ -708,16 +788,12 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { } // Don't allow a Body.Read after Body.Close. Issue 13648. -func TestResponseBodyReadAfterClose_h1(t *testing.T) { testResponseBodyReadAfterClose(t, h1Mode) } -func TestResponseBodyReadAfterClose_h2(t *testing.T) { testResponseBodyReadAfterClose(t, h2Mode) } - -func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { - defer afterTest(t) +func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) } +func testResponseBodyReadAfterClose(t *testing.T, mode testMode) { const body = "Some body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, body) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -729,13 +805,11 @@ func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { } } -func TestConcurrentReadWriteReqBody_h1(t *testing.T) { testConcurrentReadWriteReqBody(t, h1Mode) } -func TestConcurrentReadWriteReqBody_h2(t *testing.T) { testConcurrentReadWriteReqBody(t, h2Mode) } -func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { - defer afterTest(t) +func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) } +func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) { const reqBody = "some request body" const resBody = "some response body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { var wg sync.WaitGroup wg.Add(2) didRead := make(chan bool, 1) @@ -754,7 +828,7 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { // Write in another goroutine. go func() { defer wg.Done() - if !h2 { + if mode != http2Mode { // our HTTP/1 implementation intentionally // doesn't permit writes during read (mostly // due to it being undefined); if that is ever @@ -765,7 +839,6 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { }() wg.Wait() })) - defer cst.close() req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody)) req.Header.Add("Expect", "100-continue") // just to complicate things res, err := cst.c.Do(req) @@ -782,15 +855,12 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { } } -func TestConnectRequest_h1(t *testing.T) { testConnectRequest(t, h1Mode) } -func TestConnectRequest_h2(t *testing.T) { testConnectRequest(t, h2Mode) } -func testConnectRequest(t *testing.T, h2 bool) { - defer afterTest(t) +func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) } +func testConnectRequest(t *testing.T, mode testMode) { gotc := make(chan *Request, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotc <- r })) - defer cst.close() u, err := url.Parse(cst.ts.URL) if err != nil { @@ -840,17 +910,14 @@ func testConnectRequest(t *testing.T, h2 bool) { } } -func TestTransportUserAgent_h1(t *testing.T) { testTransportUserAgent(t, h1Mode) } -func TestTransportUserAgent_h2(t *testing.T) { testTransportUserAgent(t, h2Mode) } -func testTransportUserAgent(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) } +func testTransportUserAgent(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%q", r.Header["User-Agent"]) })) - defer cst.close() either := func(a, b string) string { - if h2 { + if mode == http2Mode { return b } return a @@ -901,19 +968,22 @@ func testTransportUserAgent(t *testing.T, h2 bool) { } } -func TestStarRequestFoo_h1(t *testing.T) { testStarRequest(t, "FOO", h1Mode) } -func TestStarRequestFoo_h2(t *testing.T) { testStarRequest(t, "FOO", h2Mode) } -func TestStarRequestOptions_h1(t *testing.T) { testStarRequest(t, "OPTIONS", h1Mode) } -func TestStarRequestOptions_h2(t *testing.T) { testStarRequest(t, "OPTIONS", h2Mode) } -func testStarRequest(t *testing.T, method string, h2 bool) { - defer afterTest(t) +func TestStarRequestMethod(t *testing.T) { + for _, method := range []string{"FOO", "OPTIONS"} { + t.Run(method, func(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testStarRequest(t, method, mode) + }) + }) + } +} +func testStarRequest(t *testing.T, method string, mode testMode) { gotc := make(chan *Request, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("foo", "bar") gotc <- r w.(Flusher).Flush() })) - defer cst.close() u, err := url.Parse(cst.ts.URL) if err != nil { @@ -972,9 +1042,10 @@ func testStarRequest(t *testing.T, method string, h2 bool) { // Issue 13957 func TestTransportDiscardsUnneededConns(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode}) +} +func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) })) defer cst.close() @@ -1058,20 +1129,19 @@ func TestTransportDiscardsUnneededConns(t *testing.T) { } // tests that Transport doesn't retain a pointer to the provided request. -func TestTransportGCRequest_Body_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, true) } -func TestTransportGCRequest_Body_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, true) } -func TestTransportGCRequest_NoBody_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, false) } -func TestTransportGCRequest_NoBody_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, false) } -func testTransportGCRequest(t *testing.T, h2, body bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportGCRequest(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) }) + t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) }) + }) +} +func testTransportGCRequest(t *testing.T, mode testMode, body bool) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.ReadAll(r.Body) if body { io.WriteString(w, "Hello.") } })) - defer cst.close() didGC := make(chan struct{}) (func() { @@ -1103,19 +1173,11 @@ func testTransportGCRequest(t *testing.T, h2, body bool) { } } -func TestTransportRejectsInvalidHeaders_h1(t *testing.T) { - testTransportRejectsInvalidHeaders(t, h1Mode) -} -func TestTransportRejectsInvalidHeaders_h2(t *testing.T) { - testTransportRejectsInvalidHeaders(t, h2Mode) -} -func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) } +func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Handler saw headers: %q", r.Header) }), optQuietLog) - defer cst.close() cst.tr.DisableKeepAlives = true tests := []struct { @@ -1161,27 +1223,22 @@ func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { } } -func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, "boom") } -func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, "boom") } -func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) } -func TestInterruptWithPanic_nil_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, nil) } -func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) { - testInterruptWithPanic(t, h1Mode, ErrAbortHandler) +func TestInterruptWithPanic(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") }) + t.Run("nil", func(t *testing.T) { testInterruptWithPanic(t, mode, nil) }) + t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) }) + }) } -func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) { - testInterruptWithPanic(t, h2Mode, ErrAbortHandler) -} -func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) { - setParallel(t) +func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) { const msg = "hello" - defer afterTest(t) testDone := make(chan struct{}) defer close(testDone) var errorLog lockedBytesBuffer gotHeaders := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() @@ -1193,7 +1250,6 @@ func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) { }), func(ts *httptest.Server) { ts.Config.ErrorLog = log.New(&errorLog, "", 0) }) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1274,15 +1330,11 @@ func TestH12_AutoGzipWithDumpResponse(t *testing.T) { } // Issue 14607 -func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) } -func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) } -func testCloseIdleConnections(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) } +func testCloseIdleConnections(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Addr", r.RemoteAddr) })) - defer cst.close() get := func() string { res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -1320,15 +1372,11 @@ func (r testErrorReader) Read(p []byte) (n int, err error) { return 0, io.EOF } -func TestNoSniffExpectRequestBody_h1(t *testing.T) { testNoSniffExpectRequestBody(t, h1Mode) } -func TestNoSniffExpectRequestBody_h2(t *testing.T) { testNoSniffExpectRequestBody(t, h2Mode) } - -func testNoSniffExpectRequestBody(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) } +func testNoSniffExpectRequestBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusUnauthorized) })) - defer cst.close() // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it. cst.tr.ExpectContinueTimeout = 10 * time.Second @@ -1349,18 +1397,15 @@ func testNoSniffExpectRequestBody(t *testing.T, h2 bool) { } } -func TestServerUndeclaredTrailers_h1(t *testing.T) { testServerUndeclaredTrailers(t, h1Mode) } -func TestServerUndeclaredTrailers_h2(t *testing.T) { testServerUndeclaredTrailers(t, h2Mode) } -func testServerUndeclaredTrailers(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) } +func testServerUndeclaredTrailers(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Foo", "Bar") w.Header().Set("Trailer:Foo", "Baz") w.(Flusher).Flush() w.Header().Add("Trailer:Foo", "Baz2") w.Header().Set("Trailer:Bar", "Quux") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1381,8 +1426,10 @@ func testServerUndeclaredTrailers(t *testing.T, h2 bool) { } func TestBadResponseAfterReadingBody(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testBadResponseAfterReadingBody, []testMode{http1Mode}) +} +func testBadResponseAfterReadingBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := io.Copy(io.Discard, r.Body) if err != nil { t.Fatal(err) @@ -1394,7 +1441,6 @@ func TestBadResponseAfterReadingBody(t *testing.T) { defer c.Close() fmt.Fprintln(c, "some bogus crap") })) - defer cst.close() closes := 0 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) @@ -1407,12 +1453,10 @@ func TestBadResponseAfterReadingBody(t *testing.T) { } } -func TestWriteHeader0_h1(t *testing.T) { testWriteHeader0(t, h1Mode) } -func TestWriteHeader0_h2(t *testing.T) { testWriteHeader0(t, h2Mode) } -func testWriteHeader0(t *testing.T, h2 bool) { - defer afterTest(t) +func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) } +func testWriteHeader0(t *testing.T, mode testMode) { gotpanic := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(gotpanic) defer func() { if e := recover(); e != nil { @@ -1431,7 +1475,6 @@ func testWriteHeader0(t *testing.T, h2 bool) { }() w.WriteHeader(0) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1446,15 +1489,17 @@ func testWriteHeader0(t *testing.T, h2 bool) { // Issue 23010: don't be super strict checking WriteHeader's code if // it's not even valid to call WriteHeader then anyway. -func TestWriteHeaderNoCodeCheck_h1(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, false) } -func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, true) } -func TestWriteHeaderNoCodeCheck_h2(t *testing.T) { testWriteHeaderAfterWrite(t, h2Mode, false) } -func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { - setParallel(t) - defer afterTest(t) - +func TestWriteHeaderNoCodeCheck(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testWriteHeaderAfterWrite(t, mode, false) + }) +} +func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { + testWriteHeaderAfterWrite(t, http1Mode, true) +} +func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) { var errorLog lockedBytesBuffer - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if hijack { conn, _, _ := w.(Hijacker).Hijack() defer conn.Close() @@ -1470,7 +1515,6 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { }), func(ts *httptest.Server) { ts.Config.ErrorLog = log.New(&errorLog, "", 0) }) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1485,7 +1529,7 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { } // Also check the stderr output: - if h2 { + if mode == http2Mode { // TODO: also emit this log message for HTTP/2? // We historically haven't, so don't check. return @@ -1501,14 +1545,14 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { } func TestBidiStreamReverseProxy(t *testing.T) { - setParallel(t) - defer afterTest(t) - backend := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testBidiStreamReverseProxy, []testMode{http2Mode}) +} +func testBidiStreamReverseProxy(t *testing.T, mode testMode) { + backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if _, err := io.Copy(w, r.Body); err != nil { log.Printf("bidi backend copy: %v", err) } })) - defer backend.close() backURL, err := url.Parse(backend.ts.URL) if err != nil { @@ -1516,10 +1560,9 @@ func TestBidiStreamReverseProxy(t *testing.T) { } rp := httputil.NewSingleHostReverseProxy(backURL) rp.Transport = backend.tr - proxy := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { rp.ServeHTTP(w, r) })) - defer proxy.close() bodyRes := make(chan any, 1) // error or hash.Hash pr, pw := io.Pipe() @@ -1586,15 +1629,10 @@ func TestH12_WebSocketUpgrade(t *testing.T) { }.run(t) } -func TestIdentityTransferEncoding_h1(t *testing.T) { testIdentityTransferEncoding(t, h1Mode) } -func TestIdentityTransferEncoding_h2(t *testing.T) { testIdentityTransferEncoding(t, h2Mode) } - -func testIdentityTransferEncoding(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) } +func testIdentityTransferEncoding(t *testing.T, mode testMode) { const body = "body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotBody, _ := io.ReadAll(r.Body) if got, want := string(gotBody), body; got != want { t.Errorf("got request body = %q; want %q", got, want) @@ -1604,7 +1642,6 @@ func testIdentityTransferEncoding(t *testing.T, h2 bool) { w.(Flusher).Flush() io.WriteString(w, body) })) - defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body)) res, err := cst.c.Do(req) if err != nil { @@ -1620,14 +1657,11 @@ func testIdentityTransferEncoding(t *testing.T, h2 bool) { } } -func TestEarlyHintsRequest_h1(t *testing.T) { testEarlyHintsRequest(t, h1Mode) } -func TestEarlyHintsRequest_h2(t *testing.T) { testEarlyHintsRequest(t, h2Mode) } -func testEarlyHintsRequest(t *testing.T, h2 bool) { - defer afterTest(t) - +func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) } +func testEarlyHintsRequest(t *testing.T, mode testMode) { var wg sync.WaitGroup wg.Add(1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { h := w.Header() h.Add("Content-Length", "123") // must be ignored @@ -1642,7 +1676,6 @@ func testEarlyHintsRequest(t *testing.T, h2 bool) { w.Write([]byte("Hello")) })) - defer cst.close() checkLinkHeaders := func(t *testing.T, expected, got []string) { t.Helper() diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go index 205ca83f40..fb5ab9396a 100644 --- a/src/net/http/export_test.go +++ b/src/net/http/export_test.go @@ -60,7 +60,7 @@ func init() { } } -func CondSkipHTTP2(t *testing.T) { +func CondSkipHTTP2(t testing.TB) { if omitBundledHTTP2 { t.Skip("skipping HTTP/2 test when nethttpomithttp2 build tag in use") } @@ -72,8 +72,6 @@ var ( ) func SetReadLoopBeforeNextReadHook(f func()) { - testHookMu.Lock() - defer testHookMu.Unlock() unnilTestHook(&f) testHookReadLoopBeforeNextRead = f } diff --git a/src/net/http/fs_test.go b/src/net/http/fs_test.go index 71fc064367..47526152b3 100644 --- a/src/net/http/fs_test.go +++ b/src/net/http/fs_test.go @@ -68,13 +68,11 @@ var ServeFileRangeTests = []struct { {r: "bytes=100-1000", code: StatusRequestedRangeNotSatisfiable}, } -func TestServeFile(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFile(t *testing.T) { run(t, testServeFile) } +func testServeFile(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/file") - })) - defer ts.Close() + })).ts c := ts.Client() var err error @@ -228,13 +226,12 @@ var fsRedirectTestData = []struct { {"/test/testdata/file/", "/test/testdata/file"}, } -func TestFSRedirect(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir(".")))) - defer ts.Close() +func TestFSRedirect(t *testing.T) { run(t, testFSRedirect) } +func testFSRedirect(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, StripPrefix("/test", FileServer(Dir(".")))).ts for _, data := range fsRedirectTestData { - res, err := Get(ts.URL + data.original) + res, err := ts.Client().Get(ts.URL + data.original) if err != nil { t.Fatal(err) } @@ -278,8 +275,8 @@ func TestFileServerCleans(t *testing.T) { } } -func TestFileServerEscapesNames(t *testing.T) { - defer afterTest(t) +func TestFileServerEscapesNames(t *testing.T) { run(t, testFileServerEscapesNames) } +func testFileServerEscapesNames(t *testing.T, mode testMode) { const dirListPrefix = "
\n"
 	const dirListSuffix = "\n
\n" tests := []struct { @@ -304,11 +301,10 @@ func TestFileServerEscapesNames(t *testing.T) { fs[fmt.Sprintf("/%d/%s", i, test.name)] = testFile } - ts := httptest.NewServer(FileServer(&fs)) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(&fs)).ts for i, test := range tests { url := fmt.Sprintf("%s/%d", ts.URL, i) - res, err := Get(url) + res, err := ts.Client().Get(url) if err != nil { t.Fatalf("test %q: Get: %v", test.name, err) } @@ -327,8 +323,8 @@ func TestFileServerEscapesNames(t *testing.T) { } } -func TestFileServerSortsNames(t *testing.T) { - defer afterTest(t) +func TestFileServerSortsNames(t *testing.T) { run(t, testFileServerSortsNames) } +func testFileServerSortsNames(t *testing.T, mode testMode) { const contents = "I am a fake file" dirMod := time.Unix(123, 0).UTC() fileMod := time.Unix(1000000000, 0).UTC() @@ -351,10 +347,9 @@ func TestFileServerSortsNames(t *testing.T) { }, } - ts := httptest.NewServer(FileServer(&fs)) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(&fs)).ts - res, err := Get(ts.URL) + res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatalf("Get: %v", err) } @@ -377,16 +372,15 @@ func mustRemoveAll(dir string) { } } -func TestFileServerImplicitLeadingSlash(t *testing.T) { - defer afterTest(t) +func TestFileServerImplicitLeadingSlash(t *testing.T) { run(t, testFileServerImplicitLeadingSlash) } +func testFileServerImplicitLeadingSlash(t *testing.T, mode testMode) { tempDir := t.TempDir() if err := os.WriteFile(filepath.Join(tempDir, "foo.txt"), []byte("Hello world"), 0644); err != nil { t.Fatalf("WriteFile: %v", err) } - ts := httptest.NewServer(StripPrefix("/bar/", FileServer(Dir(tempDir)))) - defer ts.Close() + ts := newClientServerTest(t, mode, StripPrefix("/bar/", FileServer(Dir(tempDir)))).ts get := func(suffix string) string { - res, err := Get(ts.URL + suffix) + res, err := ts.Client().Get(ts.URL + suffix) if err != nil { t.Fatalf("Get %s: %v", suffix, err) } @@ -405,11 +399,10 @@ func TestFileServerImplicitLeadingSlash(t *testing.T) { } } -func TestFileServerMethodOptions(t *testing.T) { - defer afterTest(t) +func TestFileServerMethodOptions(t *testing.T) { run(t, testFileServerMethodOptions) } +func testFileServerMethodOptions(t *testing.T, mode testMode) { const want = "GET, HEAD, OPTIONS" - ts := httptest.NewServer(FileServer(Dir("."))) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(Dir("."))).ts tests := []struct { method string @@ -496,10 +489,10 @@ func TestEmptyDirOpenCWD(t *testing.T) { test(Dir("./")) } -func TestServeFileContentType(t *testing.T) { - defer afterTest(t) +func TestServeFileContentType(t *testing.T) { run(t, testServeFileContentType) } +func testServeFileContentType(t *testing.T, mode testMode) { const ctype = "icecream/chocolate" - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { switch r.FormValue("override") { case "1": w.Header().Set("Content-Type", ctype) @@ -508,10 +501,9 @@ func TestServeFileContentType(t *testing.T) { w.Header()["Content-Type"] = []string{} } ServeFile(w, r, "testdata/file") - })) - defer ts.Close() + })).ts get := func(override string, want []string) { - resp, err := Get(ts.URL + "?override=" + override) + resp, err := ts.Client().Get(ts.URL + "?override=" + override) if err != nil { t.Fatal(err) } @@ -525,13 +517,12 @@ func TestServeFileContentType(t *testing.T) { get("2", nil) } -func TestServeFileMimeType(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFileMimeType(t *testing.T) { run(t, testServeFileMimeType) } +func testServeFileMimeType(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/style.css") - })) - defer ts.Close() - resp, err := Get(ts.URL) + })).ts + resp, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -542,13 +533,12 @@ func TestServeFileMimeType(t *testing.T) { } } -func TestServeFileFromCWD(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFileFromCWD(t *testing.T) { run(t, testServeFileFromCWD) } +func testServeFileFromCWD(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "fs_test.go") - })) - defer ts.Close() - r, err := Get(ts.URL) + })).ts + r, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -559,14 +549,13 @@ func TestServeFileFromCWD(t *testing.T) { } // Issue 13996 -func TestServeDirWithoutTrailingSlash(t *testing.T) { +func TestServeDirWithoutTrailingSlash(t *testing.T) { run(t, testServeDirWithoutTrailingSlash) } +func testServeDirWithoutTrailingSlash(t *testing.T, mode testMode) { e := "/testdata/" - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, ".") - })) - defer ts.Close() - r, err := Get(ts.URL + "/testdata") + })).ts + r, err := ts.Client().Get(ts.URL + "/testdata") if err != nil { t.Fatal(err) } @@ -578,11 +567,9 @@ func TestServeDirWithoutTrailingSlash(t *testing.T) { // Tests that ServeFile doesn't add a Content-Length if a Content-Encoding is // specified. -func TestServeFileWithContentEncoding_h1(t *testing.T) { testServeFileWithContentEncoding(t, h1Mode) } -func TestServeFileWithContentEncoding_h2(t *testing.T) { testServeFileWithContentEncoding(t, h2Mode) } -func testServeFileWithContentEncoding(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFileWithContentEncoding(t *testing.T) { run(t, testServeFileWithContentEncoding) } +func testServeFileWithContentEncoding(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "foo") ServeFile(w, r, "testdata/file") @@ -595,7 +582,6 @@ func testServeFileWithContentEncoding(t *testing.T, h2 bool) { // Content-Length and test ServeFile only, flush here. w.(Flusher).Flush() })) - defer cst.close() resp, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -608,11 +594,9 @@ func testServeFileWithContentEncoding(t *testing.T, h2 bool) { // Tests that ServeFile does not generate representation metadata when // file has not been modified, as per RFC 7232 section 4.1. -func TestServeFileNotModified_h1(t *testing.T) { testServeFileNotModified(t, h1Mode) } -func TestServeFileNotModified_h2(t *testing.T) { testServeFileNotModified(t, h2Mode) } -func testServeFileNotModified(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFileNotModified(t *testing.T) { run(t, testServeFileNotModified) } +func testServeFileNotModified(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Encoding", "foo") w.Header().Set("Etag", `"123"`) @@ -627,7 +611,6 @@ func testServeFileNotModified(t *testing.T, h2 bool) { // Content-Length and test ServeFile only, flush here. w.(Flusher).Flush() })) - defer cst.close() req, err := NewRequest("GET", cst.ts.URL, nil) if err != nil { t.Fatal(err) @@ -660,9 +643,8 @@ func testServeFileNotModified(t *testing.T, h2 bool) { } } -func TestServeIndexHtml(t *testing.T) { - defer afterTest(t) - +func TestServeIndexHtml(t *testing.T) { run(t, testServeIndexHtml) } +func testServeIndexHtml(t *testing.T, mode testMode) { for i := 0; i < 2; i++ { var h Handler var name string @@ -676,11 +658,10 @@ func TestServeIndexHtml(t *testing.T) { } t.Run(name, func(t *testing.T) { const want = "index.html says hello\n" - ts := httptest.NewServer(h) - defer ts.Close() + ts := newClientServerTest(t, mode, h).ts for _, path := range []string{"/testdata/", "/testdata/index.html"} { - res, err := Get(ts.URL + path) + res, err := ts.Client().Get(ts.URL + path) if err != nil { t.Fatal(err) } @@ -697,14 +678,14 @@ func TestServeIndexHtml(t *testing.T) { } } -func TestServeIndexHtmlFS(t *testing.T) { - defer afterTest(t) +func TestServeIndexHtmlFS(t *testing.T) { run(t, testServeIndexHtmlFS) } +func testServeIndexHtmlFS(t *testing.T, mode testMode) { const want = "index.html says hello\n" - ts := httptest.NewServer(FileServer(Dir("."))) + ts := newClientServerTest(t, mode, FileServer(Dir("."))).ts defer ts.Close() for _, path := range []string{"/testdata/", "/testdata/index.html"} { - res, err := Get(ts.URL + path) + res, err := ts.Client().Get(ts.URL + path) if err != nil { t.Fatal(err) } @@ -719,10 +700,9 @@ func TestServeIndexHtmlFS(t *testing.T) { } } -func TestFileServerZeroByte(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(FileServer(Dir("."))) - defer ts.Close() +func TestFileServerZeroByte(t *testing.T) { run(t, testFileServerZeroByte) } +func testFileServerZeroByte(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, FileServer(Dir("."))).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -809,8 +789,8 @@ func (fsys fakeFS) Open(name string) (File, error) { return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil } -func TestDirectoryIfNotModified(t *testing.T) { - defer afterTest(t) +func TestDirectoryIfNotModified(t *testing.T) { run(t, testDirectoryIfNotModified) } +func testDirectoryIfNotModified(t *testing.T, mode testMode) { const indexContents = "I am a fake index.html file" fileMod := time.Unix(1000000000, 0).UTC() fileModStr := fileMod.Format(TimeFormat) @@ -829,10 +809,9 @@ func TestDirectoryIfNotModified(t *testing.T) { "/index.html": indexFile, } - ts := httptest.NewServer(FileServer(fs)) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(fs)).ts - res, err := Get(ts.URL) + res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -884,8 +863,8 @@ func mustStat(t *testing.T, fileName string) fs.FileInfo { return fi } -func TestServeContent(t *testing.T) { - defer afterTest(t) +func TestServeContent(t *testing.T) { run(t, testServeContent) } +func testServeContent(t *testing.T, mode testMode) { type serveParam struct { name string modtime time.Time @@ -894,7 +873,7 @@ func TestServeContent(t *testing.T) { etag string } servec := make(chan serveParam, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { p := <-servec if p.etag != "" { w.Header().Set("ETag", p.etag) @@ -903,8 +882,7 @@ func TestServeContent(t *testing.T) { w.Header().Set("Content-Type", p.contentType) } ServeContent(w, r, p.name, p.modtime, p.content) - })) - defer ts.Close() + })).ts type testCase struct { // One of file or content must be set: @@ -1213,8 +1191,8 @@ type issue12991File struct{ File } func (issue12991File) Stat() (fs.FileInfo, error) { return nil, fs.ErrPermission } func (issue12991File) Close() error { return nil } -func TestServeContentErrorMessages(t *testing.T) { - defer afterTest(t) +func TestServeContentErrorMessages(t *testing.T) { run(t, testServeContentErrorMessages) } +func testServeContentErrorMessages(t *testing.T, mode testMode) { fs := fakeFS{ "/500": &fakeFileInfo{ err: errors.New("random error"), @@ -1223,8 +1201,7 @@ func TestServeContentErrorMessages(t *testing.T) { err: &fs.PathError{Err: fs.ErrPermission}, }, } - ts := httptest.NewServer(FileServer(fs)) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(fs)).ts c := ts.Client() for _, code := range []int{403, 404, 500} { res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, code)) @@ -1342,20 +1319,20 @@ func TestLinuxSendfileChild(*testing.T) { // Issues 18984, 49552: tests that requests for paths beyond files return not-found errors func TestFileServerNotDirError(t *testing.T) { - defer afterTest(t) - t.Run("Dir", func(t *testing.T) { - testFileServerNotDirError(t, func(path string) FileSystem { return Dir(path) }) - }) - t.Run("FS", func(t *testing.T) { - testFileServerNotDirError(t, func(path string) FileSystem { return FS(os.DirFS(path)) }) + run(t, func(t *testing.T, mode testMode) { + t.Run("Dir", func(t *testing.T) { + testFileServerNotDirError(t, mode, func(path string) FileSystem { return Dir(path) }) + }) + t.Run("FS", func(t *testing.T) { + testFileServerNotDirError(t, mode, func(path string) FileSystem { return FS(os.DirFS(path)) }) + }) }) } -func testFileServerNotDirError(t *testing.T, newfs func(string) FileSystem) { - ts := httptest.NewServer(FileServer(newfs("testdata"))) - defer ts.Close() +func testFileServerNotDirError(t *testing.T, mode testMode, newfs func(string) FileSystem) { + ts := newClientServerTest(t, mode, FileServer(newfs("testdata"))).ts - res, err := Get(ts.URL + "/index.html/not-a-file") + res, err := ts.Client().Get(ts.URL + "/index.html/not-a-file") if err != nil { t.Fatal(err) } @@ -1459,19 +1436,11 @@ func Test_scanETag(t *testing.T) { // Issue 40940: Ensure that we only accept non-negative suffix-lengths // in "Range": "bytes=-N", and should reject "bytes=--2". -func TestServeFileRejectsInvalidSuffixLengths_h1(t *testing.T) { - testServeFileRejectsInvalidSuffixLengths(t, h1Mode) +func TestServeFileRejectsInvalidSuffixLengths(t *testing.T) { + run(t, testServeFileRejectsInvalidSuffixLengths, []testMode{http1Mode, https1Mode, http2Mode}) } -func TestServeFileRejectsInvalidSuffixLengths_h2(t *testing.T) { - testServeFileRejectsInvalidSuffixLengths(t, h2Mode) -} - -func testServeFileRejectsInvalidSuffixLengths(t *testing.T, h2 bool) { - defer afterTest(t) - cst := httptest.NewUnstartedServer(FileServer(Dir("testdata"))) - cst.EnableHTTP2 = h2 - cst.StartTLS() - defer cst.Close() +func testServeFileRejectsInvalidSuffixLengths(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, FileServer(Dir("testdata"))).ts tests := []struct { r string diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go index 27e9eb30ee..686a8699fb 100644 --- a/src/net/http/request_test.go +++ b/src/net/http/request_test.go @@ -15,7 +15,6 @@ import ( "math" "mime/multipart" . "net/http" - "net/http/httptest" "net/url" "os" "reflect" @@ -289,10 +288,11 @@ Content-Type: text/plain // the payload size and the internal leeway buffer size of 10MiB overflows, that we // correctly return an error. func TestMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T) { - defer afterTest(t) - + run(t, testMaxInt64ForMultipartFormMaxMemoryOverflow) +} +func testMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T, mode testMode) { payloadSize := 1 << 10 - cst := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { // The combination of: // MaxInt64 + payloadSize + (internal spare of 10MiB) // triggers the overflow. See issue https://golang.org/issue/40430/ @@ -300,8 +300,7 @@ func TestMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T) { Error(rw, err.Error(), StatusBadRequest) return } - })) - defer cst.Close() + })).ts fBuf := new(bytes.Buffer) mw := multipart.NewWriter(fBuf) mf, err := mw.CreateFormFile("file", "myfile.txt") @@ -329,11 +328,9 @@ func TestMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T) { } } -func TestRedirect_h1(t *testing.T) { testRedirect(t, h1Mode) } -func TestRedirect_h2(t *testing.T) { testRedirect(t, h2Mode) } -func testRedirect(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestRequestRedirect(t *testing.T) { run(t, testRequestRedirect) } +func testRequestRedirect(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { switch r.URL.Path { case "/": w.Header().Set("Location", "/foo/") @@ -344,7 +341,6 @@ func testRedirect(t *testing.T, h2 bool) { w.WriteHeader(StatusBadRequest) } })) - defer cst.close() var end = regexp.MustCompile("/foo/$") r, err := cst.c.Get(cst.ts.URL) @@ -1035,19 +1031,10 @@ func TestRequestCloneTransferEncoding(t *testing.T) { } } -func TestNoPanicOnRoundTripWithBasicAuth_h1(t *testing.T) { - testNoPanicWithBasicAuth(t, h1Mode) -} - -func TestNoPanicOnRoundTripWithBasicAuth_h2(t *testing.T) { - testNoPanicWithBasicAuth(t, h2Mode) -} - // Issue 34878: verify we don't panic when including basic auth (Go 1.13 regression) -func testNoPanicWithBasicAuth(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer cst.close() +func TestNoPanicOnRoundTripWithBasicAuth(t *testing.T) { run(t, testNoPanicWithBasicAuth) } +func testNoPanicWithBasicAuth(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})) u, err := url.Parse(cst.ts.URL) if err != nil { @@ -1328,11 +1315,6 @@ Host: localhost:8080 `) } -const ( - withTLS = true - noTLS = false -) - func BenchmarkFileAndServer_1KB(b *testing.B) { benchmarkFileAndServer(b, 1<<10) } @@ -1360,16 +1342,12 @@ func benchmarkFileAndServer(b *testing.B, n int64) { b.Fatalf("Failed to copy %d bytes: %v", n, err) } - b.Run("NoTLS", func(b *testing.B) { - runFileAndServerBenchmarks(b, noTLS, f, n) - }) - - b.Run("TLS", func(b *testing.B) { - runFileAndServerBenchmarks(b, withTLS, f, n) - }) + run(b, func(b *testing.B, mode testMode) { + runFileAndServerBenchmarks(b, mode, f, n) + }, []testMode{http1Mode, https1Mode, http2Mode}) } -func runFileAndServerBenchmarks(b *testing.B, tlsOption bool, f *os.File, n int64) { +func runFileAndServerBenchmarks(b *testing.B, mode testMode, f *os.File, n int64) { handler := HandlerFunc(func(rw ResponseWriter, req *Request) { defer req.Body.Close() nc, err := io.Copy(io.Discard, req.Body) @@ -1382,14 +1360,8 @@ func runFileAndServerBenchmarks(b *testing.B, tlsOption bool, f *os.File, n int6 } }) - var cst *httptest.Server - if tlsOption == withTLS { - cst = httptest.NewTLSServer(handler) - } else { - cst = httptest.NewServer(handler) - } + cst := newClientServerTest(b, mode, handler).ts - defer cst.Close() b.ResetTimer() for i := 0; i < b.N; i++ { // Perform some setup. diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 4fadc56c9e..a93f6eff1b 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -246,15 +246,13 @@ var vtests = []struct { {"http://someHost.com/someDir", "/someDir/"}, } -func TestHostHandlers(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) } +func testHostHandlers(t *testing.T, mode testMode) { mux := NewServeMux() for _, h := range handlers { mux.Handle(h.pattern, stringHandler(h.msg)) } - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -487,9 +485,9 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { // properly sets the query string in the redirect URL. // See Issue 17841. func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { - setParallel(t) - defer afterTest(t) - + run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode}) +} +func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) { writeBackQuery := func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.URL.RawQuery) } @@ -502,8 +500,7 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { fmt.Fprintf(w, "%s:bar", r.URL.RawQuery) }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts tests := [...]struct { path string @@ -546,7 +543,6 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { setParallel(t) - defer afterTest(t) mux := NewServeMux() mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/")) @@ -578,9 +574,6 @@ func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""}, } - ts := httptest.NewServer(mux) - defer ts.Close() - for i, tt := range tests { req, _ := NewRequest(tt.method, tt.url, nil) w := httptest.NewRecorder() @@ -602,13 +595,10 @@ func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { } } -func TestShouldRedirectConcurrency(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) } +func testShouldRedirectConcurrency(t *testing.T, mode testMode) { mux := NewServeMux() - ts := httptest.NewServer(mux) - defer ts.Close() + newClientServerTest(t, mode, mux) mux.HandleFunc("/", func(w ResponseWriter, r *Request) {}) } @@ -656,13 +646,12 @@ func benchmarkServeMux(b *testing.B, runHandler bool) { } } -func TestServerTimeouts(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) } +func testServerTimeouts(t *testing.T, mode testMode) { // Try three times, with increasing timeouts. tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second} for i, timeout := range tries { - err := testServerTimeouts(timeout) + err := testServerTimeoutsWithTimeout(t, timeout, mode) if err == nil { return } @@ -674,16 +663,15 @@ func TestServerTimeouts(t *testing.T) { t.Fatal("all attempts failed") } -func testServerTimeouts(timeout time.Duration) error { +func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ fmt.Fprintf(res, "req=%d", reqNum) - })) - ts.Config.ReadTimeout = timeout - ts.Config.WriteTimeout = timeout - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadTimeout = timeout + ts.Config.WriteTimeout = timeout + }).ts // Hit the HTTP server successfully. c := ts.Client() @@ -749,22 +737,20 @@ func testServerTimeouts(timeout time.Duration) error { } // Test that the HTTP/2 server handles Server.WriteTimeout (Issue 18437) -func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { +func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) { + run(t, testWriteDeadlineExtendedOnNewRequest) +} +func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {})) - ts.Config.WriteTimeout = 250 * time.Millisecond - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}), + func(ts *httptest.Server) { + ts.Config.WriteTimeout = 250 * time.Millisecond + }, + ).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - t.Fatal(err) - } for i := 1; i <= 3; i++ { req, err := NewRequest("GET", ts.URL, nil) @@ -785,9 +771,6 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { t.Fatalf("http2 Get #%d: %v", i, err) } r.Body.Close() - if r.ProtoMajor != 2 { - t.Fatalf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } time.Sleep(ts.Config.WriteTimeout / 2) } } @@ -810,33 +793,31 @@ func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) { } // Test that the HTTP/2 server RSTs stream on slow write. -func TestHTTP2WriteDeadlineEnforcedPerStream(t *testing.T) { +func TestWriteDeadlineEnforcedPerStream(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } setParallel(t) - defer afterTest(t) - tryTimeouts(t, testHTTP2WriteDeadlineEnforcedPerStream) + run(t, func(t *testing.T, mode testMode) { + tryTimeouts(t, func(timeout time.Duration) error { + return testWriteDeadlineEnforcedPerStream(t, mode, timeout) + }) + }) } -func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { +func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ if reqNum == 1 { return // first request succeeds } time.Sleep(timeout) // second request times out - })) - ts.Config.WriteTimeout = timeout / 2 - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.WriteTimeout = timeout / 2 + }).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err) - } req, err := NewRequest("GET", ts.URL, nil) if err != nil { @@ -844,12 +825,9 @@ func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { } r, err := c.Do(req) if err != nil { - return fmt.Errorf("http2 Get #1: %v", err) + return fmt.Errorf("Get #1: %v", err) } r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } req, err = NewRequest("GET", ts.URL, nil) if err != nil { @@ -858,45 +836,42 @@ func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { r, err = c.Do(req) if err == nil { r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } - return fmt.Errorf("http2 Get #2 expected error, got nil") + return fmt.Errorf("Get #2 expected error, got nil") } - expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3 - if !strings.Contains(err.Error(), expected) { - return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err) + if mode == http2Mode { + expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3 + if !strings.Contains(err.Error(), expected) { + return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err) + } } return nil } // Test that the HTTP/2 server does not send RST when WriteDeadline not set. -func TestHTTP2NoWriteDeadline(t *testing.T) { +func TestNoWriteDeadline(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } setParallel(t) defer afterTest(t) - tryTimeouts(t, testHTTP2NoWriteDeadline) + run(t, func(t *testing.T, mode testMode) { + tryTimeouts(t, func(timeout time.Duration) error { + return testNoWriteDeadline(t, mode, timeout) + }) + }) } -func testHTTP2NoWriteDeadline(timeout time.Duration) error { +func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ if reqNum == 1 { return // first request succeeds } time.Sleep(timeout) // second request timesout - })) - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + })).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err) - } for i := 0; i < 2; i++ { req, err := NewRequest("GET", ts.URL, nil) @@ -905,12 +880,9 @@ func testHTTP2NoWriteDeadline(timeout time.Duration) error { } r, err := c.Do(req) if err != nil { - return fmt.Errorf("http2 Get #%d: %v", i, err) + return fmt.Errorf("Get #%d: %v", i, err) } r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } } return nil } @@ -918,15 +890,14 @@ func testHTTP2NoWriteDeadline(timeout time.Duration) error { // golang.org/issue/4741 -- setting only a write timeout that triggers // shouldn't cause a handler to block forever on reads (next HTTP // request) that will never happen. -func TestOnlyWriteTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) } +func testOnlyWriteTimeout(t *testing.T, mode testMode) { var ( mu sync.RWMutex conn net.Conn ) var afterTimeoutErrc = make(chan error, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { buf := make([]byte, 512<<10) _, err := w.Write(buf) if err != nil { @@ -942,10 +913,9 @@ func TestOnlyWriteTimeout(t *testing.T) { conn.SetWriteDeadline(time.Now().Add(-30 * time.Second)) _, err = w.Write(buf) afterTimeoutErrc <- err - })) - ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn} - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn} + }).ts c := ts.Client() @@ -992,9 +962,12 @@ func (l trackLastConnListener) Accept() (c net.Conn, err error) { } // TestIdentityResponse verifies that a handler can unset -func TestIdentityResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) } +func testIdentityResponse(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/56019") + } + handler := HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Length", "3") rw.Header().Set("Transfer-Encoding", req.FormValue("te")) @@ -1012,9 +985,7 @@ func TestIdentityResponse(t *testing.T) { } }) - ts := httptest.NewServer(handler) - defer ts.Close() - + ts := newClientServerTest(t, mode, handler).ts c := ts.Client() // Note: this relies on the assumption (which is true) that @@ -1048,6 +1019,10 @@ func TestIdentityResponse(t *testing.T) { } res.Body.Close() + if mode != http1Mode { + return + } + // Verify that the connection is closed when the declared Content-Length // is larger than what the handler wrote. conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -1070,9 +1045,7 @@ func TestIdentityResponse(t *testing.T) { func testTCPConnectionCloses(t *testing.T, req string, h Handler) { setParallel(t) - defer afterTest(t) - s := httptest.NewServer(h) - defer s.Close() + s := newClientServerTest(t, http1Mode, h).ts conn, err := net.Dial("tcp", s.Listener.Addr().String()) if err != nil { @@ -1114,9 +1087,7 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) { func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) { setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(handler) - defer ts.Close() + ts := newClientServerTest(t, http1Mode, handler).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -1192,14 +1163,12 @@ func TestHTTP10KeepAlive304Response(t *testing.T) { } // Issue 15703 -func TestKeepAliveFinalChunkWithEOF(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, false /* h1 */, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) } +func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() // force chunked encoding w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}")) })) - defer cst.close() type data struct { Addr string } @@ -1222,16 +1191,11 @@ func TestKeepAliveFinalChunkWithEOF(t *testing.T) { } } -func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) } -func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) } - -func testSetsRemoteAddr(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) } +func testSetsRemoteAddr(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -1276,17 +1240,18 @@ func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr { // Issue 12943 func TestServerAllowsBlockingRemoteAddr(t *testing.T) { - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "RA:%s", r.RemoteAddr) - })) + run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode}) +} +func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) { conns := make(chan net.Conn) - ts.Listener = &blockingRemoteAddrListener{ - Listener: ts.Listener, - conns: conns, - } - ts.Start() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "RA:%s", r.RemoteAddr) + }), func(ts *httptest.Server) { + ts.Listener = &blockingRemoteAddrListener{ + Listener: ts.Listener, + conns: conns, + } + }).ts c := ts.Client() c.Timeout = time.Second @@ -1351,13 +1316,9 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { // TestHeadResponses verifies that all MIME type sniffing and Content-Length // counting of GET requests also happens on HEAD requests. -func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) } -func TestHeadResponses_h2(t *testing.T) { testHeadResponses(t, h2Mode) } - -func testHeadResponses(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) } +func testHeadResponses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("")) if err != nil { t.Errorf("ResponseWriter.Write: %v", err) @@ -1369,7 +1330,6 @@ func testHeadResponses(t *testing.T, h2 bool) { t.Errorf("Copy(ResponseWriter, ...): %v", err) } })) - defer cst.close() res, err := cst.c.Head(cst.ts.URL) if err != nil { t.Error(err) @@ -1393,14 +1353,16 @@ func testHeadResponses(t *testing.T, h2 bool) { } func TestTLSHandshakeTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode}) +} +func testTLSHandshakeTimeout(t *testing.T, mode testMode) { errc := make(chanWriter, 10) // but only expecting 1 - ts.Config.ReadTimeout = 250 * time.Millisecond - ts.Config.ErrorLog = log.New(errc, "", 0) - ts.StartTLS() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), + func(ts *httptest.Server) { + ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.ErrorLog = log.New(errc, "", 0) + }, + ).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -1423,19 +1385,18 @@ func TestTLSHandshakeTimeout(t *testing.T) { } } -func TestTLSServer(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) } +func testTLSServer(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS != nil { w.Header().Set("X-TLS-Set", "true") if r.TLS.HandshakeComplete { w.Header().Set("X-TLS-HandshakeComplete", "true") } } - })) - ts.Config.ErrorLog = log.New(io.Discard, "", 0) - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(io.Discard, "", 0) + }).ts // Connect an idle TCP connection to this server before we run // our real tests. This idle connection used to block forever @@ -1528,14 +1489,15 @@ func TestServeTLS(t *testing.T) { // Test that the HTTPS server nicely rejects plaintext HTTP/1.x requests. func TestTLSServerRejectHTTPRequests(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode}) +} +func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Error("unexpected HTTPS request") - })) - var errBuf bytes.Buffer - ts.Config.ErrorLog = log.New(&errBuf, "", 0) - defer ts.Close() + }), func(ts *httptest.Server) { + var errBuf bytes.Buffer + ts.Config.ErrorLog = log.New(&errBuf, "", 0) + }).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -1727,11 +1689,9 @@ var serverExpectTests = []serverExpectTest{ // Tests that the server responds to the "Expect" request header // correctly. -// http2 test: TestServer_Response_Automatic100Continue -func TestServerExpect(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) } +func testServerExpect(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // Note using r.FormValue("readbody") because for POST // requests that would read from r.Body, which we only // conditionally want to do. @@ -1741,8 +1701,7 @@ func TestServerExpect(t *testing.T) { } else { w.WriteHeader(StatusUnauthorized) } - })) - defer ts.Close() + })).ts runTest := func(test serverExpectTest) { conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -2287,11 +2246,8 @@ func (c cancelableTimeoutContext) Err() error { return nil } -func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) } -func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) } -func testTimeoutHandler(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) } +func testTimeoutHandler(t *testing.T, mode testMode) { sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2301,8 +2257,7 @@ func testTimeoutHandler(t *testing.T, h2 bool) { }) ctx, cancel := context.WithCancel(context.Background()) h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) - cst := newClientServerTest(t, h2, h) - defer cst.close() + cst := newClientServerTest(t, mode, h) // Succeed without timing out: sendHi <- true @@ -2348,10 +2303,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) { } // See issues 8209 and 8414. -func TestTimeoutHandlerRace(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) } +func testTimeoutHandlerRace(t *testing.T, mode testMode) { delayHi := HandlerFunc(func(w ResponseWriter, r *Request) { ms, _ := strconv.Atoi(r.URL.Path[1:]) if ms == 0 { @@ -2363,8 +2316,7 @@ func TestTimeoutHandlerRace(t *testing.T) { } }) - ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts c := ts.Client() @@ -2393,16 +2345,13 @@ func TestTimeoutHandlerRace(t *testing.T) { // See issues 8209 and 8414. // Both issues involved panics in the implementation of TimeoutHandler. -func TestTimeoutHandlerRaceHeader(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) } +func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) { delay204 := HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(204) }) - ts := httptest.NewServer(TimeoutHandler(delay204, time.Nanosecond, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts var wg sync.WaitGroup gate := make(chan bool, 50) @@ -2433,9 +2382,8 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { } // Issue 9162 -func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) } +func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) { sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2446,8 +2394,7 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { }) ctx, cancel := context.WithCancel(context.Background()) h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) - cst := newClientServerTest(t, h1Mode, h) - defer cst.close() + cst := newClientServerTest(t, mode, h) // Succeed without timing out: sendHi <- true @@ -2491,15 +2438,17 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { // Issue 14568. func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { + run(t, testTimeoutHandlerStartTimerWhenServing) +} +func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping sleeping test in -short mode") } - defer afterTest(t) var handler HandlerFunc = func(w ResponseWriter, _ *Request) { w.WriteHeader(StatusNoContent) } timeout := 300 * time.Millisecond - ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) + ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts defer ts.Close() c := ts.Client() @@ -2518,9 +2467,8 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { } } -func TestTimeoutHandlerContextCanceled(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) } +func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) { writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Type", "text/plain") @@ -2540,7 +2488,7 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() h := NewTestTimeoutHandler(sayHi, ctx) - cst := newClientServerTest(t, h1Mode, h) + cst := newClientServerTest(t, mode, h) defer cst.close() res, err := cst.c.Get(cst.ts.URL) @@ -2560,15 +2508,13 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) { } // https://golang.org/issue/15948 -func TestTimeoutHandlerEmptyResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) } +func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) { var handler HandlerFunc = func(w ResponseWriter, _ *Request) { // No response. } timeout := 300 * time.Millisecond - ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts c := ts.Client() @@ -2587,7 +2533,9 @@ func TestTimeoutHandlerPanicRecovery(t *testing.T) { wrapper := func(h Handler) Handler { return TimeoutHandler(h, time.Second, "") } - testHandlerPanic(t, false, false, wrapper, "intentional death for testing") + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, wrapper, "intentional death for testing") + }, testNotParallel) } func TestRedirectBadPath(t *testing.T) { @@ -2705,17 +2653,10 @@ func TestRedirectContentTypeAndBody(t *testing.T) { // connection immediately. But when it re-uses the connection, it typically closes // the previous request's body, which is not optimal for zero-lengthed bodies, // as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF. -func TestZeroLengthPostAndResponse_h1(t *testing.T) { - testZeroLengthPostAndResponse(t, h1Mode) -} -func TestZeroLengthPostAndResponse_h2(t *testing.T) { - testZeroLengthPostAndResponse(t, h2Mode) -} +func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) } -func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { +func testZeroLengthPostAndResponse(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { all, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("handler ReadAll: %v", err) @@ -2725,7 +2666,6 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { } rw.Header().Set("Content-Length", "0") })) - defer cst.close() req, err := NewRequest("POST", cst.ts.URL, strings.NewReader("")) if err != nil { @@ -2752,23 +2692,26 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { } } -func TestHandlerPanicNil_h1(t *testing.T) { testHandlerPanic(t, false, h1Mode, nil, nil) } -func TestHandlerPanicNil_h2(t *testing.T) { testHandlerPanic(t, false, h2Mode, nil, nil) } - -func TestHandlerPanic_h1(t *testing.T) { - testHandlerPanic(t, false, h1Mode, nil, "intentional death for testing") +func TestHandlerPanicNil(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, nil, nil) + }, testNotParallel) } -func TestHandlerPanic_h2(t *testing.T) { - testHandlerPanic(t, false, h2Mode, nil, "intentional death for testing") + +func TestHandlerPanic(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, nil, "intentional death for testing") + }, testNotParallel) } func TestHandlerPanicWithHijack(t *testing.T) { // Only testing HTTP/1, and our http2 server doesn't support hijacking. - testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing") + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, true, mode, nil, "intentional death for testing") + }, testNotParallel, []testMode{http1Mode}) } -func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue any) { - defer afterTest(t) +func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) { // Unlike the other tests that set the log output to io.Discard // to quiet the output, this test uses a pipe. The pipe serves three // purposes: @@ -2803,8 +2746,7 @@ func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) H if wrapper != nil { handler = wrapper(handler) } - cst := newClientServerTest(t, h2, handler) - defer cst.close() + cst := newClientServerTest(t, mode, handler) // Do a blocking read on the log output pipe so its logging // doesn't bleed into the next test. But wait only 5 seconds @@ -2847,9 +2789,11 @@ func (w terrorWriter) Write(p []byte) (int, error) { // Issue 16456: allow writing 0 bytes on hijacked conn to test hijack // without any log spam. func TestServerWriteHijackZeroBytes(t *testing.T) { - defer afterTest(t) + run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode}) +} +func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) { done := make(chan struct{}) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) w.(Flusher).Flush() conn, _, err := w.(Hijacker).Hijack() @@ -2862,10 +2806,9 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { if err != ErrHijacked { t.Errorf("Write error = %v; want ErrHijacked", err) } - })) - ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0) - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0) + }).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -2880,19 +2823,23 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { } } -func TestServerNoDate_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Date") } -func TestServerNoDate_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Date") } -func TestServerNoContentType_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Content-Type") } -func TestServerNoContentType_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Content-Type") } +func TestServerNoDate(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testServerNoHeader(t, mode, "Date") + }) +} -func testServerNoHeader(t *testing.T, h2 bool, header string) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerContentType(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testServerNoHeader(t, mode, "Content-Type") + }) +} + +func testServerNoHeader(t *testing.T, mode testMode, header string) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()[header] = nil io.WriteString(w, "foo") // non-empty })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -2903,15 +2850,13 @@ func testServerNoHeader(t *testing.T, h2 bool, header string) { } } -func TestStripPrefix(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) } +func testStripPrefix(t *testing.T, mode testMode) { h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) w.Header().Set("X-RawPath", r.URL.RawPath) }) - ts := httptest.NewServer(StripPrefix("/foo/bar", h)) - defer ts.Close() + ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts c := ts.Client() @@ -2961,15 +2906,11 @@ func TestStripPrefixNotModifyRequest(t *testing.T) { } } -func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) } -func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) } -func testRequestLimit(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) } +func testRequestLimit(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Fatalf("didn't expect to get request in Handler") }), optQuietLog) - defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, nil) var bytesPerHeader = len("header12345: val12345\r\n") for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ { @@ -2979,7 +2920,7 @@ func testRequestLimit(t *testing.T, h2 bool) { if res != nil { defer res.Body.Close() } - if h2 { + if mode == http2Mode { // In HTTP/2, the result depends on a race. If the client has received the // server's SETTINGS before RoundTrip starts sending the request, then RoundTrip // will fail with an error. Otherwise, the client should receive a 431 from the @@ -3021,13 +2962,10 @@ func (cr countReader) Read(p []byte) (n int, err error) { return } -func TestRequestBodyLimit_h1(t *testing.T) { testRequestBodyLimit(t, h1Mode) } -func TestRequestBodyLimit_h2(t *testing.T) { testRequestBodyLimit(t, h2Mode) } -func testRequestBodyLimit(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) } +func testRequestBodyLimit(t *testing.T, mode testMode) { const limit = 1 << 20 - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = MaxBytesReader(w, r.Body, limit) n, err := io.Copy(io.Discard, r.Body) if err == nil { @@ -3044,7 +2982,6 @@ func testRequestBodyLimit(t *testing.T, h2 bool) { t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit) } })) - defer cst.close() nWritten := new(int64) req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200)) @@ -3068,13 +3005,12 @@ func testRequestBodyLimit(t *testing.T, h2 bool) { // TestClientWriteShutdown tests that if the client shuts down the write // side of their TCP connection, the server doesn't send a 400 Bad Request. -func TestClientWriteShutdown(t *testing.T) { +func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) } +func testClientWriteShutdown(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/17906") } - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -3119,12 +3055,12 @@ func TestServerBufferedChunking(t *testing.T) { // closing the TCP connection, causing the client to get a RST. // See https://golang.org/issue/3595 func TestServerGracefulClose(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testServerGracefulClose, []testMode{http1Mode}) +} +func testServerGracefulClose(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, "bye", StatusUnauthorized) - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3162,11 +3098,9 @@ func TestServerGracefulClose(t *testing.T) { <-writeErr } -func TestCaseSensitiveMethod_h1(t *testing.T) { testCaseSensitiveMethod(t, h1Mode) } -func TestCaseSensitiveMethod_h2(t *testing.T) { testCaseSensitiveMethod(t, h2Mode) } -func testCaseSensitiveMethod(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) } +func testCaseSensitiveMethod(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "get" { t.Errorf(`Got method %q; want "get"`, r.Method) } @@ -3187,8 +3121,10 @@ func testCaseSensitiveMethod(t *testing.T, h2 bool) { // response, the net/http package adds a "Content-Length: 0" response // header. func TestContentLengthZero(t *testing.T) { - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {})) - defer ts.Close() + run(t, testContentLengthZero, []testMode{http1Mode}) +} +func testContentLengthZero(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} { conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -3215,15 +3151,17 @@ func TestContentLengthZero(t *testing.T) { } func TestCloseNotifier(t *testing.T) { - defer afterTest(t) + run(t, testCloseNotifier, []testMode{http1Mode}) +} +func testCloseNotifier(t *testing.T, mode testMode) { gotReq := make(chan bool, 1) sawClose := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gotReq <- true cc := rw.(CloseNotifier).CloseNotify() <-cc sawClose <- true - })) + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("error dialing: %v", err) @@ -3257,11 +3195,12 @@ For: // // Issue 13165 (where it used to deadlock), but behavior changed in Issue 23921. func TestCloseNotifierPipelined(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testCloseNotifierPipelined, []testMode{http1Mode}) +} +func testCloseNotifierPipelined(t *testing.T, mode testMode) { gotReq := make(chan bool, 2) sawClose := make(chan bool, 2) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gotReq <- true cc := rw.(CloseNotifier).CloseNotify() select { @@ -3270,8 +3209,7 @@ func TestCloseNotifierPipelined(t *testing.T) { case <-time.After(100 * time.Millisecond): } sawClose <- true - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("error dialing: %v", err) @@ -3341,12 +3279,14 @@ func TestCloseNotifierChanLeak(t *testing.T) { // Issue 9763. // HTTP/1-only test. (http2 doesn't have Hijack) func TestHijackAfterCloseNotifier(t *testing.T) { - defer afterTest(t) + run(t, testHijackAfterCloseNotifier, []testMode{http1Mode}) +} +func testHijackAfterCloseNotifier(t *testing.T, mode testMode) { script := make(chan string, 2) script <- "closenotify" script <- "hijack" close(script) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { plan := <-script switch plan { default: @@ -3369,13 +3309,12 @@ func TestHijackAfterCloseNotifier(t *testing.T) { c.Close() return } - })) - defer ts.Close() - res1, err := Get(ts.URL) + })).ts + res1, err := ts.Client().Get(ts.URL) if err != nil { log.Fatal(err) } - res2, err := Get(ts.URL) + res2, err := ts.Client().Get(ts.URL) if err != nil { log.Fatal(err) } @@ -3387,12 +3326,13 @@ func TestHijackAfterCloseNotifier(t *testing.T) { } func TestHijackBeforeRequestBodyRead(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode}) +} +func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) { var requestBody = bytes.Repeat([]byte("a"), 1<<20) bodyOkay := make(chan bool, 1) gotCloseNotify := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(bodyOkay) // caller will read false if nothing else reqBody := r.Body @@ -3419,8 +3359,7 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) { case <-time.After(5 * time.Second): gotCloseNotify <- false } - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3440,14 +3379,14 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) { } } -func TestOptions(t *testing.T) { +func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) } +func testOptions(t *testing.T, mode testMode) { uric := make(chan string, 2) // only expect 1, but leave space for 2 mux := NewServeMux() mux.HandleFunc("/", func(w ResponseWriter, r *Request) { uric <- r.RequestURI }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3492,15 +3431,15 @@ func TestOptions(t *testing.T) { } } -func TestOptionsHandler(t *testing.T) { +func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) } +func testOptionsHandler(t *testing.T, mode testMode) { rc := make(chan *Request, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { rc <- r - })) - ts.Config.DisableGeneralOptionsHandler = true - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.DisableGeneralOptionsHandler = true + }).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3804,12 +3743,12 @@ func TestDoubleHijack(t *testing.T) { // optimization and is pointless if dealing with a // badly behaved client. func TestHTTP10ConnectionHeader(t *testing.T) { - defer afterTest(t) - + run(t, testHTTP10ConnectionHeader, []testMode{http1Mode}) +} +func testHTTP10ConnectionHeader(t *testing.T, mode testMode) { mux := NewServeMux() mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {})) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts // net/http uses HTTP/1.1 for requests, so write requests manually tests := []struct { @@ -3856,14 +3795,11 @@ func TestHTTP10ConnectionHeader(t *testing.T) { } // See golang.org/issue/5660 -func TestServerReaderFromOrder_h1(t *testing.T) { testServerReaderFromOrder(t, h1Mode) } -func TestServerReaderFromOrder_h2(t *testing.T) { testServerReaderFromOrder(t, h2Mode) } -func testServerReaderFromOrder(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) } +func testServerReaderFromOrder(t *testing.T, mode testMode) { pr, pw := io.Pipe() const size = 3 << 20 - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Type", "text/plain") // prevent sniffing path done := make(chan bool) go func() { @@ -3883,7 +3819,6 @@ func testServerReaderFromOrder(t *testing.T, h2 bool) { pw.Close() <-done })) - defer cst.close() req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size)) if err != nil { @@ -3957,16 +3892,10 @@ func TestContentTypeOkayOn204(t *testing.T) { // proxy). So then two people own that Request.Body (both the server // and the http client), and both think they can close it on failure. // Therefore, all incoming server requests Bodies need to be thread-safe. -func TestTransportAndServerSharedBodyRace_h1(t *testing.T) { - testTransportAndServerSharedBodyRace(t, h1Mode) +func TestTransportAndServerSharedBodyRace(t *testing.T) { + run(t, testTransportAndServerSharedBodyRace) } -func TestTransportAndServerSharedBodyRace_h2(t *testing.T) { - testTransportAndServerSharedBodyRace(t, h2Mode) -} -func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) { const bodySize = 1 << 20 // errorf is like t.Errorf, but also writes to println. When @@ -3980,7 +3909,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { } unblockBackend := make(chan bool) - backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gone := rw.(CloseNotifier).CloseNotify() didCopy := make(chan any) go func() { @@ -4007,7 +3936,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { backendRespc := make(chan *Response, 1) var proxy *clientServerTest - proxy = newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { req2, _ := NewRequest("POST", backend.ts.URL, req.Body) req2.ContentLength = bodySize cancel := make(chan struct{}) @@ -4027,7 +3956,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // Try to cause a race: Both the Transport and the proxy handler's Server // will try to read/close req.Body (aka req2.Body) - if h2 { + if mode == http2Mode { close(cancel) } else { proxy.c.Transport.(*Transport).CancelRequest(req2) @@ -4071,22 +4000,23 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // cause the Handler goroutine's Request.Body.Close to block. // See issue 7121. func TestRequestBodyCloseDoesntBlock(t *testing.T) { + run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode}) +} +func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in -short mode") } - defer afterTest(t) readErrCh := make(chan error, 1) errCh := make(chan error, 2) - server := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { go func(body io.Reader) { _, err := body.Read(make([]byte, 100)) readErrCh <- err }(req.Body) time.Sleep(500 * time.Millisecond) - })) - defer server.Close() + })).ts closeConn := make(chan bool) defer close(closeConn) @@ -4149,9 +4079,8 @@ func TestAppendTime(t *testing.T) { } } -func TestServerConnState(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) } +func testServerConnState(t *testing.T, mode testMode) { handler := map[string]func(w ResponseWriter, r *Request){ "/": func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello.") @@ -4217,37 +4146,36 @@ func TestServerConnState(t *testing.T) { // next call to wantLog. } - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { handler[r.URL.Path](w, r) - })) + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(io.Discard, "", 0) + ts.Config.ConnState = func(c net.Conn, state ConnState) { + if c == nil { + t.Errorf("nil conn seen in state %s", state) + return + } + sl := <-activeLog + if sl.active == nil && state == StateNew { + sl.active = c + } else if sl.active != c { + t.Errorf("unexpected conn in state %s", state) + activeLog <- sl + return + } + sl.got = append(sl.got, state) + if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) { + close(sl.complete) + sl.complete = nil + } + activeLog <- sl + } + }).ts defer func() { activeLog <- &stateLog{} // If the test failed, allow any remaining ConnState callbacks to complete. ts.Close() }() - ts.Config.ErrorLog = log.New(io.Discard, "", 0) - ts.Config.ConnState = func(c net.Conn, state ConnState) { - if c == nil { - t.Errorf("nil conn seen in state %s", state) - return - } - sl := <-activeLog - if sl.active == nil && state == StateNew { - sl.active = c - } else if sl.active != c { - t.Errorf("unexpected conn in state %s", state) - activeLog <- sl - return - } - sl.got = append(sl.got, state) - if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) { - close(sl.complete) - sl.complete = nil - } - activeLog <- sl - } - - ts.Start() c := ts.Client() mustGet := func(url string, headers ...string) { @@ -4329,13 +4257,15 @@ func TestServerConnState(t *testing.T) { }, StateNew, StateActive, StateIdle, StateClosed) } -func TestServerKeepAlivesEnabled(t *testing.T) { - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - ts.Config.SetKeepAlivesEnabled(false) - ts.Start() - defer ts.Close() - res, err := Get(ts.URL) +func TestServerKeepAlivesEnabledResultClose(t *testing.T) { + run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode}) +} +func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + }), func(ts *httptest.Server) { + ts.Config.SetKeepAlivesEnabled(false) + }).ts + res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -4346,16 +4276,12 @@ func TestServerKeepAlivesEnabled(t *testing.T) { } // golang.org/issue/7856 -func TestServerEmptyBodyRace_h1(t *testing.T) { testServerEmptyBodyRace(t, h1Mode) } -func TestServerEmptyBodyRace_h2(t *testing.T) { testServerEmptyBodyRace(t, h2Mode) } -func testServerEmptyBodyRace(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) } +func testServerEmptyBodyRace(t *testing.T, mode testMode) { var n int32 - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { atomic.AddInt32(&n, 1) }), optQuietLog) - defer cst.close() var wg sync.WaitGroup const reqs = 20 for i := 0; i < reqs; i++ { @@ -4436,9 +4362,9 @@ func TestCloseWrite(t *testing.T) { // fixed. // // So add an explicit test for this. -func TestServerFlushAndHijack(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) } +func testServerFlushAndHijack(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, "Hello, ") w.(Flusher).Flush() conn, buf, _ := w.(Hijacker).Hijack() @@ -4449,8 +4375,7 @@ func TestServerFlushAndHijack(t *testing.T) { if err := conn.Close(); err != nil { t.Error(err) } - })) - defer ts.Close() + })).ts res, err := Get(ts.URL) if err != nil { t.Fatal(err) @@ -4472,20 +4397,21 @@ func TestServerFlushAndHijack(t *testing.T) { // To test, verify we don't timeout or see fewer unique client // addresses (== unique connections) than requests. func TestServerKeepAliveAfterWriteError(t *testing.T) { + run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode}) +} +func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in -short mode") } - defer afterTest(t) const numReq = 3 addrc := make(chan string, numReq) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { addrc <- r.RemoteAddr time.Sleep(500 * time.Millisecond) w.(Flusher).Flush() - })) - ts.Config.WriteTimeout = 250 * time.Millisecond - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.WriteTimeout = 250 * time.Millisecond + }).ts errc := make(chan error, numReq) go func() { @@ -4529,12 +4455,13 @@ func TestServerKeepAliveAfterWriteError(t *testing.T) { // Issue 9987: shouldn't add automatic Content-Length (or // Content-Type) if a Transfer-Encoding was set by the handler. func TestNoContentLengthIfTransferEncoding(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode}) +} +func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Transfer-Encoding", "foo") io.WriteString(w, "") - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -4682,15 +4609,12 @@ func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) { } } -func TestHandlerSetsBodyNil_h1(t *testing.T) { testHandlerSetsBodyNil(t, h1Mode) } -func TestHandlerSetsBodyNil_h2(t *testing.T) { testHandlerSetsBodyNil(t, h2Mode) } -func testHandlerSetsBodyNil(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) } +func testHandlerSetsBodyNil(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = nil fmt.Fprintf(w, "%v", r.RemoteAddr) })) - defer cst.close() get := func() string { res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -4780,9 +4704,11 @@ func TestServerValidatesHostHeader(t *testing.T) { } func TestServerHandlersCanHandleH2PRI(t *testing.T) { + run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode}) +} +func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) { const upgradeResponse = "upgrade here" - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, br, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -4804,8 +4730,7 @@ func TestServerHandlersCanHandleH2PRI(t *testing.T) { return } io.WriteString(conn, upgradeResponse) - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -4872,17 +4797,12 @@ func TestServerValidatesHeaders(t *testing.T) { } } -func TestServerRequestContextCancel_ServeHTTPDone_h1(t *testing.T) { - testServerRequestContextCancel_ServeHTTPDone(t, h1Mode) +func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) { + run(t, testServerRequestContextCancel_ServeHTTPDone) } -func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) { - testServerRequestContextCancel_ServeHTTPDone(t, h2Mode) -} -func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) { ctxc := make(chan context.Context, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() select { case <-ctx.Done(): @@ -4891,7 +4811,6 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { } ctxc <- ctx })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -4910,16 +4829,16 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { // is always blocked in a Read call so it notices the EOF from the client. // See issues 15927 and 15224. func TestServerRequestContextCancel_ConnClose(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode}) +} +func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) { inHandler := make(chan struct{}) handlerDone := make(chan struct{}) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { close(inHandler) <-r.Context().Done() close(handlerDone) - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -4931,23 +4850,17 @@ func TestServerRequestContextCancel_ConnClose(t *testing.T) { <-handlerDone } -func TestServerContext_ServerContextKey_h1(t *testing.T) { - testServerContext_ServerContextKey(t, h1Mode) +func TestServerContext_ServerContextKey(t *testing.T) { + run(t, testServerContext_ServerContextKey) } -func TestServerContext_ServerContextKey_h2(t *testing.T) { - testServerContext_ServerContextKey(t, h2Mode) -} -func testServerContext_ServerContextKey(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func testServerContext_ServerContextKey(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() got := ctx.Value(ServerContextKey) if _, ok := got.(*Server); !ok { t.Errorf("context value = %T; want *http.Server", got) } })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -4955,20 +4868,14 @@ func testServerContext_ServerContextKey(t *testing.T, h2 bool) { res.Body.Close() } -func TestServerContext_LocalAddrContextKey_h1(t *testing.T) { - testServerContext_LocalAddrContextKey(t, h1Mode) +func TestServerContext_LocalAddrContextKey(t *testing.T) { + run(t, testServerContext_LocalAddrContextKey) } -func TestServerContext_LocalAddrContextKey_h2(t *testing.T) { - testServerContext_LocalAddrContextKey(t, h2Mode) -} -func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) { ch := make(chan any, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ch <- r.Context().Value(LocalAddrContextKey) })) - defer cst.close() if _, err := cst.c.Head(cst.ts.URL); err != nil { t.Fatal(err) } @@ -5021,16 +4928,19 @@ func TestHandlerSetTransferEncodingGzip(t *testing.T) { } func BenchmarkClientServer(b *testing.B) { + run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode}) +} +func benchmarkClientServer(b *testing.B, mode testMode) { b.ReportAllocs() b.StopTimer() - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { fmt.Fprintf(rw, "Hello world.\n") - })) - defer ts.Close() + })).ts b.StartTimer() + c := ts.Client() for i := 0; i < b.N; i++ { - res, err := Get(ts.URL) + res, err := c.Get(ts.URL) if err != nil { b.Fatal("Get:", err) } @@ -5048,33 +4958,21 @@ func BenchmarkClientServer(b *testing.B) { b.StopTimer() } -func BenchmarkClientServerParallel4(b *testing.B) { - benchmarkClientServerParallel(b, 4, false) -} - -func BenchmarkClientServerParallel64(b *testing.B) { - benchmarkClientServerParallel(b, 64, false) -} - -func BenchmarkClientServerParallelTLS4(b *testing.B) { - benchmarkClientServerParallel(b, 4, true) -} - -func BenchmarkClientServerParallelTLS64(b *testing.B) { - benchmarkClientServerParallel(b, 64, true) -} - -func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) { - b.ReportAllocs() - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { - fmt.Fprintf(rw, "Hello world.\n") - })) - if useTLS { - ts.StartTLS() - } else { - ts.Start() +func BenchmarkClientServerParallel(b *testing.B) { + for _, parallelism := range []int{4, 64} { + b.Run(fmt.Sprint(parallelism), func(b *testing.B) { + run(b, func(b *testing.B, mode testMode) { + benchmarkClientServerParallel(b, parallelism, mode) + }, []testMode{http1Mode, https1Mode, http2Mode}) + }) } - defer ts.Close() +} + +func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) { + b.ReportAllocs() + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { + fmt.Fprintf(rw, "Hello world.\n") + })).ts b.ResetTimer() b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { @@ -5464,15 +5362,15 @@ Host: golang.org } } -func BenchmarkCloseNotifier(b *testing.B) { +func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) } +func benchmarkCloseNotifier(b *testing.B, mode testMode) { b.ReportAllocs() b.StopTimer() sawClose := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { <-rw.(CloseNotifier).CloseNotify() sawClose <- true - })) - defer ts.Close() + })).ts tot := time.NewTimer(5 * time.Second) defer tot.Stop() b.StartTimer() @@ -5508,20 +5406,18 @@ func TestConcurrentServerServe(t *testing.T) { } } -func TestServerIdleTimeout(t *testing.T) { +func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) } +func testServerIdleTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) io.WriteString(w, r.RemoteAddr) - })) - ts.Config.ReadHeaderTimeout = 1 * time.Second - ts.Config.IdleTimeout = 2 * time.Second - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadHeaderTimeout = 1 * time.Second + ts.Config.IdleTimeout = 2 * time.Second + }).ts c := ts.Client() get := func() string { @@ -5576,12 +5472,12 @@ func get(t *testing.T, c *Client, url string) string { // Tests that calls to Server.SetKeepAlivesEnabled(false) closes any // currently-open connections. func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode}) +} +func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, r.RemoteAddr) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -5620,16 +5516,8 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { } } -func TestServerShutdown_h1(t *testing.T) { - testServerShutdown(t, h1Mode) -} -func TestServerShutdown_h2(t *testing.T) { - testServerShutdown(t, h2Mode) -} - -func testServerShutdown(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) } +func testServerShutdown(t *testing.T, mode testMode) { var doShutdown func() // set later var doStateCount func() var shutdownRes = make(chan error, 1) @@ -5645,10 +5533,9 @@ func testServerShutdown(t *testing.T, h2 bool) { time.Sleep(20 * time.Millisecond) io.WriteString(w, r.RemoteAddr) }) - cst := newClientServerTest(t, h2, handler, func(srv *httptest.Server) { + cst := newClientServerTest(t, mode, handler, func(srv *httptest.Server) { srv.Config.RegisterOnShutdown(func() { gotOnShutdown <- struct{}{} }) }) - defer cst.close() doShutdown = func() { shutdownRes <- cst.ts.Config.Shutdown(context.Background()) @@ -5678,24 +5565,22 @@ func testServerShutdown(t *testing.T, h2 bool) { } } -func TestServerShutdownStateNew(t *testing.T) { +func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) } +func testServerShutdownStateNew(t *testing.T, mode testMode) { if testing.Short() { t.Skip("test takes 5-6 seconds; skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { - // nothing. - })) var connAccepted sync.WaitGroup - ts.Config.ConnState = func(conn net.Conn, state ConnState) { - if state == StateNew { - connAccepted.Done() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + // nothing. + }), func(ts *httptest.Server) { + ts.Config.ConnState = func(conn net.Conn, state ConnState) { + if state == StateNew { + connAccepted.Done() + } } - } - ts.Start() - defer ts.Close() + }).ts // Start a connection but never write to it. connAccepted.Add(1) @@ -5757,16 +5642,14 @@ func TestServerCloseDeadlock(t *testing.T) { // Issue 17717: tests that Server.SetKeepAlivesEnabled is respected by // both HTTP/1 and HTTP/2. -func TestServerKeepAlivesEnabled_h1(t *testing.T) { testServerKeepAlivesEnabled(t, h1Mode) } -func TestServerKeepAlivesEnabled_h2(t *testing.T) { testServerKeepAlivesEnabled(t, h2Mode) } -func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { - if h2 { +func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) } +func testServerKeepAlivesEnabled(t *testing.T, mode testMode) { + if mode == http2Mode { restore := ExportSetH2GoawayTimeout(10 * time.Millisecond) defer restore() } // Not parallel: messes with global variable. (http2goAwayTimeout) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {})) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})) defer cst.close() srv := cst.ts.Config srv.SetKeepAlivesEnabled(false) @@ -5803,9 +5686,8 @@ func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { // Issue 18447: test that the Server's ReadTimeout is stopped while // the server's doing its 1-byte background read between requests, // waiting for the connection to maybe close. -func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) } +func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) { runTimeSensitiveTest(t, []time.Duration{ 10 * time.Millisecond, 50 * time.Millisecond, @@ -5813,17 +5695,16 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { time.Second, 2 * time.Second, }, func(t *testing.T, timeout time.Duration) error { - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { select { case <-time.After(2 * timeout): fmt.Fprint(w, "ok") case <-r.Context().Done(): fmt.Fprint(w, r.Context().Err()) } - })) - ts.Config.ReadTimeout = timeout - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadTimeout = timeout + }).ts c := ts.Client() @@ -5847,8 +5728,9 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { // beginning of a request has been received, rather than including time the // connection spent idle. func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode}) +} +func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) { runTimeSensitiveTest(t, []time.Duration{ 10 * time.Millisecond, 50 * time.Millisecond, @@ -5856,11 +5738,10 @@ func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) { time.Second, 2 * time.Second, }, func(t *testing.T, timeout time.Duration) error { - ts := httptest.NewUnstartedServer(serve(200)) - ts.Config.ReadHeaderTimeout = timeout - ts.Config.IdleTimeout = 0 // disable idle timeout - ts.Start() - defer ts.Close() + ts := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) { + ts.Config.ReadHeaderTimeout = timeout + ts.Config.IdleTimeout = 0 // disable idle timeout + }).ts // rather than using an http.Client, create a single connection, so that // we can ensure this connection is not closed. @@ -5912,13 +5793,13 @@ func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t * // Issue 18535: test that the Server doesn't try to do a background // read if it's already done one. func TestServerDuplicateBackgroundRead(t *testing.T) { + run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode}) +} +func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) { if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" { testenv.SkipFlaky(t, 24826) } - setParallel(t) - defer afterTest(t) - goroutines := 5 requests := 2000 if testing.Short() { @@ -5926,8 +5807,7 @@ func TestServerDuplicateBackgroundRead(t *testing.T) { requests = 100 } - hts := httptest.NewServer(HandlerFunc(NotFound)) - defer hts.Close() + hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n") @@ -5970,14 +5850,15 @@ func TestServerDuplicateBackgroundRead(t *testing.T) { // bufio.Reader.Buffered(), without resorting to Reading it // (potentially blocking) to get at it. func TestServerHijackGetsBackgroundByte(t *testing.T) { + run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode}) +} +func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/18657") } - setParallel(t) - defer afterTest(t) done := make(chan struct{}) inHandler := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) // Tell the client to send more data after the GET request. @@ -6000,8 +5881,7 @@ func TestServerHijackGetsBackgroundByte(t *testing.T) { t.Error("context unexpectedly canceled") default: } - })) - defer ts.Close() + })).ts cn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -6030,14 +5910,15 @@ func TestServerHijackGetsBackgroundByte(t *testing.T) { // immediate 1MB of data to the server to fill up the server's 4KB // buffer. func TestServerHijackGetsBackgroundByte_big(t *testing.T) { + run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode}) +} +func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/18657") } - setParallel(t) - defer afterTest(t) done := make(chan struct{}) const size = 8 << 10 - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) conn, buf, err := w.(Hijacker).Hijack() @@ -6061,8 +5942,7 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) { } else if !allX { t.Errorf("read %q; want %d 'x'", slurp, size) } - })) - defer ts.Close() + })).ts cn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -6198,73 +6078,27 @@ func TestStripPortFromHost(t *testing.T) { } } -func TestServerContexts(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerContexts(t *testing.T) { run(t, testServerContexts) } +func testServerContexts(t *testing.T, mode testMode) { type baseKey struct{} type connKey struct{} ch := make(chan context.Context, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { ch <- r.Context() - })) - ts.Config.BaseContext = func(ln net.Listener) context.Context { - if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { - t.Errorf("unexpected onceClose listener type %T", ln) + }), func(ts *httptest.Server) { + ts.Config.BaseContext = func(ln net.Listener) context.Context { + if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { + t.Errorf("unexpected onceClose listener type %T", ln) + } + return context.WithValue(context.Background(), baseKey{}, "base") } - return context.WithValue(context.Background(), baseKey{}, "base") - } - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) + ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + if got, want := ctx.Value(baseKey{}), "base"; got != want { + t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) + } + return context.WithValue(ctx, connKey{}, "conn") } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.Start() - defer ts.Close() - res, err := ts.Client().Get(ts.URL) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - ctx := <-ch - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("base context key = %#v; want %q", got, want) - } - if got, want := ctx.Value(connKey{}), "conn"; got != want { - t.Errorf("conn context key = %#v; want %q", got, want) - } -} - -func TestServerContextsHTTP2(t *testing.T) { - setParallel(t) - defer afterTest(t) - type baseKey struct{} - type connKey struct{} - ch := make(chan context.Context, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { - if r.ProtoMajor != 2 { - t.Errorf("unexpected HTTP/1.x request") - } - ch <- r.Context() - })) - ts.Config.BaseContext = func(ln net.Listener) context.Context { - if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { - t.Errorf("unexpected onceClose listener type %T", ln) - } - return context.WithValue(context.Background(), baseKey{}, "base") - } - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) - } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.TLS = &tls.Config{ - NextProtos: []string{"h2", "http/1.1"}, - } - ts.StartTLS() - defer ts.Close() - ts.Client().Transport.(*Transport).ForceAttemptHTTP2 = true + }).ts res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) @@ -6281,20 +6115,20 @@ func TestServerContextsHTTP2(t *testing.T) { // Issue 35750: check ConnContext not modifying context for other connections func TestConnContextNotModifyingAllContexts(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testConnContextNotModifyingAllContexts) +} +func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) { type connKey struct{} - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { rw.Header().Set("Connection", "close") - })) - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got := ctx.Value(connKey{}); got != nil { - t.Errorf("in ConnContext, unexpected context key = %#v", got) + }), func(ts *httptest.Server) { + ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + if got := ctx.Value(connKey{}); got != nil { + t.Errorf("in ConnContext, unexpected context key = %#v", got) + } + return context.WithValue(ctx, connKey{}, "conn") } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.Start() - defer ts.Close() + }).ts var res *Response var err error @@ -6315,10 +6149,12 @@ func TestConnContextNotModifyingAllContexts(t *testing.T) { // Issue 30710: ensure that as per the spec, a server responds // with 501 Not Implemented for unsupported transfer-encodings. func TestUnsupportedTransferEncodingsReturn501(t *testing.T) { - cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode}) +} +func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello, World!")) - })) - defer cst.Close() + })).ts serverURL, err := url.Parse(cst.URL) if err != nil { @@ -6353,19 +6189,9 @@ func TestUnsupportedTransferEncodingsReturn501(t *testing.T) { } } -func TestContentEncodingNoSniffing_h1(t *testing.T) { - testContentEncodingNoSniffing(t, h1Mode) -} - -func TestContentEncodingNoSniffing_h2(t *testing.T) { - testContentEncodingNoSniffing(t, h2Mode) -} - // Issue 31753: don't sniff when Content-Encoding is set -func testContentEncodingNoSniffing(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) } +func testContentEncodingNoSniffing(t *testing.T, mode testMode) { type setting struct { name string body []byte @@ -6428,13 +6254,12 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) { for _, tt := range settings { t.Run(tt.name, func(t *testing.T) { - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { if tt.contentEncoding != nil { rw.Header().Set("Content-Encoding", tt.contentEncoding.(string)) } rw.Write(tt.body) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -6460,13 +6285,13 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) { // Issue 30803: ensure that TimeoutHandler logs spurious // WriteHeader calls, for consistency with other Handlers. func TestTimeoutHandlerSuperfluousLogs(t *testing.T) { + run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode}) +} +func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - pc, curFile, _, _ := runtime.Caller(0) curFileBaseName := filepath.Base(curFile) testFuncName := runtime.FuncForPC(pc).Name() @@ -6520,7 +6345,7 @@ func TestTimeoutHandlerSuperfluousLogs(t *testing.T) { dur = 10 * time.Second } th := TimeoutHandler(sh, dur, timeoutMsg) - cst := newClientServerTest(t, h1Mode /* the test is protocol-agnostic */, th, optWithServerLog(srvLog)) + cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog)) defer cst.close() res, err := cst.c.Get(cst.ts.URL) @@ -6590,15 +6415,16 @@ func BenchmarkResponseStatusLine(b *testing.B) { } }) } + func TestDisableKeepAliveUpgrade(t *testing.T) { + run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode}) +} +func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - - s := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "someProto") w.WriteHeader(StatusSwitchingProtocols) @@ -6611,10 +6437,9 @@ func TestDisableKeepAliveUpgrade(t *testing.T) { // Copy from the *bufio.ReadWriter, which may contain buffered data. // Copy to the net.Conn, to avoid buffering the output. io.Copy(c, buf) - })) - s.Config.SetKeepAlivesEnabled(false) - s.Start() - defer s.Close() + }), func(ts *httptest.Server) { + ts.Config.SetKeepAlivesEnabled(false) + }).ts cl := s.Client() cl.Transport.(*Transport).DisableKeepAlives = true @@ -6683,21 +6508,21 @@ func TestQuerySemicolon(t *testing.T) { {"?a=1;x=good;x=bad", "", "good", true}, } - for _, tt := range tests { - t.Run(tt.query+"/allow=false", func(t *testing.T) { - allowSemicolons := false - testQuerySemicolon(t, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning) - }) - t.Run(tt.query+"/allow=true", func(t *testing.T) { - allowSemicolons, expectWarning := true, false - testQuerySemicolon(t, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning) - }) - } + run(t, func(t *testing.T, mode testMode) { + for _, tt := range tests { + t.Run(tt.query+"/allow=false", func(t *testing.T) { + allowSemicolons := false + testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning) + }) + t.Run(tt.query+"/allow=true", func(t *testing.T) { + allowSemicolons, expectWarning := true, false + testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning) + }) + } + }) } -func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolons, expectWarning bool) { - setParallel(t) - +func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectWarning bool) { writeBackX := func(w ResponseWriter, r *Request) { x := r.URL.Query().Get("x") if expectWarning { @@ -6720,11 +6545,10 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon h = AllowQuerySemicolons(h) } - ts := httptest.NewUnstartedServer(h) logBuf := &strings.Builder{} - ts.Config.ErrorLog = log.New(logBuf, "", 0) - ts.Start() - defer ts.Close() + ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(logBuf, "", 0) + }).ts req, _ := NewRequest("GET", ts.URL+query, nil) res, err := ts.Client().Do(req) @@ -6759,13 +6583,15 @@ func TestMaxBytesHandler(t *testing.T) { for _, requestSize := range []int64{100, 1_000, 1_000_000} { t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize), func(t *testing.T) { - testMaxBytesHandler(t, maxSize, requestSize) + run(t, func(t *testing.T, mode testMode) { + testMaxBytesHandler(t, mode, maxSize, requestSize) + }) }) } } } -func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { +func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) { var ( handlerN int64 handlerErr error @@ -6776,7 +6602,7 @@ func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { io.Copy(w, &buf) }) - ts := httptest.NewServer(MaxBytesHandler(echo, maxSize)) + ts := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize)).ts defer ts.Close() c := ts.Client() @@ -6843,13 +6669,12 @@ func TestProcessing(t *testing.T) { } } -func TestParseFormCleanup_h1(t *testing.T) { testParseFormCleanup(t, h1Mode) } -func TestParseFormCleanup_h2(t *testing.T) { - t.Skip("https://go.dev/issue/20253") - testParseFormCleanup(t, h2Mode) -} +func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) } +func testParseFormCleanup(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/20253") + } -func testParseFormCleanup(t *testing.T, h2 bool) { const maxMemory = 1024 const key = "file" @@ -6858,9 +6683,7 @@ func testParseFormCleanup(t *testing.T, h2 bool) { t.Skip("https://go.dev/issue/25965") } - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.ParseMultipartForm(maxMemory) f, _, err := r.FormFile(key) if err != nil { @@ -6874,7 +6697,6 @@ func testParseFormCleanup(t *testing.T, h2 bool) { } w.Write([]byte(of.Name())) })) - defer cst.close() fBuf := new(bytes.Buffer) mw := multipart.NewWriter(fBuf) @@ -6911,33 +6733,23 @@ func testParseFormCleanup(t *testing.T, h2 bool) { func TestHeadBody(t *testing.T) { const identityMode = false const chunkedMode = true - t.Run("h1", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h1Mode, identityMode, "HEAD") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h1Mode, chunkedMode, "HEAD") }) - }) - t.Run("h2", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h2Mode, identityMode, "HEAD") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h2Mode, chunkedMode, "HEAD") }) + run(t, func(t *testing.T, mode testMode) { + t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") }) + t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") }) }) } func TestGetBody(t *testing.T) { const identityMode = false const chunkedMode = true - t.Run("h1", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h1Mode, identityMode, "GET") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h1Mode, chunkedMode, "GET") }) - }) - t.Run("h2", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h2Mode, identityMode, "GET") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h2Mode, chunkedMode, "GET") }) + run(t, func(t *testing.T, mode testMode) { + t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") }) + t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") }) }) } -func testHeadBody(t *testing.T, h2, chunked bool, method string) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { b, err := io.ReadAll(r.Body) if err != nil { t.Errorf("server reading body: %v", err) diff --git a/src/net/http/sniff_test.go b/src/net/http/sniff_test.go index e91335729a..d6ef40905e 100644 --- a/src/net/http/sniff_test.go +++ b/src/net/http/sniff_test.go @@ -88,13 +88,9 @@ func TestDetectContentType(t *testing.T) { } } -func TestServerContentType_h1(t *testing.T) { testServerContentType(t, h1Mode) } -func TestServerContentType_h2(t *testing.T) { testServerContentType(t, h2Mode) } - -func testServerContentType(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerContentTypeSniff(t *testing.T) { run(t, testServerContentTypeSniff) } +func testServerContentTypeSniff(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { i, _ := strconv.Atoi(r.FormValue("i")) tt := sniffTests[i] n, err := w.Write(tt.data) @@ -134,15 +130,12 @@ func testServerContentType(t *testing.T, h2 bool) { // Issue 5953: shouldn't sniff if the handler set a Content-Type header, // even if it's the empty string. -func TestServerIssue5953_h1(t *testing.T) { testServerIssue5953(t, h1Mode) } -func TestServerIssue5953_h2(t *testing.T) { testServerIssue5953(t, h2Mode) } -func testServerIssue5953(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerIssue5953(t *testing.T) { run(t, testServerIssue5953) } +func testServerIssue5953(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()["Content-Type"] = []string{""} fmt.Fprintf(w, "hi") })) - defer cst.close() resp, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -173,11 +166,8 @@ func (b *byteAtATimeReader) Read(p []byte) (n int, err error) { return 1, nil } -func TestContentTypeWithVariousSources_h1(t *testing.T) { testContentTypeWithVariousSources(t, h1Mode) } -func TestContentTypeWithVariousSources_h2(t *testing.T) { testContentTypeWithVariousSources(t, h2Mode) } -func testContentTypeWithVariousSources(t *testing.T, h2 bool) { - defer afterTest(t) - +func TestContentTypeWithVariousSources(t *testing.T) { run(t, testContentTypeWithVariousSources) } +func testContentTypeWithVariousSources(t *testing.T, mode testMode) { const ( input = "\n\n\t\n" expected = "text/html; charset=utf-8" @@ -239,8 +229,7 @@ func testContentTypeWithVariousSources(t *testing.T, h2 bool) { }, }} { t.Run(test.name, func(t *testing.T) { - cst := newClientServerTest(t, h2, HandlerFunc(test.handler)) - defer cst.close() + cst := newClientServerTest(t, mode, HandlerFunc(test.handler)) resp, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -265,12 +254,9 @@ func testContentTypeWithVariousSources(t *testing.T, h2 bool) { } } -func TestSniffWriteSize_h1(t *testing.T) { testSniffWriteSize(t, h1Mode) } -func TestSniffWriteSize_h2(t *testing.T) { testSniffWriteSize(t, h2Mode) } -func testSniffWriteSize(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestSniffWriteSize(t *testing.T) { run(t, testSniffWriteSize) } +func testSniffWriteSize(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { size, _ := strconv.Atoi(r.FormValue("size")) written, err := io.WriteString(w, strings.Repeat("a", size)) if err != nil { @@ -281,7 +267,6 @@ func testSniffWriteSize(t *testing.T, h2 bool) { t.Errorf("write of %d bytes wrote %d bytes", size, written) } })) - defer cst.close() for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} { res, err := cst.c.Get(fmt.Sprintf("%s/?size=%d", cst.ts.URL, size)) if err != nil { diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 26293befb4..8748cf6f7b 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -135,12 +135,11 @@ func (tcs *testConnSet) check(t *testing.T) { } } -func TestReuseRequest(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) } +func testReuseRequest(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("{}")) - })) - defer ts.Close() + })).ts c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) @@ -165,10 +164,9 @@ func TestReuseRequest(t *testing.T) { // Two subsequent requests and verify their response is the same. // The response from the server is our own IP:port -func TestTransportKeepAlives(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() +func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) } +func testTransportKeepAlives(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() for _, disableKeepAlive := range []bool{false, true} { @@ -197,9 +195,10 @@ func TestTransportKeepAlives(t *testing.T) { } func TestTransportConnectionCloseOnResponse(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportConnectionCloseOnResponse) +} +func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts connSet, testDial := makeTestDial(t) @@ -253,9 +252,10 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { // describes the source source connection it got (remote port number + // address of its net.Conn). func TestTransportConnectionCloseOnRequest(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode}) +} +func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts connSet, testDial := makeTestDial(t) @@ -317,9 +317,10 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { // send Connection: close. // HTTP/1-only (Connection: close doesn't exist in h2) func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode}) +} +func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() c.Transport.(*Transport).DisableKeepAlives = true @@ -337,6 +338,9 @@ func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { // Test that Transport only sends one "Connection: close", regardless of // how "close" was indicated. func TestTransportRespectRequestWantsClose(t *testing.T) { + run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode}) +} +func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) { tests := []struct { disableKeepAlives bool close bool @@ -350,9 +354,7 @@ func TestTransportRespectRequestWantsClose(t *testing.T) { for _, tc := range tests { t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close), func(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives @@ -387,9 +389,10 @@ func TestTransportRespectRequestWantsClose(t *testing.T) { } func TestTransportIdleCacheKeys(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportIdleCacheKeys, []testMode{http1Mode}) +} +func testTransportIdleCacheKeys(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -420,12 +423,12 @@ func TestTransportIdleCacheKeys(t *testing.T) { // Tests that the HTTP transport re-uses connections when a client // reads to the end of a response Body without closing it. -func TestTransportReadToEndReusesConn(t *testing.T) { - defer afterTest(t) +func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) } +func testTransportReadToEndReusesConn(t *testing.T, mode testMode) { const msg = "foobar" var addrSeen map[string]int - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { addrSeen[r.RemoteAddr]++ if r.URL.Path == "/chunked/" { w.WriteHeader(200) @@ -435,16 +438,13 @@ func TestTransportReadToEndReusesConn(t *testing.T) { w.WriteHeader(200) } w.Write([]byte(msg)) - })) - defer ts.Close() - - buf := make([]byte, len(msg)) + })).ts for pi, path := range []string{"/content-length/", "/chunked/"} { wantLen := []int{len(msg), -1}[pi] addrSeen = make(map[string]int) for i := 0; i < 3; i++ { - res, err := Get(ts.URL + path) + res, err := ts.Client().Get(ts.URL + path) if err != nil { t.Errorf("Get %s: %v", path, err) continue @@ -459,9 +459,9 @@ func TestTransportReadToEndReusesConn(t *testing.T) { if res.ContentLength != int64(wantLen) { t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) } - n, err := res.Body.Read(buf) - if n != len(msg) || err != io.EOF { - t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg)) + got, err := io.ReadAll(res.Body) + if string(got) != msg || err != nil { + t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg) } } if len(addrSeen) != 1 { @@ -471,13 +471,15 @@ func TestTransportReadToEndReusesConn(t *testing.T) { } func TestTransportMaxPerHostIdleConns(t *testing.T) { - defer afterTest(t) + run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode}) +} +func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) { stop := make(chan struct{}) // stop marks the exit of main Test goroutine defer close(stop) resch := make(chan string) gotReq := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotReq <- true var msg string select { @@ -490,8 +492,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { t.Errorf("Write: %v", err) return } - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -559,14 +560,15 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportMaxConnsPerHostIncludeDialInProgress) +} +func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("foo")) if err != nil { t.Fatalf("Write: %v", err) } - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) dialStarted := make(chan struct{}) @@ -626,7 +628,9 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { } func TestTransportMaxConnsPerHost(t *testing.T) { - defer afterTest(t) + run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode}) +} +func testTransportMaxConnsPerHost(t *testing.T, mode testMode) { CondSkipHTTP2(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -636,115 +640,101 @@ func TestTransportMaxConnsPerHost(t *testing.T) { } }) - testMaxConns := func(scheme string, ts *httptest.Server) { - defer ts.Close() - - c := ts.Client() - tr := c.Transport.(*Transport) - tr.MaxConnsPerHost = 1 - if err := ExportHttp2ConfigureTransport(tr); err != nil { - t.Fatalf("ExportHttp2ConfigureTransport: %v", err) - } - - mu := sync.Mutex{} - var conns []net.Conn - var dialCnt, gotConnCnt, tlsHandshakeCnt int32 - tr.Dial = func(network, addr string) (net.Conn, error) { - atomic.AddInt32(&dialCnt, 1) - c, err := net.Dial(network, addr) - mu.Lock() - defer mu.Unlock() - conns = append(conns, c) - return c, err - } - - doReq := func() { - trace := &httptrace.ClientTrace{ - GotConn: func(connInfo httptrace.GotConnInfo) { - if !connInfo.Reused { - atomic.AddInt32(&gotConnCnt, 1) - } - }, - TLSHandshakeStart: func() { - atomic.AddInt32(&tlsHandshakeCnt, 1) - }, - } - req, _ := NewRequest("GET", ts.URL, nil) - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - - resp, err := c.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, err = io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read body failed: %v", err) - } - } - - wg := sync.WaitGroup{} - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - doReq() - }() - } - wg.Wait() - - expected := int32(tr.MaxConnsPerHost) - if dialCnt != expected { - t.Errorf("round 1: too many dials (%s): %d != %d", scheme, dialCnt, expected) - } - if gotConnCnt != expected { - t.Errorf("round 1: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) - } - if ts.TLS != nil && tlsHandshakeCnt != expected { - t.Errorf("round 1: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) - } - - if t.Failed() { - t.FailNow() - } + ts := newClientServerTest(t, mode, h).ts + c := ts.Client() + tr := c.Transport.(*Transport) + tr.MaxConnsPerHost = 1 + mu := sync.Mutex{} + var conns []net.Conn + var dialCnt, gotConnCnt, tlsHandshakeCnt int32 + tr.Dial = func(network, addr string) (net.Conn, error) { + atomic.AddInt32(&dialCnt, 1) + c, err := net.Dial(network, addr) mu.Lock() - for _, c := range conns { - c.Close() - } - conns = nil - mu.Unlock() - tr.CloseIdleConnections() + defer mu.Unlock() + conns = append(conns, c) + return c, err + } - doReq() - expected++ - if dialCnt != expected { - t.Errorf("round 2: too many dials (%s): %d", scheme, dialCnt) + doReq := func() { + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + if !connInfo.Reused { + atomic.AddInt32(&gotConnCnt, 1) + } + }, + TLSHandshakeStart: func() { + atomic.AddInt32(&tlsHandshakeCnt, 1) + }, } - if gotConnCnt != expected { - t.Errorf("round 2: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) + req, _ := NewRequest("GET", ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + resp, err := c.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) } - if ts.TLS != nil && tlsHandshakeCnt != expected { - t.Errorf("round 2: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body failed: %v", err) } } - testMaxConns("http", httptest.NewServer(h)) - testMaxConns("https", httptest.NewTLSServer(h)) + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + doReq() + }() + } + wg.Wait() - ts := httptest.NewUnstartedServer(h) - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - testMaxConns("http2", ts) + expected := int32(tr.MaxConnsPerHost) + if dialCnt != expected { + t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected) + } + if gotConnCnt != expected { + t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected) + } + + if t.Failed() { + t.FailNow() + } + + mu.Lock() + for _, c := range conns { + c.Close() + } + conns = nil + mu.Unlock() + tr.CloseIdleConnections() + + doReq() + expected++ + if dialCnt != expected { + t.Errorf("round 2: too many dials: %d", dialCnt) + } + if gotConnCnt != expected { + t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected) + } } func TestTransportRemovesDeadIdleConnections(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode}) +} +func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, r.RemoteAddr) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -789,10 +779,10 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) { // Test that the Transport notices when a server hangs up on its // unexpectedly (a keep-alive connection is closed). func TestTransportServerClosingUnexpectedly(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode}) +} +func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() fetch := func(n, retries int) string { @@ -846,11 +836,13 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { // Test for https://golang.org/issue/2616 (appropriate issue number) // This fails pretty reliably with GOMAXPROCS=100 or something high. func TestStressSurpriseServerCloses(t *testing.T) { - defer afterTest(t) + run(t, testStressSurpriseServerCloses, []testMode{http1Mode}) +} +func testStressSurpriseServerCloses(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in short mode") } - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "5") w.Header().Set("Content-Type", "text/plain") w.Write([]byte("Hello")) @@ -858,8 +850,7 @@ func TestStressSurpriseServerCloses(t *testing.T) { conn, buf, _ := w.(Hijacker).Hijack() buf.Flush() conn.Close() - })) - defer ts.Close() + })).ts c := ts.Client() // Do a bunch of traffic from different goroutines. Send to activityc @@ -906,16 +897,15 @@ func TestStressSurpriseServerCloses(t *testing.T) { // TestTransportHeadResponses verifies that we deal with Content-Lengths // with no bodies properly -func TestTransportHeadResponses(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) } +func testTransportHeadResponses(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) } w.Header().Set("Content-Length", "123") w.WriteHeader(200) - })) - defer ts.Close() + })).ts c := ts.Client() for i := 0; i < 2; i++ { @@ -941,16 +931,17 @@ func TestTransportHeadResponses(t *testing.T) { // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding // on responses to HEAD requests. func TestTransportHeadChunkedResponse(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel) +} +func testTransportHeadChunkedResponse(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) } w.Header().Set("Transfer-Encoding", "chunked") // client should ignore w.Header().Set("x-client-ipport", r.RemoteAddr) w.WriteHeader(200) - })) - defer ts.Close() + })).ts c := ts.Client() // Ensure that we wait for the readLoop to complete before @@ -991,11 +982,10 @@ var roundTripTests = []struct { } // Test that the modification made to the Request by the RoundTripper is cleaned up -func TestRoundTripGzip(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) } +func testRoundTripGzip(t *testing.T, mode testMode) { const responseBody = "test response body" - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { accept := req.Header.Get("Accept-Encoding") if expect := req.FormValue("expect_accept"); accept != expect { t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", @@ -1010,8 +1000,7 @@ func TestRoundTripGzip(t *testing.T) { rw.Header().Set("Content-Encoding", accept) rw.Write([]byte(responseBody)) } - })) - defer ts.Close() + })).ts tr := ts.Client().Transport.(*Transport) for i, test := range roundTripTests { @@ -1055,12 +1044,14 @@ func TestRoundTripGzip(t *testing.T) { } -func TestTransportGzip(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) } +func testTransportGzip(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/56020") + } const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" const nRandBytes = 1024 * 1024 - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { if req.Method == "HEAD" { if g := req.Header.Get("Accept-Encoding"); g != "" { t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g) @@ -1087,8 +1078,7 @@ func TestTransportGzip(t *testing.T) { io.CopyN(gz, rand.Reader, nRandBytes) } gz.Close() - })) - defer ts.Close() + })).ts c := ts.Client() for _, chunked := range []string{"1", "0"} { @@ -1153,10 +1143,10 @@ func TestTransportGzip(t *testing.T) { // If a request has Expect:100-continue header, the request blocks sending body until the first response. // Premature consumption of the request body should not be occurred. func TestTransportExpect100Continue(t *testing.T) { - setParallel(t) - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + run(t, testTransportExpect100Continue, []testMode{http1Mode}) +} +func testTransportExpect100Continue(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { switch req.URL.Path { case "/100": // This endpoint implicitly responds 100 Continue and reads body. @@ -1194,8 +1184,7 @@ func TestTransportExpect100Continue(t *testing.T) { conn.Close() } - })) - defer ts.Close() + })).ts tests := []struct { path string @@ -1242,7 +1231,9 @@ func TestTransportExpect100Continue(t *testing.T) { } func TestSOCKS5Proxy(t *testing.T) { - defer afterTest(t) + run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode}) +} +func testSOCKS5Proxy(t *testing.T, mode testMode) { ch := make(chan string, 1) l := newLocalListener(t) defer l.Close() @@ -1322,12 +1313,7 @@ func TestSOCKS5Proxy(t *testing.T) { }) for _, useTLS := range []bool{false, true} { t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) { - var ts *httptest.Server - if useTLS { - ts = httptest.NewTLSServer(h) - } else { - ts = httptest.NewServer(h) - } + ts := newClientServerTest(t, mode, h).ts go proxy(t) c := ts.Client() c.Transport.(*Transport).Proxy = ProxyURL(pu) @@ -1359,16 +1345,16 @@ func TestSOCKS5Proxy(t *testing.T) { func TestTransportProxy(t *testing.T) { defer afterTest(t) - testCases := []struct{ httpsSite, httpsProxy bool }{ - {false, false}, - {false, true}, - {true, false}, - {true, true}, + testCases := []struct{ siteMode, proxyMode testMode }{ + {http1Mode, http1Mode}, + {http1Mode, https1Mode}, + {https1Mode, http1Mode}, + {https1Mode, https1Mode}, } for _, testCase := range testCases { - httpsSite := testCase.httpsSite - httpsProxy := testCase.httpsProxy - t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { + siteMode := testCase.siteMode + proxyMode := testCase.proxyMode + t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) { siteCh := make(chan *Request, 1) h1 := HandlerFunc(func(w ResponseWriter, r *Request) { siteCh <- r @@ -1414,18 +1400,8 @@ func TestTransportProxy(t *testing.T) { }() } }) - var ts *httptest.Server - if httpsSite { - ts = httptest.NewTLSServer(h1) - } else { - ts = httptest.NewServer(h1) - } - var proxy *httptest.Server - if httpsProxy { - proxy = httptest.NewTLSServer(h2) - } else { - proxy = httptest.NewServer(h2) - } + ts := newClientServerTest(t, siteMode, h1).ts + proxy := newClientServerTest(t, proxyMode, h2).ts pu, err := url.Parse(proxy.URL) if err != nil { @@ -1436,7 +1412,7 @@ func TestTransportProxy(t *testing.T) { // If only one server is HTTPS, c must be derived from that server in order // to ensure that it is configured to use the fake root CA from testcert.go. c := proxy.Client() - if httpsSite { + if siteMode == https1Mode { c = ts.Client() } @@ -1453,7 +1429,7 @@ func TestTransportProxy(t *testing.T) { c.Transport.(*Transport).CloseIdleConnections() ts.Close() proxy.Close() - if httpsSite { + if siteMode == https1Mode { // First message should be a CONNECT, asking for a socket to the real server, if got.Method != "CONNECT" { t.Errorf("Wrong method for secure proxying: %q", got.Method) @@ -1602,10 +1578,10 @@ func TestTransportDialPreservesNetOpProxyError(t *testing.T) { // (A bug caused dialConn to instead write the per-request Proxy-Authorization // header through to the shared Header instance, introducing a data race.) func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { - setParallel(t) - defer afterTest(t) - - proxy := httptest.NewTLSServer(NotFoundHandler()) + run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader) +} +func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) { + proxy := newClientServerTest(t, mode, NotFoundHandler()).ts defer proxy.Close() c := proxy.Client() @@ -1639,13 +1615,12 @@ func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { // client gets the same value back. This is more cute than anything, // but checks that we don't recurse forever, and checks that // Content-Encoding is removed. -func TestTransportGzipRecursive(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) } +func testTransportGzipRecursive(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write(rgz) - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -1667,13 +1642,12 @@ func TestTransportGzipRecursive(t *testing.T) { // golang.org/issue/7750: request fails when server replies with // a short gzip body -func TestTransportGzipShort(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) } +func testTransportGzipShort(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write([]byte{0x1f, 0x8b}) - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -1703,19 +1677,20 @@ func waitNumGoroutine(nmax int) int { // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { + run(t, testTransportPersistConnLeak, testNotParallel) +} +func testTransportPersistConnLeak(t *testing.T, mode testMode) { // Not parallel: counts goroutines - defer afterTest(t) const numReq = 25 gotReqCh := make(chan bool, numReq) unblockCh := make(chan bool, numReq) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotReqCh <- true <-unblockCh w.Header().Set("Content-Length", "0") w.WriteHeader(204) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -1773,11 +1748,12 @@ func TestTransportPersistConnLeak(t *testing.T) { // golang.org/issue/4531: Transport leaks goroutines when // request.ContentLength is explicitly short func TestTransportPersistConnLeakShortBody(t *testing.T) { + run(t, testTransportPersistConnLeakShortBody, testNotParallel) +} +func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) { // Not parallel: measures goroutines. - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - })) - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -1851,9 +1827,10 @@ func (d *countingDialer) Read() (total, live int64) { } func TestTransportPersistConnLeakNeverIdle(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode}) +} +func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // Close every connection so that it cannot be kept alive. conn, _, err := w.(Hijacker).Hijack() if err != nil { @@ -1861,8 +1838,7 @@ func TestTransportPersistConnLeakNeverIdle(t *testing.T) { return } conn.Close() - })) - defer ts.Close() + })).ts var d countingDialer c := ts.Client() @@ -1923,13 +1899,17 @@ func (cc *contextCounter) Read() (live int64) { } func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) { - defer afterTest(t) + run(t, testTransportPersistConnContextLeakMaxConnsPerHost) +} +func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/56021") + } - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { runtime.Gosched() w.WriteHeader(StatusOK) - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).MaxConnsPerHost = 1 @@ -1979,16 +1959,15 @@ func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) { } // This used to crash; https://golang.org/issue/3266 -func TestTransportIdleConnCrash(t *testing.T) { - defer afterTest(t) +func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) } +func testTransportIdleConnCrash(t *testing.T, mode testMode) { var tr *Transport unblockCh := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockCh tr.CloseIdleConnections() - })) - defer ts.Close() + })).ts c := ts.Client() tr = c.Transport.(*Transport) @@ -2010,16 +1989,15 @@ func TestTransportIdleConnCrash(t *testing.T) { // before the response body has been read. This was a regression // which sadly lacked a triggering test. The large response body made // the old race easier to trigger. -func TestIssue3644(t *testing.T) { - defer afterTest(t) +func TestIssue3644(t *testing.T) { run(t, testIssue3644) } +func testIssue3644(t *testing.T, mode testMode) { const numFoos = 5000 - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") for i := 0; i < numFoos; i++ { w.Write([]byte("foo ")) } - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Get(ts.URL) if err != nil { @@ -2037,14 +2015,12 @@ func TestIssue3644(t *testing.T) { // Test that a client receives a server's reply, even if the server doesn't read // the entire request body. -func TestIssue3595(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestIssue3595(t *testing.T) { run(t, testIssue3595) } +func testIssue3595(t *testing.T, mode testMode) { const deniedMsg = "sorry, denied." - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, deniedMsg, StatusUnauthorized) - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) if err != nil { @@ -2062,12 +2038,11 @@ func TestIssue3595(t *testing.T) { // From https://golang.org/issue/4454 , // "client fails to handle requests with no body and chunked encoding" -func TestChunkedNoContent(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) } +func testChunkedNoContent(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNoContent) - })) - defer ts.Close() + })).ts c := ts.Client() for _, closeBody := range []bool{true, false} { @@ -2086,17 +2061,18 @@ func TestChunkedNoContent(t *testing.T) { } func TestTransportConcurrency(t *testing.T) { + run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode}) +} +func testTransportConcurrency(t *testing.T, mode testMode) { // Not parallel: uses global test hooks. - defer afterTest(t) maxProcs, numReqs := 16, 500 if testing.Short() { maxProcs, numReqs = 4, 50 } defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%v", r.FormValue("echo")) - })) - defer ts.Close() + })).ts var wg sync.WaitGroup wg.Add(numReqs) @@ -2147,16 +2123,14 @@ func TestTransportConcurrency(t *testing.T) { wg.Wait() } -func TestIssue4191_InfiniteGetTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) } +func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) { const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { io.Copy(w, neverEnding('a')) }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts timeout := 100 * time.Millisecond c := ts.Client() @@ -2206,8 +2180,9 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode}) +} +func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) { const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { @@ -2217,7 +2192,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { defer r.Body.Close() io.Copy(io.Discard, r.Body) }) - ts := httptest.NewServer(mux) + ts := newClientServerTest(t, mode, mux).ts timeout := 100 * time.Millisecond c := ts.Client() @@ -2270,9 +2245,8 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { ts.Close() } -func TestTransportResponseHeaderTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) } +func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping timeout test in -short mode") } @@ -2285,8 +2259,7 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { inHandler <- true time.Sleep(2 * time.Second) }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts c := ts.Client() c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond @@ -2342,18 +2315,18 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { } func TestTransportCancelRequest(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportCancelRequest, []testMode{http1Mode}) +} +func testTransportCancelRequest(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello") w.(Flusher).Flush() // send headers and some body <-unblockc - })) - defer ts.Close() + })).ts defer close(unblockc) c := ts.Client() @@ -2395,17 +2368,14 @@ func TestTransportCancelRequest(t *testing.T) { } } -func testTransportCancelRequestInDo(t *testing.T, body io.Reader) { - setParallel(t) - defer afterTest(t) +func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) { if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc - })) - defer ts.Close() + })).ts defer close(unblockc) c := ts.Client() @@ -2432,11 +2402,15 @@ func testTransportCancelRequestInDo(t *testing.T, body io.Reader) { } func TestTransportCancelRequestInDo(t *testing.T) { - testTransportCancelRequestInDo(t, nil) + run(t, func(t *testing.T, mode testMode) { + testTransportCancelRequestInDo(t, mode, nil) + }, []testMode{http1Mode}) } func TestTransportCancelRequestWithBodyInDo(t *testing.T) { - testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0})) + run(t, func(t *testing.T, mode testMode) { + testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0})) + }, []testMode{http1Mode}) } func TestTransportCancelRequestInDial(t *testing.T) { @@ -2497,19 +2471,17 @@ Get = Get "http://something.no-network.tld/": net/http: request canceled while w } } -func TestCancelRequestWithChannel(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) } +func testCancelRequestWithChannel(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello") w.(Flusher).Flush() // send headers and some body <-unblockc - })) - defer ts.Close() + })).ts defer close(unblockc) c := ts.Client() @@ -2555,19 +2527,20 @@ func TestCancelRequestWithChannel(t *testing.T) { } func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { - testCancelRequestWithChannelBeforeDo(t, false) + run(t, func(t *testing.T, mode testMode) { + testCancelRequestWithChannelBeforeDo(t, mode, false) + }) } func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { - testCancelRequestWithChannelBeforeDo(t, true) + run(t, func(t *testing.T, mode testMode) { + testCancelRequestWithChannelBeforeDo(t, mode, true) + }) } -func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { - setParallel(t) - defer afterTest(t) +func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) { unblockc := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc - })) - defer ts.Close() + })).ts defer close(unblockc) c := ts.Client() @@ -2642,11 +2615,11 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) { // golang.org/issue/3672 -- Client can't close HTTP stream // Calling Close on a Response.Body used to just read until EOF. // Now it actually closes the TCP connection. -func TestTransportCloseResponseBody(t *testing.T) { - defer afterTest(t) +func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) } +func testTransportCloseResponseBody(t *testing.T, mode testMode) { writeErr := make(chan error, 1) msg := []byte("young\n") - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { for { _, err := w.Write(msg) if err != nil { @@ -2655,8 +2628,7 @@ func TestTransportCloseResponseBody(t *testing.T) { } w.(Flusher).Flush() } - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -2761,10 +2733,8 @@ func TestTransportEmptyMethod(t *testing.T) { } } -func TestTransportSocketLateBinding(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) } +func testTransportSocketLateBinding(t *testing.T, mode testMode) { mux := NewServeMux() fooGate := make(chan bool, 1) mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) { @@ -2775,8 +2745,7 @@ func TestTransportSocketLateBinding(t *testing.T) { mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) { w.Header().Set("bar-ipport", r.RemoteAddr) }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts dialGate := make(chan bool, 1) c := ts.Client() @@ -2920,15 +2889,15 @@ Content-Length: %d // Issue 17739: the HTTP client must ignore any unknown 1xx // informational responses before the actual response. func TestTransportIgnore1xxResponses(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportIgnore1xxResponses, []testMode{http1Mode}) +} +func testTransportIgnore1xxResponses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello")) buf.Flush() conn.Close() })) - defer cst.close() cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway var got strings.Builder @@ -2954,9 +2923,10 @@ func TestTransportIgnore1xxResponses(t *testing.T) { } func TestTransportLimits1xxResponses(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportLimits1xxResponses, []testMode{http1Mode}) +} +func testTransportLimits1xxResponses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() for i := 0; i < 10; i++ { buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n")) @@ -2965,7 +2935,6 @@ func TestTransportLimits1xxResponses(t *testing.T) { buf.Flush() conn.Close() })) - defer cst.close() cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway res, err := cst.c.Get(cst.ts.URL) @@ -2982,16 +2951,16 @@ func TestTransportLimits1xxResponses(t *testing.T) { // Issue 26161: the HTTP client must treat 101 responses // as the final response. func TestTransportTreat101Terminal(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportTreat101Terminal, []testMode{http1Mode}) +} +func testTransportTreat101Terminal(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n")) buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) buf.Flush() conn.Close() })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -3123,16 +3092,18 @@ func TestProxyFromEnvironmentLowerCase(t *testing.T) { } func TestIdleConnChannelLeak(t *testing.T) { + run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel) +} +func testIdleConnChannelLeak(t *testing.T, mode testMode) { // Not parallel: uses global test hooks. var mu sync.Mutex var n int - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() n++ mu.Unlock() - })) - defer ts.Close() + })).ts const nReqs = 5 didRead := make(chan bool, nReqs) @@ -3180,11 +3151,12 @@ func TestIdleConnChannelLeak(t *testing.T) { // body into a ReadCloser if it's a Closer, and that the Transport // then closes it. func TestTransportClosesRequestBody(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportClosesRequestBody, []testMode{http1Mode}) +} +func testTransportClosesRequestBody(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) - })) - defer ts.Close() + })).ts c := ts.Client() @@ -3261,10 +3233,11 @@ func TestTransportTLSHandshakeTimeout(t *testing.T) { // Trying to repro golang.org/issue/3514 func TestTLSServerClosesConnection(t *testing.T) { - defer afterTest(t) - + run(t, testTLSServerClosesConnection, []testMode{https1Mode}) +} +func testTLSServerClosesConnection(t *testing.T, mode testMode) { closedc := make(chan bool, 1) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if strings.Contains(r.URL.Path, "/keep-alive-then-die") { conn, _, _ := w.(Hijacker).Hijack() conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) @@ -3273,8 +3246,7 @@ func TestTLSServerClosesConnection(t *testing.T) { return } fmt.Fprintf(w, "hello") - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -3345,8 +3317,9 @@ func (c byteFromChanReader) Read(p []byte) (n int, err error) { // questionable state. // golang.org/issue/7569 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}) +} +func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) { var sconn struct { sync.Mutex c net.Conn @@ -3365,7 +3338,7 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { } defer closeConn() - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method == "GET" { io.WriteString(w, "bar") return @@ -3376,8 +3349,7 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { sconn.Unlock() conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive go io.Copy(io.Discard, conn) - })) - defer ts.Close() + })).ts c := ts.Client() const bodySize = 256 << 10 @@ -3410,9 +3382,9 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { // Tests that we don't leak Transport persistConn.readLoop goroutines // when a server hangs up immediately after saying it would keep-alive. -func TestTransportIssue10457(t *testing.T) { - defer afterTest(t) // used to fail in goroutine leak check - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) } +func testTransportIssue10457(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // Send a response with no body, keep-alive // (implicit), and then lie and immediately close the // connection. This forces the Transport's readLoop to @@ -3421,8 +3393,7 @@ func TestTransportIssue10457(t *testing.T) { conn, _, _ := w.(Hijacker).Hijack() conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive conn.Close() - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -3463,6 +3434,9 @@ func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) } // This automatically prevents an infinite resend loop because we'll run out of // the cached keep-alive connections eventually. func TestRetryRequestsOnError(t *testing.T) { + run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode}) +} +func testRetryRequestsOnError(t *testing.T, mode testMode) { newRequest := func(method, urlStr string, body io.Reader) *Request { req, err := NewRequest(method, urlStr, body) if err != nil { @@ -3533,8 +3507,6 @@ func TestRetryRequestsOnError(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - defer afterTest(t) - var ( mu sync.Mutex logbuf strings.Builder @@ -3546,11 +3518,10 @@ func TestRetryRequestsOnError(t *testing.T) { logbuf.WriteByte('\n') } - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { logf("Handler") w.Header().Set("X-Status", "ok") - })) - defer ts.Close() + })).ts var writeNumAtomic int32 c := ts.Client() @@ -3620,15 +3591,13 @@ Handler } // Issue 6981 -func TestTransportClosesBodyOnError(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) } +func testTransportClosesBodyOnError(t *testing.T, mode testMode) { readBody := make(chan error, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := io.ReadAll(r.Body) readBody <- err - })) - defer ts.Close() + })).ts c := ts.Client() fakeErr := errors.New("fake error") didClose := make(chan bool, 1) @@ -3668,17 +3637,17 @@ func TestTransportClosesBodyOnError(t *testing.T) { } func TestTransportDialTLS(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode}) +} +func testTransportDialTLS(t *testing.T, mode testMode) { var mu sync.Mutex // guards following var gotReq, didDial bool - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) { mu.Lock() @@ -3705,19 +3674,17 @@ func TestTransportDialTLS(t *testing.T) { } } -func TestTransportDialContext(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) } +func testTransportDialContext(t *testing.T, mode testMode) { var mu sync.Mutex // guards following var gotReq bool var receivedContext context.Context - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { mu.Lock() @@ -3746,18 +3713,18 @@ func TestTransportDialContext(t *testing.T) { } func TestTransportDialTLSContext(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode}) +} +func testTransportDialTLSContext(t *testing.T, mode testMode) { var mu sync.Mutex // guards following var gotReq bool var receivedContext context.Context - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { mu.Lock() @@ -3879,6 +3846,9 @@ func TestTransportTraceGotConnH2IdleConns(t *testing.T) { } func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { + run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode}) +} +func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } @@ -3888,8 +3858,7 @@ func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { tr.MaxIdleConnsPerHost = 1 tr.IdleConnTimeout = 10 * time.Millisecond } - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) - defer cst.close() + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) if _, err := cst.c.Get(cst.ts.URL); err != nil { t.Fatalf("got error: %s", err) @@ -3920,13 +3889,12 @@ func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { // implicitly ask for gzip support. If they want that, they need to do it // on their own. // golang.org/issue/8923 -func TestTransportRangeAndGzip(t *testing.T) { - defer afterTest(t) +func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) } +func testTransportRangeAndGzip(t *testing.T, mode testMode) { reqc := make(chan *Request, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { reqc <- r - })) - defer ts.Close() + })).ts c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) @@ -3951,15 +3919,13 @@ func TestTransportRangeAndGzip(t *testing.T) { } // Test for issue 10474 -func TestTransportResponseCancelRace(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) } +func testTransportResponseCancelRace(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // important that this response has a body. var b [1024]byte w.Write(b[:]) - })) - defer ts.Close() + })).ts tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) @@ -3991,19 +3957,19 @@ func TestTransportResponseCancelRace(t *testing.T) { // Test for issue 19248: Content-Encoding's value is case insensitive. func TestTransportContentEncodingCaseInsensitive(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportContentEncodingCaseInsensitive) +} +func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) { for _, ce := range []string{"gzip", "GZIP"} { ce := ce t.Run(ce, func(t *testing.T) { const encodedString = "Hello Gopher" - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", ce) gz := gzip.NewWriter(w) gz.Write([]byte(encodedString)) gz.Close() - })) - defer ts.Close() + })).ts res, err := ts.Client().Get(ts.URL) if err != nil { @@ -4024,10 +3990,10 @@ func TestTransportContentEncodingCaseInsensitive(t *testing.T) { } func TestTransportDialCancelRace(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode}) +} +func testTransportDialCancelRace(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) @@ -4140,13 +4106,12 @@ func TestTransportFlushesBodyChunks(t *testing.T) { } // Issue 22088: flush Transport request headers if we're not sure the body won't block on read. -func TestTransportFlushesRequestHeader(t *testing.T) { - defer afterTest(t) +func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) } +func testTransportFlushesRequestHeader(t *testing.T, mode testMode) { gotReq := make(chan struct{}) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { close(gotReq) })) - defer cst.close() pr, pw := io.Pipe() req, err := NewRequest("POST", cst.ts.URL, pr) @@ -4175,20 +4140,21 @@ func TestTransportFlushesRequestHeader(t *testing.T) { // Issue 11745. func TestTransportPrefersResponseOverWriteError(t *testing.T) { + run(t, testTransportPrefersResponseOverWriteError) +} +func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - defer afterTest(t) const contentLengthLimit = 1024 * 1024 // 1MB - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.ContentLength >= contentLengthLimit { w.WriteHeader(StatusBadRequest) r.Body.Close() return } w.WriteHeader(StatusOK) - })) - defer ts.Close() + })).ts c := ts.Client() fail := 0 @@ -4296,12 +4262,13 @@ func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) { // Plus it's nice to be consistent and not have timing-dependent // behavior. func TestTransportReuseConnEmptyResponseBody(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportReuseConnEmptyResponseBody) +} +func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Addr", r.RemoteAddr) // Empty response body. })) - defer cst.close() n := 100 if testing.Short() { n = 10 @@ -4406,27 +4373,28 @@ func TestNoCrashReturningTransportAltConn(t *testing.T) { } func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) { - testTransportReuseConnection_Gzip(t, true) + run(t, func(t *testing.T, mode testMode) { + testTransportReuseConnection_Gzip(t, mode, true) + }) } func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) { - testTransportReuseConnection_Gzip(t, false) + run(t, func(t *testing.T, mode testMode) { + testTransportReuseConnection_Gzip(t, mode, false) + }) } // Make sure we re-use underlying TCP connection for gzipped responses too. -func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { - setParallel(t) - defer afterTest(t) +func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) { addr := make(chan string, 2) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { addr <- r.RemoteAddr w.Header().Set("Content-Encoding", "gzip") if chunked { w.(Flusher).Flush() } w.Write(rgz) // arbitrary gzip response - })) - defer ts.Close() + })).ts c := ts.Client() trace := &httptrace.ClientTrace{ @@ -4459,15 +4427,16 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { } } -func TestTransportResponseHeaderLength(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) } +func testTransportResponseHeaderLength(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes") + } + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.URL.Path == "/long" { w.Header().Set("Long", strings.Repeat("a", 1<<20)) } - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10 @@ -4493,18 +4462,23 @@ func TestTransportResponseHeaderLength(t *testing.T) { } } -func TestTransportEventTrace(t *testing.T) { testTransportEventTrace(t, h1Mode, false) } -func TestTransportEventTrace_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, false) } +func TestTransportEventTrace(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTransportEventTrace(t, mode, false) + }) +} // test a non-nil httptrace.ClientTrace but with all hooks set to zero. -func TestTransportEventTrace_NoHooks(t *testing.T) { testTransportEventTrace(t, h1Mode, true) } -func TestTransportEventTrace_NoHooks_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, true) } +func TestTransportEventTrace_NoHooks(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTransportEventTrace(t, mode, true) + }) +} -func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { - defer afterTest(t) +func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) { const resBody = "some body" gotWroteReqEvent := make(chan struct{}, 500) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method == "GET" { // Do nothing for the second request. return @@ -4520,7 +4494,11 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { } } io.WriteString(w, resBody) - })) + }), func(tr *Transport) { + if tr.TLSClientConfig != nil { + tr.TLSClientConfig.InsecureSkipVerify = true + } + }) defer cst.close() cst.tr.ExpectContinueTimeout = 1 * time.Second @@ -4579,7 +4557,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { gotWroteReqEvent <- struct{}{} }, } - if h2 { + if mode == http2Mode { trace.TLSHandshakeStart = func() { logf("tls handshake start") } trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) { logf("tls handshake done. ConnectionState = %v \n err = %v", s, err) @@ -4636,7 +4614,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { wantOnceOrMore("connected to tcp " + addrStr + " = ") wantOnce("Reused:false WasIdle:false IdleTime:0s") wantOnce("first response byte") - if h2 { + if mode == http2Mode { wantOnce("tls handshake start") wantOnce("tls handshake done") } else { @@ -4684,6 +4662,9 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { } func TestTransportEventTraceTLSVerify(t *testing.T) { + run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode}) +} +func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) { var mu sync.Mutex var buf strings.Builder logf := func(format string, args ...any) { @@ -4693,14 +4674,14 @@ func TestTransportEventTraceTLSVerify(t *testing.T) { buf.WriteByte('\n') } - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Error("Unexpected request") - })) - defer ts.Close() - ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { - logf("%s", p) - return len(p), nil - }), "", 0) + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { + logf("%s", p) + return len(p), nil + }), "", 0) + }).ts certpool := x509.NewCertPool() certpool.AddCert(ts.Certificate()) @@ -4834,9 +4815,10 @@ func TestTransportRejectsAlphaPort(t *testing.T) { // Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1 // connections. The http2 test is done in TestTransportEventTrace_h2 func TestTLSHandshakeTrace(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode}) +} +func testTLSHandshakeTrace(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts var mu sync.Mutex var start, done bool @@ -4879,11 +4861,12 @@ func TestTLSHandshakeTrace(t *testing.T) { } func TestTransportMaxIdleConns(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportMaxIdleConns, []testMode{http1Mode}) +} +func testTransportMaxIdleConns(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // No body for convenience. - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) tr.MaxIdleConns = 4 @@ -4931,27 +4914,24 @@ func TestTransportMaxIdleConns(t *testing.T) { } } -func TestTransportIdleConnTimeout_h1(t *testing.T) { testTransportIdleConnTimeout(t, h1Mode) } -func TestTransportIdleConnTimeout_h2(t *testing.T) { testTransportIdleConnTimeout(t, h2Mode) } -func testTransportIdleConnTimeout(t *testing.T, h2 bool) { +func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) } +func testTransportIdleConnTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - defer afterTest(t) const timeout = 1 * time.Second - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // No body for convenience. })) - defer cst.close() tr := cst.tr tr.IdleConnTimeout = timeout defer tr.CloseIdleConnections() c := &Client{Transport: tr} idleConns := func() []string { - if h2 { + if mode == http2Mode { return tr.IdleConnStrsForTesting_h2() } else { return tr.IdleConnStrsForTesting() @@ -5005,12 +4985,11 @@ func testTransportIdleConnTimeout(t *testing.T, h2 bool) { // real connection until after the RoundTrip saw the error. Then we // know the successful tls.Dial from DialTLS will need to go into the // idle pool. Then we give it a of time to explode. -func TestIdleConnH2Crash(t *testing.T) { - setParallel(t) - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) } +func testIdleConnH2Crash(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // nothing })) - defer cst.close() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -5101,21 +5080,18 @@ func TestTransportReturnsPeekError(t *testing.T) { } // Issue 13835: international domain names should work -func TestTransportIDNA_h1(t *testing.T) { testTransportIDNA(t, h1Mode) } -func TestTransportIDNA_h2(t *testing.T) { testTransportIDNA(t, h2Mode) } -func testTransportIDNA(t *testing.T, h2 bool) { - defer afterTest(t) - +func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) } +func testTransportIDNA(t *testing.T, mode testMode) { const uniDomain = "гофер.го" const punyDomain = "xn--c1ae0ajs.xn--c1aw" var port string - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { want := punyDomain + ":" + port if r.Host != want { t.Errorf("Host header = %q; want %q", r.Host, want) } - if h2 { + if mode == http2Mode { if r.TLS == nil { t.Errorf("r.TLS == nil") } else if r.TLS.ServerName != punyDomain { @@ -5123,8 +5099,11 @@ func testTransportIDNA(t *testing.T, h2 bool) { } } w.Header().Set("Hit-Handler", "1") - })) - defer cst.close() + }), func(tr *Transport) { + if tr.TLSClientConfig != nil { + tr.TLSClientConfig.InsecureSkipVerify = true + } + }) ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String()) if err != nil { @@ -5172,9 +5151,11 @@ func testTransportIDNA(t *testing.T, h2 bool) { // Issue 13290: send User-Agent in proxy CONNECT func TestTransportProxyConnectHeader(t *testing.T) { - defer afterTest(t) + run(t, testTransportProxyConnectHeader, []testMode{http1Mode}) +} +func testTransportProxyConnectHeader(t *testing.T, mode testMode) { reqc := make(chan *Request, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("method = %q; want CONNECT", r.Method) } @@ -5185,8 +5166,7 @@ func TestTransportProxyConnectHeader(t *testing.T) { return } c.Close() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { @@ -5216,9 +5196,11 @@ func TestTransportProxyConnectHeader(t *testing.T) { } func TestTransportProxyGetConnectHeader(t *testing.T) { - defer afterTest(t) + run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode}) +} +func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) { reqc := make(chan *Request, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("method = %q; want CONNECT", r.Method) } @@ -5229,8 +5211,7 @@ func TestTransportProxyGetConnectHeader(t *testing.T) { return } c.Close() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { @@ -5417,14 +5398,15 @@ func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, p // Issue 22330: do not allow the response body to be read when the status code // forbids a response body. func TestNoBodyOnChunked304Response(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testNoBodyOnChunked304Response, []testMode{http1Mode}) +} +func testNoBodyOnChunked304Response(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n")) buf.Flush() conn.Close() })) - defer cst.close() // Our test server above is sending back bogus data after the // response (the "0\r\n\r\n" part), which causes the Transport @@ -5477,11 +5459,12 @@ func TestTransportCheckContextDoneEarly(t *testing.T) { // This is the test variant that times out before the server replies with // any response headers. func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode}) +} +func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) { inHandler := make(chan net.Conn, 1) handlerReadReturned := make(chan bool, 1) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -5494,7 +5477,6 @@ func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { } handlerReadReturned <- true })) - defer cst.close() const timeout = 50 * time.Millisecond cst.c.Timeout = timeout @@ -5529,11 +5511,12 @@ func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { // This is the test variant that has the server send response headers // first, and time out during the write of the response body. func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode}) +} +func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) { inHandler := make(chan net.Conn, 1) handlerResult := make(chan error, 1) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "100") w.(Flusher).Flush() conn, _, err := w.(Hijacker).Hijack() @@ -5555,7 +5538,6 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { } handlerResult <- nil })) - defer cst.close() // Set Timeout to something very long but non-zero to exercise // the codepaths that check for it. But rather than wait for it to fire @@ -5601,11 +5583,12 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { } func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode}) +} +func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) { done := make(chan struct{}) defer close(done) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -5618,7 +5601,6 @@ func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text())) <-done })) - defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, nil) req.Header.Set("Upgrade", "foo") @@ -5651,10 +5633,10 @@ func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { } } -func TestTransportCONNECTBidi(t *testing.T) { - defer afterTest(t) +func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) } +func testTransportCONNECTBidi(t *testing.T, mode testMode) { const target = "backend:443" - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("unexpected method %q", r.Method) w.WriteHeader(500) @@ -5685,7 +5667,6 @@ func TestTransportCONNECTBidi(t *testing.T) { brw.Flush() } })) - defer cst.close() pr, pw := io.Pipe() defer pw.Close() req, err := NewRequest("CONNECT", cst.ts.URL, pr) @@ -5782,7 +5763,8 @@ func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { return c.TCPConn.ReadFrom(r) } -func TestTransportRequestWriteRoundTrip(t *testing.T) { +func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) } +func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) { nBytes := int64(1 << 10) newFileFunc := func() (r io.Reader, done func(), err error) { f, err := os.CreateTemp("", "net-http-newfilefunc") @@ -5876,7 +5858,7 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { cst := newClientServerTest( t, - h1Mode, + mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) r.Body.Close() @@ -5884,7 +5866,6 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { }), trFunc, ) - defer cst.close() req, err := NewRequest("PUT", cst.ts.URL, r) if err != nil { @@ -5901,11 +5882,15 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { t.Fatalf("status code = %d; want 200", resp.StatusCode) } - if !tConn.ReadFromCalled && tc.expectedReadFrom { + expectedReadFrom := tc.expectedReadFrom + if mode != http1Mode { + expectedReadFrom = false + } + if !tConn.ReadFromCalled && expectedReadFrom { t.Fatalf("did not call ReadFrom") } - if tConn.ReadFromCalled && !tc.expectedReadFrom { + if tConn.ReadFromCalled && !expectedReadFrom { t.Fatalf("ReadFrom was unexpectedly invoked") } }) @@ -5985,17 +5970,17 @@ func TestIs408(t *testing.T) { } } -func TestTransportIgnores408(t *testing.T) { +func TestTransportIgnores408(t *testing.T) { run(t, testTransportIgnores408, []testMode{http1Mode}) } +func testTransportIgnores408(t *testing.T, mode testMode) { // Not parallel. Relies on mutating the log package's global Output. defer log.SetOutput(log.Writer()) var logout strings.Builder log.SetOutput(&logout) - defer afterTest(t) const target = "backend:443" - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { nc, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -6005,7 +5990,6 @@ func TestTransportIgnores408(t *testing.T) { nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")) nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail })) - defer cst.close() req, err := NewRequest("GET", cst.ts.URL, nil) if err != nil { t.Fatal(err) @@ -6039,9 +6023,10 @@ func TestTransportIgnores408(t *testing.T) { } func TestInvalidHeaderResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testInvalidHeaderResponse, []testMode{http1Mode}) +} +func testInvalidHeaderResponse(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 200 OK\r\n" + "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + @@ -6051,7 +6036,6 @@ func TestInvalidHeaderResponse(t *testing.T) { buf.Flush() conn.Close() })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -6078,10 +6062,12 @@ func (bc *bodyCloser) Read(b []byte) (n int, err error) { // Issue 35015: ensure that Transport closes the body on any error // with an invalid request, as promised by Client.Do docs. func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { - cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportClosesBodyOnInvalidRequests) +} +func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Errorf("Should not have been invoked") - })) - defer cst.Close() + })).ts u, _ := url.Parse(cst.URL) @@ -6146,7 +6132,7 @@ func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { var bc bodyCloser req := tt.req req.Body = &bc - _, err := DefaultClient.Do(tt.req) + _, err := cst.Client().Do(tt.req) if err == nil { t.Fatal("Expected an error") } @@ -6183,8 +6169,10 @@ func (w *breakableConn) Write(b []byte) (n int, err error) { // Issue 34978: don't cache a broken HTTP/2 connection func TestDontCacheBrokenHTTP2Conn(t *testing.T) { - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog) - defer cst.close() + run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode}) +} +func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog) var brokenState brokenState @@ -6246,7 +6234,9 @@ func TestDontCacheBrokenHTTP2Conn(t *testing.T) { // http.http2noCachedConnError is reported on multiple requests. There should // only be one decrement regardless of the number of failures. func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { - defer afterTest(t) + run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode}) +} +func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) { CondSkipHTTP2(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -6256,17 +6246,11 @@ func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { } }) - ts := httptest.NewUnstartedServer(h) - ts.EnableHTTP2 = true - ts.StartTLS() - defer ts.Close() + ts := newClientServerTest(t, mode, h).ts c := ts.Client() tr := c.Transport.(*Transport) tr.MaxConnsPerHost = 1 - if err := ExportHttp2ConfigureTransport(tr); err != nil { - t.Fatalf("ExportHttp2ConfigureTransport: %v", err) - } errCh := make(chan error, 300) doReq := func() { @@ -6335,14 +6319,13 @@ type roundTripFunc func(r *Request) (*Response, error) func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) } // Issue 32441: body is not reset after ErrSkipAltProtocol -func TestIssue32441(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) } +func testIssue32441(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if n, _ := io.Copy(io.Discard, r.Body); n == 0 { t.Error("body length is zero") } - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) { // Draining body to trigger failure condition on actual request to server. @@ -6359,11 +6342,13 @@ func TestIssue32441(t *testing.T) { // Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers // that contain a sign (eg. "+3"), per RFC 2616, Section 14.13. func TestTransportRejectsSignInContentLength(t *testing.T) { - cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode}) +} +func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "+3") w.Write([]byte("abc")) - })) - defer cst.Close() + })).ts c := cst.Client() res, err := c.Get(cst.URL) @@ -6477,14 +6462,16 @@ func TestErrorWriteLoopRace(t *testing.T) { // Test that a new request which uses the connection of an active request // cannot cause it to be canceled as well. func TestCancelRequestWhenSharingConnection(t *testing.T) { + run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode}) +} +func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) { reqc := make(chan chan struct{}, 2) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { ch := make(chan struct{}, 1) reqc <- ch <-ch w.Header().Add("Content-Length", "0") - })) - defer ts.Close() + })).ts client := ts.Client() transport := client.Transport.(*Transport) @@ -6549,15 +6536,12 @@ func TestCancelRequestWhenSharingConnection(t *testing.T) { wg.Wait() } -func TestHandlerAbortRacesBodyRead(t *testing.T) { - setParallel(t) - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { +func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) } +func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { go io.Copy(io.Discard, req.Body) panic(ErrAbortHandler) - })) - defer ts.Close() + })).ts var wg sync.WaitGroup for i := 0; i < 2; i++ {