Skip to content

Commit 5819e69

Browse files
author
Justin Hines
authored
Merge pull request #218 from buzzfeed/sso-transition-idps
proxy: transition idps ux flow
2 parents da6efae + 0441b45 commit 5819e69

File tree

5 files changed

+247
-43
lines changed

5 files changed

+247
-43
lines changed

internal/pkg/sessions/session_state.go

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@ package sessions
22

33
import (
44
"errors"
5-
"fmt"
6-
"strconv"
7-
"strings"
85
"time"
96

107
"github.com/buzzfeed/sso/internal/pkg/aead"
@@ -17,6 +14,9 @@ var (
1714

1815
// SessionState is our object that keeps track of a user's session state
1916
type SessionState struct {
17+
ProviderSlug string `json:"slug"`
18+
ProviderType string `json:"type"`
19+
2020
AccessToken string `json:"access_token"`
2121
RefreshToken string `json:"refresh_token"`
2222

@@ -73,26 +73,3 @@ func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) {
7373
func ExtendDeadline(ttl time.Duration) time.Time {
7474
return time.Now().Add(ttl).Truncate(time.Second)
7575
}
76-
77-
// NewSessionState creates a new session state
78-
// TODO: remove this file when we transition out of backup using the payloads encryption
79-
func NewSessionState(value string, lifetimeTTL time.Duration) (*SessionState, error) {
80-
parts := strings.Split(value, "|")
81-
if len(parts) != 4 {
82-
err := fmt.Errorf("invalid number of fields (got %d expected 4)", len(parts))
83-
return nil, err
84-
}
85-
86-
ts, err := strconv.Atoi(parts[2])
87-
if err != nil {
88-
return nil, err
89-
}
90-
91-
return &SessionState{
92-
Email: parts[0],
93-
AccessToken: parts[1],
94-
RefreshDeadline: time.Unix(int64(ts), 0),
95-
RefreshToken: parts[3],
96-
LifetimeDeadline: ExtendDeadline(lifetimeTTL),
97-
}, nil
98-
}

internal/pkg/sessions/session_state_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ func TestSessionStateSerialization(t *testing.T) {
1616
}
1717

1818
want := &SessionState{
19+
ProviderSlug: "slug",
20+
ProviderType: "sso",
21+
1922
AccessToken: "token1234",
2023
RefreshToken: "refresh4321",
2124

internal/proxy/oauthproxy.go

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ var SignatureHeaders = []string{
3838

3939
// Errors
4040
var (
41-
ErrLifetimeExpired = errors.New("user lifetime expired")
42-
ErrUserNotAuthorized = errors.New("user not authorized")
41+
ErrLifetimeExpired = errors.New("user lifetime expired")
42+
ErrUserNotAuthorized = errors.New("user not authorized")
43+
ErrWrongIdentityProvider = errors.New("user authenticated with wrong identity provider")
4344
)
4445

4546
type ErrOAuthProxyMisconfigured struct {
@@ -655,23 +656,29 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
655656
// No cookie is set, start the oauth flow
656657
p.OAuthStart(rw, req, tags)
657658
return
658-
case ErrUserNotAuthorized:
659-
tags = append(tags, "error:user_unauthorized")
660-
p.StatsdClient.Incr("application_error", tags, 1.0)
661-
// We know the user is not authorized for the request, we show them a forbidden page
662-
p.ErrorPage(rw, req, http.StatusForbidden, "Forbidden", "You're not authorized to view this page")
663-
return
664659
case ErrLifetimeExpired:
665660
// User's lifetime expired, we trigger the start of the oauth flow
666661
p.OAuthStart(rw, req, tags)
667662
return
663+
case ErrWrongIdentityProvider:
664+
// User is authenticated with the incorrect provider. This most common non-malicious
665+
// case occurs when an upstream has been transitioned to a different provider but
666+
// the user has a stale sesssion.
667+
p.OAuthStart(rw, req, tags)
668+
return
668669
case sessions.ErrInvalidSession:
669670
// The user session is invalid and we can't decode it.
670671
// This can happen for a variety of reasons but the most common non-malicious
671672
// case occurs when the session encoding schema changes. We manage this ux
672673
// by triggering the start of the oauth flow.
673674
p.OAuthStart(rw, req, tags)
674675
return
676+
case ErrUserNotAuthorized:
677+
tags = append(tags, "error:user_unauthorized")
678+
p.StatsdClient.Incr("application_error", tags, 1.0)
679+
// We know the user is not authorized for the request, we show them a forbidden page
680+
p.ErrorPage(rw, req, http.StatusForbidden, "Forbidden", "You're not authorized to view this page")
681+
return
675682
default:
676683
logger.Error(err, "unknown error authenticating user")
677684
tags = append(tags, "error:internal_error")
@@ -709,6 +716,15 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er
709716
return err
710717
}
711718

719+
// check if this session belongs to the correct identity provider application.
720+
// this case exists primarly to allow us to gracefully manage a clean ux during
721+
// transitions from one provider to another by gracefully restarting the authentication process.
722+
if session.ProviderSlug != p.provider.Data().ProviderSlug {
723+
logger.WithUser(session.Email).Info(
724+
"authenticated with incorrect identity provider; restarting authentication")
725+
return ErrWrongIdentityProvider
726+
}
727+
712728
// Lifetime period is the entire duration in which the session is valid.
713729
// This should be set to something like 14 to 30 days.
714730
if session.LifetimePeriodExpired() {

internal/proxy/oauthproxy_test.go

Lines changed: 214 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,15 @@ func TestAuthOnlyEndpoint(t *testing.T) {
263263

264264
for _, tc := range testCases {
265265
t.Run(tc.name, func(t *testing.T) {
266+
providerURL, _ := url.Parse("http://localhost/")
267+
tp := providers.NewTestProvider(providerURL, "")
268+
tp.RefreshSessionFunc = func(*sessions.SessionState, []string) (bool, error) { return true, nil }
269+
tp.ValidateSessionFunc = func(*sessions.SessionState, []string) bool { return true }
270+
266271
proxy, close := testNewOAuthProxy(t,
267272
setSessionStore(tc.sessionStore),
268273
setValidator(func(_ string) bool { return tc.validEmail }),
269-
SetProvider(&providers.TestProvider{
270-
RefreshSessionFunc: func(*sessions.SessionState, []string) (bool, error) { return true, nil },
271-
ValidateSessionFunc: func(*sessions.SessionState, []string) bool { return true },
272-
}),
274+
SetProvider(tp),
273275
)
274276
defer close()
275277

@@ -571,16 +573,31 @@ func TestAuthenticate(t *testing.T) {
571573
CookieExpectation: NewCookie,
572574
ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return true },
573575
},
576+
{
577+
Name: "wrong identity provider, user OK, do not authenticate",
578+
SessionStore: &sessions.MockSessionStore{
579+
Session: &sessions.SessionState{
580+
ProviderSlug: "example",
581+
582+
AccessToken: "my_access_token",
583+
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
584+
RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour),
585+
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
586+
},
587+
},
588+
ExpectedErr: ErrWrongIdentityProvider,
589+
CookieExpectation: ClearCookie,
590+
},
574591
}
575592
for _, tc := range testCases {
576593
t.Run(tc.Name, func(t *testing.T) {
577-
provider := &providers.TestProvider{
578-
RefreshSessionFunc: tc.RefreshSessionFunc,
579-
ValidateSessionFunc: tc.ValidateSessionFunc,
580-
}
594+
providerURL, _ := url.Parse("http://localhost/")
595+
tp := providers.NewTestProvider(providerURL, "")
596+
tp.RefreshSessionFunc = tc.RefreshSessionFunc
597+
tp.ValidateSessionFunc = tc.ValidateSessionFunc
581598

582599
proxy, close := testNewOAuthProxy(t,
583-
SetProvider(provider),
600+
SetProvider(tp),
584601
setSessionStore(tc.SessionStore),
585602
)
586603
defer close()
@@ -608,6 +625,194 @@ func TestAuthenticate(t *testing.T) {
608625
}
609626
}
610627

628+
func TestAuthenticationUXFlows(t *testing.T) {
629+
var (
630+
ErrRefreshFailed = errors.New("refresh failed")
631+
LoadCookieFailed = errors.New("load cookie fail")
632+
SaveCookieFailed = errors.New("save cookie fail")
633+
)
634+
testCases := []struct {
635+
Name string
636+
637+
SessionStore *sessions.MockSessionStore
638+
RefreshSessionFunc func(*sessions.SessionState, []string) (bool, error)
639+
ValidateSessionFunc func(*sessions.SessionState, []string) bool
640+
641+
ExpectStatusCode int
642+
}{
643+
{
644+
Name: "missing deadlines, redirect to sign-in",
645+
SessionStore: &sessions.MockSessionStore{
646+
Session: &sessions.SessionState{
647+
648+
AccessToken: "my_access_token",
649+
},
650+
},
651+
ExpectStatusCode: http.StatusFound,
652+
},
653+
{
654+
Name: "session unmarshaling fails, show error",
655+
SessionStore: &sessions.MockSessionStore{
656+
Session: &sessions.SessionState{},
657+
LoadError: LoadCookieFailed,
658+
},
659+
ExpectStatusCode: http.StatusInternalServerError,
660+
},
661+
{
662+
Name: "authenticate successfully, expect ok",
663+
SessionStore: &sessions.MockSessionStore{
664+
Session: &sessions.SessionState{
665+
666+
AccessToken: "my_access_token",
667+
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
668+
RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour),
669+
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
670+
},
671+
},
672+
ExpectStatusCode: http.StatusOK,
673+
},
674+
{
675+
Name: "lifetime expired, redirect to sign-in",
676+
SessionStore: &sessions.MockSessionStore{
677+
Session: &sessions.SessionState{
678+
679+
AccessToken: "my_access_token",
680+
LifetimeDeadline: time.Now().Add(time.Duration(-24) * time.Hour),
681+
RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour),
682+
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
683+
},
684+
},
685+
ExpectStatusCode: http.StatusFound,
686+
},
687+
{
688+
Name: "refresh expired, refresh fails, show error",
689+
SessionStore: &sessions.MockSessionStore{
690+
Session: &sessions.SessionState{
691+
692+
AccessToken: "my_access_token",
693+
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
694+
RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour),
695+
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
696+
},
697+
},
698+
RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return false, ErrRefreshFailed },
699+
ExpectStatusCode: http.StatusInternalServerError,
700+
},
701+
{
702+
Name: "refresh expired, user not OK, deny",
703+
SessionStore: &sessions.MockSessionStore{
704+
Session: &sessions.SessionState{
705+
706+
AccessToken: "my_access_token",
707+
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
708+
RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour),
709+
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
710+
},
711+
},
712+
RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return false, nil },
713+
ExpectStatusCode: http.StatusForbidden,
714+
},
715+
{
716+
Name: "refresh expired, user OK, expect ok",
717+
SessionStore: &sessions.MockSessionStore{
718+
Session: &sessions.SessionState{
719+
720+
AccessToken: "my_access_token",
721+
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
722+
RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour),
723+
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
724+
},
725+
},
726+
RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return true, nil },
727+
ExpectStatusCode: http.StatusOK,
728+
},
729+
{
730+
Name: "refresh expired, refresh and user OK, error saving session, show error",
731+
SessionStore: &sessions.MockSessionStore{
732+
Session: &sessions.SessionState{
733+
734+
AccessToken: "my_access_token",
735+
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
736+
RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour),
737+
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
738+
},
739+
SaveError: SaveCookieFailed,
740+
},
741+
RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return true, nil },
742+
ExpectStatusCode: http.StatusInternalServerError,
743+
},
744+
{
745+
Name: "validation expired, user not OK, deny",
746+
SessionStore: &sessions.MockSessionStore{
747+
Session: &sessions.SessionState{
748+
749+
AccessToken: "my_access_token",
750+
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
751+
RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour),
752+
ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute),
753+
},
754+
},
755+
ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return false },
756+
ExpectStatusCode: http.StatusForbidden,
757+
},
758+
{
759+
Name: "validation expired, user OK, expect ok",
760+
SessionStore: &sessions.MockSessionStore{
761+
Session: &sessions.SessionState{
762+
763+
AccessToken: "my_access_token",
764+
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
765+
RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour),
766+
ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute),
767+
},
768+
},
769+
ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return true },
770+
ExpectStatusCode: http.StatusOK,
771+
},
772+
{
773+
Name: "wrong identity provider, redirect to sign-in",
774+
SessionStore: &sessions.MockSessionStore{
775+
Session: &sessions.SessionState{
776+
ProviderSlug: "example",
777+
778+
AccessToken: "my_access_token",
779+
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
780+
RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour),
781+
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
782+
},
783+
},
784+
ExpectStatusCode: http.StatusFound,
785+
},
786+
}
787+
for _, tc := range testCases {
788+
t.Run(tc.Name, func(t *testing.T) {
789+
providerURL, _ := url.Parse("http://localhost/")
790+
tp := providers.NewTestProvider(providerURL, "")
791+
tp.RefreshSessionFunc = tc.RefreshSessionFunc
792+
tp.ValidateSessionFunc = tc.ValidateSessionFunc
793+
794+
proxy, close := testNewOAuthProxy(t,
795+
SetProvider(tp),
796+
setSessionStore(tc.SessionStore),
797+
)
798+
defer close()
799+
800+
req := httptest.NewRequest("GET", "https://localhost", nil)
801+
rw := httptest.NewRecorder()
802+
803+
proxy.Proxy(rw, req)
804+
805+
res := rw.Result()
806+
807+
if tc.ExpectStatusCode != res.StatusCode {
808+
t.Errorf("have: %v", res.StatusCode)
809+
t.Errorf("want: %v", tc.ExpectStatusCode)
810+
t.Fatalf("expected status codes to be equal")
811+
}
812+
})
813+
}
814+
}
815+
611816
func TestProxyXHRErrorHandling(t *testing.T) {
612817
testCases := []struct {
613818
Name string

internal/proxy/providers/sso.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ func (p *SSOProvider) Redeem(redirectURL, code string) (*sessions.SessionState,
158158

159159
user := strings.Split(jsonResponse.Email, "@")[0]
160160
return &sessions.SessionState{
161+
ProviderSlug: p.ProviderData.ProviderSlug,
162+
ProviderType: "sso",
163+
161164
AccessToken: jsonResponse.AccessToken,
162165
RefreshToken: jsonResponse.RefreshToken,
163166

0 commit comments

Comments
 (0)