Skip to content

Commit d4cbe3e

Browse files
committed
Update aws-sdk-go and change way to get regional sts endpoint
1 parent 712887d commit d4cbe3e

File tree

4 files changed

+50
-16
lines changed

4 files changed

+50
-16
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ go 1.22.5
44

55
require (
66
github.com/aws/amazon-vpc-cni-k8s v1.18.1
7-
github.com/aws/aws-sdk-go v1.51.32
7+
github.com/aws/aws-sdk-go v1.55.5
88
github.com/go-logr/logr v1.4.2
99
github.com/go-logr/zapr v1.3.0
1010
github.com/golang/mock v1.6.0

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ github.com/aws/amazon-vpc-cni-k8s v1.18.1 h1:u/OeBgnUUX6f3PCEOpA4dbG0+iZ71CnY6tE
44
github.com/aws/amazon-vpc-cni-k8s v1.18.1/go.mod h1:m/J5GsxF0Th2iQTOE3ww4W9LFvwdC0tGyA9dIL4h6iQ=
55
github.com/aws/aws-sdk-go v1.51.32 h1:A6mPui7QP4mwmovyzgtdedbRbNur1Iu0/El7hBWNHms=
66
github.com/aws/aws-sdk-go v1.51.32/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
7+
github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU=
8+
github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
79
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
810
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
911
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=

pkg/aws/ec2/api/wrapper.go

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -510,21 +510,21 @@ func (e *ec2Wrapper) getClientUsingAssumedRole(instanceRegion, roleARN, clusterN
510510
}
511511
e.log.Info("created rate limited http client", "qps", qps, "burst", burst)
512512

513-
// Get the regional sts end point
514-
regionalSTSEndpoint, err := endpoints.DefaultResolver().
515-
EndpointFor("sts", aws.StringValue(userStsSession.Config.Region), endpoints.STSRegionalEndpointOption)
516-
if err != nil {
517-
return nil, fmt.Errorf("failed to get the regional sts endoint for region %s: %v",
518-
*userStsSession.Config.Region, err)
519-
}
520-
513+
// GetPartition ID, SourceAccount and SourceARN
521514
roleARN = strings.Trim(roleARN, "\"")
522515

523-
sourceAcct, sourceArn, err := utils.GetSourceAcctAndArn(roleARN, region, clusterName)
516+
sourceAcct, partitionID, sourceArn, err := utils.GetSourceAcctAndArn(roleARN, region, clusterName)
524517
if err != nil {
525518
return nil, err
526519
}
527520

521+
// Get the regional sts end point
522+
regionalSTSEndpoint, err := e.getRegionalStsEndpoint(partitionID, region)
523+
if err != nil {
524+
return nil, fmt.Errorf("failed to get the regional sts endpoint for region %s: %v %v",
525+
*userStsSession.Config.Region, err, partitionID)
526+
}
527+
528528
regionalProvider := &stscreds.AssumeRoleProvider{
529529
Client: e.createSTSClient(userStsSession, client, regionalSTSEndpoint, sourceAcct, sourceArn),
530530
RoleARN: roleARN,
@@ -547,7 +547,7 @@ func (e *ec2Wrapper) getClientUsingAssumedRole(instanceRegion, roleARN, clusterN
547547
// If the regional STS endpoint is different than the global STS endpoint then add the global sts endpoint
548548
if regionalSTSEndpoint.URL != globalSTSEndpoint.URL {
549549
globalProvider := &stscreds.AssumeRoleProvider{
550-
Client: e.createSTSClient(userStsSession, client, regionalSTSEndpoint, sourceAcct, sourceArn),
550+
Client: e.createSTSClient(userStsSession, client, globalSTSEndpoint, sourceAcct, sourceArn),
551551
RoleARN: roleARN,
552552
Duration: time.Minute * 60,
553553
}
@@ -892,3 +892,35 @@ func (e *ec2Wrapper) DisassociateTrunkInterface(input *ec2.DisassociateTrunkInte
892892
}
893893
return err
894894
}
895+
896+
func (e *ec2Wrapper) getRegionalStsEndpoint(partitionID, region string) (endpoints.ResolvedEndpoint, error) {
897+
var partition *endpoints.Partition
898+
var stsServiceID = "sts"
899+
for _, p := range endpoints.DefaultPartitions() {
900+
if partitionID == p.ID() {
901+
partition = &p
902+
break
903+
}
904+
}
905+
if partition == nil {
906+
return endpoints.ResolvedEndpoint{}, fmt.Errorf("partition %s not valid", partitionID)
907+
}
908+
909+
stsSvc, ok := partition.Services()[stsServiceID]
910+
if !ok {
911+
e.log.Info("STS service not found in partition, generating default endpoint.", "Partition:", partitionID)
912+
// Add the host of the current instances region if the service doesn't already exists in the partition
913+
// so we don't fail if the service is not present in the go sdk but matches the instances region.
914+
res, err := partition.EndpointFor(stsServiceID, region, endpoints.STSRegionalEndpointOption, endpoints.ResolveUnknownServiceOption)
915+
if err != nil {
916+
return endpoints.ResolvedEndpoint{}, fmt.Errorf("error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err)
917+
}
918+
return res, nil
919+
}
920+
921+
res, err := stsSvc.ResolveEndpoint(region, endpoints.STSRegionalEndpointOption)
922+
if err != nil {
923+
return endpoints.ResolvedEndpoint{}, fmt.Errorf("error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err)
924+
}
925+
return res, nil
926+
}

pkg/utils/helper.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,22 +213,22 @@ func IsNitroInstance(instanceType string) (bool, error) {
213213
}
214214

215215
// GetSourceAcctAndArn constructs source acct and arn and return them for use
216-
func GetSourceAcctAndArn(roleARN, region, clusterName string) (string, string, error) {
216+
func GetSourceAcctAndArn(roleARN, region, clusterName string) (string, string, string, error) {
217217
// ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
218218
// arn:partition:service:region:account-id:resource-type/resource-id
219219
// IAM format, region is always blank
220220
// arn:aws:iam::account:role/role-name-with-path
221221
if !arn.IsARN(roleARN) {
222-
return "", "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
222+
return "", "", "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
223223
} else if region == "" {
224-
return "", "", nil
224+
return "", "", "", nil
225225
}
226226

227227
parsedArn, err := arn.Parse(roleARN)
228228
if err != nil {
229-
return "", "", err
229+
return "", "", "", err
230230
}
231231

232232
sourceArn := fmt.Sprintf("arn:%s:eks:%s:%s:cluster/%s", parsedArn.Partition, region, parsedArn.AccountID, clusterName)
233-
return parsedArn.AccountID, sourceArn, nil
233+
return parsedArn.AccountID, parsedArn.Partition, sourceArn, nil
234234
}

0 commit comments

Comments
 (0)