From 5acc198e720acc01b8ac1f67c07f1c4cee98cd11 Mon Sep 17 00:00:00 2001 From: Lauren Voswinkel Date: Tue, 16 Jun 2020 15:12:53 -0700 Subject: [PATCH] Fixing more nits, pulled DSN construction into func --- plugins/database/mysql/connection_producer.go | 60 ++++++------------- .../mysql/connection_producer_test.go | 32 ++++++++++ plugins/database/mysql/mysql.go | 1 - 3 files changed, 49 insertions(+), 44 deletions(-) diff --git a/plugins/database/mysql/connection_producer.go b/plugins/database/mysql/connection_producer.go index 484da8fd871d..40f377b60af6 100644 --- a/plugins/database/mysql/connection_producer.go +++ b/plugins/database/mysql/connection_producer.go @@ -7,21 +7,19 @@ import ( "database/sql" "fmt" "net/url" - "strings" "sync" "time" "github.com/go-sql-driver/mysql" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-uuid" - "github.com/hashicorp/vault/sdk/database/dbplugin" "github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/database/helper/dbutil" "github.com/hashicorp/vault/sdk/helper/parseutil" "github.com/mitchellh/mapstructure" ) -// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases +// mySQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases type mySQLConnectionProducer struct { ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"` MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"` @@ -37,7 +35,6 @@ type mySQLConnectionProducer struct { // tlsConfigName is a globally unique name that references the TLS config for this instance in the mysql driver tlsConfigName string - Type string RawConfig map[string]interface{} maxConnectionLifetime time.Duration Initialized bool @@ -65,8 +62,6 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte return nil, fmt.Errorf("connection_url cannot be empty") } - c.Type = "mysql" - // Don't escape special characters for MySQL password password := c.Password @@ -102,7 +97,7 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte } if tlsConfig != nil { - if c.tlsConfigName == "" && tlsConfig != nil { + if c.tlsConfigName == "" { c.tlsConfigName, err = uuid.GenerateUUID() if err != nil { return nil, fmt.Errorf("unable to generate UUID for TLS configuration: %w", err) @@ -144,33 +139,9 @@ func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{}, c.db.Close() } - dbType := c.Type - - // Otherwise, attempt to make connection - uri, err := url.Parse(c.ConnectionURL) - if err != nil { - return nil, fmt.Errorf("invalid connection URL: %w", err) - } - vals := uri.Query() - if c.tlsConfigName != "" { - vals.Set("tls", c.tlsConfigName) - } - uri.RawQuery = vals.Encode() - - // This convoluted piece is to ensure we're not url encoding any username - // or password information - urlPieces := strings.Split(c.ConnectionURL, "?") - connURL := "" - for i, urlFragment := range urlPieces { - if len(urlPieces) == 1 || i != len(urlPieces)-1 { - connURL = connURL + urlFragment - } - } - if len(vals.Encode()) > 0 { - connURL = connURL + "?" + vals.Encode() - } + connURL, err := c.finalizeConnectionURL() - c.db, err = sql.Open(dbType, connURL) + c.db, err = sql.Open("mysql", connURL) if err != nil { return nil, err } @@ -205,16 +176,6 @@ func (c *mySQLConnectionProducer) Close() error { return nil } -// SetCredentials uses provided information to set/create a user in the -// database. Unlike CreateUser, this method requires a username be provided and -// uses the name given, instead of generating a name. This is used for creating -// and setting the password of static accounts, as well as rolling back -// passwords in the database in the event an updated database fails to save in -// Vault's storage. -func (c *mySQLConnectionProducer) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { - return "", "", dbutil.Unimplemented() -} - func (c *mySQLConnectionProducer) getTLSAuth() (tlsConfig *tls.Config, err error) { if len(c.TLSCAData) == 0 && len(c.TLSCertificateKeyData) == 0 { @@ -247,3 +208,16 @@ func (c *mySQLConnectionProducer) getTLSAuth() (tlsConfig *tls.Config, err error return tlsConfig, nil } + +func (c *mySQLConnectionProducer) finalizeConnectionURL() (connURL string, err error) { + config, err := mysql.ParseDSN(c.ConnectionURL) + if err != nil { + return "", fmt.Errorf("unable to parse connectionURL: %s", err) + } + + config.TLSConfig = c.tlsConfigName + + connURL = config.FormatDSN() + + return connURL, nil +} diff --git a/plugins/database/mysql/connection_producer_test.go b/plugins/database/mysql/connection_producer_test.go index 6f09b71c7afa..82e934f14dc8 100644 --- a/plugins/database/mysql/connection_producer_test.go +++ b/plugins/database/mysql/connection_producer_test.go @@ -17,6 +17,38 @@ import ( "github.com/ory/dockertest" ) +func Test_finalizeConnectionURL(t *testing.T) { + type testCase struct { + rootUrl string + tlsConfigName string + expectedResult string + } + + tests := map[string]testCase{ + "no tls, no query string": {"user:password@tcp(localhost:3306)/test", "", "user:password@tcp(localhost:3306)/test"}, + "tls, no query string": {"user:password@tcp(localhost:3306)/test", "tlsTest101", "user:password@tcp(localhost:3306)/test?tls=tlsTest101"}, + "tls, query string": {"user:password@tcp(localhost:3306)/test?foo=bar", "tlsTest101", "user:password@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar"}, + "tls, query string, ? in password": {"user:pa?ssword?@tcp(localhost:3306)/test?foo=bar", "tlsTest101", "user:pa?ssword?@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar"}, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + tCase := mySQLConnectionProducer { + ConnectionURL: test.rootUrl, + tlsConfigName: test.tlsConfigName, + } + + actual, err := tCase.finalizeConnectionURL() + if err != nil { + t.Fatalf("error occurred in test: %s", err) + } + if actual != test.expectedResult { + t.Fatalf("generated: %s, expected: %s", actual, test.expectedResult) + } + }) + } +} + func TestInit_clientTLS(t *testing.T) { t.Skip("Skipping this test because CircleCI can't mount the files we need without further investigation: " + "https://support.circleci.com/hc/en-us/articles/360007324514-How-can-I-mount-volumes-to-docker-containers-") diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 9cb0adb62cfc..381492d7360f 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -55,7 +55,6 @@ func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, erro func new(displayNameLen, roleNameLen, usernameLen int) *MySQL { connProducer := &mySQLConnectionProducer{} - connProducer.Type = mySQLTypeName credsProducer := &credsutil.SQLCredentialsProducer{ DisplayNameLen: displayNameLen,