Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support multiple certificates for the ca certificate #270

Merged
merged 11 commits into from
May 10, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package cacertificatehandler

import (
"crypto/x509"
"fmt"
"os"
)

func GetCertificatePool(certPath string) (*x509.CertPool, error) {
certBytes, err := getCertBytes(certPath)
if err != nil {
return nil, err
}
rootCertPool := x509.NewCertPool()
ok := rootCertPool.AppendCertsFromPEM(certBytes)
if !ok {
msg := "failed to append certificate to pool"
return nil, fmt.Errorf("%s :%w", msg, err)
}
return rootCertPool, nil
}

func getCertBytes(certPath string) ([]byte, error) {
certBytes, err := os.ReadFile(certPath)
if err != nil {
msg := "could not load CA certificate"
return nil, fmt.Errorf("%s :%w", msg, err)
}

return certBytes, nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package cacertificatehandler_test

import (
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"os"
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/kyma-project/runtime-watcher/skr/internal/cacertificatehandler"
"github.com/kyma-project/runtime-watcher/skr/internal/tlstest"
)

func TestGetCertificatePool1(t *testing.T) {
t.Parallel()
tests := []struct {
name string
certificateCount int
certPath string
}{
{
name: "certificate pool with one certificate",
certificateCount: 1,
certPath: "ca-1.cert",
},
{
name: "certificate pool with two certificates",
certificateCount: 2,
certPath: "ca-2.cert",
},
}
for _, tt := range tests {
testCase := tt
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
nesmabadr marked this conversation as resolved.
Show resolved Hide resolved
file, err := os.CreateTemp("", testCase.certPath)
require.NoError(t, err)

err = writeCertificatesToFile(file, testCase.certificateCount)
require.NoError(t, err)

nesmabadr marked this conversation as resolved.
Show resolved Hide resolved
got, err := cacertificatehandler.GetCertificatePool(file.Name())
require.NoError(t, err)
require.False(t, got.Equal(x509.NewCertPool()))

certificates, err := getCertificates(file.Name())
require.NoError(t, err)
err = os.Remove(file.Name())
require.NoError(t, err)
expectedCertPool := x509.NewCertPool()
for _, certificate := range certificates {
expectedCertPool.AddCert(certificate)
}
require.True(t, got.Equal(expectedCertPool))
})
}
}

func getCertificates(certPath string) ([]*x509.Certificate, error) {
caCertBytes, err := os.ReadFile(certPath)
if err != nil {
return nil, fmt.Errorf("could not load CA certificate :%w", err)
}
var certs []*x509.Certificate
remainingCert := caCertBytes
for len(remainingCert) > 0 {
var publicPemBlock *pem.Block
publicPemBlock, remainingCert = pem.Decode(remainingCert)
rootPubCrt, errParse := x509.ParseCertificate(publicPemBlock.Bytes)
if errParse != nil {
msg := "failed to parse public key"
return nil, fmt.Errorf("%s :%w", msg, errParse)
}
certs = append(certs, rootPubCrt)
}

return certs, nil
}

func createCertificate() *x509.Certificate {
sn, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
cert := &x509.Certificate{
SerialNumber: sn,
Subject: pkix.Name{
CommonName: "127.0.0.1",
},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}

return cert
}

func writeCertificatesToFile(certFile *os.File, certificateCount int) error {
var certs []byte

for i := 0; i < certificateCount; i++ {
rootKey, err := tlstest.GenerateRootKey()
if err != nil {
return fmt.Errorf("failed to generate root key: %w", err)
}

certificate := createCertificate()
cert, err := tlstest.CreateCert(certificate, certificate, rootKey, rootKey)
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
}
certBytes := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Certificate[0],
})
certs = append(certs, certBytes...)
}

if _, err := certFile.Write(certs); err != nil {
return fmt.Errorf("failed to write certificates to file: %w", err)
}

