mirror of https://github.com/golang/go.git
net/http: add optional Server.ConnState callback
Update #4674 This allows for all sorts of graceful shutdown policies, without picking a policy (e.g. lameduck period) and without adding lots of locking to the server core. That policy and locking can be implemented outside of net/http now. LGTM=adg R=golang-codereviews, josharian, r, adg, dvyukov CC=golang-codereviews https://golang.org/cl/69260044
This commit is contained in:
parent
d9c6ae6ae8
commit
67b8bf3e32
|
|
@ -26,6 +26,7 @@ import (
|
|||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
|
@ -2243,6 +2244,120 @@ func TestAppendTime(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestServerConnState(t *testing.T) {
|
||||
defer afterTest(t)
|
||||
handler := map[string]func(w ResponseWriter, r *Request){
|
||||
"/": func(w ResponseWriter, r *Request) {
|
||||
fmt.Fprintf(w, "Hello.")
|
||||
},
|
||||
"/close": func(w ResponseWriter, r *Request) {
|
||||
w.Header().Set("Connection", "close")
|
||||
fmt.Fprintf(w, "Hello.")
|
||||
},
|
||||
"/hijack": func(w ResponseWriter, r *Request) {
|
||||
c, _, _ := w.(Hijacker).Hijack()
|
||||
c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
|
||||
c.Close()
|
||||
},
|
||||
}
|
||||
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
handler[r.URL.Path](w, r)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
type connIDAndState struct {
|
||||
connID int
|
||||
state ConnState
|
||||
}
|
||||
var mu sync.Mutex // guard stateLog and connID
|
||||
var stateLog []connIDAndState
|
||||
var connID = map[net.Conn]int{}
|
||||
|
||||
ts.Config.ConnState = func(c net.Conn, state ConnState) {
|
||||
if c == nil {
|
||||
t.Error("nil conn seen in state %s", state)
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
id, ok := connID[c]
|
||||
if !ok {
|
||||
id = len(connID) + 1
|
||||
connID[c] = id
|
||||
}
|
||||
stateLog = append(stateLog, connIDAndState{id, state})
|
||||
}
|
||||
ts.Start()
|
||||
|
||||
mustGet(t, ts.URL+"/")
|
||||
mustGet(t, ts.URL+"/close")
|
||||
|
||||
mustGet(t, ts.URL+"/")
|
||||
mustGet(t, ts.URL+"/", "Connection", "close")
|
||||
|
||||
mustGet(t, ts.URL+"/hijack")
|
||||
|
||||
want := []connIDAndState{
|
||||
{1, StateNew},
|
||||
{1, StateActive},
|
||||
{1, StateIdle},
|
||||
{1, StateActive},
|
||||
{1, StateClosed},
|
||||
|
||||
{2, StateNew},
|
||||
{2, StateActive},
|
||||
{2, StateIdle},
|
||||
{2, StateActive},
|
||||
{2, StateClosed},
|
||||
|
||||
{3, StateNew},
|
||||
{3, StateActive},
|
||||
{3, StateHijacked},
|
||||
}
|
||||
logString := func(l []connIDAndState) string {
|
||||
var b bytes.Buffer
|
||||
for _, cs := range l {
|
||||
fmt.Fprintf(&b, "[%d %s] ", cs.connID, cs.state)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
time.Sleep(time.Duration(i) * 50 * time.Millisecond)
|
||||
mu.Lock()
|
||||
match := reflect.DeepEqual(stateLog, want)
|
||||
mu.Unlock()
|
||||
if match {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
t.Errorf("Unexpected events.\nGot log: %s\n Want: %s\n", logString(stateLog), logString(want))
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
func mustGet(t *testing.T, url string, headers ...string) {
|
||||
req, err := NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for len(headers) > 0 {
|
||||
req.Header.Add(headers[0], headers[1])
|
||||
headers = headers[2:]
|
||||
}
|
||||
res, err := DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Errorf("Error fetching %s: %v", url, err)
|
||||
return
|
||||
}
|
||||
_, err = ioutil.ReadAll(res.Body)
|
||||
defer res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Error reading %s: %v", url, err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkClientServer(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.StopTimer()
|
||||
|
|
|
|||
|
|
@ -1079,8 +1079,16 @@ func validNPN(proto string) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (c *conn) setState(nc net.Conn, state ConnState) {
|
||||
if hook := c.server.ConnState; hook != nil {
|
||||
hook(nc, state)
|
||||
}
|
||||
}
|
||||
|
||||
// Serve a new connection.
|
||||
func (c *conn) serve() {
|
||||
origConn := c.rwc // copy it before it's set nil on Close or Hijack
|
||||
c.setState(origConn, StateNew)
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
const size = 64 << 10
|
||||
|
|
@ -1090,6 +1098,7 @@ func (c *conn) serve() {
|
|||
}
|
||||
if !c.hijacked() {
|
||||
c.close()
|
||||
c.setState(origConn, StateClosed)
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
@ -1116,6 +1125,10 @@ func (c *conn) serve() {
|
|||
|
||||
for {
|
||||
w, err := c.readRequest()
|
||||
// TODO(bradfitz): could push this StateActive
|
||||
// earlier, but in practice header will be all in one
|
||||
// packet/Read:
|
||||
c.setState(c.rwc, StateActive)
|
||||
if err != nil {
|
||||
if err == errTooLarge {
|
||||
// Their HTTP client may or may not be
|
||||
|
|
@ -1161,6 +1174,7 @@ func (c *conn) serve() {
|
|||
// in parallel even if their responses need to be serialized.
|
||||
serverHandler{c.server}.ServeHTTP(w, w.req)
|
||||
if c.hijacked() {
|
||||
c.setState(origConn, StateHijacked)
|
||||
return
|
||||
}
|
||||
w.finishRequest()
|
||||
|
|
@ -1170,6 +1184,7 @@ func (c *conn) serve() {
|
|||
}
|
||||
break
|
||||
}
|
||||
c.setState(c.rwc, StateIdle)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1580,6 +1595,58 @@ type Server struct {
|
|||
// and RemoteAddr if not already set. The connection is
|
||||
// automatically closed when the function returns.
|
||||
TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
|
||||
|
||||
// ConnState specifies an optional callback function that is
|
||||
// called when a client connection changes state. See the
|
||||
// ConnState type and associated constants for details.
|
||||
ConnState func(net.Conn, ConnState)
|
||||
}
|
||||
|
||||
// A ConnState represents the state of a client connection to a server.
|
||||
// It's used by the optional Server.ConnState hook.
|
||||
type ConnState int
|
||||
|
||||
const (
|
||||
// StateNew represents a new connection that is expected to
|
||||
// send a request immediately. Connections begin at this
|
||||
// state and then transition to either StateActive or
|
||||
// StateClosed.
|
||||
StateNew ConnState = iota
|
||||
|
||||
// StateActive represents a connection that has read 1 or more
|
||||
// bytes of a request. The Server.ConnState hook for
|
||||
// StateActive fires before the request has entered a handler
|
||||
// and doesn't fire again until the request has been
|
||||
// handled. After the request is handled, the state
|
||||
// transitions to StateClosed, StateHijacked, or StateIdle.
|
||||
StateActive
|
||||
|
||||
// StateIdle represents a connection that has finished
|
||||
// handling a request and is in the keep-alive state, waiting
|
||||
// for a new request. Connections transition from StateIdle
|
||||
// to either StateActive or StateClosed.
|
||||
StateIdle
|
||||
|
||||
// StateHijacked represents a hijacked connection.
|
||||
// This is a terminal state. It does not transition to StateClosed.
|
||||
StateHijacked
|
||||
|
||||
// StateClosed represents a closed connection.
|
||||
// This is a terminal state. Hijacked connections do not
|
||||
// transition to StateClosed.
|
||||
StateClosed
|
||||
)
|
||||
|
||||
var stateName = map[ConnState]string{
|
||||
StateNew: "new",
|
||||
StateActive: "active",
|
||||
StateIdle: "idle",
|
||||
StateHijacked: "hijacked",
|
||||
StateClosed: "closed",
|
||||
}
|
||||
|
||||
func (c ConnState) String() string {
|
||||
return stateName[c]
|
||||
}
|
||||
|
||||
// serverHandler delegates to either the server's Handler or
|
||||
|
|
|
|||
Loading…
Reference in New Issue