diff --git a/helper/testhelpers/certhelpers/cert_helpers.go b/helper/testhelpers/certhelpers/cert_helpers.go new file mode 100644 index 000000000000..b84bbf961e5a --- /dev/null +++ b/helper/testhelpers/certhelpers/cert_helpers.go @@ -0,0 +1,244 @@ +package certhelpers + +import ( + "bytes" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "strings" + "testing" + "time" +) + +type CertBuilder struct { + tmpl *x509.Certificate + parentTmpl *x509.Certificate + + selfSign bool + parentKey *rsa.PrivateKey + + isCA bool +} + +type CertOpt func(*CertBuilder) error + +func CommonName(cn string) CertOpt { + return func(builder *CertBuilder) error { + builder.tmpl.Subject.CommonName = cn + return nil + } +} + +func Parent(parent Certificate) CertOpt { + return func(builder *CertBuilder) error { + builder.parentKey = parent.PrivKey.PrivKey + builder.parentTmpl = parent.Template + return nil + } +} + +func IsCA(isCA bool) CertOpt { + return func(builder *CertBuilder) error { + builder.isCA = isCA + return nil + } +} + +func SelfSign() CertOpt { + return func(builder *CertBuilder) error { + builder.selfSign = true + return nil + } +} + +func IP(ip ...string) CertOpt { + return func(builder *CertBuilder) error { + for _, addr := range ip { + if ipAddr := net.ParseIP(addr); ipAddr != nil { + builder.tmpl.IPAddresses = append(builder.tmpl.IPAddresses, ipAddr) + } + } + return nil + } +} + +func DNS(dns ...string) CertOpt { + return func(builder *CertBuilder) error { + builder.tmpl.DNSNames = dns + return nil + } +} + +func NewCert(t *testing.T, opts ...CertOpt) (cert Certificate) { + t.Helper() + + builder := CertBuilder{ + tmpl: &x509.Certificate{ + SerialNumber: makeSerial(t), + Subject: pkix.Name{ + CommonName: makeCommonName(), + }, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(1 * time.Hour), + IsCA: false, + KeyUsage: x509.KeyUsageDigitalSignature | + x509.KeyUsageKeyEncipherment | + x509.KeyUsageKeyAgreement, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + }, + } + + for _, opt := range opts { + err := opt(&builder) + if err != nil { + t.Fatalf("Failed to set up certificate builder: %s", err) + } + } + + key := NewPrivateKey(t) + + builder.tmpl.SubjectKeyId = getSubjKeyID(t, key.PrivKey) + + tmpl := builder.tmpl + parent := builder.parentTmpl + publicKey := key.PrivKey.Public() + signingKey := builder.parentKey + + if builder.selfSign { + parent = tmpl + signingKey = key.PrivKey + } + + if builder.isCA { + tmpl.IsCA = true + tmpl.KeyUsage = x509.KeyUsageCertSign | x509.KeyUsageCRLSign + tmpl.ExtKeyUsage = nil + } else { + tmpl.KeyUsage = x509.KeyUsageDigitalSignature | + x509.KeyUsageKeyEncipherment | + x509.KeyUsageKeyAgreement + tmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth} + } + + certBytes, err := x509.CreateCertificate(rand.Reader, tmpl, parent, publicKey, signingKey) + if err != nil { + t.Fatalf("Unable to generate certificate: %s", err) + } + certPem := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }) + + tlsCert, err := tls.X509KeyPair(certPem, key.Pem) + if err != nil { + t.Fatalf("Unable to parse X509 key pair: %s", err) + } + + return Certificate{ + Template: tmpl, + PrivKey: key, + TLSCert: tlsCert, + RawCert: certBytes, + Pem: certPem, + IsCA: builder.isCA, + } +} + +// //////////////////////////////////////////////////////////////////////////// +// Private Key +// //////////////////////////////////////////////////////////////////////////// +type KeyWrapper struct { + PrivKey *rsa.PrivateKey + Pem []byte +} + +func NewPrivateKey(t *testing.T) (key KeyWrapper) { + t.Helper() + + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Unable to generate key for cert: %s", err) + } + + privKeyPem := pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privKey), + }, + ) + + key = KeyWrapper{ + PrivKey: privKey, + Pem: privKeyPem, + } + + return key +} + +// //////////////////////////////////////////////////////////////////////////// +// Certificate +// //////////////////////////////////////////////////////////////////////////// +type Certificate struct { + PrivKey KeyWrapper + Template *x509.Certificate + TLSCert tls.Certificate + RawCert []byte + Pem []byte + IsCA bool +} + +func (cert Certificate) CombinedPEM() []byte { + if cert.IsCA { + return cert.Pem + } + return bytes.Join([][]byte{cert.PrivKey.Pem, cert.Pem}, []byte{'\n'}) +} + +func (cert Certificate) PrivateKeyPEM() []byte { + return cert.PrivKey.Pem +} + +// //////////////////////////////////////////////////////////////////////////// +// Helpers +// //////////////////////////////////////////////////////////////////////////// +func makeSerial(t *testing.T) *big.Int { + t.Helper() + + v := &big.Int{} + serialNumberLimit := v.Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + t.Fatalf("Unable to generate serial number: %s", err) + } + return serialNumber +} + +// Pulled from sdk/helper/certutil & slightly modified for test usage +func getSubjKeyID(t *testing.T, privateKey crypto.Signer) []byte { + t.Helper() + + if privateKey == nil { + t.Fatalf("passed-in private key is nil") + } + + marshaledKey, err := x509.MarshalPKIXPublicKey(privateKey.Public()) + if err != nil { + t.Fatalf("error marshalling public key: %s", err) + } + + subjKeyID := sha1.Sum(marshaledKey) + + return subjKeyID[:] +} + +func makeCommonName() (cn string) { + return strings.ReplaceAll(time.Now().Format("20060102T150405.000"), ".", "") +} diff --git a/plugins/database/mongodb/cert_helpers_test.go b/plugins/database/mongodb/cert_helpers_test.go index 200f997c3abd..deb04ab9c4e4 100644 --- a/plugins/database/mongodb/cert_helpers_test.go +++ b/plugins/database/mongodb/cert_helpers_test.go @@ -10,9 +10,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" - "io/ioutil" "math/big" - "os" "strings" "testing" "time" @@ -192,18 +190,6 @@ func (cert certificate) CombinedPEM() []byte { return bytes.Join([][]byte{cert.privKey.pem, cert.pem}, []byte{'\n'}) } -// //////////////////////////////////////////////////////////////////////////// -// Writing to file -// //////////////////////////////////////////////////////////////////////////// -func writeFile(t *testing.T, filename string, data []byte, perms os.FileMode) { - t.Helper() - - err := ioutil.WriteFile(filename, data, perms) - if err != nil { - t.Fatalf("Unable to write to file [%s]: %s", filename, err) - } -} - // //////////////////////////////////////////////////////////////////////////// // Helpers // //////////////////////////////////////////////////////////////////////////// diff --git a/plugins/database/mongodb/connection_producer_test.go b/plugins/database/mongodb/connection_producer_test.go index 087d556891fd..b96486da77e0 100644 --- a/plugins/database/mongodb/connection_producer_test.go +++ b/plugins/database/mongodb/connection_producer_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/hashicorp/vault/helper/testhelpers/certhelpers" "github.com/hashicorp/vault/helper/testhelpers/mongodb" "github.com/ory/dockertest" "go.mongodb.org/mongo-driver/mongo" @@ -30,20 +31,20 @@ func TestInit_clientTLS(t *testing.T) { defer os.RemoveAll(confDir) // Create certificates for Mongo authentication - caCert := newCert(t, - commonName("test certificate authority"), - isCA(true), - selfSign(), + caCert := certhelpers.NewCert(t, + certhelpers.CommonName("test certificate authority"), + certhelpers.IsCA(true), + certhelpers.SelfSign(), ) - serverCert := newCert(t, - commonName("server"), - dns("localhost"), - parent(caCert), + serverCert := certhelpers.NewCert(t, + certhelpers.CommonName("server"), + certhelpers.DNS("localhost"), + certhelpers.Parent(caCert), ) - clientCert := newCert(t, - commonName("client"), - dns("client"), - parent(caCert), + clientCert := certhelpers.NewCert(t, + certhelpers.CommonName("client"), + certhelpers.DNS("client"), + certhelpers.Parent(caCert), ) writeFile(t, paths.Join(confDir, "ca.pem"), caCert.CombinedPEM(), 0644) @@ -81,7 +82,7 @@ net: "connection_url": retURL, "allowed_roles": "*", "tls_certificate_key": clientCert.CombinedPEM(), - "tls_ca": caCert.pem, + "tls_ca": caCert.Pem, } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -111,7 +112,7 @@ net: AuthInfo: authInfo{ AuthenticatedUsers: []user{ { - User: fmt.Sprintf("CN=%s", clientCert.template.Subject.CommonName), + User: fmt.Sprintf("CN=%s", clientCert.Template.Subject.CommonName), DB: "$external", }, }, @@ -249,11 +250,11 @@ func connect(t *testing.T, uri string) (client *mongo.Client) { return client } -func setUpX509User(t *testing.T, client *mongo.Client, cert certificate) { +func setUpX509User(t *testing.T, client *mongo.Client, cert certhelpers.Certificate) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - username := fmt.Sprintf("CN=%s", cert.template.Subject.CommonName) + username := fmt.Sprintf("CN=%s", cert.Template.Subject.CommonName) cmd := &createUserCommand{ Username: username, @@ -301,3 +302,16 @@ type roles []role func (r roles) Len() int { return len(r) } func (r roles) Less(i, j int) bool { return r[i].Role < r[j].Role } func (r roles) Swap(i, j int) { r[i], r[j] = r[j], r[i] } + +// //////////////////////////////////////////////////////////////////////////// +// Writing to file +// //////////////////////////////////////////////////////////////////////////// +func writeFile(t *testing.T, filename string, data []byte, perms os.FileMode) { + t.Helper() + + err := ioutil.WriteFile(filename, data, perms) + if err != nil { + t.Fatalf("Unable to write to file [%s]: %s", filename, err) + } +} + diff --git a/plugins/database/mongodb/mongodb_test.go b/plugins/database/mongodb/mongodb_test.go index 503279a693ac..15ff76753597 100644 --- a/plugins/database/mongodb/mongodb_test.go +++ b/plugins/database/mongodb/mongodb_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/hashicorp/vault/helper/testhelpers/certhelpers" "github.com/hashicorp/vault/helper/testhelpers/mongodb" "github.com/hashicorp/vault/sdk/database/dbplugin" "go.mongodb.org/mongo-driver/mongo" @@ -239,14 +240,14 @@ func testCreateDBUser(t testing.TB, connURL, db, username, password string) { } func TestGetTLSAuth(t *testing.T) { - ca := newCert(t, - commonName("certificate authority"), - isCA(true), - selfSign(), + ca := certhelpers.NewCert(t, + certhelpers.CommonName("certificate authority"), + certhelpers.IsCA(true), + certhelpers.SelfSign(), ) - cert := newCert(t, - commonName("test cert"), - parent(ca), + cert := certhelpers.NewCert(t, + certhelpers.CommonName("test cert"), + certhelpers.Parent(ca), ) type testCase struct { @@ -276,12 +277,12 @@ func TestGetTLSAuth(t *testing.T) { expectErr: true, }, "good ca": { - tlsCAData: cert.pem, + tlsCAData: cert.Pem, expectOpts: options.Client(). SetTLSConfig( &tls.Config{ - RootCAs: appendToCertPool(t, x509.NewCertPool(), cert.pem), + RootCAs: appendToCertPool(t, x509.NewCertPool(), cert.Pem), }, ), expectErr: false, @@ -293,7 +294,7 @@ func TestGetTLSAuth(t *testing.T) { expectOpts: options.Client(). SetTLSConfig( &tls.Config{ - Certificates: []tls.Certificate{cert.tlsCert}, + Certificates: []tls.Certificate{cert.TLSCert}, }, ). SetAuth(options.Credential{ diff --git a/plugins/database/mysql/connection_producer.go b/plugins/database/mysql/connection_producer.go new file mode 100644 index 000000000000..4e34372005cd --- /dev/null +++ b/plugins/database/mysql/connection_producer.go @@ -0,0 +1,226 @@ +package mysql + +import ( + "context" + "crypto/tls" + "crypto/x509" + "database/sql" + "fmt" + "net/url" + "sync" + "time" + + "github.com/go-sql-driver/mysql" + "github.com/hashicorp/errwrap" + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/sdk/database/helper/connutil" + "github.com/hashicorp/vault/sdk/database/helper/dbutil" + "github.com/hashicorp/vault/sdk/helper/parseutil" + "github.com/mitchellh/mapstructure" +) + +// mySQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases +type mySQLConnectionProducer struct { + ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"` + MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"` + MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"` + + Username string `json:"username" mapstructure:"username" structs:"username"` + Password string `json:"password" mapstructure:"password" structs:"password"` + + TLSCertificateKeyData []byte `json:"tls_certificate_key" mapstructure:"tls_certificate_key" structs:"-"` + TLSCAData []byte `json:"tls_ca" mapstructure:"tls_ca" structs:"-"` + + // tlsConfigName is a globally unique name that references the TLS config for this instance in the mysql driver + tlsConfigName string + + RawConfig map[string]interface{} + maxConnectionLifetime time.Duration + Initialized bool + db *sql.DB + sync.Mutex +} + +func (c *mySQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { + _, err := c.Init(ctx, conf, verifyConnection) + return err +} + +func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) { + c.Lock() + defer c.Unlock() + + c.RawConfig = conf + + err := mapstructure.WeakDecode(conf, &c) + if err != nil { + return nil, err + } + + if len(c.ConnectionURL) == 0 { + return nil, fmt.Errorf("connection_url cannot be empty") + } + + // Don't escape special characters for MySQL password + password := c.Password + + // QueryHelper doesn't do any SQL escaping, but if it starts to do so + // then maybe we won't be able to use it to do URL substitution any more. + c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{ + "username": url.PathEscape(c.Username), + "password": password, + }) + + if c.MaxOpenConnections == 0 { + c.MaxOpenConnections = 4 + } + + if c.MaxIdleConnections == 0 { + c.MaxIdleConnections = c.MaxOpenConnections + } + if c.MaxIdleConnections > c.MaxOpenConnections { + c.MaxIdleConnections = c.MaxOpenConnections + } + if c.MaxConnectionLifetimeRaw == nil { + c.MaxConnectionLifetimeRaw = "0s" + } + + c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw) + if err != nil { + return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err) + } + + tlsConfig, err := c.getTLSAuth() + if err != nil { + return nil, err + } + + if tlsConfig != nil { + if c.tlsConfigName == "" { + c.tlsConfigName, err = uuid.GenerateUUID() + if err != nil { + return nil, fmt.Errorf("unable to generate UUID for TLS configuration: %w", err) + } + } + + mysql.RegisterTLSConfig(c.tlsConfigName, tlsConfig) + } + + // Set initialized to true at this point since all fields are set, + // and the connection can be established at a later time. + c.Initialized = true + + if verifyConnection { + if _, err := c.Connection(ctx); err != nil { + return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) + } + + if err := c.db.PingContext(ctx); err != nil { + return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) + } + } + + return c.RawConfig, nil +} + +func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) { + if !c.Initialized { + return nil, connutil.ErrNotInitialized + } + + // If we already have a DB, test it and return + if c.db != nil { + if err := c.db.PingContext(ctx); err == nil { + return c.db, nil + } + // If the ping was unsuccessful, close it and ignore errors as we'll be + // reestablishing anyways + c.db.Close() + } + + connURL, err := c.addTLStoDSN() + if err != nil { + return nil, err + } + + c.db, err = sql.Open("mysql", connURL) + if err != nil { + return nil, err + } + + // Set some connection pool settings. We don't need much of this, + // since the request rate shouldn't be high. + c.db.SetMaxOpenConns(c.MaxOpenConnections) + c.db.SetMaxIdleConns(c.MaxIdleConnections) + c.db.SetConnMaxLifetime(c.maxConnectionLifetime) + + return c.db, nil +} + +func (c *mySQLConnectionProducer) SecretValues() map[string]interface{} { + return map[string]interface{}{ + c.Password: "[password]", + } +} + +// Close attempts to close the connection +func (c *mySQLConnectionProducer) Close() error { + // Grab the write lock + c.Lock() + defer c.Unlock() + + if c.db != nil { + c.db.Close() + } + + c.db = nil + + return nil +} + +func (c *mySQLConnectionProducer) getTLSAuth() (tlsConfig *tls.Config, err error) { + if len(c.TLSCAData) == 0 && + len(c.TLSCertificateKeyData) == 0 { + return nil, nil + } + + rootCertPool := x509.NewCertPool() + if len(c.TLSCAData) > 0 { + ok := rootCertPool.AppendCertsFromPEM(c.TLSCAData) + if !ok { + return nil, fmt.Errorf("failed to append CA to client options") + } + } + + clientCert := make([]tls.Certificate, 0, 1) + + if len(c.TLSCertificateKeyData) > 0 { + certificate, err := tls.X509KeyPair(c.TLSCertificateKeyData, c.TLSCertificateKeyData) + if err != nil { + return nil, fmt.Errorf("unable to load tls_certificate_key_data: %w", err) + } + + clientCert = append(clientCert, certificate) + } + + tlsConfig = &tls.Config{ + RootCAs: rootCertPool, + Certificates: clientCert, + } + + return tlsConfig, nil +} + +func (c *mySQLConnectionProducer) addTLStoDSN() (connURL string, err error) { + config, err := mysql.ParseDSN(c.ConnectionURL) + if err != nil { + return "", fmt.Errorf("unable to parse connectionURL: %s", err) + } + + config.TLSConfig = c.tlsConfigName + + connURL = config.FormatDSN() + + return connURL, nil +} diff --git a/plugins/database/mysql/connection_producer_test.go b/plugins/database/mysql/connection_producer_test.go new file mode 100644 index 000000000000..7da5a05200d3 --- /dev/null +++ b/plugins/database/mysql/connection_producer_test.go @@ -0,0 +1,311 @@ +package mysql + +import ( + "context" + "database/sql" + "fmt" + "io/ioutil" + "os" + paths "path" + "path/filepath" + "reflect" + "testing" + "time" + + "github.com/hashicorp/vault/helper/testhelpers/certhelpers" + "github.com/hashicorp/vault/sdk/database/helper/dbutil" + "github.com/ory/dockertest" +) + +func Test_addTLStoDSN(t *testing.T) { + type testCase struct { + rootUrl string + tlsConfigName string + expectedResult string + } + + tests := map[string]testCase{ + "no tls, no query string": { + rootUrl: "user:password@tcp(localhost:3306)/test", + tlsConfigName: "", + expectedResult: "user:password@tcp(localhost:3306)/test", + }, + "tls, no query string": { + rootUrl: "user:password@tcp(localhost:3306)/test", + tlsConfigName: "tlsTest101", + expectedResult: "user:password@tcp(localhost:3306)/test?tls=tlsTest101", + }, + "tls, query string": { + rootUrl: "user:password@tcp(localhost:3306)/test?foo=bar", + tlsConfigName: "tlsTest101", + expectedResult: "user:password@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar", + }, + "tls, query string, ? in password": { + rootUrl: "user:pa?ssword?@tcp(localhost:3306)/test?foo=bar", + tlsConfigName: "tlsTest101", + expectedResult: "user:pa?ssword?@tcp(localhost:3306)/test?tls=tlsTest101&foo=bar", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + tCase := mySQLConnectionProducer{ + ConnectionURL: test.rootUrl, + tlsConfigName: test.tlsConfigName, + } + + actual, err := tCase.addTLStoDSN() + if err != nil { + t.Fatalf("error occurred in test: %s", err) + } + if actual != test.expectedResult { + t.Fatalf("generated: %s, expected: %s", actual, test.expectedResult) + } + }) + } +} + +func TestInit_clientTLS(t *testing.T) { + t.Skip("Skipping this test because CircleCI can't mount the files we need without further investigation: " + + "https://support.circleci.com/hc/en-us/articles/360007324514-How-can-I-mount-volumes-to-docker-containers-") + + // Set up temp directory so we can mount it to the docker container + confDir := makeTempDir(t) + defer os.RemoveAll(confDir) + + // Create certificates for MySQL authentication + caCert := certhelpers.NewCert(t, + certhelpers.CommonName("test certificate authority"), + certhelpers.IsCA(true), + certhelpers.SelfSign(), + ) + serverCert := certhelpers.NewCert(t, + certhelpers.CommonName("server"), + certhelpers.DNS("localhost"), + certhelpers.Parent(caCert), + ) + clientCert := certhelpers.NewCert(t, + certhelpers.CommonName("client"), + certhelpers.DNS("client"), + certhelpers.Parent(caCert), + ) + + writeFile(t, paths.Join(confDir, "ca.pem"), caCert.CombinedPEM(), 0644) + writeFile(t, paths.Join(confDir, "server-cert.pem"), serverCert.Pem, 0644) + writeFile(t, paths.Join(confDir, "server-key.pem"), serverCert.PrivateKeyPEM(), 0644) + writeFile(t, paths.Join(confDir, "client.pem"), clientCert.CombinedPEM(), 0644) + + // ////////////////////////////////////////////////////// + // Set up MySQL config file + rawConf := ` +[mysqld] +ssl +ssl-ca=/etc/mysql/ca.pem +ssl-cert=/etc/mysql/server-cert.pem +ssl-key=/etc/mysql/server-key.pem` + + writeFile(t, paths.Join(confDir, "my.cnf"), []byte(rawConf), 0644) + + // ////////////////////////////////////////////////////// + // Start MySQL container + retURL, cleanup := startMySQLWithTLS(t, "5.7", confDir) + defer cleanup() + + // ////////////////////////////////////////////////////// + // Set up x509 user + mClient := connect(t, retURL) + + username := setUpX509User(t, mClient, clientCert) + + // ////////////////////////////////////////////////////// + // Test + mysql := new(25, 25, 25) + + conf := map[string]interface{}{ + "connection_url": retURL, + "username": username, + "tls_certificate_key": clientCert.CombinedPEM(), + "tls_ca": caCert.Pem, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err := mysql.Init(ctx, conf, true) + if err != nil { + t.Fatalf("Unable to initialize mysql engine: %s", err) + } + + // Initialization complete. The connection was established, but we need to ensure + // that we're connected as the right user + whoamiCmd := "SELECT CURRENT_USER()" + + client, err := mysql.getConnection(ctx) + if err != nil { + t.Fatalf("Unable to make connection to MySQL: %s", err) + } + stmt, err := client.Prepare(whoamiCmd) + if err != nil { + t.Fatalf("Unable to prepare MySQL statementL %s", err) + } + + results := stmt.QueryRow() + + expected := fmt.Sprintf("%s@%%", username) + + var result string + if err := results.Scan(&result); err != nil { + t.Fatalf("result could not be scanned from result set: %s", err) + } + + if !reflect.DeepEqual(result, expected) { + t.Fatalf("Actual:%#v\nExpected:\n%#v", result, expected) + } +} + +func makeTempDir(t *testing.T) (confDir string) { + confDir, err := ioutil.TempDir(".", "mysql-test-data") + if err != nil { + t.Fatalf("Unable to make temp directory: %s", err) + } + // Convert the directory to an absolute path because docker needs it when mounting + confDir, err = filepath.Abs(filepath.Clean(confDir)) + if err != nil { + t.Fatalf("Unable to determine where temp directory is on absolute path: %s", err) + } + return confDir +} + +func startMySQLWithTLS(t *testing.T, version string, confDir string) (retURL string, cleanup func()) { + if os.Getenv("MYSQL_URL") != "" { + return os.Getenv("MYSQL_URL"), func() {} + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + pool.MaxWait = 30 * time.Second + + containerName := "mysql-unit-test" + + // Remove previously running container if it is still running because cleanup failed + err = pool.RemoveContainerByName(containerName) + if err != nil { + t.Fatalf("Unable to remove old running containers: %s", err) + } + + username := "root" + password := "x509test" + + runOpts := &dockertest.RunOptions{ + Name: containerName, + Repository: "mysql", + Tag: version, + Cmd: []string{"--defaults-extra-file=/etc/mysql/my.cnf", "--auto-generate-certs=OFF"}, + Env: []string{fmt.Sprintf("MYSQL_ROOT_PASSWORD=%s", password)}, + // Mount the directory from local filesystem into the container + Mounts: []string{ + fmt.Sprintf("%s:/etc/mysql", confDir), + }, + } + + resource, err := pool.RunWithOptions(runOpts) + if err != nil { + t.Fatalf("Could not start local mysql docker container: %s", err) + } + resource.Expire(30) + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + dsn := fmt.Sprintf("{{username}}:{{password}}@tcp(localhost:%s)/mysql", resource.GetPort("3306/tcp")) + + url := dbutil.QueryHelper(dsn, map[string]string{ + "username": username, + "password": password, + }) + // exponential backoff-retry + err = pool.Retry(func() error { + var err error + + db, err := sql.Open("mysql", url) + if err != nil { + t.Logf("err: %s", err) + return err + } + defer db.Close() + return db.Ping() + }) + if err != nil { + cleanup() + t.Fatalf("Could not connect to mysql docker container: %s", err) + } + + return dsn, cleanup +} + +func connect(t *testing.T, dsn string) (db *sql.DB) { + url := dbutil.QueryHelper(dsn, map[string]string{ + "username": "root", + "password": "x509test", + }) + + db, err := sql.Open("mysql", url) + if err != nil { + t.Fatalf("Unable to make connection to MySQL: %s", err) + } + + err = db.Ping() + if err != nil { + t.Fatalf("Failed to ping MySQL server: %s", err) + } + + return db +} + +func setUpX509User(t *testing.T, db *sql.DB, cert certhelpers.Certificate) (username string) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + username = cert.Template.Subject.CommonName + + cmds := []string{ + fmt.Sprintf("CREATE USER %s IDENTIFIED BY '' REQUIRE X509", username), + fmt.Sprintf("GRANT ALL ON mysql.* TO '%s'@'%s' REQUIRE X509", username, "%"), + } + + for _, cmd := range cmds { + stmt, err := db.PrepareContext(ctx, cmd) + if err != nil { + t.Fatalf("Failed to prepare query: %s", err) + } + + _, err = stmt.ExecContext(ctx) + if err != nil { + t.Fatalf("Failed to create x509 user in database: %s", err) + } + err = stmt.Close() + if err != nil { + t.Fatalf("Failed to close prepared statement: %s", err) + } + } + + return username +} + +// //////////////////////////////////////////////////////////////////////////// +// Writing to file +// //////////////////////////////////////////////////////////////////////////// +func writeFile(t *testing.T, filename string, data []byte, perms os.FileMode) { + t.Helper() + + err := ioutil.WriteFile(filename, data, perms) + if err != nil { + t.Fatalf("Unable to write to file [%s]: %s", filename, err) + } +} diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 349676643a1c..381492d7360f 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -10,7 +10,6 @@ import ( stdmysql "github.com/go-sql-driver/mysql" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/sdk/database/dbplugin" - "github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/database/helper/credsutil" "github.com/hashicorp/vault/sdk/database/helper/dbutil" "github.com/hashicorp/vault/sdk/helper/strutil" @@ -39,7 +38,7 @@ var ( var _ dbplugin.Database = (*MySQL)(nil) type MySQL struct { - *connutil.SQLConnectionProducer + *mySQLConnectionProducer credsutil.CredentialsProducer } @@ -55,8 +54,7 @@ func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, erro } func new(displayNameLen, roleNameLen, usernameLen int) *MySQL { - connProducer := &connutil.SQLConnectionProducer{} - connProducer.Type = mySQLTypeName + connProducer := &mySQLConnectionProducer{} credsProducer := &credsutil.SQLCredentialsProducer{ DisplayNameLen: displayNameLen, @@ -66,8 +64,8 @@ func new(displayNameLen, roleNameLen, usernameLen int) *MySQL { } return &MySQL{ - SQLConnectionProducer: connProducer, - CredentialsProducer: credsProducer, + mySQLConnectionProducer: connProducer, + CredentialsProducer: credsProducer, } }