-
Notifications
You must be signed in to change notification settings - Fork 830
Implement OAuth2 PKCE in SSORoleCredentialsProvider #1258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -2,10 +2,17 @@ package vault | |||||||
|
||||||||
import ( | ||||||||
"context" | ||||||||
crand "crypto/rand" | ||||||||
"crypto/sha256" | ||||||||
"crypto/subtle" | ||||||||
"encoding/base64" | ||||||||
"errors" | ||||||||
"fmt" | ||||||||
"io" | ||||||||
"log" | ||||||||
"net" | ||||||||
"net/http" | ||||||||
"net/url" | ||||||||
"os" | ||||||||
"time" | ||||||||
|
||||||||
|
@@ -122,7 +129,19 @@ func (p *SSORoleCredentialsProvider) getOIDCToken(ctx context.Context) (token *s | |||||||
return token, true, nil | ||||||||
} | ||||||||
} | ||||||||
token, err = p.newOIDCToken(ctx) | ||||||||
|
||||||||
// if we must use stdout (either by user choice or because we have determined we are in an SSH session where we cannot open a browser) | ||||||||
// then we use the "device code" grant flow, and print URLs to stdout for the user to manually copy/paste into their browser. | ||||||||
// | ||||||||
// Otherwise we use the "authorization code" grant flow with Proof Key for | ||||||||
// Code Exchange (PKCE) and open the browser automatically. The latter flow | ||||||||
// is more user-friendly and secure because the step comparing the challenge | ||||||||
// code between the CLI and browser is automated entirely. | ||||||||
if p.UseStdout { | ||||||||
token, err = p.newOIDCToken(ctx) | ||||||||
} else { | ||||||||
token, err = p.newOIDCTokenPKCE(ctx) | ||||||||
} | ||||||||
if err != nil { | ||||||||
return nil, false, err | ||||||||
} | ||||||||
|
@@ -155,16 +174,7 @@ func (p *SSORoleCredentialsProvider) newOIDCToken(ctx context.Context) (*ssooidc | |||||||
return nil, err | ||||||||
} | ||||||||
log.Printf("Created OIDC device code for %s (expires in: %ds)", p.StartURL, deviceCreds.ExpiresIn) | ||||||||
|
||||||||
if p.UseStdout { | ||||||||
fmt.Fprintf(os.Stderr, "Open the SSO authorization page in a browser (use Ctrl-C to abort)\n%s\n", aws.ToString(deviceCreds.VerificationUriComplete)) | ||||||||
} else { | ||||||||
log.Println("Opening SSO authorization page in browser") | ||||||||
fmt.Fprintf(os.Stderr, "Opening the SSO authorization page in your default browser (use Ctrl-C to abort)\n%s\n", aws.ToString(deviceCreds.VerificationUriComplete)) | ||||||||
if err := open.Run(aws.ToString(deviceCreds.VerificationUriComplete)); err != nil { | ||||||||
log.Printf("Failed to open browser: %s", err) | ||||||||
} | ||||||||
} | ||||||||
fmt.Fprintf(os.Stderr, "Open the SSO authorization page in a browser (use Ctrl-C to abort)\n%s\n", aws.ToString(deviceCreds.VerificationUriComplete)) | ||||||||
|
||||||||
// These are the default values defined in the following RFC: | ||||||||
// https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.5 | ||||||||
|
@@ -201,3 +211,191 @@ func (p *SSORoleCredentialsProvider) newOIDCToken(ctx context.Context) (*ssooidc | |||||||
return t, nil | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
// newOIDCTokenPKCE generates a new OIDC token using the "Authorization Code Grant" flow with PKCE. | ||||||||
func (p *SSORoleCredentialsProvider) newOIDCTokenPKCE(ctx context.Context) (*ssooidc.CreateTokenOutput, error) { | ||||||||
// ref: https://datatracker.ietf.org/doc/html/rfc7636 | ||||||||
|
||||||||
// generate a random 32 byte code verifier; base64 encode it | ||||||||
codeVerifierBytes := make([]byte, 32) | ||||||||
n, err := crand.Read(codeVerifierBytes) | ||||||||
if err != nil || n != 32 { | ||||||||
return nil, fmt.Errorf("failed to generate PKCE verifier: %w", err) | ||||||||
} | ||||||||
codeVerifier := base64.RawURLEncoding.EncodeToString(codeVerifierBytes) | ||||||||
|
||||||||
// generate the code challenge: base64(sha256(codeVerifier)) | ||||||||
codeChallengeBytes := sha256.Sum256([]byte(codeVerifier)) | ||||||||
codeChallenge := base64.RawURLEncoding.EncodeToString(codeChallengeBytes[:]) | ||||||||
log.Printf("Generated PKCE code_challenge: %q", codeChallenge) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this necessarily needs to be displayed by default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was confused at first also at the use of Lines 166 to 168 in d4706c8
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, good point |
||||||||
|
||||||||
clientCreds, err := p.OIDCClient.RegisterClient(ctx, &ssooidc.RegisterClientInput{ | ||||||||
ClientName: aws.String("aws-vault"), | ||||||||
ClientType: aws.String("public"), | ||||||||
GrantTypes: []string{"authorization_code", "refresh_token"}, | ||||||||
Scopes: []string{"sso:account:access"}, | ||||||||
IssuerUrl: aws.String(p.StartURL), | ||||||||
RedirectUris: []string{"http://127.0.0.1/oauth/callback"}, | ||||||||
}) | ||||||||
if err != nil { | ||||||||
return nil, err | ||||||||
} | ||||||||
log.Printf("Created new OIDC client (expires at: %s)", time.Unix(clientCreds.ClientSecretExpiresAt, 0)) | ||||||||
|
||||||||
// start the callback server | ||||||||
cbServer, err := newOauthCallbackServer() | ||||||||
if err != nil { | ||||||||
return nil, fmt.Errorf("failed to create oauthCallbackServer: %w", err) | ||||||||
} | ||||||||
log.Printf("oauthCallbackServer callback endpoint: %s", cbServer.redirectURI()) | ||||||||
go func() { | ||||||||
if err := cbServer.Serve(); err != nil && !errors.Is(err, http.ErrServerClosed) { | ||||||||
log.Printf("Failed to run oauthCallbackServer: %s", err) | ||||||||
} | ||||||||
}() | ||||||||
// keep a copy of the redirectURI for the CreateToken call (as the server will be closed after the code is received) | ||||||||
redirectURI := cbServer.redirectURI() | ||||||||
|
||||||||
// construct the authorize URL with the client and PKCE parameters | ||||||||
args := url.Values{ | ||||||||
"client_id": {aws.ToString(clientCreds.ClientId)}, | ||||||||
"response_type": {"code"}, | ||||||||
"redirect_uri": {redirectURI}, | ||||||||
"state": {cbServer.state}, | ||||||||
"code_challenge_method": {"S256"}, | ||||||||
"code_challenge": {codeChallenge}, | ||||||||
"scopes": {"sso:account:access"}, | ||||||||
} | ||||||||
// prefer the base endpoint from client options, otherwise use a default | ||||||||
var host string | ||||||||
if p.OIDCClient.Options().BaseEndpoint != nil && *p.OIDCClient.Options().BaseEndpoint != "" { | ||||||||
host = *p.OIDCClient.Options().BaseEndpoint | ||||||||
} else { | ||||||||
host = "oidc.us-east-1.amazonaws.com" | ||||||||
} | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hardcoding the region to To get this to work I used |
||||||||
authorizeURL := url.URL{ | ||||||||
Scheme: "https", | ||||||||
Host: host, | ||||||||
Path: "/authorize", | ||||||||
RawQuery: args.Encode(), | ||||||||
} | ||||||||
log.Printf("Authorize URL: %s", authorizeURL.String()) | ||||||||
|
||||||||
// redirect user to the authorize URL | ||||||||
log.Println("Opening SSO authorization page in browser") | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (related to my first comment of this PR, I think this should rather honor the |
||||||||
fmt.Fprintf(os.Stderr, "Opening the SSO authorization page in your default browser (use Ctrl-C to abort)\n%s\n", authorizeURL.String()) | ||||||||
if err := open.Run(authorizeURL.String()); err != nil { | ||||||||
log.Printf("Failed to open browser: %s", err) | ||||||||
} | ||||||||
|
||||||||
// await the authorization code from the callback server once the user has completed the flow. | ||||||||
var code string | ||||||||
timeout := time.After(1 * time.Minute) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW the device code implementation here waits indefinitely (modulo a timeout on the AWS side), so perhaps it's worth increasing it? |
||||||||
select { | ||||||||
case <-timeout: | ||||||||
return nil, errors.New("timed out waiting for authorization code") | ||||||||
case code = <-cbServer.code: | ||||||||
log.Printf("Received authorization code: %s", code) | ||||||||
if err := cbServer.h.Close(); err != nil { | ||||||||
log.Printf("Failed to close oauthCallbackServer: %s", err) | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
// create the OIDC token using the authorization code received from the callback server | ||||||||
tok, err := p.OIDCClient.CreateToken(ctx, &ssooidc.CreateTokenInput{ | ||||||||
ClientId: clientCreds.ClientId, | ||||||||
ClientSecret: clientCreds.ClientSecret, | ||||||||
Code: aws.String(code), | ||||||||
CodeVerifier: aws.String(codeVerifier), | ||||||||
GrantType: aws.String("authorization_code"), | ||||||||
RedirectUri: aws.String(redirectURI), | ||||||||
}) | ||||||||
if err != nil { | ||||||||
return nil, err | ||||||||
} | ||||||||
|
||||||||
log.Printf("Created new OIDC access token for %s (expires in: %ds)", p.StartURL, tok.ExpiresIn) | ||||||||
return tok, nil | ||||||||
} | ||||||||
|
||||||||
// newOauthCallbackServer creates a HTTP server listening on a random localhost | ||||||||
// port to serve the OAuth2 callback. It serves a single oauth callback endpoint | ||||||||
// and sends the authorization code received via a channel. | ||||||||
func newOauthCallbackServer() (*oauthCallbackServer, error) { | ||||||||
// select a random port for the callback server | ||||||||
ln, err := net.Listen("tcp", ":0") | ||||||||
if err != nil { | ||||||||
return nil, fmt.Errorf("failed to create listener: %w", err) | ||||||||
} | ||||||||
log.Printf("oauthCallbackListener listening on %s", ln.Addr().String()) | ||||||||
|
||||||||
// create a 32 byte state for CSRF protection | ||||||||
state := make([]byte, 32) | ||||||||
n, err := crand.Read(state) | ||||||||
if err != nil || n != 32 { | ||||||||
return nil, fmt.Errorf("failed to generate state: %w", err) | ||||||||
} | ||||||||
|
||||||||
oauth := &oauthCallbackServer{ | ||||||||
state: base64.RawURLEncoding.EncodeToString(state), | ||||||||
code: make(chan string), | ||||||||
ln: ln, | ||||||||
} | ||||||||
oauth.h = &http.Server{ | ||||||||
Handler: http.HandlerFunc(oauth.handleCallback), | ||||||||
} | ||||||||
|
||||||||
return oauth, nil | ||||||||
} | ||||||||
|
||||||||
// handleCallback handles the OAuth2 callback request and sends the authorization code to the server channel. | ||||||||
func (s *oauthCallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) { | ||||||||
// only respond to GET requests on the callback | ||||||||
if r.Method != http.MethodGet { | ||||||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) | ||||||||
return | ||||||||
} | ||||||||
if r.URL.Path != "/oauth/callback" { | ||||||||
http.Error(w, "Not Found", http.StatusNotFound) | ||||||||
return | ||||||||
} | ||||||||
|
||||||||
// constant time string comparison of want vs got state | ||||||||
state := r.URL.Query().Get("state") | ||||||||
if subtle.ConstantTimeCompare([]byte(state), []byte(s.state)) != 1 { | ||||||||
http.Error(w, "Invalid state", http.StatusBadRequest) | ||||||||
return | ||||||||
} | ||||||||
|
||||||||
// send the authorization code to the channel | ||||||||
code := r.URL.Query().Get("code") | ||||||||
s.code <- code | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps it's worth returning early if there's an error rather than letting the consumer timeout? In which case s.code could be a channel of a struct that can hold either an code, err := <- server.code
if err != nil {
// error
} |
||||||||
|
||||||||
// respond with a success message | ||||||||
io.WriteString(w, "Authorization code received, you can close this tab now.") | ||||||||
} | ||||||||
|
||||||||
// redirectURI returns the URL for the OAuth callback endpoint with the server's port included in the address. | ||||||||
func (s *oauthCallbackServer) redirectURI() string { | ||||||||
// AWS requires that the callback be a 127.0.0.1 v4 address | ||||||||
u := url.URL{ | ||||||||
Scheme: "http", | ||||||||
Host: fmt.Sprintf("127.0.0.1:%d", s.ln.Addr().(*net.TCPAddr).Port), | ||||||||
Path: "/oauth/callback", | ||||||||
} | ||||||||
return u.String() | ||||||||
} | ||||||||
|
||||||||
type oauthCallbackServer struct { | ||||||||
ln net.Listener | ||||||||
h *http.Server | ||||||||
|
||||||||
// secret used to prevent CSRF attacks | ||||||||
state string | ||||||||
// channel to send authorization code after successful callback | ||||||||
code chan string | ||||||||
} | ||||||||
|
||||||||
func (s *oauthCallbackServer) Serve() error { | ||||||||
return s.h.Serve(s.ln) | ||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -141,13 +141,20 @@ func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConf | |
func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) { | ||
cfg := NewAwsConfig(config.SSORegion, config.STSRegionalEndpoints) | ||
|
||
// If we're in an SSH session, we can't use the browser for SSO so we print | ||
// the URLs to stdout instead. | ||
useStdout := config.SSOUseStdout | ||
if os.Getenv("SSH_CONNECTION") != "" { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems like an addition that's independant from the PKCE flow? Overall I wouldn't necessarily expect this "magic" behavior from aws-vault, as a user I prefer a consistent behavior and explicitely passing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to back this out and have it be a separate discussion. My thinking was that given we now have login flows that cater to browser and non-browser setups to direct the user to the one that makes sense in context but I can understand how that would be confusing. A number of us at Tailscale use |
||
useStdout = true | ||
} | ||
|
||
ssoRoleCredentialsProvider := &SSORoleCredentialsProvider{ | ||
OIDCClient: ssooidc.NewFromConfig(cfg), | ||
StartURL: config.SSOStartURL, | ||
SSOClient: sso.NewFromConfig(cfg), | ||
AccountID: config.SSOAccountID, | ||
RoleName: config.SSORoleName, | ||
UseStdout: config.SSOUseStdout, | ||
UseStdout: useStdout, | ||
} | ||
|
||
if useSessionCache { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I commonly use
--stdout
even when I'm on the same machine as whereaws-vault
is running, for various purposes (e.g. wanting to login in a non-default browser). I'm not 100% convinced thatUseStdout
should be the basis for chosing between the device code and PKCE flow, how about instead following the same behavior as the AWS CLI? It seems to have a flag--use-device-code
flag aws/aws-cli@130005a#diff-e07a10a6eb1a677e905b0498651e672137b5dff19d357933bd7bc3e36f845a3bL42 which defaults to falseThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great shout. I hadn't considered the same device non-default browser path that users might wish to take. I was primarily thinking about remote workstations and in that context printing the URL out to be opened locally would not make sense as the eventual redirect destination would be inaccessible to the browser.