diff --git a/iamexample/members.go b/iamexample/members.go index 3fc08267..c59e9de1 100644 --- a/iamexample/members.go +++ b/iamexample/members.go @@ -22,16 +22,16 @@ var _ iammember.Resolver = &iamMemberHeaderResolver{} type iamMemberHeaderResolver struct{} // ResolveIAMMembers implements iammember.Resolver. -func (m *iamMemberHeaderResolver) ResolveIAMMembers(ctx context.Context) ([]string, error) { +func (m *iamMemberHeaderResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - return nil, status.Errorf(codes.Unauthenticated, "missing members header: %s", MemberHeader) + return nil, nil, status.Errorf(codes.Unauthenticated, "missing members header: %s", MemberHeader) } values := md.Get(MemberHeader) if len(values) == 0 { - return nil, status.Errorf(codes.Unauthenticated, "missing members header: %s", MemberHeader) + return nil, nil, status.Errorf(codes.Unauthenticated, "missing members header: %s", MemberHeader) } - return values, nil + return ctx, values, nil } // WithOutgoingMembers appends the provided members to the outgoing gRPC context. diff --git a/iammember/chain.go b/iammember/chain.go new file mode 100644 index 00000000..073eb614 --- /dev/null +++ b/iammember/chain.go @@ -0,0 +1,44 @@ +package iammember + +import "context" + +// ChainResolvers creates a single resolver out of a chain of many resolvers. +// +// The resulting resolved members will be the union of the members resolved by each resolver. +// +// Execution is done in left-to-right order, including passing of context. +// For example ChainResolvers(one, two, three) will execute one before two before three, and three +// will see context changes of one and two. +// +// If any resolver returns an error, that error is immediately returned and no further resolvers are called. +func ChainResolvers(resolvers ...Resolver) Resolver { + return chainResolver{resolvers: resolvers} +} + +type chainResolver struct { + resolvers []Resolver +} + +func (c chainResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) { + var result, members []string + var err error + for _, resolver := range c.resolvers { + ctx, members, err = resolver.ResolveIAMMembers(ctx) + if err != nil { + return nil, nil, err + } + for _, member := range members { + var hasMember bool + for _, resultMember := range result { + if member == resultMember { + hasMember = true + break + } + } + if !hasMember { + result = append(result, member) + } + } + } + return ctx, result, nil +} diff --git a/iammember/chain_test.go b/iammember/chain_test.go new file mode 100644 index 00000000..1c3a4d7e --- /dev/null +++ b/iammember/chain_test.go @@ -0,0 +1,74 @@ +package iammember + +import ( + "context" + "errors" + "testing" + + "gotest.tools/v3/assert" +) + +func TestChainResolvers(t *testing.T) { + t.Run("no resolvers", func(t *testing.T) { + ctx, members, err := ChainResolvers().ResolveIAMMembers(context.Background()) + assert.Equal(t, ctx, context.Background()) + assert.Assert(t, members == nil) + assert.NilError(t, err) + }) + + t.Run("single", func(t *testing.T) { + expected := []string{"foo", "bar"} + ctx, actual, err := ChainResolvers(constantResolver{expected}).ResolveIAMMembers(context.Background()) + assert.Equal(t, ctx, context.Background()) + assert.DeepEqual(t, expected, actual) + assert.NilError(t, err) + }) + + t.Run("multi", func(t *testing.T) { + expected := []string{"foo", "bar", "baz"} + ctx, actual, err := ChainResolvers( + constantResolver{members: []string{"foo", "bar"}}, + constantResolver{members: []string{"baz"}}, + ).ResolveIAMMembers(context.Background()) + assert.Equal(t, ctx, context.Background()) + assert.DeepEqual(t, expected, actual) + assert.NilError(t, err) + }) + + t.Run("multi duplicates", func(t *testing.T) { + expected := []string{"foo", "bar", "baz"} + ctx, actual, err := ChainResolvers( + constantResolver{members: []string{"foo", "bar"}}, + constantResolver{members: []string{"bar", "baz"}}, + ).ResolveIAMMembers(context.Background()) + assert.Equal(t, ctx, context.Background()) + assert.DeepEqual(t, expected, actual) + assert.NilError(t, err) + }) + + t.Run("error", func(t *testing.T) { + ctx, actual, err := ChainResolvers( + constantResolver{members: []string{"foo", "bar"}}, + errorResolver{err: errors.New("boom")}, + ).ResolveIAMMembers(context.Background()) + assert.Assert(t, ctx == nil) + assert.Assert(t, actual == nil) + assert.Error(t, err, "boom") + }) +} + +type constantResolver struct { + members []string +} + +func (c constantResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) { + return ctx, c.members, nil +} + +type errorResolver struct { + err error +} + +func (e errorResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) { + return nil, nil, e.err +} diff --git a/iammember/resolver.go b/iammember/resolver.go index 34231d20..59b6aaf5 100644 --- a/iammember/resolver.go +++ b/iammember/resolver.go @@ -4,5 +4,5 @@ import "context" // Resolver resolves the IAM member identifiers for a caller context. type Resolver interface { - ResolveIAMMembers(context.Context) ([]string, error) + ResolveIAMMembers(context.Context) (context.Context, []string, error) } diff --git a/iamspanner/server.go b/iamspanner/server.go index 57ceecdf..7d44cbb2 100644 --- a/iamspanner/server.go +++ b/iamspanner/server.go @@ -110,7 +110,7 @@ func (s *IAMServer) TestIamPermissions( ctx context.Context, request *iam.TestIamPermissionsRequest, ) (*iam.TestIamPermissionsResponse, error) { - members, err := s.resolveMembers(ctx) + ctx, members, err := s.memberResolver.ResolveIAMMembers(ctx) if err != nil { return nil, err } @@ -163,7 +163,7 @@ func (s *IAMServer) TestPermissionOnResources( permission string, resources []string, ) (map[string]bool, error) { - members, err := s.resolveMembers(ctx) + ctx, members, err := s.memberResolver.ResolveIAMMembers(ctx) if err != nil { return nil, err } @@ -450,14 +450,6 @@ func (s *IAMServer) handleStorageError(ctx context.Context, err error) error { } } -func (s *IAMServer) resolveMembers(ctx context.Context) ([]string, error) { - members, err := s.memberResolver.ResolveIAMMembers(ctx) - if err != nil { - return nil, err - } - return members, nil -} - func computeETag(policy *iam.Policy) ([]byte, error) { data, err := proto.Marshal(policy) if err != nil { diff --git a/iamspanner/server_test.go b/iamspanner/server_test.go index bcac0629..537658c5 100644 --- a/iamspanner/server_test.go +++ b/iamspanner/server_test.go @@ -76,8 +76,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -98,8 +98,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -133,8 +133,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -163,8 +163,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -202,8 +202,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -244,8 +244,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -271,8 +271,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -310,8 +310,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -350,8 +350,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user2}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user2}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -389,8 +389,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -429,8 +429,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -469,8 +469,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -492,8 +492,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -528,8 +528,8 @@ func TestServer(t *testing.T) { }) } -type iamMemberResolver func(context.Context) ([]string, error) +type iamMemberResolver func(context.Context) (context.Context, []string, error) -func (r iamMemberResolver) ResolveIAMMembers(ctx context.Context) ([]string, error) { +func (r iamMemberResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) { return r(ctx) }