Skip to content

Commit 4c32059

Browse files
authored
Normalize allowed request headers and store them in a sorted set (fixes #170) (#171)
1 parent 8d33ca4 commit 4c32059

File tree

7 files changed

+295
-192
lines changed

7 files changed

+295
-192
lines changed

bench_test.go

+17-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cors
22

33
import (
44
"net/http"
5+
"strings"
56
"testing"
67
)
78

@@ -87,7 +88,22 @@ func BenchmarkPreflightHeader(b *testing.B) {
8788
req, _ := http.NewRequest(http.MethodOptions, dummyEndpoint, nil)
8889
req.Header.Add(headerOrigin, dummyOrigin)
8990
req.Header.Add(headerACRM, http.MethodGet)
90-
req.Header.Add(headerACRH, "Accept")
91+
req.Header.Add(headerACRH, "accept")
92+
handler := Default().Handler(testHandler)
93+
94+
b.ReportAllocs()
95+
b.ResetTimer()
96+
for i := 0; i < b.N; i++ {
97+
handler.ServeHTTP(resps[i], req)
98+
}
99+
}
100+
101+
func BenchmarkPreflightAdversarialACRH(b *testing.B) {
102+
resps := makeFakeResponses(b.N)
103+
req, _ := http.NewRequest(http.MethodOptions, dummyEndpoint, nil)
104+
req.Header.Add(headerOrigin, dummyOrigin)
105+
req.Header.Add(headerACRM, http.MethodGet)
106+
req.Header.Add(headerACRH, strings.Repeat(",", 1024))
91107
handler := Default().Handler(testHandler)
92108

93109
b.ReportAllocs()

cors.go

+22-35
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import (
2626
"os"
2727
"strconv"
2828
"strings"
29+
30+
"github.com/rs/cors/internal"
2931
)
3032

3133
var headerVaryOrigin = []string{"Origin"}
@@ -111,7 +113,11 @@ type Cors struct {
111113
// Optional origin validator function
112114
allowOriginFunc func(r *http.Request, origin string) (bool, []string)
113115
// Normalized list of allowed headers
114-
allowedHeaders []string
116+
// Note: the Fetch standard guarantees that CORS-unsafe request-header names
117+
// (i.e. the values listed in the Access-Control-Request-Headers header)
118+
// are unique and sorted;
119+
// see https://fetch.spec.whatwg.org/#cors-unsafe-request-header-names.
120+
allowedHeaders internal.SortedSet
115121
// Normalized list of allowed methods
116122
allowedMethods []string
117123
// Pre-computed normalized list of exposed headers
@@ -183,15 +189,19 @@ func New(options Options) *Cors {
183189
}
184190

185191
// Allowed Headers
192+
// Note: the Fetch standard guarantees that CORS-unsafe request-header names
193+
// (i.e. the values listed in the Access-Control-Request-Headers header)
194+
// are lowercase; see https://fetch.spec.whatwg.org/#cors-unsafe-request-header-names.
186195
if len(options.AllowedHeaders) == 0 {
187196
// Use sensible defaults
188-
c.allowedHeaders = []string{"Accept", "Content-Type", "X-Requested-With"}
197+
c.allowedHeaders = internal.NewSortedSet("accept", "content-type", "x-requested-with")
189198
} else {
190-
c.allowedHeaders = convert(options.AllowedHeaders, http.CanonicalHeaderKey)
199+
normalized := convert(options.AllowedHeaders, strings.ToLower)
200+
c.allowedHeaders = internal.NewSortedSet(normalized...)
191201
for _, h := range options.AllowedHeaders {
192202
if h == "*" {
193203
c.allowedHeadersAll = true
194-
c.allowedHeaders = nil
204+
c.allowedHeaders = internal.SortedSet{}
195205
break
196206
}
197207
}
@@ -351,10 +361,12 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
351361
c.logf(" Preflight aborted: method '%s' not allowed", reqMethod)
352362
return
353363
}
354-
reqHeadersRaw := r.Header["Access-Control-Request-Headers"]
355-
reqHeaders, reqHeadersEdited := convertDidCopy(splitHeaderValues(reqHeadersRaw), http.CanonicalHeaderKey)
356-
if !c.areHeadersAllowed(reqHeaders) {
357-
c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders)
364+
// Note: the Fetch standard guarantees that at most one
365+
// Access-Control-Request-Headers header is present in the preflight request;
366+
// see step 5.2 in https://fetch.spec.whatwg.org/#cors-preflight-fetch-0.
367+
reqHeaders, found := first(r.Header, "Access-Control-Request-Headers")
368+
if found && !c.allowedHeadersAll && !c.allowedHeaders.Subsumes(reqHeaders[0]) {
369+
c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders[0])
358370
return
359371
}
360372
if c.allowedOriginsAll {
@@ -365,14 +377,10 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
365377
// Spec says: Since the list of methods can be unbounded, simply returning the method indicated
366378
// by Access-Control-Request-Method (if supported) can be enough
367379
headers["Access-Control-Allow-Methods"] = r.Header["Access-Control-Request-Method"]
368-
if len(reqHeaders) > 0 {
380+
if found && len(reqHeaders[0]) > 0 {
369381
// Spec says: Since the list of headers can be unbounded, simply returning supported headers
370382
// from Access-Control-Request-Headers can be enough
371-
if reqHeadersEdited || len(reqHeaders) != len(reqHeadersRaw) {
372-
headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", "))
373-
} else {
374-
headers["Access-Control-Allow-Headers"] = reqHeadersRaw
375-
}
383+
headers["Access-Control-Allow-Headers"] = reqHeaders
376384
}
377385
if c.allowCredentials {
378386
headers["Access-Control-Allow-Credentials"] = headerTrue
@@ -492,24 +500,3 @@ func (c *Cors) isMethodAllowed(method string) bool {
492500
}
493501
return false
494502
}
495-
496-
// areHeadersAllowed checks if a given list of headers are allowed to used within
497-
// a cross-domain request.
498-
func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool {
499-
if c.allowedHeadersAll || len(requestedHeaders) == 0 {
500-
return true
501-
}
502-
for _, header := range requestedHeaders {
503-
found := false
504-
for _, h := range c.allowedHeaders {
505-
if h == header {
506-
found = true
507-
break
508-
}
509-
}
510-
if !found {
511-
return false
512-
}
513-
}
514-
return true
515-
}

cors_test.go

+10-68
Original file line numberDiff line numberDiff line change
@@ -303,19 +303,19 @@ func TestSpec(t *testing.T) {
303303
"AllowedHeaders",
304304
Options{
305305
AllowedOrigins: []string{"http://foobar.com"},
306-
AllowedHeaders: []string{"X-Header-1", "x-header-2"},
306+
AllowedHeaders: []string{"X-Header-1", "x-header-2", "X-HEADER-3"},
307307
},
308308
"OPTIONS",
309309
map[string]string{
310310
"Origin": "http://foobar.com",
311311
"Access-Control-Request-Method": "GET",
312-
"Access-Control-Request-Headers": "X-Header-2, X-HEADER-1",
312+
"Access-Control-Request-Headers": "x-header-1,x-header-2",
313313
},
314314
map[string]string{
315315
"Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
316316
"Access-Control-Allow-Origin": "http://foobar.com",
317317
"Access-Control-Allow-Methods": "GET",
318-
"Access-Control-Allow-Headers": "X-Header-2, X-Header-1",
318+
"Access-Control-Allow-Headers": "x-header-1,x-header-2",
319319
},
320320
true,
321321
},
@@ -329,13 +329,13 @@ func TestSpec(t *testing.T) {
329329
map[string]string{
330330
"Origin": "http://foobar.com",
331331
"Access-Control-Request-Method": "GET",
332-
"Access-Control-Request-Headers": "X-Requested-With",
332+
"Access-Control-Request-Headers": "x-requested-with",
333333
},
334334
map[string]string{
335335
"Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
336336
"Access-Control-Allow-Origin": "http://foobar.com",
337337
"Access-Control-Allow-Methods": "GET",
338-
"Access-Control-Allow-Headers": "X-Requested-With",
338+
"Access-Control-Allow-Headers": "x-requested-with",
339339
},
340340
true,
341341
},
@@ -349,13 +349,13 @@ func TestSpec(t *testing.T) {
349349
map[string]string{
350350
"Origin": "http://foobar.com",
351351
"Access-Control-Request-Method": "GET",
352-
"Access-Control-Request-Headers": "X-Header-2, X-HEADER-1",
352+
"Access-Control-Request-Headers": "x-header-1,x-header-2",
353353
},
354354
map[string]string{
355355
"Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
356356
"Access-Control-Allow-Origin": "http://foobar.com",
357357
"Access-Control-Allow-Methods": "GET",
358-
"Access-Control-Allow-Headers": "X-Header-2, X-Header-1",
358+
"Access-Control-Allow-Headers": "x-header-1,x-header-2",
359359
},
360360
true,
361361
},
@@ -369,7 +369,7 @@ func TestSpec(t *testing.T) {
369369
map[string]string{
370370
"Origin": "http://foobar.com",
371371
"Access-Control-Request-Method": "GET",
372-
"Access-Control-Request-Headers": "X-Header-3, X-Header-1",
372+
"Access-Control-Request-Headers": "x-header-1,x-header-3",
373373
},
374374
map[string]string{
375375
"Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
@@ -577,8 +577,8 @@ func TestDefault(t *testing.T) {
577577
if !s.allowedOriginsAll {
578578
t.Error("c.allowedOriginsAll should be true when Default")
579579
}
580-
if s.allowedHeaders == nil {
581-
t.Error("c.allowedHeaders should be nil when Default")
580+
if s.allowedHeaders.Size() == 0 {
581+
t.Error("c.allowedHeaders should be empty when Default")
582582
}
583583
if s.allowedMethods == nil {
584584
t.Error("c.allowedMethods should be nil when Default")
@@ -712,64 +712,6 @@ func TestOptionsSuccessStatusCodeOverride(t *testing.T) {
712712
})
713713
}
714714

715-
func TestCorsAreHeadersAllowed(t *testing.T) {
716-
cases := []struct {
717-
name string
718-
allowedHeaders []string
719-
requestedHeaders []string
720-
want bool
721-
}{
722-
{
723-
name: "nil allowedHeaders",
724-
allowedHeaders: nil,
725-
requestedHeaders: []string{"X-PINGOTHER, Content-Type"},
726-
want: false,
727-
},
728-
{
729-
name: "star allowedHeaders",
730-
allowedHeaders: []string{"*"},
731-
requestedHeaders: []string{"X-PINGOTHER, Content-Type"},
732-
want: true,
733-
},
734-
{
735-
name: "empty reqHeader",
736-
allowedHeaders: nil,
737-
requestedHeaders: []string{},
738-
want: true,
739-
},
740-
{
741-
name: "match allowedHeaders",
742-
allowedHeaders: []string{"Content-Type", "X-PINGOTHER", "X-APP-KEY"},
743-
requestedHeaders: []string{"X-PINGOTHER, Content-Type"},
744-
want: true,
745-
},
746-
{
747-
name: "not matched allowedHeaders",
748-
allowedHeaders: []string{"X-PINGOTHER"},
749-
requestedHeaders: []string{"X-API-KEY, Content-Type"},
750-
want: false,
751-
},
752-
{
753-
name: "allowedHeaders should be a superset of requestedHeaders",
754-
allowedHeaders: []string{"X-PINGOTHER"},
755-
requestedHeaders: []string{"X-PINGOTHER, Content-Type"},
756-
want: false,
757-
},
758-
}
759-
760-
for _, tt := range cases {
761-
tt := tt
762-
763-
t.Run(tt.name, func(t *testing.T) {
764-
c := New(Options{AllowedHeaders: tt.allowedHeaders})
765-
have := c.areHeadersAllowed(convert(splitHeaderValues(tt.requestedHeaders), http.CanonicalHeaderKey))
766-
if have != tt.want {
767-
t.Errorf("Cors.areHeadersAllowed() have: %t want: %t", have, tt.want)
768-
}
769-
})
770-
}
771-
}
772-
773715
func TestAccessControlExposeHeadersPresence(t *testing.T) {
774716
cases := []struct {
775717
name string

internal/sortedset.go

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// adapted from github.com/jub0bs/cors
2+
package internal
3+
4+
import (
5+
"sort"
6+
"strings"
7+
)
8+
9+
// A SortedSet represents a mathematical set of strings sorted in
10+
// lexicographical order.
11+
// Each element has a unique position ranging from 0 (inclusive)
12+
// to the set's cardinality (exclusive).
13+
// The zero value represents an empty set.
14+
type SortedSet struct {
15+
m map[string]int
16+
maxLen int
17+
}
18+
19+
// NewSortedSet returns a SortedSet that contains all of elems,
20+
// but no other elements.
21+
func NewSortedSet(elems ...string) SortedSet {
22+
sort.Strings(elems)
23+
m := make(map[string]int)
24+
var maxLen int
25+
i := 0
26+
for _, s := range elems {
27+
if _, exists := m[s]; exists {
28+
continue
29+
}
30+
m[s] = i
31+
i++
32+
maxLen = max(maxLen, len(s))
33+
}
34+
return SortedSet{
35+
m: m,
36+
maxLen: maxLen,
37+
}
38+
}
39+
40+
// Size returns the cardinality of set.
41+
func (set SortedSet) Size() int {
42+
return len(set.m)
43+
}
44+
45+
// String sorts joins the elements of set (in lexicographical order)
46+
// with a comma and returns the resulting string.
47+
func (set SortedSet) String() string {
48+
elems := make([]string, len(set.m))
49+
for elem, i := range set.m {
50+
elems[i] = elem // safe indexing, by construction of SortedSet
51+
}
52+
return strings.Join(elems, ",")
53+
}
54+
55+
// Subsumes reports whether csv is a sequence of comma-separated names that are
56+
// - all elements of set,
57+
// - sorted in lexicographically order,
58+
// - unique.
59+
func (set SortedSet) Subsumes(csv string) bool {
60+
if csv == "" {
61+
return true
62+
}
63+
posOfLastNameSeen := -1
64+
chunkSize := set.maxLen + 1 // (to accommodate for at least one comma)
65+
for {
66+
// As a defense against maliciously long names in csv,
67+
// we only process at most chunkSize bytes per iteration.
68+
end := min(len(csv), chunkSize)
69+
comma := strings.IndexByte(csv[:end], ',')
70+
var name string
71+
if comma == -1 {
72+
name = csv
73+
} else {
74+
name = csv[:comma]
75+
}
76+
pos, found := set.m[name]
77+
if !found {
78+
return false
79+
}
80+
// The names in csv are expected to be sorted in lexicographical order
81+
// and appear at most once in csv.
82+
// Therefore, the positions (in set) of the names that
83+
// appear in csv should form a strictly increasing sequence.
84+
// If that's not actually the case, bail out.
85+
if pos <= posOfLastNameSeen {
86+
return false
87+
}
88+
posOfLastNameSeen = pos
89+
if comma < 0 { // We've now processed all the names in csv.
90+
break
91+
}
92+
csv = csv[comma+1:]
93+
}
94+
return true
95+
}
96+
97+
// TODO: when updating go directive to 1.21 or later,
98+
// use min builtin instead.
99+
func min(a, b int) int {
100+
if a < b {
101+
return a
102+
}
103+
return b
104+
}
105+
106+
// TODO: when updating go directive to 1.21 or later,
107+
// use max builtin instead.
108+
func max(a, b int) int {
109+
if a > b {
110+
return a
111+
}
112+
return b
113+
}

0 commit comments

Comments
 (0)