@@ -29,6 +29,8 @@ import (
29
29
30
30
"github.com/aws/aws-sdk-go-v2/aws"
31
31
"github.com/aws/aws-sdk-go-v2/config"
32
+ "github.com/aws/aws-sdk-go-v2/service/ecr"
33
+ "github.com/aws/aws-sdk-go-v2/service/ecrpublic"
32
34
"github.com/aws/aws-sdk-go-v2/service/sts"
33
35
"github.com/google/go-containerregistry/pkg/authn"
34
36
corev1 "k8s.io/api/core/v1"
@@ -70,8 +72,8 @@ func (p Provider) NewControllerToken(ctx context.Context, opts ...auth.Option) (
70
72
case o .ArtifactRepository != "" :
71
73
// We can safely ignore the error here, auth.GetToken() has already called
72
74
// ParseArtifactRepository() and validated the repository at this point.
73
- ecrRegion , _ := p .ParseArtifactRepository (o .ArtifactRepository )
74
- stsRegion = ecrRegion
75
+ registryInput , _ := p .ParseArtifactRepository (o .ArtifactRepository )
76
+ stsRegion = getECRRegionFromRegistryInput ( registryInput )
75
77
// EKS sets this environment variable automatically if the controller pod is
76
78
// properly configured with IRSA or EKS Pod Identity, so we can rely on this
77
79
// and communicate this to users since this is controller-level configuration.
@@ -140,8 +142,8 @@ func (p Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken strin
140
142
case o .ArtifactRepository != "" :
141
143
// We can safely ignore the error here, auth.GetToken() has already called
142
144
// ParseArtifactRepository() and validated the repository at this point.
143
- ecrRegion , _ := p .ParseArtifactRepository (o .ArtifactRepository )
144
- stsRegion = ecrRegion
145
+ registryInput , _ := p .ParseArtifactRepository (o .ArtifactRepository )
146
+ stsRegion = getECRRegionFromRegistryInput ( registryInput )
145
147
// In this case we can't rely on IRSA or EKS Pod Identity for the controller
146
148
// pod because this is object-level configuration, so we show a different
147
149
// error message.
@@ -204,20 +206,21 @@ func (p Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken strin
204
206
// It covers both public AWS partitions like amazonaws.com, China partitions like amazonaws.com.cn, and non-public partitions.
205
207
const registryPattern = `([0-9+]*).dkr.ecr(?:-fips)?\.([^/.]*)\.(amazonaws\.com[.cn]*|sc2s\.sgov\.gov|c2s\.ic\.gov|cloud\.adc-e\.uk|csp\.hci\.ic\.gov)`
206
208
209
+ const publicECR = "public.ecr.aws"
210
+
207
211
var registryRegex = regexp .MustCompile (registryPattern )
208
212
209
213
// ParseArtifactRepository implements auth.Provider.
210
- // ParseArtifactRepository returns the ECR region.
214
+ // ParseArtifactRepository returns the ECR region, unless the registry
215
+ // is public.ecr.aws, in which case it returns public.ecr.aws.
211
216
func (Provider ) ParseArtifactRepository (artifactRepository string ) (string , error ) {
212
217
registry , err := auth .GetRegistryFromArtifactRepository (artifactRepository )
213
218
if err != nil {
214
219
return "" , err
215
220
}
216
221
217
- // Region is required to be us-east-1 for public.ecr.aws:
218
- // https://docs.aws.amazon.com/AmazonECR/latest/public/public-registry-auth.html#public-registry-auth-token
219
- if registry == "public.ecr.aws" {
220
- return "us-east-1" , nil
222
+ if registry == publicECR {
223
+ return publicECR , nil
221
224
}
222
225
223
226
parts := registryRegex .FindAllStringSubmatch (registry , - 1 )
@@ -231,36 +234,70 @@ func (Provider) ParseArtifactRepository(artifactRepository string) (string, erro
231
234
return ecrRegion , nil
232
235
}
233
236
237
+ func getECRRegionFromRegistryInput (registryInput string ) string {
238
+ if registryInput == publicECR {
239
+ // Region is required to be us-east-1 for public ECR:
240
+ // https://docs.aws.amazon.com/AmazonECR/latest/public/public-registry-auth.html#public-registry-auth-token
241
+ return "us-east-1"
242
+ }
243
+ return registryInput
244
+ }
245
+
234
246
// NewArtifactRegistryCredentials implements auth.Provider.
235
- func (p Provider ) NewArtifactRegistryCredentials (ctx context.Context , ecrRegion string ,
247
+ func (p Provider ) NewArtifactRegistryCredentials (ctx context.Context , registryInput string ,
236
248
accessToken auth.Token , opts ... auth.Option ) (* auth.ArtifactRegistryCredentials , error ) {
237
249
238
250
var o auth.Options
239
251
o .Apply (opts ... )
240
252
253
+ authTokenFunc := p .impl ().GetAuthorizationToken
254
+ if registryInput == publicECR {
255
+ authTokenFunc = p .impl ().GetPublicAuthorizationToken
256
+ }
257
+
241
258
conf := aws.Config {
242
- Region : ecrRegion ,
259
+ Region : getECRRegionFromRegistryInput ( registryInput ) ,
243
260
Credentials : accessToken .(* Token ).CredentialsProvider (),
244
261
}
245
262
246
263
if hc := o .GetHTTPClient (); hc != nil {
247
264
conf .HTTPClient = hc
248
265
}
249
266
250
- resp , err := p . impl (). GetAuthorizationToken (ctx , conf )
267
+ respAny , err := authTokenFunc (ctx , conf )
251
268
if err != nil {
252
269
return nil , err
253
270
}
254
271
255
272
// Parse the authorization token.
256
- if len (resp .AuthorizationData ) == 0 {
257
- return nil , fmt .Errorf ("no authorization data returned" )
258
- }
259
- tokenResp := resp .AuthorizationData [0 ]
260
- if tokenResp .AuthorizationToken == nil {
261
- return nil , fmt .Errorf ("authorization token is nil" )
273
+ var token string
274
+ var expiresAt time.Time
275
+ switch resp := respAny .(type ) {
276
+ case * ecr.GetAuthorizationTokenOutput :
277
+ if len (resp .AuthorizationData ) == 0 {
278
+ return nil , fmt .Errorf ("no authorization data returned" )
279
+ }
280
+ if resp .AuthorizationData [0 ].AuthorizationToken == nil {
281
+ return nil , fmt .Errorf ("authorization token is nil" )
282
+ }
283
+ if resp .AuthorizationData [0 ].ExpiresAt == nil {
284
+ return nil , fmt .Errorf ("authorization token expiration is nil" )
285
+ }
286
+ token = * resp .AuthorizationData [0 ].AuthorizationToken
287
+ expiresAt = * resp .AuthorizationData [0 ].ExpiresAt
288
+ case * ecrpublic.GetAuthorizationTokenOutput :
289
+ if resp .AuthorizationData == nil {
290
+ return nil , fmt .Errorf ("no authorization data returned" )
291
+ }
292
+ if resp .AuthorizationData .AuthorizationToken == nil {
293
+ return nil , fmt .Errorf ("authorization token is nil" )
294
+ }
295
+ if resp .AuthorizationData .ExpiresAt == nil {
296
+ return nil , fmt .Errorf ("authorization token expiration is nil" )
297
+ }
298
+ token = * resp .AuthorizationData .AuthorizationToken
299
+ expiresAt = * resp .AuthorizationData .ExpiresAt
262
300
}
263
- token := * tokenResp .AuthorizationToken
264
301
b , err := base64 .StdEncoding .DecodeString (token )
265
302
if err != nil {
266
303
return nil , fmt .Errorf ("failed to parse authorization token: %w" , err )
@@ -269,10 +306,6 @@ func (p Provider) NewArtifactRegistryCredentials(ctx context.Context, ecrRegion
269
306
if len (s ) != 2 {
270
307
return nil , fmt .Errorf ("invalid authorization token format" )
271
308
}
272
- var expiresAt time.Time
273
- if exp := tokenResp .ExpiresAt ; exp != nil {
274
- expiresAt = * exp
275
- }
276
309
return & auth.ArtifactRegistryCredentials {
277
310
Authenticator : authn .FromConfig (authn.AuthConfig {
278
311
Username : s [0 ],
0 commit comments