From c90ae64695ac6b0e91913de11d2b8d8d76bc3a25 Mon Sep 17 00:00:00 2001 From: daqingshu <12083415+daqingshu@users.noreply.github.com> Date: Sat, 20 Nov 2021 20:09:20 +0800 Subject: [PATCH] add support custom tls config --- doc.go | 41 ++++++++++++++++++++++++++++++++ ssl.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 106 insertions(+), 9 deletions(-) diff --git a/doc.go b/doc.go index b57184801..5e6a95f56 100644 --- a/doc.go +++ b/doc.go @@ -68,6 +68,47 @@ Valid values for sslmode are: the server was signed by a trusted CA and the server host name matches the one in the certificate) +For support ssl key in memory, we extend sslmode. For example: + + import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "log" + + "github.com/lib/pq" + ) + + func main() { + rootCertPool := x509.NewCertPool() + pem, err := ioutil.ReadFile("ca.crt") + if err != nil { + log.Fatal(err) + } + if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { + log.Fatal("Failed to append PEM.") + } + clientCert := make([]tls.Certificate, 0, 1) + certs, err := tls.LoadX509KeyPair("client1.crt", "client1.key") + if err != nil { + log.Fatal(err) + } + clientCert = append(clientCert, certs) + err = pq.RegisterTLSConfig("custom", &tls.Config{ + RootCAs: rootCertPool, + Certificates: clientCert, + ServerName: "pq.example.com", + }) + if err != nil { + log.Fatal(err) + } + connStr := "host=pq.example.com port=5432 user=user1 dbname=pqgotest password=pqgotest sslmode=custom" + db, err := sql.Open("postgres", connStr) + if err != nil { + log.Fatal(err) + } + } + See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING for more information about connection string parameters. diff --git a/ssl.go b/ssl.go index 36b61ba45..3d16efc7f 100644 --- a/ssl.go +++ b/ssl.go @@ -3,19 +3,74 @@ package pq import ( "crypto/tls" "crypto/x509" + "fmt" "io/ioutil" "net" "os" "os/user" "path/filepath" "strings" + "sync" ) +// Registry for custom tls.Configs +var ( + tlsConfigLock sync.RWMutex + tlsConfigRegistry map[string]*tls.Config +) + +func RegisterTLSConfig(key string, config *tls.Config) error { + if _, isBool := readBool(key); isBool || strings.ToLower(key) == "require" || strings.ToLower(key) == "verify-ca" || strings.ToLower(key) == "verify-full" || strings.ToLower(key) == "disable" { + return fmt.Errorf("key '%s' is reserved", key) + } + + tlsConfigLock.Lock() + if tlsConfigRegistry == nil { + tlsConfigRegistry = make(map[string]*tls.Config) + } + + tlsConfigRegistry[key] = config + tlsConfigLock.Unlock() + return nil +} + +// DeregisterTLSConfig removes the tls.Config associated with key. +func DeregisterTLSConfig(key string) { + tlsConfigLock.Lock() + if tlsConfigRegistry != nil { + delete(tlsConfigRegistry, key) + } + tlsConfigLock.Unlock() +} + +func getTLSConfigClone(key string) (config *tls.Config) { + tlsConfigLock.RLock() + if v, ok := tlsConfigRegistry[key]; ok { + config = v.Clone() + } + tlsConfigLock.RUnlock() + return +} + +// Returns the bool value of the input. +// The 2nd return value indicates if the input was a valid bool value +func readBool(input string) (value bool, valid bool) { + switch input { + case "1", "true", "TRUE", "True": + return true, true + case "0", "false", "FALSE", "False": + return false, true + } + + // Not a valid bool value + return +} + // ssl generates a function to upgrade a net.Conn based on the "sslmode" and // related settings. The function is nil when no upgrade should take place. func ssl(o values) (func(net.Conn) (net.Conn, error), error) { verifyCaOnly := false - tlsConf := tls.Config{} + tlsConf := &tls.Config{} switch mode := o["sslmode"]; mode { // "require" is the default. case "", "require": @@ -48,7 +103,12 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { case "disable": return nil, nil default: - return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) + { + tlsConf = getTLSConfigClone(mode) + if tlsConf == nil { + return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) + } + } } // Set Server Name Indication (SNI), if enabled by connection parameters. @@ -61,11 +121,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { tlsConf.ServerName = o["host"] } - err := sslClientCertificates(&tlsConf, o) - if err != nil { - return nil, err - } - err = sslCertificateAuthority(&tlsConf, o) + err := sslClientCertificates(tlsConf, o) if err != nil { return nil, err } @@ -78,9 +134,9 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient return func(conn net.Conn) (net.Conn, error) { - client := tls.Client(conn, &tlsConf) + client := tls.Client(conn, tlsConf) if verifyCaOnly { - err := sslVerifyCertificateAuthority(client, &tlsConf) + err := sslVerifyCertificateAuthority(client, tlsConf) if err != nil { return nil, err }