diff --git a/internal/certs/certs.go b/internal/certs/certs.go index 506b8860ce..5b51cf8ae0 100644 --- a/internal/certs/certs.go +++ b/internal/certs/certs.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "crypto/rsa" "crypto/sha256" - "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" @@ -14,11 +13,6 @@ import ( "net" "strings" "time" - - "go.uber.org/zap" - "golang.org/x/crypto/acme/autocert" - - "github.com/infrahq/infra/internal/logging" ) func GenerateCertificate(hosts []string, caCert *x509.Certificate, caKey crypto.PrivateKey) (certPEM []byte, keyPEM []byte, err error) { @@ -63,64 +57,6 @@ func GenerateCertificate(hosts []string, caCert *x509.Certificate, caKey crypto. return PEMEncodeCertificate(certBytes), keyBytes, nil } -func SelfSignedOrLetsEncryptCert(manager *autocert.Manager) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - ctx := hello.Context() - cert, err := manager.GetCertificate(hello) - if err == nil { - return cert, nil - } - - serverName := hello.ServerName - - if serverName == "" { - serverName = hello.Conn.LocalAddr().String() - } - - certBytes, err := manager.Cache.Get(ctx, serverName+".crt") - if err != nil { - logging.S.Warnf("cert: %s", err) - } - - keyBytes, err := manager.Cache.Get(ctx, serverName+".key") - if err != nil { - logging.S.Warnf("key: %s", err) - } - - // if either cert or key is missing, create it - if certBytes == nil || keyBytes == nil { - ca, caPrivKey, err := newCA() - if err != nil { - return nil, err - } - - certBytes, keyBytes, err = GenerateCertificate([]string{serverName}, ca, caPrivKey) - if err != nil { - return nil, err - } - - if err := manager.Cache.Put(ctx, serverName+".crt", certBytes); err != nil { - return nil, err - } - - if err := manager.Cache.Put(ctx, serverName+".key", keyBytes); err != nil { - return nil, err - } - - logging.L.Info("new server certificate", - zap.String("Server name", serverName), - zap.String("SHA256 fingerprint", Fingerprint(pemDecode(certBytes)))) - } - - keypair, err := tls.X509KeyPair(certBytes, keyBytes) - if err != nil { - return nil, err - } - - return &keypair, nil - } -} - // Fingerprint returns a sha256 checksum of the certificate formatted as // hex pairs separated by colons. This is a common format used by browsers. // The bytes must be the ASN.1 DER form of the x509.Certificate. @@ -130,11 +66,6 @@ func Fingerprint(raw []byte) string { return strings.ToUpper(s) } -func pemDecode(raw []byte) []byte { - block, _ := pem.Decode(raw) - return block.Bytes -} - // PEMEncodeCertificate accepts the bytes of a x509 certificate in ASN.1 DER form // and returns a PEM encoded representation of that certificate. func PEMEncodeCertificate(raw []byte) []byte { @@ -146,39 +77,3 @@ func PEMEncodeCertificate(raw []byte) []byte { func pemEncodePrivateKey(raw []byte) []byte { return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: raw}) } - -func newCA() (*x509.Certificate, *rsa.PrivateKey, error) { - // Generate a CA to sign self-signed certificates - serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) - if err != nil { - return nil, nil, err - } - - caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return nil, nil, err - } - - ca := &x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{"Infra"}, - }, - NotBefore: time.Now().Add(-5 * time.Minute).UTC(), - NotAfter: time.Now().AddDate(0, 0, 365).UTC(), - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - IsCA: true, - BasicConstraintsValid: true, - } - - caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) - if err != nil { - return nil, nil, err - } - - // TODO: is there really no other way to get the Raw field populated? - ca, _ = x509.ParseCertificate(caBytes) - - return ca, caPrivKey, nil -} diff --git a/internal/cmd/server_test.go b/internal/cmd/server_test.go index 107e773494..3d6a7962a0 100644 --- a/internal/cmd/server_test.go +++ b/internal/cmd/server_test.go @@ -142,6 +142,7 @@ tls: caPrivateKey: file:ca.key certificate: testdata/server.crt privateKey: file:server.key + ACME: true keys: - kind: vault @@ -223,6 +224,7 @@ users: CAPrivateKey: "file:ca.key", Certificate: "-----BEGIN CERTIFICATE-----\nnot a real server certificate\n-----END CERTIFICATE-----\n", PrivateKey: "file:server.key", + ACME: true, }, Keys: []server.KeyProvider{ diff --git a/internal/server/server.go b/internal/server/server.go index fd79141148..e1b43b4ee7 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,8 +2,6 @@ package server import ( "context" - "crypto/tls" - "crypto/x509" "embed" "errors" "fmt" @@ -11,7 +9,6 @@ import ( "net" "net/http" "net/http/httputil" - "os" "strings" "time" @@ -19,12 +16,10 @@ import ( "github.com/gin-contrib/static" "github.com/gin-gonic/gin" "github.com/infrahq/secrets" - "golang.org/x/crypto/acme/autocert" "golang.org/x/sync/errgroup" "gorm.io/gorm" "github.com/infrahq/infra/internal" - "github.com/infrahq/infra/internal/certs" "github.com/infrahq/infra/internal/cmd/types" "github.com/infrahq/infra/internal/ginutil" "github.com/infrahq/infra/internal/logging" @@ -83,6 +78,11 @@ type TLSOptions struct { CAPrivateKey string Certificate types.StringOrFile PrivateKey string + + // ACME enables automated certificate manangement. When set to true a TLS + // certificate will be requested from Let's Encrypt, which will be cached + // in the TLSCache. + ACME bool } type Server struct { @@ -299,63 +299,6 @@ type routine struct { stop func() } -func tlsConfigFromOptions( - storage map[string]secrets.SecretStorage, - tlsCacheDir string, - opts TLSOptions, -) (*tls.Config, error) { - // TODO: print CA fingerprint when the client can trust that fingerprint - - if opts.Certificate != "" && opts.PrivateKey != "" { - roots, err := x509.SystemCertPool() - if err != nil { - logging.S.Warnf("failed to load TLS roots from system: %v", err) - roots = x509.NewCertPool() - } - - if opts.CA != "" { - if !roots.AppendCertsFromPEM([]byte(opts.CA)) { - logging.S.Warnf("failed to load TLS CA, invalid PEM") - } - } - - key, err := secrets.GetSecret(opts.PrivateKey, storage) - if err != nil { - return nil, fmt.Errorf("failed to load TLS private key: %w", err) - } - - cert, err := tls.X509KeyPair([]byte(opts.Certificate), []byte(key)) - if err != nil { - return nil, fmt.Errorf("failed to load TLS key pair: %w", err) - } - - return &tls.Config{ - MinVersion: tls.VersionTLS12, - // enable HTTP/2 - NextProtos: []string{"h2", "http/1.1"}, - Certificates: []tls.Certificate{cert}, - // enabled optional mTLS - ClientAuth: tls.VerifyClientCertIfGiven, - ClientCAs: roots, - }, nil - } - - if err := os.MkdirAll(tlsCacheDir, 0o700); err != nil { - return nil, fmt.Errorf("create tls cache: %w", err) - } - - manager := &autocert.Manager{ - Prompt: autocert.AcceptTOS, - Cache: autocert.DirCache(tlsCacheDir), - } - tlsConfig := manager.TLSConfig() - tlsConfig.MinVersion = tls.VersionTLS12 - // TODO: enabled optional mTLS when opts.CA is set - tlsConfig.GetCertificate = certs.SelfSignedOrLetsEncryptCert(manager) - - return tlsConfig, nil -} - func (s *Server) getDatabaseDriver() (gorm.Dialector, error) { postgres, err := s.getPostgresConnectionString() if err != nil { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 131b9b9c0e..4d3a8e2794 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "crypto/tls" - "crypto/x509" "encoding/json" "io/ioutil" "net/http" @@ -204,6 +203,10 @@ func TestServer_Run_UIProxy(t *testing.T) { DBFile: filepath.Join(dir, "sqlite3.db"), UI: UIOptions{Enabled: true}, EnableSignup: true, + TLS: TLSOptions{ + CA: types.StringOrFile(golden.Get(t, "pki/ca.crt")), + CAPrivateKey: string(golden.Get(t, "pki/ca.key")), + }, } assert.NilError(t, opts.UI.ProxyURL.Set(uiSrv.URL)) @@ -489,42 +492,3 @@ func TestServer_PersistSignupUser(t *testing.T) { // retry the authenticated endpoint checkAuthenticated() } - -func TestTLSConfigFromOptions(t *testing.T) { - storage := map[string]secrets.SecretStorage{ - "plaintext": &secrets.PlainSecretProvider{}, - "file": &secrets.FileSecretProvider{}, - } - - ca := golden.Get(t, "pki/ca.crt") - t.Run("user provided certificate", func(t *testing.T) { - opts := TLSOptions{ - CA: types.StringOrFile(ca), - Certificate: types.StringOrFile(golden.Get(t, "pki/localhost.crt")), - PrivateKey: "file:testdata/pki/localhost.key", - } - config, err := tlsConfigFromOptions(storage, t.TempDir(), opts) - assert.NilError(t, err) - - srv := httptest.NewUnstartedServer(noopHandler) - srv.TLS = config - srv.StartTLS() - t.Cleanup(srv.Close) - - roots := x509.NewCertPool() - roots.AppendCertsFromPEM(ca) - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{RootCAs: roots, MinVersion: tls.VersionTLS12}, - }, - } - - resp, err := client.Get(srv.URL) - assert.NilError(t, err) - assert.Equal(t, resp.StatusCode, http.StatusOK) - }) -} - -var noopHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) -}) diff --git a/internal/server/tls.go b/internal/server/tls.go new file mode 100644 index 0000000000..a589b0abac --- /dev/null +++ b/internal/server/tls.go @@ -0,0 +1,171 @@ +package server + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "os" + "sync" + + "github.com/infrahq/secrets" + "go.uber.org/zap" + "golang.org/x/crypto/acme/autocert" + + "github.com/infrahq/infra/internal/certs" + "github.com/infrahq/infra/internal/logging" +) + +func tlsConfigFromOptions( + storage map[string]secrets.SecretStorage, + tlsCacheDir string, + opts TLSOptions, +) (*tls.Config, error) { + // TODO: how can we test this? + if opts.ACME { + if err := os.MkdirAll(tlsCacheDir, 0o700); err != nil { + return nil, fmt.Errorf("create tls cache: %w", err) + } + + manager := &autocert.Manager{ + Prompt: autocert.AcceptTOS, + Cache: autocert.DirCache(tlsCacheDir), + // TODO: according to the docs HostPolicy should be set to prevent + // a DoS attack on certificate requests. + } + tlsConfig := manager.TLSConfig() + tlsConfig.MinVersion = tls.VersionTLS12 + return tlsConfig, nil + } + + // TODO: print CA fingerprint when the client can trust that fingerprint + + roots, err := x509.SystemCertPool() + if err != nil { + logging.S.Warnf("failed to load TLS roots from system: %v", err) + roots = x509.NewCertPool() + } + + if opts.CA != "" { + if !roots.AppendCertsFromPEM([]byte(opts.CA)) { + logging.S.Warnf("failed to load TLS CA, invalid PEM") + } + } + + cfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + // enable HTTP/2 + NextProtos: []string{"h2", "http/1.1"}, + // enabled optional mTLS + ClientAuth: tls.VerifyClientCertIfGiven, + ClientCAs: roots, + } + + if opts.Certificate != "" && opts.PrivateKey != "" { + key, err := secrets.GetSecret(opts.PrivateKey, storage) + if err != nil { + return nil, fmt.Errorf("failed to load TLS private key: %w", err) + } + + cert, err := tls.X509KeyPair([]byte(opts.Certificate), []byte(key)) + if err != nil { + return nil, fmt.Errorf("failed to load TLS key pair: %w", err) + } + + cfg.Certificates = []tls.Certificate{cert} + return cfg, nil + } + + if opts.CA == "" || opts.CAPrivateKey == "" { + return nil, fmt.Errorf("either a TLS certificate and key or a TLS CA and key is required") + } + + cfg.GetCertificate = getCertificate(autocert.DirCache(tlsCacheDir), opts) + return cfg, nil +} + +func getCertificate(cache autocert.Cache, opts TLSOptions) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + var lock sync.RWMutex + + getKeyPair := func(ctx context.Context, serverName string) (cert, key []byte) { + certBytes, err := cache.Get(ctx, serverName+".crt") + if err != nil { + logging.S.Warnf("cert: %s", err) + } + + keyBytes, err := cache.Get(ctx, serverName+".key") + if err != nil { + logging.S.Warnf("key: %s", err) + } + return certBytes, keyBytes + } + + return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + ctx := hello.Context() + serverName := hello.ServerName + + if serverName == "" { + serverName = hello.Conn.LocalAddr().String() + } + + lock.RLock() + certBytes, keyBytes := getKeyPair(ctx, serverName) + lock.RUnlock() + + // if either cert or key is missing, create it + if certBytes == nil || keyBytes == nil { + lock.Lock() + // must check again after acquire + certBytes, keyBytes := getKeyPair(ctx, serverName) + if certBytes != nil && keyBytes != nil { + keypair, err := tls.X509KeyPair(certBytes, keyBytes) + lock.Unlock() + if err != nil { + return nil, err + } + return &keypair, nil + } + defer lock.Unlock() + + ca, err := tls.X509KeyPair([]byte(opts.CA), []byte(opts.CAPrivateKey)) + if err != nil { + return nil, err + } + + caCert, err := x509.ParseCertificate(ca.Certificate[0]) + if err != nil { + return nil, err + } + + certBytes, keyBytes, err = certs.GenerateCertificate([]string{serverName}, caCert, ca.PrivateKey) + if err != nil { + return nil, err + } + + if err := cache.Put(ctx, serverName+".crt", certBytes); err != nil { + return nil, err + } + + if err := cache.Put(ctx, serverName+".key", keyBytes); err != nil { + return nil, err + } + + logging.L.Info("new server certificate", + zap.String("Server name", serverName), + zap.String("SHA256 fingerprint", certs.Fingerprint(pemDecode(certBytes)))) + } + + keypair, err := tls.X509KeyPair(certBytes, keyBytes) + if err != nil { + return nil, err + } + + return &keypair, nil + } +} + +func pemDecode(raw []byte) []byte { + block, _ := pem.Decode(raw) + return block.Bytes +} diff --git a/internal/server/tls_test.go b/internal/server/tls_test.go new file mode 100644 index 0000000000..e4d1a764e4 --- /dev/null +++ b/internal/server/tls_test.go @@ -0,0 +1,56 @@ +package server + +import ( + "crypto/tls" + "crypto/x509" + "net/http" + "net/http/httptest" + "testing" + + "github.com/infrahq/secrets" + "gotest.tools/v3/assert" + "gotest.tools/v3/golden" + + "github.com/infrahq/infra/internal/cmd/types" +) + +func TestTLSConfigFromOptions(t *testing.T) { + storage := map[string]secrets.SecretStorage{ + "plaintext": &secrets.PlainSecretProvider{}, + "file": &secrets.FileSecretProvider{}, + } + + ca := golden.Get(t, "pki/ca.crt") + t.Run("user provided certificate", func(t *testing.T) { + opts := TLSOptions{ + CA: types.StringOrFile(ca), + Certificate: types.StringOrFile(golden.Get(t, "pki/localhost.crt")), + PrivateKey: "file:testdata/pki/localhost.key", + } + config, err := tlsConfigFromOptions(storage, t.TempDir(), opts) + assert.NilError(t, err) + + srv := httptest.NewUnstartedServer(noopHandler) + srv.TLS = config + srv.StartTLS() + t.Cleanup(srv.Close) + + roots := x509.NewCertPool() + roots.AppendCertsFromPEM(ca) + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{RootCAs: roots, MinVersion: tls.VersionTLS12}, + }, + } + + resp, err := client.Get(srv.URL) + assert.NilError(t, err) + assert.Equal(t, resp.StatusCode, http.StatusOK) + }) + + // TODO: test cert generated from CA +} + +var noopHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) +})