mirror of https://github.com/golang/go.git
net/http: add Server BaseContext & ConnContext fields to control early context
Fixes golang/go#30694 Change-Id: I12a0a870e4aee6576e879d88a4868666ef448298 Reviewed-on: https://go-review.googlesource.com/c/go/+/167681 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: JP Sugarbroad <jpsugar@google.com> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
dec5d99b71
commit
2c802e9980
|
|
@ -6034,6 +6034,43 @@ func TestStripPortFromHost(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerContexts(t *testing.T) {
|
||||||
|
setParallel(t)
|
||||||
|
defer afterTest(t)
|
||||||
|
type baseKey struct{}
|
||||||
|
type connKey struct{}
|
||||||
|
ch := make(chan context.Context, 1)
|
||||||
|
ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
|
||||||
|
ch <- r.Context()
|
||||||
|
}))
|
||||||
|
ts.Config.BaseContext = func(ln net.Listener) context.Context {
|
||||||
|
if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
|
||||||
|
t.Errorf("unexpected onceClose listener type %T", ln)
|
||||||
|
}
|
||||||
|
return context.WithValue(context.Background(), baseKey{}, "base")
|
||||||
|
}
|
||||||
|
ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
|
||||||
|
if got, want := ctx.Value(baseKey{}), "base"; got != want {
|
||||||
|
t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, connKey{}, "conn")
|
||||||
|
}
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
res, err := ts.Client().Get(ts.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
ctx := <-ch
|
||||||
|
if got, want := ctx.Value(baseKey{}), "base"; got != want {
|
||||||
|
t.Errorf("base context key = %#v; want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := ctx.Value(connKey{}), "conn"; got != want {
|
||||||
|
t.Errorf("conn context key = %#v; want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkResponseStatusLine(b *testing.B) {
|
func BenchmarkResponseStatusLine(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
|
|
||||||
|
|
@ -2542,6 +2542,20 @@ type Server struct {
|
||||||
// If nil, logging is done via the log package's standard logger.
|
// If nil, logging is done via the log package's standard logger.
|
||||||
ErrorLog *log.Logger
|
ErrorLog *log.Logger
|
||||||
|
|
||||||
|
// BaseContext optionally specifies a function that returns
|
||||||
|
// the base context for incoming requests on this server.
|
||||||
|
// The provided Listener is the specific Listener that's
|
||||||
|
// about to start accepting requests.
|
||||||
|
// If BaseContext is nil, the default is context.Background().
|
||||||
|
// If non-nil, it must return a non-nil context.
|
||||||
|
BaseContext func(net.Listener) context.Context
|
||||||
|
|
||||||
|
// ConnContext optionally specifies a function that modifies
|
||||||
|
// the context used for a newly connection c. The provided ctx
|
||||||
|
// is derived from the base context and has a ServerContextKey
|
||||||
|
// value.
|
||||||
|
ConnContext func(ctx context.Context, c net.Conn) context.Context
|
||||||
|
|
||||||
disableKeepAlives int32 // accessed atomically.
|
disableKeepAlives int32 // accessed atomically.
|
||||||
inShutdown int32 // accessed atomically (non-zero means we're in Shutdown)
|
inShutdown int32 // accessed atomically (non-zero means we're in Shutdown)
|
||||||
nextProtoOnce sync.Once // guards setupHTTP2_* init
|
nextProtoOnce sync.Once // guards setupHTTP2_* init
|
||||||
|
|
@ -2838,6 +2852,7 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||||
fn(srv, l) // call hook with unwrapped listener
|
fn(srv, l) // call hook with unwrapped listener
|
||||||
}
|
}
|
||||||
|
|
||||||
|
origListener := l
|
||||||
l = &onceCloseListener{Listener: l}
|
l = &onceCloseListener{Listener: l}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
|
|
||||||
|
|
@ -2850,8 +2865,16 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||||
}
|
}
|
||||||
defer srv.trackListener(&l, false)
|
defer srv.trackListener(&l, false)
|
||||||
|
|
||||||
var tempDelay time.Duration // how long to sleep on accept failure
|
var tempDelay time.Duration // how long to sleep on accept failure
|
||||||
baseCtx := context.Background() // base is always background, per Issue 16220
|
|
||||||
|
baseCtx := context.Background()
|
||||||
|
if srv.BaseContext != nil {
|
||||||
|
baseCtx = srv.BaseContext(origListener)
|
||||||
|
if baseCtx == nil {
|
||||||
|
panic("BaseContext returned a nil context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(baseCtx, ServerContextKey, srv)
|
ctx := context.WithValue(baseCtx, ServerContextKey, srv)
|
||||||
for {
|
for {
|
||||||
rw, e := l.Accept()
|
rw, e := l.Accept()
|
||||||
|
|
@ -2876,6 +2899,12 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||||
}
|
}
|
||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
if cc := srv.ConnContext; cc != nil {
|
||||||
|
ctx = cc(ctx, rw)
|
||||||
|
if ctx == nil {
|
||||||
|
panic("ConnContext returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
tempDelay = 0
|
tempDelay = 0
|
||||||
c := srv.newConn(rw)
|
c := srv.newConn(rw)
|
||||||
c.setState(c.rwc, StateNew) // before Serve can return
|
c.setState(c.rwc, StateNew) // before Serve can return
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue