From f8345769e6b8de7b0d88be2d74bc969f8d8d6de7 Mon Sep 17 00:00:00 2001 From: Chris Hoffman Date: Fri, 23 Feb 2018 16:14:12 -0500 Subject: [PATCH] wrapping all db servers with error santizer --- plugins/database/cassandra/cassandra.go | 20 ++++++----- plugins/database/cassandra/cassandra_test.go | 16 +++------ plugins/database/hana/hana.go | 28 +++++++-------- plugins/database/hana/hana_test.go | 16 +++------ plugins/database/mongodb/mongodb.go | 11 +++--- plugins/database/mongodb/mongodb_test.go | 36 +++++-------------- plugins/database/mssql/mssql.go | 17 +++++---- plugins/database/mssql/mssql_test.go | 14 +++----- plugins/database/mysql/mysql.go | 33 +++++++++-------- plugins/database/mysql/mysql_test.go | 35 +++++------------- plugins/database/postgresql/postgresql.go | 12 ++++--- .../database/postgresql/postgresql_test.go | 20 ++++------- 12 files changed, 104 insertions(+), 154 deletions(-) diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go index 6fd36cf1b4cd..884dd22a36f1 100644 --- a/plugins/database/cassandra/cassandra.go +++ b/plugins/database/cassandra/cassandra.go @@ -12,7 +12,6 @@ import ( "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/plugins" - "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" ) @@ -27,12 +26,19 @@ var _ dbplugin.Database = &Cassandra{} // Cassandra is an implementation of Database interface type Cassandra struct { - connutil.ConnectionProducer + *cassandraConnectionProducer credsutil.CredentialsProducer } // New returns a new Cassandra instance func New() (interface{}, error) { + db := new() + dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues) + + return dbType, nil +} + +func new() *Cassandra { connProducer := &cassandraConnectionProducer{} connProducer.Type = cassandraTypeName @@ -43,14 +49,10 @@ func New() (interface{}, error) { Separator: "_", } - db := &Cassandra{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, + return &Cassandra{ + cassandraConnectionProducer: connProducer, + CredentialsProducer: credsProducer, } - - dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, connProducer.secretValues) - - return dbType, nil } // Run instantiates a Cassandra object, and runs the RPC server for the plugin diff --git a/plugins/database/cassandra/cassandra_test.go b/plugins/database/cassandra/cassandra_test.go index cb4dc84b3a55..4ad63f9707a5 100644 --- a/plugins/database/cassandra/cassandra_test.go +++ b/plugins/database/cassandra/cassandra_test.go @@ -87,16 +87,13 @@ func TestCassandra_Initialize(t *testing.T) { "protocol_version": 4, } - dbRaw, _ := New() - db := dbRaw.(*Cassandra) - connProducer := db.ConnectionProducer.(*cassandraConnectionProducer) - + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - if !connProducer.Initialized { + if !db.Initialized { t.Fatal("Database should be initalized") } @@ -135,8 +132,7 @@ func TestCassandra_CreateUser(t *testing.T) { "protocol_version": 4, } - dbRaw, _ := New() - db := dbRaw.(*Cassandra) + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -176,8 +172,7 @@ func TestMyCassandra_RenewUser(t *testing.T) { "protocol_version": 4, } - dbRaw, _ := New() - db := dbRaw.(*Cassandra) + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -222,8 +217,7 @@ func TestCassandra_RevokeUser(t *testing.T) { "protocol_version": 4, } - dbRaw, _ := New() - db := dbRaw.(*Cassandra) + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) diff --git a/plugins/database/hana/hana.go b/plugins/database/hana/hana.go index 9087968147bd..6bc4b713cb88 100644 --- a/plugins/database/hana/hana.go +++ b/plugins/database/hana/hana.go @@ -23,7 +23,7 @@ const ( // HANA is an implementation of Database interface type HANA struct { - connutil.ConnectionProducer + *connutil.SQLConnectionProducer credsutil.CredentialsProducer } @@ -31,6 +31,14 @@ var _ dbplugin.Database = &HANA{} // New implements builtinplugins.BuiltinFactory func New() (interface{}, error) { + db := new() + // Wrap the plugin with middleware to sanitize errors + dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) + + return dbType, nil +} + +func new() *HANA { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = hanaTypeName @@ -41,15 +49,10 @@ func New() (interface{}, error) { Separator: "_", } - db := &HANA{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, + return &HANA{ + SQLConnectionProducer: connProducer, + CredentialsProducer: credsProducer, } - - // Wrap the plugin with middleware to sanitize errors - dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) - - return dbType, nil } // Run instantiates a HANA object, and runs the RPC server for the plugin @@ -241,12 +244,7 @@ func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, u } } - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil + return tx.Commit() } func (h *HANA) revokeUserDefault(ctx context.Context, username string) error { diff --git a/plugins/database/hana/hana_test.go b/plugins/database/hana/hana_test.go index 7574ad80ceb4..06a457a8f208 100644 --- a/plugins/database/hana/hana_test.go +++ b/plugins/database/hana/hana_test.go @@ -10,7 +10,6 @@ import ( "time" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" - "github.com/hashicorp/vault/plugins/helper/database/connutil" ) func TestHANA_Initialize(t *testing.T) { @@ -23,16 +22,13 @@ func TestHANA_Initialize(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*HANA) - + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) - if !connProducer.Initialized { + if !db.Initialized { t.Fatal("Database should be initialized") } @@ -53,9 +49,7 @@ func TestHANA_CreateUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*HANA) - + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -96,9 +90,7 @@ func TestHANA_RevokeUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*HANA) - + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) diff --git a/plugins/database/mongodb/mongodb.go b/plugins/database/mongodb/mongodb.go index ce473db7f917..20a8f80d315a 100644 --- a/plugins/database/mongodb/mongodb.go +++ b/plugins/database/mongodb/mongodb.go @@ -31,6 +31,12 @@ var _ dbplugin.Database = &MongoDB{} // New returns a new MongoDB instance func New() (interface{}, error) { + db := new() + dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues) + return dbType, nil +} + +func new() *MongoDB { connProducer := &mongoDBConnectionProducer{} connProducer.Type = mongoDBTypeName @@ -41,13 +47,10 @@ func New() (interface{}, error) { Separator: "-", } - db := &MongoDB{ + return &MongoDB{ mongoDBConnectionProducer: connProducer, CredentialsProducer: credsProducer, } - - dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues) - return dbType, nil } // Run instantiates a MongoDB object, and runs the RPC server for the plugin diff --git a/plugins/database/mongodb/mongodb_test.go b/plugins/database/mongodb/mongodb_test.go index 76ed666e6360..d30c77a2729c 100644 --- a/plugins/database/mongodb/mongodb_test.go +++ b/plugins/database/mongodb/mongodb_test.go @@ -73,19 +73,13 @@ func TestMongoDB_Initialize(t *testing.T) { "connection_url": connURL, } - dbRaw, err := New() + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - db := dbRaw.(*MongoDB) - connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer) - _, err = db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } - - if !connProducer.Initialized { + if !db.Initialized { t.Fatal("Database should be initialized") } @@ -103,12 +97,8 @@ func TestMongoDB_CreateUser(t *testing.T) { "connection_url": connURL, } - dbRaw, err := New() - if err != nil { - t.Fatalf("err: %s", err) - } - db := dbRaw.(*MongoDB) - _, err = db.Init(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -141,12 +131,8 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) { "write_concern": testMongoDBWriteConcern, } - dbRaw, err := New() - if err != nil { - t.Fatalf("err: %s", err) - } - db := dbRaw.(*MongoDB) - _, err = db.Init(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -178,12 +164,8 @@ func TestMongoDB_RevokeUser(t *testing.T) { "connection_url": connURL, } - dbRaw, err := New() - if err != nil { - t.Fatalf("err: %s", err) - } - db := dbRaw.(*MongoDB) - _, err = db.Init(context.Background(), connectionDetails, true) + db := new() + _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index 63343e3b39fe..910da088f4a3 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -30,6 +30,14 @@ type MSSQL struct { } func New() (interface{}, error) { + db := new() + // Wrap the plugin with middleware to sanitize errors + dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) + + return dbType, nil +} + +func new() *MSSQL { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = msSQLTypeName @@ -40,15 +48,10 @@ func New() (interface{}, error) { Separator: "-", } - db := &MSSQL{ + return &MSSQL{ SQLConnectionProducer: connProducer, CredentialsProducer: credsProducer, } - - // Wrap the plugin with middleware to sanitize errors - dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) - - return dbType, nil } // Run instantiates a MSSQL object, and runs the RPC server for the plugin @@ -321,7 +324,7 @@ func (m *MSSQL) RotateRootCredentials(ctx context.Context, statements []string) rotateStatents := statements if len(rotateStatents) == 0 { - rotateStatents = []string{defaultMSSQLRotateRootCredentialsSQL} + rotateStatents = []string{rotateRootCredentialsSQL} } db, err := m.getConnection(ctx) diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go index 493fac547129..fd445d640bc4 100644 --- a/plugins/database/mssql/mssql_test.go +++ b/plugins/database/mssql/mssql_test.go @@ -11,7 +11,6 @@ import ( "time" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" - "github.com/hashicorp/vault/plugins/helper/database/connutil" ) var ( @@ -28,16 +27,13 @@ func TestMSSQL_Initialize(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*MSSQL) - + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) - if !connProducer.Initialized { + if !db.Initialized { t.Fatal("Database should be initalized") } @@ -68,8 +64,7 @@ func TestMSSQL_CreateUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*MSSQL) + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -110,8 +105,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*MSSQL) + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 57f5a6d79ea5..ac77e3c74800 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -47,21 +47,7 @@ type MySQL struct { // New implements builtinplugins.BuiltinFactory func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, error) { return func() (interface{}, error) { - connProducer := &connutil.SQLConnectionProducer{} - connProducer.Type = mySQLTypeName - - credsProducer := &credsutil.SQLCredentialsProducer{ - DisplayNameLen: displayNameLen, - RoleNameLen: roleNameLen, - UsernameLen: usernameLen, - Separator: "-", - } - - db := &MySQL{ - SQLConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - + db := new(displayNameLen, roleNameLen, usernameLen) // Wrap the plugin with middleware to sanitize errors dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) @@ -69,6 +55,23 @@ func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, erro } } +func new(displayNameLen, roleNameLen, usernameLen int) *MySQL { + connProducer := &connutil.SQLConnectionProducer{} + connProducer.Type = mySQLTypeName + + credsProducer := &credsutil.SQLCredentialsProducer{ + DisplayNameLen: displayNameLen, + RoleNameLen: roleNameLen, + UsernameLen: usernameLen, + Separator: "-", + } + + return &MySQL{ + SQLConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } +} + // Run instantiates a MySQL object, and runs the RPC server for the plugin func Run(apiTLSConfig *api.TLSConfig) error { return runCommon(false, apiTLSConfig) diff --git a/plugins/database/mysql/mysql_test.go b/plugins/database/mysql/mysql_test.go index 327f52418b3b..3c5fd5657190 100644 --- a/plugins/database/mysql/mysql_test.go +++ b/plugins/database/mysql/mysql_test.go @@ -9,8 +9,9 @@ import ( "testing" "time" - "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/plugins/helper/database/credsutil" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" dockertest "gopkg.in/ory-am/dockertest.v3" ) @@ -103,17 +104,13 @@ func TestMySQL_Initialize(t *testing.T) { "connection_url": connURL, } - f := New(MetadataLen, MetadataLen, UsernameLen) - dbRaw, _ := f() - db := dbRaw.(*MySQL) - connProducer := db.SQLConnectionProducer - + db := new(MetadataLen, MetadataLen, UsernameLen) _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - if !connProducer.Initialized { + if !db.Initialized { t.Fatal("Database should be initalized") } @@ -142,10 +139,7 @@ func TestMySQL_CreateUser(t *testing.T) { "connection_url": connURL, } - f := New(MetadataLen, MetadataLen, UsernameLen) - dbRaw, _ := f() - db := dbRaw.(*MySQL) - + db := new(MetadataLen, MetadataLen, UsernameLen) _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -207,10 +201,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { "connection_url": connURL, } - f := New(credsutil.NoneLength, LegacyMetadataLen, LegacyUsernameLen) - dbRaw, _ := f() - db := dbRaw.(*MySQL) - + db := new(credsutil.NoneLength, LegacyMetadataLen, LegacyUsernameLen) _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -263,18 +254,13 @@ func TestMySQL_RotateRootCredentials(t *testing.T) { "password": "secret", } - f := New(MetadataLen, MetadataLen, UsernameLen) - dbRaw, _ := f() - db := dbRaw.(*MySQL) - - connProducer := db.SQLConnectionProducer - + db := new(MetadataLen, MetadataLen, UsernameLen) _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - if !connProducer.Initialized { + if !db.Initialized { t.Fatal("Database should be initalized") } @@ -300,10 +286,7 @@ func TestMySQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - f := New(MetadataLen, MetadataLen, UsernameLen) - dbRaw, _ := f() - db := dbRaw.(*MySQL) - + db := new(MetadataLen, MetadataLen, UsernameLen) _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index 76cf8b22bbf0..96ac6725b684 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -33,6 +33,13 @@ var _ dbplugin.Database = &PostgreSQL{} // New implements builtinplugins.BuiltinFactory func New() (interface{}, error) { + db := new() + // Wrap the plugin with middleware to sanitize errors + dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) + return dbType, nil +} + +func new() *PostgreSQL { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = postgreSQLTypeName @@ -48,10 +55,7 @@ func New() (interface{}, error) { CredentialsProducer: credsProducer, } - // Wrap the plugin with middleware to sanitize errors - dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) - - return dbType, nil + return db } // Run instantiates a PostgreSQL object, and runs the RPC server for the plugin diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index c5ec2fec9323..7cdfc8f51562 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -67,17 +67,13 @@ func TestPostgreSQL_Initialize(t *testing.T) { "max_open_connections": 5, } - dbRaw, _ := New() - db := dbRaw.(*PostgreSQL) - - connProducer := db.SQLConnectionProducer - + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - if !connProducer.Initialized { + if !db.Initialized { t.Fatal("Database should be initalized") } @@ -107,8 +103,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*PostgreSQL) + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -160,8 +155,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*PostgreSQL) + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -233,8 +227,7 @@ func TestPostgreSQL_RotateRootCredentials(t *testing.T) { "password": "secret", } - dbRaw, _ := New() - db := dbRaw.(*PostgreSQL) + db := new() connProducer := db.SQLConnectionProducer @@ -269,8 +262,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() - db := dbRaw.(*PostgreSQL) + db := new() _, err := db.Init(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err)