This commit is contained in:
Weidi Deng 2022-11-02 09:04:56 +08:00
parent 81efd7b347
commit 62d7a8064e
2 changed files with 59 additions and 0 deletions

View File

@ -1433,6 +1433,53 @@ func testTLSServer(t *testing.T, mode testMode) {
}
}
type fakeConnectionStateConn struct {
net.Conn
}
func (fcsc *fakeConnectionStateConn) ConnectionState() tls.ConnectionState {
return tls.ConnectionState{
ServerName: "example.com",
}
}
func TestTLSServerWithoutTLSConn(t *testing.T) {
//set up
pr, pw := net.Pipe()
c := make(chan int)
listener := &oneConnListener{&fakeConnectionStateConn{pr}}
server := &Server{
Handler: HandlerFunc(func(writer ResponseWriter, request *Request) {
if request.TLS == nil {
t.Fatal("request.TLS is nil, expected not nil")
}
if request.TLS.ServerName != "example.com" {
t.Fatalf("request.TLS.ServerName is %s, expected %s", request.TLS.ServerName, "example.com")
}
writer.Header().Set("X-TLS-ServerName", "example.com")
}),
}
// write request and read response
go func() {
req, _ := NewRequest(MethodGet, "https://example.com", nil)
req.Write(pw)
resp, _ := ReadResponse(bufio.NewReader(pw), req)
if hdr := resp.Header.Get("X-TLS-ServerName"); hdr != "example.com" {
t.Errorf("response header X-TLS-ServerName is %s, expected %s", hdr, "example.com")
}
close(c)
pw.Close()
}()
server.Serve(listener)
// oneConnListener returns error after one accept, wait util response is read
<-c
pr.Close()
}
func TestServeTLS(t *testing.T) {
CondSkipHTTP2(t)
// Not parallel: uses global test hooks.

View File

@ -1826,6 +1826,10 @@ func isCommonNetReadError(err error) bool {
return false
}
type connectionStater interface {
ConnectionState() tls.ConnectionState
}
// Serve a new connection.
func (c *conn) serve(ctx context.Context) {
c.remoteAddr = c.rwc.RemoteAddr().String()
@ -1892,6 +1896,14 @@ func (c *conn) serve(ctx context.Context) {
// HTTP/1.x from here on.
// Set Request.TLS if the conn is not a *tls.Conn, but implements ConnectionState.
if c.tlsState == nil {
if tc, ok := c.rwc.(connectionStater); ok {
c.tlsState = new(tls.ConnectionState)
*c.tlsState = tc.ConnectionState()
}
}
ctx, cancelCtx := context.WithCancel(ctx)
c.cancelCtx = cancelCtx
defer cancelCtx()