Skip to content

Commit

Permalink
add tests for postgres inline TLS fields
Browse files Browse the repository at this point in the history
  • Loading branch information
fairclothjm committed Aug 7, 2024
1 parent 035330f commit 83748a8
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 63 deletions.
29 changes: 3 additions & 26 deletions helper/testhelpers/postgresql/postgresqlhelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"time"

"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/helper/docker"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
)
Expand Down Expand Up @@ -152,7 +151,7 @@ EOF
return svc.Cleanup, svc.Config.URL().String()
}

sslConfig, err := connectPostgresSSL(
sslConfig := getPostgresSSLConfig(
t,
svc.Config.URL().Host,
sslMode,
Expand Down Expand Up @@ -199,20 +198,13 @@ func prepareTestContainer(t *testing.T, runOpts docker.RunOptions, password stri
return runner, svc.Cleanup, svc.Config.URL().String(), containerID
}

// connectPostgresSSL is used to verify the connection of our test container
// and construct the connection string that is used in tests.
//
// NOTE: The RawQuery component of the url sets the custom sslinline field and
// inlines the certificate material in the sslrootcert, sslcert, and sslkey
// fields. This feature will be removed in a future version of the SDK.
func connectPostgresSSL(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) (docker.ServiceConfig, error) {
func getPostgresSSLConfig(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) docker.ServiceConfig {
if useFallback {
// set the first host to a bad address so we can test the fallback logic
host = "localhost:55," + host
}

u := url.URL{}
db := &sql.DB{}

if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); ok {
// TODO: remove this when we remove the underlying feature in a future SDK version
Expand All @@ -229,12 +221,6 @@ func connectPostgresSSL(t *testing.T, host, sslMode, caCert, clientCert, clientK
"sslkey": {clientKey},
}.Encode(),
}
var err error
db, err = connutil.OpenPostgres("pgx", u.String())
if err != nil {
return nil, err
}
defer db.Close()
} else {
u = url.URL{
Scheme: "postgres",
Expand All @@ -243,18 +229,9 @@ func connectPostgresSSL(t *testing.T, host, sslMode, caCert, clientCert, clientK
Path: "postgres",
RawQuery: url.Values{"sslmode": {sslMode}}.Encode(),
}
var err error
db, err = sql.Open("pgx", u.String())
if err != nil {
return nil, err
}
defer db.Close()
}

if err := db.Ping(); err != nil {
return nil, err
}
return docker.NewServiceURL(u), nil
return docker.NewServiceURL(u)
}

func connectPostgres(password, repo string, useFallback bool) docker.ServiceAdapter {
Expand Down
222 changes: 188 additions & 34 deletions plugins/database/postgresql/postgresql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,9 @@ func TestPostgreSQL_InitializeSSLInlineFeatureFlag(t *testing.T) {
t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true")

// Create certificates for postgres authentication
caCert := certhelpers.NewCert(t,
certhelpers.CommonName("ca"),
certhelpers.IsCA(true),
certhelpers.SelfSign(),
)
clientCert := certhelpers.NewCert(t,
certhelpers.CommonName("postgres"),
certhelpers.DNS("localhost"),
certhelpers.Parent(caCert),
)
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(
t,
"verify-ca",
caCert,
clientCert,
false,
)
caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign())
clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert))
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, "verify-ca", caCert, clientCert, false)
t.Cleanup(cleanup)

type testCase struct {
Expand Down Expand Up @@ -307,23 +293,188 @@ func TestPostgreSQL_InitializeSSLInline(t *testing.T) {
t.Parallel()

// Create certificates for postgres authentication
caCert := certhelpers.NewCert(t,
certhelpers.CommonName("ca"),
certhelpers.IsCA(true),
certhelpers.SelfSign(),
)
clientCert := certhelpers.NewCert(t,
certhelpers.CommonName("postgres"),
certhelpers.DNS("localhost"),
certhelpers.Parent(caCert),
)
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(
t,
test.sslMode,
caCert,
clientCert,
test.useFallback,
)
caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign())
clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert))
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback)
t.Cleanup(cleanup)

if test.useDSN {
var err error
connURL, err = dbutil.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
}
connectionDetails := map[string]interface{}{
"connection_url": connURL,
"max_open_connections": 5,
}

req := dbplugin.InitializeRequest{
Config: connectionDetails,
VerifyConnection: true,
}

db := new()
_, err := dbtesting.VerifyInitialize(t, db, req)
if test.wantErr && err == nil {
t.Fatal("expected error, got nil")
} else if test.wantErr && !strings.Contains(err.Error(), test.expectedError) {
t.Fatalf("got: %s, want: %s", err.Error(), test.expectedError)
}

if !test.wantErr && !db.Initialized {
t.Fatal("Database should be initialized")
}

if err := db.Close(); err != nil {
t.Fatalf("err: %s", err)
}
})
}
}

