Skip to content

Commit

Permalink
wrapping all db servers with error santizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Hoffman committed Feb 23, 2018
1 parent a65adb0 commit f834576
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 154 deletions.
20 changes: 11 additions & 9 deletions plugins/database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/plugins"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
)
Expand All @@ -27,12 +26,19 @@ var _ dbplugin.Database = &Cassandra{}

// Cassandra is an implementation of Database interface
type Cassandra struct {
connutil.ConnectionProducer
*cassandraConnectionProducer
credsutil.CredentialsProducer
}

// New returns a new Cassandra instance
func New() (interface{}, error) {
db := new()
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)

return dbType, nil
}

func new() *Cassandra {
connProducer := &cassandraConnectionProducer{}
connProducer.Type = cassandraTypeName

Expand All @@ -43,14 +49,10 @@ func New() (interface{}, error) {
Separator: "_",
}

db := &Cassandra{
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
return &Cassandra{
cassandraConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}

dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, connProducer.secretValues)

return dbType, nil
}

// Run instantiates a Cassandra object, and runs the RPC server for the plugin
Expand Down
16 changes: 5 additions & 11 deletions plugins/database/cassandra/cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,13 @@ func TestCassandra_Initialize(t *testing.T) {
"protocol_version": 4,
}

dbRaw, _ := New()
db := dbRaw.(*Cassandra)
connProducer := db.ConnectionProducer.(*cassandraConnectionProducer)

db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}

if !connProducer.Initialized {
if !db.Initialized {
t.Fatal("Database should be initalized")
}

Expand Down Expand Up @@ -135,8 +132,7 @@ func TestCassandra_CreateUser(t *testing.T) {
"protocol_version": 4,
}

dbRaw, _ := New()
db := dbRaw.(*Cassandra)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
Expand Down Expand Up @@ -176,8 +172,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
"protocol_version": 4,
}

dbRaw, _ := New()
db := dbRaw.(*Cassandra)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
Expand Down Expand Up @@ -222,8 +217,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
"protocol_version": 4,
}

dbRaw, _ := New()
db := dbRaw.(*Cassandra)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
Expand Down
28 changes: 13 additions & 15 deletions plugins/database/hana/hana.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,22 @@ const (

// HANA is an implementation of Database interface
type HANA struct {
connutil.ConnectionProducer
*connutil.SQLConnectionProducer
credsutil.CredentialsProducer
}

var _ dbplugin.Database = &HANA{}

// New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) {
db := new()
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)

return dbType, nil
}

func new() *HANA {
connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = hanaTypeName

Expand All @@ -41,15 +49,10 @@ func New() (interface{}, error) {
Separator: "_",
}

db := &HANA{
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
return &HANA{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}

// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)

return dbType, nil
}

// Run instantiates a HANA object, and runs the RPC server for the plugin
Expand Down Expand Up @@ -241,12 +244,7 @@ func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, u
}
}

// Commit the transaction
if err := tx.Commit(); err != nil {
return err
}

return nil
return tx.Commit()
}

func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
Expand Down
16 changes: 4 additions & 12 deletions plugins/database/hana/hana_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"time"

"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
)

func TestHANA_Initialize(t *testing.T) {
Expand All @@ -23,16 +22,13 @@ func TestHANA_Initialize(t *testing.T) {
"connection_url": connURL,
}

dbRaw, _ := New()
db := dbRaw.(*HANA)

db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}

connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
if !connProducer.Initialized {
if !db.Initialized {
t.Fatal("Database should be initialized")
}

Expand All @@ -53,9 +49,7 @@ func TestHANA_CreateUser(t *testing.T) {
"connection_url": connURL,
}

dbRaw, _ := New()
db := dbRaw.(*HANA)

db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
Expand Down Expand Up @@ -96,9 +90,7 @@ func TestHANA_RevokeUser(t *testing.T) {
"connection_url": connURL,
}

dbRaw, _ := New()
db := dbRaw.(*HANA)

db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
Expand Down
11 changes: 7 additions & 4 deletions plugins/database/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ var _ dbplugin.Database = &MongoDB{}

// New returns a new MongoDB instance
func New() (interface{}, error) {
db := new()
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
return dbType, nil
}

func new() *MongoDB {
connProducer := &mongoDBConnectionProducer{}
connProducer.Type = mongoDBTypeName

Expand All @@ -41,13 +47,10 @@ func New() (interface{}, error) {
Separator: "-",
}

db := &MongoDB{
return &MongoDB{
mongoDBConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}

dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
return dbType, nil
}

// Run instantiates a MongoDB object, and runs the RPC server for the plugin
Expand Down
36 changes: 9 additions & 27 deletions plugins/database/mongodb/mongodb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,13 @@ func TestMongoDB_Initialize(t *testing.T) {
"connection_url": connURL,
}

dbRaw, err := New()
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer)

_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}

if !connProducer.Initialized {
if !db.Initialized {
t.Fatal("Database should be initialized")
}

Expand All @@ -103,12 +97,8 @@ func TestMongoDB_CreateUser(t *testing.T) {
"connection_url": connURL,
}

dbRaw, err := New()
if err != nil {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
_, err = db.Init(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down Expand Up @@ -141,12 +131,8 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
"write_concern": testMongoDBWriteConcern,
}

dbRaw, err := New()
if err != nil {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
_, err = db.Init(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down Expand Up @@ -178,12 +164,8 @@ func TestMongoDB_RevokeUser(t *testing.T) {
"connection_url": connURL,
}

dbRaw, err := New()
if err != nil {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
_, err = db.Init(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down
17 changes: 10 additions & 7 deletions plugins/database/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ type MSSQL struct {
}

func New() (interface{}, error) {
db := new()
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)

return dbType, nil
}

func new() *MSSQL {
connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = msSQLTypeName

Expand All @@ -40,15 +48,10 @@ func New() (interface{}, error) {
Separator: "-",
}

db := &MSSQL{
return &MSSQL{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}

// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)

return dbType, nil
}

// Run instantiates a MSSQL object, and runs the RPC server for the plugin
Expand Down Expand Up @@ -321,7 +324,7 @@ func (m *MSSQL) RotateRootCredentials(ctx context.Context, statements []string)

rotateStatents := statements
if len(rotateStatents) == 0 {
rotateStatents = []string{defaultMSSQLRotateRootCredentialsSQL}
rotateStatents = []string{rotateRootCredentialsSQL}
}

db, err := m.getConnection(ctx)
Expand Down
14 changes: 4 additions & 10 deletions plugins/database/mssql/mssql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"time"

"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
)

var (
Expand All @@ -28,16 +27,13 @@ func TestMSSQL_Initialize(t *testing.T) {
"connection_url": connURL,
}

dbRaw, _ := New()
db := dbRaw.(*MSSQL)

db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}

connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
if !connProducer.Initialized {
if !db.Initialized {
t.Fatal("Database should be initalized")
}

Expand Down Expand Up @@ -68,8 +64,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
"connection_url": connURL,
}

dbRaw, _ := New()
db := dbRaw.(*MSSQL)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
Expand Down Expand Up @@ -110,8 +105,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
"connection_url": connURL,
}

dbRaw, _ := New()
db := dbRaw.(*MSSQL)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
Expand Down
Loading

0 comments on commit f834576

Please sign in to comment.