Skip to content

Commit 95e36f1

Browse files
committed
proxy: add idp transition ux flow
1 parent 8acd64a commit 95e36f1

File tree

5 files changed

+59
-43
lines changed

5 files changed

+59
-43
lines changed

internal/pkg/sessions/session_state.go

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@ package sessions
22

33
import (
44
"errors"
5-
"fmt"
6-
"strconv"
7-
"strings"
85
"time"
96

107
"github.com/buzzfeed/sso/internal/pkg/aead"
@@ -17,6 +14,9 @@ var (
1714

1815
// SessionState is our object that keeps track of a user's session state
1916
type SessionState struct {
17+
ProviderSlug string `json:"slug"`
18+
ProviderType string `json:"type"`
19+
2020
AccessToken string `json:"access_token"`
2121
RefreshToken string `json:"refresh_token"`
2222

@@ -73,26 +73,3 @@ func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) {
7373
func ExtendDeadline(ttl time.Duration) time.Time {
7474
return time.Now().Add(ttl).Truncate(time.Second)
7575
}
76-
77-
// NewSessionState creates a new session state
78-
// TODO: remove this file when we transition out of backup using the payloads encryption
79-
func NewSessionState(value string, lifetimeTTL time.Duration) (*SessionState, error) {
80-
parts := strings.Split(value, "|")
81-
if len(parts) != 4 {
82-
err := fmt.Errorf("invalid number of fields (got %d expected 4)", len(parts))
83-
return nil, err
84-
}
85-
86-
ts, err := strconv.Atoi(parts[2])
87-
if err != nil {
88-
return nil, err
89-
}
90-
91-
return &SessionState{
92-
Email: parts[0],
93-
AccessToken: parts[1],
94-
RefreshDeadline: time.Unix(int64(ts), 0),
95-
RefreshToken: parts[3],
96-
LifetimeDeadline: ExtendDeadline(lifetimeTTL),
97-
}, nil
98-
}

internal/pkg/sessions/session_state_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ func TestSessionStateSerialization(t *testing.T) {
1616
}
1717

1818
want := &SessionState{
19+
ProviderSlug: "slug",
20+
ProviderType: "sso",
21+
1922
AccessToken: "token1234",
2023
RefreshToken: "refresh4321",
2124

internal/proxy/oauthproxy.go

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ var SignatureHeaders = []string{
3838

3939
// Errors
4040
var (
41-
ErrLifetimeExpired = errors.New("user lifetime expired")
42-
ErrUserNotAuthorized = errors.New("user not authorized")
41+
ErrLifetimeExpired = errors.New("user lifetime expired")
42+
ErrUserNotAuthorized = errors.New("user not authorized")
43+
ErrWrongIdentityProvider = errors.New("user authenticated with wrong identity provider")
4344
)
4445

4546
type ErrOAuthProxyMisconfigured struct {
@@ -655,23 +656,29 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
655656
// No cookie is set, start the oauth flow
656657
p.OAuthStart(rw, req, tags)
657658
return
658-
case ErrUserNotAuthorized:
659-
tags = append(tags, "error:user_unauthorized")
660-
p.StatsdClient.Incr("application_error", tags, 1.0)
661-
// We know the user is not authorized for the request, we show them a forbidden page
662-
p.ErrorPage(rw, req, http.StatusForbidden, "Forbidden", "You're not authorized to view this page")
663-
return
664659
case ErrLifetimeExpired:
665660
// User's lifetime expired, we trigger the start of the oauth flow
666661
p.OAuthStart(rw, req, tags)
667662
return
663+
case ErrWrongIdentityProvider:
664+
// User is authenticated with the incorrect provider. This most common non-malicious
665+
// case occurs when an upstream has been transitioned to a different provider but
666+
// the user has a stale sesssion.
667+
p.OAuthStart(rw, req, tags)
668+
return
668669
case sessions.ErrInvalidSession:
669670
// The user session is invalid and we can't decode it.
670671
// This can happen for a variety of reasons but the most common non-malicious
671672
// case occurs when the session encoding schema changes. We manage this ux
672673
// by triggering the start of the oauth flow.
673674
p.OAuthStart(rw, req, tags)
674675
return
676+
case ErrUserNotAuthorized:
677+
tags = append(tags, "error:user_unauthorized")
678+
p.StatsdClient.Incr("application_error", tags, 1.0)
679+
// We know the user is not authorized for the request, we show them a forbidden page
680+
p.ErrorPage(rw, req, http.StatusForbidden, "Forbidden", "You're not authorized to view this page")
681+
return
675682
default:
676683
logger.Error(err, "unknown error authenticating user")
677684
tags = append(tags, "error:internal_error")
@@ -709,6 +716,15 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er
709716
return err
710717
}
711718

719+
// check if this session belongs to the correct identity provider application.
720+
// this case exists primarly to allow us to gracefully manage a clean ux during
721+
// transitions from one provider to another by gracefully restarting the authentication process.
722+
if session.ProviderSlug != p.provider.Data().ProviderSlug {
723+
logger.WithUser(session.Email).Info(
724+
"authenticated with incorrect identity provider; restarting authentication")
725+
return ErrWrongIdentityProvider
726+
}
727+
712728
// Lifetime period is the entire duration in which the session is valid.
713729
// This should be set to something like 14 to 30 days.
714730
if session.LifetimePeriodExpired() {

internal/proxy/oauthproxy_test.go

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,15 @@ func TestAuthOnlyEndpoint(t *testing.T) {
263263

264264
for _, tc := range testCases {
265265
t.Run(tc.name, func(t *testing.T) {
266+
providerURL, _ := url.Parse("http://localhost/")
267+
tp := providers.NewTestProvider(providerURL, "")
268+
tp.RefreshSessionFunc = func(*sessions.SessionState, []string) (bool, error) { return true, nil }
269+
tp.ValidateSessionFunc = func(*sessions.SessionState, []string) bool { return true }
270+
266271
proxy, close := testNewOAuthProxy(t,
267272
setSessionStore(tc.sessionStore),
268273
setValidator(func(_ string) bool { return tc.validEmail }),
269-
SetProvider(&providers.TestProvider{
270-
RefreshSessionFunc: func(*sessions.SessionState, []string) (bool, error) { return true, nil },
271-
ValidateSessionFunc: func(*sessions.SessionState, []string) bool { return true },
272-
}),
274+
SetProvider(tp),
273275
)
274276
defer close()
275277

@@ -571,16 +573,31 @@ func TestAuthenticate(t *testing.T) {
571573
CookieExpectation: NewCookie,
572574
ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return true },
573575
},
576+
{
577+
Name: "wrong identity provider, user OK, do not authenticate",
578+
SessionStore: &sessions.MockSessionStore{
579+
Session: &sessions.SessionState{
580+
ProviderSlug: "example",
581+
582+
AccessToken: "my_access_token",
583+
LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour),
584+
RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour),
585+
ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute),
586+
},
587+
},
588+
ExpectedErr: ErrWrongIdentityProvider,
589+
CookieExpectation: ClearCookie,
590+
},
574591
}
575592
for _, tc := range testCases {
576593
t.Run(tc.Name, func(t *testing.T) {
577-
provider := &providers.TestProvider{
578-
RefreshSessionFunc: tc.RefreshSessionFunc,
579-
ValidateSessionFunc: tc.ValidateSessionFunc,
580-
}
594+
providerURL, _ := url.Parse("http://localhost/")
595+
tp := providers.NewTestProvider(providerURL, "")
596+
tp.RefreshSessionFunc = tc.RefreshSessionFunc
597+
tp.ValidateSessionFunc = tc.ValidateSessionFunc
581598

582599
proxy, close := testNewOAuthProxy(t,
583-
SetProvider(provider),
600+
SetProvider(tp),
584601
setSessionStore(tc.SessionStore),
585602
)
586603
defer close()

internal/proxy/providers/sso.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ func (p *SSOProvider) Redeem(redirectURL, code string) (*sessions.SessionState,
158158

159159
user := strings.Split(jsonResponse.Email, "@")[0]
160160
return &sessions.SessionState{
161+
ProviderSlug: p.ProviderData.ProviderSlug,
162+
ProviderType: "sso",
163+
161164
AccessToken: jsonResponse.AccessToken,
162165
RefreshToken: jsonResponse.RefreshToken,
163166

0 commit comments

Comments
 (0)