Skip to content

Commit 99e69f8

Browse files
authored
Merge pull request #267 from buzzfeed/jusshersmith-group-validator-bug
sso_proxy: reduce amount of group validations
2 parents a29ba90 + 950db76 commit 99e69f8

File tree

3 files changed

+23
-13
lines changed

3 files changed

+23
-13
lines changed

internal/auth/providers/okta.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ func (p *OktaProvider) oktaRequest(method, endpoint string, params url.Values, t
199199
if resp.StatusCode != http.StatusOK {
200200
p.StatsdClient.Incr("provider.error", tags, 1.0)
201201
logger.WithHTTPStatus(resp.StatusCode).WithEndpoint(stripToken(endpoint)).WithResponseBody(
202-
respBody).Info()
202+
respBody).Error("non-200 response returned from Okta")
203203
switch resp.StatusCode {
204204
case 400:
205205
var response struct {

internal/pkg/options/email_domain_validator.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
)
1010

1111
var (
12-
_ Validator = &EmailDomainValidator{}
12+
_ Validator = EmailDomainValidator{}
1313

1414
// These error message should be formatted in such a way that is appropriate
1515
// for display to the end user.
@@ -28,7 +28,7 @@ type EmailDomainValidator struct {
2828
// - if the originally passed in list of domains consists only of "*", then all emails
2929
// are considered valid based on their domain.
3030
// If valid, nil is returned in place of an error.
31-
func NewEmailDomainValidator(allowedDomains []string) *EmailDomainValidator {
31+
func NewEmailDomainValidator(allowedDomains []string) EmailDomainValidator {
3232
emailDomains := make([]string, 0, len(allowedDomains))
3333

3434
for _, domain := range allowedDomains {
@@ -39,12 +39,12 @@ func NewEmailDomainValidator(allowedDomains []string) *EmailDomainValidator {
3939
emailDomains = append(emailDomains, emailDomain)
4040
}
4141
}
42-
return &EmailDomainValidator{
42+
return EmailDomainValidator{
4343
AllowedDomains: emailDomains,
4444
}
4545
}
4646

47-
func (v *EmailDomainValidator) Validate(session *sessions.SessionState) error {
47+
func (v EmailDomainValidator) Validate(session *sessions.SessionState) error {
4848
if session.Email == "" {
4949
return ErrInvalidEmailAddress
5050
}
@@ -64,7 +64,7 @@ func (v *EmailDomainValidator) Validate(session *sessions.SessionState) error {
6464
return nil
6565
}
6666

67-
func (v *EmailDomainValidator) validate(session *sessions.SessionState) error {
67+
func (v EmailDomainValidator) validate(session *sessions.SessionState) error {
6868
email := strings.ToLower(session.Email)
6969
for _, domain := range v.AllowedDomains {
7070
if strings.HasSuffix(email, domain) {

internal/proxy/oauthproxy.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -781,13 +781,23 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er
781781
}
782782
}
783783

784-
errors := options.RunValidators(p.Validators, session)
785-
if len(errors) == len(p.Validators) {
786-
tags = append(tags, "error:validation_failed")
787-
p.StatsdClient.Incr("application_error", tags, 1.0)
788-
logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).Info(
789-
fmt.Sprintf("permission denied: unauthorized: %q", errors))
790-
return ErrUserNotAuthorized
784+
// We revalidate group membership whenever the session is refreshed or revalidated
785+
// just above in the call to ValidateSessionState and RefreshSession.
786+
// To reduce strain on upstream identity providers we only revalidate email domains and
787+
// addresses on each request here.
788+
for _, v := range p.Validators {
789+
_, EmailGroupValidator := v.(options.EmailGroupValidator)
790+
791+
if !EmailGroupValidator {
792+
err := v.Validate(session)
793+
if err != nil {
794+
tags = append(tags, "error:validation_failed")
795+
p.StatsdClient.Incr("application_error", tags, 1.0)
796+
logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).Info(
797+
fmt.Sprintf("permission denied: unauthorized: %q", err))
798+
return ErrUserNotAuthorized
799+
}
800+
}
791801
}
792802

793803
logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).Info(

0 commit comments

Comments
 (0)