From 7cf25a1efcd8edb53aa185e70d5da05099926014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oscar=20S=C3=B6derlund?= Date: Fri, 7 May 2021 08:35:00 +0200 Subject: [PATCH] feat(iammember): add iammember.ChainResolvers For combining a set of standard resolvers. For example for resolving different types of JWT tokens, or JWT tokens from multiple headers. Context has been added to the resolver API to enable individual resolvers to cache results in the context, for example parsed JWT tokens. --- iamexample/members.go | 8 ++--- iammember/chain.go | 44 +++++++++++++++++++++++ iammember/chain_test.go | 74 +++++++++++++++++++++++++++++++++++++++ iammember/resolver.go | 2 +- iamspanner/server.go | 12 ++----- iamspanner/server_test.go | 56 ++++++++++++++--------------- 6 files changed, 153 insertions(+), 43 deletions(-) create mode 100644 iammember/chain.go create mode 100644 iammember/chain_test.go diff --git a/iamexample/members.go b/iamexample/members.go index 3fc08267..c59e9de1 100644 --- a/iamexample/members.go +++ b/iamexample/members.go @@ -22,16 +22,16 @@ var _ iammember.Resolver = &iamMemberHeaderResolver{} type iamMemberHeaderResolver struct{} // ResolveIAMMembers implements iammember.Resolver. -func (m *iamMemberHeaderResolver) ResolveIAMMembers(ctx context.Context) ([]string, error) { +func (m *iamMemberHeaderResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - return nil, status.Errorf(codes.Unauthenticated, "missing members header: %s", MemberHeader) + return nil, nil, status.Errorf(codes.Unauthenticated, "missing members header: %s", MemberHeader) } values := md.Get(MemberHeader) if len(values) == 0 { - return nil, status.Errorf(codes.Unauthenticated, "missing members header: %s", MemberHeader) + return nil, nil, status.Errorf(codes.Unauthenticated, "missing members header: %s", MemberHeader) } - return values, nil + return ctx, values, nil } // WithOutgoingMembers appends the provided members to the outgoing gRPC context. diff --git a/iammember/chain.go b/iammember/chain.go new file mode 100644 index 00000000..073eb614 --- /dev/null +++ b/iammember/chain.go @@ -0,0 +1,44 @@ +package iammember + +import "context" + +// ChainResolvers creates a single resolver out of a chain of many resolvers. +// +// The resulting resolved members will be the union of the members resolved by each resolver. +// +// Execution is done in left-to-right order, including passing of context. +// For example ChainResolvers(one, two, three) will execute one before two before three, and three +// will see context changes of one and two. +// +// If any resolver returns an error, that error is immediately returned and no further resolvers are called. +func ChainResolvers(resolvers ...Resolver) Resolver { + return chainResolver{resolvers: resolvers} +} + +type chainResolver struct { + resolvers []Resolver +} + +func (c chainResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) { + var result, members []string + var err error + for _, resolver := range c.resolvers { + ctx, members, err = resolver.ResolveIAMMembers(ctx) + if err != nil { + return nil, nil, err + } + for _, member := range members { + var hasMember bool + for _, resultMember := range result { + if member == resultMember { + hasMember = true + break + } + } + if !hasMember { + result = append(result, member) + } + } + } + return ctx, result, nil +} diff --git a/iammember/chain_test.go b/iammember/chain_test.go new file mode 100644 index 00000000..1c3a4d7e --- /dev/null +++ b/iammember/chain_test.go @@ -0,0 +1,74 @@ +package iammember + +import ( + "context" + "errors" + "testing" + + "gotest.tools/v3/assert" +) + +func TestChainResolvers(t *testing.T) { + t.Run("no resolvers", func(t *testing.T) { + ctx, members, err := ChainResolvers().ResolveIAMMembers(context.Background()) + assert.Equal(t, ctx, context.Background()) + assert.Assert(t, members == nil) + assert.NilError(t, err) + }) + + t.Run("single", func(t *testing.T) { + expected := []string{"foo", "bar"} + ctx, actual, err := ChainResolvers(constantResolver{expected}).ResolveIAMMembers(context.Background()) + assert.Equal(t, ctx, context.Background()) + assert.DeepEqual(t, expected, actual) + assert.NilError(t, err) + }) + + t.Run("multi", func(t *testing.T) { + expected := []string{"foo", "bar", "baz"} + ctx, actual, err := ChainResolvers( + constantResolver{members: []string{"foo", "bar"}}, + constantResolver{members: []string{"baz"}}, + ).ResolveIAMMembers(context.Background()) + assert.Equal(t, ctx, context.Background()) + assert.DeepEqual(t, expected, actual) + assert.NilError(t, err) + }) + + t.Run("multi duplicates", func(t *testing.T) { + expected := []string{"foo", "bar", "baz"} + ctx, actual, err := ChainResolvers( + constantResolver{members: []string{"foo", "bar"}}, + constantResolver{members: []string{"bar", "baz"}}, + ).ResolveIAMMembers(context.Background()) + assert.Equal(t, ctx, context.Background()) + assert.DeepEqual(t, expected, actual) + assert.NilError(t, err) + }) + + t.Run("error", func(t *testing.T) { + ctx, actual, err := ChainResolvers( + constantResolver{members: []string{"foo", "bar"}}, + errorResolver{err: errors.New("boom")}, + ).ResolveIAMMembers(context.Background()) + assert.Assert(t, ctx == nil) + assert.Assert(t, actual == nil) + assert.Error(t, err, "boom") + }) +} + +type constantResolver struct { + members []string +} + +func (c constantResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) { + return ctx, c.members, nil +} + +type errorResolver struct { + err error +} + +func (e errorResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) { + return nil, nil, e.err +} diff --git a/iammember/resolver.go b/iammember/resolver.go index 34231d20..59b6aaf5 100644 --- a/iammember/resolver.go +++ b/iammember/resolver.go @@ -4,5 +4,5 @@ import "context" // Resolver resolves the IAM member identifiers for a caller context. type Resolver interface { - ResolveIAMMembers(context.Context) ([]string, error) + ResolveIAMMembers(context.Context) (context.Context, []string, error) } diff --git a/iamspanner/server.go b/iamspanner/server.go index 57ceecdf..7d44cbb2 100644 --- a/iamspanner/server.go +++ b/iamspanner/server.go @@ -110,7 +110,7 @@ func (s *IAMServer) TestIamPermissions( ctx context.Context, request *iam.TestIamPermissionsRequest, ) (*iam.TestIamPermissionsResponse, error) { - members, err := s.resolveMembers(ctx) + ctx, members, err := s.memberResolver.ResolveIAMMembers(ctx) if err != nil { return nil, err } @@ -163,7 +163,7 @@ func (s *IAMServer) TestPermissionOnResources( permission string, resources []string, ) (map[string]bool, error) { - members, err := s.resolveMembers(ctx) + ctx, members, err := s.memberResolver.ResolveIAMMembers(ctx) if err != nil { return nil, err } @@ -450,14 +450,6 @@ func (s *IAMServer) handleStorageError(ctx context.Context, err error) error { } } -func (s *IAMServer) resolveMembers(ctx context.Context) ([]string, error) { - members, err := s.memberResolver.ResolveIAMMembers(ctx) - if err != nil { - return nil, err - } - return members, nil -} - func computeETag(policy *iam.Policy) ([]byte, error) { data, err := proto.Marshal(policy) if err != nil { diff --git a/iamspanner/server_test.go b/iamspanner/server_test.go index bcac0629..537658c5 100644 --- a/iamspanner/server_test.go +++ b/iamspanner/server_test.go @@ -76,8 +76,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -98,8 +98,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -133,8 +133,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -163,8 +163,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -202,8 +202,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -244,8 +244,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -271,8 +271,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -310,8 +310,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -350,8 +350,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user2}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user2}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -389,8 +389,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -429,8 +429,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -469,8 +469,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -492,8 +492,8 @@ func TestServer(t *testing.T) { server, err := NewServer( newDatabase(), roles, - iamMemberResolver(func(ctx context.Context) ([]string, error) { - return []string{user1}, nil + iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) { + return ctx, []string{user1}, nil }), ServerConfig{ ErrorHook: func(ctx context.Context, err error) { @@ -528,8 +528,8 @@ func TestServer(t *testing.T) { }) } -type iamMemberResolver func(context.Context) ([]string, error) +type iamMemberResolver func(context.Context) (context.Context, []string, error) -func (r iamMemberResolver) ResolveIAMMembers(ctx context.Context) ([]string, error) { +func (r iamMemberResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) { return r(ctx) }