diff --git a/pkg/aws/service.go b/pkg/aws/service.go index 206a6324..f77011a5 100644 --- a/pkg/aws/service.go +++ b/pkg/aws/service.go @@ -22,6 +22,7 @@ type awsClient interface { Sign(input *kms.SignInput) (*kms.SignOutput, error) GetPublicKey(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error) Verify(input *kms.VerifyInput) (*kms.VerifyOutput, error) + ListKeys(input *kms.ListKeysInput) (*kms.ListKeysOutput, error) } type metricsProvider interface { @@ -86,6 +87,23 @@ func (s *Service) Get(keyID string) (interface{}, error) { return keyID, nil } +// HealthCheck check kms. +func (s *Service) HealthCheck() error { + var limit int64 + limit = 1 + + result, err := s.client.ListKeys(&kms.ListKeysInput{Limit: &limit}) + if err != nil { + return err + } + + if len(result.Keys) == 0 { + return fmt.Errorf("list of keys are empty") + } + + return nil +} + // ExportPubKeyBytes export public key. func (s *Service) ExportPubKeyBytes(keyURI string) ([]byte, arieskms.KeyType, error) { startTime := time.Now() diff --git a/pkg/aws/service_test.go b/pkg/aws/service_test.go index 9a2dcb4a..1fd32686 100644 --- a/pkg/aws/service_test.go +++ b/pkg/aws/service_test.go @@ -83,6 +83,71 @@ func TestSign(t *testing.T) { }) } +func TestHealthCheck(t *testing.T) { + t.Run("success", func(t *testing.T) { + endpoint := localhost + awsSession, err := session.NewSession(&aws.Config{ + Endpoint: &endpoint, + Region: aws.String("ca"), + CredentialsChainVerboseErrors: aws.Bool(true), + }) + require.NoError(t, err) + + svc := New(awsSession, &mockMetrics{}) + + keyID := "key1" + + svc.client = &mockAWSClient{listKeysFunc: func(input *kms.ListKeysInput) (*kms.ListKeysOutput, error) { + return &kms.ListKeysOutput{ + Keys: []*kms.KeyListEntry{{KeyId: &keyID}}, + }, nil + }} + + err = svc.HealthCheck() + require.NoError(t, err) + }) + + t.Run("failed to list keys", func(t *testing.T) { + endpoint := localhost + awsSession, err := session.NewSession(&aws.Config{ + Endpoint: &endpoint, + Region: aws.String("ca"), + CredentialsChainVerboseErrors: aws.Bool(true), + }) + require.NoError(t, err) + + svc := New(awsSession, &mockMetrics{}) + + svc.client = &mockAWSClient{listKeysFunc: func(input *kms.ListKeysInput) (*kms.ListKeysOutput, error) { + return nil, fmt.Errorf("failed to list keys") + }} + + err = svc.HealthCheck() + require.Error(t, err) + require.Contains(t, err.Error(), "failed to list keys") + }) + + t.Run("empty keys", func(t *testing.T) { + endpoint := localhost + awsSession, err := session.NewSession(&aws.Config{ + Endpoint: &endpoint, + Region: aws.String("ca"), + CredentialsChainVerboseErrors: aws.Bool(true), + }) + require.NoError(t, err) + + svc := New(awsSession, &mockMetrics{}) + + svc.client = &mockAWSClient{listKeysFunc: func(input *kms.ListKeysInput) (*kms.ListKeysOutput, error) { + return &kms.ListKeysOutput{}, nil + }} + + err = svc.HealthCheck() + require.Error(t, err) + require.Contains(t, err.Error(), "list of keys are empty") + }) +} + func TestGet(t *testing.T) { t.Run("success", func(t *testing.T) { endpoint := localhost @@ -230,6 +295,7 @@ type mockAWSClient struct { signFunc func(input *kms.SignInput) (*kms.SignOutput, error) getPublicKeyFunc func(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error) verifyFunc func(input *kms.VerifyInput) (*kms.VerifyOutput, error) + listKeysFunc func(input *kms.ListKeysInput) (*kms.ListKeysOutput, error) } func (m *mockAWSClient) Sign(input *kms.SignInput) (*kms.SignOutput, error) { @@ -256,6 +322,14 @@ func (m *mockAWSClient) Verify(input *kms.VerifyInput) (*kms.VerifyOutput, error return nil, nil } +func (m *mockAWSClient) ListKeys(input *kms.ListKeysInput) (*kms.ListKeysOutput, error) { + if m.listKeysFunc != nil { + return m.listKeysFunc(input) + } + + return nil, nil +} + type mockMetrics struct{} func (m *mockMetrics) SignCount() {