diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index f8bf307a3057..894fa658a646 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -135,6 +135,7 @@ func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, queryMap := map[string]string{ "name": username, + "username": username, "password": password, "expiration": expirationStr, } @@ -187,6 +188,7 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, // 1295: This command is not supported in the prepared statement protocol yet // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ query = strings.Replace(query, "{{name}}", username, -1) + query = strings.Replace(query, "{{username}}", username, -1) _, err = tx.ExecContext(ctx, query) if err != nil { return err @@ -244,6 +246,7 @@ func (m *MySQL) RotateRootCredentials(ctx context.Context, statements []string) // 1295: This command is not supported in the prepared statement protocol yet // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ query = strings.Replace(query, "{{username}}", m.Username, -1) + query = strings.Replace(query, "{{name}}", m.Username, -1) query = strings.Replace(query, "{{password}}", password, -1) if _, err := tx.ExecContext(ctx, query); err != nil { @@ -283,10 +286,11 @@ func (m *MySQL) SetCredentials(ctx context.Context, statements dbplugin.Statemen queryMap := map[string]string{ "name": username, + "username": username, "password": password, } - if err := m.executePreparedStatmentsWithMap(ctx, statements.Rotation, queryMap); err != nil { + if err := m.executePreparedStatmentsWithMap(ctx, rotateStatements, queryMap); err != nil { return "", "", err } return username, password, nil diff --git a/plugins/database/mysql/mysql_test.go b/plugins/database/mysql/mysql_test.go index 3c39b44acc36..e73545ffb79c 100644 --- a/plugins/database/mysql/mysql_test.go +++ b/plugins/database/mysql/mysql_test.go @@ -53,153 +53,219 @@ func TestMySQL_Initialize(t *testing.T) { } func TestMySQL_CreateUser(t *testing.T) { - cleanup, connURL := mysqlhelper.PrepareMySQLTestContainer(t, false, "secret") - defer cleanup() - - connectionDetails := map[string]interface{}{ - "connection_url": connURL, - } - - db := new(MetadataLen, MetadataLen, UsernameLen) - _, err := db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } + t.Run("missing creation statements", func(t *testing.T) { + db := new(MetadataLen, MetadataLen, UsernameLen) - usernameConfig := dbplugin.UsernameConfig{ - DisplayName: "test-long-displayname", - RoleName: "test-long-rolename", - } - - // Test with no configured Creation Statement - _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) - if err == nil { - t.Fatal("Expected error when no creation statement is provided") - } + usernameConfig := dbplugin.UsernameConfig{ + DisplayName: "test-long-displayname", + RoleName: "test-long-rolename", + } - statements := dbplugin.Statements{ - Creation: []string{testMySQLRoleWildCard}, - } + username, password, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) + if err == nil { + t.Fatalf("expected err, got nil") + } + if username != "" { + t.Fatalf("expected empty username, got [%s]", username) + } + if password != "" { + t.Fatalf("expected empty password, got [%s]", password) + } + }) - username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } + t.Run("non-legacy", func(t *testing.T) { + // Shared test container for speed - there should not be any overlap between the tests + cleanup, connURL := mysqlhelper.PrepareMySQLTestContainer(t, false, "secret") + defer cleanup() - if err := mysqlhelper.TestCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } - // Test a second time to make sure usernames don't collide - username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } + db := new(MetadataLen, MetadataLen, UsernameLen) + _, err := db.Init(context.Background(), connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } - if err := mysqlhelper.TestCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + testCreateUser(t, db, connURL) + }) - // Test with a manually prepare statement - statements.Creation = []string{testMySQLRolePreparedStmt} + t.Run("legacy", func(t *testing.T) { + // Shared test container for speed - there should not be any overlap between the tests + cleanup, connURL := mysqlhelper.PrepareMySQLTestContainer(t, true, "secret") + defer cleanup() - username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } - if err := mysqlhelper.TestCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + db := new(credsutil.NoneLength, LegacyMetadataLen, LegacyUsernameLen) + _, err := db.Init(context.Background(), connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + testCreateUser(t, db, connURL) + }) } -func TestMySQL_CreateUser_Legacy(t *testing.T) { - cleanup, connURL := mysqlhelper.PrepareMySQLTestContainer(t, true, "secret") - defer cleanup() - - connectionDetails := map[string]interface{}{ - "connection_url": connURL, - } - - db := new(credsutil.NoneLength, LegacyMetadataLen, LegacyUsernameLen) - _, err := db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } - - usernameConfig := dbplugin.UsernameConfig{ - DisplayName: "test-long-displayname", - RoleName: "test-long-rolename", - } - - // Test with no configured Creation Statement - _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) - if err == nil { - t.Fatal("Expected error when no creation statement is provided") - } +func testCreateUser(t *testing.T, db *MySQL, connURL string) { + type testCase struct { + createStmts []string + } + + tests := map[string]testCase{ + "create name": { + createStmts: []string{` + CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; + GRANT SELECT ON *.* TO '{{name}}'@'%';`, + }, + }, + "create username": { + createStmts: []string{` + CREATE USER '{{username}}'@'%' IDENTIFIED BY '{{password}}'; + GRANT SELECT ON *.* TO '{{username}}'@'%';`, + }, + }, + "prepared statement name": { + createStmts: []string{` + CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; + set @grants=CONCAT("GRANT SELECT ON ", "*", ".* TO '{{name}}'@'%'"); + PREPARE grantStmt from @grants; + EXECUTE grantStmt; + DEALLOCATE PREPARE grantStmt; + `, + }, + }, + "prepared statement username": { + createStmts: []string{` + CREATE USER '{{username}}'@'%' IDENTIFIED BY '{{password}}'; + set @grants=CONCAT("GRANT SELECT ON ", "*", ".* TO '{{username}}'@'%'"); + PREPARE grantStmt from @grants; + EXECUTE grantStmt; + DEALLOCATE PREPARE grantStmt; + `, + }, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + usernameConfig := dbplugin.UsernameConfig{ + DisplayName: "test-long-displayname", + RoleName: "test-long-rolename", + } - statements := dbplugin.Statements{ - Creation: []string{testMySQLRoleWildCard}, - } + statements := dbplugin.Statements{ + Creation: test.createStmts, + } - username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } - if err := mysqlhelper.TestCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + if err := mysqlhelper.TestCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } - // Test a second time to make sure usernames don't collide - username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } + // Test a second time to make sure usernames don't collide + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } - if err := mysqlhelper.TestCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) + if err := mysqlhelper.TestCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + }) } } func TestMySQL_RotateRootCredentials(t *testing.T) { - cleanup, connURL := mysqlhelper.PrepareMySQLTestContainer(t, false, "secret") - defer cleanup() - - connURL = strings.Replace(connURL, "root:secret", `{{username}}:{{password}}`, -1) + type testCase struct { + statements []string + } + + tests := map[string]testCase{ + "empty statements": { + statements: nil, + }, + "default username": { + statements: []string{defaultMySQLRotateCredentialsSQL}, + }, + "default name": { + statements: []string{` + ALTER USER '{{username}}'@'%' IDENTIFIED BY '{{password}}';`, + }, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + cleanup, connURL := mysqlhelper.PrepareMySQLTestContainer(t, false, "secret") + defer cleanup() + + connURL = strings.Replace(connURL, "root:secret", `{{username}}:{{password}}`, -1) + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + "username": "root", + "password": "secret", + } - connectionDetails := map[string]interface{}{ - "connection_url": connURL, - "username": "root", - "password": "secret", - } + // Give a timeout just in case the test decides to be problematic + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - db := new(MetadataLen, MetadataLen, UsernameLen) - _, err := db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } + db := new(MetadataLen, MetadataLen, UsernameLen) + _, err := db.Init(ctx, connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } - if !db.Initialized { - t.Fatal("Database should be initialized") - } + if !db.Initialized { + t.Fatal("Database should be initialized") + } - newConf, err := db.RotateRootCredentials(context.Background(), nil) - if err != nil { - t.Fatalf("err: %v", err) - } - if newConf["password"] == "secret" { - t.Fatal("password was not updated") - } + newConf, err := db.RotateRootCredentials(ctx, test.statements) + if err != nil { + t.Fatalf("err: %v", err) + } + if newConf["password"] == "secret" { + t.Fatal("password was not updated") + } - err = db.Close() - if err != nil { - t.Fatalf("err: %s", err) + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + }) } } func TestMySQL_RevokeUser(t *testing.T) { + type testCase struct { + revokeStmts []string + } + + tests := map[string]testCase{ + "empty statements": { + revokeStmts: nil, + }, + "default name": { + revokeStmts: []string{defaultMysqlRevocationStmts}, + }, + "default username": { + revokeStmts: []string{` + REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{username}}'@'%'; + DROP USER '{{username}}'@'%'`, + }, + }, + } + + // Shared test container for speed - there should not be any overlap between the tests cleanup, connURL := mysqlhelper.PrepareMySQLTestContainer(t, false, "secret") defer cleanup() @@ -207,117 +273,142 @@ func TestMySQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - db := new(MetadataLen, MetadataLen, UsernameLen) - _, err := db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := dbplugin.Statements{ - Creation: []string{testMySQLRoleWildCard}, - } - - usernameConfig := dbplugin.UsernameConfig{ - DisplayName: "test", - RoleName: "test", - } + // Give a timeout just in case the test decides to be problematic + initCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) + db := new(MetadataLen, MetadataLen, UsernameLen) + _, err := db.Init(initCtx, connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - if err := mysqlhelper.TestCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + statements := dbplugin.Statements{ + Creation: []string{` + CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; + GRANT SELECT ON *.* TO '{{name}}'@'%';`, + }, + Revocation: test.revokeStmts, + } - // Test default revoke statements - err = db.RevokeUser(context.Background(), statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } + usernameConfig := dbplugin.UsernameConfig{ + DisplayName: "test", + RoleName: "test", + } - if err := mysqlhelper.TestCredsExist(t, connURL, username, password); err == nil { - t.Fatal("Credentials were not revoked") - } + // Give a timeout just in case the test decides to be problematic + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - statements.Creation = []string{testMySQLRoleWildCard} - username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } + username, password, err := db.CreateUser(ctx, statements, usernameConfig, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } - if err := mysqlhelper.TestCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + if err := mysqlhelper.TestCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } - // Test custom revoke statements - statements.Revocation = []string{testMySQLRevocationSQL} - err = db.RevokeUser(context.Background(), statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } + err = db.RevokeUser(context.Background(), statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } - if err := mysqlhelper.TestCredsExist(t, connURL, username, password); err == nil { - t.Fatal("Credentials were not revoked") + if err := mysqlhelper.TestCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } + }) } } func TestMySQL_SetCredentials(t *testing.T) { - cleanup, connURL := mysqlhelper.PrepareMySQLTestContainer(t, false, "secret") - defer cleanup() - - // create the database user and verify we can access - dbUser := "vaultstatictest" - createTestMySQLUser(t, connURL, dbUser, "password", testRoleStaticCreate) - if err := mysqlhelper.TestCredsExist(t, connURL, dbUser, "password"); err != nil { - t.Fatalf("Could not connect with credentials: %s", err) - } + type testCase struct { + rotateStmts []string + } + + tests := map[string]testCase{ + "empty statements": { + rotateStmts: nil, + }, + "custom statement name": { + rotateStmts: []string{` + ALTER USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';`}, + }, + "custom statement username": { + rotateStmts: []string{` + ALTER USER '{{username}}'@'%' IDENTIFIED BY '{{password}}';`}, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + cleanup, connURL := mysqlhelper.PrepareMySQLTestContainer(t, false, "secret") + defer cleanup() + + // create the database user and verify we can access + dbUser := "vaultstatictest" + initPassword := "password" + + createStatements := ` + CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; + GRANT SELECT ON *.* TO '{{name}}'@'%';` + + createTestMySQLUser(t, connURL, dbUser, initPassword, createStatements) + if err := mysqlhelper.TestCredsExist(t, connURL, dbUser, "password"); err != nil { + t.Fatalf("Could not connect with credentials: %s", err) + } - connectionDetails := map[string]interface{}{ - "connection_url": connURL, - } + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } - db := new(MetadataLen, MetadataLen, UsernameLen) - _, err := db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } + // Give a timeout just in case the test decides to be problematic + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - newPassword, err := db.GenerateCredentials(context.Background()) - if err != nil { - t.Fatal(err) - } + db := new(MetadataLen, MetadataLen, UsernameLen) + _, err := db.Init(ctx, connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } - userConfig := dbplugin.StaticUserConfig{ - Username: dbUser, - Password: newPassword, - } + newPassword, err := db.GenerateCredentials(ctx) + if err != nil { + t.Fatalf("unable to generate password: %s", err) + } - statements := dbplugin.Statements{ - Rotation: []string{testRoleStaticRotate}, - } + userConfig := dbplugin.StaticUserConfig{ + Username: dbUser, + Password: newPassword, + } - _, _, err = db.SetCredentials(context.Background(), statements, userConfig) - if err != nil { - t.Fatalf("err: %s", err) - } + statements := dbplugin.Statements{ + Rotation: test.rotateStmts, + } - // verify new password works - if err := mysqlhelper.TestCredsExist(t, connURL, dbUser, newPassword); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + username, password, err := db.SetCredentials(ctx, statements, userConfig) + if err != nil { + t.Fatalf("err: %s", err) + } + if username != userConfig.Username { + t.Fatalf("expected username [%s], got [%s]", userConfig.Username, username) + } + if password != userConfig.Password { + t.Fatalf("expected password [%s] got [%s]", userConfig.Password, password) + } - // call SetCredentials again, password will change - newPassword, _ = db.GenerateCredentials(context.Background()) - userConfig.Password = newPassword - _, _, err = db.SetCredentials(context.Background(), statements, userConfig) - if err != nil { - t.Fatalf("err: %s", err) - } + // verify new password works + if err := mysqlhelper.TestCredsExist(t, connURL, dbUser, newPassword); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } - if err := mysqlhelper.TestCredsExist(t, connURL, dbUser, newPassword); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) + // verify old password doesn't work + if err := mysqlhelper.TestCredsExist(t, connURL, dbUser, initPassword); err == nil { + t.Fatalf("Should not be able to connect with initial credentials") + } + }) } } @@ -399,28 +490,3 @@ func createTestMySQLUser(t *testing.T, connURL, username, password, query string stmt.Close() } } - -const testMySQLRolePreparedStmt = ` -CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; -set @grants=CONCAT("GRANT SELECT ON ", "*", ".* TO '{{name}}'@'%'"); -PREPARE grantStmt from @grants; -EXECUTE grantStmt; -DEALLOCATE PREPARE grantStmt; -` -const testMySQLRoleWildCard = ` -CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; -GRANT SELECT ON *.* TO '{{name}}'@'%'; -` -const testMySQLRevocationSQL = ` -REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; -DROP USER '{{name}}'@'%'; -` - -const testRoleStaticCreate = ` -CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; -GRANT SELECT ON *.* TO '{{name}}'@'%'; -` - -const testRoleStaticRotate = ` -ALTER USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; -` diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index 34493f80ce4f..6e221481e2b8 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -26,10 +26,6 @@ ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; ` defaultPostgresRotateRootCredentialsSQL = ` ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}'; -` - - defaultPostgresRotateCredentialsSQL = ` -ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}'; ` ) @@ -149,6 +145,7 @@ func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Sta m := map[string]string{ "name": staticUser.Username, + "username": staticUser.Username, "password": password, } if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { @@ -217,6 +214,7 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme m := map[string]string{ "name": username, + "username": username, "password": password, "expiration": expirationStr, } @@ -272,6 +270,7 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen m := map[string]string{ "name": username, + "username": username, "expiration": expirationStr, } if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { @@ -319,7 +318,8 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revo } m := map[string]string{ - "name": username, + "name": username, + "username": username, } if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return err @@ -479,6 +479,7 @@ func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []str continue } m := map[string]string{ + "name": p.Username, "username": p.Username, "password": password, } diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index 3fbbb439ebbd..a849cefbb710 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "strings" - "sync" "testing" "time" @@ -17,10 +16,6 @@ import ( "github.com/ory/dockertest" ) -var ( - testPostgresImagePull sync.Once -) - func preparePostgresTestContainer(t *testing.T) (cleanup func(), retURL string) { if os.Getenv("PG_URL") != "" { return func() {}, os.Getenv("PG_URL") @@ -97,59 +92,73 @@ func TestPostgreSQL_Initialize(t *testing.T) { } -func TestPostgreSQL_CreateUser(t *testing.T) { - cleanup, connURL := preparePostgresTestContainer(t) - defer cleanup() - - connectionDetails := map[string]interface{}{ - "connection_url": connURL, - } - +func TestPostgreSQL_CreateUser_missingArgs(t *testing.T) { db := new() - _, err := db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } usernameConfig := dbplugin.UsernameConfig{ DisplayName: "test", RoleName: "test", } - // Test with no configured Creation Statement - _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) if err == nil { - t.Fatal("Expected error when no creation statement is provided") - } - - statements := dbplugin.Statements{ - Creation: []string{testPostgresRole}, - } - - username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err = testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) + t.Fatalf("expected err, got nil") } - - statements.Creation = []string{testPostgresReadOnlyRole} - username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) + if username != "" { + t.Fatalf("expected empty username, got [%s]", username) } - - // Sleep to make sure we haven't expired if granularity is only down to the second - time.Sleep(2 * time.Second) - - if err = testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) + if password != "" { + t.Fatalf("expected empty password, got [%s]", password) } } -func TestPostgreSQL_RenewUser(t *testing.T) { +func TestPostgreSQL_CreateUser(t *testing.T) { + type testCase struct { + createStmts []string + } + + tests := map[string]testCase{ + "admin name": { + createStmts: []string{` + CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`, + }, + }, + "admin username": { + createStmts: []string{` + CREATE ROLE "{{username}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`, + }, + }, + "read only name": { + createStmts: []string{` + CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; + GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; + GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`, + }, + }, + "read only username": { + createStmts: []string{` + CREATE ROLE "{{username}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; + GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}"; + GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`, + }, + }, + } + + // Shared test container for speed - there should not be any overlap between the tests cleanup, connURL := preparePostgresTestContainer(t) defer cleanup() @@ -163,100 +172,60 @@ func TestPostgreSQL_RenewUser(t *testing.T) { t.Fatalf("err: %s", err) } - statements := dbplugin.Statements{ - Creation: []string{testPostgresRole}, - } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + usernameConfig := dbplugin.UsernameConfig{ + DisplayName: "test", + RoleName: "test", + } - usernameConfig := dbplugin.UsernameConfig{ - DisplayName: "test", - RoleName: "test", - } + statements := dbplugin.Statements{ + Creation: test.createStmts, + } - username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) - if err != nil { - t.Fatalf("err: %s", err) - } + // Give a timeout just in case the test decides to be problematic + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - if err = testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + username, password, err := db.CreateUser(ctx, statements, usernameConfig, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } - err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } - // Sleep longer than the initial expiration time - time.Sleep(2 * time.Second) + // Ensure that the role doesn't expire immediately + time.Sleep(2 * time.Second) - if err = testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + }) } - statements.Renewal = []string{defaultPostgresRenewSQL} - username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err = testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - - err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Sleep longer than the initial expiration time - time.Sleep(2 * time.Second) - - if err = testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - } -func TestPostgreSQL_RotateRootCredentials(t *testing.T) { - cleanup, connURL := preparePostgresTestContainer(t) - defer cleanup() - - connURL = strings.Replace(connURL, "postgres:secret", `{{username}}:{{password}}`, -1) - - connectionDetails := map[string]interface{}{ - "connection_url": connURL, - "max_open_connections": 5, - "username": "postgres", - "password": "secret", - } - - db := new() - - connProducer := db.SQLConnectionProducer - - _, err := db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } - - if !connProducer.Initialized { - t.Fatal("Database should be initialized") - } - - newConf, err := db.RotateRootCredentials(context.Background(), nil) - if err != nil { - t.Fatalf("err: %v", err) - } - if newConf["password"] == "secret" { - t.Fatal("password was not updated") +func TestPostgreSQL_RenewUser(t *testing.T) { + type testCase struct { + renewalStmts []string } - err = db.Close() - if err != nil { - t.Fatalf("err: %s", err) + tests := map[string]testCase{ + "empty renewal statements": { + renewalStmts: nil, + }, + "default renewal name": { + renewalStmts: []string{defaultPostgresRenewSQL}, + }, + "default renewal username": { + renewalStmts: []string{` + ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`, + }, + }, } -} -func TestPostgreSQL_RevokeUser(t *testing.T) { + // Shared test container for speed - there should not be any overlap between the tests cleanup, connURL := preparePostgresTestContainer(t) defer cleanup() @@ -265,121 +234,331 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { } db := new() - _, err := db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } - statements := dbplugin.Statements{ - Creation: []string{testPostgresRole}, - } + // Give a timeout just in case the test decides to be problematic + initCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - usernameConfig := dbplugin.UsernameConfig{ - DisplayName: "test", - RoleName: "test", - } - - username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) + _, err := db.Init(initCtx, connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - if err = testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + statements := dbplugin.Statements{ + Creation: []string{createAdminUser}, + Renewal: test.renewalStmts, + } - // Test default revoke statements - err = db.RevokeUser(context.Background(), statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } + usernameConfig := dbplugin.UsernameConfig{ + DisplayName: "test", + RoleName: "test", + } - if err := testCredsExist(t, connURL, username, password); err == nil { - t.Fatal("Credentials were not revoked") - } + // Give a timeout just in case the test decides to be problematic + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) - if err != nil { - t.Fatalf("err: %s", err) - } + username, password, err := db.CreateUser(ctx, statements, usernameConfig, time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } - if err = testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } - // Test custom revoke statements - statements.Revocation = []string{defaultPostgresRevocationSQL} - err = db.RevokeUser(context.Background(), statements, username) - if err != nil { - t.Fatalf("err: %s", err) + err = db.RenewUser(ctx, statements, username, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Sleep longer than the initial expiration time + time.Sleep(2 * time.Second) + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + }) } +} - if err := testCredsExist(t, connURL, username, password); err == nil { - t.Fatal("Credentials were not revoked") +func TestPostgreSQL_RotateRootCredentials(t *testing.T) { + type testCase struct { + statements []string + } + + tests := map[string]testCase{ + "empty statements": { + statements: nil, + }, + "default name": { + statements: []string{` + ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, + }, + }, + "default username": { + statements: []string{defaultPostgresRotateRootCredentialsSQL}, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connURL = strings.Replace(connURL, "postgres:secret", `{{username}}:{{password}}`, -1) + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + "max_open_connections": 5, + "username": "postgres", + "password": "secret", + } + + db := new() + + connProducer := db.SQLConnectionProducer + + // Give a timeout just in case the test decides to be problematic + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + _, err := db.Init(ctx, connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !connProducer.Initialized { + t.Fatal("Database should be initialized") + } + + newConf, err := db.RotateRootCredentials(ctx, test.statements) + if err != nil { + t.Fatalf("err: %v", err) + } + if newConf["password"] == "secret" { + t.Fatal("password was not updated") + } + + err = db.Close() + if err != nil { + t.Fatalf("failed to close: %s", err) + } + }) } } -func TestPostgresSQL_SetCredentials(t *testing.T) { +func TestPostgreSQL_RevokeUser(t *testing.T) { + type testCase struct { + revokeStmts []string + } + + tests := map[string]testCase{ + "empty statements": { + revokeStmts: nil, + }, + "explicit default name": { + revokeStmts: []string{defaultPostgresRevocationSQL}, + }, + "explicit default username": { + revokeStmts: []string{` + REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{username}}"; + REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{username}}"; + REVOKE USAGE ON SCHEMA public FROM "{{username}}"; + + DROP ROLE IF EXISTS "{{username}}";`, + }, + }, + } + + // Shared test container for speed - there should not be any overlap between the tests cleanup, connURL := preparePostgresTestContainer(t) defer cleanup() - // create the database user - dbUser := "vaultstatictest" - createTestPGUser(t, connURL, dbUser, "password", testRoleStaticCreate) - connectionDetails := map[string]interface{}{ "connection_url": connURL, } db := new() - _, err := db.Init(context.Background(), connectionDetails, true) + + // Give a timeout just in case the test decides to be problematic + initCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + _, err := db.Init(initCtx, connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } - password, err := db.GenerateCredentials(context.Background()) - if err != nil { - t.Fatal(err) - } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + statements := dbplugin.Statements{ + Creation: []string{createAdminUser}, + Revocation: test.revokeStmts, + } - usernameConfig := dbplugin.StaticUserConfig{ - Username: dbUser, - Password: password, - } + usernameConfig := dbplugin.UsernameConfig{ + DisplayName: "test", + RoleName: "test", + } - // Test with no configured Rotation Statement - username, password, err := db.SetCredentials(context.Background(), dbplugin.Statements{}, usernameConfig) - if err == nil { - t.Fatalf("err: %s", err) - } + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } - statements := dbplugin.Statements{ - Rotation: []string{testPostgresStaticRoleRotate}, - } - // User should not exist, make sure we can create - username, password, err = db.SetCredentials(context.Background(), statements, usernameConfig) - if err != nil { - t.Fatalf("err: %s", err) - } + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } - if err := testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + // Test default revoke statements + err = db.RevokeUser(context.Background(), statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } - // call SetCredentials again, password will change - newPassword, _ := db.GenerateCredentials(context.Background()) - usernameConfig.Password = newPassword - username, password, err = db.SetCredentials(context.Background(), statements, usernameConfig) - if err != nil { - t.Fatalf("err: %s", err) + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } + }) } +} - if password != newPassword { - t.Fatal("passwords should have changed") +func TestPostgreSQL_SetCredentials_missingArgs(t *testing.T) { + type testCase struct { + statements dbplugin.Statements + userConfig dbplugin.StaticUserConfig + } + + tests := map[string]testCase{ + "empty rotation statements": { + statements: dbplugin.Statements{ + Rotation: nil, + }, + userConfig: dbplugin.StaticUserConfig{ + Username: "testuser", + Password: "password", + }, + }, + "empty username": { + statements: dbplugin.Statements{ + Rotation: []string{` + ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, + }, + }, + userConfig: dbplugin.StaticUserConfig{ + Username: "", + Password: "password", + }, + }, + "empty password": { + statements: dbplugin.Statements{ + Rotation: []string{` + ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, + }, + }, + userConfig: dbplugin.StaticUserConfig{ + Username: "testuser", + Password: "", + }, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + db := new() + + username, password, err := db.SetCredentials(context.Background(), test.statements, test.userConfig) + if err == nil { + t.Fatalf("expected err, got nil") + } + if username != "" { + t.Fatalf("expected empty username, got [%s]", username) + } + if password != "" { + t.Fatalf("expected empty password, got [%s]", password) + } + }) } +} - if err := testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) +func TestPostgresSQL_SetCredentials(t *testing.T) { + type testCase struct { + rotationStmts []string + } + + tests := map[string]testCase{ + "name rotation": { + rotationStmts: []string{` + ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, + }, + }, + "username rotation": { + rotationStmts: []string{` + ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';`, + }, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + // Shared test container for speed - there should not be any overlap between the tests + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + // create the database user + dbUser := "vaultstatictest" + initPassword := "password" + createTestPGUser(t, connURL, dbUser, initPassword, testRoleStaticCreate) + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := new() + + // Give a timeout just in case the test decides to be problematic + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + _, err := db.Init(ctx, connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + Rotation: test.rotationStmts, + } + + password, err := db.GenerateCredentials(context.Background()) + if err != nil { + t.Fatal(err) + } + + usernameConfig := dbplugin.StaticUserConfig{ + Username: dbUser, + Password: password, + } + + if err := testCredsExist(t, connURL, dbUser, initPassword); err != nil { + t.Fatalf("Could not connect with initial credentials: %s", err) + } + + username, password, err := db.SetCredentials(ctx, statements, usernameConfig) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + if err := testCredsExist(t, connURL, username, initPassword); err == nil { + t.Fatalf("Should not be able to connect with initial credentials") + } + }) } } @@ -395,7 +574,7 @@ func testCredsExist(t testing.TB, connURL, username, password string) error { return db.Ping() } -const testPostgresRole = ` +const createAdminUser = ` CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' @@ -403,37 +582,6 @@ CREATE ROLE "{{name}}" WITH GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; ` -const testPostgresReadOnlyRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; -GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; -GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; -` - -const testPostgresBlockStatementRole = ` -DO $$ -BEGIN - IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN - CREATE ROLE "foo-role"; - CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; - ALTER ROLE "foo-role" SET search_path = foo; - GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; - GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; - END IF; -END -$$ - -CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; -GRANT "foo-role" TO "{{name}}"; -ALTER ROLE "{{name}}" SET search_path = foo; -GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; -` - var testPostgresBlockStatementRoleSlice = []string{ ` DO $$ @@ -465,27 +613,12 @@ REVOKE USAGE ON SCHEMA public FROM "{{name}}"; DROP ROLE IF EXISTS "{{name}}"; ` -const testPostgresStaticRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}'; -GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; -` - const testRoleStaticCreate = ` CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}'; ` -const testPostgresStaticRoleRotate = ` -ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}'; -` - -const testPostgresStaticRoleGrant = ` -GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; -` - // This is a copy of a test helper method also found in // builtin/logical/database/rotation_test.go , and should be moved into a shared // helper file in the future.