Skip to content

Commit 99d5e91

Browse files
authored
Merge pull request #1193 from 99designs/fix-login-with-static-credentials
Fix credential type detection for login
2 parents 5e0a968 + bfa952d commit 99d5e91

8 files changed

+195
-94
lines changed

cli/login.go

Lines changed: 146 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/aws/aws-sdk-go-v2/aws"
1818
awsconfig "github.com/aws/aws-sdk-go-v2/config"
1919
"github.com/aws/aws-sdk-go-v2/credentials"
20+
"github.com/aws/aws-sdk-go-v2/service/sts"
2021
"github.com/skratchdot/open-golang/open"
2122
)
2223

@@ -74,108 +75,99 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) {
7475
return err
7576
}
7677

77-
err = LoginCommand(input, f, keyring)
78+
err = LoginCommand(context.Background(), input, f, keyring)
7879
app.FatalIfError(err, "login")
7980
return nil
8081
})
8182
}
8283

83-
func LoginCommand(input LoginCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error {
84-
config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).GetProfileConfig(input.ProfileName)
85-
if err != nil {
86-
return fmt.Errorf("Error loading config: %w", err)
87-
}
88-
89-
var credsProvider aws.CredentialsProvider
90-
84+
func getCredsProvider(input LoginCommandInput, config *vault.ProfileConfig, keyring keyring.Keyring) (credsProvider aws.CredentialsProvider, err error) {
9185
if input.ProfileName == "" {
9286
// When no profile is specified, source credentials from the environment
9387
configFromEnv, err := awsconfig.NewEnvConfig()
9488
if err != nil {
95-
return fmt.Errorf("unable to authenticate to AWS through your environment variables: %w", err)
89+
return nil, fmt.Errorf("unable to authenticate to AWS through your environment variables: %w", err)
9690
}
97-
credsProvider = credentials.StaticCredentialsProvider{Value: configFromEnv.Credentials}
98-
if configFromEnv.Credentials.SessionToken == "" {
99-
credsProvider, err = vault.NewFederationTokenProvider(context.TODO(), credsProvider, config)
100-
if err != nil {
101-
return err
102-
}
91+
92+
if configFromEnv.Credentials.AccessKeyID == "" {
93+
return nil, fmt.Errorf("argument 'profile' not provided, nor any AWS env vars found. Try --help")
10394
}
95+
96+
credsProvider = credentials.StaticCredentialsProvider{Value: configFromEnv.Credentials}
10497
} else {
10598
// Use a profile from the AWS config file
10699
ckr := &vault.CredentialKeyring{Keyring: keyring}
107-
if config.HasRole() || config.HasSSOStartURL() || config.HasCredentialProcess() || config.HasWebIdentity() {
108-
// If AssumeRole or sso.GetRoleCredentials isn't used, GetFederationToken has to be used for IAM credentials
109-
credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, input.NoSession, false)
110-
} else {
111-
credsProvider, err = vault.NewFederationTokenCredentialsProvider(context.TODO(), input.ProfileName, ckr, config)
100+
t := vault.TempCredentialsCreator{
101+
Keyring: ckr,
102+
DisableSessions: input.NoSession,
103+
DisableSessionsForProfile: config.ProfileName,
112104
}
105+
credsProvider, err = t.GetProviderForProfile(config)
113106
if err != nil {
114-
return fmt.Errorf("profile %s: %w", input.ProfileName, err)
107+
return nil, fmt.Errorf("profile %s: %w", input.ProfileName, err)
115108
}
116109
}
117110

118-
creds, err := credsProvider.Retrieve(context.TODO())
111+
return credsProvider, err
112+
}
113+
114+
// LoginCommand creates a login URL for the AWS Management Console using the method described at
115+
// https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_providers_enable-console-custom-url.html
116+
func LoginCommand(ctx context.Context, input LoginCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error {
117+
config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).GetProfileConfig(input.ProfileName)
119118
if err != nil {
120-
return fmt.Errorf("Failed to get credentials: %w", err)
121-
}
122-
if creds.AccessKeyID == "" && input.ProfileName == "" {
123-
return fmt.Errorf("argument 'profile' not provided, nor any AWS env vars found. Try --help")
119+
return fmt.Errorf("Error loading config: %w", err)
124120
}
125121

126-
jsonBytes, err := json.Marshal(map[string]string{
127-
"sessionId": creds.AccessKeyID,
128-
"sessionKey": creds.SecretAccessKey,
129-
"sessionToken": creds.SessionToken,
130-
})
122+
credsProvider, err := getCredsProvider(input, config, keyring)
131123
if err != nil {
132124
return err
133125
}
134126

135-
loginURLPrefix, destination := generateLoginURL(config.Region, input.Path)
136-
137-
req, err := http.NewRequestWithContext(context.TODO(), "GET", loginURLPrefix, nil)
127+
// if we already know the type of credentials being created, avoid calling isCallerIdentityAssumedRole
128+
canCredsBeUsedInLoginURL, err := canProviderBeUsedForLogin(credsProvider)
138129
if err != nil {
139130
return err
140131
}
141132

142-
if creds.CanExpire {
143-
log.Printf("Creating login token, expires in %s", time.Until(creds.Expires))
144-
}
133+
if !canCredsBeUsedInLoginURL {
134+
// use a static creds provider so that we don't request credentials from AWS more than once
135+
credsProvider, err = createStaticCredentialsProvider(ctx, credsProvider)
136+
if err != nil {
137+
return err
138+
}
145139

146-
q := req.URL.Query()
147-
q.Add("Action", "getSigninToken")
148-
q.Add("Session", string(jsonBytes))
149-
req.URL.RawQuery = q.Encode()
140+
// if the credentials have come from an unknown source like credential_process, check the
141+
// caller identity to see if it's an assumed role
142+
isAssumedRole, err := isCallerIdentityAssumedRole(ctx, credsProvider, config)
143+
if err != nil {
144+
return err
145+
}
150146

151-
resp, err := http.DefaultClient.Do(req)
152-
if err != nil {
153-
return err
147+
if !isAssumedRole {
148+
log.Println("Creating a federated session")
149+
credsProvider, err = vault.NewFederationTokenProvider(ctx, credsProvider, config)
150+
if err != nil {
151+
return err
152+
}
153+
}
154154
}
155155

156-
defer resp.Body.Close()
157-
body, err := io.ReadAll(resp.Body)
156+
creds, err := credsProvider.Retrieve(ctx)
158157
if err != nil {
159158
return err
160159
}
161160

162-
if resp.StatusCode != http.StatusOK {
163-
log.Printf("Response body was %s", body)
164-
return fmt.Errorf("Call to getSigninToken failed with %v", resp.Status)
161+
if creds.CanExpire {
162+
log.Printf("Requesting a signin token for session expiring in %s", time.Until(creds.Expires))
165163
}
166164

167-
var respParsed map[string]string
168-
169-
err = json.Unmarshal(body, &respParsed)
165+
loginURLPrefix, destination := generateLoginURL(config.Region, input.Path)
166+
signinToken, err := requestSigninToken(ctx, creds, loginURLPrefix)
170167
if err != nil {
171168
return err
172169
}
173170

174-
signinToken, ok := respParsed["SigninToken"]
175-
if !ok {
176-
return fmt.Errorf("Expected a response with SigninToken")
177-
}
178-
179171
loginURL := fmt.Sprintf("%s?Action=login&Issuer=aws-vault&Destination=%s&SigninToken=%s",
180172
loginURLPrefix, url.QueryEscape(destination), url.QueryEscape(signinToken))
181173

@@ -212,3 +204,99 @@ func generateLoginURL(region string, path string) (string, string) {
212204
}
213205
return loginURLPrefix, destination
214206
}
207+
208+
func isCallerIdentityAssumedRole(ctx context.Context, credsProvider aws.CredentialsProvider, config *vault.ProfileConfig) (bool, error) {
209+
cfg := vault.NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints)
210+
client := sts.NewFromConfig(cfg)
211+
id, err := client.GetCallerIdentity(ctx, nil)
212+
if err != nil {
213+
return false, err
214+
}
215+
arn := aws.ToString(id.Arn)
216+
arnParts := strings.Split(arn, ":")
217+
if len(arnParts) < 6 {
218+
return false, fmt.Errorf("unable to parse ARN: %s", arn)
219+
}
220+
if strings.HasPrefix(arnParts[5], "assumed-role") {
221+
return true, nil
222+
}
223+
return false, nil
224+
}
225+
226+
func createStaticCredentialsProvider(ctx context.Context, credsProvider aws.CredentialsProvider) (sc credentials.StaticCredentialsProvider, err error) {
227+
creds, err := credsProvider.Retrieve(ctx)
228+
if err != nil {
229+
return sc, err
230+
}
231+
return credentials.StaticCredentialsProvider{Value: creds}, nil
232+
}
233+
234+
// canProviderBeUsedForLogin returns true if the credentials produced by the provider is known to be usable by the login URL endpoint
235+
func canProviderBeUsedForLogin(credsProvider aws.CredentialsProvider) (bool, error) {
236+
if _, ok := credsProvider.(*vault.AssumeRoleProvider); ok {
237+
return true, nil
238+
}
239+
if _, ok := credsProvider.(*vault.SSORoleCredentialsProvider); ok {
240+
return true, nil
241+
}
242+
if _, ok := credsProvider.(*vault.AssumeRoleWithWebIdentityProvider); ok {
243+
return true, nil
244+
}
245+
if c, ok := credsProvider.(*vault.CachedSessionProvider); ok {
246+
return canProviderBeUsedForLogin(c.SessionProvider)
247+
}
248+
249+
return false, nil
250+
}
251+
252+
// Create a signin token
253+
func requestSigninToken(ctx context.Context, creds aws.Credentials, loginURLPrefix string) (string, error) {
254+
jsonSession, err := json.Marshal(map[string]string{
255+
"sessionId": creds.AccessKeyID,
256+
"sessionKey": creds.SecretAccessKey,
257+
"sessionToken": creds.SessionToken,
258+
})
259+
if err != nil {
260+
return "", err
261+
}
262+
263+
req, err := http.NewRequestWithContext(ctx, "GET", loginURLPrefix, nil)
264+
if err != nil {
265+
return "", err
266+
}
267+
268+
q := req.URL.Query()
269+
q.Add("Action", "getSigninToken")
270+
q.Add("Session", string(jsonSession))
271+
req.URL.RawQuery = q.Encode()
272+
273+
resp, err := http.DefaultClient.Do(req)
274+
if err != nil {
275+
return "", err
276+
}
277+
278+
defer resp.Body.Close()
279+
body, err := io.ReadAll(resp.Body)
280+
if err != nil {
281+
return "", err
282+
}
283+
284+
if resp.StatusCode != http.StatusOK {
285+
log.Printf("Response body was %s", body)
286+
return "", fmt.Errorf("Call to getSigninToken failed with %v", resp.Status)
287+
}
288+
289+
var respParsed map[string]string
290+
291+
err = json.Unmarshal(body, &respParsed)
292+
if err != nil {
293+
return "", err
294+
}
295+
296+
signinToken, ok := respParsed["SigninToken"]
297+
if !ok {
298+
return "", fmt.Errorf("Expected a response with SigninToken")
299+
}
300+
301+
return signinToken, nil
302+
}

vault/assumeroleprovider.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ type AssumeRoleProvider struct {
2626

2727
// Retrieve generates a new set of temporary credentials using STS AssumeRole
2828
func (p *AssumeRoleProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
29-
role, err := p.assumeRole(ctx)
29+
role, err := p.RetrieveStsCredentials(ctx)
3030
if err != nil {
3131
return aws.Credentials{}, err
3232
}
@@ -49,7 +49,7 @@ func (p *AssumeRoleProvider) roleSessionName() string {
4949
return p.RoleSessionName
5050
}
5151

52-
func (p *AssumeRoleProvider) assumeRole(ctx context.Context) (*ststypes.Credentials, error) {
52+
func (p *AssumeRoleProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
5353
var err error
5454

5555
input := &sts.AssumeRoleInput{

vault/assumerolewithwebidentityprovider.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ type AssumeRoleWithWebIdentityProvider struct {
2525

2626
// Retrieve generates a new set of temporary credentials using STS AssumeRoleWithWebIdentity
2727
func (p *AssumeRoleWithWebIdentityProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
28-
creds, err := p.assumeRole(ctx)
28+
creds, err := p.RetrieveStsCredentials(ctx)
2929
if err != nil {
3030
return aws.Credentials{}, err
3131
}
@@ -48,7 +48,7 @@ func (p *AssumeRoleWithWebIdentityProvider) roleSessionName() string {
4848
return p.RoleSessionName
4949
}
5050

51-
func (p *AssumeRoleWithWebIdentityProvider) assumeRole(ctx context.Context) (*ststypes.Credentials, error) {
51+
func (p *AssumeRoleWithWebIdentityProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
5252
var err error
5353

5454
webIdentityToken, err := p.webIdentityToken()

vault/cachedsessionprovider.go

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,48 @@ import (
99
ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
1010
)
1111

12+
type StsSessionProvider interface {
13+
aws.CredentialsProvider
14+
RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error)
15+
}
16+
1217
// CachedSessionProvider retrieves cached credentials from the keyring, or if no credentials are cached
1318
// retrieves temporary credentials using the CredentialsFunc
1419
type CachedSessionProvider struct {
1520
SessionKey SessionMetadata
16-
CredentialsFunc func(context.Context) (*ststypes.Credentials, error)
21+
SessionProvider StsSessionProvider
1722
Keyring *SessionKeyring
1823
ExpiryWindow time.Duration
1924
}
2025

21-
// Retrieve returns cached credentials from the keyring, or if no credentials are cached
22-
// generates a new set of temporary credentials using the CredentialsFunc
23-
func (p *CachedSessionProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
26+
func (p *CachedSessionProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
2427
creds, err := p.Keyring.Get(p.SessionKey)
2528

2629
if err != nil || time.Until(*creds.Expiration) < p.ExpiryWindow {
2730
// lookup missed, we need to create a new one.
28-
creds, err = p.CredentialsFunc(ctx)
31+
creds, err = p.SessionProvider.RetrieveStsCredentials(ctx)
2932
if err != nil {
30-
return aws.Credentials{}, err
33+
return nil, err
3134
}
3235
err = p.Keyring.Set(p.SessionKey, creds)
3336
if err != nil {
34-
return aws.Credentials{}, err
37+
return nil, err
3538
}
3639
} else {
3740
log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String())
3841
}
3942

43+
return creds, nil
44+
}
45+
46+
// Retrieve returns cached credentials from the keyring, or if no credentials are cached
47+
// generates a new set of temporary credentials using the CredentialsFunc
48+
func (p *CachedSessionProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
49+
creds, err := p.RetrieveStsCredentials(ctx)
50+
if err != nil {
51+
return aws.Credentials{}, err
52+
}
53+
4054
return aws.Credentials{
4155
AccessKeyID: aws.ToString(creds.AccessKeyId),
4256
SecretAccessKey: aws.ToString(creds.SecretAccessKey),

vault/credentialprocessprovider.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func (p *CredentialProcessProvider) retrieveWith(ctx context.Context, fn func(st
5555
}, nil
5656
}
5757

58-
func (p *CredentialProcessProvider) callCredentialProcess(ctx context.Context) (*ststypes.Credentials, error) {
58+
func (p *CredentialProcessProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
5959
return p.callCredentialProcessWith(ctx, executeProcess)
6060
}
6161

vault/sessiontokenprovider.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ type SessionTokenProvider struct {
1919

2020
// Retrieve generates a new set of temporary credentials using STS GetSessionToken
2121
func (p *SessionTokenProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
22-
creds, err := p.GetSessionToken(ctx)
22+
creds, err := p.RetrieveStsCredentials(ctx)
2323
if err != nil {
2424
return aws.Credentials{}, err
2525
}
@@ -34,7 +34,7 @@ func (p *SessionTokenProvider) Retrieve(ctx context.Context) (aws.Credentials, e
3434
}
3535

3636
// GetSessionToken generates a new set of temporary credentials using STS GetSessionToken
37-
func (p *SessionTokenProvider) GetSessionToken(ctx context.Context) (*ststypes.Credentials, error) {
37+
func (p *SessionTokenProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
3838
var err error
3939

4040
input := &sts.GetSessionTokenInput{

vault/ssorolecredentialsprovider.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ func (p *SSORoleCredentialsProvider) getRoleCredentials(ctx context.Context) (*s
9393
return resp.RoleCredentials, nil
9494
}
9595

96+
func (p *SSORoleCredentialsProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
97+
return p.getRoleCredentialsAsStsCredemtials(ctx)
98+
}
99+
96100
// getRoleCredentialsAsStsCredemtials returns getRoleCredentials as sts.Credentials because sessions.Store expects it
97101
func (p *SSORoleCredentialsProvider) getRoleCredentialsAsStsCredemtials(ctx context.Context) (*ststypes.Credentials, error) {
98102
creds, err := p.getRoleCredentials(ctx)

0 commit comments

Comments
 (0)