Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cyrilgdn#321 replaces postgresql_grant all the time. #476

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 54 additions & 34 deletions postgresql/resource_postgresql_grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ var objectTypes = map[string]string{
"schema": "n",
}

type ResourceSchemeGetter func(string) interface{}

func resourcePostgreSQLGrant() *schema.Resource {
return &schema.Resource{
Create: PGResourceFunc(resourcePostgreSQLGrantCreate),
// Since all of this resource's arguments force a recreation
// there's no need for an Update function
// Update:
Update: PGResourceFunc(resourcePostgreSQLGrantUpdate),
Read: PGResourceFunc(resourcePostgreSQLGrantRead),
Delete: PGResourceFunc(resourcePostgreSQLGrantDelete),

Expand Down Expand Up @@ -88,7 +88,6 @@ func resourcePostgreSQLGrant() *schema.Resource {
"privileges": {
Type: schema.TypeSet,
Required: true,
ForceNew: true,
Elem: &schema.Schema{Type: schema.TypeString},
Set: schema.HashString,
Description: "The list of privileges to grant",
Expand Down Expand Up @@ -129,6 +128,14 @@ func resourcePostgreSQLGrantRead(db *DBConnection, d *schema.ResourceData) error
}

func resourcePostgreSQLGrantCreate(db *DBConnection, d *schema.ResourceData) error {
return resourcePostgreSQLGrantCreateOrUpdate(db, d, false)
}

func resourcePostgreSQLGrantUpdate(db *DBConnection, d *schema.ResourceData) error {
return resourcePostgreSQLGrantCreateOrUpdate(db, d, true)
}

func resourcePostgreSQLGrantCreateOrUpdate(db *DBConnection, d *schema.ResourceData, usePrevious bool) error {
if err := validateFeatureSupport(db, d); err != nil {
return fmt.Errorf("feature is not supported: %v", err)
}
Expand Down Expand Up @@ -187,7 +194,7 @@ func resourcePostgreSQLGrantCreate(db *DBConnection, d *schema.ResourceData) err
// Revoke all privileges before granting otherwise reducing privileges will not work.
// We just have to revoke them in the same transaction so the role will not lost its
// privileges between the revoke and grant statements.
if err := revokeRolePrivileges(txn, d); err != nil {
if err := revokeRolePrivileges(txn, d, usePrevious); err != nil {
return err
}
if err := grantRolePrivileges(txn, d); err != nil {
Expand Down Expand Up @@ -243,7 +250,7 @@ func resourcePostgreSQLGrantDelete(db *DBConnection, d *schema.ResourceData) err
}

if err := withRolesGranted(txn, owners, func() error {
return revokeRolePrivileges(txn, d)
return revokeRolePrivileges(txn, d, false)
}); err != nil {
return err
}
Expand Down Expand Up @@ -589,40 +596,40 @@ func createGrantQuery(d *schema.ResourceData, privileges []string) string {
return query
}

func createRevokeQuery(d *schema.ResourceData) string {
func createRevokeQuery(getter ResourceSchemeGetter) string {
var query string

switch strings.ToUpper(d.Get("object_type").(string)) {
switch strings.ToUpper(getter("object_type").(string)) {
case "DATABASE":
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON DATABASE %s FROM %s",
pq.QuoteIdentifier(d.Get("database").(string)),
pq.QuoteIdentifier(d.Get("role").(string)),
pq.QuoteIdentifier(getter("database").(string)),
pq.QuoteIdentifier(getter("role").(string)),
)
case "SCHEMA":
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON SCHEMA %s FROM %s",
pq.QuoteIdentifier(d.Get("schema").(string)),
pq.QuoteIdentifier(d.Get("role").(string)),
pq.QuoteIdentifier(getter("schema").(string)),
pq.QuoteIdentifier(getter("role").(string)),
)
case "FOREIGN_DATA_WRAPPER":
fdwName := d.Get("objects").(*schema.Set).List()[0]
fdwName := getter("objects").(*schema.Set).List()[0]
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON FOREIGN DATA WRAPPER %s FROM %s",
pq.QuoteIdentifier(fdwName.(string)),
pq.QuoteIdentifier(d.Get("role").(string)),
pq.QuoteIdentifier(getter("role").(string)),
)
case "FOREIGN_SERVER":
srvName := d.Get("objects").(*schema.Set).List()[0]
srvName := getter("objects").(*schema.Set).List()[0]
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON FOREIGN SERVER %s FROM %s",
pq.QuoteIdentifier(srvName.(string)),
pq.QuoteIdentifier(d.Get("role").(string)),
pq.QuoteIdentifier(getter("role").(string)),
)
case "COLUMN":
objects := d.Get("objects").(*schema.Set)
columns := d.Get("columns").(*schema.Set)
privileges := d.Get("privileges").(*schema.Set)
objects := getter("objects").(*schema.Set)
columns := getter("columns").(*schema.Set)
privileges := getter("privileges").(*schema.Set)
if privileges.Len() == 0 || columns.Len() == 0 {
// No privileges to revoke, so don't revoke anything
query = "SELECT NULL"
Expand All @@ -631,38 +638,38 @@ func createRevokeQuery(d *schema.ResourceData) string {
"REVOKE %s (%s) ON TABLE %s FROM %s",
setToPgIdentSimpleList(privileges),
setToPgIdentListWithoutSchema(columns),
setToPgIdentList(d.Get("schema").(string), objects),
pq.QuoteIdentifier(d.Get("role").(string)),
setToPgIdentList(getter("schema").(string), objects),
pq.QuoteIdentifier(getter("role").(string)),
)
}
case "TABLE", "SEQUENCE", "FUNCTION", "PROCEDURE", "ROUTINE":
objects := d.Get("objects").(*schema.Set)
privileges := d.Get("privileges").(*schema.Set)
objects := getter("objects").(*schema.Set)
privileges := getter("privileges").(*schema.Set)
if objects.Len() > 0 {
if privileges.Len() > 0 {
// Revoking specific privileges instead of all privileges
// to avoid messing with column level grants
query = fmt.Sprintf(
"REVOKE %s ON %s %s FROM %s",
setToPgIdentSimpleList(privileges),
strings.ToUpper(d.Get("object_type").(string)),
setToPgIdentList(d.Get("schema").(string), objects),
pq.QuoteIdentifier(d.Get("role").(string)),
strings.ToUpper(getter("object_type").(string)),
setToPgIdentList(getter("schema").(string), objects),
pq.QuoteIdentifier(getter("role").(string)),
)
} else {
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON %s %s FROM %s",
strings.ToUpper(d.Get("object_type").(string)),
setToPgIdentList(d.Get("schema").(string), objects),
pq.QuoteIdentifier(d.Get("role").(string)),
strings.ToUpper(getter("object_type").(string)),
setToPgIdentList(getter("schema").(string), objects),
pq.QuoteIdentifier(getter("role").(string)),
)
}
} else {
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON ALL %sS IN SCHEMA %s FROM %s",
strings.ToUpper(d.Get("object_type").(string)),
pq.QuoteIdentifier(d.Get("schema").(string)),
pq.QuoteIdentifier(d.Get("role").(string)),
strings.ToUpper(getter("object_type").(string)),
pq.QuoteIdentifier(getter("schema").(string)),
pq.QuoteIdentifier(getter("role").(string)),
)
}
}
Expand All @@ -687,8 +694,21 @@ func grantRolePrivileges(txn *sql.Tx, d *schema.ResourceData) error {
return err
}

func revokeRolePrivileges(txn *sql.Tx, d *schema.ResourceData) error {
query := createRevokeQuery(d)
func revokeRolePrivileges(txn *sql.Tx, d *schema.ResourceData, usePrevious bool) error {
getter := d.Get

if usePrevious {
getter = func(name string) interface{} {
if d.HasChange(name) {
old, _ := d.GetChange(name)
return old
}

return d.Get(name)
}
}

query := createRevokeQuery(getter)
if len(query) == 0 {
// Query is empty, don't run anything
return nil
Expand Down
2 changes: 1 addition & 1 deletion postgresql/resource_postgresql_grant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func TestCreateRevokeQuery(t *testing.T) {
}

for _, c := range cases {
out := createRevokeQuery(c.resource)
out := createRevokeQuery(c.resource.Get)
if out != c.expected {
t.Fatalf("Error matching output and expected: %#v vs %#v", out, c.expected)
}
Expand Down
Loading