Skip to content

Commit 9266625

Browse files
L3n41cemicklei
andauthored
use exact matching of allowed domain entries, issue #489 (#493) (#503)
* use exact matching of allowed domain entries, issue #489 * update doc, add testcases from PR conversation * introduce AllowedDomainFunc #489 * more tests, fix doc * lowercase origin before checking cors Co-authored-by: Ernest Micklei <[email protected]>
1 parent d9c71e1 commit 9266625

File tree

2 files changed

+64
-40
lines changed

2 files changed

+64
-40
lines changed

cors_filter.go

+26-38
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,22 @@ import (
1818
// http://enable-cors.org/server.html
1919
// http://www.html5rocks.com/en/tutorials/cors/#toc-handling-a-not-so-simple-request
2020
type CrossOriginResourceSharing struct {
21-
ExposeHeaders []string // list of Header names
22-
AllowedHeaders []string // list of Header names
23-
AllowedDomains []string // list of allowed values for Http Origin. An allowed value can be a regular expression to support subdomain matching. If empty all are allowed.
21+
ExposeHeaders []string // list of Header names
22+
23+
// AllowedHeaders is alist of Header names. Checking is case-insensitive.
24+
// The list may contain the special wildcard string ".*" ; all is allowed
25+
AllowedHeaders []string
26+
27+
// AllowedDomains is a list of allowed values for Http Origin.
28+
// The list may contain the special wildcard string ".*" ; all is allowed
29+
// If empty all are allowed.
30+
AllowedDomains []string
31+
32+
// AllowedDomainFunc is optional and is a function that will do the check
33+
// when the origin is not part of the AllowedDomains and it does not contain the wildcard ".*".
34+
AllowedDomainFunc func(origin string) bool
35+
36+
// AllowedMethods is either empty or has a list of http methods names. Checking is case-insensitive.
2437
AllowedMethods []string
2538
MaxAge int // number of seconds before requiring new Options request
2639
CookiesAllowed bool
@@ -119,36 +132,24 @@ func (c CrossOriginResourceSharing) isOriginAllowed(origin string) bool {
119132
if len(origin) == 0 {
120133
return false
121134
}
135+
lowerOrigin := strings.ToLower(origin)
122136
if len(c.AllowedDomains) == 0 {
137+
if c.AllowedDomainFunc != nil {
138+
return c.AllowedDomainFunc(lowerOrigin)
139+
}
123140
return true
124141
}
125142

126-
allowed := false
143+
// exact match on each allowed domain
127144
for _, domain := range c.AllowedDomains {
128-
if domain == origin {
129-
allowed = true
130-
break
145+
if domain == ".*" || strings.ToLower(domain) == lowerOrigin {
146+
return true
131147
}
132148
}
133-
134-
if !allowed {
135-
if len(c.allowedOriginPatterns) == 0 {
136-
// compile allowed domains to allowed origin patterns
137-
allowedOriginRegexps, err := compileRegexps(c.AllowedDomains)
138-
if err != nil {
139-
return false
140-
}
141-
c.allowedOriginPatterns = allowedOriginRegexps
142-
}
143-
144-
for _, pattern := range c.allowedOriginPatterns {
145-
if allowed = pattern.MatchString(origin); allowed {
146-
break
147-
}
148-
}
149+
if c.AllowedDomainFunc != nil {
150+
return c.AllowedDomainFunc(origin)
149151
}
150-
151-
return allowed
152+
return false
152153
}
153154

154155
func (c CrossOriginResourceSharing) setAllowOriginHeader(req *Request, resp *Response) {
@@ -190,16 +191,3 @@ func (c CrossOriginResourceSharing) isValidAccessControlRequestHeader(header str
190191
}
191192
return false
192193
}
193-
194-
// Take a list of strings and compile them into a list of regular expressions.
195-
func compileRegexps(regexpStrings []string) ([]*regexp.Regexp, error) {
196-
regexps := []*regexp.Regexp{}
197-
for _, regexpStr := range regexpStrings {
198-
r, err := regexp.Compile(regexpStr)
199-
if err != nil {
200-
return regexps, err
201-
}
202-
regexps = append(regexps, r)
203-
}
204-
return regexps, nil
205-
}

cors_filter_test.go

+38-2
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,46 @@ func TestCORSFilter_AllowedDomains(t *testing.T) {
120120
DefaultContainer.Dispatch(httpWriter, httpRequest)
121121
actual := httpWriter.Header().Get(HEADER_AccessControlAllowOrigin)
122122
if actual != each.origin && each.allowed {
123-
t.Fatal("expected to be accepted")
123+
t.Error("expected to be accepted", each)
124124
}
125125
if actual == each.origin && !each.allowed {
126-
t.Fatal("did not expect to be accepted")
126+
t.Error("did not expect to be accepted")
127127
}
128128
}
129129
}
130+
131+
func TestCORSFilter_AllowedDomainFunc(t *testing.T) {
132+
cors := CrossOriginResourceSharing{
133+
AllowedDomains: []string{"here", "there"},
134+
AllowedDomainFunc: func(origin string) bool {
135+
return "where" == origin
136+
},
137+
}
138+
if got, want := cors.isOriginAllowed("here"), true; got != want {
139+
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
140+
}
141+
if got, want := cors.isOriginAllowed("HERE"), true; got != want {
142+
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
143+
}
144+
if got, want := cors.isOriginAllowed("there"), true; got != want {
145+
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
146+
}
147+
if got, want := cors.isOriginAllowed("where"), true; got != want {
148+
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
149+
}
150+
if got, want := cors.isOriginAllowed("nowhere"), false; got != want {
151+
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
152+
}
153+
// just func
154+
cors.AllowedDomains = []string{}
155+
if got, want := cors.isOriginAllowed("here"), false; got != want {
156+
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
157+
}
158+
if got, want := cors.isOriginAllowed("where"), true; got != want {
159+
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
160+
}
161+
// empty domain
162+
if got, want := cors.isOriginAllowed(""), false; got != want {
163+
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
164+
}
165+
}

0 commit comments

Comments
 (0)