Skip to content

Commit a0d4442

Browse files
committed
Fix support for public.ecr.aws
Signed-off-by: Matheus Pimenta <[email protected]>
1 parent 53c7b2d commit a0d4442

File tree

12 files changed

+226
-73
lines changed

12 files changed

+226
-73
lines changed

.github/workflows/integration-aws.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ jobs:
2020
auth-mode:
2121
- node-identity
2222
- workload-identity
23+
fail-fast: false
2324
defaults:
2425
run:
2526
working-directory: ./oci/tests/integration
@@ -38,7 +39,7 @@ jobs:
3839
with:
3940
role-to-assume: arn:aws:iam::${{ secrets.AWS_ACCOUNT_ID }}:role/${{ secrets.OCI_E2E_AWS_ASSUME_ROLE_NAME }}
4041
role-session-name: OCI_GH_Actions
41-
aws-region: ${{ vars.AWS_REGION }}
42+
aws-region: us-east-1
4243
- name: Setup QEMU
4344
uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3.6.0
4445
- name: Setup Docker Buildx
@@ -56,13 +57,13 @@ jobs:
5657
- name: Run tests
5758
run: . .env && make test-aws
5859
env:
59-
AWS_REGION: ${{ vars.AWS_REGION }}
60-
TF_VAR_cross_region: ${{ vars.OCI_E2E_TF_VAR_cross_region }}
60+
AWS_REGION: us-east-1
61+
TF_VAR_cross_region: us-east-2
6162
TF_VAR_enable_wi: ${{ (matrix.auth-mode == 'workload-identity' && 'true') || 'false' }}
6263
- name: Ensure resource cleanup
6364
if: ${{ always() }}
6465
run: . .env && make destroy-aws
6566
env:
66-
AWS_REGION: ${{ vars.AWS_REGION }}
67-
TF_VAR_cross_region: ${{ vars.OCI_E2E_TF_VAR_cross_region }}
67+
AWS_REGION: us-east-1
68+
TF_VAR_cross_region: us-east-2
6869
TF_VAR_enable_wi: ${{ (matrix.auth-mode == 'workload-identity' && 'true') || 'false' }}

auth/aws/implementation.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@ import (
2222
"github.com/aws/aws-sdk-go-v2/aws"
2323
"github.com/aws/aws-sdk-go-v2/config"
2424
"github.com/aws/aws-sdk-go-v2/service/ecr"
25+
"github.com/aws/aws-sdk-go-v2/service/ecrpublic"
2526
"github.com/aws/aws-sdk-go-v2/service/sts"
2627
)
2728

2829
// Implementation provides the required methods of the AWS libraries.
2930
type Implementation interface {
3031
LoadDefaultConfig(ctx context.Context, optFns ...func(*config.LoadOptions) error) (aws.Config, error)
3132
AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, options sts.Options) (*sts.AssumeRoleWithWebIdentityOutput, error)
32-
GetAuthorizationToken(ctx context.Context, cfg aws.Config) (*ecr.GetAuthorizationTokenOutput, error)
33+
GetAuthorizationToken(ctx context.Context, cfg aws.Config) (any, error)
34+
GetPublicAuthorizationToken(ctx context.Context, cfg aws.Config) (any, error)
3335
}
3436

3537
type implementation struct{}
@@ -42,6 +44,10 @@ func (implementation) AssumeRoleWithWebIdentity(ctx context.Context, params *sts
4244
return sts.New(options).AssumeRoleWithWebIdentity(ctx, params)
4345
}
4446

