mirror of https://github.com/golang/go.git
net/http: add Client.Timeout for end-to-end timeouts
Fixes #3362 LGTM=josharian R=golang-codereviews, josharian CC=adg, dsymonds, golang-codereviews, n13m3y3r https://golang.org/cl/70120045
This commit is contained in:
parent
92d5483391
commit
2ad72ecf34
|
|
@ -17,6 +17,8 @@ import (
|
|||
"log"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A Client is an HTTP client. Its zero value (DefaultClient) is a
|
||||
|
|
@ -52,6 +54,21 @@ type Client struct {
|
|||
// If Jar is nil, cookies are not sent in requests and ignored
|
||||
// in responses.
|
||||
Jar CookieJar
|
||||
|
||||
// Timeout specifies the end-to-end timeout for requests made
|
||||
// via this Client. The timeout includes connection time, any
|
||||
// redirects, and reading the response body. The timeout
|
||||
// remains running once Get, Head, Post, or Do returns and
|
||||
// will interrupt the read of the Response.Body if EOF hasn't
|
||||
// been reached.
|
||||
//
|
||||
// A Timeout of zero means no timeout.
|
||||
//
|
||||
// The Client's Transport must support the CancelRequest
|
||||
// method or Client will return errors when attempting to make
|
||||
// a request with Get, Head, Post, or Do. Client's default
|
||||
// Transport (DefaultTransport) supports CancelRequest.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// DefaultClient is the default Client and is used by Get, Head, and Post.
|
||||
|
|
@ -97,7 +114,7 @@ func (c *Client) send(req *Request) (*Response, error) {
|
|||
req.AddCookie(cookie)
|
||||
}
|
||||
}
|
||||
resp, err := send(req, c.Transport)
|
||||
resp, err := send(req, c.transport())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -134,15 +151,18 @@ func (c *Client) Do(req *Request) (resp *Response, err error) {
|
|||
return c.send(req)
|
||||
}
|
||||
|
||||
func (c *Client) transport() RoundTripper {
|
||||
if c.Transport != nil {
|
||||
return c.Transport
|
||||
}
|
||||
return DefaultTransport
|
||||
}
|
||||
|
||||
// send issues an HTTP request.
|
||||
// Caller should close resp.Body when done reading from it.
|
||||
func send(req *Request, t RoundTripper) (resp *Response, err error) {
|
||||
if t == nil {
|
||||
t = DefaultTransport
|
||||
if t == nil {
|
||||
err = errors.New("http: no Client.Transport or DefaultTransport")
|
||||
return
|
||||
}
|
||||
return nil, errors.New("http: no Client.Transport or DefaultTransport")
|
||||
}
|
||||
|
||||
if req.URL == nil {
|
||||
|
|
@ -260,18 +280,36 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
|
|||
return nil, errors.New("http: nil Request.URL")
|
||||
}
|
||||
|
||||
var reqmu sync.Mutex // guards req
|
||||
req := ireq
|
||||
|
||||
var timer *time.Timer
|
||||
if c.Timeout > 0 {
|
||||
type canceler interface {
|
||||
CancelRequest(*Request)
|
||||
}
|
||||
tr, ok := c.transport().(canceler)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport())
|
||||
}
|
||||
timer = time.AfterFunc(c.Timeout, func() {
|
||||
reqmu.Lock()
|
||||
defer reqmu.Unlock()
|
||||
tr.CancelRequest(req)
|
||||
})
|
||||
}
|
||||
|
||||
urlStr := "" // next relative or absolute URL to fetch (after first request)
|
||||
redirectFailed := false
|
||||
for redirect := 0; ; redirect++ {
|
||||
if redirect != 0 {
|
||||
req = new(Request)
|
||||
req.Method = ireq.Method
|
||||
nreq := new(Request)
|
||||
nreq.Method = ireq.Method
|
||||
if ireq.Method == "POST" || ireq.Method == "PUT" {
|
||||
req.Method = "GET"
|
||||
nreq.Method = "GET"
|
||||
}
|
||||
req.Header = make(Header)
|
||||
req.URL, err = base.Parse(urlStr)
|
||||
nreq.Header = make(Header)
|
||||
nreq.URL, err = base.Parse(urlStr)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
|
@ -279,15 +317,18 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
|
|||
// Add the Referer header.
|
||||
lastReq := via[len(via)-1]
|
||||
if lastReq.URL.Scheme != "https" {
|
||||
req.Header.Set("Referer", lastReq.URL.String())
|
||||
nreq.Header.Set("Referer", lastReq.URL.String())
|
||||
}
|
||||
|
||||
err = redirectChecker(req, via)
|
||||
err = redirectChecker(nreq, via)
|
||||
if err != nil {
|
||||
redirectFailed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
reqmu.Lock()
|
||||
req = nreq
|
||||
reqmu.Unlock()
|
||||
}
|
||||
|
||||
urlStr = req.URL.String()
|
||||
|
|
@ -305,7 +346,10 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
|
|||
via = append(via, req)
|
||||
continue
|
||||
}
|
||||
return
|
||||
if timer != nil {
|
||||
resp.Body = &cancelTimerBody{timer, resp.Body}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
method := ireq.Method
|
||||
|
|
@ -408,3 +452,22 @@ func (c *Client) Head(url string) (resp *Response, err error) {
|
|||
}
|
||||
return c.doFollowingRedirects(req, shouldRedirectGet)
|
||||
}
|
||||
|
||||
type cancelTimerBody struct {
|
||||
t *time.Timer
|
||||
rc io.ReadCloser
|
||||
}
|
||||
|
||||
func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
|
||||
n, err = b.rc.Read(p)
|
||||
if err == io.EOF {
|
||||
b.t.Stop()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (b *cancelTimerBody) Close() error {
|
||||
err := b.rc.Close()
|
||||
b.t.Stop()
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -812,3 +812,70 @@ func TestBasicAuth(t *testing.T) {
|
|||
t.Errorf("Invalid auth %q", auth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientTimeout(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
defer afterTest(t)
|
||||
sawRoot := make(chan bool, 1)
|
||||
sawSlow := make(chan bool, 1)
|
||||
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
if r.URL.Path == "/" {
|
||||
sawRoot <- true
|
||||
Redirect(w, r, "/slow", StatusFound)
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/slow" {
|
||||
w.Write([]byte("Hello"))
|
||||
w.(Flusher).Flush()
|
||||
sawSlow <- true
|
||||
time.Sleep(2 * time.Second)
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
const timeout = 500 * time.Millisecond
|
||||
c := &Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
res, err := c.Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-sawRoot:
|
||||
// good.
|
||||
default:
|
||||
t.Fatal("handler never got / request")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-sawSlow:
|
||||
// good.
|
||||
default:
|
||||
t.Fatal("handler never got /slow request")
|
||||
}
|
||||
|
||||
var all []byte
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
var err error
|
||||
all, err = ioutil.ReadAll(res.Body)
|
||||
errc <- err
|
||||
res.Body.Close()
|
||||
}()
|
||||
|
||||
const failTime = timeout * 2
|
||||
select {
|
||||
case err := <-errc:
|
||||
if err == nil {
|
||||
t.Error("expected error from ReadAll")
|
||||
}
|
||||
t.Logf("Got expected ReadAll error of %v after reading body %q", err, all)
|
||||
case <-time.After(failTime):
|
||||
t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue