Skip to content

sso_*: allow simultaneous use of Validators #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions internal/auth/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/buzzfeed/sso/internal/auth/providers"
"github.com/buzzfeed/sso/internal/pkg/aead"
log "github.com/buzzfeed/sso/internal/pkg/logging"
"github.com/buzzfeed/sso/internal/pkg/options"
"github.com/buzzfeed/sso/internal/pkg/sessions"
"github.com/buzzfeed/sso/internal/pkg/templates"

Expand All @@ -20,7 +21,7 @@ import (

// Authenticator stores all the information associated with proxying the request.
type Authenticator struct {
Validator func(string) bool
Validators []options.Validator
EmailDomains []string
ProxyRootDomains []string
Host string
Expand Down Expand Up @@ -225,11 +226,16 @@ func (p *Authenticator) authenticate(rw http.ResponseWriter, req *http.Request)
}
}

if !p.Validator(session.Email) {
logger.WithUser(session.Email).Error("invalid email user")
errors := options.RunValidators(p.Validators, session)
if len(errors) == len(p.Validators) {
logger.WithUser(session.Email).Info(
fmt.Sprintf("permission denied: unauthorized: %q", errors))
return nil, ErrUserNotAuthorized
}

logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).Info(
fmt.Sprintf("authentication: user passed validation"))

return session, nil
}

Expand Down Expand Up @@ -575,13 +581,24 @@ func (p *Authenticator) getOAuthCallback(rw http.ResponseWriter, req *http.Reque
// Set cookie, or deny: The authenticator validates the session email and group
// - for p.Validator see validator.go#newValidatorImpl for more info
// - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info
if !p.Validator(session.Email) {

errors := options.RunValidators(p.Validators, session)
if len(errors) == len(p.Validators) {
tags := append(tags, "error:invalid_email")
p.StatsdClient.Incr("application_error", tags, 1.0)
logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).Error(
"invalid_email", "permission denied; unauthorized user")
return "", HTTPError{Code: http.StatusForbidden, Message: "Invalid Account"}
logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).Info(
fmt.Sprintf("oauth callback: unauthorized: %q", errors))

formattedErrors := make([]string, 0, len(errors))
for _, err := range errors {
formattedErrors = append(formattedErrors, err.Error())
}
errorMsg := fmt.Sprintf("We ran into some issues while validating your account: \"%s\"",
strings.Join(formattedErrors, ", "))
return "", HTTPError{Code: http.StatusForbidden, Message: errorMsg}
}
logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).Info(
fmt.Sprintf("oauth callback: user passed validation"))

logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).Info("authentication complete")
err = p.sessionStore.SaveSession(rw, req, session)
Expand Down
23 changes: 9 additions & 14 deletions internal/auth/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

"github.com/buzzfeed/sso/internal/auth/providers"
"github.com/buzzfeed/sso/internal/pkg/aead"
"github.com/buzzfeed/sso/internal/pkg/options"
"github.com/buzzfeed/sso/internal/pkg/sessions"
"github.com/buzzfeed/sso/internal/pkg/templates"
"github.com/buzzfeed/sso/internal/pkg/testutil"
Expand Down Expand Up @@ -66,13 +67,6 @@ func setTestProvider(provider *providers.TestProvider) func(*Authenticator) erro
}
}

func setMockValidator(response bool) func(*Authenticator) error {
return func(a *Authenticator) error {
a.Validator = func(string) bool { return response }
return nil
}
}

func setRedirectURL(redirectURL *url.URL) func(*Authenticator) error {
return func(a *Authenticator) error {
a.redirectURL = redirectURL
Expand Down Expand Up @@ -424,7 +418,7 @@ func TestSignIn(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
config := testConfiguration(t)
auth, err := NewAuthenticator(config,
setMockValidator(tc.validEmail),
SetValidators([]options.Validator{options.NewMockValidator(tc.validEmail)}),
setMockSessionStore(tc.mockSessionStore),
setMockTempl(),
setMockRedirectURL(),
Expand Down Expand Up @@ -571,7 +565,7 @@ func TestSignOutPage(t *testing.T) {
provider.RevokeError = tc.RevokeError

p, _ := NewAuthenticator(config,
setMockValidator(true),
SetValidators([]options.Validator{options.NewMockValidator(true)}),
setMockSessionStore(tc.mockSessionStore),
setMockTempl(),
setTestProvider(provider),
Expand Down Expand Up @@ -948,7 +942,7 @@ func TestGetProfile(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
config := testConfiguration(t)
p, _ := NewAuthenticator(config,
setMockValidator(true),
SetValidators([]options.Validator{options.NewMockValidator(true)}),
)
u, _ := url.Parse("http://example.com")
testProvider := providers.NewTestProvider(u)
Expand Down Expand Up @@ -1050,7 +1044,7 @@ func TestRedeemCode(t *testing.T) {
config := testConfiguration(t)

proxy, _ := NewAuthenticator(config,
setMockValidator(true),
SetValidators([]options.Validator{options.NewMockValidator(true)}),
)

testURL, err := url.Parse("example.com")
Expand Down Expand Up @@ -1357,7 +1351,8 @@ func TestOAuthCallback(t *testing.T) {
Value: "state",
},
},
expectedError: HTTPError{Code: http.StatusForbidden, Message: "Invalid Account"},
expectedError: HTTPError{Code: http.StatusForbidden,
Message: "We ran into some issues while validating your account: \"MockValidator error\""},
},
{
name: "valid email, invalid redirect",
Expand Down Expand Up @@ -1438,7 +1433,7 @@ func TestOAuthCallback(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
config := testConfiguration(t)
proxy, _ := NewAuthenticator(config,
setMockValidator(tc.validEmail),
SetValidators([]options.Validator{options.NewMockValidator(tc.validEmail)}),
setMockCSRFStore(tc.csrfResp),
setMockSessionStore(tc.sessionStore),
)
Expand Down Expand Up @@ -1559,7 +1554,7 @@ func TestOAuthStart(t *testing.T) {
provider := providers.NewTestProvider(nil)
proxy, _ := NewAuthenticator(config,
setTestProvider(provider),
setMockValidator(true),
SetValidators([]options.Validator{options.NewMockValidator(true)}),
setMockRedirectURL(),
setMockCSRFStore(&sessions.MockCSRFStore{}),
)
Expand Down
9 changes: 4 additions & 5 deletions internal/auth/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ type AuthenticatorMux struct {

func NewAuthenticatorMux(config Configuration, statsdClient *statsd.Client) (*AuthenticatorMux, error) {
logger := log.NewLogEntry()

var validator func(string) bool
validators := []options.Validator{}
if len(config.AuthorizeConfig.EmailConfig.Addresses) != 0 {
validator = options.NewEmailAddressValidator(config.AuthorizeConfig.EmailConfig.Addresses)
validators = append(validators, options.NewEmailAddressValidator(config.AuthorizeConfig.EmailConfig.Addresses))
} else {
validator = options.NewEmailDomainValidator(config.AuthorizeConfig.EmailConfig.Domains)
validators = append(validators, options.NewEmailDomainValidator(config.AuthorizeConfig.EmailConfig.Domains))
}

authenticators := []*Authenticator{}
Expand All @@ -38,7 +37,7 @@ func NewAuthenticatorMux(config Configuration, statsdClient *statsd.Client) (*Au

idpSlug := idp.Data().ProviderSlug
authenticator, err := NewAuthenticator(config,
SetValidator(validator),
SetValidators(validators),
SetProvider(idp),
SetCookieStore(config.SessionConfig, idpSlug),
SetStatsdClient(statsdClient),
Expand Down
5 changes: 3 additions & 2 deletions internal/auth/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/buzzfeed/sso/internal/auth/providers"
"github.com/buzzfeed/sso/internal/pkg/aead"
"github.com/buzzfeed/sso/internal/pkg/groups"
"github.com/buzzfeed/sso/internal/pkg/options"
"github.com/buzzfeed/sso/internal/pkg/sessions"

"github.com/datadog/datadog-go/statsd"
Expand Down Expand Up @@ -96,9 +97,9 @@ func SetRedirectURL(serverConfig ServerConfig, slug string) func(*Authenticator)
}

// SetValidator sets the email validator
func SetValidator(validator func(string) bool) func(*Authenticator) error {
func SetValidators(validators []options.Validator) func(*Authenticator) error {
return func(a *Authenticator) error {
a.Validator = validator
a.Validators = validators
return nil
}
}
Expand Down
74 changes: 54 additions & 20 deletions internal/pkg/options/email_address_validator.go
Original file line number Diff line number Diff line change
@@ -1,38 +1,72 @@
package options

import (
"errors"
"fmt"
"strings"

"github.com/buzzfeed/sso/internal/pkg/sessions"
)

// NewEmailAddressValidator returns a function that checks whether a given email is valid based on a list
// of email addresses. The address "*" is a wild card that matches any non-empty email.
func NewEmailAddressValidator(emails []string) func(string) bool {
allowAll := false
var (
_ Validator = EmailAddressValidator{}

// These error message should be formatted in such a way that is appropriate
// for display to the end user.
ErrEmailAddressDenied = errors.New("Unauthorized Email Address")
)

type EmailAddressValidator struct {
AllowedEmails []string
}

// NewEmailAddressValidator takes in a list of email addresses and returns a Validator object.
// The validator can be used to validate that the session.Email:
// - is non-empty
// - matches one of the originally passed in email addresses
// (case insensitive)
// - if the originally passed in list of emails consists only of "*", then all emails
// are considered valid based on their domain.
// If valid, nil is returned in place of an error.
func NewEmailAddressValidator(allowedEmails []string) EmailAddressValidator {
var emailAddresses []string

for _, email := range emails {
if email == "*" {
allowAll = true
}
for _, email := range allowedEmails {
emailAddress := fmt.Sprintf("%s", strings.ToLower(email))
emailAddresses = append(emailAddresses, emailAddress)
}

if allowAll {
return func(email string) bool { return email != "" }
return EmailAddressValidator{
AllowedEmails: emailAddresses,
}
}

return func(email string) bool {
if email == "" {
return false
}
email = strings.ToLower(email)
for _, emailItem := range emailAddresses {
if email == emailItem {
return true
}
func (v EmailAddressValidator) Validate(session *sessions.SessionState) error {
if session.Email == "" {
return ErrInvalidEmailAddress
}

if len(v.AllowedEmails) == 0 {
return ErrEmailAddressDenied
}

if len(v.AllowedEmails) == 1 && v.AllowedEmails[0] == "*" {
return nil
}

err := v.validate(session)
if err != nil {
return err
}
return nil
}

func (v EmailAddressValidator) validate(session *sessions.SessionState) error {
email := strings.ToLower(session.Email)
for _, emailItem := range v.AllowedEmails {
if email == emailItem {
return nil
}
return false
}
return ErrEmailAddressDenied
}
Loading