From a600540c0617ac0c6dd6bdb1407a1b20ad435d82 Mon Sep 17 00:00:00 2001 From: Charlie Getzen Date: Fri, 5 Nov 2021 11:32:58 -0500 Subject: [PATCH] tests use single server for multiple requests --- src/net/http/export_test.go | 5 ++--- src/net/http/serve_test.go | 42 +++++++++++++++++++++++-------------- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go index d633ca26d5..a849327f45 100644 --- a/src/net/http/export_test.go +++ b/src/net/http/export_test.go @@ -88,13 +88,12 @@ func SetPendingDialHooks(before, after func()) { func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn } -func NewTestTimeoutHandler(handler Handler, t time.Duration) (Handler, context.CancelFunc) { - ctx, cancel := context.WithTimeout(context.Background(), t) +func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler { return &timeoutHandler{ handler: handler, testContext: ctx, // (no body) - }, cancel + } } func ResetCachedEnvironment() { diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index b1083c7213..e320838b9c 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -2273,6 +2273,18 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) { } } +// cancelableTimeoutContext overwrites the error message to DeadlineExceeded +type cancelableTimeoutContext struct { + context.Context +} + +func (c cancelableTimeoutContext) Err() error { + if c.Context.Err() != nil { + return context.DeadlineExceeded + } + return nil +} + func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) } func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) } func testTimeoutHandler(t *testing.T, h2 bool) { @@ -2285,8 +2297,10 @@ func testTimeoutHandler(t *testing.T, h2 bool) { _, werr := w.Write([]byte("hi")) writeErrors <- werr }) - h, cancel := NewTestTimeoutHandler(sayHi, 1*time.Second) + ctx, cancel := context.WithCancel(context.Background()) + h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) cst := newClientServerTest(t, h2, h) + defer cst.close() // Succeed without timing out: sendHi <- true @@ -2305,14 +2319,9 @@ func testTimeoutHandler(t *testing.T, h2 bool) { t.Errorf("got unexpected Write error on first request: %v", g) } - cancel() - cst.close() - // Times out: - h, cancel = NewTestTimeoutHandler(sayHi, 0*time.Second) - defer cancel() - cst = newClientServerTest(t, h1Mode, h) - defer cst.close() + cancel() + res, err = cst.c.Get(cst.ts.URL) if err != nil { t.Error(err) @@ -2433,8 +2442,10 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { _, werr := w.Write([]byte("hi")) writeErrors <- werr }) - h, cancel := NewTestTimeoutHandler(sayHi, 1*time.Second) + ctx, cancel := context.WithCancel(context.Background()) + h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) cst := newClientServerTest(t, h1Mode, h) + defer cst.close() // Succeed without timing out: sendHi <- true @@ -2452,14 +2463,9 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { if g := <-writeErrors; g != nil { t.Errorf("got unexpected Write error on first request: %v", g) } - cancel() - cst.close() // Times out: - h, cancel = NewTestTimeoutHandler(sayHi, 0*time.Second) - defer cancel() - cst = newClientServerTest(t, h1Mode, h) - defer cst.close() + cancel() res, err = cst.c.Get(cst.ts.URL) if err != nil { @@ -2521,7 +2527,8 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) { _, werr := w.Write([]byte("hi")) writeErrors <- werr }) - h, cancel := NewTestTimeoutHandler(sayHi, 1*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Hour) + h := NewTestTimeoutHandler(sayHi, ctx) cancel() cst := newClientServerTest(t, h1Mode, h) defer cst.close() @@ -2532,6 +2539,9 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) { if err != nil { t.Error(err) } + if g, e := res.StatusCode, StatusServiceUnavailable; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } body, _ := io.ReadAll(res.Body) if g, e := string(body), ""; g != e { t.Errorf("got body %q; expected %q", g, e)