// TestPostgreSQL_InitializeSSL tests that we can successfully authenticate
// with a postgres server via ssl with a URL connection string or DSN (key/value)
// for each ssl mode.
func TestPostgreSQL_InitializeSSL(t *testing.T) {
type testCase struct {
sslMode string
useDSN bool
useFallback bool
wantErr bool
expectedError string
}

tests := map[string]testCase{
"disable sslmode": {
sslMode: "disable",
wantErr: true,
expectedError: "error verifying connection",
},
"allow sslmode": {
sslMode: "allow",
wantErr: false,
},
"prefer sslmode": {
sslMode: "prefer",
wantErr: false,
},
"require sslmode": {
sslMode: "require",
wantErr: false,
},
"verify-ca sslmode": {
sslMode: "verify-ca",
wantErr: false,
},
"verify-full sslmode": {
sslMode: "verify-full",
wantErr: false,
},
"disable sslmode with DSN": {
sslMode: "disable",
useDSN: true,
wantErr: true,
expectedError: "error verifying connection",
},
"allow sslmode with DSN": {
sslMode: "allow",
useDSN: true,
wantErr: false,
},
"prefer sslmode with DSN": {
sslMode: "prefer",
useDSN: true,
wantErr: false,
},
"require sslmode with DSN": {
sslMode: "require",
useDSN: true,
wantErr: false,
},
"verify-ca sslmode with DSN": {
sslMode: "verify-ca",
useDSN: true,
wantErr: false,
},
"verify-full sslmode with DSN": {
sslMode: "verify-full",
useDSN: true,
wantErr: false,
},
"disable sslmode with fallback": {
sslMode: "disable",
useFallback: true,
wantErr: true,
expectedError: "error verifying connection",
},
"allow sslmode with fallback": {
sslMode: "allow",
useFallback: true,
},
"prefer sslmode with fallback": {
sslMode: "prefer",
useFallback: true,
},
"require sslmode with fallback": {
sslMode: "require",
useFallback: true,
},
"verify-ca sslmode with fallback": {
sslMode: "verify-ca",
useFallback: true,
},
"verify-full sslmode with fallback": {
sslMode: "verify-full",
useFallback: true,
},
"disable sslmode with DSN with fallback": {
sslMode: "disable",
useDSN: true,
useFallback: true,
wantErr: true,
expectedError: "error verifying connection",
},
"allow sslmode with DSN with fallback": {
sslMode: "allow",
useDSN: true,
useFallback: true,
wantErr: false,
},
"prefer sslmode with DSN with fallback": {
sslMode: "prefer",
useDSN: true,
useFallback: true,
wantErr: false,
},
"require sslmode with DSN with fallback": {
sslMode: "require",
useDSN: true,
useFallback: true,
wantErr: false,
},
"verify-ca sslmode with DSN with fallback": {
sslMode: "verify-ca",
useDSN: true,
useFallback: true,
wantErr: false,
},
"verify-full sslmode with DSN with fallback": {
sslMode: "verify-full",
useDSN: true,
useFallback: true,
wantErr: false,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()

// Create certificates for postgres authentication
caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign())
clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert))
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback)
t.Cleanup(cleanup)

if test.useDSN {
Expand All @@ -336,6 +487,9 @@ func TestPostgreSQL_InitializeSSLInline(t *testing.T) {
connectionDetails := map[string]interface{}{
"connection_url": connURL,
"max_open_connections": 5,
"tls_certificate": string(clientCert.CombinedPEM()),
"tls_private_key": string(clientCert.PrivateKeyPEM()),
"tls_ca": string(caCert.CombinedPEM()),
}

req := dbplugin.InitializeRequest{
Expand Down
4 changes: 2 additions & 2 deletions sdk/database/helper/connutil/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import (
"github.com/jackc/pgx/v4/stdlib"
)

// OpenPostgres parses the connection string and opens a connection to the database.
// openPostgres parses the connection string and opens a connection to the database.
//
// If sslinline is set, strips the connection string of all ssl settings and
// creates a TLS config based on the settings provided, then uses the
Expand All @@ -55,7 +55,7 @@ import (
// expects to source ssl material from the file system.
//
// Deprecated: OpenPostgres will be removed in a future version of the Vault SDK.
func OpenPostgres(driverName, connString string) (*sql.DB, error) {
func openPostgres(driverName, connString string) (*sql.DB, error) {
if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); !ok {
return nil, fmt.Errorf("failed to open postgres connection with deprecated funtion, set feature flag to enable")
}
Expand Down
6 changes: 5 additions & 1 deletion sdk/database/helper/connutil/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er
if err != nil {
return nil, fmt.Errorf("failed to parse config: %w", err)
}
if config.TLSConfig == nil {
// handle sslmode=disable
config.TLSConfig = &tls.Config{}
}

config.TLSConfig.RootCAs = c.TLSConfig.RootCAs
config.TLSConfig.ClientCAs = c.TLSConfig.ClientCAs
Expand All @@ -249,7 +253,7 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er
} else if driverName == dbTypePostgres && os.Getenv(pluginutil.PluginUsePostgresSSLInline) != "" {
var err error
// TODO: remove this deprecated function call in a future SDK version
c.db, err = OpenPostgres(driverName, conn)
c.db, err = openPostgres(driverName, conn)
if err != nil {
return nil, fmt.Errorf("failed to open connection: %w", err)
}
Expand Down

0 comments on commit 83748a8

Please sign in to comment.