diff --git a/aws/acm.go b/aws/acm.go index 309b41a2..a6018e41 100644 --- a/aws/acm.go +++ b/aws/acm.go @@ -44,16 +44,18 @@ func getACMCertificateSummaries(api acmiface.ACMAPI, filterTag string) ([]*acm.C } acmSummaries := make([]*acm.CertificateSummary, 0) - err := api.ListCertificatesPages(params, func(page *acm.ListCertificatesOutput, lastPage bool) bool { + if err := api.ListCertificatesPages(params, func(page *acm.ListCertificatesOutput, lastPage bool) bool { acmSummaries = append(acmSummaries, page.CertificateSummaryList...) return true - }) + }); err != nil { + return nil, err + } if tag := strings.Split(filterTag, "="); filterTag != "=" && len(tag) == 2 { return filterCertificatesByTag(api, acmSummaries, tag[0], tag[1]) } - return acmSummaries, err + return acmSummaries, nil } func filterCertificatesByTag(api acmiface.ACMAPI, allSummaries []*acm.CertificateSummary, key, value string) ([]*acm.CertificateSummary, error) { diff --git a/aws/acm_test.go b/aws/acm_test.go index 9a48b1b6..ea5ff813 100644 --- a/aws/acm_test.go +++ b/aws/acm_test.go @@ -1,6 +1,7 @@ package aws import ( + "fmt" "testing" "github.com/aws/aws-sdk-go/aws" @@ -14,7 +15,7 @@ type acmExpect struct { ARN string DomainNames []string Chain int - Error error + Error string EmptyList bool } @@ -45,11 +46,11 @@ func TestACM(t *testing.T) { CertificateChain: aws.String(chain), }, }, + nil, ), expect: acmExpect{ ARN: "foobar", DomainNames: []string{"foobar.de"}, - Error: nil, }, }, { @@ -68,16 +69,16 @@ func TestACM(t *testing.T) { Certificate: aws.String(cert), }, }, + nil, ), expect: acmExpect{ ARN: "foobar", DomainNames: []string{"foobar.de"}, - Error: nil, }, }, { msg: "Found one ACM Cert with correct filter tag", - api: fake.NewACMClientWithTags( + api: fake.NewACMClient( acm.ListCertificatesOutput{ CertificateSummaryList: []*acm.CertificateSummary{ { @@ -111,12 +112,11 @@ func TestACM(t *testing.T) { expect: acmExpect{ ARN: "foobar", DomainNames: []string{"foobar.de"}, - Error: nil, }, }, { msg: "ACM Cert with incorrect filter tag should not be found", - api: fake.NewACMClientWithTags( + api: fake.NewACMClient( acm.ListCertificatesOutput{ CertificateSummaryList: []*acm.CertificateSummary{ { @@ -141,7 +141,18 @@ func TestACM(t *testing.T) { EmptyList: true, ARN: "foobar", DomainNames: []string{"foobar.de"}, - Error: nil, + }, + }, + { + msg: "Fail on ListCertificatesPages error", + api: fake.NewACMClient( + acm.ListCertificatesOutput{}, nil, nil, + ).WithListCertificatesPages(func(input *acm.ListCertificatesInput, fn func(p *acm.ListCertificatesOutput, lastPage bool) (shouldContinue bool)) error { + return fmt.Errorf("ListCertificatesPages error") + }), + filterTag: "production=true", + expect: acmExpect{ + Error: "ListCertificatesPages error", }, }, } { @@ -149,15 +160,15 @@ func TestACM(t *testing.T) { provider := newACMCertProvider(ti.api, ti.filterTag) list, err := provider.GetCertificates() - if ti.expect.Error != nil { - require.Equal(t, ti.expect.Error, err) + if ti.expect.Error != "" { + require.EqualError(t, err, ti.expect.Error) + return } else { require.NoError(t, err) } if ti.expect.EmptyList { require.Equal(t, 0, len(list)) - } else { require.Equal(t, 1, len(list)) diff --git a/aws/fake/acm.go b/aws/fake/acm.go index 18e7f967..c9a01cd0 100644 --- a/aws/fake/acm.go +++ b/aws/fake/acm.go @@ -9,25 +9,21 @@ import ( type ACMClient struct { acmiface.ACMAPI - output acm.ListCertificatesOutput - cert map[string]*acm.GetCertificateOutput - tags map[string]*acm.ListTagsForCertificateOutput -} + cert map[string]*acm.GetCertificateOutput + tags map[string]*acm.ListTagsForCertificateOutput -func (m ACMClient) ListCertificates(in *acm.ListCertificatesInput) (*acm.ListCertificatesOutput, error) { - return &m.output, nil + listCertificatesPages func(input *acm.ListCertificatesInput, fn func(p *acm.ListCertificatesOutput, lastPage bool) (shouldContinue bool)) error } -func (m ACMClient) ListCertificatesPages(input *acm.ListCertificatesInput, fn func(p *acm.ListCertificatesOutput, lastPage bool) (shouldContinue bool)) error { - fn(&m.output, true) - return nil +func (m *ACMClient) ListCertificatesPages(input *acm.ListCertificatesInput, fn func(p *acm.ListCertificatesOutput, lastPage bool) (shouldContinue bool)) error { + return m.listCertificatesPages(input, fn) } -func (m ACMClient) GetCertificate(input *acm.GetCertificateInput) (*acm.GetCertificateOutput, error) { +func (m *ACMClient) GetCertificate(input *acm.GetCertificateInput) (*acm.GetCertificateOutput, error) { return m.cert[*input.CertificateArn], nil } -func (m ACMClient) ListTagsForCertificate(in *acm.ListTagsForCertificateInput) (*acm.ListTagsForCertificateOutput, error) { +func (m *ACMClient) ListTagsForCertificate(in *acm.ListTagsForCertificateInput) (*acm.ListTagsForCertificateOutput, error) { if in.CertificateArn == nil { return nil, fmt.Errorf("expected a valid CertificateArn, got: nil") } @@ -35,21 +31,23 @@ func (m ACMClient) ListTagsForCertificate(in *acm.ListTagsForCertificateInput) ( return m.tags[arn], nil } -func NewACMClient(output acm.ListCertificatesOutput, cert map[string]*acm.GetCertificateOutput) ACMClient { - return ACMClient{ - output: output, - cert: cert, - } +func (m *ACMClient) WithListCertificatesPages(f func(input *acm.ListCertificatesInput, fn func(p *acm.ListCertificatesOutput, lastPage bool) (shouldContinue bool)) error) *ACMClient { + m.listCertificatesPages = f + return m } -func NewACMClientWithTags( +func NewACMClient( output acm.ListCertificatesOutput, cert map[string]*acm.GetCertificateOutput, tags map[string]*acm.ListTagsForCertificateOutput, -) ACMClient { - return ACMClient{ - output: output, - cert: cert, - tags: tags, +) *ACMClient { + c := &ACMClient{ + cert: cert, + tags: tags, } + c.WithListCertificatesPages(func(input *acm.ListCertificatesInput, fn func(p *acm.ListCertificatesOutput, lastPage bool) (shouldContinue bool)) error { + fn(&output, true) + return nil + }) + return c }