Skip to content

Commit 2dd80e5

Browse files
backport of commit 8a84d13 (#30417)
Co-authored-by: kpcraig <[email protected]>
1 parent eb14be9 commit 2dd80e5

File tree

3 files changed

+286
-35
lines changed

3 files changed

+286
-35
lines changed

builtin/logical/aws/client.go

Lines changed: 105 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,79 @@ import (
2424
"github.com/hashicorp/vault/sdk/logical"
2525
)
2626

27-
const fallbackEndpoint = "https://sts.amazonaws.com" // this is not regionally distributed; all requests go to us-east-1
27+
// getRootIAMConfig creates an *aws.Config for Vault to connect to IAM.
28+
func (b *backend) getRootIAMConfig(ctx context.Context, s logical.Storage, logger hclog.Logger) (*aws.Config, error) {
29+
credsConfig := &awsutil.CredentialsConfig{}
30+
var endpoint string
31+
var maxRetries int = aws.UseServiceDefaultRetries
32+
33+
entry, err := s.Get(ctx, "config/root")
34+
if err != nil {
35+
return nil, err
36+
}
37+
if entry != nil {
38+
var config rootConfig
39+
if err := entry.DecodeJSON(&config); err != nil {
40+
return nil, fmt.Errorf("error reading root configuration: %w", err)
41+
}
42+
43+
credsConfig.AccessKey = config.AccessKey
44+
credsConfig.SecretKey = config.SecretKey
45+
credsConfig.Region = config.Region
46+
maxRetries = config.MaxRetries
47+
48+
if config.IAMEndpoint != "" {
49+
endpoint = *aws.String(config.IAMEndpoint)
50+
}
51+
52+
if config.IdentityTokenAudience != "" {
53+
ns, err := namespace.FromContext(ctx)
54+
if err != nil {
55+
return nil, fmt.Errorf("failed to get namespace from context: %w", err)
56+
}
57+
58+
fetcher := &PluginIdentityTokenFetcher{
59+
sys: b.System(),
60+
logger: b.Logger(),
61+
ns: ns,
62+
audience: config.IdentityTokenAudience,
63+
ttl: config.IdentityTokenTTL,
64+
}
65+
66+
sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10)
67+
credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix)
68+
credsConfig.WebIdentityTokenFetcher = fetcher
69+
credsConfig.RoleARN = config.RoleARN
70+
}
71+
}
72+
73+
if credsConfig.Region == "" {
74+
credsConfig.Region = getFallbackRegion()
75+
}
76+
77+
credsConfig.HTTPClient = cleanhttp.DefaultClient()
78+
79+
credsConfig.Logger = logger
80+
81+
creds, err := credsConfig.GenerateCredentialChain()
82+
if err != nil {
83+
return nil, err
84+
}
85+
86+
return &aws.Config{
87+
Credentials: creds,
88+
Region: aws.String(credsConfig.Region),
89+
Endpoint: &endpoint,
90+
HTTPClient: cleanhttp.DefaultClient(),
91+
MaxRetries: aws.Int(maxRetries),
92+
}, nil
93+
}
2894

