Skip to content

Commit

Permalink
refactor PrepareTestContainerWithSSL
Browse files Browse the repository at this point in the history
  • Loading branch information
fairclothjm committed Aug 6, 2024
1 parent 09ede87 commit 035330f
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 36 deletions.
82 changes: 52 additions & 30 deletions helper/testhelpers/postgresql/postgresqlhelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ import (
"fmt"
"net/url"
"os"
"strconv"
"testing"
"time"

"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/helper/docker"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
)

const (
Expand Down Expand Up @@ -68,7 +71,13 @@ func PrepareTestContainerWithVaultUser(t *testing.T, ctx context.Context) (func(

// PrepareTestContainerWithSSL will setup a test container with SSL enabled so
// that we can test client certificate authentication.
func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, sslMode string, useFallback bool) (func(), string) {
func PrepareTestContainerWithSSL(
t *testing.T,
sslMode string,
caCert certhelpers.Certificate,
clientCert certhelpers.Certificate,
useFallback bool,
) (func(), string) {
runOpts := defaultRunOpts(t)
runner, err := docker.NewServiceRunner(runOpts)
if err != nil {
Expand All @@ -82,21 +91,11 @@ func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, sslMode stri
}

// Create certificates for postgres authentication
caCert := certhelpers.NewCert(t,
certhelpers.CommonName("ca"),
certhelpers.IsCA(true),
certhelpers.SelfSign(),
)
serverCert := certhelpers.NewCert(t,
certhelpers.CommonName("server"),
certhelpers.DNS("localhost"),
certhelpers.Parent(caCert),
)
clientCert := certhelpers.NewCert(t,
certhelpers.CommonName("postgres"),
certhelpers.DNS("localhost"),
certhelpers.Parent(caCert),
)

bCtx := docker.NewBuildContext()
bCtx["ca.crt"] = docker.PathContentsFromBytes(caCert.CombinedPEM())
Expand Down Expand Up @@ -133,6 +132,9 @@ EOF
t.Fatalf("failed to copy to container: %v", err)
}

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

// overwrite the postgresql.conf config file with our ssl settings
mustRunCommand(t, ctx, runner, id,
[]string{"bash", "/var/lib/postgresql/pg-conf.sh"})
Expand Down Expand Up @@ -208,28 +210,48 @@ func connectPostgresSSL(t *testing.T, host, sslMode, caCert, clientCert, clientK
// set the first host to a bad address so we can test the fallback logic
host = "localhost:55," + host
}
u := url.URL{
Scheme: "postgres",
User: url.User("postgres"),
Host: host,
Path: "postgres",
RawQuery: url.Values{
"sslmode": {sslMode},
"sslinline": {"true"},
"sslrootcert": {caCert},
"sslcert": {clientCert},
"sslkey": {clientKey},
}.Encode(),
}

// TODO: remove this deprecated function call in a future SDK version
db, err := connutil.OpenPostgres("pgx", u.String())
if err != nil {
return nil, err
u := url.URL{}
db := &sql.DB{}

if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); ok {
// TODO: remove this when we remove the underlying feature in a future SDK version
u = url.URL{
Scheme: "postgres",
User: url.User("postgres"),
Host: host,
Path: "postgres",
RawQuery: url.Values{
"sslmode": {sslMode},
"sslinline": {"true"},
"sslrootcert": {caCert},
"sslcert": {clientCert},
"sslkey": {clientKey},
}.Encode(),
}
var err error
db, err = connutil.OpenPostgres("pgx", u.String())
if err != nil {
return nil, err
}
defer db.Close()
} else {
u = url.URL{
Scheme: "postgres",
User: url.User("postgres"),
Host: host,
Path: "postgres",
RawQuery: url.Values{"sslmode": {sslMode}}.Encode(),
}
var err error
db, err = sql.Open("pgx", u.String())
if err != nil {
return nil, err
}
defer db.Close()
}
defer db.Close()

if err = db.Ping(); err != nil {
if err := db.Ping(); err != nil {
return nil, err
}
return docker.NewServiceURL(u), nil
Expand Down
48 changes: 42 additions & 6 deletions plugins/database/postgresql/postgresql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"testing"
"time"

"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
"github.com/hashicorp/vault/helper/testhelpers/postgresql"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
Expand Down Expand Up @@ -86,15 +87,32 @@ func TestPostgreSQL_InitializeMultiHost(t *testing.T) {
}
}

// TestPostgreSQL_InitializeSSLFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE
// TestPostgreSQL_InitializeSSLInlineFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE
// flag guards against unwanted usage of the deprecated SSL client authentication path.
// TODO: remove this when we remove the underlying feature in a future SDK version
func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) {
func TestPostgreSQL_InitializeSSLInlineFeatureFlag(t *testing.T) {
// set the flag to true so we can call PrepareTestContainerWithSSL
// which does a validation check on the connection
t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true")

cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), "verify-ca", false)
// Create certificates for postgres authentication
caCert := certhelpers.NewCert(t,
certhelpers.CommonName("ca"),
certhelpers.IsCA(true),
certhelpers.SelfSign(),
)
clientCert := certhelpers.NewCert(t,
certhelpers.CommonName("postgres"),
certhelpers.DNS("localhost"),
certhelpers.Parent(caCert),
)
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(
t,
"verify-ca",
caCert,
clientCert,
false,
)
t.Cleanup(cleanup)

type testCase struct {
Expand Down Expand Up @@ -166,11 +184,11 @@ func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) {
}
}

// TestPostgreSQL_InitializeSSL tests that we can successfully authenticate
// TestPostgreSQL_InitializeSSLInline tests that we can successfully authenticate
// with a postgres server via ssl with a URL connection string or DSN (key/value)
// for each ssl mode.
// TODO: remove this when we remove the underlying feature in a future SDK version
func TestPostgreSQL_InitializeSSL(t *testing.T) {
func TestPostgreSQL_InitializeSSLInline(t *testing.T) {
// required to enable the sslinline custom parsing
t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true")

Expand Down Expand Up @@ -287,7 +305,25 @@ func TestPostgreSQL_InitializeSSL(t *testing.T) {
for name, test := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), test.sslMode, test.useFallback)

// Create certificates for postgres authentication
caCert := certhelpers.NewCert(t,
certhelpers.CommonName("ca"),
certhelpers.IsCA(true),
certhelpers.SelfSign(),
)
clientCert := certhelpers.NewCert(t,
certhelpers.CommonName("postgres"),
certhelpers.DNS("localhost"),
certhelpers.Parent(caCert),
)
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(
t,
test.sslMode,
caCert,
clientCert,
test.useFallback,
)
t.Cleanup(cleanup)

if test.useDSN {
Expand Down

0 comments on commit 035330f

Please sign in to comment.