@@ -24,19 +24,79 @@ import (
24
24
"github.com/hashicorp/vault/sdk/logical"
25
25
)
26
26
27
- const fallbackEndpoint = "https://sts.amazonaws.com" // this is not regionally distributed; all requests go to us-east-1
27
+ // getRootIAMConfig creates an *aws.Config for Vault to connect to IAM.
28
+ func (b * backend ) getRootIAMConfig (ctx context.Context , s logical.Storage , logger hclog.Logger ) (* aws.Config , error ) {
29
+ credsConfig := & awsutil.CredentialsConfig {}
30
+ var endpoint string
31
+ var maxRetries int = aws .UseServiceDefaultRetries
32
+
33
+ entry , err := s .Get (ctx , "config/root" )
34
+ if err != nil {
35
+ return nil , err
36
+ }
37
+ if entry != nil {
38
+ var config rootConfig
39
+ if err := entry .DecodeJSON (& config ); err != nil {
40
+ return nil , fmt .Errorf ("error reading root configuration: %w" , err )
41
+ }
42
+
43
+ credsConfig .AccessKey = config .AccessKey
44
+ credsConfig .SecretKey = config .SecretKey
45
+ credsConfig .Region = config .Region
46
+ maxRetries = config .MaxRetries
47
+
48
+ if config .IAMEndpoint != "" {
49
+ endpoint = * aws .String (config .IAMEndpoint )
50
+ }
51
+
52
+ if config .IdentityTokenAudience != "" {
53
+ ns , err := namespace .FromContext (ctx )
54
+ if err != nil {
55
+ return nil , fmt .Errorf ("failed to get namespace from context: %w" , err )
56
+ }
57
+
58
+ fetcher := & PluginIdentityTokenFetcher {
59
+ sys : b .System (),
60
+ logger : b .Logger (),
61
+ ns : ns ,
62
+ audience : config .IdentityTokenAudience ,
63
+ ttl : config .IdentityTokenTTL ,
64
+ }
65
+
66
+ sessionSuffix := strconv .FormatInt (time .Now ().UnixNano (), 10 )
67
+ credsConfig .RoleSessionName = fmt .Sprintf ("vault-aws-secrets-%s" , sessionSuffix )
68
+ credsConfig .WebIdentityTokenFetcher = fetcher
69
+ credsConfig .RoleARN = config .RoleARN
70
+ }
71
+ }
72
+
73
+ if credsConfig .Region == "" {
74
+ credsConfig .Region = getFallbackRegion ()
75
+ }
76
+
77
+ credsConfig .HTTPClient = cleanhttp .DefaultClient ()
78
+
79
+ credsConfig .Logger = logger
80
+
81
+ creds , err := credsConfig .GenerateCredentialChain ()
82
+ if err != nil {
83
+ return nil , err
84
+ }
85
+
86
+ return & aws.Config {
87
+ Credentials : creds ,
88
+ Region : aws .String (credsConfig .Region ),
89
+ Endpoint : & endpoint ,
90
+ HTTPClient : cleanhttp .DefaultClient (),
91
+ MaxRetries : aws .Int (maxRetries ),
92
+ }, nil
93
+ }
28
94
29
95
// Return a slice of *aws.Config, based on descending configuration priority. STS endpoints are the only place this is used.
30
96
// NOTE: The caller is required to ensure that b.clientMutex is at least read locked
31
- func (b * backend ) getRootConfigs (ctx context.Context , s logical.Storage , clientType string , logger hclog.Logger ) ([]* aws.Config , error ) {
97
+ func (b * backend ) getRootSTSConfigs (ctx context.Context , s logical.Storage , logger hclog.Logger ) ([]* aws.Config , error ) {
32
98
// set fallback region (we can overwrite later)
33
- fallbackRegion := os .Getenv ("AWS_REGION" )
34
- if fallbackRegion == "" {
35
- fallbackRegion = os .Getenv ("AWS_DEFAULT_REGION" )
36
- }
37
- if fallbackRegion == "" {
38
- fallbackRegion = "us-east-1"
39
- }
99
+ fallbackRegion := getFallbackRegion ()
40
100
41
101
maxRetries := aws .UseServiceDefaultRetries
42
102
@@ -81,13 +141,16 @@ func (b *backend) getRootConfigs(ctx context.Context, s logical.Storage, clientT
81
141
credsConfig .HTTPClient = cleanhttp .DefaultClient ()
82
142
credsConfig .Logger = logger
83
143
144
+ if config .Region != "" {
145
+ regions = append (regions , config .Region )
146
+ }
147
+
84
148
maxRetries = config .MaxRetries
85
- if clientType == "iam" && config .IAMEndpoint != "" {
86
- endpoints = append (endpoints , config .IAMEndpoint )
87
- } else if clientType == "sts" && config .STSEndpoint != "" {
149
+ if config .STSEndpoint != "" {
88
150
endpoints = append (endpoints , config .STSEndpoint )
89
151
if config .STSRegion != "" {
90
- regions = append (regions , config .STSRegion )
152
+ // this retains original logic, where sts region was only used if sts endpoint was set
153
+ regions = []string {config .STSRegion } // override to be "only" region if set
91
154
}
92
155
93
156
if len (config .STSFallbackEndpoints ) > 0 {
@@ -124,23 +187,22 @@ func (b *backend) getRootConfigs(ctx context.Context, s logical.Storage, clientT
124
187
opts = append (opts , awsutil .WithEnvironmentCredentials (false ), awsutil .WithSharedCredentials (false ))
125
188
}
126
189
127
- // at this point, in the IAM case, regions contains nothing, and endpoints contains iam_endpoint, if it was set.
128
- // in the sts case, regions contains sts_region, if it was set, then the sts_fallback_regions in order, if they were set.
129
- // endpoints contains sts_endpint, if it wa set, then sts_fallback_endpoints in order, if they were set.
190
+ // at this point, in the IAM case,
191
+ // - regions contains config.Region, if it was set.
192
+ // - endpoints contains iam_endpoint, if it was set.
193
+ // in the sts case,
194
+ // - regions contains sts_region, if it was set, then sts_fallback_regions in order, if they were set.
195
+ // - endpoints contains sts_endpoint, if it was set, then sts_fallback_endpoints in order, if they were set.
130
196
131
197
// case in which nothing was supplied
132
198
if len (regions ) == 0 {
133
199
// fallback region is in descending order, AWS_REGION, or AWS_DEFAULT_REGION, or us-east-1
134
200
regions = append (regions , fallbackRegion )
201
+ }
135
202
136
- // we also need to set the endpoint based on this region (since we need matched length arrays)
137
- if len (endpoints ) == 0 {
138
- switch clientType {
139
- case "sts" :
140
- endpoints = append (endpoints , matchingSTSEndpoint (fallbackRegion ))
141
- case "iam" :
142
- endpoints = append (endpoints , "https://iam.amazonaws.com" ) // see https://docs.aws.amazon.com/general/latest/gr/iam-service.html
143
- }
203
+ if len (endpoints ) == 0 {
204
+ for _ , v := range regions {
205
+ endpoints = append (endpoints , matchingSTSEndpoint (v ))
144
206
}
145
207
}
146
208
@@ -181,14 +243,10 @@ func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, log
181
243
return nil , fmt .Errorf ("failed to assume role %q: %w" , entry .AssumeRoleARN , err )
182
244
}
183
245
} else {
184
- configs , err : = b .getRootConfigs (ctx , s , "iam" , logger )
246
+ awsConfig , err = b .getRootIAMConfig (ctx , s , logger )
185
247
if err != nil {
186
248
return nil , err
187
249
}
188
- if len (configs ) != 1 {
189
- return nil , errors .New ("could not obtain aws config" )
190
- }
191
- awsConfig = configs [0 ]
192
250
}
193
251
194
252
sess , err := session .NewSession (awsConfig )
@@ -203,7 +261,7 @@ func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, log
203
261
}
204
262
205
263
func (b * backend ) nonCachedClientSTS (ctx context.Context , s logical.Storage , logger hclog.Logger ) (* sts.STS , error ) {
206
- awsConfig , err := b .getRootConfigs (ctx , s , "sts" , logger )
264
+ awsConfig , err := b .getRootSTSConfigs (ctx , s , logger )
207
265
if err != nil {
208
266
return nil , err
209
267
}
@@ -238,6 +296,23 @@ func matchingSTSEndpoint(stsRegion string) string {
238
296
return fmt .Sprintf ("https://sts.%s.amazonaws.com" , stsRegion )
239
297
}
240
298
299
+ // getFallbackRegion returns an aws region fallback. It will check in the AWS specified order:
300
+ // - AWS_REGION, then
301
+ // - AWS_DEFAULT_REGION, then
302
+ // - us-east-1
303
+ func getFallbackRegion () string {
304
+ // set fallback region (we can overwrite later)
305
+ fallbackRegion := os .Getenv ("AWS_REGION" )
306
+ if fallbackRegion == "" {
307
+ fallbackRegion = os .Getenv ("AWS_DEFAULT_REGION" )
308
+ }
309
+ if fallbackRegion == "" {
310
+ fallbackRegion = "us-east-1"
311
+ }
312
+
313
+ return fallbackRegion
314
+ }
315
+
241
316
// PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided
242
317
// to the AWS SDK client to keep assumed role credentials refreshed through expiration.
243
318
// When the client's STS credentials expire, it will use this interface to fetch a new
0 commit comments