Skip to content

Commit 8bd38d5

Browse files
jaydeokaryash97
authored andcommitted
Update aws-sdk-go and change way to get regional sts endpoint (aws#466)
1 parent a1f7176 commit 8bd38d5

File tree

6 files changed

+128
-26
lines changed

6 files changed

+128
-26
lines changed

go.mod

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ module github.com/aws/amazon-vpc-resource-controller-k8s
33
go 1.21
44

55
require (
6-
github.com/aws/amazon-vpc-cni-k8s v1.17.1
7-
github.com/aws/aws-sdk-go v1.51.12
8-
github.com/go-logr/logr v1.4.1
6+
github.com/aws/amazon-vpc-cni-k8s v1.18.1
7+
github.com/aws/aws-sdk-go v1.55.5
8+
github.com/go-logr/logr v1.4.2
99
github.com/go-logr/zapr v1.3.0
1010
github.com/golang/mock v1.6.0
1111
github.com/google/uuid v1.6.0

go.sum

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
22
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
3-
github.com/aws/amazon-vpc-cni-k8s v1.17.1 h1:pF+AmlGbgK8/e58LbtOsLUzDy2hqI8Ug/D8Xxx7+Sis=
4-
github.com/aws/amazon-vpc-cni-k8s v1.17.1/go.mod h1:fNfKsEUNrAj+046SGML0UQWLcsF7hAsKRqnvwIcflvw=
5-
github.com/aws/aws-sdk-go v1.51.12 h1:DvuhIHZXwnjaR1/Gu19gUe1EGPw4J0qSJw4Qs/5PA8g=
6-
github.com/aws/aws-sdk-go v1.51.12/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
3+
github.com/aws/amazon-vpc-cni-k8s v1.18.1 h1:u/OeBgnUUX6f3PCEOpA4dbG0+iZ71CnY6tEljjrl3iw=
4+
github.com/aws/amazon-vpc-cni-k8s v1.18.1/go.mod h1:m/J5GsxF0Th2iQTOE3ww4W9LFvwdC0tGyA9dIL4h6iQ=
5+
github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU=
6+
github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
77
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
88
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
99
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=

pkg/aws/ec2/api/wrapper.go

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -455,21 +455,21 @@ func (e *ec2Wrapper) getClientUsingAssumedRole(instanceRegion, roleARN, clusterN
455455
}
456456
e.log.Info("created rate limited http client", "qps", qps, "burst", burst)
457457

458-
// Get the regional sts end point
459-
regionalSTSEndpoint, err := endpoints.DefaultResolver().
460-
EndpointFor("sts", aws.StringValue(userStsSession.Config.Region), endpoints.STSRegionalEndpointOption)
461-
if err != nil {
462-
return nil, fmt.Errorf("failed to get the regional sts endoint for region %s: %v",
463-
*userStsSession.Config.Region, err)
464-
}
465-
458+
// GetPartition ID, SourceAccount and SourceARN
466459
roleARN = strings.Trim(roleARN, "\"")
467460

468-
sourceAcct, sourceArn, err := utils.GetSourceAcctAndArn(roleARN, region, clusterName)
461+
sourceAcct, partitionID, sourceArn, err := utils.GetSourceAcctAndArn(roleARN, region, clusterName)
469462
if err != nil {
470463
return nil, err
471464
}
472465

466+
// Get the regional sts end point
467+
regionalSTSEndpoint, err := e.getRegionalStsEndpoint(partitionID, region)
468+
if err != nil {
469+
return nil, fmt.Errorf("failed to get the regional sts endpoint for region %s: %v %v",
470+
*userStsSession.Config.Region, err, partitionID)
471+
}
472+
473473
regionalProvider := &stscreds.AssumeRoleProvider{
474474
Client: e.createSTSClient(userStsSession, client, regionalSTSEndpoint, sourceAcct, sourceArn),
475475
RoleARN: roleARN,
@@ -789,3 +789,35 @@ func (e *ec2Wrapper) CreateNetworkInterfacePermission(input *ec2.CreateNetworkIn
789789

790790
return output, err
791791
}
792+
793+
func (e *ec2Wrapper) getRegionalStsEndpoint(partitionID, region string) (endpoints.ResolvedEndpoint, error) {
794+
var partition *endpoints.Partition
795+
var stsServiceID = "sts"
796+
for _, p := range endpoints.DefaultPartitions() {
797+
if partitionID == p.ID() {
798+
partition = &p
799+
break
800+
}
801+
}
802+
if partition == nil {
803+
return endpoints.ResolvedEndpoint{}, fmt.Errorf("partition %s not valid", partitionID)
804+
}
805+
806+
stsSvc, ok := partition.Services()[stsServiceID]
807+
if !ok {
808+
e.log.Info("STS service not found in partition, generating default endpoint.", "Partition:", partitionID)
809+
// Add the host of the current instances region if the service doesn't already exists in the partition
810+
// so we don't fail if the service is not present in the go sdk but matches the instances region.
811+
res, err := partition.EndpointFor(stsServiceID, region, endpoints.STSRegionalEndpointOption, endpoints.ResolveUnknownServiceOption)
812+
if err != nil {
813+
return endpoints.ResolvedEndpoint{}, fmt.Errorf("error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err)
814+
}
815+
return res, nil
816+
}
817+
818+
res, err := stsSvc.ResolveEndpoint(region, endpoints.STSRegionalEndpointOption)
819+
if err != nil {
820+
return endpoints.ResolvedEndpoint{}, fmt.Errorf("error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err)
821+
}
822+
return res, nil
823+
}

pkg/aws/ec2/api/wrapper_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package api
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func getMockEC2Wrapper() ec2Wrapper {
8+
return ec2Wrapper{}
9+
}
10+
func Test_getRegionalStsEndpoint(t *testing.T) {
11+
12+
ec2Wapper := getMockEC2Wrapper()
13+
14+
type args struct {
15+
partitionID string
16+
region string
17+
}
18+
19+
tests := []struct {
20+
name string
21+
args args
22+
want string
23+
wantErr bool
24+
}{
25+
{
26+
name: "service doesn't exist in partition",
27+
args: args{
28+
partitionID: "aws-iso-f",
29+
region: "testregions",
30+
},
31+
want: "https://sts.testregions.csp.hci.ic.gov",
32+
wantErr: false,
33+
},
34+
{
35+
name: "region doesn't exist in partition",
36+
args: args{
37+
partitionID: "aws",
38+
region: "us-test-2",
39+
},
40+
want: "https://sts.us-test-2.amazonaws.com",
41+
wantErr: false,
42+
},
43+
{
44+
name: "region and service exist in partition",
45+
args: args{
46+
partitionID: "aws",
47+
region: "us-west-2",
48+
},
49+
want: "https://sts.us-west-2.amazonaws.com",
50+
wantErr: false,
51+
},
52+
}
53+
for _, tt := range tests {
54+
t.Run(tt.name, func(t *testing.T) {
55+
got, err := ec2Wapper.getRegionalStsEndpoint(tt.args.partitionID, tt.args.region)
56+
if (err != nil) != tt.wantErr {
57+
t.Errorf("getRegionalStsEndpoint() error = %v, wantErr %v", err, tt.wantErr)
58+
return
59+
}
60+
if got.URL != tt.want {
61+
t.Errorf("getRegionalStsEndpoint() = %v, want %v", got, tt.want)
62+
}
63+
})
64+
}
65+
}

pkg/utils/helper.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,22 +213,22 @@ func IsNitroInstance(instanceType string) (bool, error) {
213213
}
214214

215215
// GetSourceAcctAndArn constructs source acct and arn and return them for use
216-
func GetSourceAcctAndArn(roleARN, region, clusterName string) (string, string, error) {
216+
func GetSourceAcctAndArn(roleARN, region, clusterName string) (string, string, string, error) {
217217
// ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
218218
// arn:partition:service:region:account-id:resource-type/resource-id
219219
// IAM format, region is always blank
220220
// arn:aws:iam::account:role/role-name-with-path
221221
if !arn.IsARN(roleARN) {
222-
return "", "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
222+
return "", "", "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
223223
} else if region == "" {
224-
return "", "", nil
224+
return "", "", "", nil
225225
}
226226

227227
parsedArn, err := arn.Parse(roleARN)
228228
if err != nil {
229-
return "", "", err
229+
return "", "", "", err
230230
}
231231

232232
sourceArn := fmt.Sprintf("arn:%s:eks:%s:%s:cluster/%s", parsedArn.Partition, region, parsedArn.AccountID, clusterName)
233-
return parsedArn.AccountID, sourceArn, nil
233+
return parsedArn.AccountID, parsedArn.Partition, sourceArn, nil
234234
}

pkg/utils/helper_test.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -538,26 +538,29 @@ func TestGetSourceAcctAndArn(t *testing.T) {
538538
clusterName := "test-cluster"
539539
region := "us-west-2"
540540
clusterARN := "arn:aws:eks:us-west-2:123456789876:cluster/test-cluster"
541-
541+
partition := "aws"
542542
roleARN := "arn:aws:iam::123456789876:role/test-cluster"
543543

544544
// test correct inputs
545-
acct, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
545+
acct, part, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
546546
assert.NoError(t, err, "no error should be returned with accurate role arn")
547+
assert.Equal(t, partition, part, "correct partition should be retrieved")
547548
assert.Equal(t, accountID, acct, "correct account ID should be retrieved")
548549
assert.Equal(t, clusterARN, arn, "correct cluster arn should be retrieved")
549550

550551
region = "us-gov-west-1"
551552
roleARN = "arn:aws-us-gov:iam::123456789876:role/test-cluster"
552553
clusterARN = "arn:aws-us-gov:eks:us-gov-west-1:123456789876:cluster/test-cluster"
553-
acct, arn, err = GetSourceAcctAndArn(roleARN, region, clusterName)
554+
partition = "aws-us-gov"
555+
acct, part, arn, err = GetSourceAcctAndArn(roleARN, region, clusterName)
554556
assert.NoError(t, err, "no error should be returned with accurate aws-us-gov partition role arn")
555557
assert.Equal(t, accountID, acct, "correct account ID should be retrieved")
558+
assert.Equal(t, partition, part, "correct patition should be retrieved")
556559
assert.Equal(t, clusterARN, arn, "correct gov partition cluster arn should be retrieved")
557560

558561
// test error handling
559562
roleARN = "arn:aws:iam::123456789876"
560-
_, _, err = GetSourceAcctAndArn(roleARN, region, clusterName)
563+
_, _, _, err = GetSourceAcctAndArn(roleARN, region, clusterName)
561564
assert.Error(t, err, "error should be returned with inaccurate role arn is given")
562565
}
563566

@@ -569,8 +572,10 @@ func TestGetSourceAcctAndArn_NoRegion(t *testing.T) {
569572
roleARN := "arn:aws:iam::123456789876:role/test-cluster"
570573

571574
// test correct inputs
572-
acct, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
575+
acct, part, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
573576
assert.NoError(t, err, "no error should be returned with accurate role arn")
574577
assert.Equal(t, "", acct, "correct account ID should be retrieved")
575578
assert.Equal(t, "", arn, "correct cluster arn should be retrieved")
579+
assert.Equal(t, "", part, "correct partiton should be retrieved")
580+
576581
}

0 commit comments

Comments
 (0)