Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
nesmabadr committed May 3, 2024
1 parent 1589c46 commit 256feb1
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package cacertificatehandler_test

import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"encoding/pem"
"fmt"
"os"
"testing"

"github.com/stretchr/testify/assert"

"github.com/kyma-project/runtime-watcher/skr/internal/cacertificatehandler"
"github.com/kyma-project/runtime-watcher/skr/internal/tlstest"
)

func TestGetCertificatePool(t *testing.T) {

Check failure on line 18 in runtime-watcher/internal/cacertificatehandler/ca_certificate_handler_test.go

View workflow job for this annotation

GitHub Actions / lint-build-test

Function TestGetCertificatePool missing the call to method parallel
certPath := "ca.crt"
err := createCaCertFile(certPath)
assert.NoError(t, err)

Check failure on line 21 in runtime-watcher/internal/cacertificatehandler/ca_certificate_handler_test.go

View workflow job for this annotation

GitHub Actions / lint-build-test

require-error: for error assertions use require (testifylint)
defer deleteCaCertFile(certPath)

Check failure on line 22 in runtime-watcher/internal/cacertificatehandler/ca_certificate_handler_test.go

View workflow job for this annotation

GitHub Actions / lint-build-test

Error return value is not checked (errcheck)

got, err := cacertificatehandler.GetCertificatePool(certPath)
assert.NoError(t, err)

Check failure on line 25 in runtime-watcher/internal/cacertificatehandler/ca_certificate_handler_test.go

View workflow job for this annotation

GitHub Actions / lint-build-test

require-error: for error assertions use require (testifylint)
assert.NotNil(t, got)

assert.Equal(t, 2, len(got.Subjects()))

Check failure on line 28 in runtime-watcher/internal/cacertificatehandler/ca_certificate_handler_test.go

View workflow job for this annotation

GitHub Actions / lint-build-test

SA1019: got.Subjects has been deprecated since Go 1.18: if s was returned by [SystemCertPool], Subjects will not include the system roots. (staticcheck)
assert.Contains(t, string(got.Subjects()[0]), "oldCert")

Check failure on line 29 in runtime-watcher/internal/cacertificatehandler/ca_certificate_handler_test.go

View workflow job for this annotation

GitHub Actions / lint-build-test

SA1019: got.Subjects has been deprecated since Go 1.18: if s was returned by [SystemCertPool], Subjects will not include the system roots. (staticcheck)
assert.Contains(t, string(got.Subjects()[1]), "newCert")

Check failure on line 30 in runtime-watcher/internal/cacertificatehandler/ca_certificate_handler_test.go

View workflow job for this annotation

GitHub Actions / lint-build-test

SA1019: got.Subjects has been deprecated since Go 1.18: if s was returned by [SystemCertPool], Subjects will not include the system roots. (staticcheck)
}

func createCaCertFile(certPath string) error {
certFile, err := os.OpenFile(certPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)

Check failure on line 34 in runtime-watcher/internal/cacertificatehandler/ca_certificate_handler_test.go

View workflow job for this annotation

GitHub Actions / lint-build-test

File is not `gofumpt`-ed (gofumpt)
if err != nil {
return fmt.Errorf("failed to create cert file: %w", err)
}
firstCert, err := createCertificate("oldCert")
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
}

secondCert, err := createCertificate("newCert")
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
}

firstCertBytes := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: firstCert.Certificate[0],
})

secondCertBytes := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: secondCert.Certificate[0],
})

if _, err = certFile.Write(firstCertBytes); err != nil {
return fmt.Errorf("failed to write cert file: %w", err)
}
if _, err = certFile.Write(secondCertBytes); err != nil {
return fmt.Errorf("failed to write cert file: %w", err)
}

return nil
}

func createCertificate(subjectName string) (*tls.Certificate, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("failed to create key: %w", err)
}
certTemplate, err := tlstest.CreateCertTemplate(true)
certTemplate.Subject.CommonName = subjectName
if err != nil {
return nil, fmt.Errorf("failed to create cert template: %w", err)
}
cert, err := tlstest.CreateCert(certTemplate, certTemplate, key, key)
if err != nil {
return nil, fmt.Errorf("failed to create cert: %w", err)
}

return cert, nil
}

func deleteCaCertFile(certPath string) error {
err := os.Remove(certPath)
if err != nil {
return fmt.Errorf("failed to delete cert file: %w", err)
}

return nil
}
16 changes: 8 additions & 8 deletions runtime-watcher/internal/tlstest/certificate_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (p *CertProvider) CleanUp() error {
return p.removeTempFiles()
}

func createCertTemplate(isCA bool) (*x509.Certificate, error) {
func CreateCertTemplate(isCA bool) (*x509.Certificate, error) {
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), certSerialNumberUpperLimit))
if err != nil {
return nil, fmt.Errorf("serial number generation failed: %w", err)
Expand All @@ -120,7 +120,7 @@ func createCertTemplate(isCA bool) (*x509.Certificate, error) {
return template, nil
}

func createCert(template, parent *x509.Certificate, privateKey *rsa.PrivateKey, rootKey *rsa.PrivateKey) (
func CreateCert(template, parent *x509.Certificate, privateKey *rsa.PrivateKey, rootKey *rsa.PrivateKey) (
*tls.Certificate, error,
) {
certBytes, err := x509.CreateCertificate(rand.Reader, template, parent, &privateKey.PublicKey, rootKey)
Expand All @@ -146,11 +146,11 @@ func (p *CertProvider) GenerateCerts() error {
if err != nil {
return fmt.Errorf("%s: %w", errMsgCreatingPrivateKey, err)
}
rootTemplate, err := createCertTemplate(true)
rootTemplate, err := CreateCertTemplate(true)
if err != nil {
return err
}
p.RootCert, err = createCert(rootTemplate, rootTemplate, rootKey, rootKey)
p.RootCert, err = CreateCert(rootTemplate, rootTemplate, rootKey, rootKey)
if err != nil {
return err
}
Expand All @@ -163,12 +163,12 @@ func (p *CertProvider) GenerateCerts() error {
if err != nil {
return fmt.Errorf("%s: %w", errMsgCreatingPrivateKey, err)
}
serverTemplate, err := createCertTemplate(false)
serverTemplate, err := CreateCertTemplate(false)
if err != nil {
return err
}
serverTemplate.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
p.ServerCert, err = createCert(serverTemplate, rootTemplate, serverKey, rootKey)
p.ServerCert, err = CreateCert(serverTemplate, rootTemplate, serverKey, rootKey)
if err != nil {
return err
}
Expand All @@ -177,12 +177,12 @@ func (p *CertProvider) GenerateCerts() error {
if err != nil {
return fmt.Errorf("%s: %w", errMsgCreatingPrivateKey, err)
}
clientTemplate, err := createCertTemplate(false)
clientTemplate, err := CreateCertTemplate(false)
if err != nil {
return err
}
clientTemplate.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
clientCert, err := createCert(clientTemplate, rootTemplate, clientKey, rootKey)
clientCert, err := CreateCert(clientTemplate, rootTemplate, clientKey, rootKey)
if err != nil {
return err
}
Expand Down

0 comments on commit 256feb1

Please sign in to comment.