Skip to content

Commit f2f28dc

Browse files
authored
Add option to filter certificates by tag before adding it to LB (#658)
* Add option to filter certificates by tag before adding it to LB Signed-off-by: Lucas Thiesen <[email protected]>
1 parent 04144f2 commit f2f28dc

File tree

8 files changed

+310
-30
lines changed

8 files changed

+310
-30
lines changed

aws/acm.go

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package aws
22

33
import (
44
"crypto/x509"
5+
"strings"
56

67
"github.com/aws/aws-sdk-go/aws"
78
"github.com/aws/aws-sdk-go/service/acm"
@@ -10,16 +11,17 @@ import (
1011
)
1112

1213
type acmCertificateProvider struct {
13-
api acmiface.ACMAPI
14+
api acmiface.ACMAPI
15+
filterTag string
1416
}
1517

16-
func newACMCertProvider(api acmiface.ACMAPI) certs.CertificatesProvider {
17-
return &acmCertificateProvider{api: api}
18+
func newACMCertProvider(api acmiface.ACMAPI, certFilterTag string) certs.CertificatesProvider {
19+
return &acmCertificateProvider{api: api, filterTag: certFilterTag}
1820
}
1921

2022
// GetCertificates returns a list of AWS ACM certificates
2123
func (p *acmCertificateProvider) GetCertificates() ([]*certs.CertificateSummary, error) {
22-
acmSummaries, err := getACMCertificateSummaries(p.api)
24+
acmSummaries, err := getACMCertificateSummaries(p.api, p.filterTag)
2325
if err != nil {
2426
return nil, err
2527
}
@@ -34,20 +36,47 @@ func (p *acmCertificateProvider) GetCertificates() ([]*certs.CertificateSummary,
3436
return result, nil
3537
}
3638

37-
func getACMCertificateSummaries(api acmiface.ACMAPI) ([]*acm.CertificateSummary, error) {
39+
func getACMCertificateSummaries(api acmiface.ACMAPI, filterTag string) ([]*acm.CertificateSummary, error) {
3840
params := &acm.ListCertificatesInput{
3941
CertificateStatuses: []*string{
4042
aws.String(acm.CertificateStatusIssued),
4143
},
4244
}
4345
acmSummaries := make([]*acm.CertificateSummary, 0)
46+
4447
err := api.ListCertificatesPages(params, func(page *acm.ListCertificatesOutput, lastPage bool) bool {
4548
acmSummaries = append(acmSummaries, page.CertificateSummaryList...)
4649
return true
4750
})
51+
52+
if tag := strings.Split(filterTag, "="); filterTag != "=" && len(tag) == 2 {
53+
return filterCertificatesByTag(api, acmSummaries, tag[0], tag[1])
54+
}
55+
4856
return acmSummaries, err
4957
}
5058

59+
func filterCertificatesByTag(api acmiface.ACMAPI, allSummaries []*acm.CertificateSummary, key, value string) ([]*acm.CertificateSummary, error) {
60+
prodSummaries := make([]*acm.CertificateSummary, 0)
61+
for _, summary := range allSummaries {
62+
in := &acm.ListTagsForCertificateInput{
63+
CertificateArn: summary.CertificateArn,
64+
}
65+
out, err := api.ListTagsForCertificate(in)
66+
if err != nil {
67+
return nil, err
68+
}
69+
70+
for _, tag := range out.Tags {
71+
if *tag.Key == key && *tag.Value == value {
72+
prodSummaries = append(prodSummaries, summary)
73+
}
74+
}
75+
}
76+
77+
return prodSummaries, nil
78+
}
79+
5180
func getCertificateSummaryFromACM(api acmiface.ACMAPI, arn *string) (*certs.CertificateSummary, error) {
5281
params := &acm.GetCertificateInput{CertificateArn: arn}
5382
resp, err := api.GetCertificate(params)

aws/acm_test.go

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,18 @@ type acmExpect struct {
1515
DomainNames []string
1616
Chain int
1717
Error error
18+
EmptyList bool
1819
}
1920

2021
func TestACM(t *testing.T) {
2122
cert := mustRead("acm.txt")
2223
chain := mustRead("chain.txt")
2324

2425
for _, ti := range []struct {
25-
msg string
26-
api acmiface.ACMAPI
27-
expect acmExpect
26+
msg string
27+
api acmiface.ACMAPI
28+
filterTag string
29+
expect acmExpect
2830
}{
2931
{
3032
msg: "Found ACM Cert foobar and a chain",
@@ -37,9 +39,11 @@ func TestACM(t *testing.T) {
3739
},
3840
},
3941
},
40-
acm.GetCertificateOutput{
41-
Certificate: aws.String(cert),
42-
CertificateChain: aws.String(chain),
42+
map[string]*acm.GetCertificateOutput{
43+
"foobar": {
44+
Certificate: aws.String(cert),
45+
CertificateChain: aws.String(chain),
46+
},
4347
},
4448
),
4549
expect: acmExpect{
@@ -59,19 +63,90 @@ func TestACM(t *testing.T) {
5963
},
6064
},
6165
},
62-
acm.GetCertificateOutput{
63-
Certificate: aws.String(cert),
66+
map[string]*acm.GetCertificateOutput{
67+
"foobar": {
68+
Certificate: aws.String(cert),
69+
},
70+
},
71+
),
72+
expect: acmExpect{
73+
ARN: "foobar",
74+
DomainNames: []string{"foobar.de"},
75+
Error: nil,
76+
},
77+
},
78+
{
79+
msg: "Found one ACM Cert with correct filter tag",
80+
api: fake.NewACMClientWithTags(
81+
acm.ListCertificatesOutput{
82+
CertificateSummaryList: []*acm.CertificateSummary{
83+
{
84+
CertificateArn: aws.String("foobar"),
85+
DomainName: aws.String("foobar.de"),
86+
},
87+
{
88+
CertificateArn: aws.String("foobaz"),
89+
DomainName: aws.String("foobar.de"),
90+
},
91+
},
92+
},
93+
map[string]*acm.GetCertificateOutput{
94+
"foobar": {
95+
Certificate: aws.String(cert),
96+
},
97+
"foobaz": {
98+
Certificate: aws.String(cert),
99+
},
100+
},
101+
map[string]*acm.ListTagsForCertificateOutput{
102+
"foobar": {
103+
Tags: []*acm.Tag{{Key: aws.String("production"), Value: aws.String("true")}},
104+
},
105+
"foobaz": {
106+
Tags: []*acm.Tag{{Key: aws.String("production"), Value: aws.String("false")}},
107+
},
108+
},
109+
),
110+
filterTag: "production=true",
111+
expect: acmExpect{
112+
ARN: "foobar",
113+
DomainNames: []string{"foobar.de"},
114+
Error: nil,
115+
},
116+
},
117+
{
118+
msg: "ACM Cert with incorrect filter tag should not be found",
119+
api: fake.NewACMClientWithTags(
120+
acm.ListCertificatesOutput{
121+
CertificateSummaryList: []*acm.CertificateSummary{
122+
{
123+
CertificateArn: aws.String("foobar"),
124+
DomainName: aws.String("foobar.de"),
125+
},
126+
},
127+
},
128+
map[string]*acm.GetCertificateOutput{
129+
"foobar": {
130+
Certificate: aws.String(cert),
131+
},
132+
},
133+
map[string]*acm.ListTagsForCertificateOutput{
134+
"foobar": {
135+
Tags: []*acm.Tag{{Key: aws.String("production"), Value: aws.String("false")}},
136+
},
64137
},
65138
),
139+
filterTag: "production=true",
66140
expect: acmExpect{
141+
EmptyList: true,
67142
ARN: "foobar",
68143
DomainNames: []string{"foobar.de"},
69144
Error: nil,
70145
},
71146
},
72147
} {
73148
t.Run(ti.msg, func(t *testing.T) {
74-
provider := newACMCertProvider(ti.api)
149+
provider := newACMCertProvider(ti.api, ti.filterTag)
75150
list, err := provider.GetCertificates()
76151

77152
if ti.expect.Error != nil {
@@ -80,11 +155,16 @@ func TestACM(t *testing.T) {
80155
require.NoError(t, err)
81156
}
82157

83-
require.Equal(t, 1, len(list))
158+
if ti.expect.EmptyList {
159+
require.Equal(t, 0, len(list))
84160

85-
cert := list[0]
86-
require.Equal(t, ti.expect.ARN, cert.ID())
87-
require.Equal(t, ti.expect.DomainNames, cert.DomainNames())
161+
} else {
162+
require.Equal(t, 1, len(list))
163+
164+
cert := list[0]
165+
require.Equal(t, ti.expect.ARN, cert.ID())
166+
require.Equal(t, ti.expect.DomainNames, cert.DomainNames())
167+
}
88168
})
89169
}
90170
}

aws/adapter.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,12 @@ func (a *Adapter) UpdateManifest(clusterID, vpcID string) (*Adapter, error) {
273273
return a, err
274274
}
275275

276-
func (a *Adapter) NewACMCertificateProvider() certs.CertificatesProvider {
277-
return newACMCertProvider(a.acm)
276+
func (a *Adapter) NewACMCertificateProvider(certFilterTag string) certs.CertificatesProvider {
277+
return newACMCertProvider(a.acm, certFilterTag)
278278
}
279279

280-
func (a *Adapter) NewIAMCertificateProvider() certs.CertificatesProvider {
281-
return newIAMCertProvider(a.iam)
280+
func (a *Adapter) NewIAMCertificateProvider(certFilterTag string) certs.CertificatesProvider {
281+
return newIAMCertProvider(a.iam, certFilterTag)
282282
}
283283

284284
// WithHealthCheckPath returns the receiver adapter after changing the health check path that will be used by

aws/fake/acm.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
package fake
22

33
import (
4+
"fmt"
5+
46
"github.com/aws/aws-sdk-go/service/acm"
57
"github.com/aws/aws-sdk-go/service/acm/acmiface"
68
)
79

810
type ACMClient struct {
911
acmiface.ACMAPI
1012
output acm.ListCertificatesOutput
11-
cert acm.GetCertificateOutput
13+
cert map[string]*acm.GetCertificateOutput
14+
tags map[string]*acm.ListTagsForCertificateOutput
1215
}
1316

1417
func (m ACMClient) ListCertificates(in *acm.ListCertificatesInput) (*acm.ListCertificatesOutput, error) {
@@ -21,12 +24,32 @@ func (m ACMClient) ListCertificatesPages(input *acm.ListCertificatesInput, fn fu
2124
}
2225

2326
func (m ACMClient) GetCertificate(input *acm.GetCertificateInput) (*acm.GetCertificateOutput, error) {
24-
return &m.cert, nil
27+
return m.cert[*input.CertificateArn], nil
28+
}
29+
30+
func (m ACMClient) ListTagsForCertificate(in *acm.ListTagsForCertificateInput) (*acm.ListTagsForCertificateOutput, error) {
31+
if in.CertificateArn == nil {
32+
return nil, fmt.Errorf("expected a valid CertificateArn, got: nil")
33+
}
34+
arn := *in.CertificateArn
35+
return m.tags[arn], nil
36+
}
37+
38+
func NewACMClient(output acm.ListCertificatesOutput, cert map[string]*acm.GetCertificateOutput) ACMClient {
39+
return ACMClient{
40+
output: output,
41+
cert: cert,
42+
}
2543
}
2644

27-
func NewACMClient(output acm.ListCertificatesOutput, cert acm.GetCertificateOutput) ACMClient {
45+
func NewACMClientWithTags(
46+
output acm.ListCertificatesOutput,
47+
cert map[string]*acm.GetCertificateOutput,
48+
tags map[string]*acm.ListTagsForCertificateOutput,
49+
) ACMClient {
2850
return ACMClient{
2951
output: output,
3052
cert: cert,
53+
tags: tags,
3154
}
3255
}

aws/fake/iam.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package fake
22

33
import (
4+
"fmt"
5+
46
"github.com/aws/aws-sdk-go/service/iam"
57
"github.com/aws/aws-sdk-go/service/iam/iamiface"
68
)
@@ -9,6 +11,7 @@ type IAMClient struct {
911
iamiface.IAMAPI
1012
list iam.ListServerCertificatesOutput
1113
cert iam.GetServerCertificateOutput
14+
tags map[string]*iam.ListServerCertificateTagsOutput
1215
}
1316

1417
func (m IAMClient) ListServerCertificates(*iam.ListServerCertificatesInput) (*iam.ListServerCertificatesOutput, error) {
@@ -20,6 +23,17 @@ func (m IAMClient) ListServerCertificatesPages(input *iam.ListServerCertificates
2023
return nil
2124
}
2225

26+
func (m IAMClient) ListServerCertificateTags(
27+
in *iam.ListServerCertificateTagsInput,
28+
) (*iam.ListServerCertificateTagsOutput, error) {
29+
30+
if in.ServerCertificateName == nil {
31+
return nil, fmt.Errorf("expected a valid CertificateArn, got: nil")
32+
}
33+
name := *in.ServerCertificateName
34+
return m.tags[name], nil
35+
}
36+
2337
func (m IAMClient) GetServerCertificate(*iam.GetServerCertificateInput) (*iam.GetServerCertificateOutput, error) {
2438
return &m.cert, nil
2539
}
@@ -30,3 +44,15 @@ func NewIAMClient(list iam.ListServerCertificatesOutput, cert iam.GetServerCerti
3044
cert: cert,
3145
}
3246
}
47+
48+
func NewIAMClientWithTag(
49+
list iam.ListServerCertificatesOutput,
50+
cert iam.GetServerCertificateOutput,
51+
tags map[string]*iam.ListServerCertificateTagsOutput,
52+
) IAMClient {
53+
return IAMClient{
54+
list: list,
55+
cert: cert,
56+
tags: tags,
57+
}
58+
}

0 commit comments

Comments
 (0)