2995
// Return a slice of *aws.Config, based on descending configuration priority. STS endpoints are the only place this is used.
3096
// NOTE: The caller is required to ensure that b.clientMutex is at least read locked
31-
func (b *backend) getRootConfigs(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) ([]*aws.Config, error) {
97+
func (b *backend) getRootSTSConfigs(ctx context.Context, s logical.Storage, logger hclog.Logger) ([]*aws.Config, error) {
3298
// set fallback region (we can overwrite later)
33-
fallbackRegion := os.Getenv("AWS_REGION")
34-
if fallbackRegion == "" {
35-
fallbackRegion = os.Getenv("AWS_DEFAULT_REGION")
36-
}
37-
if fallbackRegion == "" {
38-
fallbackRegion = "us-east-1"
39-
}
99+
fallbackRegion := getFallbackRegion()
40100

41101
maxRetries := aws.UseServiceDefaultRetries
42102

@@ -81,13 +141,16 @@ func (b *backend) getRootConfigs(ctx context.Context, s logical.Storage, clientT
81141
credsConfig.HTTPClient = cleanhttp.DefaultClient()
82142
credsConfig.Logger = logger
83143

144+
if config.Region != "" {
145+
regions = append(regions, config.Region)
146+
}
147+
84148
maxRetries = config.MaxRetries
85-
if clientType == "iam" && config.IAMEndpoint != "" {
86-
endpoints = append(endpoints, config.IAMEndpoint)
87-
} else if clientType == "sts" && config.STSEndpoint != "" {
149+
if config.STSEndpoint != "" {
88150
endpoints = append(endpoints, config.STSEndpoint)
89151
if config.STSRegion != "" {
90-
regions = append(regions, config.STSRegion)
152+
// this retains original logic, where sts region was only used if sts endpoint was set
153+
regions = []string{config.STSRegion} // override to be "only" region if set
91154
}
92155

93156
if len(config.STSFallbackEndpoints) > 0 {
@@ -124,23 +187,22 @@ func (b *backend) getRootConfigs(ctx context.Context, s logical.Storage, clientT
124187
opts = append(opts, awsutil.WithEnvironmentCredentials(false), awsutil.WithSharedCredentials(false))
125188
}
126189

127-
// at this point, in the IAM case, regions contains nothing, and endpoints contains iam_endpoint, if it was set.
128-
// in the sts case, regions contains sts_region, if it was set, then the sts_fallback_regions in order, if they were set.
129-
// endpoints contains sts_endpint, if it wa set, then sts_fallback_endpoints in order, if they were set.
190+
// at this point, in the IAM case,
191+
// - regions contains config.Region, if it was set.
192+
// - endpoints contains iam_endpoint, if it was set.
193+
// in the sts case,
194+
// - regions contains sts_region, if it was set, then sts_fallback_regions in order, if they were set.
195+
// - endpoints contains sts_endpoint, if it was set, then sts_fallback_endpoints in order, if they were set.
130196

131197
// case in which nothing was supplied
132198
if len(regions) == 0 {
133199
// fallback region is in descending order, AWS_REGION, or AWS_DEFAULT_REGION, or us-east-1
134200
regions = append(regions, fallbackRegion)
201+
}
135202

136-
// we also need to set the endpoint based on this region (since we need matched length arrays)
137-
if len(endpoints) == 0 {
138-
switch clientType {
139-
case "sts":
140-
endpoints = append(endpoints, matchingSTSEndpoint(fallbackRegion))
141-
case "iam":
142-
endpoints = append(endpoints, "https://iam.amazonaws.com") // see https://docs.aws.amazon.com/general/latest/gr/iam-service.html
143-
}
203+
if len(endpoints) == 0 {
204+
for _, v := range regions {
205+
endpoints = append(endpoints, matchingSTSEndpoint(v))
144206
}
145207
}
146208

@@ -181,14 +243,10 @@ func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, log
181243
return nil, fmt.Errorf("failed to assume role %q: %w", entry.AssumeRoleARN, err)
182244
}
183245
} else {
184-
configs, err := b.getRootConfigs(ctx, s, "iam", logger)
246+
awsConfig, err = b.getRootIAMConfig(ctx, s, logger)
185247
if err != nil {
186248
return nil, err
187249
}
188-
if len(configs) != 1 {
189-
return nil, errors.New("could not obtain aws config")
190-
}
191-
awsConfig = configs[0]
192250
}
193251

194252
sess, err := session.NewSession(awsConfig)
@@ -203,7 +261,7 @@ func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, log
203261
}
204262

205263
func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
206-
awsConfig, err := b.getRootConfigs(ctx, s, "sts", logger)
264+
awsConfig, err := b.getRootSTSConfigs(ctx, s, logger)
207265
if err != nil {
208266
return nil, err
209267
}
@@ -238,6 +296,23 @@ func matchingSTSEndpoint(stsRegion string) string {
238296
return fmt.Sprintf("https://sts.%s.amazonaws.com", stsRegion)
239297
}
240298

299+
// getFallbackRegion returns an aws region fallback. It will check in the AWS specified order:
300+
// - AWS_REGION, then
301+
// - AWS_DEFAULT_REGION, then
302+
// - us-east-1
303+
func getFallbackRegion() string {
304+
// set fallback region (we can overwrite later)
305+
fallbackRegion := os.Getenv("AWS_REGION")
306+
if fallbackRegion == "" {
307+
fallbackRegion = os.Getenv("AWS_DEFAULT_REGION")
308+
}
309+
if fallbackRegion == "" {
310+
fallbackRegion = "us-east-1"
311+
}
312+
313+
return fallbackRegion
314+
}
315+
241316
// PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided
242317
// to the AWS SDK client to keep assumed role credentials refreshed through expiration.
243318
// When the client's STS credentials expire, it will use this interface to fetch a new

0 commit comments

Comments
 (0)