45-
func (implementation) GetAuthorizationToken(ctx context.Context, cfg aws.Config) (*ecr.GetAuthorizationTokenOutput, error) {
47+
func (implementation) GetAuthorizationToken(ctx context.Context, cfg aws.Config) (any, error) {
4648
return ecr.NewFromConfig(cfg).GetAuthorizationToken(ctx, &ecr.GetAuthorizationTokenInput{})
4749
}
50+
51+
func (implementation) GetPublicAuthorizationToken(ctx context.Context, cfg aws.Config) (any, error) {
52+
return ecrpublic.NewFromConfig(cfg).GetAuthorizationToken(ctx, &ecrpublic.GetAuthorizationTokenInput{})
53+
}

auth/aws/implementation_test.go

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import (
2727
"github.com/aws/aws-sdk-go-v2/config"
2828
"github.com/aws/aws-sdk-go-v2/service/ecr"
2929
ecrtypes "github.com/aws/aws-sdk-go-v2/service/ecr/types"
30+
"github.com/aws/aws-sdk-go-v2/service/ecrpublic"
31+
ecrpublictypes "github.com/aws/aws-sdk-go-v2/service/ecrpublic/types"
3032
"github.com/aws/aws-sdk-go-v2/service/sts"
3133
ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
3234
. "github.com/onsi/gomega"
@@ -35,6 +37,8 @@ import (
3537
type mockImplementation struct {
3638
t *testing.T
3739

40+
publicECR bool
41+
3842
argRoleARN string
3943
argRoleSessionName string
4044
argOIDCToken string
@@ -101,7 +105,33 @@ func (m *mockImplementation) AssumeRoleWithWebIdentity(ctx context.Context, para
101105
}, nil
102106
}
103107

104-
func (m *mockImplementation) GetAuthorizationToken(ctx context.Context, cfg aws.Config) (*ecr.GetAuthorizationTokenOutput, error) {
108+
func (m *mockImplementation) GetAuthorizationToken(ctx context.Context, cfg aws.Config) (any, error) {
109+
m.t.Helper()
110+
g := NewWithT(m.t)
111+
g.Expect(m.publicECR).To(BeFalse())
112+
m.checkGetAuthorizationToken(ctx, cfg)
113+
return &ecr.GetAuthorizationTokenOutput{
114+
AuthorizationData: []ecrtypes.AuthorizationData{{
115+
AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString([]byte(m.returnUsername + ":" + m.returnPassword))),
116+
ExpiresAt: aws.Time(m.returnCreds.Expires),
117+
}},
118+
}, nil
119+
}
120+
121+
func (m *mockImplementation) GetPublicAuthorizationToken(ctx context.Context, cfg aws.Config) (any, error) {
122+
m.t.Helper()
123+
g := NewWithT(m.t)
124+
g.Expect(m.publicECR).To(BeTrue())
125+
m.checkGetAuthorizationToken(ctx, cfg)
126+
return &ecrpublic.GetAuthorizationTokenOutput{
127+
AuthorizationData: &ecrpublictypes.AuthorizationData{
128+
AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString([]byte(m.returnUsername + ":" + m.returnPassword))),
129+
ExpiresAt: aws.Time(m.returnCreds.Expires),
130+
},
131+
}, nil
132+
}
133+
134+
func (m *mockImplementation) checkGetAuthorizationToken(ctx context.Context, cfg aws.Config) {
105135
m.t.Helper()
106136
g := NewWithT(m.t)
107137
g.Expect(cfg.Region).To(Equal(m.argRegion))
@@ -114,11 +144,6 @@ func (m *mockImplementation) GetAuthorizationToken(ctx context.Context, cfg aws.
114144
proxyURL, err := cfg.HTTPClient.(*http.Client).Transport.(*http.Transport).Proxy(nil)
115145
g.Expect(err).NotTo(HaveOccurred())
116146
g.Expect(proxyURL).To(Equal(m.argProxyURL))
117-
return &ecr.GetAuthorizationTokenOutput{
118-
AuthorizationData: []ecrtypes.AuthorizationData{{
119-
AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString([]byte(m.returnUsername + ":" + m.returnPassword))),
120-
}},
121-
}, nil
122147
}
123148

124149
func (m *mockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {

auth/aws/provider.go

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import (
2929

3030
"github.com/aws/aws-sdk-go-v2/aws"
3131
"github.com/aws/aws-sdk-go-v2/config"
32+
"github.com/aws/aws-sdk-go-v2/service/ecr"
33+
"github.com/aws/aws-sdk-go-v2/service/ecrpublic"
3234
"github.com/aws/aws-sdk-go-v2/service/sts"
3335
"github.com/google/go-containerregistry/pkg/authn"
3436
corev1 "k8s.io/api/core/v1"
@@ -70,8 +72,8 @@ func (p Provider) NewControllerToken(ctx context.Context, opts ...auth.Option) (
7072
case o.ArtifactRepository != "":
7173
// We can safely ignore the error here, auth.GetToken() has already called
7274
// ParseArtifactRepository() and validated the repository at this point.
73-
ecrRegion, _ := p.ParseArtifactRepository(o.ArtifactRepository)
74-
stsRegion = ecrRegion
75+
registryInput, _ := p.ParseArtifactRepository(o.ArtifactRepository)
76+
stsRegion = getECRRegionFromRegistryInput(registryInput)
7577
// EKS sets this environment variable automatically if the controller pod is
7678
// properly configured with IRSA or EKS Pod Identity, so we can rely on this
7779
// and communicate this to users since this is controller-level configuration.
@@ -140,8 +142,8 @@ func (p Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken strin
140142
case o.ArtifactRepository != "":
141143
// We can safely ignore the error here, auth.GetToken() has already called
142144
// ParseArtifactRepository() and validated the repository at this point.
143-
ecrRegion, _ := p.ParseArtifactRepository(o.ArtifactRepository)
144-
stsRegion = ecrRegion
145+
registryInput, _ := p.ParseArtifactRepository(o.ArtifactRepository)
146+
stsRegion = getECRRegionFromRegistryInput(registryInput)
145147
// In this case we can't rely on IRSA or EKS Pod Identity for the controller
146148
// pod because this is object-level configuration, so we show a different
147149
// error message.
@@ -204,20 +206,21 @@ func (p Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken strin
204206
// It covers both public AWS partitions like amazonaws.com, China partitions like amazonaws.com.cn, and non-public partitions.
205207
const registryPattern = `([0-9+]*).dkr.ecr(?:-fips)?\.([^/.]*)\.(amazonaws\.com[.cn]*|sc2s\.sgov\.gov|c2s\.ic\.gov|cloud\.adc-e\.uk|csp\.hci\.ic\.gov)`
206208

209+
const publicECR = "public.ecr.aws"
210+
207211
var registryRegex = regexp.MustCompile(registryPattern)
208212

209213
// ParseArtifactRepository implements auth.Provider.
210-
// ParseArtifactRepository returns the ECR region.
214+
// ParseArtifactRepository returns the ECR region, unless the registry
215+
// is public.ecr.aws, in which case it returns public.ecr.aws.
211216
func (Provider) ParseArtifactRepository(artifactRepository string) (string, error) {
212217
registry, err := auth.GetRegistryFromArtifactRepository(artifactRepository)
213218
if err != nil {
214219
return "", err
215220
}
216221

217-
// Region is required to be us-east-1 for public.ecr.aws:
218-
// https://docs.aws.amazon.com/AmazonECR/latest/public/public-registry-auth.html#public-registry-auth-token
219-
if registry == "public.ecr.aws" {
220-
return "us-east-1", nil
222+
if registry == publicECR {
223+
return publicECR, nil
221224
}
222225

223226
parts := registryRegex.FindAllStringSubmatch(registry, -1)
@@ -231,36 +234,70 @@ func (Provider) ParseArtifactRepository(artifactRepository string) (string, erro
231234
return ecrRegion, nil
232235
}
233236

237+
func getECRRegionFromRegistryInput(registryInput string) string {
238+
if registryInput == publicECR {
239+
// Region is required to be us-east-1 for public ECR:
240+
// https://docs.aws.amazon.com/AmazonECR/latest/public/public-registry-auth.html#public-registry-auth-token
241+
return "us-east-1"
242+
}
243+
return registryInput
244+
}
245+
234246
// NewArtifactRegistryCredentials implements auth.Provider.
235-
func (p Provider) NewArtifactRegistryCredentials(ctx context.Context, ecrRegion string,
247+
func (p Provider) NewArtifactRegistryCredentials(ctx context.Context, registryInput string,
236248
accessToken auth.Token, opts ...auth.Option) (*auth.ArtifactRegistryCredentials, error) {
237249

238250
var o auth.Options
239251
o.Apply(opts...)
240252

253+
authTokenFunc := p.impl().GetAuthorizationToken
254+
if registryInput == publicECR {
255+
authTokenFunc = p.impl().GetPublicAuthorizationToken
256+
}
257+
241258
conf := aws.Config{
242-
Region: ecrRegion,
259+
Region: getECRRegionFromRegistryInput(registryInput),
243260
Credentials: accessToken.(*Token).CredentialsProvider(),
244261
}
245262

246263
if hc := o.GetHTTPClient(); hc != nil {
247264
conf.HTTPClient = hc
248265
}
249266

250-
resp, err := p.impl().GetAuthorizationToken(ctx, conf)
267+
respAny, err := authTokenFunc(ctx, conf)
251268
if err != nil {
252269
return nil, err
253270
}
254271

255272
// Parse the authorization token.
256-
if len(resp.AuthorizationData) == 0 {
257-
return nil, fmt.Errorf("no authorization data returned")
258-
}
259-
tokenResp := resp.AuthorizationData[0]
260-
if tokenResp.AuthorizationToken == nil {
261-
return nil, fmt.Errorf("authorization token is nil")
273+
var token string
274+
var expiresAt time.Time
275+
switch resp := respAny.(type) {
276+
case *ecr.GetAuthorizationTokenOutput:
277+
if len(resp.AuthorizationData) == 0 {
278+
return nil, fmt.Errorf("no authorization data returned")
279+
}
280+
if resp.AuthorizationData[0].AuthorizationToken == nil {
281+
return nil, fmt.Errorf("authorization token is nil")
282+
}
283+
if resp.AuthorizationData[0].ExpiresAt == nil {
284+
return nil, fmt.Errorf("authorization token expiration is nil")
285+
}
286+
token = *resp.AuthorizationData[0].AuthorizationToken
287+
expiresAt = *resp.AuthorizationData[0].ExpiresAt
288+
case *ecrpublic.GetAuthorizationTokenOutput:
289+
if resp.AuthorizationData == nil {
290+
return nil, fmt.Errorf("no authorization data returned")
291+
}
292+
if resp.AuthorizationData.AuthorizationToken == nil {
293+
return nil, fmt.Errorf("authorization token is nil")
294+
}
295+
if resp.AuthorizationData.ExpiresAt == nil {
296+
return nil, fmt.Errorf("authorization token expiration is nil")
297+
}
298+
token = *resp.AuthorizationData.AuthorizationToken
299+
expiresAt = *resp.AuthorizationData.ExpiresAt
262300
}
263-
token := *tokenResp.AuthorizationToken
264301
b, err := base64.StdEncoding.DecodeString(token)
265302
if err != nil {
266303
return nil, fmt.Errorf("failed to parse authorization token: %w", err)
@@ -269,10 +306,6 @@ func (p Provider) NewArtifactRegistryCredentials(ctx context.Context, ecrRegion
269306
if len(s) != 2 {
270307
return nil, fmt.Errorf("invalid authorization token format")
271308
}
272-
var expiresAt time.Time
273-
if exp := tokenResp.ExpiresAt; exp != nil {
274-
expiresAt = *exp
275-
}
276309
return &auth.ArtifactRegistryCredentials{
277310
Authenticator: authn.FromConfig(authn.AuthConfig{
278311
Username: s[0],

auth/aws/provider_test.go

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -225,39 +225,67 @@ func TestProvider_GetIdentity(t *testing.T) {
225225
}
226226

227227
func TestProvider_NewArtifactRegistryCredentials(t *testing.T) {
228-
g := NewWithT(t)
228+
for _, tt := range []struct {
229+
name string
230+
registryInput string
231+
expectedPublicECR bool
232+
expectedRegion string
233+
}{
234+
{
235+
name: "non public ECR",
236+
registryInput: "us-east-1",
237+
expectedRegion: "us-east-1",
238+
expectedPublicECR: false,
239+
},
240+
{
241+
name: "non public ECR",
242+
registryInput: "us-west-2",
243+
expectedRegion: "us-west-2",
244+
expectedPublicECR: false,
245+
},
246+
{
247+
name: "public ECR",
248+
registryInput: "public.ecr.aws",
249+
expectedRegion: "us-east-1", // Public ECR is always us-east-1
250+
expectedPublicECR: true,
251+
},
252+
} {
253+
t.Run(tt.name, func(t *testing.T) {
254+
g := NewWithT(t)
229255

230-
impl := &mockImplementation{
231-
t: t,
232-
argRegion: "us-east-1",
233-
argProxyURL: &url.URL{Scheme: "http", Host: "proxy.example.com"},
234-
argCredsProvider: credentials.NewStaticCredentialsProvider("access-key-id", "secret-access-key", "session-token"),
235-
returnUsername: "username",
236-
returnPassword: "password",
237-
}
256+
impl := &mockImplementation{
257+
t: t,
258+
publicECR: tt.expectedPublicECR,
259+
argRegion: tt.expectedRegion,
260+
argProxyURL: &url.URL{Scheme: "http", Host: "proxy.example.com"},
261+
argCredsProvider: credentials.NewStaticCredentialsProvider("access-key-id", "secret-access-key", "session-token"),
262+
returnUsername: "username",
263+
returnPassword: "password",
264+
}
238265

239-
ecrRegion := "us-east-1"
240-
accessToken := &aws.Token{
241-
Credentials: types.Credentials{
242-
AccessKeyId: awssdk.String("access-key-id"),
243-
SecretAccessKey: awssdk.String("secret-access-key"),
244-
SessionToken: awssdk.String("session-token"),
245-
},
246-
}
247-
opts := []auth.Option{
248-
auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.example.com"}),
249-
}
266+
accessToken := &aws.Token{
267+
Credentials: types.Credentials{
268+
AccessKeyId: awssdk.String("access-key-id"),
269+
SecretAccessKey: awssdk.String("secret-access-key"),
270+
SessionToken: awssdk.String("session-token"),
271+
},
272+
}
273+
opts := []auth.Option{
274+
auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.example.com"}),
275+
}
250276

251-
provider := aws.Provider{Implementation: impl}
252-
creds, err := provider.NewArtifactRegistryCredentials(
253-
context.Background(), ecrRegion, accessToken, opts...)
254-
g.Expect(err).NotTo(HaveOccurred())
255-
g.Expect(creds).To(Equal(&auth.ArtifactRegistryCredentials{
256-
Authenticator: authn.FromConfig(authn.AuthConfig{
257-
Username: "username",
258-
Password: "password",
259-
}),
260-
}))
277+
provider := aws.Provider{Implementation: impl}
278+
creds, err := provider.NewArtifactRegistryCredentials(
279+
context.Background(), tt.registryInput, accessToken, opts...)
280+
g.Expect(err).NotTo(HaveOccurred())
281+
g.Expect(creds).To(Equal(&auth.ArtifactRegistryCredentials{
282+
Authenticator: authn.FromConfig(authn.AuthConfig{
283+
Username: "username",
284+
Password: "password",
285+
}),
286+
}))
287+
})
288+
}
261289
}
262290

263291
func TestProvider_ParseArtifactRepository(t *testing.T) {
@@ -322,7 +350,7 @@ func TestProvider_ParseArtifactRepository(t *testing.T) {
322350
},
323351
{
324352
artifactRepository: "public.ecr.aws/foo/bar",
325-
expectedRegion: "us-east-1",
353+
expectedRegion: "public.ecr.aws",
326354
expectValid: true,
327355
},
328356
}

auth/go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ require (
1212
github.com/aws/aws-sdk-go-v2/config v1.29.14
1313
github.com/aws/aws-sdk-go-v2/credentials v1.17.67
1414
github.com/aws/aws-sdk-go-v2/service/ecr v1.43.3
15+
github.com/aws/aws-sdk-go-v2/service/ecrpublic v1.33.0
1516
github.com/aws/aws-sdk-go-v2/service/sts v1.33.19
1617
github.com/coreos/go-oidc/v3 v3.14.1
1718
github.com/fluxcd/pkg/cache v0.9.0

0 commit comments

Comments
 (0)