Skip to content

Commit

Permalink
wrapping all the things in middleware, implementing templating for mo…
Browse files Browse the repository at this point in the history
…ngodb
  • Loading branch information
Chris Hoffman committed Feb 23, 2018
1 parent 837b4f8 commit a65adb0
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 45 deletions.
47 changes: 31 additions & 16 deletions builtin/logical/database/dbplugin/databasemiddleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"net/url"
"strings"
"sync"
"time"

"github.com/hashicorp/errwrap"
Expand Down Expand Up @@ -205,59 +206,73 @@ func (mw *databaseMetricsMiddleware) Close() (err error) {

// ---- Error Sanitizer Middleware Domain ----

// databaseErrorSanitizerMiddleware wraps an implementation of Databases and
// DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and
// sanitizes returned error messages
type databaseErrorSanitizerMiddleware struct {
next Database
type DatabaseErrorSanitizerMiddleware struct {
l sync.RWMutex
next Database
secretsFn func() []string
}

func NewDatabaseErrorSanitizerMiddleware(next Database, secretsFn func() []string) *DatabaseErrorSanitizerMiddleware {
return &DatabaseErrorSanitizerMiddleware{
next: next,
secretsFn: secretsFn,
}
}

func (mw *databaseErrorSanitizerMiddleware) Type() (string, error) {
func (mw *DatabaseErrorSanitizerMiddleware) Type() (string, error) {
dbType, err := mw.next.Type()
return dbType, mw.sanitize(err)
}

func (mw *databaseErrorSanitizerMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (mw *DatabaseErrorSanitizerMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
username, password, err = mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
return username, password, mw.sanitize(err)
}

func (mw *databaseErrorSanitizerMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
func (mw *DatabaseErrorSanitizerMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
return mw.sanitize(mw.next.RenewUser(ctx, statements, username, expiration))
}

func (mw *databaseErrorSanitizerMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
func (mw *DatabaseErrorSanitizerMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
return mw.sanitize(mw.next.RevokeUser(ctx, statements, username))
}

func (mw *databaseErrorSanitizerMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
func (mw *DatabaseErrorSanitizerMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
conf, err = mw.next.RotateRootCredentials(ctx, statements)
return conf, mw.sanitize(err)
}

func (mw *databaseErrorSanitizerMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
func (mw *DatabaseErrorSanitizerMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := mw.Init(ctx, conf, verifyConnection)
return err
}

func (mw *databaseErrorSanitizerMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
func (mw *DatabaseErrorSanitizerMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
saveConf, err = mw.next.Init(ctx, conf, verifyConnection)
return saveConf, mw.sanitize(err)
}

func (mw *databaseErrorSanitizerMiddleware) Close() (err error) {
func (mw *DatabaseErrorSanitizerMiddleware) Close() (err error) {
return mw.sanitize(mw.next.Close())
}

// sanitize
func (mw *databaseErrorSanitizerMiddleware) sanitize(err error) error {
func (mw *DatabaseErrorSanitizerMiddleware) sanitize(err error) error {
if err == nil {
return nil
}
errStr := err.Error()
if errwrap.ContainsType(err, new(url.Error)) ||
strings.Contains(errStr, "//") ||
strings.Contains(errStr, "@") {
if errwrap.ContainsType(err, new(url.Error)) {
return errors.New("unable to parse connection url")
}
if mw.secretsFn != nil {
for _, secret := range mw.secretsFn() {
if secret == "" {
continue
}
err = errors.New(strings.Replace(err.Error(), secret, "*****", -1))
}
}
return err
}
4 changes: 2 additions & 2 deletions builtin/logical/database/dbplugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ type DatabasePlugin struct {
}

func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) {
impl := &databaseErrorSanitizerMiddleware{
impl := &DatabaseErrorSanitizerMiddleware{
next: d.impl,
}

Expand All @@ -131,7 +131,7 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
}

func (d DatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
impl := &databaseErrorSanitizerMiddleware{
impl := &DatabaseErrorSanitizerMiddleware{
next: d.impl,
}

Expand Down
4 changes: 3 additions & 1 deletion plugins/database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ func New() (interface{}, error) {
Separator: "_",
}

dbType := &Cassandra{
db := &Cassandra{
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}

dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, connProducer.secretValues)

return dbType, nil
}

Expand Down
4 changes: 4 additions & 0 deletions plugins/database/cassandra/connection_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {

return session, nil
}

func (c *cassandraConnectionProducer) secretValues() []string {
return []string{c.Password, c.PemBundle, c.PemJSON}
}
7 changes: 5 additions & 2 deletions plugins/database/hana/hana.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ func New() (interface{}, error) {
Separator: "_",
}

dbType := &HANA{
db := &HANA{
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}

// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)

return dbType, nil
}

Expand All @@ -56,7 +59,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err
}

plugins.Serve(dbType.(*HANA), apiTLSConfig)
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)

return nil
}
Expand Down
15 changes: 15 additions & 0 deletions plugins/database/mongodb/connection_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
"github.com/mitchellh/mapstructure"

"gopkg.in/mgo.v2"
Expand All @@ -26,8 +27,11 @@ import (
type mongoDBConnectionProducer struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
WriteConcern string `json:"write_concern" structs:"write_concern" mapstructure:"write_concern"`
Username string `json:"username" structs:"username" mapstructure:"username"`
Password string `json:"password" structs:"password" mapstructure:"password"`

Initialized bool
RawConfig map[string]interface{}
Type string
session *mgo.Session
safe *mgo.Safe
Expand All @@ -44,6 +48,8 @@ func (c *mongoDBConnectionProducer) Init(ctx context.Context, conf map[string]in
c.Lock()
defer c.Unlock()

c.RawConfig = conf

err := mapstructure.WeakDecode(conf, c)
if err != nil {
return nil, err
Expand All @@ -53,6 +59,11 @@ func (c *mongoDBConnectionProducer) Init(ctx context.Context, conf map[string]in
return nil, fmt.Errorf("connection_url cannot be empty")
}

c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
"username": c.Username,
"password": c.Password,
})

if c.WriteConcern != "" {
input := c.WriteConcern

Expand Down Expand Up @@ -209,3 +220,7 @@ func parseMongoURL(rawURL string) (*mgo.DialInfo, error) {

return &info, nil
}

func (c *mongoDBConnectionProducer) secretValues() []string {
return []string{c.Password}
}
13 changes: 7 additions & 6 deletions plugins/database/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
"gopkg.in/mgo.v2"
Expand All @@ -24,7 +23,7 @@ const mongoDBTypeName = "mongodb"

// MongoDB is an implementation of Database interface
type MongoDB struct {
connutil.ConnectionProducer
*mongoDBConnectionProducer
credsutil.CredentialsProducer
}

Expand All @@ -42,10 +41,12 @@ func New() (interface{}, error) {
Separator: "-",
}

dbType := &MongoDB{
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
db := &MongoDB{
mongoDBConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}

dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
return dbType, nil
}

Expand Down Expand Up @@ -191,7 +192,7 @@ func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements
switch {
case err == nil, err == mgo.ErrNotFound:
case err == io.EOF, strings.Contains(err.Error(), "EOF"):
if err := m.ConnectionProducer.Close(); err != nil {
if err := m.Close(); err != nil {
return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
}
session, err := m.getConnection(ctx)
Expand Down
7 changes: 5 additions & 2 deletions plugins/database/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ func New() (interface{}, error) {
Separator: "-",
}

dbType := &MSSQL{
db := &MSSQL{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}

// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)

return dbType, nil
}

Expand All @@ -55,7 +58,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err
}

plugins.Serve(dbType.(*MSSQL), apiTLSConfig)
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)

return nil
}
Expand Down
7 changes: 5 additions & 2 deletions plugins/database/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,14 @@ func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, erro
Separator: "-",
}

dbType := &MySQL{
db := &MySQL{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}

// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)

return dbType, nil
}
}
Expand All @@ -88,7 +91,7 @@ func runCommon(legacy bool, apiTLSConfig *api.TLSConfig) error {
return err
}

plugins.Serve(dbType.(*MySQL), apiTLSConfig)
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)

return nil
}
Expand Down
19 changes: 7 additions & 12 deletions plugins/database/postgresql/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,14 @@ func New() (interface{}, error) {
Separator: "-",
}

dbType := &PostgreSQL{
db := &PostgreSQL{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}

// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)

return dbType, nil
}

Expand All @@ -58,7 +61,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err
}

plugins.Serve(dbType.(*PostgreSQL), apiTLSConfig)
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)

return nil
}
Expand Down Expand Up @@ -202,11 +205,7 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
}
}

if err := tx.Commit(); err != nil {
return err
}

return nil
return tx.Commit()
}

func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
Expand Down Expand Up @@ -256,11 +255,7 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revo
}
}

if err := tx.Commit(); err != nil {
return err
}

return nil
return tx.Commit()
}

func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
Expand Down
4 changes: 2 additions & 2 deletions plugins/helper/database/connutil/connutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ type ConnectionProducer interface {
Init(context.Context, map[string]interface{}, bool) (map[string]interface{}, error)
Connection(context.Context) (interface{}, error)

sync.Locker

// DEPRECATED, will be removed in 0.12
Initialize(context.Context, map[string]interface{}, bool) error

sync.Locker
}
4 changes: 4 additions & 0 deletions plugins/helper/database/connutil/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er
return c.db, nil
}

func (c *SQLConnectionProducer) SecretValues() []string {
return []string{c.Password}
}

// Close attempts to close the connection
func (c *SQLConnectionProducer) Close() error {
// Grab the write lock
Expand Down

0 comments on commit a65adb0

Please sign in to comment.