Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support custom tls config #1066

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,47 @@ Valid values for sslmode are:
the server was signed by a trusted CA and the server host name
matches the one in the certificate)
For support ssl key in memory, we extend sslmode. For example:
import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"log"
"github.com/lib/pq"
)
func main() {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile("ca.crt")
if err != nil {
log.Fatal(err)
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
log.Fatal("Failed to append PEM.")
}
clientCert := make([]tls.Certificate, 0, 1)
certs, err := tls.LoadX509KeyPair("client1.crt", "client1.key")
if err != nil {
log.Fatal(err)
}
clientCert = append(clientCert, certs)
err = pq.RegisterTLSConfig("custom", &tls.Config{
RootCAs: rootCertPool,
Certificates: clientCert,
ServerName: "pq.example.com",
})
if err != nil {
log.Fatal(err)
}
connStr := "host=pq.example.com port=5432 user=user1 dbname=pqgotest password=pqgotest sslmode=custom"
db, err := sql.Open("postgres", connStr)
if err != nil {
log.Fatal(err)
}
}
See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
for more information about connection string parameters.
Expand Down
74 changes: 65 additions & 9 deletions ssl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,74 @@ package pq
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"os"
"os/user"
"path/filepath"
"strings"
"sync"
)

// Registry for custom tls.Configs
var (
tlsConfigLock sync.RWMutex
tlsConfigRegistry map[string]*tls.Config
)

func RegisterTLSConfig(key string, config *tls.Config) error {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "require" || strings.ToLower(key) == "verify-ca" || strings.ToLower(key) == "verify-full" || strings.ToLower(key) == "disable" {
return fmt.Errorf("key '%s' is reserved", key)
}

tlsConfigLock.Lock()
if tlsConfigRegistry == nil {
tlsConfigRegistry = make(map[string]*tls.Config)
}

tlsConfigRegistry[key] = config
tlsConfigLock.Unlock()
return nil
}

// DeregisterTLSConfig removes the tls.Config associated with key.
func DeregisterTLSConfig(key string) {
tlsConfigLock.Lock()
if tlsConfigRegistry != nil {
delete(tlsConfigRegistry, key)
}
tlsConfigLock.Unlock()
}

func getTLSConfigClone(key string) (config *tls.Config) {
tlsConfigLock.RLock()
if v, ok := tlsConfigRegistry[key]; ok {
config = v.Clone()
}
tlsConfigLock.RUnlock()
return
}

// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
func readBool(input string) (value bool, valid bool) {
switch input {
case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
}

// Not a valid bool value
return
}

// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
// related settings. The function is nil when no upgrade should take place.
func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
verifyCaOnly := false
tlsConf := tls.Config{}
tlsConf := &tls.Config{}
switch mode := o["sslmode"]; mode {
// "require" is the default.
case "", "require":
Expand Down Expand Up @@ -48,7 +103,12 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
case "disable":
return nil, nil
default:
return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
{
tlsConf = getTLSConfigClone(mode)
if tlsConf == nil {
return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
}
}
}

// Set Server Name Indication (SNI), if enabled by connection parameters.
Expand All @@ -61,11 +121,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
tlsConf.ServerName = o["host"]
}

err := sslClientCertificates(&tlsConf, o)
if err != nil {
return nil, err
}
err = sslCertificateAuthority(&tlsConf, o)
err := sslClientCertificates(tlsConf, o)
if err != nil {
return nil, err
}
Expand All @@ -78,9 +134,9 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient

return func(conn net.Conn) (net.Conn, error) {
client := tls.Client(conn, &tlsConf)
client := tls.Client(conn, tlsConf)
if verifyCaOnly {
err := sslVerifyCertificateAuthority(client, &tlsConf)
err := sslVerifyCertificateAuthority(client, tlsConf)
if err != nil {
return nil, err
}
Expand Down