diff --git a/iamspanner/server.go b/iamspanner/server.go index 7a7158cd..4dee13ea 100644 --- a/iamspanner/server.go +++ b/iamspanner/server.go @@ -144,6 +144,41 @@ func (s *IAMServer) TestIamPermissions( return response, nil } +// ReadWritePolicy enables the caller to modify a policy in a read-write transaction. +func (s *IAMServer) ReadWritePolicy( + ctx context.Context, + resource string, + fn func(*iam.Policy) (*iam.Policy, error), +) (*iam.Policy, error) { + var result *iam.Policy + if _, err := s.client.ReadWriteTransaction( + ctx, + func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { + policy, err := s.QueryIamPolicyInTransaction(ctx, tx, resource) + if err != nil { + return err + } + policy, err = fn(policy) + if err != nil { + return err + } + result = policy + mutations := []*spanner.Mutation{deleteIAMPolicyMutation(resource)} + mutations = append(mutations, insertIAMPolicyMutations(resource, policy)...) + return tx.BufferWrite(mutations) + }, + ); err != nil { + return nil, s.handleStorageError(ctx, err) + } + result.Etag = nil + etag, err := computeETag(result) + if err != nil { + return nil, err + } + result.Etag = etag + return result, nil +} + // TestPermissionOnResource tests if the caller has the specified permission on the specified resource. func (s *IAMServer) TestPermissionOnResource( ctx context.Context, diff --git a/iamspanner/server_test.go b/iamspanner/server_test.go index 75bf4c73..94f47758 100644 --- a/iamspanner/server_test.go +++ b/iamspanner/server_test.go @@ -7,6 +7,7 @@ import ( "cloud.google.com/go/spanner" "go.einride.tech/iam/iammember" + "go.einride.tech/iam/iampolicy" "go.einride.tech/iam/iamregistry" "go.einride.tech/iam/iamresource" iamv1 "go.einride.tech/iam/proto/gen/einride/iam/v1" @@ -561,4 +562,51 @@ func TestServer(t *testing.T) { } assert.DeepEqual(t, expected, actual, protocmp.Transform()) }) + + t.Run("read+write", func(t *testing.T) { + t.Parallel() + server, err := NewIAMServer( + newDatabase(), + roles, + iammember.FromContextResolver(), + ServerConfig{ + ErrorHook: func(ctx context.Context, err error) { + t.Log(err) + }, + }) + assert.NilError(t, err) + expected := &iam.Policy{ + Bindings: []*iam.Binding{ + { + Role: "roles/test.admin", + Members: []string{"user:user1"}, + }, + }, + } + actual, err := server.ReadWritePolicy(ctx, "resources/test1", func(policy *iam.Policy) (*iam.Policy, error) { + iampolicy.AddBinding(policy, "roles/test.admin", "user:user1") + return policy, nil + }) + assert.NilError(t, err) + assert.DeepEqual(t, expected.Bindings, actual.Bindings, protocmp.Transform()) + expected2 := &iam.Policy{ + Bindings: []*iam.Binding{ + { + Role: "roles/test.admin", + Members: []string{"user:user1"}, + }, + { + Role: "roles/test.user", + Members: []string{"user:user2"}, + }, + }, + } + actual2, err := server.ReadWritePolicy(ctx, "resources/test1", func(policy *iam.Policy) (*iam.Policy, error) { + assert.DeepEqual(t, actual, policy, protocmp.Transform()) + iampolicy.AddBinding(policy, "roles/test.user", "user:user2") + return policy, nil + }) + assert.NilError(t, err) + assert.DeepEqual(t, expected2.Bindings, actual2.Bindings, protocmp.Transform()) + }) }