Skip to content

Update aws-sdk-go and change way to get regional sts endpoint #466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.22.5

require (
github.com/aws/amazon-vpc-cni-k8s v1.18.1
github.com/aws/aws-sdk-go v1.51.32
github.com/aws/aws-sdk-go v1.55.5
github.com/go-logr/logr v1.4.2
github.com/go-logr/zapr v1.3.0
github.com/golang/mock v1.6.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPd
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/aws/amazon-vpc-cni-k8s v1.18.1 h1:u/OeBgnUUX6f3PCEOpA4dbG0+iZ71CnY6tEljjrl3iw=
github.com/aws/amazon-vpc-cni-k8s v1.18.1/go.mod h1:m/J5GsxF0Th2iQTOE3ww4W9LFvwdC0tGyA9dIL4h6iQ=
github.com/aws/aws-sdk-go v1.51.32 h1:A6mPui7QP4mwmovyzgtdedbRbNur1Iu0/El7hBWNHms=
github.com/aws/aws-sdk-go v1.51.32/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU=
github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
Expand Down
50 changes: 41 additions & 9 deletions pkg/aws/ec2/api/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,21 +510,21 @@ func (e *ec2Wrapper) getClientUsingAssumedRole(instanceRegion, roleARN, clusterN
}
e.log.Info("created rate limited http client", "qps", qps, "burst", burst)

// Get the regional sts end point
regionalSTSEndpoint, err := endpoints.DefaultResolver().
EndpointFor("sts", aws.StringValue(userStsSession.Config.Region), endpoints.STSRegionalEndpointOption)
if err != nil {
return nil, fmt.Errorf("failed to get the regional sts endoint for region %s: %v",
*userStsSession.Config.Region, err)
}

// GetPartition ID, SourceAccount and SourceARN
roleARN = strings.Trim(roleARN, "\"")

sourceAcct, sourceArn, err := utils.GetSourceAcctAndArn(roleARN, region, clusterName)
sourceAcct, partitionID, sourceArn, err := utils.GetSourceAcctAndArn(roleARN, region, clusterName)
if err != nil {
return nil, err
}

// Get the regional sts end point
regionalSTSEndpoint, err := e.getRegionalStsEndpoint(partitionID, region)
if err != nil {
return nil, fmt.Errorf("failed to get the regional sts endpoint for region %s: %v %v",
*userStsSession.Config.Region, err, partitionID)
}

regionalProvider := &stscreds.AssumeRoleProvider{
Client: e.createSTSClient(userStsSession, client, regionalSTSEndpoint, sourceAcct, sourceArn),
RoleARN: roleARN,
Expand Down Expand Up @@ -892,3 +892,35 @@ func (e *ec2Wrapper) DisassociateTrunkInterface(input *ec2.DisassociateTrunkInte
}
return err
}

func (e *ec2Wrapper) getRegionalStsEndpoint(partitionID, region string) (endpoints.ResolvedEndpoint, error) {
var partition *endpoints.Partition
var stsServiceID = "sts"
for _, p := range endpoints.DefaultPartitions() {
if partitionID == p.ID() {
partition = &p
break
}
}
if partition == nil {
return endpoints.ResolvedEndpoint{}, fmt.Errorf("partition %s not valid", partitionID)
}

stsSvc, ok := partition.Services()[stsServiceID]
if !ok {
e.log.Info("STS service not found in partition, generating default endpoint.", "Partition:", partitionID)
// Add the host of the current instances region if the service doesn't already exists in the partition
// so we don't fail if the service is not present in the go sdk but matches the instances region.
res, err := partition.EndpointFor(stsServiceID, region, endpoints.STSRegionalEndpointOption, endpoints.ResolveUnknownServiceOption)
if err != nil {
return endpoints.ResolvedEndpoint{}, fmt.Errorf("error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err)
}
return res, nil
}

res, err := stsSvc.ResolveEndpoint(region, endpoints.STSRegionalEndpointOption)
if err != nil {
return endpoints.ResolvedEndpoint{}, fmt.Errorf("error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err)
}
return res, nil
}
65 changes: 65 additions & 0 deletions pkg/aws/ec2/api/wrapper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package api

import (
"testing"
)

func getMockEC2Wrapper() ec2Wrapper {
return ec2Wrapper{}
}
func Test_getRegionalStsEndpoint(t *testing.T) {

ec2Wapper := getMockEC2Wrapper()

type args struct {
partitionID string
region string
}

tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "service doesn't exist in partition",
args: args{
partitionID: "aws-iso-f",
region: "testregions",
},
want: "https://sts.testregions.csp.hci.ic.gov",
wantErr: false,
},
{
name: "region doesn't exist in partition",
args: args{
partitionID: "aws",
region: "us-test-2",
},
want: "https://sts.us-test-2.amazonaws.com",
wantErr: false,
},
{
name: "region and service exist in partition",
args: args{
partitionID: "aws",
region: "us-west-2",
},
want: "https://sts.us-west-2.amazonaws.com",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ec2Wapper.getRegionalStsEndpoint(tt.args.partitionID, tt.args.region)
if (err != nil) != tt.wantErr {
t.Errorf("getRegionalStsEndpoint() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got.URL != tt.want {
t.Errorf("getRegionalStsEndpoint() = %v, want %v", got, tt.want)
}
})
}
}
10 changes: 5 additions & 5 deletions pkg/utils/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,22 +213,22 @@ func IsNitroInstance(instanceType string) (bool, error) {
}

// GetSourceAcctAndArn constructs source acct and arn and return them for use
func GetSourceAcctAndArn(roleARN, region, clusterName string) (string, string, error) {
func GetSourceAcctAndArn(roleARN, region, clusterName string) (string, string, string, error) {
// ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
// arn:partition:service:region:account-id:resource-type/resource-id
// IAM format, region is always blank
// arn:aws:iam::account:role/role-name-with-path
if !arn.IsARN(roleARN) {
return "", "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
return "", "", "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
} else if region == "" {
return "", "", nil
return "", "", "", nil
}

parsedArn, err := arn.Parse(roleARN)
if err != nil {
return "", "", err
return "", "", "", err
}

sourceArn := fmt.Sprintf("arn:%s:eks:%s:%s:cluster/%s", parsedArn.Partition, region, parsedArn.AccountID, clusterName)
return parsedArn.AccountID, sourceArn, nil
return parsedArn.AccountID, parsedArn.Partition, sourceArn, nil
}
15 changes: 10 additions & 5 deletions pkg/utils/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,26 +538,29 @@ func TestGetSourceAcctAndArn(t *testing.T) {
clusterName := "test-cluster"
region := "us-west-2"
clusterARN := "arn:aws:eks:us-west-2:123456789876:cluster/test-cluster"

partition := "aws"
roleARN := "arn:aws:iam::123456789876:role/test-cluster"

// test correct inputs
acct, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
acct, part, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
assert.NoError(t, err, "no error should be returned with accurate role arn")
assert.Equal(t, partition, part, "correct partition should be retrieved")
assert.Equal(t, accountID, acct, "correct account ID should be retrieved")
assert.Equal(t, clusterARN, arn, "correct cluster arn should be retrieved")

region = "us-gov-west-1"
roleARN = "arn:aws-us-gov:iam::123456789876:role/test-cluster"
clusterARN = "arn:aws-us-gov:eks:us-gov-west-1:123456789876:cluster/test-cluster"
acct, arn, err = GetSourceAcctAndArn(roleARN, region, clusterName)
partition = "aws-us-gov"
acct, part, arn, err = GetSourceAcctAndArn(roleARN, region, clusterName)
assert.NoError(t, err, "no error should be returned with accurate aws-us-gov partition role arn")
assert.Equal(t, accountID, acct, "correct account ID should be retrieved")
assert.Equal(t, partition, part, "correct patition should be retrieved")
assert.Equal(t, clusterARN, arn, "correct gov partition cluster arn should be retrieved")

// test error handling
roleARN = "arn:aws:iam::123456789876"
_, _, err = GetSourceAcctAndArn(roleARN, region, clusterName)
_, _, _, err = GetSourceAcctAndArn(roleARN, region, clusterName)
assert.Error(t, err, "error should be returned with inaccurate role arn is given")
}

Expand All @@ -569,8 +572,10 @@ func TestGetSourceAcctAndArn_NoRegion(t *testing.T) {
roleARN := "arn:aws:iam::123456789876:role/test-cluster"

// test correct inputs
acct, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
acct, part, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
assert.NoError(t, err, "no error should be returned with accurate role arn")
assert.Equal(t, "", acct, "correct account ID should be retrieved")
assert.Equal(t, "", arn, "correct cluster arn should be retrieved")
assert.Equal(t, "", part, "correct partiton should be retrieved")

}
Loading