@@ -17,6 +17,7 @@ import (
17
17
"github.com/aws/aws-sdk-go-v2/aws"
18
18
awsconfig "github.com/aws/aws-sdk-go-v2/config"
19
19
"github.com/aws/aws-sdk-go-v2/credentials"
20
+ "github.com/aws/aws-sdk-go-v2/service/sts"
20
21
"github.com/skratchdot/open-golang/open"
21
22
)
22
23
@@ -74,108 +75,99 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) {
74
75
return err
75
76
}
76
77
77
- err = LoginCommand (input , f , keyring )
78
+ err = LoginCommand (context . Background (), input , f , keyring )
78
79
app .FatalIfError (err , "login" )
79
80
return nil
80
81
})
81
82
}
82
83
83
- func LoginCommand (input LoginCommandInput , f * vault.ConfigFile , keyring keyring.Keyring ) error {
84
- config , err := vault .NewConfigLoader (input .Config , f , input .ProfileName ).GetProfileConfig (input .ProfileName )
85
- if err != nil {
86
- return fmt .Errorf ("Error loading config: %w" , err )
87
- }
88
-
89
- var credsProvider aws.CredentialsProvider
90
-
84
+ func getCredsProvider (input LoginCommandInput , config * vault.ProfileConfig , keyring keyring.Keyring ) (credsProvider aws.CredentialsProvider , err error ) {
91
85
if input .ProfileName == "" {
92
86
// When no profile is specified, source credentials from the environment
93
87
configFromEnv , err := awsconfig .NewEnvConfig ()
94
88
if err != nil {
95
- return fmt .Errorf ("unable to authenticate to AWS through your environment variables: %w" , err )
89
+ return nil , fmt .Errorf ("unable to authenticate to AWS through your environment variables: %w" , err )
96
90
}
97
- credsProvider = credentials.StaticCredentialsProvider {Value : configFromEnv .Credentials }
98
- if configFromEnv .Credentials .SessionToken == "" {
99
- credsProvider , err = vault .NewFederationTokenProvider (context .TODO (), credsProvider , config )
100
- if err != nil {
101
- return err
102
- }
91
+
92
+ if configFromEnv .Credentials .AccessKeyID == "" {
93
+ return nil , fmt .Errorf ("argument 'profile' not provided, nor any AWS env vars found. Try --help" )
103
94
}
95
+
96
+ credsProvider = credentials.StaticCredentialsProvider {Value : configFromEnv .Credentials }
104
97
} else {
105
98
// Use a profile from the AWS config file
106
99
ckr := & vault.CredentialKeyring {Keyring : keyring }
107
- if config .HasRole () || config .HasSSOStartURL () || config .HasCredentialProcess () || config .HasWebIdentity () {
108
- // If AssumeRole or sso.GetRoleCredentials isn't used, GetFederationToken has to be used for IAM credentials
109
- credsProvider , err = vault .NewTempCredentialsProvider (config , ckr , input .NoSession , false )
110
- } else {
111
- credsProvider , err = vault .NewFederationTokenCredentialsProvider (context .TODO (), input .ProfileName , ckr , config )
100
+ t := vault.TempCredentialsCreator {
101
+ Keyring : ckr ,
102
+ DisableSessions : input .NoSession ,
103
+ DisableSessionsForProfile : config .ProfileName ,
112
104
}
105
+ credsProvider , err = t .GetProviderForProfile (config )
113
106
if err != nil {
114
- return fmt .Errorf ("profile %s: %w" , input .ProfileName , err )
107
+ return nil , fmt .Errorf ("profile %s: %w" , input .ProfileName , err )
115
108
}
116
109
}
117
110
118
- creds , err := credsProvider .Retrieve (context .TODO ())
111
+ return credsProvider , err
112
+ }
113
+
114
+ // LoginCommand creates a login URL for the AWS Management Console using the method described at
115
+ // https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_providers_enable-console-custom-url.html
116
+ func LoginCommand (ctx context.Context , input LoginCommandInput , f * vault.ConfigFile , keyring keyring.Keyring ) error {
117
+ config , err := vault .NewConfigLoader (input .Config , f , input .ProfileName ).GetProfileConfig (input .ProfileName )
119
118
if err != nil {
120
- return fmt .Errorf ("Failed to get credentials: %w" , err )
121
- }
122
- if creds .AccessKeyID == "" && input .ProfileName == "" {
123
- return fmt .Errorf ("argument 'profile' not provided, nor any AWS env vars found. Try --help" )
119
+ return fmt .Errorf ("Error loading config: %w" , err )
124
120
}
125
121
126
- jsonBytes , err := json .Marshal (map [string ]string {
127
- "sessionId" : creds .AccessKeyID ,
128
- "sessionKey" : creds .SecretAccessKey ,
129
- "sessionToken" : creds .SessionToken ,
130
- })
122
+ credsProvider , err := getCredsProvider (input , config , keyring )
131
123
if err != nil {
132
124
return err
133
125
}
134
126
135
- loginURLPrefix , destination := generateLoginURL (config .Region , input .Path )
136
-
137
- req , err := http .NewRequestWithContext (context .TODO (), "GET" , loginURLPrefix , nil )
127
+ // if we already know the type of credentials being created, avoid calling isCallerIdentityAssumedRole
128
+ canCredsBeUsedInLoginURL , err := canProviderBeUsedForLogin (credsProvider )
138
129
if err != nil {
139
130
return err
140
131
}
141
132
142
- if creds .CanExpire {
143
- log .Printf ("Creating login token, expires in %s" , time .Until (creds .Expires ))
144
- }
133
+ if ! canCredsBeUsedInLoginURL {
134
+ // use a static creds provider so that we don't request credentials from AWS more than once
135
+ credsProvider , err = createStaticCredentialsProvider (ctx , credsProvider )
136
+ if err != nil {
137
+ return err
138
+ }
145
139
146
- q := req .URL .Query ()
147
- q .Add ("Action" , "getSigninToken" )
148
- q .Add ("Session" , string (jsonBytes ))
149
- req .URL .RawQuery = q .Encode ()
140
+ // if the credentials have come from an unknown source like credential_process, check the
141
+ // caller identity to see if it's an assumed role
142
+ isAssumedRole , err := isCallerIdentityAssumedRole (ctx , credsProvider , config )
143
+ if err != nil {
144
+ return err
145
+ }
150
146
151
- resp , err := http .DefaultClient .Do (req )
152
- if err != nil {
153
- return err
147
+ if ! isAssumedRole {
148
+ log .Println ("Creating a federated session" )
149
+ credsProvider , err = vault .NewFederationTokenProvider (ctx , credsProvider , config )
150
+ if err != nil {
151
+ return err
152
+ }
153
+ }
154
154
}
155
155
156
- defer resp .Body .Close ()
157
- body , err := io .ReadAll (resp .Body )
156
+ creds , err := credsProvider .Retrieve (ctx )
158
157
if err != nil {
159
158
return err
160
159
}
161
160
162
- if resp .StatusCode != http .StatusOK {
163
- log .Printf ("Response body was %s" , body )
164
- return fmt .Errorf ("Call to getSigninToken failed with %v" , resp .Status )
161
+ if creds .CanExpire {
162
+ log .Printf ("Requesting a signin token for session expiring in %s" , time .Until (creds .Expires ))
165
163
}
166
164
167
- var respParsed map [string ]string
168
-
169
- err = json .Unmarshal (body , & respParsed )
165
+ loginURLPrefix , destination := generateLoginURL (config .Region , input .Path )
166
+ signinToken , err := requestSigninToken (ctx , creds , loginURLPrefix )
170
167
if err != nil {
171
168
return err
172
169
}
173
170
174
- signinToken , ok := respParsed ["SigninToken" ]
175
- if ! ok {
176
- return fmt .Errorf ("Expected a response with SigninToken" )
177
- }
178
-
179
171
loginURL := fmt .Sprintf ("%s?Action=login&Issuer=aws-vault&Destination=%s&SigninToken=%s" ,
180
172
loginURLPrefix , url .QueryEscape (destination ), url .QueryEscape (signinToken ))
181
173
@@ -212,3 +204,99 @@ func generateLoginURL(region string, path string) (string, string) {
212
204
}
213
205
return loginURLPrefix , destination
214
206
}
207
+
208
+ func isCallerIdentityAssumedRole (ctx context.Context , credsProvider aws.CredentialsProvider , config * vault.ProfileConfig ) (bool , error ) {
209
+ cfg := vault .NewAwsConfigWithCredsProvider (credsProvider , config .Region , config .STSRegionalEndpoints )
210
+ client := sts .NewFromConfig (cfg )
211
+ id , err := client .GetCallerIdentity (ctx , nil )
212
+ if err != nil {
213
+ return false , err
214
+ }
215
+ arn := aws .ToString (id .Arn )
216
+ arnParts := strings .Split (arn , ":" )
217
+ if len (arnParts ) < 6 {
218
+ return false , fmt .Errorf ("unable to parse ARN: %s" , arn )
219
+ }
220
+ if strings .HasPrefix (arnParts [5 ], "assumed-role" ) {
221
+ return true , nil
222
+ }
223
+ return false , nil
224
+ }
225
+
226
+ func createStaticCredentialsProvider (ctx context.Context , credsProvider aws.CredentialsProvider ) (sc credentials.StaticCredentialsProvider , err error ) {
227
+ creds , err := credsProvider .Retrieve (ctx )
228
+ if err != nil {
229
+ return sc , err
230
+ }
231
+ return credentials.StaticCredentialsProvider {Value : creds }, nil
232
+ }
233
+
234
+ // canProviderBeUsedForLogin returns true if the credentials produced by the provider is known to be usable by the login URL endpoint
235
+ func canProviderBeUsedForLogin (credsProvider aws.CredentialsProvider ) (bool , error ) {
236
+ if _ , ok := credsProvider .(* vault.AssumeRoleProvider ); ok {
237
+ return true , nil
238
+ }
239
+ if _ , ok := credsProvider .(* vault.SSORoleCredentialsProvider ); ok {
240
+ return true , nil
241
+ }
242
+ if _ , ok := credsProvider .(* vault.AssumeRoleWithWebIdentityProvider ); ok {
243
+ return true , nil
244
+ }
245
+ if c , ok := credsProvider .(* vault.CachedSessionProvider ); ok {
246
+ return canProviderBeUsedForLogin (c .SessionProvider )
247
+ }
248
+
249
+ return false , nil
250
+ }
251
+
252
+ // Create a signin token
253
+ func requestSigninToken (ctx context.Context , creds aws.Credentials , loginURLPrefix string ) (string , error ) {
254
+ jsonSession , err := json .Marshal (map [string ]string {
255
+ "sessionId" : creds .AccessKeyID ,
256
+ "sessionKey" : creds .SecretAccessKey ,
257
+ "sessionToken" : creds .SessionToken ,
258
+ })
259
+ if err != nil {
260
+ return "" , err
261
+ }
262
+
263
+ req , err := http .NewRequestWithContext (ctx , "GET" , loginURLPrefix , nil )
264
+ if err != nil {
265
+ return "" , err
266
+ }
267
+
268
+ q := req .URL .Query ()
269
+ q .Add ("Action" , "getSigninToken" )
270
+ q .Add ("Session" , string (jsonSession ))
271
+ req .URL .RawQuery = q .Encode ()
272
+
273
+ resp , err := http .DefaultClient .Do (req )
274
+ if err != nil {
275
+ return "" , err
276
+ }
277
+
278
+ defer resp .Body .Close ()
279
+ body , err := io .ReadAll (resp .Body )
280
+ if err != nil {
281
+ return "" , err
282
+ }
283
+
284
+ if resp .StatusCode != http .StatusOK {
285
+ log .Printf ("Response body was %s" , body )
286
+ return "" , fmt .Errorf ("Call to getSigninToken failed with %v" , resp .Status )
287
+ }
288
+
289
+ var respParsed map [string ]string
290
+
291
+ err = json .Unmarshal (body , & respParsed )
292
+ if err != nil {
293
+ return "" , err
294
+ }
295
+
296
+ signinToken , ok := respParsed ["SigninToken" ]
297
+ if ! ok {
298
+ return "" , fmt .Errorf ("Expected a response with SigninToken" )
299
+ }
300
+
301
+ return signinToken , nil
302
+ }
0 commit comments