diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index 0873038757..32d97ea9f0 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -40,9 +40,10 @@ import ( type testMode string const ( - http1Mode = testMode("h1") // HTTP/1.1 - https1Mode = testMode("https1") // HTTPS/1.1 - http2Mode = testMode("h2") // HTTP/2 + http1Mode = testMode("h1") // HTTP/1.1 + https1Mode = testMode("https1") // HTTPS/1.1 + http2Mode = testMode("h2") // HTTP/2 + http2UnencryptedMode = testMode("h2unencrypted") // HTTP/2 ) type testNotParallelOpt struct{} @@ -132,6 +133,7 @@ type clientServerTest struct { ts *httptest.Server tr *Transport c *Client + li *fakeNetListener } func (t *clientServerTest) close() { @@ -169,6 +171,8 @@ func optWithServerLog(lg *log.Logger) func(*httptest.Server) { } } +var optFakeNet = new(struct{}) + // newClientServerTest creates and starts an httptest.Server. // // The mode parameter selects the implementation to test: @@ -180,6 +184,9 @@ func optWithServerLog(lg *log.Logger) func(*httptest.Server) { // // func(*httptest.Server) // run before starting the server // func(*http.Transport) +// +// The optFakeNet option configures the server and client to use a fake network implementation, +// suitable for use in testing/synctest tests. func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest { if mode == http2Mode { CondSkipHTTP2(t) @@ -189,9 +196,31 @@ func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *c h2: mode == http2Mode, h: h, } - cst.ts = httptest.NewUnstartedServer(h) var transportFuncs []func(*Transport) + + if idx := slices.Index(opts, any(optFakeNet)); idx >= 0 { + opts = slices.Delete(opts, idx, idx+1) + cst.li = fakeNetListen() + cst.ts = &httptest.Server{ + Config: &Server{Handler: h}, + Listener: cst.li, + } + transportFuncs = append(transportFuncs, func(tr *Transport) { + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return cst.li.connect(), nil + } + }) + } else { + cst.ts = httptest.NewUnstartedServer(h) + } + + if mode == http2UnencryptedMode { + p := &Protocols{} + p.SetUnencryptedHTTP2(true) + cst.ts.Config.Protocols = p + } + for _, opt := range opts { switch opt := opt.(type) { case func(*Transport): @@ -212,6 +241,9 @@ func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *c cst.ts.Start() case https1Mode: cst.ts.StartTLS() + case http2UnencryptedMode: + ExportHttp2ConfigureServer(cst.ts.Config, nil) + cst.ts.Start() case http2Mode: ExportHttp2ConfigureServer(cst.ts.Config, nil) cst.ts.TLS = cst.ts.Config.TLSConfig @@ -221,7 +253,7 @@ func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *c } cst.c = cst.ts.Client() cst.tr = cst.c.Transport.(*Transport) - if mode == http2Mode { + if mode == http2Mode || mode == http2UnencryptedMode { if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { t.Fatal(err) } @@ -229,6 +261,13 @@ func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *c for _, f := range transportFuncs { f(cst.tr) } + + if mode == http2UnencryptedMode { + p := &Protocols{} + p.SetUnencryptedHTTP2(true) + cst.tr.Protocols = p + } + t.Cleanup(func() { cst.close() }) @@ -246,9 +285,19 @@ func (w testLogWriter) Write(b []byte) (int, error) { // Testing the newClientServerTest helper itself. func TestNewClientServerTest(t *testing.T) { - run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode}) + modes := []testMode{http1Mode, https1Mode, http2Mode} + t.Run("realnet", func(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testNewClientServerTest(t, mode) + }, modes) + }) + t.Run("synctest", func(t *testing.T) { + runSynctest(t, func(t testing.TB, mode testMode) { + testNewClientServerTest(t, mode, optFakeNet) + }, modes) + }) } -func testNewClientServerTest(t *testing.T, mode testMode) { +func testNewClientServerTest(t testing.TB, mode testMode, opts ...any) { var got struct { sync.Mutex proto string @@ -260,7 +309,7 @@ func testNewClientServerTest(t *testing.T, mode testMode) { got.proto = r.Proto got.hasTLS = r.TLS != nil }) - cst := newClientServerTest(t, mode, h) + cst := newClientServerTest(t, mode, h, opts...) if _, err := cst.c.Head(cst.ts.URL); err != nil { t.Fatal(err) } diff --git a/src/net/http/netconn_test.go b/src/net/http/netconn_test.go index 251b919f67..ed02b98d43 100644 --- a/src/net/http/netconn_test.go +++ b/src/net/http/netconn_test.go @@ -19,9 +19,10 @@ import ( func fakeNetListen() *fakeNetListener { li := &fakeNetListener{ - setc: make(chan struct{}, 1), - unsetc: make(chan struct{}, 1), - addr: net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000")), + setc: make(chan struct{}, 1), + unsetc: make(chan struct{}, 1), + addr: netip.MustParseAddrPort("127.0.0.1:8000"), + locPort: 10000, } li.unsetc <- struct{}{} return li @@ -31,7 +32,13 @@ type fakeNetListener struct { setc, unsetc chan struct{} queue []net.Conn closed bool - addr net.Addr + addr netip.AddrPort + locPort uint16 + + onDial func() // called when making a new connection + + trackConns bool // set this to record all created conns + conns []*fakeNetConn } func (li *fakeNetListener) lock() { @@ -50,10 +57,18 @@ func (li *fakeNetListener) unlock() { } func (li *fakeNetListener) connect() *fakeNetConn { + if li.onDial != nil { + li.onDial() + } li.lock() defer li.unlock() - c0, c1 := fakeNetPipe() + locAddr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), li.locPort) + li.locPort++ + c0, c1 := fakeNetPipe(li.addr, locAddr) li.queue = append(li.queue, c0) + if li.trackConns { + li.conns = append(li.conns, c0) + } return c1 } @@ -76,7 +91,7 @@ func (li *fakeNetListener) Close() error { } func (li *fakeNetListener) Addr() net.Addr { - return li.addr + return net.TCPAddrFromAddrPort(li.addr) } // fakeNetPipe creates an in-memory, full duplex network connection. @@ -84,13 +99,16 @@ func (li *fakeNetListener) Addr() net.Addr { // Unlike net.Pipe, the connection is not synchronous. // Writes are made to a buffer, and return immediately. // By default, the buffer size is unlimited. -func fakeNetPipe() (r, w *fakeNetConn) { - s1addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000")) - s2addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8001")) +func fakeNetPipe(s1ap, s2ap netip.AddrPort) (r, w *fakeNetConn) { + s1addr := net.TCPAddrFromAddrPort(s1ap) + s2addr := net.TCPAddrFromAddrPort(s2ap) s1 := newSynctestNetConnHalf(s1addr) s2 := newSynctestNetConnHalf(s2addr) - return &fakeNetConn{loc: s1, rem: s2}, - &fakeNetConn{loc: s2, rem: s1} + c1 := &fakeNetConn{loc: s1, rem: s2} + c2 := &fakeNetConn{loc: s2, rem: s1} + c1.peer = c2 + c2.peer = c1 + return c1, c2 } // A fakeNetConn is one endpoint of the connection created by fakeNetPipe. @@ -102,6 +120,11 @@ type fakeNetConn struct { // When set, synctest.Wait is automatically called before reads and after writes. autoWait bool + + // peer is the other endpoint. + peer *fakeNetConn + + onClose func() // called when closing } // Read reads data from the connection. @@ -143,6 +166,9 @@ func (c *fakeNetConn) IsClosedByPeer() bool { // Close closes the connection. func (c *fakeNetConn) Close() error { + if c.onClose != nil { + c.onClose() + } // Local half of the conn is now closed. c.loc.lock() c.loc.writeErr = net.ErrClosed diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index d742b78cf8..2963255b87 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -22,6 +22,7 @@ import ( "fmt" "go/token" "internal/nettrace" + "internal/synctest" "io" "log" mrand "math/rand" @@ -4219,53 +4220,175 @@ func TestTransportTraceGotConnH2IdleConns(t *testing.T) { wantIdle("after round trip", 1) } -func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { - run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode}) +// https://go.dev/issue/70515 +// +// When the first request on a new connection fails, we do not retry the request. +// If the first request on a connection races with IdleConnTimeout, +// we should not fail the request. +func TestTransportIdleConnRacesRequest(t *testing.T) { + // Use unencrypted HTTP/2, since the *tls.Conn interfers with our ability to + // block the connection closing. + runSynctest(t, testTransportIdleConnRacesRequest, []testMode{http1Mode, http2UnencryptedMode}) } -func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) { +func testTransportIdleConnRacesRequest(t testing.TB, mode testMode) { + if mode == http2UnencryptedMode { + t.Skip("remove skip when #70515 is fixed") + } + timeout := 1 * time.Millisecond + trFunc := func(tr *Transport) { + tr.IdleConnTimeout = timeout + } + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + }), trFunc, optFakeNet) + cst.li.trackConns = true + + // We want to put a connection into the pool which has never had a request made on it. + // + // Make a request and cancel it before the dial completes. + // Then complete the dial. + dialc := make(chan struct{}) + cst.li.onDial = func() { + <-dialc + } + ctx, cancel := context.WithCancel(context.Background()) + req1c := make(chan error) + go func() { + req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) + resp, err := cst.c.Do(req) + if err == nil { + resp.Body.Close() + } + req1c <- err + }() + // Wait for the connection attempt to start. + synctest.Wait() + // Cancel the request. + cancel() + synctest.Wait() + if err := <-req1c; err == nil { + t.Fatal("expected request to fail, but it succeeded") + } + // Unblock the dial, placing a new, unused connection into the Transport's pool. + close(dialc) + + // We want IdleConnTimeout to race with a new request. + // + // There's no perfect way to do this, but the following exercises the bug in #70515: + // Block net.Conn.Close, wait until IdleConnTimeout occurs, and make a request while + // the connection close is still blocked. + // + // First: Wait for IdleConnTimeout. The net.Conn.Close blocks. + synctest.Wait() + closec := make(chan struct{}) + cst.li.conns[0].peer.onClose = func() { + <-closec + } + time.Sleep(timeout) + synctest.Wait() + // Make a request, which will use a new connection (since the existing one is closing). + req2c := make(chan error) + go func() { + resp, err := cst.c.Get(cst.ts.URL) + if err == nil { + resp.Body.Close() + } + req2c <- err + }() + // Don't synctest.Wait here: The HTTP/1 transport closes the idle conn + // with a mutex held, and we'll end up in a deadlock. + close(closec) + if err := <-req2c; err != nil { + t.Fatalf("Get: %v", err) + } +} + +func TestTransportRemovesConnsAfterIdle(t *testing.T) { + runSynctest(t, testTransportRemovesConnsAfterIdle) +} +func testTransportRemovesConnsAfterIdle(t testing.TB, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - timeout := 1 * time.Millisecond - retry := true - for retry { - trFunc := func(tr *Transport) { - tr.MaxConnsPerHost = 1 - tr.MaxIdleConnsPerHost = 1 - tr.IdleConnTimeout = timeout - } - cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) + timeout := 1 * time.Second + trFunc := func(tr *Transport) { + tr.MaxConnsPerHost = 1 + tr.MaxIdleConnsPerHost = 1 + tr.IdleConnTimeout = timeout + } + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("X-Addr", r.RemoteAddr) + }), trFunc, optFakeNet) - retry = false - tooShort := func(err error) bool { - if err == nil || !strings.Contains(err.Error(), "use of closed network connection") { - return false - } - if !retry { - t.Helper() - t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout) - timeout *= 2 - retry = true - cst.close() - } - return true - } - - if _, err := cst.c.Get(cst.ts.URL); err != nil { - if tooShort(err) { - continue - } + // makeRequest returns the local address a request was made from + // (unique for each connection). + makeRequest := func() string { + resp, err := cst.c.Get(cst.ts.URL) + if err != nil { t.Fatalf("got error: %s", err) } + resp.Body.Close() + return resp.Header.Get("X-Addr") + } - time.Sleep(10 * timeout) - if _, err := cst.c.Get(cst.ts.URL); err != nil { - if tooShort(err) { - continue - } + addr1 := makeRequest() + + time.Sleep(timeout / 2) + synctest.Wait() + addr2 := makeRequest() + if addr1 != addr2 { + t.Fatalf("two requests made within IdleConnTimeout should have used the same conn, but used %v, %v", addr1, addr2) + } + + time.Sleep(timeout) + synctest.Wait() + addr3 := makeRequest() + if addr1 == addr3 { + t.Fatalf("two requests made more than IdleConnTimeout apart should have used different conns, but used %v, %v", addr1, addr3) + } +} + +func TestTransportRemovesConnsAfterBroken(t *testing.T) { + runSynctest(t, testTransportRemovesConnsAfterBroken) +} +func testTransportRemovesConnsAfterBroken(t testing.TB, mode testMode) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + trFunc := func(tr *Transport) { + tr.MaxConnsPerHost = 1 + tr.MaxIdleConnsPerHost = 1 + } + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("X-Addr", r.RemoteAddr) + }), trFunc, optFakeNet) + cst.li.trackConns = true + + // makeRequest returns the local address a request was made from + // (unique for each connection). + makeRequest := func() string { + resp, err := cst.c.Get(cst.ts.URL) + if err != nil { t.Fatalf("got error: %s", err) } + resp.Body.Close() + return resp.Header.Get("X-Addr") + } + + addr1 := makeRequest() + addr2 := makeRequest() + if addr1 != addr2 { + t.Fatalf("successive requests should have used the same conn, but used %v, %v", addr1, addr2) + } + + // The connection breaks. + synctest.Wait() + cst.li.conns[0].peer.Close() + synctest.Wait() + addr3 := makeRequest() + if addr1 == addr3 { + t.Fatalf("successive requests made with conn broken between should have used different conns, but used %v, %v", addr1, addr3) } }