Skip to content

Commit 6d26a94

Browse files
committed
some cleanup and fix group validation
1 parent 5bbd5bf commit 6d26a94

File tree

1 file changed

+51
-115
lines changed

1 file changed

+51
-115
lines changed

internal/auth/providers/okta.go

Lines changed: 51 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,14 @@ import (
2424
type OktaProvider struct {
2525
*ProviderData
2626
StatsdClient *statsd.Client
27-
// AdminService AdminService
28-
cb *circuit.Breaker
29-
GroupsCache groups.MemberSetCache
27+
cb *circuit.Breaker
28+
GroupsCache groups.MemberSetCache
29+
}
30+
31+
type UserInfoResponse struct {
32+
EmailAddress string `json:"email"`
33+
EmailVerified bool `json:"email_verified"`
34+
Groups []string `json:"groups"`
3035
}
3136

3237
// NewOktaProvider returns a new OktaProvider and sets the provider url endpoints.
@@ -80,7 +85,7 @@ func NewOktaProvider(p *ProviderData, orgName string) (*OktaProvider, error) {
8085
}
8186
if p.Scope == "" {
8287
// https://developer.okta.com/docs/api/resources/oidc/#authorize
83-
p.Scope = "openid profile email email_verified"
88+
p.Scope = "openid profile email groups"
8489
}
8590

8691
oktaProvider := &OktaProvider{
@@ -100,27 +105,6 @@ func NewOktaProvider(p *ProviderData, orgName string) (*OktaProvider, error) {
100105
return oktaProvider, nil
101106
}
102107

103-
// SetStatsdClient sets the okta provider and admin service statsd client
104-
//func (p *OktaProvider) SetStatsdClient(statsdClient *statsd.Client) {
105-
// logger := log.NewLogEntry()
106-
//
107-
// p.StatsdClient = statsdClient
108-
//
109-
// switch s := p.AdminService.(type) {
110-
// case *OktaAdminService:
111-
// s.StatsdClient = statsdClient
112-
// default:
113-
// logger.Info("admin service does not have statsd client")
114-
// }
115-
116-
// switch g := p.GroupsCache.(type) {
117-
// case *groups.FillCache:
118-
// g.StatsdClient = statsdClient
119-
// default:
120-
// logger.Info("groups cache does not have statsd client")
121-
// }
122-
//}
123-
124108
// ValidateSessionState attempts to validate the session state's access token.
125109
func (p *OktaProvider) ValidateSessionState(s *sessions.SessionState) bool {
126110
return validateToken(p, s.AccessToken, p.ClientID, p.ClientSecret, nil)
@@ -132,9 +116,6 @@ func (p *OktaProvider) GetSignInURL(redirectURI, state string) string {
132116
a = *p.SignInURL
133117
params, _ := url.ParseQuery(a.RawQuery)
134118
params.Set("redirect_uri", redirectURI)
135-
// https://developer.okta.com/docs/api/resources/oidc/#parameter-details
136-
//params.Set("prompt", p.ApprovalPrompt)
137-
// params.Add("nonce", nonce)
138119
params.Add("scope", p.Scope)
139120
params.Set("client_id", p.ClientID)
140121
params.Add("response_mode", "query")
@@ -226,36 +207,6 @@ func (p *OktaProvider) oktaRequest(method, endpoint string, params url.Values, t
226207
return nil
227208
}
228209

229-
func emailFromIDTokenOkta(idToken []string) (string, error) {
230-
231-
// id_token is a base64 encode ID token payload
232-
// https://developers.okta.com/accounts/docs/OAuth2Login#obtainuserinfo
233-
//jwt := strings.Split(idToken, ".")
234-
//b, err := jwtDecodeSegmentOkta(jwt[1])
235-
//if err != nil {
236-
// return "", err
237-
//}
238-
239-
//var email struct {
240-
// Email string `json:"email"`
241-
// EmailVerified bool `json:"email_verified"`
242-
//}
243-
//err = json.Unmarshal(b, &email)
244-
//if err != nil {
245-
// return "", err
246-
//}
247-
//if email.Email == "" {
248-
// return "", errors.New("missing email")
249-
//}
250-
// TESTING: added for test purposes
251-
//added 'b' output as well for debugging
252-
//if !email.EmailVerified {
253-
// return email.Email, nil
254-
// return "", fmt.Errorf("email %s not listed as verified: %s", email.Email, b)
255-
//}
256-
return "test", nil
257-
}
258-
259210
func jwtDecodeSegmentOkta(seg string) ([]byte, error) {
260211
if l := len(seg) % 4; l > 0 {
261212
seg += strings.Repeat("=", 4-l)
@@ -313,7 +264,7 @@ func (p *OktaProvider) Redeem(redirectURL, code string) (*sessions.SessionState,
313264
return nil, err
314265
}
315266
var email string
316-
email, err = p.UserInfo(response.AccessToken)
267+
email, err = p.verifyEmailWithAccessToken(response.AccessToken)
317268
if err != nil {
318269
return nil, err
319270
}
@@ -327,83 +278,68 @@ func (p *OktaProvider) Redeem(redirectURL, code string) (*sessions.SessionState,
327278
}, nil
328279
}
329280

330-
func (p *OktaProvider) UserInfo(access_token string) (string, error) {
331-
if access_token == "" {
281+
func (p *OktaProvider) verifyEmailWithAccessToken(AccessToken string) (string, error) {
282+
if AccessToken == "" {
332283
return "", ErrBadRequest
333284
}
334-
var bearer = "Bearer " + access_token
335-
header := http.Header{}
336-
header.Set("Authorization", bearer)
337285

338-
var response struct {
339-
EmailAddress string `json:"email"`
340-
EmailVerified bool `json:"email_verified"`
341-
}
342-
err := p.oktaRequest("GET", p.UserInfoURL.String(), nil, []string{"tags", "test"}, header, &response)
286+
userinfo, err := p.UserInfo(AccessToken)
343287
if err != nil {
344288
return "", err
345289
}
346-
if response.EmailAddress == "" {
290+
if userinfo.EmailAddress == "" {
347291
return "", errors.New("missing email")
348292
}
349-
if !response.EmailVerified {
350-
return "", fmt.Errorf("email %s not listed as verified", response.EmailAddress)
293+
if !userinfo.EmailVerified {
294+
return "", fmt.Errorf("email %s not listed as verified", userinfo.EmailAddress)
351295
}
352296

353-
return response.EmailAddress, nil
297+
return userinfo.EmailAddress, nil
354298
}
355299

356-
// PopulateMembers is the fill function for the groups cache
357-
//func (p *OktaProvider) PopulateMembers(group string) (groups.MemberSet, error) {
358-
// members, err := p.AdminService.ListMemberships(group, 4)
359-
// if err != nil {
360-
// return nil, err
361-
// }
362-
// memberSet := map[string]struct{}{}
363-
// for _, member := range members {
364-
// memberSet[member] = struct{}{}
365-
// }
366-
// return memberSet, nil
367-
//}
368-
369-
// ValidateGroupMembership takes in an email and the allowed groups and returns the groups that the email is part of in that list.
370-
// If `allGroups` is an empty list, returns an empty list.
371-
//
372-
373-
func (p *OktaProvider) ValidateGroupMembership(email string, allGroups []string) ([]string, error) {
374-
logger := log.NewLogEntry()
375-
376-
groups := []string{}
377-
378-
//var useGroupsResource bool
300+
func (p *OktaProvider) ValidateGroupMembership(email string, allowedGroups []string, accessToken string) ([]string, error) {
301+
if accessToken == "" {
302+
return nil, ErrBadRequest
303+
}
379304

380-
// if `allGroups` is empty, we return an empty list
381-
if len(allGroups) == 0 {
305+
if len(allowedGroups) == 0 {
382306
return []string{}, nil
383307
}
384308

385-
// iterate over the groups, if a set isn't populated only call the GroupsResource once and check all groups
386-
for _, group := range allGroups {
387-
memberSet, ok := p.GroupsCache.Get(group)
388-
if !ok {
389-
//useGroupsResource = true
390-
if started := p.GroupsCache.RefreshLoop(group); started {
391-
logger.WithUserGroup(group).Info(
392-
"no member set cached for group; refresh loops started")
393-
p.StatsdClient.Incr("cache_refresh_loop", []string{"action:profile", fmt.Sprintf("group:%s", group)}, 1.0)
309+
userinfo, err := p.UserInfo(accessToken)
310+
if err != nil {
311+
return nil, err
312+
}
313+
if len(userinfo.Groups) == 0 {
314+
return nil, fmt.Errorf("no groups found")
315+
}
316+
317+
matchingGroups := []string{}
318+
for _, x := range allowedGroups {
319+
for _, y := range userinfo.Groups {
320+
if x == y {
321+
matchingGroups = append(matchingGroups, x)
394322
}
395323
}
396-
if _, exists := memberSet[email]; exists {
397-
groups = append(groups, group)
398-
}
399324
}
325+
return matchingGroups, nil
326+
}
400327

401-
// // if a cached member set was not populated, use the groups resource to get all the groups and filter out the ones that are in `allGroups`
402-
// if useGroupsResource {
403-
// return p.AdminService.CheckMemberships(allGroups, email)
404-
// }
405-
return groups, nil
328+
func (p *OktaProvider) UserInfo(AccessToken string) (*UserInfoResponse, error) {
329+
response := &UserInfoResponse{}
330+
if AccessToken == "" {
331+
return nil, ErrBadRequest
332+
}
333+
var bearer = "Bearer " + AccessToken
334+
header := http.Header{}
335+
header.Set("Authorization", bearer)
336+
337+
err := p.oktaRequest("GET", p.UserInfoURL.String(), nil, []string{"tags", "test"}, header, &response)
338+
if err != nil {
339+
return nil, err
340+
}
406341

342+
return response, nil
407343
}
408344

409345
// RefreshSessionIfNeeded takes in a SessionState and

0 commit comments

Comments
 (0)