Skip to content

Commit

Permalink
Cassandra: Refactor PEM parsing logic (#11861) (#11921)
Browse files Browse the repository at this point in the history
* Refactor TLS parsing

The ParsePEMBundle and ParsePKIJSON functions in the certutil package assumes
both a client certificate and a custom CA are specified. Cassandra needs to
allow for either a client certificate, a custom CA, or both. This revamps the
parsing of pem_json and pem_bundle to accomodate for any of these configurations
  • Loading branch information
pcman312 authored Jun 23, 2021
1 parent 6ef3e6e commit 6a403a6
Show file tree
Hide file tree
Showing 17 changed files with 458 additions and 1,461 deletions.
3 changes: 3 additions & 0 deletions changelog/11861.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
secrets/database/cassandra: Fixed issue where the PEM parsing logic of `pem_bundle` and `pem_json` didn't work for CA-only configurations
```
53 changes: 42 additions & 11 deletions helper/testhelpers/cassandra/cassandrahelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,34 @@ import (
)

type containerConfig struct {
version string
copyFromTo map[string]string
sslOpts *gocql.SslOptions
containerName string
imageName string
version string
copyFromTo map[string]string
env []string

sslOpts *gocql.SslOptions
}

type ContainerOpt func(*containerConfig)

func ContainerName(name string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.containerName = name
}
}

func Image(imageName string, version string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.imageName = imageName
cfg.version = version

// Reset the environment because there's a very good chance the default environment doesn't apply to the
// non-default image being used
cfg.env = nil
}
}

func Version(version string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.version = version
Expand All @@ -33,6 +54,12 @@ func CopyFromTo(copyFromTo map[string]string) ContainerOpt {
}
}

func Env(keyValue string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.env = append(cfg.env, keyValue)
}
}

func SslOpts(sslOpts *gocql.SslOptions) ContainerOpt {
return func(cfg *containerConfig) {
cfg.sslOpts = sslOpts
Expand Down Expand Up @@ -63,7 +90,9 @@ func PrepareTestContainer(t *testing.T, opts ...ContainerOpt) (Host, func()) {
}

containerCfg := &containerConfig{
version: "3.11",
imageName: "cassandra",
version: "3.11",
env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
}

for _, opt := range opts {
Expand All @@ -79,13 +108,15 @@ func PrepareTestContainer(t *testing.T, opts ...ContainerOpt) (Host, func()) {
copyFromTo[absFrom] = to
}

runner, err := docker.NewServiceRunner(docker.RunOptions{
ImageRepo: "cassandra",
ImageTag: containerCfg.version,
Ports: []string{"9042/tcp"},
CopyFromTo: copyFromTo,
Env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
})
runOpts := docker.RunOptions{
ContainerName: containerCfg.containerName,
ImageRepo: containerCfg.imageName,
ImageTag: containerCfg.version,
Ports: []string{"9042/tcp"},
CopyFromTo: copyFromTo,
Env: containerCfg.env,
}
runner, err := docker.NewServiceRunner(runOpts)
if err != nil {
t.Fatalf("Could not start docker cassandra: %s", err)
}
Expand Down
50 changes: 31 additions & 19 deletions plugins/database/cassandra/cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,27 @@ func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()
}

func TestInitialize(t *testing.T) {
db, cleanup := getCassandra(t, 4)
defer cleanup()
t.Run("integer protocol version", func(t *testing.T) {
// getCassandra performs an Initialize call
db, cleanup := getCassandra(t, 4)
t.Cleanup(cleanup)

err := db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
err := db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
})

db, cleanup = getCassandra(t, "4")
defer cleanup()
t.Run("string protocol version", func(t *testing.T) {
// getCassandra performs an Initialize call
db, cleanup := getCassandra(t, "4")
t.Cleanup(cleanup)

err := db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
})
}

func TestCreateUser(t *testing.T) {
Expand All @@ -74,7 +85,7 @@ func TestCreateUser(t *testing.T) {
newUserReq dbplugin.NewUserRequest
expectErr bool
expectedUsernameRegex string
assertCreds func(t testing.TB, address string, port int, username, password string, timeout time.Duration)
assertCreds func(t testing.TB, address string, port int, username, password string, sslOpts *gocql.SslOptions, timeout time.Duration)
}

tests := map[string]testCase{
Expand Down Expand Up @@ -160,7 +171,7 @@ func TestCreateUser(t *testing.T) {
t.Fatalf("no error expected, got: %s", err)
}
require.Regexp(t, test.expectedUsernameRegex, newUserResp.Username)
test.assertCreds(t, db.Hosts, db.Port, newUserResp.Username, test.newUserReq.Password, 5*time.Second)
test.assertCreds(t, db.Hosts, db.Port, newUserResp.Username, test.newUserReq.Password, nil, 5*time.Second)
})
}
}
Expand All @@ -184,7 +195,7 @@ func TestUpdateUserPassword(t *testing.T) {

createResp := dbtesting.AssertNewUser(t, db, createReq)

assertCreds(t, db.Hosts, db.Port, createResp.Username, password, 5*time.Second)
assertCreds(t, db.Hosts, db.Port, createResp.Username, password, nil, 5*time.Second)

newPassword := "somenewpassword"
updateReq := dbplugin.UpdateUserRequest{
Expand All @@ -198,7 +209,7 @@ func TestUpdateUserPassword(t *testing.T) {

dbtesting.AssertUpdateUser(t, db, updateReq)

assertCreds(t, db.Hosts, db.Port, createResp.Username, newPassword, 5*time.Second)
assertCreds(t, db.Hosts, db.Port, createResp.Username, newPassword, nil, 5*time.Second)
}

func TestDeleteUser(t *testing.T) {
Expand All @@ -220,21 +231,21 @@ func TestDeleteUser(t *testing.T) {

createResp := dbtesting.AssertNewUser(t, db, createReq)

assertCreds(t, db.Hosts, db.Port, createResp.Username, password, 5*time.Second)
assertCreds(t, db.Hosts, db.Port, createResp.Username, password, nil, 5*time.Second)

deleteReq := dbplugin.DeleteUserRequest{
Username: createResp.Username,
}

dbtesting.AssertDeleteUser(t, db, deleteReq)

assertNoCreds(t, db.Hosts, db.Port, createResp.Username, password, 5*time.Second)
assertNoCreds(t, db.Hosts, db.Port, createResp.Username, password, nil, 5*time.Second)
}

func assertCreds(t testing.TB, address string, port int, username, password string, timeout time.Duration) {
func assertCreds(t testing.TB, address string, port int, username, password string, sslOpts *gocql.SslOptions, timeout time.Duration) {
t.Helper()
op := func() error {
return connect(t, address, port, username, password)
return connect(t, address, port, username, password, sslOpts)
}
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = timeout
Expand All @@ -248,7 +259,7 @@ func assertCreds(t testing.TB, address string, port int, username, password stri
}
}

func connect(t testing.TB, address string, port int, username, password string) error {
func connect(t testing.TB, address string, port int, username, password string, sslOpts *gocql.SslOptions) error {
t.Helper()
clusterConfig := gocql.NewCluster(address)
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
Expand All @@ -257,6 +268,7 @@ func connect(t testing.TB, address string, port int, username, password string)
}
clusterConfig.ProtoVersion = 4
clusterConfig.Port = port
clusterConfig.SslOpts = sslOpts

session, err := clusterConfig.CreateSession()
if err != nil {
Expand All @@ -266,12 +278,12 @@ func connect(t testing.TB, address string, port int, username, password string)
return nil
}

func assertNoCreds(t testing.TB, address string, port int, username, password string, timeout time.Duration) {
func assertNoCreds(t testing.TB, address string, port int, username, password string, sslOpts *gocql.SslOptions, timeout time.Duration) {
t.Helper()

op := func() error {
// "Invert" the error so the backoff logic sees a failure to connect as a success
err := connect(t, address, port, username, password)
err := connect(t, address, port, username, password, sslOpts)
if err != nil {
return nil
}
Expand Down
100 changes: 27 additions & 73 deletions plugins/database/cassandra/connection_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/hashicorp/vault/sdk/helper/tlsutil"
"github.com/mitchellh/mapstructure"
Expand Down Expand Up @@ -40,7 +39,7 @@ type cassandraConnectionProducer struct {

connectTimeout time.Duration
socketKeepAlive time.Duration
certBundle *certutil.CertBundle
sslOpts *gocql.SslOptions
rawConfig map[string]interface{}

Initialized bool
Expand Down Expand Up @@ -83,38 +82,46 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug
return fmt.Errorf("username cannot be empty")
case len(c.Password) == 0:
return fmt.Errorf("password cannot be empty")
case len(c.PemJSON) > 0 && len(c.PemBundle) > 0:
return fmt.Errorf("cannot specify both pem_json and pem_bundle")
}

var tlsMinVersion uint16 = tls.VersionTLS12
if c.TLSMinVersion != "" {
ver, exists := tlsutil.TLSLookup[c.TLSMinVersion]
if !exists {
return fmt.Errorf("unrecognized TLS version [%s]", c.TLSMinVersion)
}
tlsMinVersion = ver
}

var certBundle *certutil.CertBundle
var parsedCertBundle *certutil.ParsedCertBundle
switch {
case len(c.PemJSON) != 0:
parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
cfg, err := jsonBundleToTLSConfig(c.PemJSON, tlsMinVersion, c.TLSServerName, c.InsecureTLS)
if err != nil {
return fmt.Errorf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: %w", err)
return fmt.Errorf("failed to parse pem_json: %w", err)
}
certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil {
return fmt.Errorf("error marshaling PEM information: %w", err)
c.sslOpts = &gocql.SslOptions{
Config: cfg,
EnableHostVerification: !cfg.InsecureSkipVerify,
}
c.certBundle = certBundle
c.TLS = true

case len(c.PemBundle) != 0:
parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
cfg, err := pemBundleToTLSConfig(c.PemBundle, tlsMinVersion, c.TLSServerName, c.InsecureTLS)
if err != nil {
return fmt.Errorf("error parsing the given PEM information: %w", err)
return fmt.Errorf("failed to parse pem_bundle: %w", err)
}
certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil {
return fmt.Errorf("error marshaling PEM information: %w", err)
c.sslOpts = &gocql.SslOptions{
Config: cfg,
EnableHostVerification: !cfg.InsecureSkipVerify,
}
c.certBundle = certBundle
c.TLS = true
}

if c.InsecureTLS {
c.TLS = true
case c.InsecureTLS:
c.sslOpts = &gocql.SslOptions{
EnableHostVerification: !c.InsecureTLS,
}
}

// Set initialized to true at this point since all fields are set,
Expand Down Expand Up @@ -183,14 +190,7 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql

clusterConfig.Timeout = c.connectTimeout
clusterConfig.SocketKeepalive = c.socketKeepAlive

if c.TLS {
sslOpts, err := getSslOpts(c.certBundle, c.TLSMinVersion, c.TLSServerName, c.InsecureTLS)
if err != nil {
return nil, err
}
clusterConfig.SslOpts = sslOpts
}
clusterConfig.SslOpts = c.sslOpts

if c.LocalDatacenter != "" {
clusterConfig.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy(c.LocalDatacenter)
Expand Down Expand Up @@ -231,52 +231,6 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql
return session, nil
}

func getSslOpts(certBundle *certutil.CertBundle, minTLSVersion, serverName string, insecureSkipVerify bool) (*gocql.SslOptions, error) {
tlsConfig := &tls.Config{}
if certBundle != nil {
if certBundle.Certificate == "" && certBundle.PrivateKey != "" {
return nil, fmt.Errorf("found private key for TLS authentication but no certificate")
}
if certBundle.Certificate != "" && certBundle.PrivateKey == "" {
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
}

parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err)
}

tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%w", tlsConfig, err)
}
}

tlsConfig.InsecureSkipVerify = insecureSkipVerify

if serverName != "" {
tlsConfig.ServerName = serverName
}

if minTLSVersion != "" {
var ok bool
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[minTLSVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
} else {
// MinVersion was not being set earlier. Reset it to
// zero to gracefully handle upgrades.
tlsConfig.MinVersion = 0
}

opts := &gocql.SslOptions{
Config: tlsConfig,
EnableHostVerification: !insecureSkipVerify,
}
return opts, nil
}

func (c *cassandraConnectionProducer) secretValues() map[string]string {
return map[string]string{
c.Password: "[password]",
Expand Down
Loading

0 comments on commit 6a403a6

Please sign in to comment.