Skip to content

Commit 6ebe3fb

Browse files
authored
Merge pull request #1196 from 99designs/fix-source-profile
Prioritise source_profile over sso config
2 parents 99d5e91 + 9dc5bca commit 6ebe3fb

File tree

2 files changed

+102
-29
lines changed

2 files changed

+102
-29
lines changed

vault/vault.go

+42-29
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ func FormatKeyForDisplay(k string) string {
4141
return fmt.Sprintf("****************%s", k[len(k)-4:])
4242
}
4343

44+
func isMasterCredentialsProvider(credsProvider aws.CredentialsProvider) bool {
45+
_, ok := credsProvider.(*KeyringProvider)
46+
return ok
47+
}
48+
4449
// NewMasterCredentialsProvider creates a provider for the master credentials
4550
func NewMasterCredentialsProvider(k *CredentialKeyring, credentialsName string) *KeyringProvider {
4651
return &KeyringProvider{k, credentialsName}
@@ -243,52 +248,60 @@ func (t *TempCredentialsCreator) getSourceCreds(config *ProfileConfig, hasStored
243248
return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName)
244249
}
245250

246-
func (t *TempCredentialsCreator) GetProviderForProfile(config *ProfileConfig) (aws.CredentialsProvider, error) {
247-
hasStoredCredentials, err := t.Keyring.Has(config.ProfileName)
251+
func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, hasStoredCredentials bool) (sourcecredsProvider aws.CredentialsProvider, err error) {
252+
sourcecredsProvider, err = t.getSourceCreds(config, hasStoredCredentials)
248253
if err != nil {
249254
return nil, err
250255
}
251256

252-
if !hasStoredCredentials {
253-
if config.HasSSOStartURL() {
254-
log.Printf("profile %s: using SSO role credentials", config.ProfileName)
255-
return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config, !t.DisableCache)
256-
}
257-
258-
if config.HasWebIdentity() {
259-
log.Printf("profile %s: using web identity", config.ProfileName)
260-
return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config, !t.DisableCache)
257+
if config.HasRole() {
258+
isMfaChained := config.MfaSerial != "" && config.MfaSerial == t.chainedMfa
259+
if isMfaChained {
260+
config.MfaSerial = ""
261261
}
262+
log.Printf("profile %s: using AssumeRole %s", config.ProfileName, mfaDetails(isMfaChained, config))
263+
return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache)
264+
}
262265

263-
if config.HasCredentialProcess() {
264-
log.Printf("profile %s: using credential process", config.ProfileName)
265-
return NewCredentialProcessProvider(t.Keyring.Keyring, config, !t.DisableCache)
266+
if isMasterCredentialsProvider(sourcecredsProvider) {
267+
canUseGetSessionToken, reason := t.canUseGetSessionToken(config)
268+
if canUseGetSessionToken {
269+
t.chainedMfa = config.MfaSerial
270+
log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config))
271+
return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache)
266272
}
273+
log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason)
267274
}
268275

269-
sourcecredsProvider, err := t.getSourceCreds(config, hasStoredCredentials)
276+
return sourcecredsProvider, nil
277+
}
278+
279+
func (t *TempCredentialsCreator) GetProviderForProfile(config *ProfileConfig) (aws.CredentialsProvider, error) {
280+
hasStoredCredentials, err := t.Keyring.Has(config.ProfileName)
270281
if err != nil {
271282
return nil, err
272283
}
273284

274-
if config.HasRole() {
275-
isMfaChained := config.MfaSerial != "" && config.MfaSerial == t.chainedMfa
276-
if isMfaChained {
277-
config.MfaSerial = ""
278-
}
279-
log.Printf("profile %s: using AssumeRole %s", config.ProfileName, mfaDetails(isMfaChained, config))
280-
return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache)
285+
if hasStoredCredentials || config.HasSourceProfile() {
286+
return t.getSourceCredWithSession(config, hasStoredCredentials)
281287
}
282288

283-
canUseGetSessionToken, reason := t.canUseGetSessionToken(config)
284-
if canUseGetSessionToken {
285-
t.chainedMfa = config.MfaSerial
286-
log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config))
287-
return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache)
289+
if config.HasSSOStartURL() {
290+
log.Printf("profile %s: using SSO role credentials", config.ProfileName)
291+
return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config, !t.DisableCache)
288292
}
289293

290-
log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason)
291-
return sourcecredsProvider, nil
294+
if config.HasWebIdentity() {
295+
log.Printf("profile %s: using web identity", config.ProfileName)
296+
return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config, !t.DisableCache)
297+
}
298+
299+
if config.HasCredentialProcess() {
300+
log.Printf("profile %s: using credential process", config.ProfileName)
301+
return NewCredentialProcessProvider(t.Keyring.Keyring, config, !t.DisableCache)
302+
}
303+
304+
return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName)
292305
}
293306

294307
// canUseGetSessionToken determines if GetSessionToken should be used, and if not returns a reason

vault/vault_test.go

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package vault_test
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"github.com/99designs/aws-vault/v7/vault"
8+
"github.com/99designs/keyring"
9+
)
10+
11+
func TestIssue1195(t *testing.T) {
12+
f := newConfigFile(t, []byte(`
13+
[profile test]
14+
source_profile=dev
15+
region=ap-northeast-2
16+
17+
[profile dev]
18+
sso_session=common
19+
sso_account_id=2160xxxx
20+
sso_role_name=AdministratorAccess
21+
region=ap-northeast-2
22+
output=json
23+
24+
[default]
25+
sso_session=common
26+
sso_account_id=3701xxxx
27+
sso_role_name=AdministratorAccess
28+
region=ap-northeast-2
29+
output=json
30+
31+
[sso-session common]
32+
sso_start_url=https://xxxx.awsapps.com/start
33+
sso_region=ap-northeast-2
34+
sso_registration_scopes=sso:account:access
35+
`))
36+
defer os.Remove(f)
37+
configFile, err := vault.LoadConfig(f)
38+
if err != nil {
39+
t.Fatal(err)
40+
}
41+
configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "test"}
42+
config, err := configLoader.GetProfileConfig("test")
43+
if err != nil {
44+
t.Fatalf("Should have found a profile: %v", err)
45+
}
46+
47+
ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{})}
48+
p, err := vault.NewTempCredentialsProvider(config, ckr, true, true)
49+
if err != nil {
50+
t.Fatal(err)
51+
}
52+
53+
ssoProvider, ok := p.(*vault.SSORoleCredentialsProvider)
54+
if !ok {
55+
t.Fatalf("Expected SSORoleCredentialsProvider, got %T", p)
56+
}
57+
if ssoProvider.AccountID != "2160xxxx" {
58+
t.Fatalf("Expected AccountID to be 2160xxxx, got %s", ssoProvider.AccountID)
59+
}
60+
}

0 commit comments

Comments
 (0)