Skip to content

Commit 19ed9ef

Browse files
authored
Update aws-sdk-go and change way to get regional sts endpoint (#466)
1 parent 712887d commit 19ed9ef

File tree

6 files changed

+124
-22
lines changed

6 files changed

+124
-22
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ go 1.22.5
44

55
require (
66
github.com/aws/amazon-vpc-cni-k8s v1.18.1
7-
github.com/aws/aws-sdk-go v1.51.32
7+
github.com/aws/aws-sdk-go v1.55.5
88
github.com/go-logr/logr v1.4.2
99
github.com/go-logr/zapr v1.3.0
1010
github.com/golang/mock v1.6.0

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPd
22
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
33
github.com/aws/amazon-vpc-cni-k8s v1.18.1 h1:u/OeBgnUUX6f3PCEOpA4dbG0+iZ71CnY6tEljjrl3iw=
44
github.com/aws/amazon-vpc-cni-k8s v1.18.1/go.mod h1:m/J5GsxF0Th2iQTOE3ww4W9LFvwdC0tGyA9dIL4h6iQ=
5-
github.com/aws/aws-sdk-go v1.51.32 h1:A6mPui7QP4mwmovyzgtdedbRbNur1Iu0/El7hBWNHms=
6-
github.com/aws/aws-sdk-go v1.51.32/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
5+
github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU=
6+
github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
77
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
88
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
99
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=

pkg/aws/ec2/api/wrapper.go

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -510,21 +510,21 @@ func (e *ec2Wrapper) getClientUsingAssumedRole(instanceRegion, roleARN, clusterN
510510
}
511511
e.log.Info("created rate limited http client", "qps", qps, "burst", burst)
512512

513-
// Get the regional sts end point
514-
regionalSTSEndpoint, err := endpoints.DefaultResolver().
515-
EndpointFor("sts", aws.StringValue(userStsSession.Config.Region), endpoints.STSRegionalEndpointOption)
516-
if err != nil {
517-
return nil, fmt.Errorf("failed to get the regional sts endoint for region %s: %v",
518-
*userStsSession.Config.Region, err)
519-
}
520-
513+
// GetPartition ID, SourceAccount and SourceARN
521514
roleARN = strings.Trim(roleARN, "\"")
522515

523-
sourceAcct, sourceArn, err := utils.GetSourceAcctAndArn(roleARN, region, clusterName)
516+
sourceAcct, partitionID, sourceArn, err := utils.GetSourceAcctAndArn(roleARN, region, clusterName)
524517
if err != nil {
525518
return nil, err
526519
}
527520

521+
// Get the regional sts end point
522+
regionalSTSEndpoint, err := e.getRegionalStsEndpoint(partitionID, region)
523+
if err != nil {
524+
return nil, fmt.Errorf("failed to get the regional sts endpoint for region %s: %v %v",
525+
*userStsSession.Config.Region, err, partitionID)
526+
}
527+
528528
regionalProvider := &stscreds.AssumeRoleProvider{
529529
Client: e.createSTSClient(userStsSession, client, regionalSTSEndpoint, sourceAcct, sourceArn),
530530
RoleARN: roleARN,
@@ -892,3 +892,35 @@ func (e *ec2Wrapper) DisassociateTrunkInterface(input *ec2.DisassociateTrunkInte
892892
}
893893
return err
894894
}
895+
896+
func (e *ec2Wrapper) getRegionalStsEndpoint(partitionID, region string) (endpoints.ResolvedEndpoint, error) {
897+
var partition *endpoints.Partition
898+
var stsServiceID = "sts"
899+
for _, p := range endpoints.DefaultPartitions() {
900+
if partitionID == p.ID() {
901+
partition = &p
902+
break
903+
}
904+
}
905+
if partition == nil {
906+
return endpoints.ResolvedEndpoint{}, fmt.Errorf("partition %s not valid", partitionID)
907+
}
908+
909+
stsSvc, ok := partition.Services()[stsServiceID]
910+
if !ok {
911+
e.log.Info("STS service not found in partition, generating default endpoint.", "Partition:", partitionID)
912+
// Add the host of the current instances region if the service doesn't already exists in the partition
913+
// so we don't fail if the service is not present in the go sdk but matches the instances region.
914+
res, err := partition.EndpointFor(stsServiceID, region, endpoints.STSRegionalEndpointOption, endpoints.ResolveUnknownServiceOption)
915+
if err != nil {
916+
return endpoints.ResolvedEndpoint{}, fmt.Errorf("error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err)
917+
}
918+
return res, nil
919+
}
920+
921+
res, err := stsSvc.ResolveEndpoint(region, endpoints.STSRegionalEndpointOption)
922+
if err != nil {
923+
return endpoints.ResolvedEndpoint{}, fmt.Errorf("error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err)
924+
}
925+
return res, nil
926+
}

pkg/aws/ec2/api/wrapper_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package api
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func getMockEC2Wrapper() ec2Wrapper {
8+
return ec2Wrapper{}
9+
}
10+
func Test_getRegionalStsEndpoint(t *testing.T) {
11+
12+
ec2Wapper := getMockEC2Wrapper()
13+
14+
type args struct {
15+
partitionID string
16+
region string
17+
}
18+
19+
tests := []struct {
20+
name string
21+
args args
22+
want string
23+
wantErr bool
24+
}{
25+
{
26+
name: "service doesn't exist in partition",
27+
args: args{
28+
partitionID: "aws-iso-f",
29+
region: "testregions",
30+
},
31+
want: "https://sts.testregions.csp.hci.ic.gov",
32+
wantErr: false,
33+
},
34+
{
35+
name: "region doesn't exist in partition",
36+
args: args{
37+
partitionID: "aws",
38+
region: "us-test-2",
39+
},
40+
want: "https://sts.us-test-2.amazonaws.com",
41+
wantErr: false,
42+
},
43+
{
44+
name: "region and service exist in partition",
45+
args: args{
46+
partitionID: "aws",
47+
region: "us-west-2",
48+
},
49+
want: "https://sts.us-west-2.amazonaws.com",
50+
wantErr: false,
51+
},
52+
}
53+
for _, tt := range tests {
54+
t.Run(tt.name, func(t *testing.T) {
55+
got, err := ec2Wapper.getRegionalStsEndpoint(tt.args.partitionID, tt.args.region)
56+
if (err != nil) != tt.wantErr {
57+
t.Errorf("getRegionalStsEndpoint() error = %v, wantErr %v", err, tt.wantErr)
58+
return
59+
}
60+
if got.URL != tt.want {
61+
t.Errorf("getRegionalStsEndpoint() = %v, want %v", got, tt.want)
62+
}
63+
})
64+
}
65+
}

pkg/utils/helper.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,22 +213,22 @@ func IsNitroInstance(instanceType string) (bool, error) {
213213
}
214214

215215
// GetSourceAcctAndArn constructs source acct and arn and return them for use
216-
func GetSourceAcctAndArn(roleARN, region, clusterName string) (string, string, error) {
216+
func GetSourceAcctAndArn(roleARN, region, clusterName string) (string, string, string, error) {
217217
// ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
218218
// arn:partition:service:region:account-id:resource-type/resource-id
219219
// IAM format, region is always blank
220220
// arn:aws:iam::account:role/role-name-with-path
221221
if !arn.IsARN(roleARN) {
222-
return "", "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
222+
return "", "", "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
223223
} else if region == "" {
224-
return "", "", nil
224+
return "", "", "", nil
225225
}
226226

227227
parsedArn, err := arn.Parse(roleARN)
228228
if err != nil {
229-
return "", "", err
229+
return "", "", "", err
230230
}
231231

232232
sourceArn := fmt.Sprintf("arn:%s:eks:%s:%s:cluster/%s", parsedArn.Partition, region, parsedArn.AccountID, clusterName)
233-
return parsedArn.AccountID, sourceArn, nil
233+
return parsedArn.AccountID, parsedArn.Partition, sourceArn, nil
234234
}

pkg/utils/helper_test.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -538,26 +538,29 @@ func TestGetSourceAcctAndArn(t *testing.T) {
538538
clusterName := "test-cluster"
539539
region := "us-west-2"
540540
clusterARN := "arn:aws:eks:us-west-2:123456789876:cluster/test-cluster"
541-
541+
partition := "aws"
542542
roleARN := "arn:aws:iam::123456789876:role/test-cluster"
543543

544544
// test correct inputs
545-
acct, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
545+
acct, part, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
546546
assert.NoError(t, err, "no error should be returned with accurate role arn")
547+
assert.Equal(t, partition, part, "correct partition should be retrieved")
547548
assert.Equal(t, accountID, acct, "correct account ID should be retrieved")
548549
assert.Equal(t, clusterARN, arn, "correct cluster arn should be retrieved")
549550

550551
region = "us-gov-west-1"
551552
roleARN = "arn:aws-us-gov:iam::123456789876:role/test-cluster"
552553
clusterARN = "arn:aws-us-gov:eks:us-gov-west-1:123456789876:cluster/test-cluster"
553-
acct, arn, err = GetSourceAcctAndArn(roleARN, region, clusterName)
554+
partition = "aws-us-gov"
555+
acct, part, arn, err = GetSourceAcctAndArn(roleARN, region, clusterName)
554556
assert.NoError(t, err, "no error should be returned with accurate aws-us-gov partition role arn")
555557
assert.Equal(t, accountID, acct, "correct account ID should be retrieved")
558+
assert.Equal(t, partition, part, "correct patition should be retrieved")
556559
assert.Equal(t, clusterARN, arn, "correct gov partition cluster arn should be retrieved")
557560

558561
// test error handling
559562
roleARN = "arn:aws:iam::123456789876"
560-
_, _, err = GetSourceAcctAndArn(roleARN, region, clusterName)
563+
_, _, _, err = GetSourceAcctAndArn(roleARN, region, clusterName)
561564
assert.Error(t, err, "error should be returned with inaccurate role arn is given")
562565
}
563566

@@ -569,8 +572,10 @@ func TestGetSourceAcctAndArn_NoRegion(t *testing.T) {
569572
roleARN := "arn:aws:iam::123456789876:role/test-cluster"
570573

571574
// test correct inputs
572-
acct, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
575+
acct, part, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
573576
assert.NoError(t, err, "no error should be returned with accurate role arn")
574577
assert.Equal(t, "", acct, "correct account ID should be retrieved")
575578
assert.Equal(t, "", arn, "correct cluster arn should be retrieved")
579+
assert.Equal(t, "", part, "correct partiton should be retrieved")
580+
576581
}

0 commit comments

Comments
 (0)