return nil
}
18 changes: 4 additions & 14 deletions runtime-watcher/internal/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@ package internal
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"net/http"
"os"
"reflect"
"strings"
"time"
Expand All @@ -26,6 +23,7 @@ import (
"github.com/go-logr/logr"
listenerTypes "github.com/kyma-project/runtime-watcher/listener/pkg/types"

"github.com/kyma-project/runtime-watcher/skr/internal/cacertificatehandler"
"github.com/kyma-project/runtime-watcher/skr/internal/requestparser"
"github.com/kyma-project/runtime-watcher/skr/internal/serverconfig"
"github.com/kyma-project/runtime-watcher/skr/internal/watchermetrics"
Expand Down Expand Up @@ -324,19 +322,11 @@ func (h *Handler) getHTTPSClient() (*http.Client, error) {
msg := "could not load tls certificate"
return nil, fmt.Errorf("%s :%w", msg, err)
}
caCertBytes, err := os.ReadFile(h.config.CACertPath)

rootCertPool, err := cacertificatehandler.GetCertificatePool(h.config.CACertPath)
if err != nil {
msg := "could not load CA certificate"
return nil, fmt.Errorf("%s :%w", msg, err)
}
publicPemBlock, _ := pem.Decode(caCertBytes)
rootPubCrt, errParse := x509.ParseCertificate(publicPemBlock.Bytes)
if errParse != nil {
msg := "failed to parse public key"
return nil, fmt.Errorf("%s :%w", msg, errParse)
return nil, fmt.Errorf("failed to get certificate pool:%w", err)
}
rootCertPool := x509.NewCertPool()
rootCertPool.AddCert(rootPubCrt)

httpsClient.Timeout = HTTPTimeout
//nolint:gosec
Expand Down
20 changes: 14 additions & 6 deletions runtime-watcher/internal/tlstest/certificate_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func createCertTemplate(isCA bool) (*x509.Certificate, error) {
return template, nil
}

func createCert(template, parent *x509.Certificate, privateKey *rsa.PrivateKey, rootKey *rsa.PrivateKey) (
func CreateCert(template, parent *x509.Certificate, privateKey *rsa.PrivateKey, rootKey *rsa.PrivateKey) (
*tls.Certificate, error,
) {
certBytes, err := x509.CreateCertificate(rand.Reader, template, parent, &privateKey.PublicKey, rootKey)
Expand All @@ -141,16 +141,24 @@ func createCert(template, parent *x509.Certificate, privateKey *rsa.PrivateKey,
return &cert, nil
}

func (p *CertProvider) GenerateCerts() error {
func GenerateRootKey() (*rsa.PrivateKey, error) {
rootKey, err := rsa.GenerateKey(rand.Reader, privateKeyBits)
if err != nil {
return fmt.Errorf("%s: %w", errMsgCreatingPrivateKey, err)
return nil, fmt.Errorf("%s: %w", errMsgCreatingPrivateKey, err)
}
return rootKey, nil
}

func (p *CertProvider) GenerateCerts() error {
rootKey, err := GenerateRootKey()
if err != nil {
return err
}
rootTemplate, err := createCertTemplate(true)
if err != nil {
return err
}
p.RootCert, err = createCert(rootTemplate, rootTemplate, rootKey, rootKey)
p.RootCert, err = CreateCert(rootTemplate, rootTemplate, rootKey, rootKey)
if err != nil {
return err
}
Expand All @@ -168,7 +176,7 @@ func (p *CertProvider) GenerateCerts() error {
return err
}
serverTemplate.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
p.ServerCert, err = createCert(serverTemplate, rootTemplate, serverKey, rootKey)
p.ServerCert, err = CreateCert(serverTemplate, rootTemplate, serverKey, rootKey)
if err != nil {
return err
}
Expand All @@ -182,7 +190,7 @@ func (p *CertProvider) GenerateCerts() error {
return err
}
clientTemplate.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
clientCert, err := createCert(clientTemplate, rootTemplate, clientKey, rootKey)
clientCert, err := CreateCert(clientTemplate, rootTemplate, clientKey, rootKey)
if err != nil {
return err
}
Expand Down
Loading