Skip to content

Add option to filter certificates by tag before adding it to LB #658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions aws/acm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aws

import (
"crypto/x509"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/acm"
Expand All @@ -10,16 +11,17 @@ import (
)

type acmCertificateProvider struct {
api acmiface.ACMAPI
api acmiface.ACMAPI
filterTag string
}

func newACMCertProvider(api acmiface.ACMAPI) certs.CertificatesProvider {
return &acmCertificateProvider{api: api}
func newACMCertProvider(api acmiface.ACMAPI, certFilterTag string) certs.CertificatesProvider {
return &acmCertificateProvider{api: api, filterTag: certFilterTag}
}

// GetCertificates returns a list of AWS ACM certificates
func (p *acmCertificateProvider) GetCertificates() ([]*certs.CertificateSummary, error) {
acmSummaries, err := getACMCertificateSummaries(p.api)
acmSummaries, err := getACMCertificateSummaries(p.api, p.filterTag)
if err != nil {
return nil, err
}
Expand All @@ -34,20 +36,47 @@ func (p *acmCertificateProvider) GetCertificates() ([]*certs.CertificateSummary,
return result, nil
}

func getACMCertificateSummaries(api acmiface.ACMAPI) ([]*acm.CertificateSummary, error) {
func getACMCertificateSummaries(api acmiface.ACMAPI, filterTag string) ([]*acm.CertificateSummary, error) {
params := &acm.ListCertificatesInput{
CertificateStatuses: []*string{
aws.String(acm.CertificateStatusIssued),
},
}
acmSummaries := make([]*acm.CertificateSummary, 0)

err := api.ListCertificatesPages(params, func(page *acm.ListCertificatesOutput, lastPage bool) bool {
acmSummaries = append(acmSummaries, page.CertificateSummaryList...)
return true
})

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should add a

if filterTag = "" {
return acmSummaries
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could but I believe this case is covered in the if in line 52, because if filterTag == "" then len(tag) != 2, and in this case we will just return acmSummaries in line 56, which is just after the if.

if tag := strings.Split(filterTag, "="); len(tag) == 2 {
return filterCertificatesByTag(api, acmSummaries, tag[0], tag[1])
}

return acmSummaries, err
}

func filterCertificatesByTag(api acmiface.ACMAPI, allSummaries []*acm.CertificateSummary, key, value string) ([]*acm.CertificateSummary, error) {
prodSummaries := make([]*acm.CertificateSummary, 0)
for _, summary := range allSummaries {
in := &acm.ListTagsForCertificateInput{
CertificateArn: summary.CertificateArn,
}
out, err := api.ListTagsForCertificate(in)
if err != nil {
return nil, err
}

for _, tag := range out.Tags {
if *tag.Key == key && *tag.Value == value {
prodSummaries = append(prodSummaries, summary)
}
}
}

return prodSummaries, nil
}

func getCertificateSummaryFromACM(api acmiface.ACMAPI, arn *string) (*certs.CertificateSummary, error) {
params := &acm.GetCertificateInput{CertificateArn: arn}
resp, err := api.GetCertificate(params)
Expand Down
78 changes: 70 additions & 8 deletions aws/acm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@ type acmExpect struct {
DomainNames []string
Chain int
Error error
EmptyList bool
}

func TestACM(t *testing.T) {
cert := mustRead("acm.txt")
chain := mustRead("chain.txt")

for _, ti := range []struct {
msg string
api acmiface.ACMAPI
expect acmExpect
msg string
api acmiface.ACMAPI
filterTag string
expect acmExpect
}{
{
msg: "Found ACM Cert foobar and a chain",
Expand Down Expand Up @@ -69,9 +71,64 @@ func TestACM(t *testing.T) {
Error: nil,
},
},
{
msg: "Found ACM Cert with correct filter tag",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you like to add tests with several certificates in the list.
For example, "2 certs with corresponding tag" or "1 cert with corresponding tag + 1 cert with not corresponding tag"?

api: fake.NewACMClientWithTags(
acm.ListCertificatesOutput{
CertificateSummaryList: []*acm.CertificateSummary{
{
CertificateArn: aws.String("foobar"),
DomainName: aws.String("foobar.de"),
},
},
},
acm.GetCertificateOutput{
Certificate: aws.String(cert),
},
map[string]*acm.ListTagsForCertificateOutput{
"foobar": {
Tags: []*acm.Tag{{Key: aws.String("production"), Value: aws.String("true")}},
},
},
),
filterTag: "production=true",
expect: acmExpect{
ARN: "foobar",
DomainNames: []string{"foobar.de"},
Error: nil,
},
},
{
msg: "Found ACM Cert with incorrect filter tag",
api: fake.NewACMClientWithTags(
acm.ListCertificatesOutput{
CertificateSummaryList: []*acm.CertificateSummary{
{
CertificateArn: aws.String("foobar"),
DomainName: aws.String("foobar.de"),
},
},
},
acm.GetCertificateOutput{
Certificate: aws.String(cert),
},
map[string]*acm.ListTagsForCertificateOutput{
"foobar": {
Tags: []*acm.Tag{{Key: aws.String("production"), Value: aws.String("false")}},
},
},
),
filterTag: "production=true",
expect: acmExpect{
EmptyList: true,
ARN: "foobar",
DomainNames: []string{"foobar.de"},
Error: nil,
},
},
} {
t.Run(ti.msg, func(t *testing.T) {
provider := newACMCertProvider(ti.api)
provider := newACMCertProvider(ti.api, ti.filterTag)
list, err := provider.GetCertificates()

if ti.expect.Error != nil {
Expand All @@ -80,11 +137,16 @@ func TestACM(t *testing.T) {
require.NoError(t, err)
}

require.Equal(t, 1, len(list))
if ti.expect.EmptyList {
require.Equal(t, 0, len(list))

cert := list[0]
require.Equal(t, ti.expect.ARN, cert.ID())
require.Equal(t, ti.expect.DomainNames, cert.DomainNames())
} else {
require.Equal(t, 1, len(list))

cert := list[0]
require.Equal(t, ti.expect.ARN, cert.ID())
require.Equal(t, ti.expect.DomainNames, cert.DomainNames())
}
})
}
}
8 changes: 4 additions & 4 deletions aws/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,12 @@ func (a *Adapter) UpdateManifest(clusterID, vpcID string) (*Adapter, error) {
return a, err
}

func (a *Adapter) NewACMCertificateProvider() certs.CertificatesProvider {
return newACMCertProvider(a.acm)
func (a *Adapter) NewACMCertificateProvider(certFilterTag string) certs.CertificatesProvider {
return newACMCertProvider(a.acm, certFilterTag)
}

func (a *Adapter) NewIAMCertificateProvider() certs.CertificatesProvider {
return newIAMCertProvider(a.iam)
func (a *Adapter) NewIAMCertificateProvider(certFilterTag string) certs.CertificatesProvider {
return newIAMCertProvider(a.iam, certFilterTag)
}

// WithHealthCheckPath returns the receiver adapter after changing the health check path that will be used by
Expand Down
23 changes: 23 additions & 0 deletions aws/fake/acm.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package fake

import (
"fmt"

"github.com/aws/aws-sdk-go/service/acm"
"github.com/aws/aws-sdk-go/service/acm/acmiface"
)
Expand All @@ -9,6 +11,7 @@ type ACMClient struct {
acmiface.ACMAPI
output acm.ListCertificatesOutput
cert acm.GetCertificateOutput
tags map[string]*acm.ListTagsForCertificateOutput
}

func (m ACMClient) ListCertificates(in *acm.ListCertificatesInput) (*acm.ListCertificatesOutput, error) {
Expand All @@ -24,9 +27,29 @@ func (m ACMClient) GetCertificate(input *acm.GetCertificateInput) (*acm.GetCerti
return &m.cert, nil
}

func (m ACMClient) ListTagsForCertificate(in *acm.ListTagsForCertificateInput) (*acm.ListTagsForCertificateOutput, error) {
if in.CertificateArn == nil {
return nil, fmt.Errorf("expected a valid CertificateArn, got: nil")
}
arn := *in.CertificateArn
return m.tags[arn], nil
}

func NewACMClient(output acm.ListCertificatesOutput, cert acm.GetCertificateOutput) ACMClient {
return ACMClient{
output: output,
cert: cert,
}
}

func NewACMClientWithTags(
output acm.ListCertificatesOutput,
cert acm.GetCertificateOutput,
tags map[string]*acm.ListTagsForCertificateOutput,
) ACMClient {
return ACMClient{
output: output,
cert: cert,
tags: tags,
}
}
26 changes: 26 additions & 0 deletions aws/fake/iam.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package fake

import (
"fmt"

"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
)
Expand All @@ -9,6 +11,7 @@ type IAMClient struct {
iamiface.IAMAPI
list iam.ListServerCertificatesOutput
cert iam.GetServerCertificateOutput
tags map[string]*iam.ListServerCertificateTagsOutput
}

func (m IAMClient) ListServerCertificates(*iam.ListServerCertificatesInput) (*iam.ListServerCertificatesOutput, error) {
Expand All @@ -20,6 +23,17 @@ func (m IAMClient) ListServerCertificatesPages(input *iam.ListServerCertificates
return nil
}

func (m IAMClient) ListServerCertificateTags(
in *iam.ListServerCertificateTagsInput,
) (*iam.ListServerCertificateTagsOutput, error) {

if in.ServerCertificateName == nil {
return nil, fmt.Errorf("expected a valid CertificateArn, got: nil")
}
name := *in.ServerCertificateName
return m.tags[name], nil
}

func (m IAMClient) GetServerCertificate(*iam.GetServerCertificateInput) (*iam.GetServerCertificateOutput, error) {
return &m.cert, nil
}
Expand All @@ -30,3 +44,15 @@ func NewIAMClient(list iam.ListServerCertificatesOutput, cert iam.GetServerCerti
cert: cert,
}
}

func NewIAMClientWithTag(
list iam.ListServerCertificatesOutput,
cert iam.GetServerCertificateOutput,
tags map[string]*iam.ListServerCertificateTagsOutput,
) IAMClient {
return IAMClient{
list: list,
cert: cert,
tags: tags,
}
}
35 changes: 32 additions & 3 deletions aws/iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aws

import (
"crypto/x509"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/iam"
Expand All @@ -10,11 +11,12 @@ import (
)

type iamCertificateProvider struct {
api iamiface.IAMAPI
api iamiface.IAMAPI
filterTag string
}

func newIAMCertProvider(api iamiface.IAMAPI) certs.CertificatesProvider {
return &iamCertificateProvider{api: api}
func newIAMCertProvider(api iamiface.IAMAPI, filterTag string) certs.CertificatesProvider {
return &iamCertificateProvider{api: api, filterTag: filterTag}
}

// GetCertificates returns a list of AWS IAM certificates
Expand All @@ -25,6 +27,17 @@ func (p *iamCertificateProvider) GetCertificates() ([]*certs.CertificateSummary,
}
list := make([]*certs.CertificateSummary, 0)
for _, o := range serverCertificatesMetadata {
if kv := strings.Split(p.filterTag, "="); len(kv) == 2 {
hasTag, err := certHasTag(p.api, *o.ServerCertificateName, kv[0], kv[1])
if err != nil {
return nil, err
}

if !hasTag {
continue
}
}

certDetail, err := getCertificateSummaryFromIAM(p.api, aws.StringValue(o.ServerCertificateName))
if err != nil {
return nil, err
Expand All @@ -34,6 +47,22 @@ func (p *iamCertificateProvider) GetCertificates() ([]*certs.CertificateSummary,
return list, nil
}

func certHasTag(api iamiface.IAMAPI, certName, key, value string) (bool, error) {
t, err := api.ListServerCertificateTags(&iam.ListServerCertificateTagsInput{
ServerCertificateName: &certName,
})
if err != nil {
return false, err
}
for _, tag := range t.Tags {
if *tag.Key == key && *tag.Value == value {
return true, nil
}
}

return false, nil
}

func getIAMServerCertificateMetadata(api iamiface.IAMAPI) ([]*iam.ServerCertificateMetadata, error) {
params := &iam.ListServerCertificatesInput{
PathPrefix: aws.String("/"),
Expand Down
Loading