From 1caee37b888b395c225cd2ec5d7b31f17e39fa88 Mon Sep 17 00:00:00 2001 From: Bertrand Paquet Date: Tue, 5 Dec 2023 15:45:14 +0100 Subject: [PATCH] Fix #321: Always use a single transaction when changing grant (#9) Fix https://github.com/cyrilgdn/terraform-provider-postgresql/issues/321 --- postgresql/resource_postgresql_grant.go | 125 +++++++++++-------- postgresql/resource_postgresql_grant_test.go | 5 +- 2 files changed, 76 insertions(+), 54 deletions(-) diff --git a/postgresql/resource_postgresql_grant.go b/postgresql/resource_postgresql_grant.go index f4a5a6cb..e037d032 100644 --- a/postgresql/resource_postgresql_grant.go +++ b/postgresql/resource_postgresql_grant.go @@ -37,9 +37,7 @@ var objectTypes = map[string]string{ func resourcePostgreSQLGrant() *schema.Resource { return &schema.Resource{ Create: PGResourceFunc(resourcePostgreSQLGrantCreate), - // Since all of this resource's arguments force a recreation - // there's no need for an Update function - // Update: + Update: PGResourceFunc(resourcePostgreSQLGrantUpdate), Read: PGResourceFunc(resourcePostgreSQLGrantRead), Delete: PGResourceFunc(resourcePostgreSQLGrantDelete), @@ -57,46 +55,46 @@ func resourcePostgreSQLGrant() *schema.Resource { Description: "The database to grant privileges on for this role", }, "schema": { - Type: schema.TypeString, - Optional: true, - ForceNew: true, + Type: schema.TypeString, + Optional: true, + // ForceNew: true, Description: "The database schema to grant privileges on for this role", }, "object_type": { - Type: schema.TypeString, - Required: true, - ForceNew: true, + Type: schema.TypeString, + Required: true, + // ForceNew: true, ValidateFunc: validation.StringInSlice(allowedObjectTypes, false), Description: "The PostgreSQL object type to grant the privileges on (one of: " + strings.Join(allowedObjectTypes, ", ") + ")", }, "objects": { - Type: schema.TypeSet, - Optional: true, - ForceNew: true, + Type: schema.TypeSet, + Optional: true, + // ForceNew: true, Elem: &schema.Schema{Type: schema.TypeString}, Set: schema.HashString, Description: "The specific objects to grant privileges on for this role (empty means all objects of the requested type)", }, "columns": { - Type: schema.TypeSet, - Optional: true, - ForceNew: true, + Type: schema.TypeSet, + Optional: true, + // ForceNew: true, Elem: &schema.Schema{Type: schema.TypeString}, Set: schema.HashString, Description: "The specific columns to grant privileges on for this role", }, "privileges": { - Type: schema.TypeSet, - Required: true, - ForceNew: true, + Type: schema.TypeSet, + Required: true, + // ForceNew: true, Elem: &schema.Schema{Type: schema.TypeString}, Set: schema.HashString, Description: "The list of privileges to grant", }, "with_grant_option": { - Type: schema.TypeBool, - Optional: true, - ForceNew: true, + Type: schema.TypeBool, + Optional: true, + // ForceNew: true, Default: false, Description: "Permit the grant recipient to grant it to others", }, @@ -129,6 +127,10 @@ func resourcePostgreSQLGrantRead(db *DBConnection, d *schema.ResourceData) error } func resourcePostgreSQLGrantCreate(db *DBConnection, d *schema.ResourceData) error { + return resourcePostgreSQLGrantCreateOrUpdate(db, d, false) +} + +func resourcePostgreSQLGrantCreateOrUpdate(db *DBConnection, d *schema.ResourceData, usePreviousForRevoke bool) error { if err := validateFeatureSupport(db, d); err != nil { return fmt.Errorf("feature is not supported: %v", err) } @@ -187,7 +189,7 @@ func resourcePostgreSQLGrantCreate(db *DBConnection, d *schema.ResourceData) err // Revoke all privileges before granting otherwise reducing privileges will not work. // We just have to revoke them in the same transaction so the role will not lost its // privileges between the revoke and grant statements. - if err := revokeRolePrivileges(txn, d); err != nil { + if err := revokeRolePrivileges(txn, d, usePreviousForRevoke); err != nil { return err } if err := grantRolePrivileges(txn, d); err != nil { @@ -213,6 +215,10 @@ func resourcePostgreSQLGrantCreate(db *DBConnection, d *schema.ResourceData) err return readRolePrivileges(txn, d) } +func resourcePostgreSQLGrantUpdate(db *DBConnection, d *schema.ResourceData) error { + return resourcePostgreSQLGrantCreateOrUpdate(db, d, true) +} + func resourcePostgreSQLGrantDelete(db *DBConnection, d *schema.ResourceData) error { if err := validateFeatureSupport(db, d); err != nil { return fmt.Errorf("feature is not supported: %v", err) @@ -243,7 +249,7 @@ func resourcePostgreSQLGrantDelete(db *DBConnection, d *schema.ResourceData) err } if err := withRolesGranted(txn, owners, func() error { - return revokeRolePrivileges(txn, d) + return revokeRolePrivileges(txn, d, false) }); err != nil { return err } @@ -589,40 +595,42 @@ func createGrantQuery(d *schema.ResourceData, privileges []string) string { return query } -func createRevokeQuery(d *schema.ResourceData) string { +type ResourceSchemGetter func(string) interface{} + +func createRevokeQuery(getter ResourceSchemGetter) string { var query string - switch strings.ToUpper(d.Get("object_type").(string)) { + switch strings.ToUpper(getter("object_type").(string)) { case "DATABASE": query = fmt.Sprintf( "REVOKE ALL PRIVILEGES ON DATABASE %s FROM %s", - pq.QuoteIdentifier(d.Get("database").(string)), - pq.QuoteIdentifier(d.Get("role").(string)), + pq.QuoteIdentifier(getter("database").(string)), + pq.QuoteIdentifier(getter("role").(string)), ) case "SCHEMA": query = fmt.Sprintf( "REVOKE ALL PRIVILEGES ON SCHEMA %s FROM %s", - pq.QuoteIdentifier(d.Get("schema").(string)), - pq.QuoteIdentifier(d.Get("role").(string)), + pq.QuoteIdentifier(getter("schema").(string)), + pq.QuoteIdentifier(getter("role").(string)), ) case "FOREIGN_DATA_WRAPPER": - fdwName := d.Get("objects").(*schema.Set).List()[0] + fdwName := getter("objects").(*schema.Set).List()[0] query = fmt.Sprintf( "REVOKE ALL PRIVILEGES ON FOREIGN DATA WRAPPER %s FROM %s", pq.QuoteIdentifier(fdwName.(string)), - pq.QuoteIdentifier(d.Get("role").(string)), + pq.QuoteIdentifier(getter("role").(string)), ) case "FOREIGN_SERVER": - srvName := d.Get("objects").(*schema.Set).List()[0] + srvName := getter("objects").(*schema.Set).List()[0] query = fmt.Sprintf( "REVOKE ALL PRIVILEGES ON FOREIGN SERVER %s FROM %s", pq.QuoteIdentifier(srvName.(string)), - pq.QuoteIdentifier(d.Get("role").(string)), + pq.QuoteIdentifier(getter("role").(string)), ) case "COLUMN": - objects := d.Get("objects").(*schema.Set) - columns := d.Get("columns").(*schema.Set) - privileges := d.Get("privileges").(*schema.Set) + objects := getter("objects").(*schema.Set) + columns := getter("columns").(*schema.Set) + privileges := getter("privileges").(*schema.Set) if privileges.Len() == 0 || columns.Len() == 0 { // No privileges to revoke, so don't revoke anything query = "SELECT NULL" @@ -631,13 +639,13 @@ func createRevokeQuery(d *schema.ResourceData) string { "REVOKE %s (%s) ON TABLE %s FROM %s", setToPgIdentSimpleList(privileges), setToPgIdentListWithoutSchema(columns), - setToPgIdentList(d.Get("schema").(string), objects), - pq.QuoteIdentifier(d.Get("role").(string)), + setToPgIdentList(getter("schema").(string), objects), + pq.QuoteIdentifier(getter("role").(string)), ) } case "TABLE", "SEQUENCE", "FUNCTION", "PROCEDURE", "ROUTINE": - objects := d.Get("objects").(*schema.Set) - privileges := d.Get("privileges").(*schema.Set) + objects := getter("objects").(*schema.Set) + privileges := getter("privileges").(*schema.Set) if objects.Len() > 0 { if privileges.Len() > 0 { // Revoking specific privileges instead of all privileges @@ -645,24 +653,24 @@ func createRevokeQuery(d *schema.ResourceData) string { query = fmt.Sprintf( "REVOKE %s ON %s %s FROM %s", setToPgIdentSimpleList(privileges), - strings.ToUpper(d.Get("object_type").(string)), - setToPgIdentList(d.Get("schema").(string), objects), - pq.QuoteIdentifier(d.Get("role").(string)), + strings.ToUpper(getter("object_type").(string)), + setToPgIdentList(getter("schema").(string), objects), + pq.QuoteIdentifier(getter("role").(string)), ) } else { query = fmt.Sprintf( "REVOKE ALL PRIVILEGES ON %s %s FROM %s", - strings.ToUpper(d.Get("object_type").(string)), - setToPgIdentList(d.Get("schema").(string), objects), - pq.QuoteIdentifier(d.Get("role").(string)), + strings.ToUpper(getter("object_type").(string)), + setToPgIdentList(getter("schema").(string), objects), + pq.QuoteIdentifier(getter("role").(string)), ) } } else { query = fmt.Sprintf( "REVOKE ALL PRIVILEGES ON ALL %sS IN SCHEMA %s FROM %s", - strings.ToUpper(d.Get("object_type").(string)), - pq.QuoteIdentifier(d.Get("schema").(string)), - pq.QuoteIdentifier(d.Get("role").(string)), + strings.ToUpper(getter("object_type").(string)), + pq.QuoteIdentifier(getter("schema").(string)), + pq.QuoteIdentifier(getter("role").(string)), ) } } @@ -675,24 +683,35 @@ func grantRolePrivileges(txn *sql.Tx, d *schema.ResourceData) error { for _, priv := range d.Get("privileges").(*schema.Set).List() { privileges = append(privileges, priv.(string)) } - if len(privileges) == 0 { log.Printf("[DEBUG] no privileges to grant for role %s in database: %s,", d.Get("role").(string), d.Get("database")) return nil } - query := createGrantQuery(d, privileges) + log.Printf("[INFO] executing %s", query) _, err := txn.Exec(query) return err } -func revokeRolePrivileges(txn *sql.Tx, d *schema.ResourceData) error { - query := createRevokeQuery(d) +func revokeRolePrivileges(txn *sql.Tx, d *schema.ResourceData, usePrevious bool) error { + var getter ResourceSchemGetter + if usePrevious { + getter = func(name string) interface{} { + old, _ := d.GetChange(name) + return old + } + } else { + getter = func(name string) interface{} { + return d.Get(name) + } + } + query := createRevokeQuery(getter) if len(query) == 0 { // Query is empty, don't run anything return nil } + log.Printf("[INFO] executing %s", query) if _, err := txn.Exec(query); err != nil { return fmt.Errorf("could not execute revoke query: %w", err) } diff --git a/postgresql/resource_postgresql_grant_test.go b/postgresql/resource_postgresql_grant_test.go index a81c95bc..f4de40dd 100644 --- a/postgresql/resource_postgresql_grant_test.go +++ b/postgresql/resource_postgresql_grant_test.go @@ -293,7 +293,10 @@ func TestCreateRevokeQuery(t *testing.T) { } for _, c := range cases { - out := createRevokeQuery(c.resource) + getter := func(name string) interface{} { + return c.resource.Get(name) + } + out := createRevokeQuery(getter) if out != c.expected { t.Fatalf("Error matching output and expected: %#v vs %#v", out, c.expected) }