diff --git a/src/net/http/csrf.go b/src/net/http/csrf.go index a46071f806..8812a508ae 100644 --- a/src/net/http/csrf.go +++ b/src/net/http/csrf.go @@ -26,12 +26,15 @@ import ( // Requests without Sec-Fetch-Site or Origin headers are currently assumed to be // either same-origin or non-browser requests, and are allowed. // +// The zero value of CrossOriginProtection is valid and has no trusted origins +// or bypass patterns. +// // [Sec-Fetch-Site]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site // [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin // [Cross-Site Request Forgery (CSRF)]: https://developer.mozilla.org/en-US/docs/Web/Security/Attacks/CSRF // [safe methods]: https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP type CrossOriginProtection struct { - bypass *ServeMux + bypass atomic.Pointer[ServeMux] trustedMu sync.RWMutex trusted map[string]bool deny atomic.Pointer[Handler] @@ -39,10 +42,7 @@ type CrossOriginProtection struct { // NewCrossOriginProtection returns a new [CrossOriginProtection] value. func NewCrossOriginProtection() *CrossOriginProtection { - return &CrossOriginProtection{ - bypass: NewServeMux(), - trusted: make(map[string]bool), - } + return &CrossOriginProtection{} } // AddTrustedOrigin allows all requests with an [Origin] header @@ -70,6 +70,9 @@ func (c *CrossOriginProtection) AddTrustedOrigin(origin string) error { } c.trustedMu.Lock() defer c.trustedMu.Unlock() + if c.trusted == nil { + c.trusted = make(map[string]bool) + } c.trusted[origin] = true return nil } @@ -82,7 +85,21 @@ var noopHandler = HandlerFunc(func(w ResponseWriter, r *Request) {}) // AddInsecureBypassPattern can be called concurrently with other methods // or request handling, and applies to future requests. func (c *CrossOriginProtection) AddInsecureBypassPattern(pattern string) { - c.bypass.Handle(pattern, noopHandler) + var bypass *ServeMux + + // Lazily initialize c.bypass + for { + bypass = c.bypass.Load() + if bypass != nil { + break + } + bypass = NewServeMux() + if c.bypass.CompareAndSwap(nil, bypass) { + break + } + } + + bypass.Handle(pattern, noopHandler) } // SetDenyHandler sets a handler to invoke when a request is rejected. @@ -149,9 +166,11 @@ func (c *CrossOriginProtection) Check(req *Request) error { // isRequestExempt checks the bypasses which require taking a lock, and should // be deferred until the last moment. func (c *CrossOriginProtection) isRequestExempt(req *Request) bool { - if _, pattern := c.bypass.Handler(req); pattern != "" { - // The request matches a bypass pattern. - return true + if bypass := c.bypass.Load(); bypass != nil { + if _, pattern := bypass.Handler(req); pattern != "" { + // The request matches a bypass pattern. + return true + } } c.trustedMu.RLock()