@@ -510,21 +510,21 @@ func (e *ec2Wrapper) getClientUsingAssumedRole(instanceRegion, roleARN, clusterN
510
510
}
511
511
e .log .Info ("created rate limited http client" , "qps" , qps , "burst" , burst )
512
512
513
- // Get the regional sts end point
514
- regionalSTSEndpoint , err := endpoints .DefaultResolver ().
515
- EndpointFor ("sts" , aws .StringValue (userStsSession .Config .Region ), endpoints .STSRegionalEndpointOption )
516
- if err != nil {
517
- return nil , fmt .Errorf ("failed to get the regional sts endoint for region %s: %v" ,
518
- * userStsSession .Config .Region , err )
519
- }
520
-
513
+ // GetPartition ID, SourceAccount and SourceARN
521
514
roleARN = strings .Trim (roleARN , "\" " )
522
515
523
- sourceAcct , sourceArn , err := utils .GetSourceAcctAndArn (roleARN , region , clusterName )
516
+ sourceAcct , partitionID , sourceArn , err := utils .GetSourceAcctAndArn (roleARN , region , clusterName )
524
517
if err != nil {
525
518
return nil , err
526
519
}
527
520
521
+ // Get the regional sts end point
522
+ regionalSTSEndpoint , err := e .getRegionalStsEndpoint (partitionID , region )
523
+ if err != nil {
524
+ return nil , fmt .Errorf ("failed to get the regional sts endpoint for region %s: %v %v" ,
525
+ * userStsSession .Config .Region , err , partitionID )
526
+ }
527
+
528
528
regionalProvider := & stscreds.AssumeRoleProvider {
529
529
Client : e .createSTSClient (userStsSession , client , regionalSTSEndpoint , sourceAcct , sourceArn ),
530
530
RoleARN : roleARN ,
@@ -547,7 +547,7 @@ func (e *ec2Wrapper) getClientUsingAssumedRole(instanceRegion, roleARN, clusterN
547
547
// If the regional STS endpoint is different than the global STS endpoint then add the global sts endpoint
548
548
if regionalSTSEndpoint .URL != globalSTSEndpoint .URL {
549
549
globalProvider := & stscreds.AssumeRoleProvider {
550
- Client : e .createSTSClient (userStsSession , client , regionalSTSEndpoint , sourceAcct , sourceArn ),
550
+ Client : e .createSTSClient (userStsSession , client , globalSTSEndpoint , sourceAcct , sourceArn ),
551
551
RoleARN : roleARN ,
552
552
Duration : time .Minute * 60 ,
553
553
}
@@ -892,3 +892,35 @@ func (e *ec2Wrapper) DisassociateTrunkInterface(input *ec2.DisassociateTrunkInte
892
892
}
893
893
return err
894
894
}
895
+
896
+ func (e * ec2Wrapper ) getRegionalStsEndpoint (partitionID , region string ) (endpoints.ResolvedEndpoint , error ) {
897
+ var partition * endpoints.Partition
898
+ var stsServiceID = "sts"
899
+ for _ , p := range endpoints .DefaultPartitions () {
900
+ if partitionID == p .ID () {
901
+ partition = & p
902
+ break
903
+ }
904
+ }
905
+ if partition == nil {
906
+ return endpoints.ResolvedEndpoint {}, fmt .Errorf ("partition %s not valid" , partitionID )
907
+ }
908
+
909
+ stsSvc , ok := partition .Services ()[stsServiceID ]
910
+ if ! ok {
911
+ e .log .Info ("STS service not found in partition, generating default endpoint." , "Partition:" , partitionID )
912
+ // Add the host of the current instances region if the service doesn't already exists in the partition
913
+ // so we don't fail if the service is not present in the go sdk but matches the instances region.
914
+ res , err := partition .EndpointFor (stsServiceID , region , endpoints .STSRegionalEndpointOption , endpoints .ResolveUnknownServiceOption )
915
+ if err != nil {
916
+ return endpoints.ResolvedEndpoint {}, fmt .Errorf ("error resolving endpoint for %s in partition %s. err: %v" , region , partition .ID (), err )
917
+ }
918
+ return res , nil
919
+ }
920
+
921
+ res , err := stsSvc .ResolveEndpoint (region , endpoints .STSRegionalEndpointOption )
922
+ if err != nil {
923
+ return endpoints.ResolvedEndpoint {}, fmt .Errorf ("error resolving endpoint for %s in partition %s. err: %v" , region , partition .ID (), err )
924
+ }
925
+ return res , nil
926
+ }
0 commit comments