From 83748a855950390ca871372cd37381dcc23d2c3a Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Wed, 7 Aug 2024 12:38:05 -0500 Subject: [PATCH] add tests for postgres inline TLS fields --- .../postgresql/postgresqlhelper.go | 29 +-- .../database/postgresql/postgresql_test.go | 222 +++++++++++++++--- sdk/database/helper/connutil/postgres.go | 4 +- sdk/database/helper/connutil/sql.go | 6 +- 4 files changed, 198 insertions(+), 63 deletions(-) diff --git a/helper/testhelpers/postgresql/postgresqlhelper.go b/helper/testhelpers/postgresql/postgresqlhelper.go index 2cbe3147e8754..7229d2127b9f4 100644 --- a/helper/testhelpers/postgresql/postgresqlhelper.go +++ b/helper/testhelpers/postgresql/postgresqlhelper.go @@ -14,7 +14,6 @@ import ( "time" "github.com/hashicorp/vault/helper/testhelpers/certhelpers" - "github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/helper/docker" "github.com/hashicorp/vault/sdk/helper/pluginutil" ) @@ -152,7 +151,7 @@ EOF return svc.Cleanup, svc.Config.URL().String() } - sslConfig, err := connectPostgresSSL( + sslConfig := getPostgresSSLConfig( t, svc.Config.URL().Host, sslMode, @@ -199,20 +198,13 @@ func prepareTestContainer(t *testing.T, runOpts docker.RunOptions, password stri return runner, svc.Cleanup, svc.Config.URL().String(), containerID } -// connectPostgresSSL is used to verify the connection of our test container -// and construct the connection string that is used in tests. -// -// NOTE: The RawQuery component of the url sets the custom sslinline field and -// inlines the certificate material in the sslrootcert, sslcert, and sslkey -// fields. This feature will be removed in a future version of the SDK. -func connectPostgresSSL(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) (docker.ServiceConfig, error) { +func getPostgresSSLConfig(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) docker.ServiceConfig { if useFallback { // set the first host to a bad address so we can test the fallback logic host = "localhost:55," + host } u := url.URL{} - db := &sql.DB{} if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); ok { // TODO: remove this when we remove the underlying feature in a future SDK version @@ -229,12 +221,6 @@ func connectPostgresSSL(t *testing.T, host, sslMode, caCert, clientCert, clientK "sslkey": {clientKey}, }.Encode(), } - var err error - db, err = connutil.OpenPostgres("pgx", u.String()) - if err != nil { - return nil, err - } - defer db.Close() } else { u = url.URL{ Scheme: "postgres", @@ -243,18 +229,9 @@ func connectPostgresSSL(t *testing.T, host, sslMode, caCert, clientCert, clientK Path: "postgres", RawQuery: url.Values{"sslmode": {sslMode}}.Encode(), } - var err error - db, err = sql.Open("pgx", u.String()) - if err != nil { - return nil, err - } - defer db.Close() } - if err := db.Ping(); err != nil { - return nil, err - } - return docker.NewServiceURL(u), nil + return docker.NewServiceURL(u) } func connectPostgres(password, repo string, useFallback bool) docker.ServiceAdapter { diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index 26f87fefb064e..e9d4efd20e022 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -96,23 +96,9 @@ func TestPostgreSQL_InitializeSSLInlineFeatureFlag(t *testing.T) { t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true") // Create certificates for postgres authentication - caCert := certhelpers.NewCert(t, - certhelpers.CommonName("ca"), - certhelpers.IsCA(true), - certhelpers.SelfSign(), - ) - clientCert := certhelpers.NewCert(t, - certhelpers.CommonName("postgres"), - certhelpers.DNS("localhost"), - certhelpers.Parent(caCert), - ) - cleanup, connURL := postgresql.PrepareTestContainerWithSSL( - t, - "verify-ca", - caCert, - clientCert, - false, - ) + caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign()) + clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert)) + cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, "verify-ca", caCert, clientCert, false) t.Cleanup(cleanup) type testCase struct { @@ -307,23 +293,188 @@ func TestPostgreSQL_InitializeSSLInline(t *testing.T) { t.Parallel() // Create certificates for postgres authentication - caCert := certhelpers.NewCert(t, - certhelpers.CommonName("ca"), - certhelpers.IsCA(true), - certhelpers.SelfSign(), - ) - clientCert := certhelpers.NewCert(t, - certhelpers.CommonName("postgres"), - certhelpers.DNS("localhost"), - certhelpers.Parent(caCert), - ) - cleanup, connURL := postgresql.PrepareTestContainerWithSSL( - t, - test.sslMode, - caCert, - clientCert, - test.useFallback, - ) + caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign()) + clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert)) + cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback) + t.Cleanup(cleanup) + + if test.useDSN { + var err error + connURL, err = dbutil.ParseURL(connURL) + if err != nil { + t.Fatal(err) + } + } + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + "max_open_connections": 5, + } + + req := dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + } + + db := new() + _, err := dbtesting.VerifyInitialize(t, db, req) + if test.wantErr && err == nil { + t.Fatal("expected error, got nil") + } else if test.wantErr && !strings.Contains(err.Error(), test.expectedError) { + t.Fatalf("got: %s, want: %s", err.Error(), test.expectedError) + } + + if !test.wantErr && !db.Initialized { + t.Fatal("Database should be initialized") + } + + if err := db.Close(); err != nil { + t.Fatalf("err: %s", err) + } + }) + } +} + +// TestPostgreSQL_InitializeSSL tests that we can successfully authenticate +// with a postgres server via ssl with a URL connection string or DSN (key/value) +// for each ssl mode. +func TestPostgreSQL_InitializeSSL(t *testing.T) { + type testCase struct { + sslMode string + useDSN bool + useFallback bool + wantErr bool + expectedError string + } + + tests := map[string]testCase{ + "disable sslmode": { + sslMode: "disable", + wantErr: true, + expectedError: "error verifying connection", + }, + "allow sslmode": { + sslMode: "allow", + wantErr: false, + }, + "prefer sslmode": { + sslMode: "prefer", + wantErr: false, + }, + "require sslmode": { + sslMode: "require", + wantErr: false, + }, + "verify-ca sslmode": { + sslMode: "verify-ca", + wantErr: false, + }, + "verify-full sslmode": { + sslMode: "verify-full", + wantErr: false, + }, + "disable sslmode with DSN": { + sslMode: "disable", + useDSN: true, + wantErr: true, + expectedError: "error verifying connection", + }, + "allow sslmode with DSN": { + sslMode: "allow", + useDSN: true, + wantErr: false, + }, + "prefer sslmode with DSN": { + sslMode: "prefer", + useDSN: true, + wantErr: false, + }, + "require sslmode with DSN": { + sslMode: "require", + useDSN: true, + wantErr: false, + }, + "verify-ca sslmode with DSN": { + sslMode: "verify-ca", + useDSN: true, + wantErr: false, + }, + "verify-full sslmode with DSN": { + sslMode: "verify-full", + useDSN: true, + wantErr: false, + }, + "disable sslmode with fallback": { + sslMode: "disable", + useFallback: true, + wantErr: true, + expectedError: "error verifying connection", + }, + "allow sslmode with fallback": { + sslMode: "allow", + useFallback: true, + }, + "prefer sslmode with fallback": { + sslMode: "prefer", + useFallback: true, + }, + "require sslmode with fallback": { + sslMode: "require", + useFallback: true, + }, + "verify-ca sslmode with fallback": { + sslMode: "verify-ca", + useFallback: true, + }, + "verify-full sslmode with fallback": { + sslMode: "verify-full", + useFallback: true, + }, + "disable sslmode with DSN with fallback": { + sslMode: "disable", + useDSN: true, + useFallback: true, + wantErr: true, + expectedError: "error verifying connection", + }, + "allow sslmode with DSN with fallback": { + sslMode: "allow", + useDSN: true, + useFallback: true, + wantErr: false, + }, + "prefer sslmode with DSN with fallback": { + sslMode: "prefer", + useDSN: true, + useFallback: true, + wantErr: false, + }, + "require sslmode with DSN with fallback": { + sslMode: "require", + useDSN: true, + useFallback: true, + wantErr: false, + }, + "verify-ca sslmode with DSN with fallback": { + sslMode: "verify-ca", + useDSN: true, + useFallback: true, + wantErr: false, + }, + "verify-full sslmode with DSN with fallback": { + sslMode: "verify-full", + useDSN: true, + useFallback: true, + wantErr: false, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + // Create certificates for postgres authentication + caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign()) + clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert)) + cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback) t.Cleanup(cleanup) if test.useDSN { @@ -336,6 +487,9 @@ func TestPostgreSQL_InitializeSSLInline(t *testing.T) { connectionDetails := map[string]interface{}{ "connection_url": connURL, "max_open_connections": 5, + "tls_certificate": string(clientCert.CombinedPEM()), + "tls_private_key": string(clientCert.PrivateKeyPEM()), + "tls_ca": string(caCert.CombinedPEM()), } req := dbplugin.InitializeRequest{ diff --git a/sdk/database/helper/connutil/postgres.go b/sdk/database/helper/connutil/postgres.go index 7d96376bd292e..ebb46a6cae30e 100644 --- a/sdk/database/helper/connutil/postgres.go +++ b/sdk/database/helper/connutil/postgres.go @@ -46,7 +46,7 @@ import ( "github.com/jackc/pgx/v4/stdlib" ) -// OpenPostgres parses the connection string and opens a connection to the database. +// openPostgres parses the connection string and opens a connection to the database. // // If sslinline is set, strips the connection string of all ssl settings and // creates a TLS config based on the settings provided, then uses the @@ -55,7 +55,7 @@ import ( // expects to source ssl material from the file system. // // Deprecated: OpenPostgres will be removed in a future version of the Vault SDK. -func OpenPostgres(driverName, connString string) (*sql.DB, error) { +func openPostgres(driverName, connString string) (*sql.DB, error) { if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); !ok { return nil, fmt.Errorf("failed to open postgres connection with deprecated funtion, set feature flag to enable") } diff --git a/sdk/database/helper/connutil/sql.go b/sdk/database/helper/connutil/sql.go index 8db60334f844b..245667358e047 100644 --- a/sdk/database/helper/connutil/sql.go +++ b/sdk/database/helper/connutil/sql.go @@ -232,6 +232,10 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er if err != nil { return nil, fmt.Errorf("failed to parse config: %w", err) } + if config.TLSConfig == nil { + // handle sslmode=disable + config.TLSConfig = &tls.Config{} + } config.TLSConfig.RootCAs = c.TLSConfig.RootCAs config.TLSConfig.ClientCAs = c.TLSConfig.ClientCAs @@ -249,7 +253,7 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er } else if driverName == dbTypePostgres && os.Getenv(pluginutil.PluginUsePostgresSSLInline) != "" { var err error // TODO: remove this deprecated function call in a future SDK version - c.db, err = OpenPostgres(driverName, conn) + c.db, err = openPostgres(driverName, conn) if err != nil { return nil, fmt.Errorf("failed to open connection: %w", err) }