mirror of https://github.com/golang/go.git
7372 lines
199 KiB
Go
7372 lines
199 KiB
Go
// Copyright 2011 The Go Authors. All rights reserved.
|
||
// Use of this source code is governed by a BSD-style
|
||
// license that can be found in the LICENSE file.
|
||
|
||
// Tests for transport.go.
|
||
//
|
||
// More tests are in clientserver_test.go (for things testing both client & server for both
|
||
// HTTP/1 and HTTP/2). This
|
||
|
||
package http_test
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"compress/gzip"
|
||
"context"
|
||
"crypto/rand"
|
||
"crypto/tls"
|
||
"crypto/x509"
|
||
"encoding/binary"
|
||
"errors"
|
||
"fmt"
|
||
"go/token"
|
||
"internal/nettrace"
|
||
"io"
|
||
"log"
|
||
mrand "math/rand"
|
||
"net"
|
||
. "net/http"
|
||
"net/http/httptest"
|
||
"net/http/httptrace"
|
||
"net/http/httputil"
|
||
"net/http/internal/testcert"
|
||
"net/textproto"
|
||
"net/url"
|
||
"os"
|
||
"reflect"
|
||
"runtime"
|
||
"slices"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"testing"
|
||
"testing/iotest"
|
||
"time"
|
||
|
||
"golang.org/x/net/http/httpguts"
|
||
)
|
||
|
||
// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close
|
||
// and then verify that the final 2 responses get errors back.
|
||
|
||
// hostPortHandler writes back the client's "host:port".
|
||
var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
if r.FormValue("close") == "true" {
|
||
w.Header().Set("Connection", "close")
|
||
}
|
||
w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
|
||
w.Write([]byte(r.RemoteAddr))
|
||
|
||
// Include the address of the net.Conn in addition to the RemoteAddr,
|
||
// in case kernels reuse source ports quickly (see Issue 52450)
|
||
if c, ok := ResponseWriterConnForTesting(w); ok {
|
||
fmt.Fprintf(w, ", %T %p", c, c)
|
||
}
|
||
})
|
||
|
||
// testCloseConn is a net.Conn tracked by a testConnSet.
|
||
type testCloseConn struct {
|
||
net.Conn
|
||
set *testConnSet
|
||
}
|
||
|
||
func (c *testCloseConn) Close() error {
|
||
c.set.remove(c)
|
||
return c.Conn.Close()
|
||
}
|
||
|
||
// testConnSet tracks a set of TCP connections and whether they've
|
||
// been closed.
|
||
type testConnSet struct {
|
||
t *testing.T
|
||
mu sync.Mutex // guards closed and list
|
||
closed map[net.Conn]bool
|
||
list []net.Conn // in order created
|
||
}
|
||
|
||
func (tcs *testConnSet) insert(c net.Conn) {
|
||
tcs.mu.Lock()
|
||
defer tcs.mu.Unlock()
|
||
tcs.closed[c] = false
|
||
tcs.list = append(tcs.list, c)
|
||
}
|
||
|
||
func (tcs *testConnSet) remove(c net.Conn) {
|
||
tcs.mu.Lock()
|
||
defer tcs.mu.Unlock()
|
||
tcs.closed[c] = true
|
||
}
|
||
|
||
// some tests use this to manage raw tcp connections for later inspection
|
||
func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
|
||
connSet := &testConnSet{
|
||
t: t,
|
||
closed: make(map[net.Conn]bool),
|
||
}
|
||
dial := func(n, addr string) (net.Conn, error) {
|
||
c, err := net.Dial(n, addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
tc := &testCloseConn{c, connSet}
|
||
connSet.insert(tc)
|
||
return tc, nil
|
||
}
|
||
return connSet, dial
|
||
}
|
||
|
||
func (tcs *testConnSet) check(t *testing.T) {
|
||
tcs.mu.Lock()
|
||
defer tcs.mu.Unlock()
|
||
for i := 4; i >= 0; i-- {
|
||
for i, c := range tcs.list {
|
||
if tcs.closed[c] {
|
||
continue
|
||
}
|
||
if i != 0 {
|
||
// TODO(bcmills): What is the Sleep here doing, and why is this
|
||
// Unlock/Sleep/Lock cycle needed at all?
|
||
tcs.mu.Unlock()
|
||
time.Sleep(50 * time.Millisecond)
|
||
tcs.mu.Lock()
|
||
continue
|
||
}
|
||
t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
|
||
func testReuseRequest(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
w.Write([]byte("{}"))
|
||
})).ts
|
||
|
||
c := ts.Client()
|
||
req, _ := NewRequest("GET", ts.URL, nil)
|
||
res, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
err = res.Body.Close()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
res, err = c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
err = res.Body.Close()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
}
|
||
|
||
// Two subsequent requests and verify their response is the same.
|
||
// The response from the server is our own IP:port
|
||
func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
|
||
func testTransportKeepAlives(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, hostPortHandler).ts
|
||
|
||
c := ts.Client()
|
||
for _, disableKeepAlive := range []bool{false, true} {
|
||
c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
|
||
fetch := func(n int) string {
|
||
res, err := c.Get(ts.URL)
|
||
if err != nil {
|
||
t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
|
||
}
|
||
body, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
|
||
}
|
||
return string(body)
|
||
}
|
||
|
||
body1 := fetch(1)
|
||
body2 := fetch(2)
|
||
|
||
bodiesDiffer := body1 != body2
|
||
if bodiesDiffer != disableKeepAlive {
|
||
t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
|
||
disableKeepAlive, bodiesDiffer, body1, body2)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestTransportConnectionCloseOnResponse(t *testing.T) {
|
||
run(t, testTransportConnectionCloseOnResponse)
|
||
}
|
||
func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, hostPortHandler).ts
|
||
|
||
connSet, testDial := makeTestDial(t)
|
||
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
tr.Dial = testDial
|
||
|
||
for _, connectionClose := range []bool{false, true} {
|
||
fetch := func(n int) string {
|
||
req := new(Request)
|
||
var err error
|
||
req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
|
||
if err != nil {
|
||
t.Fatalf("URL parse error: %v", err)
|
||
}
|
||
req.Method = "GET"
|
||
req.Proto = "HTTP/1.1"
|
||
req.ProtoMajor = 1
|
||
req.ProtoMinor = 1
|
||
|
||
res, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
|
||
}
|
||
defer res.Body.Close()
|
||
body, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
|
||
}
|
||
return string(body)
|
||
}
|
||
|
||
body1 := fetch(1)
|
||
body2 := fetch(2)
|
||
bodiesDiffer := body1 != body2
|
||
if bodiesDiffer != connectionClose {
|
||
t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
|
||
connectionClose, bodiesDiffer, body1, body2)
|
||
}
|
||
|
||
tr.CloseIdleConnections()
|
||
}
|
||
|
||
connSet.check(t)
|
||
}
|
||
|
||
// TestTransportConnectionCloseOnRequest tests that the Transport's doesn't reuse
|
||
// an underlying TCP connection after making an http.Request with Request.Close set.
|
||
//
|
||
// It tests the behavior by making an HTTP request to a server which
|
||
// describes the source connection it got (remote port number +
|
||
// address of its net.Conn).
|
||
func TestTransportConnectionCloseOnRequest(t *testing.T) {
|
||
run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
|
||
}
|
||
func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, hostPortHandler).ts
|
||
|
||
connSet, testDial := makeTestDial(t)
|
||
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
tr.Dial = testDial
|
||
for _, reqClose := range []bool{false, true} {
|
||
fetch := func(n int) string {
|
||
req := new(Request)
|
||
var err error
|
||
req.URL, err = url.Parse(ts.URL)
|
||
if err != nil {
|
||
t.Fatalf("URL parse error: %v", err)
|
||
}
|
||
req.Method = "GET"
|
||
req.Proto = "HTTP/1.1"
|
||
req.ProtoMajor = 1
|
||
req.ProtoMinor = 1
|
||
req.Close = reqClose
|
||
|
||
res, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
|
||
}
|
||
if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
|
||
t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
|
||
reqClose, got, !reqClose)
|
||
}
|
||
body, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
|
||
}
|
||
return string(body)
|
||
}
|
||
|
||
body1 := fetch(1)
|
||
body2 := fetch(2)
|
||
|
||
got := 1
|
||
if body1 != body2 {
|
||
got++
|
||
}
|
||
want := 1
|
||
if reqClose {
|
||
want = 2
|
||
}
|
||
if got != want {
|
||
t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
|
||
reqClose, got, want, body1, body2)
|
||
}
|
||
|
||
tr.CloseIdleConnections()
|
||
}
|
||
|
||
connSet.check(t)
|
||
}
|
||
|
||
// if the Transport's DisableKeepAlives is set, all requests should
|
||
// send Connection: close.
|
||
// HTTP/1-only (Connection: close doesn't exist in h2)
|
||
func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
|
||
run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
|
||
}
|
||
func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, hostPortHandler).ts
|
||
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).DisableKeepAlives = true
|
||
|
||
res, err := c.Get(ts.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
res.Body.Close()
|
||
if res.Header.Get("X-Saw-Close") != "true" {
|
||
t.Errorf("handler didn't see Connection: close ")
|
||
}
|
||
}
|
||
|
||
// Test that Transport only sends one "Connection: close", regardless of
|
||
// how "close" was indicated.
|
||
func TestTransportRespectRequestWantsClose(t *testing.T) {
|
||
run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
|
||
}
|
||
func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
|
||
tests := []struct {
|
||
disableKeepAlives bool
|
||
close bool
|
||
}{
|
||
{disableKeepAlives: false, close: false},
|
||
{disableKeepAlives: false, close: true},
|
||
{disableKeepAlives: true, close: false},
|
||
{disableKeepAlives: true, close: true},
|
||
}
|
||
|
||
for _, tc := range tests {
|
||
t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
|
||
func(t *testing.T) {
|
||
ts := newClientServerTest(t, mode, hostPortHandler).ts
|
||
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
|
||
req, err := NewRequest("GET", ts.URL, nil)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
count := 0
|
||
trace := &httptrace.ClientTrace{
|
||
WroteHeaderField: func(key string, field []string) {
|
||
if key != "Connection" {
|
||
return
|
||
}
|
||
if httpguts.HeaderValuesContainsToken(field, "close") {
|
||
count += 1
|
||
}
|
||
},
|
||
}
|
||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
|
||
req.Close = tc.close
|
||
res, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer res.Body.Close()
|
||
if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
|
||
t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
|
||
}
|
||
})
|
||
}
|
||
|
||
}
|
||
|
||
func TestTransportIdleCacheKeys(t *testing.T) {
|
||
run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
|
||
}
|
||
func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, hostPortHandler).ts
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
|
||
if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
|
||
t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
|
||
}
|
||
|
||
resp, err := c.Get(ts.URL)
|
||
if err != nil {
|
||
t.Error(err)
|
||
}
|
||
io.ReadAll(resp.Body)
|
||
|
||
keys := tr.IdleConnKeysForTesting()
|
||
if e, g := 1, len(keys); e != g {
|
||
t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
|
||
}
|
||
|
||
if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
|
||
t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
|
||
}
|
||
|
||
tr.CloseIdleConnections()
|
||
if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
|
||
t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
|
||
}
|
||
}
|
||
|
||
// Tests that the HTTP transport re-uses connections when a client
|
||
// reads to the end of a response Body without closing it.
|
||
func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
|
||
func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
|
||
const msg = "foobar"
|
||
|
||
var addrSeen map[string]int
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
addrSeen[r.RemoteAddr]++
|
||
if r.URL.Path == "/chunked/" {
|
||
w.WriteHeader(200)
|
||
w.(Flusher).Flush()
|
||
} else {
|
||
w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
|
||
w.WriteHeader(200)
|
||
}
|
||
w.Write([]byte(msg))
|
||
})).ts
|
||
|
||
for pi, path := range []string{"/content-length/", "/chunked/"} {
|
||
wantLen := []int{len(msg), -1}[pi]
|
||
addrSeen = make(map[string]int)
|
||
for i := 0; i < 3; i++ {
|
||
res, err := ts.Client().Get(ts.URL + path)
|
||
if err != nil {
|
||
t.Errorf("Get %s: %v", path, err)
|
||
continue
|
||
}
|
||
// We want to close this body eventually (before the
|
||
// defer afterTest at top runs), but not before the
|
||
// len(addrSeen) check at the bottom of this test,
|
||
// since Closing this early in the loop would risk
|
||
// making connections be re-used for the wrong reason.
|
||
defer res.Body.Close()
|
||
|
||
if res.ContentLength != int64(wantLen) {
|
||
t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
|
||
}
|
||
got, err := io.ReadAll(res.Body)
|
||
if string(got) != msg || err != nil {
|
||
t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
|
||
}
|
||
}
|
||
if len(addrSeen) != 1 {
|
||
t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestTransportMaxPerHostIdleConns(t *testing.T) {
|
||
run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
|
||
}
|
||
func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
|
||
stop := make(chan struct{}) // stop marks the exit of main Test goroutine
|
||
defer close(stop)
|
||
|
||
resch := make(chan string)
|
||
gotReq := make(chan bool)
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
gotReq <- true
|
||
var msg string
|
||
select {
|
||
case <-stop:
|
||
return
|
||
case msg = <-resch:
|
||
}
|
||
_, err := w.Write([]byte(msg))
|
||
if err != nil {
|
||
t.Errorf("Write: %v", err)
|
||
return
|
||
}
|
||
})).ts
|
||
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
maxIdleConnsPerHost := 2
|
||
tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
|
||
|
||
// Start 3 outstanding requests and wait for the server to get them.
|
||
// Their responses will hang until we write to resch, though.
|
||
donech := make(chan bool)
|
||
doReq := func() {
|
||
defer func() {
|
||
select {
|
||
case <-stop:
|
||
return
|
||
case donech <- t.Failed():
|
||
}
|
||
}()
|
||
resp, err := c.Get(ts.URL)
|
||
if err != nil {
|
||
t.Error(err)
|
||
return
|
||
}
|
||
if _, err := io.ReadAll(resp.Body); err != nil {
|
||
t.Errorf("ReadAll: %v", err)
|
||
return
|
||
}
|
||
}
|
||
go doReq()
|
||
<-gotReq
|
||
go doReq()
|
||
<-gotReq
|
||
go doReq()
|
||
<-gotReq
|
||
|
||
if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
|
||
t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
|
||
}
|
||
|
||
resch <- "res1"
|
||
<-donech
|
||
keys := tr.IdleConnKeysForTesting()
|
||
if e, g := 1, len(keys); e != g {
|
||
t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
|
||
}
|
||
addr := ts.Listener.Addr().String()
|
||
cacheKey := "|http|" + addr
|
||
if keys[0] != cacheKey {
|
||
t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
|
||
}
|
||
if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
|
||
t.Errorf("after first response, expected %d idle conns; got %d", e, g)
|
||
}
|
||
|
||
resch <- "res2"
|
||
<-donech
|
||
if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
|
||
t.Errorf("after second response, idle conns = %d; want %d", g, w)
|
||
}
|
||
|
||
resch <- "res3"
|
||
<-donech
|
||
if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
|
||
t.Errorf("after third response, idle conns = %d; want %d", g, w)
|
||
}
|
||
}
|
||
|
||
func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
|
||
run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
|
||
}
|
||
func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
_, err := w.Write([]byte("foo"))
|
||
if err != nil {
|
||
t.Fatalf("Write: %v", err)
|
||
}
|
||
})).ts
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
dialStarted := make(chan struct{})
|
||
stallDial := make(chan struct{})
|
||
tr.Dial = func(network, addr string) (net.Conn, error) {
|
||
dialStarted <- struct{}{}
|
||
<-stallDial
|
||
return net.Dial(network, addr)
|
||
}
|
||
|
||
tr.DisableKeepAlives = true
|
||
tr.MaxConnsPerHost = 1
|
||
|
||
preDial := make(chan struct{})
|
||
reqComplete := make(chan struct{})
|
||
doReq := func(reqId string) {
|
||
req, _ := NewRequest("GET", ts.URL, nil)
|
||
trace := &httptrace.ClientTrace{
|
||
GetConn: func(hostPort string) {
|
||
preDial <- struct{}{}
|
||
},
|
||
}
|
||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
|
||
resp, err := tr.RoundTrip(req)
|
||
if err != nil {
|
||
t.Errorf("unexpected error for request %s: %v", reqId, err)
|
||
}
|
||
_, err = io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
t.Errorf("unexpected error for request %s: %v", reqId, err)
|
||
}
|
||
reqComplete <- struct{}{}
|
||
}
|
||
// get req1 to dial-in-progress
|
||
go doReq("req1")
|
||
<-preDial
|
||
<-dialStarted
|
||
|
||
// get req2 to waiting on conns per host to go down below max
|
||
go doReq("req2")
|
||
<-preDial
|
||
select {
|
||
case <-dialStarted:
|
||
t.Error("req2 dial started while req1 dial in progress")
|
||
return
|
||
default:
|
||
}
|
||
|
||
// let req1 complete
|
||
stallDial <- struct{}{}
|
||
<-reqComplete
|
||
|
||
// let req2 complete
|
||
<-dialStarted
|
||
stallDial <- struct{}{}
|
||
<-reqComplete
|
||
}
|
||
|
||
func TestTransportMaxConnsPerHost(t *testing.T) {
|
||
run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
|
||
}
|
||
func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
|
||
CondSkipHTTP2(t)
|
||
|
||
h := HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
_, err := w.Write([]byte("foo"))
|
||
if err != nil {
|
||
t.Fatalf("Write: %v", err)
|
||
}
|
||
})
|
||
|
||
ts := newClientServerTest(t, mode, h).ts
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
tr.MaxConnsPerHost = 1
|
||
|
||
mu := sync.Mutex{}
|
||
var conns []net.Conn
|
||
var dialCnt, gotConnCnt, tlsHandshakeCnt int32
|
||
tr.Dial = func(network, addr string) (net.Conn, error) {
|
||
atomic.AddInt32(&dialCnt, 1)
|
||
c, err := net.Dial(network, addr)
|
||
mu.Lock()
|
||
defer mu.Unlock()
|
||
conns = append(conns, c)
|
||
return c, err
|
||
}
|
||
|
||
doReq := func() {
|
||
trace := &httptrace.ClientTrace{
|
||
GotConn: func(connInfo httptrace.GotConnInfo) {
|
||
if !connInfo.Reused {
|
||
atomic.AddInt32(&gotConnCnt, 1)
|
||
}
|
||
},
|
||
TLSHandshakeStart: func() {
|
||
atomic.AddInt32(&tlsHandshakeCnt, 1)
|
||
},
|
||
}
|
||
req, _ := NewRequest("GET", ts.URL, nil)
|
||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
|
||
|
||
resp, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatalf("request failed: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
_, err = io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
t.Fatalf("read body failed: %v", err)
|
||
}
|
||
}
|
||
|
||
wg := sync.WaitGroup{}
|
||
for i := 0; i < 10; i++ {
|
||
wg.Add(1)
|
||
go func() {
|
||
defer wg.Done()
|
||
doReq()
|
||
}()
|
||
}
|
||
wg.Wait()
|
||
|
||
expected := int32(tr.MaxConnsPerHost)
|
||
if dialCnt != expected {
|
||
t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
|
||
}
|
||
if gotConnCnt != expected {
|
||
t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
|
||
}
|
||
if ts.TLS != nil && tlsHandshakeCnt != expected {
|
||
t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
|
||
}
|
||
|
||
if t.Failed() {
|
||
t.FailNow()
|
||
}
|
||
|
||
mu.Lock()
|
||
for _, c := range conns {
|
||
c.Close()
|
||
}
|
||
conns = nil
|
||
mu.Unlock()
|
||
tr.CloseIdleConnections()
|
||
|
||
doReq()
|
||
expected++
|
||
if dialCnt != expected {
|
||
t.Errorf("round 2: too many dials: %d", dialCnt)
|
||
}
|
||
if gotConnCnt != expected {
|
||
t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
|
||
}
|
||
if ts.TLS != nil && tlsHandshakeCnt != expected {
|
||
t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
|
||
}
|
||
}
|
||
|
||
func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
|
||
run(t, testTransportMaxConnsPerHostDialCancellation,
|
||
testNotParallel, // because test uses SetPendingDialHooks
|
||
[]testMode{http1Mode, https1Mode, http2Mode},
|
||
)
|
||
}
|
||
|
||
func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
|
||
CondSkipHTTP2(t)
|
||
|
||
h := HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
_, err := w.Write([]byte("foo"))
|
||
if err != nil {
|
||
t.Fatalf("Write: %v", err)
|
||
}
|
||
})
|
||
|
||
cst := newClientServerTest(t, mode, h)
|
||
defer cst.close()
|
||
ts := cst.ts
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
tr.MaxConnsPerHost = 1
|
||
|
||
// This request is canceled when dial is queued, which preempts dialing.
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
SetPendingDialHooks(cancel, nil)
|
||
defer SetPendingDialHooks(nil, nil)
|
||
|
||
req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
|
||
_, err := c.Do(req)
|
||
if !errors.Is(err, context.Canceled) {
|
||
t.Errorf("expected error %v, got %v", context.Canceled, err)
|
||
}
|
||
|
||
// This request should succeed.
|
||
SetPendingDialHooks(nil, nil)
|
||
req, _ = NewRequest("GET", ts.URL, nil)
|
||
resp, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatalf("request failed: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
_, err = io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
t.Fatalf("read body failed: %v", err)
|
||
}
|
||
}
|
||
|
||
func TestTransportRemovesDeadIdleConnections(t *testing.T) {
|
||
run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
|
||
}
|
||
func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
io.WriteString(w, r.RemoteAddr)
|
||
})).ts
|
||
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
|
||
doReq := func(name string) {
|
||
// Do a POST instead of a GET to prevent the Transport's
|
||
// idempotent request retry logic from kicking in...
|
||
res, err := c.Post(ts.URL, "", nil)
|
||
if err != nil {
|
||
t.Fatalf("%s: %v", name, err)
|
||
}
|
||
if res.StatusCode != 200 {
|
||
t.Fatalf("%s: %v", name, res.Status)
|
||
}
|
||
defer res.Body.Close()
|
||
slurp, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
t.Fatalf("%s: %v", name, err)
|
||
}
|
||
t.Logf("%s: ok (%q)", name, slurp)
|
||
}
|
||
|
||
doReq("first")
|
||
keys1 := tr.IdleConnKeysForTesting()
|
||
|
||
ts.CloseClientConnections()
|
||
|
||
var keys2 []string
|
||
waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
|
||
keys2 = tr.IdleConnKeysForTesting()
|
||
if len(keys2) != 0 {
|
||
if d > 0 {
|
||
t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
|
||
}
|
||
return false
|
||
}
|
||
return true
|
||
})
|
||
|
||
doReq("second")
|
||
}
|
||
|
||
// Test that the Transport notices when a server hangs up on its
|
||
// unexpectedly (a keep-alive connection is closed).
|
||
func TestTransportServerClosingUnexpectedly(t *testing.T) {
|
||
run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
|
||
}
|
||
func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, hostPortHandler).ts
|
||
c := ts.Client()
|
||
|
||
fetch := func(n, retries int) string {
|
||
condFatalf := func(format string, arg ...any) {
|
||
if retries <= 0 {
|
||
t.Fatalf(format, arg...)
|
||
}
|
||
t.Logf("retrying shortly after expected error: "+format, arg...)
|
||
time.Sleep(time.Second / time.Duration(retries))
|
||
}
|
||
for retries >= 0 {
|
||
retries--
|
||
res, err := c.Get(ts.URL)
|
||
if err != nil {
|
||
condFatalf("error in req #%d, GET: %v", n, err)
|
||
continue
|
||
}
|
||
body, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
condFatalf("error in req #%d, ReadAll: %v", n, err)
|
||
continue
|
||
}
|
||
res.Body.Close()
|
||
return string(body)
|
||
}
|
||
panic("unreachable")
|
||
}
|
||
|
||
body1 := fetch(1, 0)
|
||
body2 := fetch(2, 0)
|
||
|
||
// Close all the idle connections in a way that's similar to
|
||
// the server hanging up on us. We don't use
|
||
// httptest.Server.CloseClientConnections because it's
|
||
// best-effort and stops blocking after 5 seconds. On a loaded
|
||
// machine running many tests concurrently it's possible for
|
||
// that method to be async and cause the body3 fetch below to
|
||
// run on an old connection. This function is synchronous.
|
||
ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
|
||
|
||
body3 := fetch(3, 5)
|
||
|
||
if body1 != body2 {
|
||
t.Errorf("expected body1 and body2 to be equal")
|
||
}
|
||
if body2 == body3 {
|
||
t.Errorf("expected body2 and body3 to be different")
|
||
}
|
||
}
|
||
|
||
// Test for https://golang.org/issue/2616 (appropriate issue number)
|
||
// This fails pretty reliably with GOMAXPROCS=100 or something high.
|
||
func TestStressSurpriseServerCloses(t *testing.T) {
|
||
run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
|
||
}
|
||
func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
|
||
if testing.Short() {
|
||
t.Skip("skipping test in short mode")
|
||
}
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
w.Header().Set("Content-Length", "5")
|
||
w.Header().Set("Content-Type", "text/plain")
|
||
w.Write([]byte("Hello"))
|
||
w.(Flusher).Flush()
|
||
conn, buf, _ := w.(Hijacker).Hijack()
|
||
buf.Flush()
|
||
conn.Close()
|
||
})).ts
|
||
c := ts.Client()
|
||
|
||
// Do a bunch of traffic from different goroutines. Send to activityc
|
||
// after each request completes, regardless of whether it failed.
|
||
// If these are too high, OS X exhausts its ephemeral ports
|
||
// and hangs waiting for them to transition TCP states. That's
|
||
// not what we want to test. TODO(bradfitz): use an io.Pipe
|
||
// dialer for this test instead?
|
||
const (
|
||
numClients = 20
|
||
reqsPerClient = 25
|
||
)
|
||
var wg sync.WaitGroup
|
||
wg.Add(numClients * reqsPerClient)
|
||
for i := 0; i < numClients; i++ {
|
||
go func() {
|
||
for i := 0; i < reqsPerClient; i++ {
|
||
res, err := c.Get(ts.URL)
|
||
if err == nil {
|
||
// We expect errors since the server is
|
||
// hanging up on us after telling us to
|
||
// send more requests, so we don't
|
||
// actually care what the error is.
|
||
// But we want to close the body in cases
|
||
// where we won the race.
|
||
res.Body.Close()
|
||
}
|
||
wg.Done()
|
||
}
|
||
}()
|
||
}
|
||
|
||
// Make sure all the request come back, one way or another.
|
||
wg.Wait()
|
||
}
|
||
|
||
// TestTransportHeadResponses verifies that we deal with Content-Lengths
|
||
// with no bodies properly
|
||
func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
|
||
func testTransportHeadResponses(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
if r.Method != "HEAD" {
|
||
panic("expected HEAD; got " + r.Method)
|
||
}
|
||
w.Header().Set("Content-Length", "123")
|
||
w.WriteHeader(200)
|
||
})).ts
|
||
c := ts.Client()
|
||
|
||
for i := 0; i < 2; i++ {
|
||
res, err := c.Head(ts.URL)
|
||
if err != nil {
|
||
t.Errorf("error on loop %d: %v", i, err)
|
||
continue
|
||
}
|
||
if e, g := "123", res.Header.Get("Content-Length"); e != g {
|
||
t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
|
||
}
|
||
if e, g := int64(123), res.ContentLength; e != g {
|
||
t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
|
||
}
|
||
if all, err := io.ReadAll(res.Body); err != nil {
|
||
t.Errorf("loop %d: Body ReadAll: %v", i, err)
|
||
} else if len(all) != 0 {
|
||
t.Errorf("Bogus body %q", all)
|
||
}
|
||
}
|
||
}
|
||
|
||
// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding
|
||
// on responses to HEAD requests.
|
||
func TestTransportHeadChunkedResponse(t *testing.T) {
|
||
run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
|
||
}
|
||
func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
if r.Method != "HEAD" {
|
||
panic("expected HEAD; got " + r.Method)
|
||
}
|
||
w.Header().Set("Transfer-Encoding", "chunked") // client should ignore
|
||
w.Header().Set("x-client-ipport", r.RemoteAddr)
|
||
w.WriteHeader(200)
|
||
})).ts
|
||
c := ts.Client()
|
||
|
||
// Ensure that we wait for the readLoop to complete before
|
||
// calling Head again
|
||
didRead := make(chan bool)
|
||
SetReadLoopBeforeNextReadHook(func() { didRead <- true })
|
||
defer SetReadLoopBeforeNextReadHook(nil)
|
||
|
||
res1, err := c.Head(ts.URL)
|
||
<-didRead
|
||
|
||
if err != nil {
|
||
t.Fatalf("request 1 error: %v", err)
|
||
}
|
||
|
||
res2, err := c.Head(ts.URL)
|
||
<-didRead
|
||
|
||
if err != nil {
|
||
t.Fatalf("request 2 error: %v", err)
|
||
}
|
||
if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
|
||
t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
|
||
}
|
||
}
|
||
|
||
var roundTripTests = []struct {
|
||
accept string
|
||
expectAccept string
|
||
compressed bool
|
||
}{
|
||
// Requests with no accept-encoding header use transparent compression
|
||
{"", "gzip", false},
|
||
// Requests with other accept-encoding should pass through unmodified
|
||
{"foo", "foo", false},
|
||
// Requests with accept-encoding == gzip should be passed through
|
||
{"gzip", "gzip", true},
|
||
}
|
||
|
||
// Test that the modification made to the Request by the RoundTripper is cleaned up
|
||
func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
|
||
func testRoundTripGzip(t *testing.T, mode testMode) {
|
||
const responseBody = "test response body"
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||
accept := req.Header.Get("Accept-Encoding")
|
||
if expect := req.FormValue("expect_accept"); accept != expect {
|
||
t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
|
||
req.FormValue("testnum"), accept, expect)
|
||
}
|
||
if accept == "gzip" {
|
||
rw.Header().Set("Content-Encoding", "gzip")
|
||
gz := gzip.NewWriter(rw)
|
||
gz.Write([]byte(responseBody))
|
||
gz.Close()
|
||
} else {
|
||
rw.Header().Set("Content-Encoding", accept)
|
||
rw.Write([]byte(responseBody))
|
||
}
|
||
})).ts
|
||
tr := ts.Client().Transport.(*Transport)
|
||
|
||
for i, test := range roundTripTests {
|
||
// Test basic request (no accept-encoding)
|
||
req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
|
||
if test.accept != "" {
|
||
req.Header.Set("Accept-Encoding", test.accept)
|
||
}
|
||
res, err := tr.RoundTrip(req)
|
||
if err != nil {
|
||
t.Errorf("%d. RoundTrip: %v", i, err)
|
||
continue
|
||
}
|
||
var body []byte
|
||
if test.compressed {
|
||
var r *gzip.Reader
|
||
r, err = gzip.NewReader(res.Body)
|
||
if err != nil {
|
||
t.Errorf("%d. gzip NewReader: %v", i, err)
|
||
continue
|
||
}
|
||
body, err = io.ReadAll(r)
|
||
res.Body.Close()
|
||
} else {
|
||
body, err = io.ReadAll(res.Body)
|
||
}
|
||
if err != nil {
|
||
t.Errorf("%d. Error: %q", i, err)
|
||
continue
|
||
}
|
||
if g, e := string(body), responseBody; g != e {
|
||
t.Errorf("%d. body = %q; want %q", i, g, e)
|
||
}
|
||
if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
|
||
t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
|
||
}
|
||
if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
|
||
t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
|
||
}
|
||
}
|
||
|
||
}
|
||
|
||
func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
|
||
func testTransportGzip(t *testing.T, mode testMode) {
|
||
if mode == http2Mode {
|
||
t.Skip("https://go.dev/issue/56020")
|
||
}
|
||
const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||
const nRandBytes = 1024 * 1024
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||
if req.Method == "HEAD" {
|
||
if g := req.Header.Get("Accept-Encoding"); g != "" {
|
||
t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
|
||
}
|
||
return
|
||
}
|
||
if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
|
||
t.Errorf("Accept-Encoding = %q, want %q", g, e)
|
||
}
|
||
rw.Header().Set("Content-Encoding", "gzip")
|
||
|
||
var w io.Writer = rw
|
||
var buf bytes.Buffer
|
||
if req.FormValue("chunked") == "0" {
|
||
w = &buf
|
||
defer io.Copy(rw, &buf)
|
||
defer func() {
|
||
rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
|
||
}()
|
||
}
|
||
gz := gzip.NewWriter(w)
|
||
gz.Write([]byte(testString))
|
||
if req.FormValue("body") == "large" {
|
||
io.CopyN(gz, rand.Reader, nRandBytes)
|
||
}
|
||
gz.Close()
|
||
})).ts
|
||
c := ts.Client()
|
||
|
||
for _, chunked := range []string{"1", "0"} {
|
||
// First fetch something large, but only read some of it.
|
||
res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
|
||
if err != nil {
|
||
t.Fatalf("large get: %v", err)
|
||
}
|
||
buf := make([]byte, len(testString))
|
||
n, err := io.ReadFull(res.Body, buf)
|
||
if err != nil {
|
||
t.Fatalf("partial read of large response: size=%d, %v", n, err)
|
||
}
|
||
if e, g := testString, string(buf); e != g {
|
||
t.Errorf("partial read got %q, expected %q", g, e)
|
||
}
|
||
res.Body.Close()
|
||
// Read on the body, even though it's closed
|
||
n, err = res.Body.Read(buf)
|
||
if n != 0 || err == nil {
|
||
t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
|
||
}
|
||
|
||
// Then something small.
|
||
res, err = c.Get(ts.URL + "/?chunked=" + chunked)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
body, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if g, e := string(body), testString; g != e {
|
||
t.Fatalf("body = %q; want %q", g, e)
|
||
}
|
||
if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
|
||
t.Fatalf("Content-Encoding = %q; want %q", g, e)
|
||
}
|
||
|
||
// Read on the body after it's been fully read:
|
||
n, err = res.Body.Read(buf)
|
||
if n != 0 || err == nil {
|
||
t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
|
||
}
|
||
res.Body.Close()
|
||
n, err = res.Body.Read(buf)
|
||
if n != 0 || err == nil {
|
||
t.Errorf("expected Read error after Close; got %d, %v", n, err)
|
||
}
|
||
}
|
||
|
||
// And a HEAD request too, because they're always weird.
|
||
res, err := c.Head(ts.URL)
|
||
if err != nil {
|
||
t.Fatalf("Head: %v", err)
|
||
}
|
||
if res.StatusCode != 200 {
|
||
t.Errorf("Head status=%d; want=200", res.StatusCode)
|
||
}
|
||
}
|
||
|
||
// A transport100Continue test exercises Transport behaviors when sending a
|
||
// request with an Expect: 100-continue header.
|
||
type transport100ContinueTest struct {
|
||
t *testing.T
|
||
|
||
reqdone chan struct{}
|
||
resp *Response
|
||
respErr error
|
||
|
||
conn net.Conn
|
||
reader *bufio.Reader
|
||
}
|
||
|
||
const transport100ContinueTestBody = "request body"
|
||
|
||
// newTransport100ContinueTest creates a Transport and sends an Expect: 100-continue
|
||
// request on it.
|
||
func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
|
||
ln := newLocalListener(t)
|
||
defer ln.Close()
|
||
|
||
test := &transport100ContinueTest{
|
||
t: t,
|
||
reqdone: make(chan struct{}),
|
||
}
|
||
|
||
tr := &Transport{
|
||
ExpectContinueTimeout: timeout,
|
||
}
|
||
go func() {
|
||
defer close(test.reqdone)
|
||
body := strings.NewReader(transport100ContinueTestBody)
|
||
req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
|
||
req.Header.Set("Expect", "100-continue")
|
||
req.ContentLength = int64(len(transport100ContinueTestBody))
|
||
test.resp, test.respErr = tr.RoundTrip(req)
|
||
test.resp.Body.Close()
|
||
}()
|
||
|
||
c, err := ln.Accept()
|
||
if err != nil {
|
||
t.Fatalf("Accept: %v", err)
|
||
}
|
||
t.Cleanup(func() {
|
||
c.Close()
|
||
})
|
||
br := bufio.NewReader(c)
|
||
_, err = ReadRequest(br)
|
||
if err != nil {
|
||
t.Fatalf("ReadRequest: %v", err)
|
||
}
|
||
test.conn = c
|
||
test.reader = br
|
||
t.Cleanup(func() {
|
||
<-test.reqdone
|
||
tr.CloseIdleConnections()
|
||
got, _ := io.ReadAll(test.reader)
|
||
if len(got) > 0 {
|
||
t.Fatalf("Transport sent unexpected bytes: %q", got)
|
||
}
|
||
})
|
||
|
||
return test
|
||
}
|
||
|
||
// respond sends response lines from the server to the transport.
|
||
func (test *transport100ContinueTest) respond(lines ...string) {
|
||
for _, line := range lines {
|
||
if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
|
||
test.t.Fatalf("Write: %v", err)
|
||
}
|
||
}
|
||
if _, err := test.conn.Write([]byte("\r\n")); err != nil {
|
||
test.t.Fatalf("Write: %v", err)
|
||
}
|
||
}
|
||
|
||
// wantBodySent ensures the transport has sent the request body to the server.
|
||
func (test *transport100ContinueTest) wantBodySent() {
|
||
got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
|
||
if err != nil {
|
||
test.t.Fatalf("unexpected error reading body: %v", err)
|
||
}
|
||
if got, want := string(got), transport100ContinueTestBody; got != want {
|
||
test.t.Fatalf("unexpected body: got %q, want %q", got, want)
|
||
}
|
||
}
|
||
|
||
// wantRequestDone ensures the Transport.RoundTrip has completed with the expected status.
|
||
func (test *transport100ContinueTest) wantRequestDone(want int) {
|
||
<-test.reqdone
|
||
if test.respErr != nil {
|
||
test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
|
||
}
|
||
if got := test.resp.StatusCode; got != want {
|
||
test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
|
||
}
|
||
}
|
||
|
||
func TestTransportExpect100ContinueSent(t *testing.T) {
|
||
test := newTransport100ContinueTest(t, 1*time.Hour)
|
||
// Server sends a 100 Continue response, and the client sends the request body.
|
||
test.respond("HTTP/1.1 100 Continue")
|
||
test.wantBodySent()
|
||
test.respond("HTTP/1.1 200", "Content-Length: 0")
|
||
test.wantRequestDone(200)
|
||
}
|
||
|
||
func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
|
||
test := newTransport100ContinueTest(t, 1*time.Hour)
|
||
// No 100 Continue response, no Connection: close header.
|
||
test.respond("HTTP/1.1 200", "Content-Length: 0")
|
||
test.wantBodySent()
|
||
test.wantRequestDone(200)
|
||
}
|
||
|
||
func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
|
||
test := newTransport100ContinueTest(t, 1*time.Hour)
|
||
// No 100 Continue response, Connection: close header set.
|
||
test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
|
||
test.wantRequestDone(200)
|
||
}
|
||
|
||
func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
|
||
test := newTransport100ContinueTest(t, 1*time.Hour)
|
||
// No 100 Continue response, no Connection: close header.
|
||
test.respond("HTTP/1.1 500", "Content-Length: 0")
|
||
test.wantBodySent()
|
||
test.wantRequestDone(500)
|
||
}
|
||
|
||
func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
|
||
test := newTransport100ContinueTest(t, 5*time.Millisecond) // short timeout
|
||
test.wantBodySent() // after timeout
|
||
test.respond("HTTP/1.1 200", "Content-Length: 0")
|
||
test.wantRequestDone(200)
|
||
}
|
||
|
||
func TestSOCKS5Proxy(t *testing.T) {
|
||
run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
|
||
}
|
||
func testSOCKS5Proxy(t *testing.T, mode testMode) {
|
||
ch := make(chan string, 1)
|
||
l := newLocalListener(t)
|
||
defer l.Close()
|
||
defer close(ch)
|
||
proxy := func(t *testing.T) {
|
||
s, err := l.Accept()
|
||
if err != nil {
|
||
t.Errorf("socks5 proxy Accept(): %v", err)
|
||
return
|
||
}
|
||
defer s.Close()
|
||
var buf [22]byte
|
||
if _, err := io.ReadFull(s, buf[:3]); err != nil {
|
||
t.Errorf("socks5 proxy initial read: %v", err)
|
||
return
|
||
}
|
||
if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
|
||
t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
|
||
return
|
||
}
|
||
if _, err := s.Write([]byte{5, 0}); err != nil {
|
||
t.Errorf("socks5 proxy initial write: %v", err)
|
||
return
|
||
}
|
||
if _, err := io.ReadFull(s, buf[:4]); err != nil {
|
||
t.Errorf("socks5 proxy second read: %v", err)
|
||
return
|
||
}
|
||
if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
|
||
t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
|
||
return
|
||
}
|
||
var ipLen int
|
||
switch buf[3] {
|
||
case 1:
|
||
ipLen = net.IPv4len
|
||
case 4:
|
||
ipLen = net.IPv6len
|
||
default:
|
||
t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
|
||
return
|
||
}
|
||
if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
|
||
t.Errorf("socks5 proxy address read: %v", err)
|
||
return
|
||
}
|
||
ip := net.IP(buf[4 : ipLen+4])
|
||
port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
|
||
copy(buf[:3], []byte{5, 0, 0})
|
||
if _, err := s.Write(buf[:ipLen+6]); err != nil {
|
||
t.Errorf("socks5 proxy connect write: %v", err)
|
||
return
|
||
}
|
||
ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
|
||
|
||
// Implement proxying.
|
||
targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
|
||
targetConn, err := net.Dial("tcp", targetHost)
|
||
if err != nil {
|
||
t.Errorf("net.Dial failed")
|
||
return
|
||
}
|
||
go io.Copy(targetConn, s)
|
||
io.Copy(s, targetConn) // Wait for the client to close the socket.
|
||
targetConn.Close()
|
||
}
|
||
|
||
pu, err := url.Parse("socks5://" + l.Addr().String())
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
sentinelHeader := "X-Sentinel"
|
||
sentinelValue := "12345"
|
||
h := HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
w.Header().Set(sentinelHeader, sentinelValue)
|
||
})
|
||
for _, useTLS := range []bool{false, true} {
|
||
t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
|
||
ts := newClientServerTest(t, mode, h).ts
|
||
go proxy(t)
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).Proxy = ProxyURL(pu)
|
||
r, err := c.Head(ts.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if r.Header.Get(sentinelHeader) != sentinelValue {
|
||
t.Errorf("Failed to retrieve sentinel value")
|
||
}
|
||
got := <-ch
|
||
ts.Close()
|
||
tsu, err := url.Parse(ts.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
want := "proxy for " + tsu.Host
|
||
if got != want {
|
||
t.Errorf("got %q, want %q", got, want)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestTransportProxy(t *testing.T) {
|
||
defer afterTest(t)
|
||
testCases := []struct{ siteMode, proxyMode testMode }{
|
||
{http1Mode, http1Mode},
|
||
{http1Mode, https1Mode},
|
||
{https1Mode, http1Mode},
|
||
{https1Mode, https1Mode},
|
||
}
|
||
for _, testCase := range testCases {
|
||
siteMode := testCase.siteMode
|
||
proxyMode := testCase.proxyMode
|
||
t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
|
||
siteCh := make(chan *Request, 1)
|
||
h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
siteCh <- r
|
||
})
|
||
proxyCh := make(chan *Request, 1)
|
||
h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
proxyCh <- r
|
||
// Implement an entire CONNECT proxy
|
||
if r.Method == "CONNECT" {
|
||
hijacker, ok := w.(Hijacker)
|
||
if !ok {
|
||
t.Errorf("hijack not allowed")
|
||
return
|
||
}
|
||
clientConn, _, err := hijacker.Hijack()
|
||
if err != nil {
|
||
t.Errorf("hijacking failed")
|
||
return
|
||
}
|
||
res := &Response{
|
||
StatusCode: StatusOK,
|
||
Proto: "HTTP/1.1",
|
||
ProtoMajor: 1,
|
||
ProtoMinor: 1,
|
||
Header: make(Header),
|
||
}
|
||
|
||
targetConn, err := net.Dial("tcp", r.URL.Host)
|
||
if err != nil {
|
||
t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
|
||
return
|
||
}
|
||
|
||
if err := res.Write(clientConn); err != nil {
|
||
t.Errorf("Writing 200 OK failed: %v", err)
|
||
return
|
||
}
|
||
|
||
go io.Copy(targetConn, clientConn)
|
||
go func() {
|
||
io.Copy(clientConn, targetConn)
|
||
targetConn.Close()
|
||
}()
|
||
}
|
||
})
|
||
ts := newClientServerTest(t, siteMode, h1).ts
|
||
proxy := newClientServerTest(t, proxyMode, h2).ts
|
||
|
||
pu, err := url.Parse(proxy.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
// If neither server is HTTPS or both are, then c may be derived from either.
|
||
// If only one server is HTTPS, c must be derived from that server in order
|
||
// to ensure that it is configured to use the fake root CA from testcert.go.
|
||
c := proxy.Client()
|
||
if siteMode == https1Mode {
|
||
c = ts.Client()
|
||
}
|
||
|
||
c.Transport.(*Transport).Proxy = ProxyURL(pu)
|
||
if _, err := c.Head(ts.URL); err != nil {
|
||
t.Error(err)
|
||
}
|
||
got := <-proxyCh
|
||
c.Transport.(*Transport).CloseIdleConnections()
|
||
ts.Close()
|
||
proxy.Close()
|
||
if siteMode == https1Mode {
|
||
// First message should be a CONNECT, asking for a socket to the real server,
|
||
if got.Method != "CONNECT" {
|
||
t.Errorf("Wrong method for secure proxying: %q", got.Method)
|
||
}
|
||
gotHost := got.URL.Host
|
||
pu, err := url.Parse(ts.URL)
|
||
if err != nil {
|
||
t.Fatal("Invalid site URL")
|
||
}
|
||
if wantHost := pu.Host; gotHost != wantHost {
|
||
t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
|
||
}
|
||
|
||
// The next message on the channel should be from the site's server.
|
||
next := <-siteCh
|
||
if next.Method != "HEAD" {
|
||
t.Errorf("Wrong method at destination: %s", next.Method)
|
||
}
|
||
if nextURL := next.URL.String(); nextURL != "/" {
|
||
t.Errorf("Wrong URL at destination: %s", nextURL)
|
||
}
|
||
} else {
|
||
if got.Method != "HEAD" {
|
||
t.Errorf("Wrong method for destination: %q", got.Method)
|
||
}
|
||
gotURL := got.URL.String()
|
||
wantURL := ts.URL + "/"
|
||
if gotURL != wantURL {
|
||
t.Errorf("Got URL %q, want %q", gotURL, wantURL)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestOnProxyConnectResponse(t *testing.T) {
|
||
|
||
var tcases = []struct {
|
||
proxyStatusCode int
|
||
err error
|
||
}{
|
||
{
|
||
StatusOK,
|
||
nil,
|
||
},
|
||
{
|
||
StatusForbidden,
|
||
errors.New("403"),
|
||
},
|
||
}
|
||
for _, tcase := range tcases {
|
||
h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
|
||
})
|
||
|
||
h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
// Implement an entire CONNECT proxy
|
||
if r.Method == "CONNECT" {
|
||
if tcase.proxyStatusCode != StatusOK {
|
||
w.WriteHeader(tcase.proxyStatusCode)
|
||
return
|
||
}
|
||
hijacker, ok := w.(Hijacker)
|
||
if !ok {
|
||
t.Errorf("hijack not allowed")
|
||
return
|
||
}
|
||
clientConn, _, err := hijacker.Hijack()
|
||
if err != nil {
|
||
t.Errorf("hijacking failed")
|
||
return
|
||
}
|
||
res := &Response{
|
||
StatusCode: StatusOK,
|
||
Proto: "HTTP/1.1",
|
||
ProtoMajor: 1,
|
||
ProtoMinor: 1,
|
||
Header: make(Header),
|
||
}
|
||
|
||
targetConn, err := net.Dial("tcp", r.URL.Host)
|
||
if err != nil {
|
||
t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
|
||
return
|
||
}
|
||
|
||
if err := res.Write(clientConn); err != nil {
|
||
t.Errorf("Writing 200 OK failed: %v", err)
|
||
return
|
||
}
|
||
|
||
go io.Copy(targetConn, clientConn)
|
||
go func() {
|
||
io.Copy(clientConn, targetConn)
|
||
targetConn.Close()
|
||
}()
|
||
}
|
||
})
|
||
ts := newClientServerTest(t, https1Mode, h1).ts
|
||
proxy := newClientServerTest(t, https1Mode, h2).ts
|
||
|
||
pu, err := url.Parse(proxy.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
c := proxy.Client()
|
||
|
||
var (
|
||
dials atomic.Int32
|
||
closes atomic.Int32
|
||
)
|
||
c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
conn, err := net.Dial(network, addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
dials.Add(1)
|
||
return noteCloseConn{
|
||
Conn: conn,
|
||
closeFunc: func() {
|
||
closes.Add(1)
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
c.Transport.(*Transport).Proxy = ProxyURL(pu)
|
||
c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
|
||
if proxyURL.String() != pu.String() {
|
||
t.Errorf("proxy url got %s, want %s", proxyURL, pu)
|
||
}
|
||
|
||
if "https://"+connectReq.URL.String() != ts.URL {
|
||
t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
|
||
}
|
||
return tcase.err
|
||
}
|
||
wantCloses := int32(0)
|
||
if _, err := c.Head(ts.URL); err != nil {
|
||
wantCloses = 1
|
||
if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
|
||
t.Errorf("got %v, want %v", err, tcase.err)
|
||
}
|
||
} else {
|
||
if tcase.err != nil {
|
||
t.Errorf("got %v, want nil", err)
|
||
}
|
||
}
|
||
if got, want := dials.Load(), int32(1); got != want {
|
||
t.Errorf("got %v dials, want %v", got, want)
|
||
}
|
||
// #64804: If OnProxyConnectResponse returns an error, we should close the conn.
|
||
if got, want := closes.Load(), wantCloses; got != want {
|
||
t.Errorf("got %v closes, want %v", got, want)
|
||
}
|
||
}
|
||
}
|
||
|
||
// Issue 28012: verify that the Transport closes its TCP connection to http proxies
|
||
// when they're slow to reply to HTTPS CONNECT responses.
|
||
func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
|
||
cancelc := make(chan struct{})
|
||
SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||
ctx, cancel := context.WithCancel(ctx)
|
||
go func() {
|
||
select {
|
||
case <-cancelc:
|
||
case <-ctx.Done():
|
||
}
|
||
cancel()
|
||
}()
|
||
return ctx, cancel
|
||
})
|
||
|
||
defer afterTest(t)
|
||
|
||
ln := newLocalListener(t)
|
||
defer ln.Close()
|
||
listenerDone := make(chan struct{})
|
||
go func() {
|
||
defer close(listenerDone)
|
||
c, err := ln.Accept()
|
||
if err != nil {
|
||
t.Errorf("Accept: %v", err)
|
||
return
|
||
}
|
||
defer c.Close()
|
||
// Read the CONNECT request
|
||
br := bufio.NewReader(c)
|
||
cr, err := ReadRequest(br)
|
||
if err != nil {
|
||
t.Errorf("proxy server failed to read CONNECT request")
|
||
return
|
||
}
|
||
if cr.Method != "CONNECT" {
|
||
t.Errorf("unexpected method %q", cr.Method)
|
||
return
|
||
}
|
||
|
||
// Now hang and never write a response; instead, cancel the request and wait
|
||
// for the client to close.
|
||
// (Prior to Issue 28012 being fixed, we never closed.)
|
||
close(cancelc)
|
||
var buf [1]byte
|
||
_, err = br.Read(buf[:])
|
||
if err != io.EOF {
|
||
t.Errorf("proxy server Read err = %v; want EOF", err)
|
||
}
|
||
return
|
||
}()
|
||
|
||
c := &Client{
|
||
Transport: &Transport{
|
||
Proxy: func(*Request) (*url.URL, error) {
|
||
return url.Parse("http://" + ln.Addr().String())
|
||
},
|
||
},
|
||
}
|
||
req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
_, err = c.Do(req)
|
||
if err == nil {
|
||
t.Errorf("unexpected Get success")
|
||
}
|
||
|
||
// Wait unconditionally for the listener goroutine to exit: this should never
|
||
// hang, so if it does we want a full goroutine dump — and that's exactly what
|
||
// the testing package will give us when the test run times out.
|
||
<-listenerDone
|
||
}
|
||
|
||
// Issue 16997: test transport dial preserves typed errors
|
||
func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
|
||
defer afterTest(t)
|
||
|
||
var errDial = errors.New("some dial error")
|
||
|
||
tr := &Transport{
|
||
Proxy: func(*Request) (*url.URL, error) {
|
||
return url.Parse("http://proxy.fake.tld/")
|
||
},
|
||
Dial: func(string, string) (net.Conn, error) {
|
||
return nil, errDial
|
||
},
|
||
}
|
||
defer tr.CloseIdleConnections()
|
||
|
||
c := &Client{Transport: tr}
|
||
req, _ := NewRequest("GET", "http://fake.tld", nil)
|
||
res, err := c.Do(req)
|
||
if err == nil {
|
||
res.Body.Close()
|
||
t.Fatal("wanted a non-nil error")
|
||
}
|
||
|
||
uerr, ok := err.(*url.Error)
|
||
if !ok {
|
||
t.Fatalf("got %T, want *url.Error", err)
|
||
}
|
||
oe, ok := uerr.Err.(*net.OpError)
|
||
if !ok {
|
||
t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
|
||
}
|
||
want := &net.OpError{
|
||
Op: "proxyconnect",
|
||
Net: "tcp",
|
||
Err: errDial, // original error, unwrapped.
|
||
}
|
||
if !reflect.DeepEqual(oe, want) {
|
||
t.Errorf("Got error %#v; want %#v", oe, want)
|
||
}
|
||
}
|
||
|
||
// Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader.
|
||
//
|
||
// (A bug caused dialConn to instead write the per-request Proxy-Authorization
|
||
// header through to the shared Header instance, introducing a data race.)
|
||
func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
|
||
run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
|
||
}
|
||
func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
|
||
proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
|
||
defer proxy.Close()
|
||
c := proxy.Client()
|
||
|
||
tr := c.Transport.(*Transport)
|
||
tr.Proxy = func(*Request) (*url.URL, error) {
|
||
u, _ := url.Parse(proxy.URL)
|
||
u.User = url.UserPassword("aladdin", "opensesame")
|
||
return u, nil
|
||
}
|
||
h := tr.ProxyConnectHeader
|
||
if h == nil {
|
||
h = make(Header)
|
||
}
|
||
tr.ProxyConnectHeader = h.Clone()
|
||
|
||
req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
_, err = c.Do(req)
|
||
if err == nil {
|
||
t.Errorf("unexpected Get success")
|
||
}
|
||
|
||
if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
|
||
t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
|
||
}
|
||
}
|
||
|
||
// TestTransportGzipRecursive sends a gzip quine and checks that the
|
||
// client gets the same value back. This is more cute than anything,
|
||
// but checks that we don't recurse forever, and checks that
|
||
// Content-Encoding is removed.
|
||
func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
|
||
func testTransportGzipRecursive(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
w.Header().Set("Content-Encoding", "gzip")
|
||
w.Write(rgz)
|
||
})).ts
|
||
|
||
c := ts.Client()
|
||
res, err := c.Get(ts.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
body, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if !bytes.Equal(body, rgz) {
|
||
t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
|
||
body, rgz)
|
||
}
|
||
if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
|
||
t.Fatalf("Content-Encoding = %q; want %q", g, e)
|
||
}
|
||
}
|
||
|
||
// golang.org/issue/7750: request fails when server replies with
|
||
// a short gzip body
|
||
func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
|
||
func testTransportGzipShort(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
w.Header().Set("Content-Encoding", "gzip")
|
||
w.Write([]byte{0x1f, 0x8b})
|
||
})).ts
|
||
|
||
c := ts.Client()
|
||
res, err := c.Get(ts.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer res.Body.Close()
|
||
_, err = io.ReadAll(res.Body)
|
||
if err == nil {
|
||
t.Fatal("Expect an error from reading a body.")
|
||
}
|
||
if err != io.ErrUnexpectedEOF {
|
||
t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
|
||
}
|
||
}
|
||
|
||
// Wait until number of goroutines is no greater than nmax, or time out.
|
||
func waitNumGoroutine(nmax int) int {
|
||
nfinal := runtime.NumGoroutine()
|
||
for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
|
||
time.Sleep(50 * time.Millisecond)
|
||
runtime.GC()
|
||
nfinal = runtime.NumGoroutine()
|
||
}
|
||
return nfinal
|
||
}
|
||
|
||
// tests that persistent goroutine connections shut down when no longer desired.
|
||
func TestTransportPersistConnLeak(t *testing.T) {
|
||
run(t, testTransportPersistConnLeak, testNotParallel)
|
||
}
|
||
func testTransportPersistConnLeak(t *testing.T, mode testMode) {
|
||
if mode == http2Mode {
|
||
t.Skip("flaky in HTTP/2")
|
||
}
|
||
// Not parallel: counts goroutines
|
||
|
||
const numReq = 25
|
||
gotReqCh := make(chan bool, numReq)
|
||
unblockCh := make(chan bool, numReq)
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
gotReqCh <- true
|
||
<-unblockCh
|
||
w.Header().Set("Content-Length", "0")
|
||
w.WriteHeader(204)
|
||
})).ts
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
|
||
n0 := runtime.NumGoroutine()
|
||
|
||
didReqCh := make(chan bool, numReq)
|
||
failed := make(chan bool, numReq)
|
||
for i := 0; i < numReq; i++ {
|
||
go func() {
|
||
res, err := c.Get(ts.URL)
|
||
didReqCh <- true
|
||
if err != nil {
|
||
t.Logf("client fetch error: %v", err)
|
||
failed <- true
|
||
return
|
||
}
|
||
res.Body.Close()
|
||
}()
|
||
}
|
||
|
||
// Wait for all goroutines to be stuck in the Handler.
|
||
for i := 0; i < numReq; i++ {
|
||
select {
|
||
case <-gotReqCh:
|
||
// ok
|
||
case <-failed:
|
||
// Not great but not what we are testing:
|
||
// sometimes an overloaded system will fail to make all the connections.
|
||
}
|
||
}
|
||
|
||
nhigh := runtime.NumGoroutine()
|
||
|
||
// Tell all handlers to unblock and reply.
|
||
close(unblockCh)
|
||
|
||
// Wait for all HTTP clients to be done.
|
||
for i := 0; i < numReq; i++ {
|
||
<-didReqCh
|
||
}
|
||
|
||
tr.CloseIdleConnections()
|
||
nfinal := waitNumGoroutine(n0 + 5)
|
||
|
||
growth := nfinal - n0
|
||
|
||
// We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
|
||
// Previously we were leaking one per numReq.
|
||
if int(growth) > 5 {
|
||
t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
|
||
t.Error("too many new goroutines")
|
||
}
|
||
}
|
||
|
||
// golang.org/issue/4531: Transport leaks goroutines when
|
||
// request.ContentLength is explicitly short
|
||
func TestTransportPersistConnLeakShortBody(t *testing.T) {
|
||
run(t, testTransportPersistConnLeakShortBody, testNotParallel)
|
||
}
|
||
func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
|
||
if mode == http2Mode {
|
||
t.Skip("flaky in HTTP/2")
|
||
}
|
||
|
||
// Not parallel: measures goroutines.
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
})).ts
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
|
||
n0 := runtime.NumGoroutine()
|
||
body := []byte("Hello")
|
||
for i := 0; i < 20; i++ {
|
||
req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
req.ContentLength = int64(len(body) - 2) // explicitly short
|
||
_, err = c.Do(req)
|
||
if err == nil {
|
||
t.Fatal("Expect an error from writing too long of a body.")
|
||
}
|
||
}
|
||
nhigh := runtime.NumGoroutine()
|
||
tr.CloseIdleConnections()
|
||
nfinal := waitNumGoroutine(n0 + 5)
|
||
|
||
growth := nfinal - n0
|
||
|
||
// We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
|
||
// Previously we were leaking one per numReq.
|
||
t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
|
||
if int(growth) > 5 {
|
||
t.Error("too many new goroutines")
|
||
}
|
||
}
|
||
|
||
// A countedConn is a net.Conn that decrements an atomic counter when finalized.
|
||
type countedConn struct {
|
||
net.Conn
|
||
}
|
||
|
||
// A countingDialer dials connections and counts the number that remain reachable.
|
||
type countingDialer struct {
|
||
dialer net.Dialer
|
||
mu sync.Mutex
|
||
total, live int64
|
||
}
|
||
|
||
func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||
conn, err := d.dialer.DialContext(ctx, network, address)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
counted := new(countedConn)
|
||
counted.Conn = conn
|
||
|
||
d.mu.Lock()
|
||
defer d.mu.Unlock()
|
||
d.total++
|
||
d.live++
|
||
|
||
runtime.SetFinalizer(counted, d.decrement)
|
||
return counted, nil
|
||
}
|
||
|
||
func (d *countingDialer) decrement(*countedConn) {
|
||
d.mu.Lock()
|
||
defer d.mu.Unlock()
|
||
d.live--
|
||
}
|
||
|
||
func (d *countingDialer) Read() (total, live int64) {
|
||
d.mu.Lock()
|
||
defer d.mu.Unlock()
|
||
return d.total, d.live
|
||
}
|
||
|
||
func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
|
||
run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
|
||
}
|
||
func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
// Close every connection so that it cannot be kept alive.
|
||
conn, _, err := w.(Hijacker).Hijack()
|
||
if err != nil {
|
||
t.Errorf("Hijack failed unexpectedly: %v", err)
|
||
return
|
||
}
|
||
conn.Close()
|
||
})).ts
|
||
|
||
var d countingDialer
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).DialContext = d.DialContext
|
||
|
||
body := []byte("Hello")
|
||
for i := 0; ; i++ {
|
||
total, live := d.Read()
|
||
if live < total {
|
||
break
|
||
}
|
||
if i >= 1<<12 {
|
||
t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
|
||
}
|
||
|
||
req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
_, err = c.Do(req)
|
||
if err == nil {
|
||
t.Fatal("expected broken connection")
|
||
}
|
||
|
||
runtime.GC()
|
||
}
|
||
}
|
||
|
||
type countedContext struct {
|
||
context.Context
|
||
}
|
||
|
||
type contextCounter struct {
|
||
mu sync.Mutex
|
||
live int64
|
||
}
|
||
|
||
func (cc *contextCounter) Track(ctx context.Context) context.Context {
|
||
counted := new(countedContext)
|
||
counted.Context = ctx
|
||
cc.mu.Lock()
|
||
defer cc.mu.Unlock()
|
||
cc.live++
|
||
runtime.SetFinalizer(counted, cc.decrement)
|
||
return counted
|
||
}
|
||
|
||
func (cc *contextCounter) decrement(*countedContext) {
|
||
cc.mu.Lock()
|
||
defer cc.mu.Unlock()
|
||
cc.live--
|
||
}
|
||
|
||
func (cc *contextCounter) Read() (live int64) {
|
||
cc.mu.Lock()
|
||
defer cc.mu.Unlock()
|
||
return cc.live
|
||
}
|
||
|
||
func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
|
||
run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
|
||
}
|
||
func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
|
||
if mode == http2Mode {
|
||
t.Skip("https://go.dev/issue/56021")
|
||
}
|
||
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
runtime.Gosched()
|
||
w.WriteHeader(StatusOK)
|
||
})).ts
|
||
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).MaxConnsPerHost = 1
|
||
|
||
ctx := context.Background()
|
||
body := []byte("Hello")
|
||
doPosts := func(cc *contextCounter) {
|
||
var wg sync.WaitGroup
|
||
for n := 64; n > 0; n-- {
|
||
wg.Add(1)
|
||
go func() {
|
||
defer wg.Done()
|
||
|
||
ctx := cc.Track(ctx)
|
||
req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
|
||
if err != nil {
|
||
t.Error(err)
|
||
}
|
||
|
||
_, err = c.Do(req.WithContext(ctx))
|
||
if err != nil {
|
||
t.Errorf("Do failed with error: %v", err)
|
||
}
|
||
}()
|
||
}
|
||
wg.Wait()
|
||
}
|
||
|
||
var initialCC contextCounter
|
||
doPosts(&initialCC)
|
||
|
||
// flushCC exists only to put pressure on the GC to finalize the initialCC
|
||
// contexts: the flushCC allocations should eventually displace the initialCC
|
||
// allocations.
|
||
var flushCC contextCounter
|
||
for i := 0; ; i++ {
|
||
live := initialCC.Read()
|
||
if live == 0 {
|
||
break
|
||
}
|
||
if i >= 100 {
|
||
t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
|
||
}
|
||
doPosts(&flushCC)
|
||
runtime.GC()
|
||
}
|
||
}
|
||
|
||
// This used to crash; https://golang.org/issue/3266
|
||
func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
|
||
func testTransportIdleConnCrash(t *testing.T, mode testMode) {
|
||
var tr *Transport
|
||
|
||
unblockCh := make(chan bool, 1)
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
<-unblockCh
|
||
tr.CloseIdleConnections()
|
||
})).ts
|
||
c := ts.Client()
|
||
tr = c.Transport.(*Transport)
|
||
|
||
didreq := make(chan bool)
|
||
go func() {
|
||
res, err := c.Get(ts.URL)
|
||
if err != nil {
|
||
t.Error(err)
|
||
} else {
|
||
res.Body.Close() // returns idle conn
|
||
}
|
||
didreq <- true
|
||
}()
|
||
unblockCh <- true
|
||
<-didreq
|
||
}
|
||
|
||
// Test that the transport doesn't close the TCP connection early,
|
||
// before the response body has been read. This was a regression
|
||
// which sadly lacked a triggering test. The large response body made
|
||
// the old race easier to trigger.
|
||
func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
|
||
func testIssue3644(t *testing.T, mode testMode) {
|
||
const numFoos = 5000
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
w.Header().Set("Connection", "close")
|
||
for i := 0; i < numFoos; i++ {
|
||
w.Write([]byte("foo "))
|
||
}
|
||
})).ts
|
||
c := ts.Client()
|
||
res, err := c.Get(ts.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer res.Body.Close()
|
||
bs, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if len(bs) != numFoos*len("foo ") {
|
||
t.Errorf("unexpected response length")
|
||
}
|
||
}
|
||
|
||
// Test that a client receives a server's reply, even if the server doesn't read
|
||
// the entire request body.
|
||
func TestIssue3595(t *testing.T) {
|
||
// Not parallel: modifies the global rstAvoidanceDelay.
|
||
run(t, testIssue3595, testNotParallel)
|
||
}
|
||
func testIssue3595(t *testing.T, mode testMode) {
|
||
runTimeSensitiveTest(t, []time.Duration{
|
||
1 * time.Millisecond,
|
||
5 * time.Millisecond,
|
||
10 * time.Millisecond,
|
||
50 * time.Millisecond,
|
||
100 * time.Millisecond,
|
||
500 * time.Millisecond,
|
||
time.Second,
|
||
5 * time.Second,
|
||
}, func(t *testing.T, timeout time.Duration) error {
|
||
SetRSTAvoidanceDelay(t, timeout)
|
||
t.Logf("set RST avoidance delay to %v", timeout)
|
||
|
||
const deniedMsg = "sorry, denied."
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
Error(w, deniedMsg, StatusUnauthorized)
|
||
}))
|
||
// We need to close cst explicitly here so that in-flight server
|
||
// requests don't race with the call to SetRSTAvoidanceDelay for a retry.
|
||
defer cst.close()
|
||
ts := cst.ts
|
||
c := ts.Client()
|
||
|
||
res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
|
||
if err != nil {
|
||
return fmt.Errorf("Post: %v", err)
|
||
}
|
||
got, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
return fmt.Errorf("Body ReadAll: %v", err)
|
||
}
|
||
t.Logf("server response:\n%s", got)
|
||
if !strings.Contains(string(got), deniedMsg) {
|
||
// If we got an RST packet too early, we should have seen an error
|
||
// from io.ReadAll, not a silently-truncated body.
|
||
t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
// From https://golang.org/issue/4454 ,
|
||
// "client fails to handle requests with no body and chunked encoding"
|
||
func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
|
||
func testChunkedNoContent(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
w.WriteHeader(StatusNoContent)
|
||
})).ts
|
||
|
||
c := ts.Client()
|
||
for _, closeBody := range []bool{true, false} {
|
||
const n = 4
|
||
for i := 1; i <= n; i++ {
|
||
res, err := c.Get(ts.URL)
|
||
if err != nil {
|
||
t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
|
||
} else {
|
||
if closeBody {
|
||
res.Body.Close()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestTransportConcurrency(t *testing.T) {
|
||
run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
|
||
}
|
||
func testTransportConcurrency(t *testing.T, mode testMode) {
|
||
// Not parallel: uses global test hooks.
|
||
maxProcs, numReqs := 16, 500
|
||
if testing.Short() {
|
||
maxProcs, numReqs = 4, 50
|
||
}
|
||
defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
fmt.Fprintf(w, "%v", r.FormValue("echo"))
|
||
})).ts
|
||
|
||
var wg sync.WaitGroup
|
||
wg.Add(numReqs)
|
||
|
||
// Due to the Transport's "socket late binding" (see
|
||
// idleConnCh in transport.go), the numReqs HTTP requests
|
||
// below can finish with a dial still outstanding. To keep
|
||
// the leak checker happy, keep track of pending dials and
|
||
// wait for them to finish (and be closed or returned to the
|
||
// idle pool) before we close idle connections.
|
||
SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
|
||
defer SetPendingDialHooks(nil, nil)
|
||
|
||
c := ts.Client()
|
||
reqs := make(chan string)
|
||
defer close(reqs)
|
||
|
||
for i := 0; i < maxProcs*2; i++ {
|
||
go func() {
|
||
for req := range reqs {
|
||
res, err := c.Get(ts.URL + "/?echo=" + req)
|
||
if err != nil {
|
||
if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
|
||
// https://go.dev/issue/52168: this test was observed to fail with
|
||
// ECONNRESET errors in Dial on various netbsd builders.
|
||
t.Logf("error on req %s: %v", req, err)
|
||
t.Logf("(see https://go.dev/issue/52168)")
|
||
} else {
|
||
t.Errorf("error on req %s: %v", req, err)
|
||
}
|
||
wg.Done()
|
||
continue
|
||
}
|
||
all, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
t.Errorf("read error on req %s: %v", req, err)
|
||
} else if string(all) != req {
|
||
t.Errorf("body of req %s = %q; want %q", req, all, req)
|
||
}
|
||
res.Body.Close()
|
||
wg.Done()
|
||
}
|
||
}()
|
||
}
|
||
for i := 0; i < numReqs; i++ {
|
||
reqs <- fmt.Sprintf("request-%d", i)
|
||
}
|
||
wg.Wait()
|
||
}
|
||
|
||
func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
|
||
func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
|
||
mux := NewServeMux()
|
||
mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
|
||
io.Copy(w, neverEnding('a'))
|
||
})
|
||
ts := newClientServerTest(t, mode, mux).ts
|
||
|
||
connc := make(chan net.Conn, 1)
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
|
||
conn, err := net.Dial(n, addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
select {
|
||
case connc <- conn:
|
||
default:
|
||
}
|
||
return conn, nil
|
||
}
|
||
|
||
res, err := c.Get(ts.URL + "/get")
|
||
if err != nil {
|
||
t.Fatalf("Error issuing GET: %v", err)
|
||
}
|
||
defer res.Body.Close()
|
||
|
||
conn := <-connc
|
||
conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
|
||
_, err = io.Copy(io.Discard, res.Body)
|
||
if err == nil {
|
||
t.Errorf("Unexpected successful copy")
|
||
}
|
||
}
|
||
|
||
func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
|
||
run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
|
||
}
|
||
func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
|
||
const debug = false
|
||
mux := NewServeMux()
|
||
mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
|
||
io.Copy(w, neverEnding('a'))
|
||
})
|
||
mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
|
||
defer r.Body.Close()
|
||
io.Copy(io.Discard, r.Body)
|
||
})
|
||
ts := newClientServerTest(t, mode, mux).ts
|
||
timeout := 100 * time.Millisecond
|
||
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
|
||
conn, err := net.Dial(n, addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
conn.SetDeadline(time.Now().Add(timeout))
|
||
if debug {
|
||
conn = NewLoggingConn("client", conn)
|
||
}
|
||
return conn, nil
|
||
}
|
||
|
||
getFailed := false
|
||
nRuns := 5
|
||
if testing.Short() {
|
||
nRuns = 1
|
||
}
|
||
for i := 0; i < nRuns; i++ {
|
||
if debug {
|
||
println("run", i+1, "of", nRuns)
|
||
}
|
||
sres, err := c.Get(ts.URL + "/get")
|
||
if err != nil {
|
||
if !getFailed {
|
||
// Make the timeout longer, once.
|
||
getFailed = true
|
||
t.Logf("increasing timeout")
|
||
i--
|
||
timeout *= 10
|
||
continue
|
||
}
|
||
t.Errorf("Error issuing GET: %v", err)
|
||
break
|
||
}
|
||
req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
|
||
_, err = c.Do(req)
|
||
if err == nil {
|
||
sres.Body.Close()
|
||
t.Errorf("Unexpected successful PUT")
|
||
break
|
||
}
|
||
sres.Body.Close()
|
||
}
|
||
if debug {
|
||
println("tests complete; waiting for handlers to finish")
|
||
}
|
||
ts.Close()
|
||
}
|
||
|
||
func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
|
||
func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
|
||
if testing.Short() {
|
||
t.Skip("skipping timeout test in -short mode")
|
||
}
|
||
|
||
timeout := 2 * time.Millisecond
|
||
retry := true
|
||
for retry && !t.Failed() {
|
||
var srvWG sync.WaitGroup
|
||
inHandler := make(chan bool, 1)
|
||
mux := NewServeMux()
|
||
mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
|
||
inHandler <- true
|
||
srvWG.Done()
|
||
})
|
||
mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
|
||
inHandler <- true
|
||
<-r.Context().Done()
|
||
srvWG.Done()
|
||
})
|
||
ts := newClientServerTest(t, mode, mux).ts
|
||
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).ResponseHeaderTimeout = timeout
|
||
|
||
retry = false
|
||
srvWG.Add(3)
|
||
tests := []struct {
|
||
path string
|
||
wantTimeout bool
|
||
}{
|
||
{path: "/fast"},
|
||
{path: "/slow", wantTimeout: true},
|
||
{path: "/fast"},
|
||
}
|
||
for i, tt := range tests {
|
||
req, _ := NewRequest("GET", ts.URL+tt.path, nil)
|
||
req = req.WithT(t)
|
||
res, err := c.Do(req)
|
||
<-inHandler
|
||
if err != nil {
|
||
uerr, ok := err.(*url.Error)
|
||
if !ok {
|
||
t.Errorf("error is not a url.Error; got: %#v", err)
|
||
continue
|
||
}
|
||
nerr, ok := uerr.Err.(net.Error)
|
||
if !ok {
|
||
t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
|
||
continue
|
||
}
|
||
if !nerr.Timeout() {
|
||
t.Errorf("want timeout error; got: %q", nerr)
|
||
continue
|
||
}
|
||
if !tt.wantTimeout {
|
||
if !retry {
|
||
// The timeout may be set too short. Retry with a longer one.
|
||
t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
|
||
timeout *= 2
|
||
retry = true
|
||
}
|
||
}
|
||
if !strings.Contains(err.Error(), "timeout awaiting response headers") {
|
||
t.Errorf("%d. unexpected error: %v", i, err)
|
||
}
|
||
continue
|
||
}
|
||
if tt.wantTimeout {
|
||
t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
|
||
continue
|
||
}
|
||
if res.StatusCode != 200 {
|
||
t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
|
||
}
|
||
}
|
||
|
||
srvWG.Wait()
|
||
ts.Close()
|
||
}
|
||
}
|
||
|
||
// A cancelTest is a test of request cancellation.
|
||
type cancelTest struct {
|
||
mode testMode
|
||
newReq func(req *Request) *Request // prepare the request to cancel
|
||
cancel func(tr *Transport, req *Request) // cancel the request
|
||
checkErr func(when string, err error) // verify the expected error
|
||
}
|
||
|
||
// runCancelTestTransport uses Transport.CancelRequest.
|
||
func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
|
||
t.Run("TransportCancel", func(t *testing.T) {
|
||
f(t, cancelTest{
|
||
mode: mode,
|
||
newReq: func(req *Request) *Request {
|
||
return req
|
||
},
|
||
cancel: func(tr *Transport, req *Request) {
|
||
tr.CancelRequest(req)
|
||
},
|
||
checkErr: func(when string, err error) {
|
||
if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
|
||
t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
|
||
}
|
||
},
|
||
})
|
||
})
|
||
}
|
||
|
||
// runCancelTestChannel uses Request.Cancel.
|
||
func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
|
||
cancelc := make(chan struct{})
|
||
cancelOnce := sync.OnceFunc(func() { close(cancelc) })
|
||
f(t, cancelTest{
|
||
mode: mode,
|
||
newReq: func(req *Request) *Request {
|
||
req.Cancel = cancelc
|
||
return req
|
||
},
|
||
cancel: func(tr *Transport, req *Request) {
|
||
cancelOnce()
|
||
},
|
||
checkErr: func(when string, err error) {
|
||
if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
|
||
t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
|
||
}
|
||
},
|
||
})
|
||
}
|
||
|
||
// runCancelTestContext uses a request context.
|
||
func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
f(t, cancelTest{
|
||
mode: mode,
|
||
newReq: func(req *Request) *Request {
|
||
return req.WithContext(ctx)
|
||
},
|
||
cancel: func(tr *Transport, req *Request) {
|
||
cancel()
|
||
},
|
||
checkErr: func(when string, err error) {
|
||
if !errors.Is(err, context.Canceled) {
|
||
t.Errorf("%v error = %v, want context.Canceled", when, err)
|
||
}
|
||
},
|
||
})
|
||
}
|
||
|
||
func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
|
||
run(t, func(t *testing.T, mode testMode) {
|
||
if mode == http1Mode {
|
||
t.Run("TransportCancel", func(t *testing.T) {
|
||
runCancelTestTransport(t, mode, f)
|
||
})
|
||
}
|
||
t.Run("RequestCancel", func(t *testing.T) {
|
||
runCancelTestChannel(t, mode, f)
|
||
})
|
||
t.Run("ContextCancel", func(t *testing.T) {
|
||
runCancelTestContext(t, mode, f)
|
||
})
|
||
}, opts...)
|
||
}
|
||
|
||
func TestTransportCancelRequest(t *testing.T) {
|
||
runCancelTest(t, testTransportCancelRequest)
|
||
}
|
||
func testTransportCancelRequest(t *testing.T, test cancelTest) {
|
||
if testing.Short() {
|
||
t.Skip("skipping test in -short mode")
|
||
}
|
||
|
||
const msg = "Hello"
|
||
unblockc := make(chan bool)
|
||
ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
io.WriteString(w, msg)
|
||
w.(Flusher).Flush() // send headers and some body
|
||
<-unblockc
|
||
})).ts
|
||
defer close(unblockc)
|
||
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
|
||
req, _ := NewRequest("GET", ts.URL, nil)
|
||
req = test.newReq(req)
|
||
res, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
body := make([]byte, len(msg))
|
||
n, _ := io.ReadFull(res.Body, body)
|
||
if n != len(body) || !bytes.Equal(body, []byte(msg)) {
|
||
t.Errorf("Body = %q; want %q", body[:n], msg)
|
||
}
|
||
test.cancel(tr, req)
|
||
|
||
tail, err := io.ReadAll(res.Body)
|
||
res.Body.Close()
|
||
test.checkErr("Body.Read", err)
|
||
if len(tail) > 0 {
|
||
t.Errorf("Spurious bytes from Body.Read: %q", tail)
|
||
}
|
||
|
||
// Verify no outstanding requests after readLoop/writeLoop
|
||
// goroutines shut down.
|
||
waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
|
||
n := tr.NumPendingRequestsForTesting()
|
||
if n > 0 {
|
||
if d > 0 {
|
||
t.Logf("pending requests = %d after %v (want 0)", n, d)
|
||
}
|
||
return false
|
||
}
|
||
return true
|
||
})
|
||
}
|
||
|
||
func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
|
||
if testing.Short() {
|
||
t.Skip("skipping test in -short mode")
|
||
}
|
||
unblockc := make(chan bool)
|
||
ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
<-unblockc
|
||
})).ts
|
||
defer close(unblockc)
|
||
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
|
||
donec := make(chan bool)
|
||
req, _ := NewRequest("GET", ts.URL, body)
|
||
req = test.newReq(req)
|
||
go func() {
|
||
defer close(donec)
|
||
c.Do(req)
|
||
}()
|
||
|
||
unblockc <- true
|
||
waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
|
||
test.cancel(tr, req)
|
||
select {
|
||
case <-donec:
|
||
return true
|
||
default:
|
||
if d > 0 {
|
||
t.Logf("Do of canceled request has not returned after %v", d)
|
||
}
|
||
return false
|
||
}
|
||
})
|
||
}
|
||
|
||
func TestTransportCancelRequestInDo(t *testing.T) {
|
||
runCancelTest(t, func(t *testing.T, test cancelTest) {
|
||
testTransportCancelRequestInDo(t, test, nil)
|
||
})
|
||
}
|
||
|
||
func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
|
||
runCancelTest(t, func(t *testing.T, test cancelTest) {
|
||
testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
|
||
})
|
||
}
|
||
|
||
func TestTransportCancelRequestInDial(t *testing.T) {
|
||
runCancelTest(t, testTransportCancelRequestInDial)
|
||
}
|
||
func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
|
||
defer afterTest(t)
|
||
if testing.Short() {
|
||
t.Skip("skipping test in -short mode")
|
||
}
|
||
var logbuf strings.Builder
|
||
eventLog := log.New(&logbuf, "", 0)
|
||
|
||
unblockDial := make(chan bool)
|
||
defer close(unblockDial)
|
||
|
||
inDial := make(chan bool)
|
||
tr := &Transport{
|
||
Dial: func(network, addr string) (net.Conn, error) {
|
||
eventLog.Println("dial: blocking")
|
||
if !<-inDial {
|
||
return nil, errors.New("main Test goroutine exited")
|
||
}
|
||
<-unblockDial
|
||
return nil, errors.New("nope")
|
||
},
|
||
}
|
||
cl := &Client{Transport: tr}
|
||
gotres := make(chan bool)
|
||
req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
|
||
req = test.newReq(req)
|
||
go func() {
|
||
_, err := cl.Do(req)
|
||
eventLog.Printf("Get error = %v", err != nil)
|
||
test.checkErr("Get", err)
|
||
gotres <- true
|
||
}()
|
||
|
||
inDial <- true
|
||
|
||
eventLog.Printf("canceling")
|
||
test.cancel(tr, req)
|
||
test.cancel(tr, req) // used to panic on second call to Transport.Cancel
|
||
|
||
if d, ok := t.Deadline(); ok {
|
||
// When the test's deadline is about to expire, log the pending events for
|
||
// better debugging.
|
||
timeout := time.Until(d) * 19 / 20 // Allow 5% for cleanup.
|
||
timer := time.AfterFunc(timeout, func() {
|
||
panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
|
||
})
|
||
defer timer.Stop()
|
||
}
|
||
<-gotres
|
||
|
||
got := logbuf.String()
|
||
want := `dial: blocking
|
||
canceling
|
||
Get error = true
|
||
`
|
||
if got != want {
|
||
t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
|
||
}
|
||
}
|
||
|
||
// Issue 51354
|
||
func TestTransportCancelRequestWithBody(t *testing.T) {
|
||
runCancelTest(t, testTransportCancelRequestWithBody)
|
||
}
|
||
func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
|
||
if testing.Short() {
|
||
t.Skip("skipping test in -short mode")
|
||
}
|
||
|
||
const msg = "Hello"
|
||
unblockc := make(chan struct{})
|
||
ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
io.WriteString(w, msg)
|
||
w.(Flusher).Flush() // send headers and some body
|
||
<-unblockc
|
||
})).ts
|
||
defer close(unblockc)
|
||
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
|
||
req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
|
||
req = test.newReq(req)
|
||
|
||
res, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
body := make([]byte, len(msg))
|
||
n, _ := io.ReadFull(res.Body, body)
|
||
if n != len(body) || !bytes.Equal(body, []byte(msg)) {
|
||
t.Errorf("Body = %q; want %q", body[:n], msg)
|
||
}
|
||
test.cancel(tr, req)
|
||
|
||
tail, err := io.ReadAll(res.Body)
|
||
res.Body.Close()
|
||
test.checkErr("Body.Read", err)
|
||
if len(tail) > 0 {
|
||
t.Errorf("Spurious bytes from Body.Read: %q", tail)
|
||
}
|
||
|
||
// Verify no outstanding requests after readLoop/writeLoop
|
||
// goroutines shut down.
|
||
waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
|
||
n := tr.NumPendingRequestsForTesting()
|
||
if n > 0 {
|
||
if d > 0 {
|
||
t.Logf("pending requests = %d after %v (want 0)", n, d)
|
||
}
|
||
return false
|
||
}
|
||
return true
|
||
})
|
||
}
|
||
|
||
func TestTransportCancelRequestBeforeDo(t *testing.T) {
|
||
// We can't cancel a request that hasn't started using Transport.CancelRequest.
|
||
run(t, func(t *testing.T, mode testMode) {
|
||
t.Run("RequestCancel", func(t *testing.T) {
|
||
runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
|
||
})
|
||
t.Run("ContextCancel", func(t *testing.T) {
|
||
runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
|
||
})
|
||
})
|
||
}
|
||
func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
|
||
unblockc := make(chan bool)
|
||
cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
<-unblockc
|
||
}))
|
||
defer close(unblockc)
|
||
|
||
c := cst.ts.Client()
|
||
|
||
req, _ := NewRequest("GET", cst.ts.URL, nil)
|
||
req = test.newReq(req)
|
||
test.cancel(cst.tr, req)
|
||
|
||
_, err := c.Do(req)
|
||
test.checkErr("Do", err)
|
||
}
|
||
|
||
// Issue 11020. The returned error message should be errRequestCanceled
|
||
func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
|
||
runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
|
||
}
|
||
func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
|
||
defer afterTest(t)
|
||
|
||
serverConnCh := make(chan net.Conn, 1)
|
||
tr := &Transport{
|
||
Dial: func(network, addr string) (net.Conn, error) {
|
||
cc, sc := net.Pipe()
|
||
serverConnCh <- sc
|
||
return cc, nil
|
||
},
|
||
}
|
||
defer tr.CloseIdleConnections()
|
||
errc := make(chan error, 1)
|
||
req, _ := NewRequest("GET", "http://example.com/", nil)
|
||
req = test.newReq(req)
|
||
go func() {
|
||
_, err := tr.RoundTrip(req)
|
||
errc <- err
|
||
}()
|
||
|
||
sc := <-serverConnCh
|
||
verb := make([]byte, 3)
|
||
if _, err := io.ReadFull(sc, verb); err != nil {
|
||
t.Errorf("Error reading HTTP verb from server: %v", err)
|
||
}
|
||
if string(verb) != "GET" {
|
||
t.Errorf("server received %q; want GET", verb)
|
||
}
|
||
defer sc.Close()
|
||
|
||
test.cancel(tr, req)
|
||
|
||
err := <-errc
|
||
if err == nil {
|
||
t.Fatalf("unexpected success from RoundTrip")
|
||
}
|
||
test.checkErr("RoundTrip", err)
|
||
}
|
||
|
||
// golang.org/issue/3672 -- Client can't close HTTP stream
|
||
// Calling Close on a Response.Body used to just read until EOF.
|
||
// Now it actually closes the TCP connection.
|
||
func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
|
||
func testTransportCloseResponseBody(t *testing.T, mode testMode) {
|
||
writeErr := make(chan error, 1)
|
||
msg := []byte("young\n")
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
for {
|
||
_, err := w.Write(msg)
|
||
if err != nil {
|
||
writeErr <- err
|
||
return
|
||
}
|
||
w.(Flusher).Flush()
|
||
}
|
||
})).ts
|
||
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
|
||
req, _ := NewRequest("GET", ts.URL, nil)
|
||
defer tr.CancelRequest(req)
|
||
|
||
res, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
const repeats = 3
|
||
buf := make([]byte, len(msg)*repeats)
|
||
want := bytes.Repeat(msg, repeats)
|
||
|
||
_, err = io.ReadFull(res.Body, buf)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if !bytes.Equal(buf, want) {
|
||
t.Fatalf("read %q; want %q", buf, want)
|
||
}
|
||
|
||
if err := res.Body.Close(); err != nil {
|
||
t.Errorf("Close = %v", err)
|
||
}
|
||
|
||
if err := <-writeErr; err == nil {
|
||
t.Errorf("expected non-nil write error")
|
||
}
|
||
}
|
||
|
||
type fooProto struct{}
|
||
|
||
func (fooProto) RoundTrip(req *Request) (*Response, error) {
|
||
res := &Response{
|
||
Status: "200 OK",
|
||
StatusCode: 200,
|
||
Header: make(Header),
|
||
Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
|
||
}
|
||
return res, nil
|
||
}
|
||
|
||
func TestTransportAltProto(t *testing.T) {
|
||
defer afterTest(t)
|
||
tr := &Transport{}
|
||
c := &Client{Transport: tr}
|
||
tr.RegisterProtocol("foo", fooProto{})
|
||
res, err := c.Get("foo://bar.com/path")
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
bodyb, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
body := string(bodyb)
|
||
if e := "You wanted foo://bar.com/path"; body != e {
|
||
t.Errorf("got response %q, want %q", body, e)
|
||
}
|
||
}
|
||
|
||
func TestTransportNoHost(t *testing.T) {
|
||
defer afterTest(t)
|
||
tr := &Transport{}
|
||
_, err := tr.RoundTrip(&Request{
|
||
Header: make(Header),
|
||
URL: &url.URL{
|
||
Scheme: "http",
|
||
},
|
||
})
|
||
want := "http: no Host in request URL"
|
||
if got := fmt.Sprint(err); got != want {
|
||
t.Errorf("error = %v; want %q", err, want)
|
||
}
|
||
}
|
||
|
||
// Issue 13311
|
||
func TestTransportEmptyMethod(t *testing.T) {
|
||
req, _ := NewRequest("GET", "http://foo.com/", nil)
|
||
req.Method = "" // docs say "For client requests an empty string means GET"
|
||
got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if !strings.Contains(string(got), "GET ") {
|
||
t.Fatalf("expected substring 'GET '; got: %s", got)
|
||
}
|
||
}
|
||
|
||
func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
|
||
func testTransportSocketLateBinding(t *testing.T, mode testMode) {
|
||
mux := NewServeMux()
|
||
fooGate := make(chan bool, 1)
|
||
mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
|
||
w.Header().Set("foo-ipport", r.RemoteAddr)
|
||
w.(Flusher).Flush()
|
||
<-fooGate
|
||
})
|
||
mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
|
||
w.Header().Set("bar-ipport", r.RemoteAddr)
|
||
})
|
||
ts := newClientServerTest(t, mode, mux).ts
|
||
|
||
dialGate := make(chan bool, 1)
|
||
dialing := make(chan bool)
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
|
||
for {
|
||
select {
|
||
case ok := <-dialGate:
|
||
if !ok {
|
||
return nil, errors.New("manually closed")
|
||
}
|
||
return net.Dial(n, addr)
|
||
case dialing <- true:
|
||
}
|
||
}
|
||
}
|
||
defer close(dialGate)
|
||
|
||
dialGate <- true // only allow one dial
|
||
fooRes, err := c.Get(ts.URL + "/foo")
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
fooAddr := fooRes.Header.Get("foo-ipport")
|
||
if fooAddr == "" {
|
||
t.Fatal("No addr on /foo request")
|
||
}
|
||
|
||
fooDone := make(chan struct{})
|
||
go func() {
|
||
// We know that the foo Dial completed and reached the handler because we
|
||
// read its header. Wait for the bar request to block in Dial, then
|
||
// let the foo response finish so we can use its connection for /bar.
|
||
|
||
if mode == http2Mode {
|
||
// In HTTP/2 mode, the second Dial won't happen because the protocol
|
||
// multiplexes the streams by default. Just sleep for an arbitrary time;
|
||
// the test should pass regardless of how far the bar request gets by this
|
||
// point.
|
||
select {
|
||
case <-dialing:
|
||
t.Errorf("unexpected second Dial in HTTP/2 mode")
|
||
case <-time.After(10 * time.Millisecond):
|
||
}
|
||
} else {
|
||
<-dialing
|
||
}
|
||
fooGate <- true
|
||
io.Copy(io.Discard, fooRes.Body)
|
||
fooRes.Body.Close()
|
||
close(fooDone)
|
||
}()
|
||
defer func() {
|
||
<-fooDone
|
||
}()
|
||
|
||
barRes, err := c.Get(ts.URL + "/bar")
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
barAddr := barRes.Header.Get("bar-ipport")
|
||
if barAddr != fooAddr {
|
||
t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
|
||
}
|
||
barRes.Body.Close()
|
||
}
|
||
|
||
// Issue 2184
|
||
func TestTransportReading100Continue(t *testing.T) {
|
||
defer afterTest(t)
|
||
|
||
const numReqs = 5
|
||
reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
|
||
reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
|
||
|
||
send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
|
||
defer w.Close()
|
||
defer r.Close()
|
||
br := bufio.NewReader(r)
|
||
n := 0
|
||
for {
|
||
n++
|
||
req, err := ReadRequest(br)
|
||
if err == io.EOF {
|
||
return
|
||
}
|
||
if err != nil {
|
||
t.Error(err)
|
||
return
|
||
}
|
||
slurp, err := io.ReadAll(req.Body)
|
||
if err != nil {
|
||
t.Errorf("Server request body slurp: %v", err)
|
||
return
|
||
}
|
||
id := req.Header.Get("Request-Id")
|
||
resCode := req.Header.Get("X-Want-Response-Code")
|
||
if resCode == "" {
|
||
resCode = "100 Continue"
|
||
if string(slurp) != reqBody(n) {
|
||
t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
|
||
}
|
||
}
|
||
body := fmt.Sprintf("Response number %d", n)
|
||
v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
|
||
Date: Thu, 28 Feb 2013 17:55:41 GMT
|
||
|
||
HTTP/1.1 200 OK
|
||
Content-Type: text/html
|
||
Echo-Request-Id: %s
|
||
Content-Length: %d
|
||
|
||
%s`, resCode, id, len(body), body), "\n", "\r\n", -1))
|
||
w.Write(v)
|
||
if id == reqID(numReqs) {
|
||
return
|
||
}
|
||
}
|
||
|
||
}
|
||
|
||
tr := &Transport{
|
||
Dial: func(n, addr string) (net.Conn, error) {
|
||
sr, sw := io.Pipe() // server read/write
|
||
cr, cw := io.Pipe() // client read/write
|
||
conn := &rwTestConn{
|
||
Reader: cr,
|
||
Writer: sw,
|
||
closeFunc: func() error {
|
||
sw.Close()
|
||
cw.Close()
|
||
return nil
|
||
},
|
||
}
|
||
go send100Response(cw, sr)
|
||
return conn, nil
|
||
},
|
||
DisableKeepAlives: false,
|
||
}
|
||
defer tr.CloseIdleConnections()
|
||
c := &Client{Transport: tr}
|
||
|
||
testResponse := func(req *Request, name string, wantCode int) {
|
||
t.Helper()
|
||
res, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatalf("%s: Do: %v", name, err)
|
||
}
|
||
if res.StatusCode != wantCode {
|
||
t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
|
||
}
|
||
if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
|
||
t.Errorf("%s: response id %q != request id %q", name, idBack, id)
|
||
}
|
||
_, err = io.ReadAll(res.Body)
|
||
if err != nil {
|
||
t.Fatalf("%s: Slurp error: %v", name, err)
|
||
}
|
||
}
|
||
|
||
// Few 100 responses, making sure we're not off-by-one.
|
||
for i := 1; i <= numReqs; i++ {
|
||
req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
|
||
req.Header.Set("Request-Id", reqID(i))
|
||
testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
|
||
}
|
||
}
|
||
|
||
// Issue 17739: the HTTP client must ignore any unknown 1xx
|
||
// informational responses before the actual response.
|
||
func TestTransportIgnore1xxResponses(t *testing.T) {
|
||
run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
|
||
}
|
||
func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
conn, buf, _ := w.(Hijacker).Hijack()
|
||
buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
|
||
buf.Flush()
|
||
conn.Close()
|
||
}))
|
||
cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
|
||
|
||
var got strings.Builder
|
||
|
||
req, _ := NewRequest("GET", cst.ts.URL, nil)
|
||
req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
|
||
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
||
fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
|
||
return nil
|
||
},
|
||
}))
|
||
res, err := cst.c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer res.Body.Close()
|
||
|
||
res.Write(&got)
|
||
want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
|
||
if got.String() != want {
|
||
t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
|
||
}
|
||
}
|
||
|
||
func TestTransportLimits1xxResponses(t *testing.T) { run(t, testTransportLimits1xxResponses) }
|
||
func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
w.Header().Add("X-Header", strings.Repeat("a", 100))
|
||
for i := 0; i < 10; i++ {
|
||
w.WriteHeader(123)
|
||
}
|
||
w.WriteHeader(204)
|
||
}))
|
||
cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
|
||
cst.tr.MaxResponseHeaderBytes = 1000
|
||
|
||
res, err := cst.c.Get(cst.ts.URL)
|
||
if err == nil {
|
||
res.Body.Close()
|
||
t.Fatalf("RoundTrip succeeded; want error")
|
||
}
|
||
for _, want := range []string{
|
||
"response headers exceeded",
|
||
"too many 1xx",
|
||
} {
|
||
if strings.Contains(err.Error(), want) {
|
||
return
|
||
}
|
||
}
|
||
t.Errorf(`got error %q; want "response headers exceeded" or "too many 1xx"`, err)
|
||
}
|
||
|
||
func TestTransportDoesNotLimitDelivered1xxResponses(t *testing.T) {
|
||
run(t, testTransportDoesNotLimitDelivered1xxResponses)
|
||
}
|
||
func testTransportDoesNotLimitDelivered1xxResponses(t *testing.T, mode testMode) {
|
||
if mode == http2Mode {
|
||
t.Skip("skip until x/net/http2 updated")
|
||
}
|
||
const num1xx = 10
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
w.Header().Add("X-Header", strings.Repeat("a", 100))
|
||
for i := 0; i < 10; i++ {
|
||
w.WriteHeader(123)
|
||
}
|
||
w.WriteHeader(204)
|
||
}))
|
||
cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
|
||
cst.tr.MaxResponseHeaderBytes = 1000
|
||
|
||
got1xx := 0
|
||
ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
|
||
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
||
got1xx++
|
||
return nil
|
||
},
|
||
})
|
||
req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
|
||
res, err := cst.c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
res.Body.Close()
|
||
if got1xx != num1xx {
|
||
t.Errorf("Got %v 1xx responses, want %x", got1xx, num1xx)
|
||
}
|
||
}
|
||
|
||
// Issue 26161: the HTTP client must treat 101 responses
|
||
// as the final response.
|
||
func TestTransportTreat101Terminal(t *testing.T) {
|
||
run(t, testTransportTreat101Terminal, []testMode{http1Mode})
|
||
}
|
||
func testTransportTreat101Terminal(t *testing.T, mode testMode) {
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
conn, buf, _ := w.(Hijacker).Hijack()
|
||
buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
|
||
buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
|
||
buf.Flush()
|
||
conn.Close()
|
||
}))
|
||
res, err := cst.c.Get(cst.ts.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer res.Body.Close()
|
||
if res.StatusCode != StatusSwitchingProtocols {
|
||
t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
|
||
}
|
||
}
|
||
|
||
type proxyFromEnvTest struct {
|
||
req string // URL to fetch; blank means "http://example.com"
|
||
|
||
env string // HTTP_PROXY
|
||
httpsenv string // HTTPS_PROXY
|
||
noenv string // NO_PROXY
|
||
reqmeth string // REQUEST_METHOD
|
||
|
||
want string
|
||
wanterr error
|
||
}
|
||
|
||
func (t proxyFromEnvTest) String() string {
|
||
var buf strings.Builder
|
||
space := func() {
|
||
if buf.Len() > 0 {
|
||
buf.WriteByte(' ')
|
||
}
|
||
}
|
||
if t.env != "" {
|
||
fmt.Fprintf(&buf, "http_proxy=%q", t.env)
|
||
}
|
||
if t.httpsenv != "" {
|
||
space()
|
||
fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
|
||
}
|
||
if t.noenv != "" {
|
||
space()
|
||
fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
|
||
}
|
||
if t.reqmeth != "" {
|
||
space()
|
||
fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
|
||
}
|
||
req := "http://example.com"
|
||
if t.req != "" {
|
||
req = t.req
|
||
}
|
||
space()
|
||
fmt.Fprintf(&buf, "req=%q", req)
|
||
return strings.TrimSpace(buf.String())
|
||
}
|
||
|
||
var proxyFromEnvTests = []proxyFromEnvTest{
|
||
{env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
|
||
{env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
|
||
{env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
|
||
{env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
|
||
{env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
|
||
{env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
|
||
{env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
|
||
{env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"},
|
||
|
||
// Don't use secure for http
|
||
{req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
|
||
// Use secure for https.
|
||
{req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
|
||
{req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
|
||
|
||
// Issue 16405: don't use HTTP_PROXY in a CGI environment,
|
||
// where HTTP_PROXY can be attacker-controlled.
|
||
{env: "http://10.1.2.3:8080", reqmeth: "POST",
|
||
want: "<nil>",
|
||
wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
|
||
|
||
{want: "<nil>"},
|
||
|
||
{noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
|
||
{noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
|
||
{noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
|
||
{noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
|
||
{noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
|
||
}
|
||
|
||
func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
|
||
t.Helper()
|
||
reqURL := tt.req
|
||
if reqURL == "" {
|
||
reqURL = "http://example.com"
|
||
}
|
||
req, _ := NewRequest("GET", reqURL, nil)
|
||
url, err := proxyForRequest(req)
|
||
if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
|
||
t.Errorf("%v: got error = %q, want %q", tt, g, e)
|
||
return
|
||
}
|
||
if got := fmt.Sprintf("%s", url); got != tt.want {
|
||
t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
|
||
}
|
||
}
|
||
|
||
func TestProxyFromEnvironment(t *testing.T) {
|
||
ResetProxyEnv()
|
||
defer ResetProxyEnv()
|
||
for _, tt := range proxyFromEnvTests {
|
||
testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
|
||
os.Setenv("HTTP_PROXY", tt.env)
|
||
os.Setenv("HTTPS_PROXY", tt.httpsenv)
|
||
os.Setenv("NO_PROXY", tt.noenv)
|
||
os.Setenv("REQUEST_METHOD", tt.reqmeth)
|
||
ResetCachedEnvironment()
|
||
return ProxyFromEnvironment(req)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestProxyFromEnvironmentLowerCase(t *testing.T) {
|
||
ResetProxyEnv()
|
||
defer ResetProxyEnv()
|
||
for _, tt := range proxyFromEnvTests {
|
||
testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
|
||
os.Setenv("http_proxy", tt.env)
|
||
os.Setenv("https_proxy", tt.httpsenv)
|
||
os.Setenv("no_proxy", tt.noenv)
|
||
os.Setenv("REQUEST_METHOD", tt.reqmeth)
|
||
ResetCachedEnvironment()
|
||
return ProxyFromEnvironment(req)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestIdleConnChannelLeak(t *testing.T) {
|
||
run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
|
||
}
|
||
func testIdleConnChannelLeak(t *testing.T, mode testMode) {
|
||
// Not parallel: uses global test hooks.
|
||
var mu sync.Mutex
|
||
var n int
|
||
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
mu.Lock()
|
||
n++
|
||
mu.Unlock()
|
||
})).ts
|
||
|
||
const nReqs = 5
|
||
didRead := make(chan bool, nReqs)
|
||
SetReadLoopBeforeNextReadHook(func() { didRead <- true })
|
||
defer SetReadLoopBeforeNextReadHook(nil)
|
||
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
tr.Dial = func(netw, addr string) (net.Conn, error) {
|
||
return net.Dial(netw, ts.Listener.Addr().String())
|
||
}
|
||
|
||
// First, without keep-alives.
|
||
for _, disableKeep := range []bool{true, false} {
|
||
tr.DisableKeepAlives = disableKeep
|
||
for i := 0; i < nReqs; i++ {
|
||
_, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
// Note: no res.Body.Close is needed here, since the
|
||
// response Content-Length is zero. Perhaps the test
|
||
// should be more explicit and use a HEAD, but tests
|
||
// elsewhere guarantee that zero byte responses generate
|
||
// a "Content-Length: 0" instead of chunking.
|
||
}
|
||
|
||
// At this point, each of the 5 Transport.readLoop goroutines
|
||
// are scheduling noting that there are no response bodies (see
|
||
// earlier comment), and are then calling putIdleConn, which
|
||
// decrements this count. Usually that happens quickly, which is
|
||
// why this test has seemed to work for ages. But it's still
|
||
// racey: we have wait for them to finish first. See Issue 10427
|
||
for i := 0; i < nReqs; i++ {
|
||
<-didRead
|
||
}
|
||
|
||
if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
|
||
t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
|
||
}
|
||
}
|
||
}
|
||
|
||
// Verify the status quo: that the Client.Post function coerces its
|
||
// body into a ReadCloser if it's a Closer, and that the Transport
|
||
// then closes it.
|
||
func TestTransportClosesRequestBody(t *testing.T) {
|
||
run(t, testTransportClosesRequestBody, []testMode{http1Mode})
|
||
}
|
||
func testTransportClosesRequestBody(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
io.Copy(io.Discard, r.Body)
|
||
})).ts
|
||
|
||
c := ts.Client()
|
||
|
||
closes := 0
|
||
|
||
res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
res.Body.Close()
|
||
if closes != 1 {
|
||
t.Errorf("closes = %d; want 1", closes)
|
||
}
|
||
}
|
||
|
||
func TestTransportTLSHandshakeTimeout(t *testing.T) {
|
||
defer afterTest(t)
|
||
if testing.Short() {
|
||
t.Skip("skipping in short mode")
|
||
}
|
||
ln := newLocalListener(t)
|
||
defer ln.Close()
|
||
testdonec := make(chan struct{})
|
||
defer close(testdonec)
|
||
|
||
go func() {
|
||
c, err := ln.Accept()
|
||
if err != nil {
|
||
t.Error(err)
|
||
return
|
||
}
|
||
<-testdonec
|
||
c.Close()
|
||
}()
|
||
|
||
tr := &Transport{
|
||
Dial: func(_, _ string) (net.Conn, error) {
|
||
return net.Dial("tcp", ln.Addr().String())
|
||
},
|
||
TLSHandshakeTimeout: 250 * time.Millisecond,
|
||
}
|
||
cl := &Client{Transport: tr}
|
||
_, err := cl.Get("https://dummy.tld/")
|
||
if err == nil {
|
||
t.Error("expected error")
|
||
return
|
||
}
|
||
ue, ok := err.(*url.Error)
|
||
if !ok {
|
||
t.Errorf("expected url.Error; got %#v", err)
|
||
return
|
||
}
|
||
ne, ok := ue.Err.(net.Error)
|
||
if !ok {
|
||
t.Errorf("expected net.Error; got %#v", err)
|
||
return
|
||
}
|
||
if !ne.Timeout() {
|
||
t.Errorf("expected timeout error; got %v", err)
|
||
}
|
||
if !strings.Contains(err.Error(), "handshake timeout") {
|
||
t.Errorf("expected 'handshake timeout' in error; got %v", err)
|
||
}
|
||
}
|
||
|
||
// Trying to repro golang.org/issue/3514
|
||
func TestTLSServerClosesConnection(t *testing.T) {
|
||
run(t, testTLSServerClosesConnection, []testMode{https1Mode})
|
||
}
|
||
func testTLSServerClosesConnection(t *testing.T, mode testMode) {
|
||
closedc := make(chan bool, 1)
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
|
||
conn, _, _ := w.(Hijacker).Hijack()
|
||
conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
|
||
conn.Close()
|
||
closedc <- true
|
||
return
|
||
}
|
||
fmt.Fprintf(w, "hello")
|
||
})).ts
|
||
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
|
||
var nSuccess = 0
|
||
var errs []error
|
||
const trials = 20
|
||
for i := 0; i < trials; i++ {
|
||
tr.CloseIdleConnections()
|
||
res, err := c.Get(ts.URL + "/keep-alive-then-die")
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
<-closedc
|
||
slurp, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if string(slurp) != "foo" {
|
||
t.Errorf("Got %q, want foo", slurp)
|
||
}
|
||
|
||
// Now try again and see if we successfully
|
||
// pick a new connection.
|
||
res, err = c.Get(ts.URL + "/")
|
||
if err != nil {
|
||
errs = append(errs, err)
|
||
continue
|
||
}
|
||
slurp, err = io.ReadAll(res.Body)
|
||
if err != nil {
|
||
errs = append(errs, err)
|
||
continue
|
||
}
|
||
nSuccess++
|
||
}
|
||
if nSuccess > 0 {
|
||
t.Logf("successes = %d of %d", nSuccess, trials)
|
||
} else {
|
||
t.Errorf("All runs failed:")
|
||
}
|
||
for _, err := range errs {
|
||
t.Logf(" err: %v", err)
|
||
}
|
||
}
|
||
|
||
// byteFromChanReader is an io.Reader that reads a single byte at a
|
||
// time from the channel. When the channel is closed, the reader
|
||
// returns io.EOF.
|
||
type byteFromChanReader chan byte
|
||
|
||
func (c byteFromChanReader) Read(p []byte) (n int, err error) {
|
||
if len(p) == 0 {
|
||
return
|
||
}
|
||
b, ok := <-c
|
||
if !ok {
|
||
return 0, io.EOF
|
||
}
|
||
p[0] = b
|
||
return 1, nil
|
||
}
|
||
|
||
// Verifies that the Transport doesn't reuse a connection in the case
|
||
// where the server replies before the request has been fully
|
||
// written. We still honor that reply (see TestIssue3595), but don't
|
||
// send future requests on the connection because it's then in a
|
||
// questionable state.
|
||
// golang.org/issue/7569
|
||
func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
|
||
run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
|
||
}
|
||
func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
|
||
defer func(d time.Duration) {
|
||
*MaxWriteWaitBeforeConnReuse = d
|
||
}(*MaxWriteWaitBeforeConnReuse)
|
||
*MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
|
||
var sconn struct {
|
||
sync.Mutex
|
||
c net.Conn
|
||
}
|
||
var getOkay bool
|
||
var copying sync.WaitGroup
|
||
closeConn := func() {
|
||
sconn.Lock()
|
||
defer sconn.Unlock()
|
||
if sconn.c != nil {
|
||
sconn.c.Close()
|
||
sconn.c = nil
|
||
if !getOkay {
|
||
t.Logf("Closed server connection")
|
||
}
|
||
}
|
||
}
|
||
defer func() {
|
||
closeConn()
|
||
copying.Wait()
|
||
}()
|
||
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
if r.Method == "GET" {
|
||
io.WriteString(w, "bar")
|
||
return
|
||
}
|
||
conn, _, _ := w.(Hijacker).Hijack()
|
||
sconn.Lock()
|
||
sconn.c = conn
|
||
sconn.Unlock()
|
||
conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive
|
||
|
||
copying.Add(1)
|
||
go func() {
|
||
io.Copy(io.Discard, conn)
|
||
copying.Done()
|
||
}()
|
||
})).ts
|
||
c := ts.Client()
|
||
|
||
const bodySize = 256 << 10
|
||
finalBit := make(byteFromChanReader, 1)
|
||
req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
|
||
req.ContentLength = bodySize
|
||
res, err := c.Do(req)
|
||
if err := wantBody(res, err, "foo"); err != nil {
|
||
t.Errorf("POST response: %v", err)
|
||
}
|
||
|
||
res, err = c.Get(ts.URL)
|
||
if err := wantBody(res, err, "bar"); err != nil {
|
||
t.Errorf("GET response: %v", err)
|
||
return
|
||
}
|
||
getOkay = true // suppress test noise
|
||
finalBit <- 'x' // unblock the writeloop of the first Post
|
||
close(finalBit)
|
||
}
|
||
|
||
// Tests that we don't leak Transport persistConn.readLoop goroutines
|
||
// when a server hangs up immediately after saying it would keep-alive.
|
||
func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
|
||
func testTransportIssue10457(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
// Send a response with no body, keep-alive
|
||
// (implicit), and then lie and immediately close the
|
||
// connection. This forces the Transport's readLoop to
|
||
// immediately Peek an io.EOF and get to the point
|
||
// that used to hang.
|
||
conn, _, _ := w.(Hijacker).Hijack()
|
||
conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive
|
||
conn.Close()
|
||
})).ts
|
||
c := ts.Client()
|
||
|
||
res, err := c.Get(ts.URL)
|
||
if err != nil {
|
||
t.Fatalf("Get: %v", err)
|
||
}
|
||
defer res.Body.Close()
|
||
|
||
// Just a sanity check that we at least get the response. The real
|
||
// test here is that the "defer afterTest" above doesn't find any
|
||
// leaked goroutines.
|
||
if got, want := res.Header.Get("Foo"), "Bar"; got != want {
|
||
t.Errorf("Foo header = %q; want %q", got, want)
|
||
}
|
||
}
|
||
|
||
type closerFunc func() error
|
||
|
||
func (f closerFunc) Close() error { return f() }
|
||
|
||
type writerFuncConn struct {
|
||
net.Conn
|
||
write func(p []byte) (n int, err error)
|
||
}
|
||
|
||
func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
|
||
|
||
// Issues 4677, 18241, and 17844. If we try to reuse a connection that the
|
||
// server is in the process of closing, we may end up successfully writing out
|
||
// our request (or a portion of our request) only to find a connection error
|
||
// when we try to read from (or finish writing to) the socket.
|
||
//
|
||
// NOTE: we resend a request only if:
|
||
// - we reused a keep-alive connection
|
||
// - we haven't yet received any header data
|
||
// - either we wrote no bytes to the server, or the request is idempotent
|
||
//
|
||
// This automatically prevents an infinite resend loop because we'll run out of
|
||
// the cached keep-alive connections eventually.
|
||
func TestRetryRequestsOnError(t *testing.T) {
|
||
run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
|
||
}
|
||
func testRetryRequestsOnError(t *testing.T, mode testMode) {
|
||
newRequest := func(method, urlStr string, body io.Reader) *Request {
|
||
req, err := NewRequest(method, urlStr, body)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
return req
|
||
}
|
||
|
||
testCases := []struct {
|
||
name string
|
||
failureN int
|
||
failureErr error
|
||
// Note that we can't just re-use the Request object across calls to c.Do
|
||
// because we need to rewind Body between calls. (GetBody is only used to
|
||
// rewind Body on failure and redirects, not just because it's done.)
|
||
req func() *Request
|
||
reqString string
|
||
}{
|
||
{
|
||
name: "IdempotentNoBodySomeWritten",
|
||
// Believe that we've written some bytes to the server, so we know we're
|
||
// not just in the "retry when no bytes sent" case".
|
||
failureN: 1,
|
||
// Use the specific error that shouldRetryRequest looks for with idempotent requests.
|
||
failureErr: ExportErrServerClosedIdle,
|
||
req: func() *Request {
|
||
return newRequest("GET", "http://fake.golang", nil)
|
||
},
|
||
reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
|
||
},
|
||
{
|
||
name: "IdempotentGetBodySomeWritten",
|
||
// Believe that we've written some bytes to the server, so we know we're
|
||
// not just in the "retry when no bytes sent" case".
|
||
failureN: 1,
|
||
// Use the specific error that shouldRetryRequest looks for with idempotent requests.
|
||
failureErr: ExportErrServerClosedIdle,
|
||
req: func() *Request {
|
||
return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
|
||
},
|
||
reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
|
||
},
|
||
{
|
||
name: "NothingWrittenNoBody",
|
||
// It's key that we return 0 here -- that's what enables Transport to know
|
||
// that nothing was written, even though this is a non-idempotent request.
|
||
failureN: 0,
|
||
failureErr: errors.New("second write fails"),
|
||
req: func() *Request {
|
||
return newRequest("DELETE", "http://fake.golang", nil)
|
||
},
|
||
reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
|
||
},
|
||
{
|
||
name: "NothingWrittenGetBody",
|
||
// It's key that we return 0 here -- that's what enables Transport to know
|
||
// that nothing was written, even though this is a non-idempotent request.
|
||
failureN: 0,
|
||
failureErr: errors.New("second write fails"),
|
||
// Note that NewRequest will set up GetBody for strings.Reader, which is
|
||
// required for the retry to occur
|
||
req: func() *Request {
|
||
return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
|
||
},
|
||
reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
|
||
},
|
||
}
|
||
|
||
for _, tc := range testCases {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
var (
|
||
mu sync.Mutex
|
||
logbuf strings.Builder
|
||
)
|
||
logf := func(format string, args ...any) {
|
||
mu.Lock()
|
||
defer mu.Unlock()
|
||
fmt.Fprintf(&logbuf, format, args...)
|
||
logbuf.WriteByte('\n')
|
||
}
|
||
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
logf("Handler")
|
||
w.Header().Set("X-Status", "ok")
|
||
})).ts
|
||
|
||
var writeNumAtomic int32
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
|
||
logf("Dial")
|
||
c, err := net.Dial(network, ts.Listener.Addr().String())
|
||
if err != nil {
|
||
logf("Dial error: %v", err)
|
||
return nil, err
|
||
}
|
||
return &writerFuncConn{
|
||
Conn: c,
|
||
write: func(p []byte) (n int, err error) {
|
||
if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
|
||
logf("intentional write failure")
|
||
return tc.failureN, tc.failureErr
|
||
}
|
||
logf("Write(%q)", p)
|
||
return c.Write(p)
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
SetRoundTripRetried(func() {
|
||
logf("Retried.")
|
||
})
|
||
defer SetRoundTripRetried(nil)
|
||
|
||
for i := 0; i < 3; i++ {
|
||
t0 := time.Now()
|
||
req := tc.req()
|
||
res, err := c.Do(req)
|
||
if err != nil {
|
||
if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
|
||
mu.Lock()
|
||
got := logbuf.String()
|
||
mu.Unlock()
|
||
t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
|
||
}
|
||
t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
|
||
}
|
||
res.Body.Close()
|
||
if res.Request != req {
|
||
t.Errorf("Response.Request != original request; want identical Request")
|
||
}
|
||
}
|
||
|
||
mu.Lock()
|
||
got := logbuf.String()
|
||
mu.Unlock()
|
||
want := fmt.Sprintf(`Dial
|
||
Write("%s")
|
||
Handler
|
||
intentional write failure
|
||
Retried.
|
||
Dial
|
||
Write("%s")
|
||
Handler
|
||
Write("%s")
|
||
Handler
|
||
`, tc.reqString, tc.reqString, tc.reqString)
|
||
if got != want {
|
||
t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// Issue 6981
|
||
func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
|
||
func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
|
||
readBody := make(chan error, 1)
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
_, err := io.ReadAll(r.Body)
|
||
readBody <- err
|
||
})).ts
|
||
c := ts.Client()
|
||
fakeErr := errors.New("fake error")
|
||
didClose := make(chan bool, 1)
|
||
req, _ := NewRequest("POST", ts.URL, struct {
|
||
io.Reader
|
||
io.Closer
|
||
}{
|
||
io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
|
||
closerFunc(func() error {
|
||
select {
|
||
case didClose <- true:
|
||
default:
|
||
}
|
||
return nil
|
||
}),
|
||
})
|
||
res, err := c.Do(req)
|
||
if res != nil {
|
||
defer res.Body.Close()
|
||
}
|
||
if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
|
||
t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
|
||
}
|
||
if err := <-readBody; err == nil {
|
||
t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
|
||
}
|
||
select {
|
||
case <-didClose:
|
||
default:
|
||
t.Errorf("didn't see Body.Close")
|
||
}
|
||
}
|
||
|
||
func TestTransportDialTLS(t *testing.T) {
|
||
run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
|
||
}
|
||
func testTransportDialTLS(t *testing.T, mode testMode) {
|
||
var mu sync.Mutex // guards following
|
||
var gotReq, didDial bool
|
||
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
mu.Lock()
|
||
gotReq = true
|
||
mu.Unlock()
|
||
})).ts
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
|
||
mu.Lock()
|
||
didDial = true
|
||
mu.Unlock()
|
||
c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return c, c.Handshake()
|
||
}
|
||
|
||
res, err := c.Get(ts.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
res.Body.Close()
|
||
mu.Lock()
|
||
if !gotReq {
|
||
t.Error("didn't get request")
|
||
}
|
||
if !didDial {
|
||
t.Error("didn't use dial hook")
|
||
}
|
||
}
|
||
|
||
func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
|
||
func testTransportDialContext(t *testing.T, mode testMode) {
|
||
ctxKey := "some-key"
|
||
ctxValue := "some-value"
|
||
var (
|
||
mu sync.Mutex // guards following
|
||
gotReq bool
|
||
gotCtxValue any
|
||
)
|
||
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
mu.Lock()
|
||
gotReq = true
|
||
mu.Unlock()
|
||
})).ts
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
||
mu.Lock()
|
||
gotCtxValue = ctx.Value(ctxKey)
|
||
mu.Unlock()
|
||
return net.Dial(netw, addr)
|
||
}
|
||
|
||
req, err := NewRequest("GET", ts.URL, nil)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
|
||
res, err := c.Do(req.WithContext(ctx))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
res.Body.Close()
|
||
mu.Lock()
|
||
if !gotReq {
|
||
t.Error("didn't get request")
|
||
}
|
||
if got, want := gotCtxValue, ctxValue; got != want {
|
||
t.Errorf("got context with value %v, want %v", got, want)
|
||
}
|
||
}
|
||
|
||
func TestTransportDialTLSContext(t *testing.T) {
|
||
run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
|
||
}
|
||
func testTransportDialTLSContext(t *testing.T, mode testMode) {
|
||
ctxKey := "some-key"
|
||
ctxValue := "some-value"
|
||
var (
|
||
mu sync.Mutex // guards following
|
||
gotReq bool
|
||
gotCtxValue any
|
||
)
|
||
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
mu.Lock()
|
||
gotReq = true
|
||
mu.Unlock()
|
||
})).ts
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
||
mu.Lock()
|
||
gotCtxValue = ctx.Value(ctxKey)
|
||
mu.Unlock()
|
||
c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return c, c.HandshakeContext(ctx)
|
||
}
|
||
|
||
req, err := NewRequest("GET", ts.URL, nil)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
|
||
res, err := c.Do(req.WithContext(ctx))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
res.Body.Close()
|
||
mu.Lock()
|
||
if !gotReq {
|
||
t.Error("didn't get request")
|
||
}
|
||
if got, want := gotCtxValue, ctxValue; got != want {
|
||
t.Errorf("got context with value %v, want %v", got, want)
|
||
}
|
||
}
|
||
|
||
// Test for issue 8755
|
||
// Ensure that if a proxy returns an error, it is exposed by RoundTrip
|
||
func TestRoundTripReturnsProxyError(t *testing.T) {
|
||
badProxy := func(*Request) (*url.URL, error) {
|
||
return nil, errors.New("errorMessage")
|
||
}
|
||
|
||
tr := &Transport{Proxy: badProxy}
|
||
|
||
req, _ := NewRequest("GET", "http://example.com", nil)
|
||
|
||
_, err := tr.RoundTrip(req)
|
||
|
||
if err == nil {
|
||
t.Error("Expected proxy error to be returned by RoundTrip")
|
||
}
|
||
}
|
||
|
||
// tests that putting an idle conn after a call to CloseIdleConns does return it
|
||
func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
|
||
tr := &Transport{}
|
||
wantIdle := func(when string, n int) bool {
|
||
got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn
|
||
if got == n {
|
||
return true
|
||
}
|
||
t.Errorf("%s: idle conns = %d; want %d", when, got, n)
|
||
return false
|
||
}
|
||
wantIdle("start", 0)
|
||
if !tr.PutIdleTestConn("http", "example.com") {
|
||
t.Fatal("put failed")
|
||
}
|
||
if !tr.PutIdleTestConn("http", "example.com") {
|
||
t.Fatal("second put failed")
|
||
}
|
||
wantIdle("after put", 2)
|
||
tr.CloseIdleConnections()
|
||
if !tr.IsIdleForTesting() {
|
||
t.Error("should be idle after CloseIdleConnections")
|
||
}
|
||
wantIdle("after close idle", 0)
|
||
if tr.PutIdleTestConn("http", "example.com") {
|
||
t.Fatal("put didn't fail")
|
||
}
|
||
wantIdle("after second put", 0)
|
||
|
||
tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode
|
||
if tr.IsIdleForTesting() {
|
||
t.Error("shouldn't be idle after QueueForIdleConnForTesting")
|
||
}
|
||
if !tr.PutIdleTestConn("http", "example.com") {
|
||
t.Fatal("after re-activation")
|
||
}
|
||
wantIdle("after final put", 1)
|
||
}
|
||
|
||
// Test for issue 34282
|
||
// Ensure that getConn doesn't call the GotConn trace hook on an HTTP/2 idle conn
|
||
func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
|
||
tr := &Transport{}
|
||
wantIdle := func(when string, n int) bool {
|
||
got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2
|
||
if got == n {
|
||
return true
|
||
}
|
||
t.Errorf("%s: idle conns = %d; want %d", when, got, n)
|
||
return false
|
||
}
|
||
wantIdle("start", 0)
|
||
alt := funcRoundTripper(func() {})
|
||
if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
|
||
t.Fatal("put failed")
|
||
}
|
||
wantIdle("after put", 1)
|
||
ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
|
||
GotConn: func(httptrace.GotConnInfo) {
|
||
// tr.getConn should leave it for the HTTP/2 alt to call GotConn.
|
||
t.Error("GotConn called")
|
||
},
|
||
})
|
||
req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
|
||
_, err := tr.RoundTrip(req)
|
||
if err != errFakeRoundTrip {
|
||
t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
|
||
}
|
||
wantIdle("after round trip", 1)
|
||
}
|
||
|
||
func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
|
||
run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode})
|
||
}
|
||
func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) {
|
||
if testing.Short() {
|
||
t.Skip("skipping in short mode")
|
||
}
|
||
|
||
timeout := 1 * time.Millisecond
|
||
retry := true
|
||
for retry {
|
||
trFunc := func(tr *Transport) {
|
||
tr.MaxConnsPerHost = 1
|
||
tr.MaxIdleConnsPerHost = 1
|
||
tr.IdleConnTimeout = timeout
|
||
}
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
|
||
|
||
retry = false
|
||
tooShort := func(err error) bool {
|
||
if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
|
||
return false
|
||
}
|
||
if !retry {
|
||
t.Helper()
|
||
t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout)
|
||
timeout *= 2
|
||
retry = true
|
||
cst.close()
|
||
}
|
||
return true
|
||
}
|
||
|
||
if _, err := cst.c.Get(cst.ts.URL); err != nil {
|
||
if tooShort(err) {
|
||
continue
|
||
}
|
||
t.Fatalf("got error: %s", err)
|
||
}
|
||
|
||
time.Sleep(10 * timeout)
|
||
if _, err := cst.c.Get(cst.ts.URL); err != nil {
|
||
if tooShort(err) {
|
||
continue
|
||
}
|
||
t.Fatalf("got error: %s", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
// This tests that a client requesting a content range won't also
|
||
// implicitly ask for gzip support. If they want that, they need to do it
|
||
// on their own.
|
||
// golang.org/issue/8923
|
||
func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
|
||
func testTransportRangeAndGzip(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
|
||
t.Error("Transport advertised gzip support in the Accept header")
|
||
}
|
||
if r.Header.Get("Range") == "" {
|
||
t.Error("no Range in request")
|
||
}
|
||
})).ts
|
||
c := ts.Client()
|
||
|
||
req, _ := NewRequest("GET", ts.URL, nil)
|
||
req.Header.Set("Range", "bytes=7-11")
|
||
res, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
res.Body.Close()
|
||
}
|
||
|
||
// Test for issue 10474
|
||
func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
|
||
func testTransportResponseCancelRace(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
// important that this response has a body.
|
||
var b [1024]byte
|
||
w.Write(b[:])
|
||
})).ts
|
||
tr := ts.Client().Transport.(*Transport)
|
||
|
||
req, err := NewRequest("GET", ts.URL, nil)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
res, err := tr.RoundTrip(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
// If we do an early close, Transport just throws the connection away and
|
||
// doesn't reuse it. In order to trigger the bug, it has to reuse the connection
|
||
// so read the body
|
||
if _, err := io.Copy(io.Discard, res.Body); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
req2, err := NewRequest("GET", ts.URL, nil)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
tr.CancelRequest(req)
|
||
res, err = tr.RoundTrip(req2)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
res.Body.Close()
|
||
}
|
||
|
||
// Test for issue 19248: Content-Encoding's value is case insensitive.
|
||
func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
|
||
run(t, testTransportContentEncodingCaseInsensitive)
|
||
}
|
||
func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
|
||
for _, ce := range []string{"gzip", "GZIP"} {
|
||
ce := ce
|
||
t.Run(ce, func(t *testing.T) {
|
||
const encodedString = "Hello Gopher"
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
w.Header().Set("Content-Encoding", ce)
|
||
gz := gzip.NewWriter(w)
|
||
gz.Write([]byte(encodedString))
|
||
gz.Close()
|
||
})).ts
|
||
|
||
res, err := ts.Client().Get(ts.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
body, err := io.ReadAll(res.Body)
|
||
res.Body.Close()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if string(body) != encodedString {
|
||
t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// https://go.dev/issue/49621
|
||
func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
|
||
run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
|
||
}
|
||
func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
|
||
func(tr *Transport) {
|
||
tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
|
||
// Connection immediately returns errors.
|
||
return &funcConn{
|
||
read: func([]byte) (int, error) {
|
||
return 0, errors.New("error")
|
||
},
|
||
write: func([]byte) (int, error) {
|
||
return 0, errors.New("error")
|
||
},
|
||
}, nil
|
||
}
|
||
},
|
||
).ts
|
||
// Set a short delay in RoundTrip to give the persistConn time to notice
|
||
// the connection is broken. We want to exercise the path where writeLoop exits
|
||
// before it reads the request to send. If this delay is too short, we may instead
|
||
// exercise the path where writeLoop accepts the request and then fails to write it.
|
||
// That's fine, so long as we get the desired path often enough.
|
||
SetEnterRoundTripHook(func() {
|
||
time.Sleep(1 * time.Millisecond)
|
||
})
|
||
defer SetEnterRoundTripHook(nil)
|
||
var closes int
|
||
_, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
|
||
if err == nil {
|
||
t.Fatalf("expected request to fail, but it did not")
|
||
}
|
||
if closes != 1 {
|
||
t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
|
||
}
|
||
}
|
||
|
||
// logWritesConn is a net.Conn that logs each Write call to writes
|
||
// and then proxies to w.
|
||
// It proxies Read calls to a reader it receives from rch.
|
||
type logWritesConn struct {
|
||
net.Conn // nil. crash on use.
|
||
|
||
w io.Writer
|
||
|
||
rch <-chan io.Reader
|
||
r io.Reader // nil until received by rch
|
||
|
||
mu sync.Mutex
|
||
writes []string
|
||
}
|
||
|
||
func (c *logWritesConn) Write(p []byte) (n int, err error) {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
c.writes = append(c.writes, string(p))
|
||
return c.w.Write(p)
|
||
}
|
||
|
||
func (c *logWritesConn) Read(p []byte) (n int, err error) {
|
||
if c.r == nil {
|
||
c.r = <-c.rch
|
||
}
|
||
return c.r.Read(p)
|
||
}
|
||
|
||
func (c *logWritesConn) Close() error { return nil }
|
||
|
||
// Issue 6574
|
||
func TestTransportFlushesBodyChunks(t *testing.T) {
|
||
defer afterTest(t)
|
||
resBody := make(chan io.Reader, 1)
|
||
connr, connw := io.Pipe() // connection pipe pair
|
||
lw := &logWritesConn{
|
||
rch: resBody,
|
||
w: connw,
|
||
}
|
||
tr := &Transport{
|
||
Dial: func(network, addr string) (net.Conn, error) {
|
||
return lw, nil
|
||
},
|
||
}
|
||
bodyr, bodyw := io.Pipe() // body pipe pair
|
||
go func() {
|
||
defer bodyw.Close()
|
||
for i := 0; i < 3; i++ {
|
||
fmt.Fprintf(bodyw, "num%d\n", i)
|
||
}
|
||
}()
|
||
resc := make(chan *Response)
|
||
go func() {
|
||
req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
|
||
req.Header.Set("User-Agent", "x") // known value for test
|
||
res, err := tr.RoundTrip(req)
|
||
if err != nil {
|
||
t.Errorf("RoundTrip: %v", err)
|
||
close(resc)
|
||
return
|
||
}
|
||
resc <- res
|
||
|
||
}()
|
||
// Fully consume the request before checking the Write log vs. want.
|
||
req, err := ReadRequest(bufio.NewReader(connr))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
io.Copy(io.Discard, req.Body)
|
||
|
||
// Unblock the transport's roundTrip goroutine.
|
||
resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
|
||
res, ok := <-resc
|
||
if !ok {
|
||
return
|
||
}
|
||
defer res.Body.Close()
|
||
|
||
want := []string{
|
||
"POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
|
||
"5\r\nnum0\n\r\n",
|
||
"5\r\nnum1\n\r\n",
|
||
"5\r\nnum2\n\r\n",
|
||
"0\r\n\r\n",
|
||
}
|
||
if !slices.Equal(lw.writes, want) {
|
||
t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
|
||
}
|
||
}
|
||
|
||
// Issue 22088: flush Transport request headers if we're not sure the body won't block on read.
|
||
func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
|
||
func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
|
||
gotReq := make(chan struct{})
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
close(gotReq)
|
||
}))
|
||
|
||
pr, pw := io.Pipe()
|
||
req, err := NewRequest("POST", cst.ts.URL, pr)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
gotRes := make(chan struct{})
|
||
go func() {
|
||
defer close(gotRes)
|
||
res, err := cst.tr.RoundTrip(req)
|
||
if err != nil {
|
||
t.Error(err)
|
||
return
|
||
}
|
||
res.Body.Close()
|
||
}()
|
||
|
||
<-gotReq
|
||
pw.Close()
|
||
<-gotRes
|
||
}
|
||
|
||
type wgReadCloser struct {
|
||
io.Reader
|
||
wg *sync.WaitGroup
|
||
closed bool
|
||
}
|
||
|
||
func (c *wgReadCloser) Close() error {
|
||
if c.closed {
|
||
return net.ErrClosed
|
||
}
|
||
c.closed = true
|
||
c.wg.Done()
|
||
return nil
|
||
}
|
||
|
||
// Issue 11745.
|
||
func TestTransportPrefersResponseOverWriteError(t *testing.T) {
|
||
// Not parallel: modifies the global rstAvoidanceDelay.
|
||
run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
|
||
}
|
||
func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
|
||
if testing.Short() {
|
||
t.Skip("skipping in short mode")
|
||
}
|
||
|
||
runTimeSensitiveTest(t, []time.Duration{
|
||
1 * time.Millisecond,
|
||
5 * time.Millisecond,
|
||
10 * time.Millisecond,
|
||
50 * time.Millisecond,
|
||
100 * time.Millisecond,
|
||
500 * time.Millisecond,
|
||
time.Second,
|
||
5 * time.Second,
|
||
}, func(t *testing.T, timeout time.Duration) error {
|
||
SetRSTAvoidanceDelay(t, timeout)
|
||
t.Logf("set RST avoidance delay to %v", timeout)
|
||
|
||
const contentLengthLimit = 1024 * 1024 // 1MB
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
if r.ContentLength >= contentLengthLimit {
|
||
w.WriteHeader(StatusBadRequest)
|
||
r.Body.Close()
|
||
return
|
||
}
|
||
w.WriteHeader(StatusOK)
|
||
}))
|
||
// We need to close cst explicitly here so that in-flight server
|
||
// requests don't race with the call to SetRSTAvoidanceDelay for a retry.
|
||
defer cst.close()
|
||
ts := cst.ts
|
||
c := ts.Client()
|
||
|
||
count := 100
|
||
|
||
bigBody := strings.Repeat("a", contentLengthLimit*2)
|
||
var wg sync.WaitGroup
|
||
defer wg.Wait()
|
||
getBody := func() (io.ReadCloser, error) {
|
||
wg.Add(1)
|
||
body := &wgReadCloser{
|
||
Reader: strings.NewReader(bigBody),
|
||
wg: &wg,
|
||
}
|
||
return body, nil
|
||
}
|
||
|
||
for i := 0; i < count; i++ {
|
||
reqBody, _ := getBody()
|
||
req, err := NewRequest("PUT", ts.URL, reqBody)
|
||
if err != nil {
|
||
reqBody.Close()
|
||
t.Fatal(err)
|
||
}
|
||
req.ContentLength = int64(len(bigBody))
|
||
req.GetBody = getBody
|
||
|
||
resp, err := c.Do(req)
|
||
if err != nil {
|
||
return fmt.Errorf("Do %d: %v", i, err)
|
||
} else {
|
||
resp.Body.Close()
|
||
if resp.StatusCode != 400 {
|
||
t.Errorf("Expected status code 400, got %v", resp.Status)
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func TestTransportAutomaticHTTP2(t *testing.T) {
|
||
testTransportAutoHTTP(t, &Transport{}, true)
|
||
}
|
||
|
||
func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
|
||
testTransportAutoHTTP(t, &Transport{
|
||
ForceAttemptHTTP2: true,
|
||
TLSClientConfig: new(tls.Config),
|
||
}, true)
|
||
}
|
||
|
||
// golang.org/issue/14391: also check DefaultTransport
|
||
func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
|
||
testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
|
||
}
|
||
|
||
func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
|
||
testTransportAutoHTTP(t, &Transport{
|
||
TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
|
||
}, false)
|
||
}
|
||
|
||
func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
|
||
testTransportAutoHTTP(t, &Transport{
|
||
TLSClientConfig: new(tls.Config),
|
||
}, false)
|
||
}
|
||
|
||
func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
|
||
testTransportAutoHTTP(t, &Transport{
|
||
ExpectContinueTimeout: 1 * time.Second,
|
||
}, true)
|
||
}
|
||
|
||
func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
|
||
var d net.Dialer
|
||
testTransportAutoHTTP(t, &Transport{
|
||
Dial: d.Dial,
|
||
}, false)
|
||
}
|
||
|
||
func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
|
||
var d net.Dialer
|
||
testTransportAutoHTTP(t, &Transport{
|
||
DialContext: d.DialContext,
|
||
}, false)
|
||
}
|
||
|
||
func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
|
||
testTransportAutoHTTP(t, &Transport{
|
||
DialTLS: func(network, addr string) (net.Conn, error) {
|
||
panic("unused")
|
||
},
|
||
}, false)
|
||
}
|
||
|
||
func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
|
||
CondSkipHTTP2(t)
|
||
_, err := tr.RoundTrip(new(Request))
|
||
if err == nil {
|
||
t.Error("expected error from RoundTrip")
|
||
}
|
||
if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
|
||
t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
|
||
}
|
||
}
|
||
|
||
// Issue 13633: there was a race where we returned bodyless responses
|
||
// to callers before recycling the persistent connection, which meant
|
||
// a client doing two subsequent requests could end up on different
|
||
// connections. It's somewhat harmless but enough tests assume it's
|
||
// not true in order to test other things that it's worth fixing.
|
||
// Plus it's nice to be consistent and not have timing-dependent
|
||
// behavior.
|
||
func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
|
||
run(t, testTransportReuseConnEmptyResponseBody)
|
||
}
|
||
func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
w.Header().Set("X-Addr", r.RemoteAddr)
|
||
// Empty response body.
|
||
}))
|
||
n := 100
|
||
if testing.Short() {
|
||
n = 10
|
||
}
|
||
var firstAddr string
|
||
for i := 0; i < n; i++ {
|
||
res, err := cst.c.Get(cst.ts.URL)
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
addr := res.Header.Get("X-Addr")
|
||
if i == 0 {
|
||
firstAddr = addr
|
||
} else if addr != firstAddr {
|
||
t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
|
||
}
|
||
res.Body.Close()
|
||
}
|
||
}
|
||
|
||
// Issue 13839
|
||
func TestNoCrashReturningTransportAltConn(t *testing.T) {
|
||
cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
ln := newLocalListener(t)
|
||
defer ln.Close()
|
||
|
||
var wg sync.WaitGroup
|
||
SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
|
||
defer SetPendingDialHooks(nil, nil)
|
||
|
||
testDone := make(chan struct{})
|
||
defer close(testDone)
|
||
go func() {
|
||
tln := tls.NewListener(ln, &tls.Config{
|
||
NextProtos: []string{"foo"},
|
||
Certificates: []tls.Certificate{cert},
|
||
})
|
||
sc, err := tln.Accept()
|
||
if err != nil {
|
||
t.Error(err)
|
||
return
|
||
}
|
||
if err := sc.(*tls.Conn).Handshake(); err != nil {
|
||
t.Error(err)
|
||
return
|
||
}
|
||
<-testDone
|
||
sc.Close()
|
||
}()
|
||
|
||
addr := ln.Addr().String()
|
||
|
||
req, _ := NewRequest("GET", "https://fake.tld/", nil)
|
||
cancel := make(chan struct{})
|
||
req.Cancel = cancel
|
||
|
||
doReturned := make(chan bool, 1)
|
||
madeRoundTripper := make(chan bool, 1)
|
||
|
||
tr := &Transport{
|
||
DisableKeepAlives: true,
|
||
TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
|
||
"foo": func(authority string, c *tls.Conn) RoundTripper {
|
||
madeRoundTripper <- true
|
||
return funcRoundTripper(func() {
|
||
t.Error("foo RoundTripper should not be called")
|
||
})
|
||
},
|
||
},
|
||
Dial: func(_, _ string) (net.Conn, error) {
|
||
panic("shouldn't be called")
|
||
},
|
||
DialTLS: func(_, _ string) (net.Conn, error) {
|
||
tc, err := tls.Dial("tcp", addr, &tls.Config{
|
||
InsecureSkipVerify: true,
|
||
NextProtos: []string{"foo"},
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if err := tc.Handshake(); err != nil {
|
||
return nil, err
|
||
}
|
||
close(cancel)
|
||
<-doReturned
|
||
return tc, nil
|
||
},
|
||
}
|
||
c := &Client{Transport: tr}
|
||
|
||
_, err = c.Do(req)
|
||
if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
|
||
t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
|
||
}
|
||
|
||
doReturned <- true
|
||
<-madeRoundTripper
|
||
wg.Wait()
|
||
}
|
||
|
||
func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
|
||
run(t, func(t *testing.T, mode testMode) {
|
||
testTransportReuseConnection_Gzip(t, mode, true)
|
||
})
|
||
}
|
||
|
||
func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
|
||
run(t, func(t *testing.T, mode testMode) {
|
||
testTransportReuseConnection_Gzip(t, mode, false)
|
||
})
|
||
}
|
||
|
||
// Make sure we re-use underlying TCP connection for gzipped responses too.
|
||
func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
|
||
addr := make(chan string, 2)
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
addr <- r.RemoteAddr
|
||
w.Header().Set("Content-Encoding", "gzip")
|
||
if chunked {
|
||
w.(Flusher).Flush()
|
||
}
|
||
w.Write(rgz) // arbitrary gzip response
|
||
})).ts
|
||
c := ts.Client()
|
||
|
||
trace := &httptrace.ClientTrace{
|
||
GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
|
||
GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
|
||
PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
|
||
ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
|
||
ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
|
||
}
|
||
ctx := httptrace.WithClientTrace(context.Background(), trace)
|
||
|
||
for i := 0; i < 2; i++ {
|
||
req, _ := NewRequest("GET", ts.URL, nil)
|
||
req = req.WithContext(ctx)
|
||
res, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
buf := make([]byte, len(rgz))
|
||
if n, err := io.ReadFull(res.Body, buf); err != nil {
|
||
t.Errorf("%d. ReadFull = %v, %v", i, n, err)
|
||
}
|
||
// Note: no res.Body.Close call. It should work without it,
|
||
// since the flate.Reader's internal buffering will hit EOF
|
||
// and that should be sufficient.
|
||
}
|
||
a1, a2 := <-addr, <-addr
|
||
if a1 != a2 {
|
||
t.Fatalf("didn't reuse connection")
|
||
}
|
||
}
|
||
|
||
func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
|
||
func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
|
||
if mode == http2Mode {
|
||
t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
|
||
}
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
if r.URL.Path == "/long" {
|
||
w.Header().Set("Long", strings.Repeat("a", 1<<20))
|
||
}
|
||
})).ts
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
|
||
|
||
if res, err := c.Get(ts.URL); err != nil {
|
||
t.Fatal(err)
|
||
} else {
|
||
res.Body.Close()
|
||
}
|
||
|
||
res, err := c.Get(ts.URL + "/long")
|
||
if err == nil {
|
||
defer res.Body.Close()
|
||
var n int64
|
||
for k, vv := range res.Header {
|
||
for _, v := range vv {
|
||
n += int64(len(k)) + int64(len(v))
|
||
}
|
||
}
|
||
t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
|
||
}
|
||
if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
|
||
t.Errorf("got error: %v; want %q", err, want)
|
||
}
|
||
}
|
||
|
||
func TestTransportEventTrace(t *testing.T) {
|
||
run(t, func(t *testing.T, mode testMode) {
|
||
testTransportEventTrace(t, mode, false)
|
||
}, testNotParallel)
|
||
}
|
||
|
||
// test a non-nil httptrace.ClientTrace but with all hooks set to zero.
|
||
func TestTransportEventTrace_NoHooks(t *testing.T) {
|
||
run(t, func(t *testing.T, mode testMode) {
|
||
testTransportEventTrace(t, mode, true)
|
||
}, testNotParallel)
|
||
}
|
||
|
||
func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
|
||
const resBody = "some body"
|
||
gotWroteReqEvent := make(chan struct{}, 500)
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
if r.Method == "GET" {
|
||
// Do nothing for the second request.
|
||
return
|
||
}
|
||
if _, err := io.ReadAll(r.Body); err != nil {
|
||
t.Error(err)
|
||
}
|
||
if !noHooks {
|
||
<-gotWroteReqEvent
|
||
}
|
||
io.WriteString(w, resBody)
|
||
}), func(tr *Transport) {
|
||
if tr.TLSClientConfig != nil {
|
||
tr.TLSClientConfig.InsecureSkipVerify = true
|
||
}
|
||
})
|
||
defer cst.close()
|
||
|
||
cst.tr.ExpectContinueTimeout = 1 * time.Second
|
||
|
||
var mu sync.Mutex // guards buf
|
||
var buf strings.Builder
|
||
logf := func(format string, args ...any) {
|
||
mu.Lock()
|
||
defer mu.Unlock()
|
||
fmt.Fprintf(&buf, format, args...)
|
||
buf.WriteByte('\n')
|
||
}
|
||
|
||
addrStr := cst.ts.Listener.Addr().String()
|
||
ip, port, err := net.SplitHostPort(addrStr)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
// Install a fake DNS server.
|
||
ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
|
||
if host != "dns-is-faked.golang" {
|
||
t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
|
||
return nil, nil
|
||
}
|
||
return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
|
||
})
|
||
|
||
body := "some body"
|
||
req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
|
||
req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
|
||
trace := &httptrace.ClientTrace{
|
||
GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
|
||
GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
|
||
GotFirstResponseByte: func() { logf("first response byte") },
|
||
PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
|
||
DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
|
||
DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
|
||
ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
|
||
ConnectDone: func(network, addr string, err error) {
|
||
if err != nil {
|
||
t.Errorf("ConnectDone: %v", err)
|
||
}
|
||
logf("ConnectDone: connected to %s %s = %v", network, addr, err)
|
||
},
|
||
WroteHeaderField: func(key string, value []string) {
|
||
logf("WroteHeaderField: %s: %v", key, value)
|
||
},
|
||
WroteHeaders: func() {
|
||
logf("WroteHeaders")
|
||
},
|
||
Wait100Continue: func() { logf("Wait100Continue") },
|
||
Got100Continue: func() { logf("Got100Continue") },
|
||
WroteRequest: func(e httptrace.WroteRequestInfo) {
|
||
logf("WroteRequest: %+v", e)
|
||
gotWroteReqEvent <- struct{}{}
|
||
},
|
||
}
|
||
if mode == http2Mode {
|
||
trace.TLSHandshakeStart = func() { logf("tls handshake start") }
|
||
trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
|
||
logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
|
||
}
|
||
}
|
||
if noHooks {
|
||
// zero out all func pointers, trying to get some path to crash
|
||
*trace = httptrace.ClientTrace{}
|
||
}
|
||
req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
|
||
|
||
req.Header.Set("Expect", "100-continue")
|
||
res, err := cst.c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
logf("got roundtrip.response")
|
||
slurp, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
logf("consumed body")
|
||
if string(slurp) != resBody || res.StatusCode != 200 {
|
||
t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
|
||
}
|
||
res.Body.Close()
|
||
|
||
if noHooks {
|
||
// Done at this point. Just testing a full HTTP
|
||
// requests can happen with a trace pointing to a zero
|
||
// ClientTrace, full of nil func pointers.
|
||
return
|
||
}
|
||
|
||
mu.Lock()
|
||
got := buf.String()
|
||
mu.Unlock()
|
||
|
||
wantOnce := func(sub string) {
|
||
if strings.Count(got, sub) != 1 {
|
||
t.Errorf("expected substring %q exactly once in output.", sub)
|
||
}
|
||
}
|
||
wantOnceOrMore := func(sub string) {
|
||
if strings.Count(got, sub) == 0 {
|
||
t.Errorf("expected substring %q at least once in output.", sub)
|
||
}
|
||
}
|
||
wantOnce("Getting conn for dns-is-faked.golang:" + port)
|
||
wantOnce("DNS start: {Host:dns-is-faked.golang}")
|
||
wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
|
||
wantOnce("got conn: {")
|
||
wantOnceOrMore("Connecting to tcp " + addrStr)
|
||
wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
|
||
wantOnce("Reused:false WasIdle:false IdleTime:0s")
|
||
wantOnce("first response byte")
|
||
if mode == http2Mode {
|
||
wantOnce("tls handshake start")
|
||
wantOnce("tls handshake done")
|
||
} else {
|
||
wantOnce("PutIdleConn = <nil>")
|
||
wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
|
||
// TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the
|
||
// WroteHeaderField hook is not yet implemented in h2.)
|
||
wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
|
||
wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
|
||
wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
|
||
wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
|
||
}
|
||
wantOnce("WroteHeaders")
|
||
wantOnce("Wait100Continue")
|
||
wantOnce("Got100Continue")
|
||
wantOnce("WroteRequest: {Err:<nil>}")
|
||
if strings.Contains(got, " to udp ") {
|
||
t.Errorf("should not see UDP (DNS) connections")
|
||
}
|
||
if t.Failed() {
|
||
t.Errorf("Output:\n%s", got)
|
||
}
|
||
|
||
// And do a second request:
|
||
req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
|
||
req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
|
||
res, err = cst.c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if res.StatusCode != 200 {
|
||
t.Fatal(res.Status)
|
||
}
|
||
res.Body.Close()
|
||
|
||
mu.Lock()
|
||
got = buf.String()
|
||
mu.Unlock()
|
||
|
||
sub := "Getting conn for dns-is-faked.golang:"
|
||
if gotn, want := strings.Count(got, sub), 2; gotn != want {
|
||
t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
|
||
}
|
||
|
||
}
|
||
|
||
func TestTransportEventTraceTLSVerify(t *testing.T) {
|
||
run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
|
||
}
|
||
func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
|
||
var mu sync.Mutex
|
||
var buf strings.Builder
|
||
logf := func(format string, args ...any) {
|
||
mu.Lock()
|
||
defer mu.Unlock()
|
||
fmt.Fprintf(&buf, format, args...)
|
||
buf.WriteByte('\n')
|
||
}
|
||
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
t.Error("Unexpected request")
|
||
}), func(ts *httptest.Server) {
|
||
ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
|
||
logf("%s", p)
|
||
return len(p), nil
|
||
}), "", 0)
|
||
}).ts
|
||
|
||
certpool := x509.NewCertPool()
|
||
certpool.AddCert(ts.Certificate())
|
||
|
||
c := &Client{Transport: &Transport{
|
||
TLSClientConfig: &tls.Config{
|
||
ServerName: "dns-is-faked.golang",
|
||
RootCAs: certpool,
|
||
},
|
||
}}
|
||
|
||
trace := &httptrace.ClientTrace{
|
||
TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
|
||
TLSHandshakeDone: func(s tls.ConnectionState, err error) {
|
||
logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
|
||
},
|
||
}
|
||
|
||
req, _ := NewRequest("GET", ts.URL, nil)
|
||
req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
|
||
_, err := c.Do(req)
|
||
if err == nil {
|
||
t.Error("Expected request to fail TLS verification")
|
||
}
|
||
|
||
mu.Lock()
|
||
got := buf.String()
|
||
mu.Unlock()
|
||
|
||
wantOnce := func(sub string) {
|
||
if strings.Count(got, sub) != 1 {
|
||
t.Errorf("expected substring %q exactly once in output.", sub)
|
||
}
|
||
}
|
||
|
||
wantOnce("TLSHandshakeStart")
|
||
wantOnce("TLSHandshakeDone")
|
||
wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
|
||
|
||
if t.Failed() {
|
||
t.Errorf("Output:\n%s", got)
|
||
}
|
||
}
|
||
|
||
var isDNSHijacked = sync.OnceValue(func() bool {
|
||
addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
|
||
return len(addrs) != 0
|
||
})
|
||
|
||
func skipIfDNSHijacked(t *testing.T) {
|
||
// Skip this test if the user is using a shady/ISP
|
||
// DNS server hijacking queries.
|
||
// See issues 16732, 16716.
|
||
if isDNSHijacked() {
|
||
t.Skip("skipping; test requires non-hijacking DNS server")
|
||
}
|
||
}
|
||
|
||
func TestTransportEventTraceRealDNS(t *testing.T) {
|
||
skipIfDNSHijacked(t)
|
||
defer afterTest(t)
|
||
tr := &Transport{}
|
||
defer tr.CloseIdleConnections()
|
||
c := &Client{Transport: tr}
|
||
|
||
var mu sync.Mutex // guards buf
|
||
var buf strings.Builder
|
||
logf := func(format string, args ...any) {
|
||
mu.Lock()
|
||
defer mu.Unlock()
|
||
fmt.Fprintf(&buf, format, args...)
|
||
buf.WriteByte('\n')
|
||
}
|
||
|
||
req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
|
||
trace := &httptrace.ClientTrace{
|
||
DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
|
||
DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
|
||
ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
|
||
ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
|
||
}
|
||
req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
|
||
|
||
resp, err := c.Do(req)
|
||
if err == nil {
|
||
resp.Body.Close()
|
||
t.Fatal("expected error during DNS lookup")
|
||
}
|
||
|
||
mu.Lock()
|
||
got := buf.String()
|
||
mu.Unlock()
|
||
|
||
wantSub := func(sub string) {
|
||
if !strings.Contains(got, sub) {
|
||
t.Errorf("expected substring %q in output.", sub)
|
||
}
|
||
}
|
||
wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
|
||
wantSub("DNSDone: {Addrs:[] Err:")
|
||
if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
|
||
t.Errorf("should not see Connect events")
|
||
}
|
||
if t.Failed() {
|
||
t.Errorf("Output:\n%s", got)
|
||
}
|
||
}
|
||
|
||
// Issue 14353: port can only contain digits.
|
||
func TestTransportRejectsAlphaPort(t *testing.T) {
|
||
res, err := Get("http://dummy.tld:123foo/bar")
|
||
if err == nil {
|
||
res.Body.Close()
|
||
t.Fatal("unexpected success")
|
||
}
|
||
ue, ok := err.(*url.Error)
|
||
if !ok {
|
||
t.Fatalf("got %#v; want *url.Error", err)
|
||
}
|
||
got := ue.Err.Error()
|
||
want := `invalid port ":123foo" after host`
|
||
if got != want {
|
||
t.Errorf("got error %q; want %q", got, want)
|
||
}
|
||
}
|
||
|
||
// Test the httptrace.TLSHandshake{Start,Done} hooks with an https http1
|
||
// connections. The http2 test is done in TestTransportEventTrace_h2
|
||
func TestTLSHandshakeTrace(t *testing.T) {
|
||
run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
|
||
}
|
||
func testTLSHandshakeTrace(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
|
||
|
||
var mu sync.Mutex
|
||
var start, done bool
|
||
trace := &httptrace.ClientTrace{
|
||
TLSHandshakeStart: func() {
|
||
mu.Lock()
|
||
defer mu.Unlock()
|
||
start = true
|
||
},
|
||
TLSHandshakeDone: func(s tls.ConnectionState, err error) {
|
||
mu.Lock()
|
||
defer mu.Unlock()
|
||
done = true
|
||
if err != nil {
|
||
t.Fatal("Expected error to be nil but was:", err)
|
||
}
|
||
},
|
||
}
|
||
|
||
c := ts.Client()
|
||
req, err := NewRequest("GET", ts.URL, nil)
|
||
if err != nil {
|
||
t.Fatal("Unable to construct test request:", err)
|
||
}
|
||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
|
||
|
||
r, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatal("Unexpected error making request:", err)
|
||
}
|
||
r.Body.Close()
|
||
mu.Lock()
|
||
defer mu.Unlock()
|
||
if !start {
|
||
t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
|
||
}
|
||
if !done {
|
||
t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
|
||
}
|
||
}
|
||
|
||
func TestTransportMaxIdleConns(t *testing.T) {
|
||
run(t, testTransportMaxIdleConns, []testMode{http1Mode})
|
||
}
|
||
func testTransportMaxIdleConns(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
// No body for convenience.
|
||
})).ts
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
tr.MaxIdleConns = 4
|
||
|
||
ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
|
||
return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
|
||
})
|
||
|
||
hitHost := func(n int) {
|
||
req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
|
||
req = req.WithContext(ctx)
|
||
res, err := c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
res.Body.Close()
|
||
}
|
||
for i := 0; i < 4; i++ {
|
||
hitHost(i)
|
||
}
|
||
want := []string{
|
||
"|http|host-0.dns-is-faked.golang:" + port,
|
||
"|http|host-1.dns-is-faked.golang:" + port,
|
||
"|http|host-2.dns-is-faked.golang:" + port,
|
||
"|http|host-3.dns-is-faked.golang:" + port,
|
||
}
|
||
if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
|
||
t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
|
||
}
|
||
|
||
// Now hitting the 5th host should kick out the first host:
|
||
hitHost(4)
|
||
want = []string{
|
||
"|http|host-1.dns-is-faked.golang:" + port,
|
||
"|http|host-2.dns-is-faked.golang:" + port,
|
||
"|http|host-3.dns-is-faked.golang:" + port,
|
||
"|http|host-4.dns-is-faked.golang:" + port,
|
||
}
|
||
if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
|
||
t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
|
||
}
|
||
}
|
||
|
||
func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
|
||
func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
|
||
if testing.Short() {
|
||
t.Skip("skipping in short mode")
|
||
}
|
||
|
||
timeout := 1 * time.Millisecond
|
||
timeoutLoop:
|
||
for {
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
// No body for convenience.
|
||
}))
|
||
tr := cst.tr
|
||
tr.IdleConnTimeout = timeout
|
||
defer tr.CloseIdleConnections()
|
||
c := &Client{Transport: tr}
|
||
|
||
idleConns := func() []string {
|
||
if mode == http2Mode {
|
||
return tr.IdleConnStrsForTesting_h2()
|
||
} else {
|
||
return tr.IdleConnStrsForTesting()
|
||
}
|
||
}
|
||
|
||
var conn string
|
||
doReq := func(n int) (timeoutOk bool) {
|
||
req, _ := NewRequest("GET", cst.ts.URL, nil)
|
||
req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
|
||
PutIdleConn: func(err error) {
|
||
if err != nil {
|
||
t.Errorf("failed to keep idle conn: %v", err)
|
||
}
|
||
},
|
||
}))
|
||
res, err := c.Do(req)
|
||
if err != nil {
|
||
if strings.Contains(err.Error(), "use of closed network connection") {
|
||
t.Logf("req %v: connection closed prematurely", n)
|
||
return false
|
||
}
|
||
}
|
||
res.Body.Close()
|
||
conns := idleConns()
|
||
if len(conns) != 1 {
|
||
if len(conns) == 0 {
|
||
t.Logf("req %v: no idle conns", n)
|
||
return false
|
||
}
|
||
t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
|
||
}
|
||
if conn == "" {
|
||
conn = conns[0]
|
||
}
|
||
if conn != conns[0] {
|
||
t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
for i := 0; i < 3; i++ {
|
||
if !doReq(i) {
|
||
t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
|
||
timeout *= 2
|
||
cst.close()
|
||
continue timeoutLoop
|
||
}
|
||
time.Sleep(timeout / 2)
|
||
}
|
||
|
||
waitCondition(t, timeout/2, func(d time.Duration) bool {
|
||
if got := idleConns(); len(got) != 0 {
|
||
if d >= timeout*3/2 {
|
||
t.Logf("after %v, idle conns = %q", d, got)
|
||
}
|
||
return false
|
||
}
|
||
return true
|
||
})
|
||
break
|
||
}
|
||
}
|
||
|
||
// Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an
|
||
// HTTP/2 connection was established but its caller no longer
|
||
// wanted it. (Assuming the connection cache was enabled, which it is
|
||
// by default)
|
||
//
|
||
// This test reproduced the crash by setting the IdleConnTimeout low
|
||
// (to make the test reasonable) and then making a request which is
|
||
// canceled by the DialTLS hook, which then also waits to return the
|
||
// real connection until after the RoundTrip saw the error. Then we
|
||
// know the successful tls.Dial from DialTLS will need to go into the
|
||
// idle pool. Then we give it a of time to explode.
|
||
func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
|
||
func testIdleConnH2Crash(t *testing.T, mode testMode) {
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
// nothing
|
||
}))
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
|
||
sawDoErr := make(chan bool, 1)
|
||
testDone := make(chan struct{})
|
||
defer close(testDone)
|
||
|
||
cst.tr.IdleConnTimeout = 5 * time.Millisecond
|
||
cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
|
||
c, err := tls.Dial(network, addr, &tls.Config{
|
||
InsecureSkipVerify: true,
|
||
NextProtos: []string{"h2"},
|
||
})
|
||
if err != nil {
|
||
t.Error(err)
|
||
return nil, err
|
||
}
|
||
if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
|
||
t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
|
||
c.Close()
|
||
return nil, errors.New("bogus")
|
||
}
|
||
|
||
cancel()
|
||
|
||
select {
|
||
case <-sawDoErr:
|
||
case <-testDone:
|
||
}
|
||
return c, nil
|
||
}
|
||
|
||
req, _ := NewRequest("GET", cst.ts.URL, nil)
|
||
req = req.WithContext(ctx)
|
||
res, err := cst.c.Do(req)
|
||
if err == nil {
|
||
res.Body.Close()
|
||
t.Fatal("unexpected success")
|
||
}
|
||
sawDoErr <- true
|
||
|
||
// Wait for the explosion.
|
||
time.Sleep(cst.tr.IdleConnTimeout * 10)
|
||
}
|
||
|
||
type funcConn struct {
|
||
net.Conn
|
||
read func([]byte) (int, error)
|
||
write func([]byte) (int, error)
|
||
}
|
||
|
||
func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
|
||
func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
|
||
func (c funcConn) Close() error { return nil }
|
||
|
||
// Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek
|
||
// back to the caller.
|
||
func TestTransportReturnsPeekError(t *testing.T) {
|
||
errValue := errors.New("specific error value")
|
||
|
||
wrote := make(chan struct{})
|
||
wroteOnce := sync.OnceFunc(func() { close(wrote) })
|
||
|
||
tr := &Transport{
|
||
Dial: func(network, addr string) (net.Conn, error) {
|
||
c := funcConn{
|
||
read: func([]byte) (int, error) {
|
||
<-wrote
|
||
return 0, errValue
|
||
},
|
||
write: func(p []byte) (int, error) {
|
||
wroteOnce()
|
||
return len(p), nil
|
||
},
|
||
}
|
||
return c, nil
|
||
},
|
||
}
|
||
_, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
|
||
if err != errValue {
|
||
t.Errorf("error = %#v; want %v", err, errValue)
|
||
}
|
||
}
|
||
|
||
// Issue 13835: international domain names should work
|
||
func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
|
||
func testTransportIDNA(t *testing.T, mode testMode) {
|
||
const uniDomain = "гофер.го"
|
||
const punyDomain = "xn--c1ae0ajs.xn--c1aw"
|
||
|
||
var port string
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
want := punyDomain + ":" + port
|
||
if r.Host != want {
|
||
t.Errorf("Host header = %q; want %q", r.Host, want)
|
||
}
|
||
if mode == http2Mode {
|
||
if r.TLS == nil {
|
||
t.Errorf("r.TLS == nil")
|
||
} else if r.TLS.ServerName != punyDomain {
|
||
t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
|
||
}
|
||
}
|
||
w.Header().Set("Hit-Handler", "1")
|
||
}), func(tr *Transport) {
|
||
if tr.TLSClientConfig != nil {
|
||
tr.TLSClientConfig.InsecureSkipVerify = true
|
||
}
|
||
})
|
||
|
||
ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
// Install a fake DNS server.
|
||
ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
|
||
if host != punyDomain {
|
||
t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
|
||
return nil, nil
|
||
}
|
||
return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
|
||
})
|
||
|
||
req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
|
||
trace := &httptrace.ClientTrace{
|
||
GetConn: func(hostPort string) {
|
||
want := net.JoinHostPort(punyDomain, port)
|
||
if hostPort != want {
|
||
t.Errorf("getting conn for %q; want %q", hostPort, want)
|
||
}
|
||
},
|
||
DNSStart: func(e httptrace.DNSStartInfo) {
|
||
if e.Host != punyDomain {
|
||
t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
|
||
}
|
||
},
|
||
}
|
||
req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
|
||
|
||
res, err := cst.tr.RoundTrip(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer res.Body.Close()
|
||
if res.Header.Get("Hit-Handler") != "1" {
|
||
out, err := httputil.DumpResponse(res, true)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
|
||
}
|
||
}
|
||
|
||
// Issue 13290: send User-Agent in proxy CONNECT
|
||
func TestTransportProxyConnectHeader(t *testing.T) {
|
||
run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
|
||
}
|
||
func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
|
||
reqc := make(chan *Request, 1)
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
if r.Method != "CONNECT" {
|
||
t.Errorf("method = %q; want CONNECT", r.Method)
|
||
}
|
||
reqc <- r
|
||
c, _, err := w.(Hijacker).Hijack()
|
||
if err != nil {
|
||
t.Errorf("Hijack: %v", err)
|
||
return
|
||
}
|
||
c.Close()
|
||
})).ts
|
||
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
|
||
return url.Parse(ts.URL)
|
||
}
|
||
c.Transport.(*Transport).ProxyConnectHeader = Header{
|
||
"User-Agent": {"foo"},
|
||
"Other": {"bar"},
|
||
}
|
||
|
||
res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
|
||
if err == nil {
|
||
res.Body.Close()
|
||
t.Errorf("unexpected success")
|
||
}
|
||
|
||
r := <-reqc
|
||
if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
|
||
t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
|
||
}
|
||
if got, want := r.Header.Get("Other"), "bar"; got != want {
|
||
t.Errorf("CONNECT request Other = %q; want %q", got, want)
|
||
}
|
||
}
|
||
|
||
func TestTransportProxyGetConnectHeader(t *testing.T) {
|
||
run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
|
||
}
|
||
func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
|
||
reqc := make(chan *Request, 1)
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
if r.Method != "CONNECT" {
|
||
t.Errorf("method = %q; want CONNECT", r.Method)
|
||
}
|
||
reqc <- r
|
||
c, _, err := w.(Hijacker).Hijack()
|
||
if err != nil {
|
||
t.Errorf("Hijack: %v", err)
|
||
return
|
||
}
|
||
c.Close()
|
||
})).ts
|
||
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
|
||
return url.Parse(ts.URL)
|
||
}
|
||
// These should be ignored:
|
||
c.Transport.(*Transport).ProxyConnectHeader = Header{
|
||
"User-Agent": {"foo"},
|
||
"Other": {"bar"},
|
||
}
|
||
c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
|
||
return Header{
|
||
"User-Agent": {"foo2"},
|
||
"Other": {"bar2"},
|
||
}, nil
|
||
}
|
||
|
||
res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
|
||
if err == nil {
|
||
res.Body.Close()
|
||
t.Errorf("unexpected success")
|
||
}
|
||
|
||
r := <-reqc
|
||
if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
|
||
t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
|
||
}
|
||
if got, want := r.Header.Get("Other"), "bar2"; got != want {
|
||
t.Errorf("CONNECT request Other = %q; want %q", got, want)
|
||
}
|
||
}
|
||
|
||
var errFakeRoundTrip = errors.New("fake roundtrip")
|
||
|
||
type funcRoundTripper func()
|
||
|
||
func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
|
||
fn()
|
||
return nil, errFakeRoundTrip
|
||
}
|
||
|
||
func wantBody(res *Response, err error, want string) error {
|
||
if err != nil {
|
||
return err
|
||
}
|
||
slurp, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
return fmt.Errorf("error reading body: %v", err)
|
||
}
|
||
if string(slurp) != want {
|
||
return fmt.Errorf("body = %q; want %q", slurp, want)
|
||
}
|
||
if err := res.Body.Close(); err != nil {
|
||
return fmt.Errorf("body Close = %v", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func newLocalListener(t *testing.T) net.Listener {
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
ln, err = net.Listen("tcp6", "[::1]:0")
|
||
}
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
return ln
|
||
}
|
||
|
||
type countCloseReader struct {
|
||
n *int
|
||
io.Reader
|
||
}
|
||
|
||
func (cr countCloseReader) Close() error {
|
||
(*cr.n)++
|
||
return nil
|
||
}
|
||
|
||
// rgz is a gzip quine that uncompresses to itself.
|
||
var rgz = []byte{
|
||
0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
|
||
0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
|
||
0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
|
||
0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
|
||
0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
|
||
0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
|
||
0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
|
||
0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
|
||
0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
|
||
0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
|
||
0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
|
||
0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
|
||
0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
|
||
0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
|
||
0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
|
||
0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
|
||
0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
|
||
0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
|
||
0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
|
||
0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
|
||
0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
|
||
0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
|
||
0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
|
||
0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
|
||
0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
|
||
0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
|
||
0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
|
||
0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
|
||
0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
|
||
0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
|
||
0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
|
||
0x00, 0x00,
|
||
}
|
||
|
||
// Ensure that a missing status doesn't make the server panic
|
||
// See Issue https://golang.org/issues/21701
|
||
func TestMissingStatusNoPanic(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
const want = "unknown status code"
|
||
|
||
ln := newLocalListener(t)
|
||
addr := ln.Addr().String()
|
||
done := make(chan bool)
|
||
fullAddrURL := fmt.Sprintf("http://%s", addr)
|
||
raw := "HTTP/1.1 400\r\n" +
|
||
"Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
|
||
"Content-Type: text/html; charset=utf-8\r\n" +
|
||
"Content-Length: 10\r\n" +
|
||
"Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
|
||
"Vary: Accept-Encoding\r\n\r\n" +
|
||
"Aloha Olaa"
|
||
|
||
go func() {
|
||
defer close(done)
|
||
|
||
conn, _ := ln.Accept()
|
||
if conn != nil {
|
||
io.WriteString(conn, raw)
|
||
io.ReadAll(conn)
|
||
conn.Close()
|
||
}
|
||
}()
|
||
|
||
proxyURL, err := url.Parse(fullAddrURL)
|
||
if err != nil {
|
||
t.Fatalf("proxyURL: %v", err)
|
||
}
|
||
|
||
tr := &Transport{Proxy: ProxyURL(proxyURL)}
|
||
|
||
req, _ := NewRequest("GET", "https://golang.org/", nil)
|
||
res, err, panicked := doFetchCheckPanic(tr, req)
|
||
if panicked {
|
||
t.Error("panicked, expecting an error")
|
||
}
|
||
if res != nil && res.Body != nil {
|
||
io.Copy(io.Discard, res.Body)
|
||
res.Body.Close()
|
||
}
|
||
|
||
if err == nil || !strings.Contains(err.Error(), want) {
|
||
t.Errorf("got=%v want=%q", err, want)
|
||
}
|
||
|
||
ln.Close()
|
||
<-done
|
||
}
|
||
|
||
func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
panicked = true
|
||
}
|
||
}()
|
||
res, err = tr.RoundTrip(req)
|
||
return
|
||
}
|
||
|
||
// Issue 22330: do not allow the response body to be read when the status code
|
||
// forbids a response body.
|
||
func TestNoBodyOnChunked304Response(t *testing.T) {
|
||
run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
|
||
}
|
||
func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
conn, buf, _ := w.(Hijacker).Hijack()
|
||
buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
|
||
buf.Flush()
|
||
conn.Close()
|
||
}))
|
||
|
||
// Our test server above is sending back bogus data after the
|
||
// response (the "0\r\n\r\n" part), which causes the Transport
|
||
// code to log spam. Disable keep-alives so we never even try
|
||
// to reuse the connection.
|
||
cst.tr.DisableKeepAlives = true
|
||
|
||
res, err := cst.c.Get(cst.ts.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if res.Body != NoBody {
|
||
t.Errorf("Unexpected body on 304 response")
|
||
}
|
||
}
|
||
|
||
type funcWriter func([]byte) (int, error)
|
||
|
||
func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
|
||
|
||
type doneContext struct {
|
||
context.Context
|
||
err error
|
||
}
|
||
|
||
func (doneContext) Done() <-chan struct{} {
|
||
c := make(chan struct{})
|
||
close(c)
|
||
return c
|
||
}
|
||
|
||
func (d doneContext) Err() error { return d.err }
|
||
|
||
// Issue 25852: Transport should check whether Context is done early.
|
||
func TestTransportCheckContextDoneEarly(t *testing.T) {
|
||
tr := &Transport{}
|
||
req, _ := NewRequest("GET", "http://fake.example/", nil)
|
||
wantErr := errors.New("some error")
|
||
req = req.WithContext(doneContext{context.Background(), wantErr})
|
||
_, err := tr.RoundTrip(req)
|
||
if err != wantErr {
|
||
t.Errorf("error = %v; want %v", err, wantErr)
|
||
}
|
||
}
|
||
|
||
// Issue 23399: verify that if a client request times out, the Transport's
|
||
// conn is closed so that it's not reused.
|
||
//
|
||
// This is the test variant that times out before the server replies with
|
||
// any response headers.
|
||
func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
|
||
run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
|
||
}
|
||
func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
|
||
timeout := 1 * time.Millisecond
|
||
for {
|
||
inHandler := make(chan bool)
|
||
cancelHandler := make(chan struct{})
|
||
handlerDone := make(chan bool)
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
<-r.Context().Done()
|
||
|
||
select {
|
||
case <-cancelHandler:
|
||
return
|
||
case inHandler <- true:
|
||
}
|
||
defer func() { handlerDone <- true }()
|
||
|
||
// Read from the conn until EOF to verify that it was correctly closed.
|
||
conn, _, err := w.(Hijacker).Hijack()
|
||
if err != nil {
|
||
t.Error(err)
|
||
return
|
||
}
|
||
n, err := conn.Read([]byte{0})
|
||
if n != 0 || err != io.EOF {
|
||
t.Errorf("unexpected Read result: %v, %v", n, err)
|
||
}
|
||
conn.Close()
|
||
}))
|
||
|
||
cst.c.Timeout = timeout
|
||
|
||
_, err := cst.c.Get(cst.ts.URL)
|
||
if err == nil {
|
||
close(cancelHandler)
|
||
t.Fatal("unexpected Get success")
|
||
}
|
||
|
||
tooSlow := time.NewTimer(timeout * 10)
|
||
select {
|
||
case <-tooSlow.C:
|
||
// If we didn't get into the Handler, that probably means the builder was
|
||
// just slow and the Get failed in that time but never made it to the
|
||
// server. That's fine; we'll try again with a longer timeout.
|
||
t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
|
||
close(cancelHandler)
|
||
cst.close()
|
||
timeout *= 2
|
||
continue
|
||
case <-inHandler:
|
||
tooSlow.Stop()
|
||
<-handlerDone
|
||
}
|
||
break
|
||
}
|
||
}
|
||
|
||
// Issue 23399: verify that if a client request times out, the Transport's
|
||
// conn is closed so that it's not reused.
|
||
//
|
||
// This is the test variant that has the server send response headers
|
||
// first, and time out during the write of the response body.
|
||
func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
|
||
run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
|
||
}
|
||
func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
|
||
inHandler := make(chan bool)
|
||
cancelHandler := make(chan struct{})
|
||
handlerDone := make(chan bool)
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
w.Header().Set("Content-Length", "100")
|
||
w.(Flusher).Flush()
|
||
|
||
select {
|
||
case <-cancelHandler:
|
||
return
|
||
case inHandler <- true:
|
||
}
|
||
defer func() { handlerDone <- true }()
|
||
|
||
conn, _, err := w.(Hijacker).Hijack()
|
||
if err != nil {
|
||
t.Error(err)
|
||
return
|
||
}
|
||
conn.Write([]byte("foo"))
|
||
|
||
n, err := conn.Read([]byte{0})
|
||
// The error should be io.EOF or "read tcp
|
||
// 127.0.0.1:35827->127.0.0.1:40290: read: connection
|
||
// reset by peer" depending on timing. Really we just
|
||
// care that it returns at all. But if it returns with
|
||
// data, that's weird.
|
||
if n != 0 || err == nil {
|
||
t.Errorf("unexpected Read result: %v, %v", n, err)
|
||
}
|
||
conn.Close()
|
||
}))
|
||
|
||
// Set Timeout to something very long but non-zero to exercise
|
||
// the codepaths that check for it. But rather than wait for it to fire
|
||
// (which would make the test slow), we send on the req.Cancel channel instead,
|
||
// which happens to exercise the same code paths.
|
||
cst.c.Timeout = 24 * time.Hour // just to be non-zero, not to hit it.
|
||
req, _ := NewRequest("GET", cst.ts.URL, nil)
|
||
cancelReq := make(chan struct{})
|
||
req.Cancel = cancelReq
|
||
|
||
res, err := cst.c.Do(req)
|
||
if err != nil {
|
||
close(cancelHandler)
|
||
t.Fatalf("Get error: %v", err)
|
||
}
|
||
|
||
// Cancel the request while the handler is still blocked on sending to the
|
||
// inHandler channel. Then read it until it fails, to verify that the
|
||
// connection is broken before the handler itself closes it.
|
||
close(cancelReq)
|
||
got, err := io.ReadAll(res.Body)
|
||
if err == nil {
|
||
t.Errorf("unexpected success; read %q, nil", got)
|
||
}
|
||
|
||
// Now unblock the handler and wait for it to complete.
|
||
<-inHandler
|
||
<-handlerDone
|
||
}
|
||
|
||
func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
|
||
run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
|
||
}
|
||
func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
|
||
done := make(chan struct{})
|
||
defer close(done)
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
conn, _, err := w.(Hijacker).Hijack()
|
||
if err != nil {
|
||
t.Error(err)
|
||
return
|
||
}
|
||
defer conn.Close()
|
||
io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
|
||
bs := bufio.NewScanner(conn)
|
||
bs.Scan()
|
||
fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
|
||
<-done
|
||
}))
|
||
|
||
req, _ := NewRequest("GET", cst.ts.URL, nil)
|
||
req.Header.Set("Upgrade", "foo")
|
||
req.Header.Set("Connection", "upgrade")
|
||
res, err := cst.c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if res.StatusCode != 101 {
|
||
t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
|
||
}
|
||
rwc, ok := res.Body.(io.ReadWriteCloser)
|
||
if !ok {
|
||
t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
|
||
}
|
||
defer rwc.Close()
|
||
bs := bufio.NewScanner(rwc)
|
||
if !bs.Scan() {
|
||
t.Fatalf("expected readable input")
|
||
}
|
||
if got, want := bs.Text(), "Some buffered data"; got != want {
|
||
t.Errorf("read %q; want %q", got, want)
|
||
}
|
||
io.WriteString(rwc, "echo\n")
|
||
if !bs.Scan() {
|
||
t.Fatalf("expected another line")
|
||
}
|
||
if got, want := bs.Text(), "ECHO"; got != want {
|
||
t.Errorf("read %q; want %q", got, want)
|
||
}
|
||
}
|
||
|
||
func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
|
||
func testTransportCONNECTBidi(t *testing.T, mode testMode) {
|
||
const target = "backend:443"
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
if r.Method != "CONNECT" {
|
||
t.Errorf("unexpected method %q", r.Method)
|
||
w.WriteHeader(500)
|
||
return
|
||
}
|
||
if r.RequestURI != target {
|
||
t.Errorf("unexpected CONNECT target %q", r.RequestURI)
|
||
w.WriteHeader(500)
|
||
return
|
||
}
|
||
nc, brw, err := w.(Hijacker).Hijack()
|
||
if err != nil {
|
||
t.Error(err)
|
||
return
|
||
}
|
||
defer nc.Close()
|
||
nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
|
||
// Switch to a little protocol that capitalize its input lines:
|
||
for {
|
||
line, err := brw.ReadString('\n')
|
||
if err != nil {
|
||
if err != io.EOF {
|
||
t.Error(err)
|
||
}
|
||
return
|
||
}
|
||
io.WriteString(brw, strings.ToUpper(line))
|
||
brw.Flush()
|
||
}
|
||
}))
|
||
pr, pw := io.Pipe()
|
||
defer pw.Close()
|
||
req, err := NewRequest("CONNECT", cst.ts.URL, pr)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
req.URL.Opaque = target
|
||
res, err := cst.c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer res.Body.Close()
|
||
if res.StatusCode != 200 {
|
||
t.Fatalf("status code = %d; want 200", res.StatusCode)
|
||
}
|
||
br := bufio.NewReader(res.Body)
|
||
for _, str := range []string{"foo", "bar", "baz"} {
|
||
fmt.Fprintf(pw, "%s\n", str)
|
||
got, err := br.ReadString('\n')
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
got = strings.TrimSpace(got)
|
||
want := strings.ToUpper(str)
|
||
if got != want {
|
||
t.Fatalf("got %q; want %q", got, want)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestTransportRequestReplayable(t *testing.T) {
|
||
someBody := io.NopCloser(strings.NewReader(""))
|
||
tests := []struct {
|
||
name string
|
||
req *Request
|
||
want bool
|
||
}{
|
||
{
|
||
name: "GET",
|
||
req: &Request{Method: "GET"},
|
||
want: true,
|
||
},
|
||
{
|
||
name: "GET_http.NoBody",
|
||
req: &Request{Method: "GET", Body: NoBody},
|
||
want: true,
|
||
},
|
||
{
|
||
name: "GET_body",
|
||
req: &Request{Method: "GET", Body: someBody},
|
||
want: false,
|
||
},
|
||
{
|
||
name: "POST",
|
||
req: &Request{Method: "POST"},
|
||
want: false,
|
||
},
|
||
{
|
||
name: "POST_idempotency-key",
|
||
req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
|
||
want: true,
|
||
},
|
||
{
|
||
name: "POST_x-idempotency-key",
|
||
req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
|
||
want: true,
|
||
},
|
||
{
|
||
name: "POST_body",
|
||
req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
|
||
want: false,
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
got := tt.req.ExportIsReplayable()
|
||
if got != tt.want {
|
||
t.Errorf("replyable = %v; want %v", got, tt.want)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// testMockTCPConn is a mock TCP connection used to test that
|
||
// ReadFrom is called when sending the request body.
|
||
type testMockTCPConn struct {
|
||
*net.TCPConn
|
||
|
||
ReadFromCalled bool
|
||
}
|
||
|
||
func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
|
||
c.ReadFromCalled = true
|
||
return c.TCPConn.ReadFrom(r)
|
||
}
|
||
|
||
func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
|
||
func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
|
||
nBytes := int64(1 << 10)
|
||
newFileFunc := func() (r io.Reader, done func(), err error) {
|
||
f, err := os.CreateTemp("", "net-http-newfilefunc")
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
// Write some bytes to the file to enable reading.
|
||
if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
|
||
return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
|
||
}
|
||
if _, err := f.Seek(0, 0); err != nil {
|
||
return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
|
||
}
|
||
|
||
done = func() {
|
||
f.Close()
|
||
os.Remove(f.Name())
|
||
}
|
||
|
||
return f, done, nil
|
||
}
|
||
|
||
newBufferFunc := func() (io.Reader, func(), error) {
|
||
return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
|
||
}
|
||
|
||
cases := []struct {
|
||
name string
|
||
readerFunc func() (io.Reader, func(), error)
|
||
contentLength int64
|
||
expectedReadFrom bool
|
||
}{
|
||
{
|
||
name: "file, length",
|
||
readerFunc: newFileFunc,
|
||
contentLength: nBytes,
|
||
expectedReadFrom: true,
|
||
},
|
||
{
|
||
name: "file, no length",
|
||
readerFunc: newFileFunc,
|
||
},
|
||
{
|
||
name: "file, negative length",
|
||
readerFunc: newFileFunc,
|
||
contentLength: -1,
|
||
},
|
||
{
|
||
name: "buffer",
|
||
contentLength: nBytes,
|
||
readerFunc: newBufferFunc,
|
||
},
|
||
{
|
||
name: "buffer, no length",
|
||
readerFunc: newBufferFunc,
|
||
},
|
||
{
|
||
name: "buffer, length -1",
|
||
contentLength: -1,
|
||
readerFunc: newBufferFunc,
|
||
},
|
||
}
|
||
|
||
for _, tc := range cases {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
r, cleanup, err := tc.readerFunc()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer cleanup()
|
||
|
||
tConn := &testMockTCPConn{}
|
||
trFunc := func(tr *Transport) {
|
||
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
var d net.Dialer
|
||
conn, err := d.DialContext(ctx, network, addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
tcpConn, ok := conn.(*net.TCPConn)
|
||
if !ok {
|
||
return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
|
||
}
|
||
|
||
tConn.TCPConn = tcpConn
|
||
return tConn, nil
|
||
}
|
||
}
|
||
|
||
cst := newClientServerTest(
|
||
t,
|
||
mode,
|
||
HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
io.Copy(io.Discard, r.Body)
|
||
r.Body.Close()
|
||
w.WriteHeader(200)
|
||
}),
|
||
trFunc,
|
||
)
|
||
|
||
req, err := NewRequest("PUT", cst.ts.URL, r)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
req.ContentLength = tc.contentLength
|
||
req.Header.Set("Content-Type", "application/octet-stream")
|
||
resp, err := cst.c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode != 200 {
|
||
t.Fatalf("status code = %d; want 200", resp.StatusCode)
|
||
}
|
||
|
||
expectedReadFrom := tc.expectedReadFrom
|
||
if mode != http1Mode {
|
||
expectedReadFrom = false
|
||
}
|
||
if !tConn.ReadFromCalled && expectedReadFrom {
|
||
t.Fatalf("did not call ReadFrom")
|
||
}
|
||
|
||
if tConn.ReadFromCalled && !expectedReadFrom {
|
||
t.Fatalf("ReadFrom was unexpectedly invoked")
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestTransportClone(t *testing.T) {
|
||
tr := &Transport{
|
||
Proxy: func(*Request) (*url.URL, error) { panic("") },
|
||
OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
|
||
return nil
|
||
},
|
||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
|
||
Dial: func(network, addr string) (net.Conn, error) { panic("") },
|
||
DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
|
||
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
|
||
TLSClientConfig: new(tls.Config),
|
||
TLSHandshakeTimeout: time.Second,
|
||
DisableKeepAlives: true,
|
||
DisableCompression: true,
|
||
MaxIdleConns: 1,
|
||
MaxIdleConnsPerHost: 1,
|
||
MaxConnsPerHost: 1,
|
||
IdleConnTimeout: time.Second,
|
||
ResponseHeaderTimeout: time.Second,
|
||
ExpectContinueTimeout: time.Second,
|
||
ProxyConnectHeader: Header{},
|
||
GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
|
||
MaxResponseHeaderBytes: 1,
|
||
ForceAttemptHTTP2: true,
|
||
HTTP2: &HTTP2Config{MaxConcurrentStreams: 1},
|
||
Protocols: &Protocols{},
|
||
TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
|
||
"foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
|
||
},
|
||
ReadBufferSize: 1,
|
||
WriteBufferSize: 1,
|
||
}
|
||
tr.Protocols.SetHTTP1(true)
|
||
tr.Protocols.SetHTTP2(true)
|
||
tr2 := tr.Clone()
|
||
rv := reflect.ValueOf(tr2).Elem()
|
||
rt := rv.Type()
|
||
for i := 0; i < rt.NumField(); i++ {
|
||
sf := rt.Field(i)
|
||
if !token.IsExported(sf.Name) {
|
||
continue
|
||
}
|
||
if rv.Field(i).IsZero() {
|
||
t.Errorf("cloned field t2.%s is zero", sf.Name)
|
||
}
|
||
}
|
||
|
||
if _, ok := tr2.TLSNextProto["foo"]; !ok {
|
||
t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
|
||
}
|
||
|
||
// But test that a nil TLSNextProto is kept nil:
|
||
tr = new(Transport)
|
||
tr2 = tr.Clone()
|
||
if tr2.TLSNextProto != nil {
|
||
t.Errorf("Transport.TLSNextProto unexpected non-nil")
|
||
}
|
||
}
|
||
|
||
func TestIs408(t *testing.T) {
|
||
tests := []struct {
|
||
in string
|
||
want bool
|
||
}{
|
||
{"HTTP/1.0 408", true},
|
||
{"HTTP/1.1 408", true},
|
||
{"HTTP/1.8 408", true},
|
||
{"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now.
|
||
{"HTTP/1.1 408 ", true},
|
||
{"HTTP/1.1 40", false},
|
||
{"http/1.0 408", false},
|
||
{"HTTP/1-1 408", false},
|
||
}
|
||
for _, tt := range tests {
|
||
if got := Export_is408Message([]byte(tt.in)); got != tt.want {
|
||
t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestTransportIgnores408(t *testing.T) {
|
||
run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
|
||
}
|
||
func testTransportIgnores408(t *testing.T, mode testMode) {
|
||
// Not parallel. Relies on mutating the log package's global Output.
|
||
defer log.SetOutput(log.Writer())
|
||
|
||
var logout strings.Builder
|
||
log.SetOutput(&logout)
|
||
|
||
const target = "backend:443"
|
||
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
nc, _, err := w.(Hijacker).Hijack()
|
||
if err != nil {
|
||
t.Error(err)
|
||
return
|
||
}
|
||
defer nc.Close()
|
||
nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
|
||
nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail
|
||
}))
|
||
req, err := NewRequest("GET", cst.ts.URL, nil)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
res, err := cst.c.Do(req)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
slurp, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if string(slurp) != "ok" {
|
||
t.Fatalf("got %q; want ok", slurp)
|
||
}
|
||
|
||
waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
|
||
if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
|
||
if d > 0 {
|
||
t.Logf("%v idle conns still present after %v", n, d)
|
||
}
|
||
return false
|
||
}
|
||
return true
|
||
})
|
||
if got := logout.String(); got != "" {
|
||
t.Fatalf("expected no log output; got: %s", got)
|
||
}
|
||
}
|
||
|
||
func TestInvalidHeaderResponse(t *testing.T) {
|
||
run(t, testInvalidHeaderResponse, []testMode{http1Mode})
|
||
}
|
||
func testInvalidHeaderResponse(t *testing.T, mode testMode) {
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
conn, buf, _ := w.(Hijacker).Hijack()
|
||
buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
|
||
"Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
|
||
"Content-Type: text/html; charset=utf-8\r\n" +
|
||
"Content-Length: 0\r\n" +
|
||
"Foo : bar\r\n\r\n"))
|
||
buf.Flush()
|
||
conn.Close()
|
||
}))
|
||
res, err := cst.c.Get(cst.ts.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer res.Body.Close()
|
||
if v := res.Header.Get("Foo"); v != "" {
|
||
t.Errorf(`unexpected "Foo" header: %q`, v)
|
||
}
|
||
if v := res.Header.Get("Foo "); v != "bar" {
|
||
t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
|
||
}
|
||
}
|
||
|
||
type bodyCloser bool
|
||
|
||
func (bc *bodyCloser) Close() error {
|
||
*bc = true
|
||
return nil
|
||
}
|
||
func (bc *bodyCloser) Read(b []byte) (n int, err error) {
|
||
return 0, io.EOF
|
||
}
|
||
|
||
// Issue 35015: ensure that Transport closes the body on any error
|
||
// with an invalid request, as promised by Client.Do docs.
|
||
func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
|
||
run(t, testTransportClosesBodyOnInvalidRequests)
|
||
}
|
||
func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
t.Errorf("Should not have been invoked")
|
||
})).ts
|
||
|
||
u, _ := url.Parse(cst.URL)
|
||
|
||
tests := []struct {
|
||
name string
|
||
req *Request
|
||
wantErr string
|
||
}{
|
||
{
|
||
name: "invalid method",
|
||
req: &Request{
|
||
Method: " ",
|
||
URL: u,
|
||
},
|
||
wantErr: `invalid method " "`,
|
||
},
|
||
{
|
||
name: "nil URL",
|
||
req: &Request{
|
||
Method: "GET",
|
||
},
|
||
wantErr: `nil Request.URL`,
|
||
},
|
||
{
|
||
name: "invalid header key",
|
||
req: &Request{
|
||
Method: "GET",
|
||
Header: Header{"💡": {"emoji"}},
|
||
URL: u,
|
||
},
|
||
wantErr: `invalid header field name "💡"`,
|
||
},
|
||
{
|
||
name: "invalid header value",
|
||
req: &Request{
|
||
Method: "POST",
|
||
Header: Header{"key": {"\x19"}},
|
||
URL: u,
|
||
},
|
||
wantErr: `invalid header field value for "key"`,
|
||
},
|
||
{
|
||
name: "non HTTP(s) scheme",
|
||
req: &Request{
|
||
Method: "POST",
|
||
URL: &url.URL{Scheme: "faux"},
|
||
},
|
||
wantErr: `unsupported protocol scheme "faux"`,
|
||
},
|
||
{
|
||
name: "no Host in URL",
|
||
req: &Request{
|
||
Method: "POST",
|
||
URL: &url.URL{Scheme: "http"},
|
||
},
|
||
wantErr: `no Host in request URL`,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
var bc bodyCloser
|
||
req := tt.req
|
||
req.Body = &bc
|
||
_, err := cst.Client().Do(tt.req)
|
||
if err == nil {
|
||
t.Fatal("Expected an error")
|
||
}
|
||
if !bc {
|
||
t.Fatal("Expected body to have been closed")
|
||
}
|
||
if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
|
||
t.Fatalf("Error mismatch: %q does not end with %q", g, w)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// breakableConn is a net.Conn wrapper with a Write method
|
||
// that will fail when its brokenState is true.
|
||
type breakableConn struct {
|
||
net.Conn
|
||
*brokenState
|
||
}
|
||
|
||
type brokenState struct {
|
||
sync.Mutex
|
||
broken bool
|
||
}
|
||
|
||
func (w *breakableConn) Write(b []byte) (n int, err error) {
|
||
w.Lock()
|
||
defer w.Unlock()
|
||
if w.broken {
|
||
return 0, errors.New("some write error")
|
||
}
|
||
return w.Conn.Write(b)
|
||
}
|
||
|
||
// Issue 34978: don't cache a broken HTTP/2 connection
|
||
func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
|
||
run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
|
||
}
|
||
func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
|
||
|
||
var brokenState brokenState
|
||
|
||
const numReqs = 5
|
||
var numDials, gotConns uint32 // atomic
|
||
|
||
cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
|
||
atomic.AddUint32(&numDials, 1)
|
||
c, err := net.Dial(netw, addr)
|
||
if err != nil {
|
||
t.Errorf("unexpected Dial error: %v", err)
|
||
return nil, err
|
||
}
|
||
return &breakableConn{c, &brokenState}, err
|
||
}
|
||
|
||
for i := 1; i <= numReqs; i++ {
|
||
brokenState.Lock()
|
||
brokenState.broken = false
|
||
brokenState.Unlock()
|
||
|
||
// doBreak controls whether we break the TCP connection after the TLS
|
||
// handshake (before the HTTP/2 handshake). We test a few failures
|
||
// in a row followed by a final success.
|
||
doBreak := i != numReqs
|
||
|
||
ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
|
||
GotConn: func(info httptrace.GotConnInfo) {
|
||
t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
|
||
atomic.AddUint32(&gotConns, 1)
|
||
},
|
||
TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
|
||
brokenState.Lock()
|
||
defer brokenState.Unlock()
|
||
if doBreak {
|
||
brokenState.broken = true
|
||
}
|
||
},
|
||
})
|
||
req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
_, err = cst.c.Do(req)
|
||
if doBreak != (err != nil) {
|
||
t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
|
||
}
|
||
}
|
||
if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
|
||
t.Errorf("GotConn calls = %v; want %v", got, want)
|
||
}
|
||
if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
|
||
t.Errorf("Dials = %v; want %v", got, want)
|
||
}
|
||
}
|
||
|
||
// Issue 34941
|
||
// When the client has too many concurrent requests on a single connection,
|
||
// http.http2noCachedConnError is reported on multiple requests. There should
|
||
// only be one decrement regardless of the number of failures.
|
||
func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
|
||
run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
|
||
}
|
||
func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
|
||
CondSkipHTTP2(t)
|
||
|
||
h := HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
_, err := w.Write([]byte("foo"))
|
||
if err != nil {
|
||
t.Fatalf("Write: %v", err)
|
||
}
|
||
})
|
||
|
||
ts := newClientServerTest(t, mode, h).ts
|
||
|
||
c := ts.Client()
|
||
tr := c.Transport.(*Transport)
|
||
tr.MaxConnsPerHost = 1
|
||
|
||
errCh := make(chan error, 300)
|
||
doReq := func() {
|
||
resp, err := c.Get(ts.URL)
|
||
if err != nil {
|
||
errCh <- fmt.Errorf("request failed: %v", err)
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
_, err = io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
errCh <- fmt.Errorf("read body failed: %v", err)
|
||
}
|
||
}
|
||
|
||
var wg sync.WaitGroup
|
||
for i := 0; i < 300; i++ {
|
||
wg.Add(1)
|
||
go func() {
|
||
defer wg.Done()
|
||
doReq()
|
||
}()
|
||
}
|
||
wg.Wait()
|
||
close(errCh)
|
||
|
||
for err := range errCh {
|
||
t.Errorf("error occurred: %v", err)
|
||
}
|
||
}
|
||
|
||
// Issue 36820
|
||
// Test that we use the older backward compatible cancellation protocol
|
||
// when a RoundTripper is registered via RegisterProtocol.
|
||
func TestAltProtoCancellation(t *testing.T) {
|
||
defer afterTest(t)
|
||
tr := &Transport{}
|
||
c := &Client{
|
||
Transport: tr,
|
||
Timeout: time.Millisecond,
|
||
}
|
||
tr.RegisterProtocol("cancel", cancelProto{})
|
||
_, err := c.Get("cancel://bar.com/path")
|
||
if err == nil {
|
||
t.Error("request unexpectedly succeeded")
|
||
} else if !strings.Contains(err.Error(), errCancelProto.Error()) {
|
||
t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
|
||
}
|
||
}
|
||
|
||
var errCancelProto = errors.New("canceled as expected")
|
||
|
||
type cancelProto struct{}
|
||
|
||
func (cancelProto) RoundTrip(req *Request) (*Response, error) {
|
||
<-req.Cancel
|
||
return nil, errCancelProto
|
||
}
|
||
|
||
type roundTripFunc func(r *Request) (*Response, error)
|
||
|
||
func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
|
||
|
||
// Issue 32441: body is not reset after ErrSkipAltProtocol
|
||
func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
|
||
func testIssue32441(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
|
||
t.Error("body length is zero")
|
||
}
|
||
})).ts
|
||
c := ts.Client()
|
||
c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
|
||
// Draining body to trigger failure condition on actual request to server.
|
||
if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
|
||
t.Error("body length is zero during round trip")
|
||
}
|
||
return nil, ErrSkipAltProtocol
|
||
}))
|
||
if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
|
||
t.Error(err)
|
||
}
|
||
}
|
||
|
||
// Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers
|
||
// that contain a sign (eg. "+3"), per RFC 2616, Section 14.13.
|
||
func TestTransportRejectsSignInContentLength(t *testing.T) {
|
||
run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
|
||
}
|
||
func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||
w.Header().Set("Content-Length", "+3")
|
||
w.Write([]byte("abc"))
|
||
})).ts
|
||
|
||
c := cst.Client()
|
||
res, err := c.Get(cst.URL)
|
||
if err == nil || res != nil {
|
||
t.Fatal("Expected a non-nil error and a nil http.Response")
|
||
}
|
||
if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
|
||
t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
|
||
}
|
||
}
|
||
|
||
// dumpConn is a net.Conn which writes to Writer and reads from Reader
|
||
type dumpConn struct {
|
||
io.Writer
|
||
io.Reader
|
||
}
|
||
|
||
func (c *dumpConn) Close() error { return nil }
|
||
func (c *dumpConn) LocalAddr() net.Addr { return nil }
|
||
func (c *dumpConn) RemoteAddr() net.Addr { return nil }
|
||
func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
|
||
func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
|
||
func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
|
||
|
||
// delegateReader is a reader that delegates to another reader,
|
||
// once it arrives on a channel.
|
||
type delegateReader struct {
|
||
c chan io.Reader
|
||
r io.Reader // nil until received from c
|
||
}
|
||
|
||
func (r *delegateReader) Read(p []byte) (int, error) {
|
||
if r.r == nil {
|
||
var ok bool
|
||
if r.r, ok = <-r.c; !ok {
|
||
return 0, errors.New("delegate closed")
|
||
}
|
||
}
|
||
return r.r.Read(p)
|
||
}
|
||
|
||
func testTransportRace(req *Request) {
|
||
save := req.Body
|
||
pr, pw := io.Pipe()
|
||
defer pr.Close()
|
||
defer pw.Close()
|
||
dr := &delegateReader{c: make(chan io.Reader)}
|
||
|
||
t := &Transport{
|
||
Dial: func(net, addr string) (net.Conn, error) {
|
||
return &dumpConn{pw, dr}, nil
|
||
},
|
||
}
|
||
defer t.CloseIdleConnections()
|
||
|
||
quitReadCh := make(chan struct{})
|
||
// Wait for the request before replying with a dummy response:
|
||
go func() {
|
||
defer close(quitReadCh)
|
||
|
||
req, err := ReadRequest(bufio.NewReader(pr))
|
||
if err == nil {
|
||
// Ensure all the body is read; otherwise
|
||
// we'll get a partial dump.
|
||
io.Copy(io.Discard, req.Body)
|
||
req.Body.Close()
|
||
}
|
||
select {
|
||
case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
|
||
case quitReadCh <- struct{}{}:
|
||
// Ensure delegate is closed so Read doesn't block forever.
|
||
close(dr.c)
|
||
}
|
||
}()
|
||
|
||
t.RoundTrip(req)
|
||
|
||
// Ensure the reader returns before we reset req.Body to prevent
|
||
// a data race on req.Body.
|
||
pw.Close()
|
||
<-quitReadCh
|
||
|
||
req.Body = save
|
||
}
|
||
|
||
// Issue 37669
|
||
// Test that a cancellation doesn't result in a data race due to the writeLoop
|
||
// goroutine being left running, if the caller mutates the processed Request
|
||
// upon completion.
|
||
func TestErrorWriteLoopRace(t *testing.T) {
|
||
if testing.Short() {
|
||
return
|
||
}
|
||
t.Parallel()
|
||
for i := 0; i < 1000; i++ {
|
||
delay := time.Duration(mrand.Intn(5)) * time.Millisecond
|
||
ctx, cancel := context.WithTimeout(context.Background(), delay)
|
||
defer cancel()
|
||
|
||
r := bytes.NewBuffer(make([]byte, 10000))
|
||
req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
testTransportRace(req)
|
||
}
|
||
}
|
||
|
||
// Issue 41600
|
||
// Test that a new request which uses the connection of an active request
|
||
// cannot cause it to be canceled as well.
|
||
func TestCancelRequestWhenSharingConnection(t *testing.T) {
|
||
run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
|
||
}
|
||
func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
|
||
reqc := make(chan chan struct{}, 2)
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
|
||
ch := make(chan struct{}, 1)
|
||
reqc <- ch
|
||
<-ch
|
||
w.Header().Add("Content-Length", "0")
|
||
})).ts
|
||
|
||
client := ts.Client()
|
||
transport := client.Transport.(*Transport)
|
||
transport.MaxIdleConns = 1
|
||
transport.MaxConnsPerHost = 1
|
||
|
||
var wg sync.WaitGroup
|
||
|
||
wg.Add(1)
|
||
putidlec := make(chan chan struct{}, 1)
|
||
reqerrc := make(chan error, 1)
|
||
go func() {
|
||
defer wg.Done()
|
||
ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
|
||
PutIdleConn: func(error) {
|
||
// Signal that the idle conn has been returned to the pool,
|
||
// and wait for the order to proceed.
|
||
ch := make(chan struct{})
|
||
putidlec <- ch
|
||
close(putidlec) // panic if PutIdleConn runs twice for some reason
|
||
<-ch
|
||
},
|
||
})
|
||
req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
|
||
res, err := client.Do(req)
|
||
if err != nil {
|
||
reqerrc <- err
|
||
} else {
|
||
res.Body.Close()
|
||
}
|
||
}()
|
||
|
||
// Wait for the first request to receive a response and return the
|
||
// connection to the idle pool.
|
||
select {
|
||
case err := <-reqerrc:
|
||
t.Fatalf("request 1: got err %v, want nil", err)
|
||
case r1c := <-reqc:
|
||
close(r1c)
|
||
}
|
||
var idlec chan struct{}
|
||
select {
|
||
case err := <-reqerrc:
|
||
t.Fatalf("request 1: got err %v, want nil", err)
|
||
case idlec = <-putidlec:
|
||
}
|
||
|
||
wg.Add(1)
|
||
cancelctx, cancel := context.WithCancel(context.Background())
|
||
go func() {
|
||
defer wg.Done()
|
||
req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
|
||
res, err := client.Do(req)
|
||
if err == nil {
|
||
res.Body.Close()
|
||
}
|
||
if !errors.Is(err, context.Canceled) {
|
||
t.Errorf("request 2: got err %v, want Canceled", err)
|
||
}
|
||
|
||
// Unblock the first request.
|
||
close(idlec)
|
||
}()
|
||
|
||
// Wait for the second request to arrive at the server, and then cancel
|
||
// the request context.
|
||
r2c := <-reqc
|
||
cancel()
|
||
|
||
<-idlec
|
||
|
||
close(r2c)
|
||
wg.Wait()
|
||
}
|
||
|
||
func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
|
||
func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||
go io.Copy(io.Discard, req.Body)
|
||
panic(ErrAbortHandler)
|
||
})).ts
|
||
|
||
var wg sync.WaitGroup
|
||
for i := 0; i < 2; i++ {
|
||
wg.Add(1)
|
||
go func() {
|
||
defer wg.Done()
|
||
for j := 0; j < 10; j++ {
|
||
const reqLen = 6 * 1024 * 1024
|
||
req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
|
||
req.ContentLength = reqLen
|
||
resp, _ := ts.Client().Transport.RoundTrip(req)
|
||
if resp != nil {
|
||
resp.Body.Close()
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
wg.Wait()
|
||
}
|
||
|
||
func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
|
||
func testRequestSanitization(t *testing.T, mode testMode) {
|
||
if mode == http2Mode {
|
||
// Remove this after updating x/net.
|
||
t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
|
||
}
|
||
ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||
if h, ok := req.Header["X-Evil"]; ok {
|
||
t.Errorf("request has X-Evil header: %q", h)
|
||
}
|
||
})).ts
|
||
req, _ := NewRequest("GET", ts.URL, nil)
|
||
req.Host = "go.dev\r\nX-Evil:evil"
|
||
resp, _ := ts.Client().Do(req)
|
||
if resp != nil {
|
||
resp.Body.Close()
|
||
}
|
||
}
|
||
|
||
func TestProxyAuthHeader(t *testing.T) {
|
||
// Not parallel: Sets an environment variable.
|
||
run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
|
||
}
|
||
func testProxyAuthHeader(t *testing.T, mode testMode) {
|
||
const username = "u"
|
||
const password = "@/?!"
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||
// Copy the Proxy-Authorization header to a new Request,
|
||
// since Request.BasicAuth only parses the Authorization header.
|
||
var r2 Request
|
||
r2.Header = Header{
|
||
"Authorization": req.Header["Proxy-Authorization"],
|
||
}
|
||
gotuser, gotpass, ok := r2.BasicAuth()
|
||
if !ok || gotuser != username || gotpass != password {
|
||
t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
|
||
}
|
||
}))
|
||
u, err := url.Parse(cst.ts.URL)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
u.User = url.UserPassword(username, password)
|
||
t.Setenv("HTTP_PROXY", u.String())
|
||
cst.tr.Proxy = ProxyURL(u)
|
||
resp, err := cst.c.Get("http://_/")
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
resp.Body.Close()
|
||
}
|
||
|
||
// Issue 61708
|
||
func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
|
||
ln := newLocalListener(t)
|
||
addr := ln.Addr().String()
|
||
|
||
done := make(chan struct{})
|
||
go func() {
|
||
conn, err := ln.Accept()
|
||
if err != nil {
|
||
t.Errorf("ln.Accept: %v", err)
|
||
return
|
||
}
|
||
// Start reading request before sending response to avoid
|
||
// "Unsolicited response received on idle HTTP channel" RoundTrip error.
|
||
if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil {
|
||
t.Errorf("conn.Read: %v", err)
|
||
return
|
||
}
|
||
io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo")
|
||
<-done
|
||
conn.Close()
|
||
}()
|
||
|
||
didRead := make(chan bool)
|
||
SetReadLoopBeforeNextReadHook(func() { didRead <- true })
|
||
defer SetReadLoopBeforeNextReadHook(nil)
|
||
|
||
tr := &Transport{}
|
||
|
||
// Send a request with a body guaranteed to fail on write.
|
||
req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30))
|
||
if err != nil {
|
||
t.Fatalf("NewRequest: %v", err)
|
||
}
|
||
|
||
resp, err := tr.RoundTrip(req)
|
||
if err != nil {
|
||
t.Fatalf("tr.RoundTrip: %v", err)
|
||
}
|
||
|
||
close(done)
|
||
|
||
// Before closing response body wait for readLoopDone goroutine
|
||
// to complete due to closed connection by writeLoop.
|
||
<-didRead
|
||
|
||
resp.Body.Close()
|
||
|
||
// Verify no outstanding requests after readLoop/writeLoop
|
||
// goroutines shut down.
|
||
waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
|
||
n := tr.NumPendingRequestsForTesting()
|
||
if n > 0 {
|
||
if d > 0 {
|
||
t.Logf("pending requests = %d after %v (want 0)", n, d)
|
||
}
|
||
return false
|
||
}
|
||
return true
|
||
})
|
||
}
|
||
|
||
func TestValidateClientRequestTrailers(t *testing.T) {
|
||
run(t, testValidateClientRequestTrailers)
|
||
}
|
||
|
||
func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
|
||
cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||
rw.Write([]byte("Hello"))
|
||
})).ts
|
||
|
||
cases := []struct {
|
||
trailer Header
|
||
wantErr string
|
||
}{
|
||
{Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
|
||
{Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
|
||
}
|
||
|
||
for i, tt := range cases {
|
||
testName := fmt.Sprintf("%s%d", mode, i)
|
||
t.Run(testName, func(t *testing.T) {
|
||
req, err := NewRequest("GET", cst.URL, nil)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
req.Trailer = tt.trailer
|
||
res, err := cst.Client().Do(req)
|
||
if err == nil {
|
||
t.Fatal("Expected an error")
|
||
}
|
||
if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
|
||
t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
|
||
}
|
||
if res != nil {
|
||
t.Fatal("Unexpected non-nil response")
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestTransportServerProtocols(t *testing.T) {
|
||
CondSkipHTTP2(t)
|
||
DefaultTransport.(*Transport).CloseIdleConnections()
|
||
|
||
cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
leafCert, err := x509.ParseCertificate(cert.Certificate[0])
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
certpool := x509.NewCertPool()
|
||
certpool.AddCert(leafCert)
|
||
|
||
for _, test := range []struct {
|
||
name string
|
||
scheme string
|
||
setup func(t *testing.T)
|
||
transport func(*Transport)
|
||
server func(*Server)
|
||
want string
|
||
}{{
|
||
name: "http default",
|
||
scheme: "http",
|
||
want: "HTTP/1.1",
|
||
}, {
|
||
name: "https default",
|
||
scheme: "https",
|
||
transport: func(tr *Transport) {
|
||
// Transport default is HTTP/1.
|
||
},
|
||
want: "HTTP/1.1",
|
||
}, {
|
||
name: "https transport protocols include HTTP2",
|
||
scheme: "https",
|
||
transport: func(tr *Transport) {
|
||
// Server default is to support HTTP/2, so if the Transport enables
|
||
// HTTP/2 we get it.
|
||
tr.Protocols = &Protocols{}
|
||
tr.Protocols.SetHTTP1(true)
|
||
tr.Protocols.SetHTTP2(true)
|
||
},
|
||
want: "HTTP/2.0",
|
||
}, {
|
||
name: "https transport protocols only include HTTP1",
|
||
scheme: "https",
|
||
transport: func(tr *Transport) {
|
||
// Explicitly enable only HTTP/1.
|
||
tr.Protocols = &Protocols{}
|
||
tr.Protocols.SetHTTP1(true)
|
||
},
|
||
want: "HTTP/1.1",
|
||
}, {
|
||
name: "https transport ForceAttemptHTTP2",
|
||
scheme: "https",
|
||
transport: func(tr *Transport) {
|
||
// Pre-Protocols-field way of enabling HTTP/2.
|
||
tr.ForceAttemptHTTP2 = true
|
||
},
|
||
want: "HTTP/2.0",
|
||
}, {
|
||
name: "https transport protocols override TLSNextProto",
|
||
scheme: "https",
|
||
transport: func(tr *Transport) {
|
||
// Setting TLSNextProto to an empty map is the historical way
|
||
// of disabling HTTP/2. Explicitly enabling HTTP2 in the Protocols
|
||
// field takes precedence.
|
||
tr.Protocols = &Protocols{}
|
||
tr.Protocols.SetHTTP1(true)
|
||
tr.Protocols.SetHTTP2(true)
|
||
tr.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{}
|
||
},
|
||
want: "HTTP/2.0",
|
||
}, {
|
||
name: "https server disables HTTP2 with TLSNextProto",
|
||
scheme: "https",
|
||
server: func(srv *Server) {
|
||
// Disable HTTP/2 on the server with TLSNextProto,
|
||
// use default Protocols value.
|
||
srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
|
||
},
|
||
want: "HTTP/1.1",
|
||
}, {
|
||
name: "https server Protocols overrides empty TLSNextProto",
|
||
scheme: "https",
|
||
server: func(srv *Server) {
|
||
// Explicitly enabling HTTP2 in the Protocols field takes precedence
|
||
// over setting an empty TLSNextProto.
|
||
srv.Protocols = &Protocols{}
|
||
srv.Protocols.SetHTTP1(true)
|
||
srv.Protocols.SetHTTP2(true)
|
||
srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
|
||
},
|
||
want: "HTTP/2.0",
|
||
}, {
|
||
name: "https server protocols only include HTTP1",
|
||
scheme: "https",
|
||
server: func(srv *Server) {
|
||
srv.Protocols = &Protocols{}
|
||
srv.Protocols.SetHTTP1(true)
|
||
},
|
||
want: "HTTP/1.1",
|
||
}, {
|
||
name: "https server protocols include HTTP2",
|
||
scheme: "https",
|
||
server: func(srv *Server) {
|
||
srv.Protocols = &Protocols{}
|
||
srv.Protocols.SetHTTP1(true)
|
||
srv.Protocols.SetHTTP2(true)
|
||
},
|
||
want: "HTTP/2.0",
|
||
}, {
|
||
name: "GODEBUG disables HTTP2 client",
|
||
scheme: "https",
|
||
setup: func(t *testing.T) {
|
||
t.Setenv("GODEBUG", "http2client=0")
|
||
},
|
||
transport: func(tr *Transport) {
|
||
// Server default is to support HTTP/2, so if the Transport enables
|
||
// HTTP/2 we get it.
|
||
tr.Protocols = &Protocols{}
|
||
tr.Protocols.SetHTTP1(true)
|
||
tr.Protocols.SetHTTP2(true)
|
||
},
|
||
want: "HTTP/1.1",
|
||
}, {
|
||
name: "GODEBUG disables HTTP2 server",
|
||
scheme: "https",
|
||
setup: func(t *testing.T) {
|
||
t.Setenv("GODEBUG", "http2server=0")
|
||
},
|
||
transport: func(tr *Transport) {
|
||
// Server default is to support HTTP/2, so if the Transport enables
|
||
// HTTP/2 we get it.
|
||
tr.Protocols = &Protocols{}
|
||
tr.Protocols.SetHTTP1(true)
|
||
tr.Protocols.SetHTTP2(true)
|
||
},
|
||
want: "HTTP/1.1",
|
||
}} {
|
||
t.Run(test.name, func(t *testing.T) {
|
||
// We don't use httptest here because it makes its own decisions
|
||
// about how to enable/disable HTTP/2.
|
||
srv := &Server{
|
||
TLSConfig: &tls.Config{
|
||
Certificates: []tls.Certificate{cert},
|
||
},
|
||
Handler: HandlerFunc(func(w ResponseWriter, req *Request) {
|
||
w.Header().Set("X-Proto", req.Proto)
|
||
}),
|
||
}
|
||
tr := &Transport{
|
||
TLSClientConfig: &tls.Config{
|
||
RootCAs: certpool,
|
||
},
|
||
}
|
||
|
||
if test.setup != nil {
|
||
test.setup(t)
|
||
}
|
||
if test.server != nil {
|
||
test.server(srv)
|
||
}
|
||
if test.transport != nil {
|
||
test.transport(tr)
|
||
} else {
|
||
tr.Protocols = &Protocols{}
|
||
tr.Protocols.SetHTTP1(true)
|
||
tr.Protocols.SetHTTP2(true)
|
||
}
|
||
|
||
listener := newLocalListener(t)
|
||
srvc := make(chan error, 1)
|
||
go func() {
|
||
switch test.scheme {
|
||
case "http":
|
||
srvc <- srv.Serve(listener)
|
||
case "https":
|
||
srvc <- srv.ServeTLS(listener, "", "")
|
||
}
|
||
}()
|
||
t.Cleanup(func() {
|
||
srv.Close()
|
||
<-srvc
|
||
})
|
||
|
||
client := &Client{Transport: tr}
|
||
resp, err := client.Get(test.scheme + "://" + listener.Addr().String())
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if got := resp.Header.Get("X-Proto"); got != test.want {
|
||
t.Fatalf("request proto %q, want %q", got, test.want)
|
||
}
|
||
})
|
||
}
|
||
}
|