mirror of https://github.com/golang/go.git
net/http: add CrossOriginProtection
Fixes #73626 Change-Id: I6a6a4656862e7a38acb65c4815fb7a1e04896172 Reviewed-on: https://go-review.googlesource.com/c/go/+/674936 Reviewed-by: Damien Neil <dneil@google.com> Auto-Submit: Filippo Valsorda <filippo@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: David Chase <drchase@google.com>
This commit is contained in:
parent
ef20ccc10b
commit
1881d680b0
|
|
@ -0,0 +1,7 @@
|
||||||
|
pkg net/http, func NewCrossOriginProtection() *CrossOriginProtection #73626
|
||||||
|
pkg net/http, method (*CrossOriginProtection) AddInsecureBypassPattern(string) #73626
|
||||||
|
pkg net/http, method (*CrossOriginProtection) AddTrustedOrigin(string) error #73626
|
||||||
|
pkg net/http, method (*CrossOriginProtection) Check(*Request) error #73626
|
||||||
|
pkg net/http, method (*CrossOriginProtection) Handler(Handler) Handler #73626
|
||||||
|
pkg net/http, method (*CrossOriginProtection) SetDenyHandler(Handler) #73626
|
||||||
|
pkg net/http, type CrossOriginProtection struct #73626
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
The new [CrossOriginProtection] implements protections against [Cross-Site
|
||||||
|
Request Forgery (CSRF)][] by rejecting non-safe cross-origin browser requests.
|
||||||
|
It uses [modern browser Fetch metadata][Sec-Fetch-Site], doesn't require tokens
|
||||||
|
or cookies, and supports origin-based and pattern-based bypasses.
|
||||||
|
|
||||||
|
[Sec-Fetch-Site]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site
|
||||||
|
[Cross-Site Request Forgery (CSRF)]: https://developer.mozilla.org/en-US/docs/Web/Security/Attacks/CSRF
|
||||||
|
|
@ -0,0 +1,182 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CrossOriginProtection implements protections against [Cross-Site Request
|
||||||
|
// Forgery (CSRF)] by rejecting non-safe cross-origin browser requests.
|
||||||
|
//
|
||||||
|
// Cross-origin requests are currently detected with the [Sec-Fetch-Site]
|
||||||
|
// header, available in all browsers since 2023, or by comparing the hostname of
|
||||||
|
// the [Origin] header with the Host header.
|
||||||
|
//
|
||||||
|
// The GET, HEAD, and OPTIONS methods are [safe methods] and are always allowed.
|
||||||
|
// It's important that applications do not perform any state changing actions
|
||||||
|
// due to requests with safe methods.
|
||||||
|
//
|
||||||
|
// Requests without Sec-Fetch-Site or Origin headers are currently assumed to be
|
||||||
|
// either same-origin or non-browser requests, and are allowed.
|
||||||
|
//
|
||||||
|
// [Sec-Fetch-Site]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site
|
||||||
|
// [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
|
||||||
|
// [Cross-Site Request Forgery (CSRF)]: https://developer.mozilla.org/en-US/docs/Web/Security/Attacks/CSRF
|
||||||
|
// [safe methods]: https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP
|
||||||
|
type CrossOriginProtection struct {
|
||||||
|
bypass *ServeMux
|
||||||
|
trustedMu sync.RWMutex
|
||||||
|
trusted map[string]bool
|
||||||
|
deny atomic.Pointer[Handler]
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCrossOriginProtection returns a new [CrossOriginProtection] value.
|
||||||
|
func NewCrossOriginProtection() *CrossOriginProtection {
|
||||||
|
return &CrossOriginProtection{
|
||||||
|
bypass: NewServeMux(),
|
||||||
|
trusted: make(map[string]bool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTrustedOrigin allows all requests with an [Origin] header
|
||||||
|
// which exactly matches the given value.
|
||||||
|
//
|
||||||
|
// Origin header values are of the form "scheme://host[:port]".
|
||||||
|
//
|
||||||
|
// AddTrustedOrigin can be called concurrently with other methods
|
||||||
|
// or request handling, and applies to future requests.
|
||||||
|
//
|
||||||
|
// [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
|
||||||
|
func (c *CrossOriginProtection) AddTrustedOrigin(origin string) error {
|
||||||
|
u, err := url.Parse(origin)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid origin %q: %w", origin, err)
|
||||||
|
}
|
||||||
|
if u.Scheme == "" {
|
||||||
|
return fmt.Errorf("invalid origin %q: scheme is required", origin)
|
||||||
|
}
|
||||||
|
if u.Host == "" {
|
||||||
|
return fmt.Errorf("invalid origin %q: host is required", origin)
|
||||||
|
}
|
||||||
|
if u.Path != "" || u.RawQuery != "" || u.Fragment != "" {
|
||||||
|
return fmt.Errorf("invalid origin %q: path, query, and fragment are not allowed", origin)
|
||||||
|
}
|
||||||
|
c.trustedMu.Lock()
|
||||||
|
defer c.trustedMu.Unlock()
|
||||||
|
c.trusted[origin] = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var noopHandler = HandlerFunc(func(w ResponseWriter, r *Request) {})
|
||||||
|
|
||||||
|
// AddInsecureBypassPattern permits all requests that match the given pattern.
|
||||||
|
// The pattern syntax and precedence rules are the same as [ServeMux].
|
||||||
|
//
|
||||||
|
// AddInsecureBypassPattern can be called concurrently with other methods
|
||||||
|
// or request handling, and applies to future requests.
|
||||||
|
func (c *CrossOriginProtection) AddInsecureBypassPattern(pattern string) {
|
||||||
|
c.bypass.Handle(pattern, noopHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDenyHandler sets a handler to invoke when a request is rejected.
|
||||||
|
// The default error handler responds with a 403 Forbidden status.
|
||||||
|
//
|
||||||
|
// SetDenyHandler can be called concurrently with other methods
|
||||||
|
// or request handling, and applies to future requests.
|
||||||
|
//
|
||||||
|
// Check does not call the error handler.
|
||||||
|
func (c *CrossOriginProtection) SetDenyHandler(h Handler) {
|
||||||
|
if h == nil {
|
||||||
|
c.deny.Store(nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.deny.Store(&h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check applies cross-origin checks to a request.
|
||||||
|
// It returns an error if the request should be rejected.
|
||||||
|
func (c *CrossOriginProtection) Check(req *Request) error {
|
||||||
|
switch req.Method {
|
||||||
|
case "GET", "HEAD", "OPTIONS":
|
||||||
|
// Safe methods are always allowed.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch req.Header.Get("Sec-Fetch-Site") {
|
||||||
|
case "":
|
||||||
|
// No Sec-Fetch-Site header is present.
|
||||||
|
// Fallthrough to check the Origin header.
|
||||||
|
case "same-origin", "none":
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
if c.isRequestExempt(req) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errors.New("cross-origin request detected from Sec-Fetch-Site header")
|
||||||
|
}
|
||||||
|
|
||||||
|
origin := req.Header.Get("Origin")
|
||||||
|
if origin == "" {
|
||||||
|
// Neither Sec-Fetch-Site nor Origin headers are present.
|
||||||
|
// Either the request is same-origin or not a browser request.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if o, err := url.Parse(origin); err == nil && o.Host == req.Host {
|
||||||
|
// The Origin header matches the Host header. Note that the Host header
|
||||||
|
// doesn't include the scheme, so we don't know if this might be an
|
||||||
|
// HTTP→HTTPS cross-origin request. We fail open, since all modern
|
||||||
|
// browsers support Sec-Fetch-Site since 2023, and running an older
|
||||||
|
// browser makes a clear security trade-off already. Sites can mitigate
|
||||||
|
// this with HTTP Strict Transport Security (HSTS).
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.isRequestExempt(req) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errors.New("cross-origin request detected, and/or browser is out of date: " +
|
||||||
|
"Sec-Fetch-Site is missing, and Origin does not match Host")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isRequestExempt checks the bypasses which require taking a lock, and should
|
||||||
|
// be deferred until the last moment.
|
||||||
|
func (c *CrossOriginProtection) isRequestExempt(req *Request) bool {
|
||||||
|
if _, pattern := c.bypass.Handler(req); pattern != "" {
|
||||||
|
// The request matches a bypass pattern.
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
c.trustedMu.RLock()
|
||||||
|
defer c.trustedMu.RUnlock()
|
||||||
|
origin := req.Header.Get("Origin")
|
||||||
|
// The request matches a trusted origin.
|
||||||
|
return origin != "" && c.trusted[origin]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler returns a handler that applies cross-origin checks
|
||||||
|
// before invoking the handler h.
|
||||||
|
//
|
||||||
|
// If a request fails cross-origin checks, the request is rejected
|
||||||
|
// with a 403 Forbidden status or handled with the handler passed
|
||||||
|
// to [CrossOriginProtection.SetDenyHandler].
|
||||||
|
func (c *CrossOriginProtection) Handler(h Handler) Handler {
|
||||||
|
return HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||||
|
if err := c.Check(r); err != nil {
|
||||||
|
if deny := c.deny.Load(); deny != nil {
|
||||||
|
(*deny).ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Error(w, err.Error(), StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,330 @@
|
||||||
|
// Copyright 2025 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package http_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// httptestNewRequest works around https://go.dev/issue/73151.
|
||||||
|
func httptestNewRequest(method, target string) *http.Request {
|
||||||
|
req := httptest.NewRequest(method, target, nil)
|
||||||
|
req.URL.Scheme = ""
|
||||||
|
req.URL.Host = ""
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
func TestCrossOriginProtectionSecFetchSite(t *testing.T) {
|
||||||
|
protection := http.NewCrossOriginProtection()
|
||||||
|
handler := protection.Handler(okHandler)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
secFetchSite string
|
||||||
|
origin string
|
||||||
|
expectedStatus int
|
||||||
|
}{
|
||||||
|
{"same-origin allowed", "POST", "same-origin", "", http.StatusOK},
|
||||||
|
{"none allowed", "POST", "none", "", http.StatusOK},
|
||||||
|
{"cross-site blocked", "POST", "cross-site", "", http.StatusForbidden},
|
||||||
|
{"same-site blocked", "POST", "same-site", "", http.StatusForbidden},
|
||||||
|
|
||||||
|
{"no header with no origin", "POST", "", "", http.StatusOK},
|
||||||
|
{"no header with matching origin", "POST", "", "https://example.com", http.StatusOK},
|
||||||
|
{"no header with mismatched origin", "POST", "", "https://attacker.example", http.StatusForbidden},
|
||||||
|
{"no header with null origin", "POST", "", "null", http.StatusForbidden},
|
||||||
|
|
||||||
|
{"GET allowed", "GET", "cross-site", "", http.StatusOK},
|
||||||
|
{"HEAD allowed", "HEAD", "cross-site", "", http.StatusOK},
|
||||||
|
{"OPTIONS allowed", "OPTIONS", "cross-site", "", http.StatusOK},
|
||||||
|
{"PUT blocked", "PUT", "cross-site", "", http.StatusForbidden},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptestNewRequest(tc.method, "https://example.com/")
|
||||||
|
if tc.secFetchSite != "" {
|
||||||
|
req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
|
||||||
|
}
|
||||||
|
if tc.origin != "" {
|
||||||
|
req.Header.Set("Origin", tc.origin)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != tc.expectedStatus {
|
||||||
|
t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossOriginProtectionTrustedOriginBypass(t *testing.T) {
|
||||||
|
protection := http.NewCrossOriginProtection()
|
||||||
|
err := protection.AddTrustedOrigin("https://trusted.example")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AddTrustedOrigin: %v", err)
|
||||||
|
}
|
||||||
|
handler := protection.Handler(okHandler)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
origin string
|
||||||
|
secFetchSite string
|
||||||
|
expectedStatus int
|
||||||
|
}{
|
||||||
|
{"trusted origin without sec-fetch-site", "https://trusted.example", "", http.StatusOK},
|
||||||
|
{"trusted origin with cross-site", "https://trusted.example", "cross-site", http.StatusOK},
|
||||||
|
{"untrusted origin without sec-fetch-site", "https://attacker.example", "", http.StatusForbidden},
|
||||||
|
{"untrusted origin with cross-site", "https://attacker.example", "cross-site", http.StatusForbidden},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptestNewRequest("POST", "https://example.com/")
|
||||||
|
req.Header.Set("Origin", tc.origin)
|
||||||
|
if tc.secFetchSite != "" {
|
||||||
|
req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != tc.expectedStatus {
|
||||||
|
t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossOriginProtectionPatternBypass(t *testing.T) {
|
||||||
|
protection := http.NewCrossOriginProtection()
|
||||||
|
protection.AddInsecureBypassPattern("/bypass/")
|
||||||
|
protection.AddInsecureBypassPattern("/only/{foo}")
|
||||||
|
handler := protection.Handler(okHandler)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
secFetchSite string
|
||||||
|
expectedStatus int
|
||||||
|
}{
|
||||||
|
{"bypass path without sec-fetch-site", "/bypass/", "", http.StatusOK},
|
||||||
|
{"bypass path with cross-site", "/bypass/", "cross-site", http.StatusOK},
|
||||||
|
{"non-bypass path without sec-fetch-site", "/api/", "", http.StatusForbidden},
|
||||||
|
{"non-bypass path with cross-site", "/api/", "cross-site", http.StatusForbidden},
|
||||||
|
|
||||||
|
{"redirect to bypass path without ..", "/foo/../bypass/bar", "", http.StatusOK},
|
||||||
|
{"redirect to bypass path with trailing slash", "/bypass", "", http.StatusOK},
|
||||||
|
{"redirect to non-bypass path with ..", "/foo/../api/bar", "", http.StatusForbidden},
|
||||||
|
{"redirect to non-bypass path with trailing slash", "/api", "", http.StatusForbidden},
|
||||||
|
|
||||||
|
{"wildcard bypass", "/only/123", "", http.StatusOK},
|
||||||
|
{"non-wildcard", "/only/123/foo", "", http.StatusForbidden},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptestNewRequest("POST", "https://example.com"+tc.path)
|
||||||
|
req.Header.Set("Origin", "https://attacker.example")
|
||||||
|
if tc.secFetchSite != "" {
|
||||||
|
req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != tc.expectedStatus {
|
||||||
|
t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossOriginProtectionSetDenyHandler(t *testing.T) {
|
||||||
|
protection := http.NewCrossOriginProtection()
|
||||||
|
|
||||||
|
handler := protection.Handler(okHandler)
|
||||||
|
|
||||||
|
req := httptestNewRequest("POST", "https://example.com/")
|
||||||
|
req.Header.Set("Sec-Fetch-Site", "cross-site")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
|
||||||
|
customErrHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusTeapot)
|
||||||
|
io.WriteString(w, "custom error")
|
||||||
|
})
|
||||||
|
protection.SetDenyHandler(customErrHandler)
|
||||||
|
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusTeapot {
|
||||||
|
t.Errorf("got status %d, want %d", w.Code, http.StatusTeapot)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(w.Body.String(), "custom error") {
|
||||||
|
t.Errorf("expected custom error message, got: %q", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptestNewRequest("GET", "https://example.com/")
|
||||||
|
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
protection.SetDenyHandler(nil)
|
||||||
|
|
||||||
|
req = httptestNewRequest("POST", "https://example.com/")
|
||||||
|
req.Header.Set("Sec-Fetch-Site", "cross-site")
|
||||||
|
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossOriginProtectionAddTrustedOriginErrors(t *testing.T) {
|
||||||
|
protection := http.NewCrossOriginProtection()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
origin string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"valid origin", "https://example.com", false},
|
||||||
|
{"valid origin with port", "https://example.com:8080", false},
|
||||||
|
{"http origin", "http://example.com", false},
|
||||||
|
{"missing scheme", "example.com", true},
|
||||||
|
{"missing host", "https://", true},
|
||||||
|
{"trailing slash", "https://example.com/", true},
|
||||||
|
{"with path", "https://example.com/path", true},
|
||||||
|
{"with query", "https://example.com?query=value", true},
|
||||||
|
{"with fragment", "https://example.com#fragment", true},
|
||||||
|
{"invalid url", "https://ex ample.com", true},
|
||||||
|
{"empty string", "", true},
|
||||||
|
{"null", "null", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := protection.AddTrustedOrigin(tc.origin)
|
||||||
|
if (err != nil) != tc.wantErr {
|
||||||
|
t.Errorf("AddTrustedOrigin(%q) error = %v, wantErr %v", tc.origin, err, tc.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossOriginProtectionAddingBypassesConcurrently(t *testing.T) {
|
||||||
|
protection := http.NewCrossOriginProtection()
|
||||||
|
handler := protection.Handler(okHandler)
|
||||||
|
|
||||||
|
req := httptestNewRequest("POST", "https://example.com/")
|
||||||
|
req.Header.Set("Origin", "https://concurrent.example")
|
||||||
|
req.Header.Set("Sec-Fetch-Site", "cross-site")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
|
||||||
|
start := make(chan struct{})
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
close(start)
|
||||||
|
defer close(done)
|
||||||
|
for range 10 {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Add bypasses while the requests are in flight.
|
||||||
|
<-start
|
||||||
|
protection.AddTrustedOrigin("https://concurrent.example")
|
||||||
|
protection.AddInsecureBypassPattern("/foo/")
|
||||||
|
<-done
|
||||||
|
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("After concurrent bypass addition, got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossOriginProtectionServer(t *testing.T) {
|
||||||
|
protection := http.NewCrossOriginProtection()
|
||||||
|
protection.AddTrustedOrigin("https://trusted.example")
|
||||||
|
protection.AddInsecureBypassPattern("/bypass/")
|
||||||
|
handler := protection.Handler(okHandler)
|
||||||
|
|
||||||
|
ts := httptest.NewServer(handler)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
url string
|
||||||
|
origin string
|
||||||
|
secFetchSite string
|
||||||
|
expectedStatus int
|
||||||
|
}{
|
||||||
|
{"cross-site", "POST", ts.URL, "https://attacker.example", "cross-site", http.StatusForbidden},
|
||||||
|
{"same-origin", "POST", ts.URL, "", "same-origin", http.StatusOK},
|
||||||
|
{"origin matches host", "POST", ts.URL, ts.URL, "", http.StatusOK},
|
||||||
|
{"trusted origin", "POST", ts.URL, "https://trusted.example", "", http.StatusOK},
|
||||||
|
{"untrusted origin", "POST", ts.URL, "https://attacker.example", "", http.StatusForbidden},
|
||||||
|
{"bypass path", "POST", ts.URL + "/bypass/", "https://attacker.example", "", http.StatusOK},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(tc.method, tc.url, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest: %v", err)
|
||||||
|
}
|
||||||
|
if tc.origin != "" {
|
||||||
|
req.Header.Set("Origin", tc.origin)
|
||||||
|
}
|
||||||
|
if tc.secFetchSite != "" {
|
||||||
|
req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
|
||||||
|
}
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != tc.expectedStatus {
|
||||||
|
t.Errorf("got status %d, want %d", resp.StatusCode, tc.expectedStatus)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue