Skip to content

Commit

Permalink
Fixing more nits, pulled DSN construction into func
Browse files Browse the repository at this point in the history
  • Loading branch information
Lauren Voswinkel committed Jun 16, 2020
1 parent 75acb94 commit 5acc198
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 44 deletions.
60 changes: 17 additions & 43 deletions plugins/database/mysql/connection_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,19 @@ import (
"database/sql"
"fmt"
"net/url"
"strings"
"sync"
"time"

"github.com/go-sql-driver/mysql"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/mitchellh/mapstructure"
)

// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
// mySQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
type mySQLConnectionProducer struct {
ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"`
Expand All @@ -37,7 +35,6 @@ type mySQLConnectionProducer struct {
// tlsConfigName is a globally unique name that references the TLS config for this instance in the mysql driver
tlsConfigName string

Type string
RawConfig map[string]interface{}
maxConnectionLifetime time.Duration
Initialized bool
Expand Down Expand Up @@ -65,8 +62,6 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte
return nil, fmt.Errorf("connection_url cannot be empty")
}

c.Type = "mysql"

// Don't escape special characters for MySQL password
password := c.Password

Expand Down Expand Up @@ -102,7 +97,7 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte
}

if tlsConfig != nil {
if c.tlsConfigName == "" && tlsConfig != nil {
if c.tlsConfigName == "" {
c.tlsConfigName, err = uuid.GenerateUUID()
if err != nil {
return nil, fmt.Errorf("unable to generate UUID for TLS configuration: %w", err)
Expand Down Expand Up @@ -144,33 +139,9 @@ func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{},
c.db.Close()
}

dbType := c.Type

// Otherwise, attempt to make connection
uri, err := url.Parse(c.ConnectionURL)
if err != nil {
return nil, fmt.Errorf("invalid connection URL: %w", err)
}
vals := uri.Query()
if c.tlsConfigName != "" {
vals.Set("tls", c.tlsConfigName)
}
uri.RawQuery = vals.Encode()

// This convoluted piece is to ensure we're not url encoding any username
// or password information
urlPieces := strings.Split(c.ConnectionURL, "?")
connURL := ""
for i, urlFragment := range urlPieces {
if len(urlPieces) == 1 || i != len(urlPieces)-1 {
connURL = connURL + urlFragment
}
}
if len(vals.Encode()) > 0 {
connURL = connURL + "?" + vals.Encode()
}
connURL, err := c.finalizeConnectionURL()

c.db, err = sql.Open(dbType, connURL)
c.db, err = sql.Open("mysql", connURL)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -205,16 +176,6 @@ func (c *mySQLConnectionProducer) Close() error {
return nil
}

// SetCredentials uses provided information to set/create a user in the
// database. Unlike CreateUser, this method requires a username be provided and
// uses the name given, instead of generating a name. This is used for creating
// and setting the password of static accounts, as well as rolling back
// passwords in the database in the event an updated database fails to save in
// Vault's storage.
func (c *mySQLConnectionProducer) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
return "", "", dbutil.Unimplemented()
}

func (c *mySQLConnectionProducer) getTLSAuth() (tlsConfig *tls.Config, err error) {
if len(c.TLSCAData) == 0 &&
len(c.TLSCertificateKeyData) == 0 {
Expand Down Expand Up @@ -247,3 +208,16 @@ func (c *mySQLConnectionProducer) getTLSAuth() (tlsConfig *tls.Config, err error

return tlsConfig, nil
}

func (c *mySQLConnectionProducer) finalizeConnectionURL() (connURL string, err error) {
config, err := mysql.ParseDSN(c.ConnectionURL)
if err != nil {
return "", fmt.Errorf("unable to parse connectionURL: %s", err)
}

config.TLSConfig = c.tlsConfigName

connURL = config.FormatDSN()

return connURL, nil
}
32 changes: 32 additions & 0 deletions plugins/database/mysql/connection_producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,38 @@ import (
"github.com/ory/dockertest"
)

func Test_finalizeConnectionURL(t *testing.T) {
type testCase struct {
rootUrl string
tlsConfigName string
expectedResult string
}

tests := map[string]testCase{
"no tls, no query string": {"user:password@tcp(localhost:3306)/test", "", "user:password@tcp(localhost:3306)/test"},
"tls, no query string": {"user:password@tcp(localhost:3306)/test", "tlsTest101", "user:password@tcp(localhost:3306)/test?tls=tlsTest101"},
"tls, query string": {"user:password@tcp(localhost:3306)/test?foo=bar", "tlsTest101", "user:password@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar"},
"tls, query string, ? in password": {"user:pa?ssword?@tcp(localhost:3306)/test?foo=bar", "tlsTest101", "user:pa?ssword?@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar"},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
tCase := mySQLConnectionProducer {
ConnectionURL: test.rootUrl,
tlsConfigName: test.tlsConfigName,
}

actual, err := tCase.finalizeConnectionURL()
if err != nil {
t.Fatalf("error occurred in test: %s", err)
}
if actual != test.expectedResult {
t.Fatalf("generated: %s, expected: %s", actual, test.expectedResult)
}
})
}
}

func TestInit_clientTLS(t *testing.T) {
t.Skip("Skipping this test because CircleCI can't mount the files we need without further investigation: " +
"https://support.circleci.com/hc/en-us/articles/360007324514-How-can-I-mount-volumes-to-docker-containers-")
Expand Down
1 change: 0 additions & 1 deletion plugins/database/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, erro

func new(displayNameLen, roleNameLen, usernameLen int) *MySQL {
connProducer := &mySQLConnectionProducer{}
connProducer.Type = mySQLTypeName

credsProducer := &credsutil.SQLCredentialsProducer{
DisplayNameLen: displayNameLen,
Expand Down

0 comments on commit 5acc198

Please sign in to comment.