Skip to content

Commit

Permalink
feat(iammember): add iammember.ChainResolvers
Browse files Browse the repository at this point in the history
For combining a set of standard resolvers. For example for resolving
different types of JWT tokens, or JWT tokens from multiple headers.

Context has been added to the resolver API to enable individual
resolvers to cache results in the context, for example parsed JWT
tokens.
  • Loading branch information
odsod committed May 7, 2021
1 parent 36f28b6 commit 7cf25a1
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 43 deletions.
8 changes: 4 additions & 4 deletions iamexample/members.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ var _ iammember.Resolver = &iamMemberHeaderResolver{}
type iamMemberHeaderResolver struct{}

// ResolveIAMMembers implements iammember.Resolver.
func (m *iamMemberHeaderResolver) ResolveIAMMembers(ctx context.Context) ([]string, error) {
func (m *iamMemberHeaderResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "missing members header: %s", MemberHeader)
return nil, nil, status.Errorf(codes.Unauthenticated, "missing members header: %s", MemberHeader)
}
values := md.Get(MemberHeader)
if len(values) == 0 {
return nil, status.Errorf(codes.Unauthenticated, "missing members header: %s", MemberHeader)
return nil, nil, status.Errorf(codes.Unauthenticated, "missing members header: %s", MemberHeader)
}
return values, nil
return ctx, values, nil
}

// WithOutgoingMembers appends the provided members to the outgoing gRPC context.
Expand Down
44 changes: 44 additions & 0 deletions iammember/chain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package iammember

import "context"

// ChainResolvers creates a single resolver out of a chain of many resolvers.
//
// The resulting resolved members will be the union of the members resolved by each resolver.
//
// Execution is done in left-to-right order, including passing of context.
// For example ChainResolvers(one, two, three) will execute one before two before three, and three
// will see context changes of one and two.
//
// If any resolver returns an error, that error is immediately returned and no further resolvers are called.
func ChainResolvers(resolvers ...Resolver) Resolver {
return chainResolver{resolvers: resolvers}
}

type chainResolver struct {
resolvers []Resolver
}

func (c chainResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) {
var result, members []string
var err error
for _, resolver := range c.resolvers {
ctx, members, err = resolver.ResolveIAMMembers(ctx)
if err != nil {
return nil, nil, err
}
for _, member := range members {
var hasMember bool
for _, resultMember := range result {
if member == resultMember {
hasMember = true
break
}
}
if !hasMember {
result = append(result, member)
}
}
}
return ctx, result, nil
}
74 changes: 74 additions & 0 deletions iammember/chain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package iammember

import (
"context"
"errors"
"testing"

"gotest.tools/v3/assert"
)

func TestChainResolvers(t *testing.T) {
t.Run("no resolvers", func(t *testing.T) {
ctx, members, err := ChainResolvers().ResolveIAMMembers(context.Background())
assert.Equal(t, ctx, context.Background())
assert.Assert(t, members == nil)
assert.NilError(t, err)
})

t.Run("single", func(t *testing.T) {
expected := []string{"foo", "bar"}
ctx, actual, err := ChainResolvers(constantResolver{expected}).ResolveIAMMembers(context.Background())
assert.Equal(t, ctx, context.Background())
assert.DeepEqual(t, expected, actual)
assert.NilError(t, err)
})

t.Run("multi", func(t *testing.T) {
expected := []string{"foo", "bar", "baz"}
ctx, actual, err := ChainResolvers(
constantResolver{members: []string{"foo", "bar"}},
constantResolver{members: []string{"baz"}},
).ResolveIAMMembers(context.Background())
assert.Equal(t, ctx, context.Background())
assert.DeepEqual(t, expected, actual)
assert.NilError(t, err)
})

t.Run("multi duplicates", func(t *testing.T) {
expected := []string{"foo", "bar", "baz"}
ctx, actual, err := ChainResolvers(
constantResolver{members: []string{"foo", "bar"}},
constantResolver{members: []string{"bar", "baz"}},
).ResolveIAMMembers(context.Background())
assert.Equal(t, ctx, context.Background())
assert.DeepEqual(t, expected, actual)
assert.NilError(t, err)
})

t.Run("error", func(t *testing.T) {
ctx, actual, err := ChainResolvers(
constantResolver{members: []string{"foo", "bar"}},
errorResolver{err: errors.New("boom")},
).ResolveIAMMembers(context.Background())
assert.Assert(t, ctx == nil)
assert.Assert(t, actual == nil)
assert.Error(t, err, "boom")
})
}

type constantResolver struct {
members []string
}

func (c constantResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) {
return ctx, c.members, nil
}

type errorResolver struct {
err error
}

func (e errorResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) {
return nil, nil, e.err
}
2 changes: 1 addition & 1 deletion iammember/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ import "context"

// Resolver resolves the IAM member identifiers for a caller context.
type Resolver interface {
ResolveIAMMembers(context.Context) ([]string, error)
ResolveIAMMembers(context.Context) (context.Context, []string, error)
}
12 changes: 2 additions & 10 deletions iamspanner/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (s *IAMServer) TestIamPermissions(
ctx context.Context,
request *iam.TestIamPermissionsRequest,
) (*iam.TestIamPermissionsResponse, error) {
members, err := s.resolveMembers(ctx)
ctx, members, err := s.memberResolver.ResolveIAMMembers(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -163,7 +163,7 @@ func (s *IAMServer) TestPermissionOnResources(
permission string,
resources []string,
) (map[string]bool, error) {
members, err := s.resolveMembers(ctx)
ctx, members, err := s.memberResolver.ResolveIAMMembers(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -450,14 +450,6 @@ func (s *IAMServer) handleStorageError(ctx context.Context, err error) error {
}
}

func (s *IAMServer) resolveMembers(ctx context.Context) ([]string, error) {
members, err := s.memberResolver.ResolveIAMMembers(ctx)
if err != nil {
return nil, err
}
return members, nil
}

func computeETag(policy *iam.Policy) ([]byte, error) {
data, err := proto.Marshal(policy)
if err != nil {
Expand Down
56 changes: 28 additions & 28 deletions iamspanner/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ func TestServer(t *testing.T) {
server, err := NewServer(
newDatabase(),
roles,
iamMemberResolver(func(ctx context.Context) ([]string, error) {
return []string{user1}, nil
iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) {
return ctx, []string{user1}, nil
}),
ServerConfig{
ErrorHook: func(ctx context.Context, err error) {
Expand All @@ -98,8 +98,8 @@ func TestServer(t *testing.T) {
server, err := NewServer(
newDatabase(),
roles,
iamMemberResolver(func(ctx context.Context) ([]string, error) {
return []string{user1}, nil
iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) {
return ctx, []string{user1}, nil
}),
ServerConfig{
ErrorHook: func(ctx context.Context, err error) {
Expand Down Expand Up @@ -133,8 +133,8 @@ func TestServer(t *testing.T) {
server, err := NewServer(
newDatabase(),
roles,
iamMemberResolver(func(ctx context.Context) ([]string, error) {
return []string{user1}, nil
iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) {
return ctx, []string{user1}, nil
}),
ServerConfig{
ErrorHook: func(ctx context.Context, err error) {
Expand Down Expand Up @@ -163,8 +163,8 @@ func TestServer(t *testing.T) {
server, err := NewServer(
newDatabase(),
roles,
iamMemberResolver(func(ctx context.Context) ([]string, error) {
return []string{user1}, nil
iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) {
return ctx, []string{user1}, nil
}),
ServerConfig{
ErrorHook: func(ctx context.Context, err error) {
Expand Down Expand Up @@ -202,8 +202,8 @@ func TestServer(t *testing.T) {
server, err := NewServer(
newDatabase(),
roles,
iamMemberResolver(func(ctx context.Context) ([]string, error) {
return []string{user1}, nil
iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) {
return ctx, []string{user1}, nil
}),
ServerConfig{
ErrorHook: func(ctx context.Context, err error) {
Expand Down Expand Up @@ -244,8 +244,8 @@ func TestServer(t *testing.T) {
server, err := NewServer(
newDatabase(),
roles,
iamMemberResolver(func(ctx context.Context) ([]string, error) {
return []string{user1}, nil
iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) {
return ctx, []string{user1}, nil
}),
ServerConfig{
ErrorHook: func(ctx context.Context, err error) {
Expand All @@ -271,8 +271,8 @@ func TestServer(t *testing.T) {
server, err := NewServer(
newDatabase(),
roles,
iamMemberResolver(func(ctx context.Context) ([]string, error) {
return []string{user1}, nil
iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) {
return ctx, []string{user1}, nil
}),
ServerConfig{
ErrorHook: func(ctx context.Context, err error) {
Expand Down Expand Up @@ -310,8 +310,8 @@ func TestServer(t *testing.T) {
server, err := NewServer(
newDatabase(),
roles,
iamMemberResolver(func(ctx context.Context) ([]string, error) {
return []string{user1}, nil
iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) {
return ctx, []string{user1}, nil
}),
ServerConfig{
ErrorHook: func(ctx context.Context, err error) {
Expand Down Expand Up @@ -350,8 +350,8 @@ func TestServer(t *testing.T) {
server, err := NewServer(
newDatabase(),
roles,
iamMemberResolver(func(ctx context.Context) ([]string, error) {
return []string{user2}, nil
iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) {
return ctx, []string{user2}, nil
}),
ServerConfig{
ErrorHook: func(ctx context.Context, err error) {
Expand Down Expand Up @@ -389,8 +389,8 @@ func TestServer(t *testing.T) {
server, err := NewServer(
newDatabase(),
roles,
iamMemberResolver(func(ctx context.Context) ([]string, error) {
return []string{user1}, nil
iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) {
return ctx, []string{user1}, nil
}),
ServerConfig{
ErrorHook: func(ctx context.Context, err error) {
Expand Down Expand Up @@ -429,8 +429,8 @@ func TestServer(t *testing.T) {
server, err := NewServer(
newDatabase(),
roles,
iamMemberResolver(func(ctx context.Context) ([]string, error) {
return []string{user1}, nil
iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) {
return ctx, []string{user1}, nil
}),
ServerConfig{
ErrorHook: func(ctx context.Context, err error) {
Expand Down Expand Up @@ -469,8 +469,8 @@ func TestServer(t *testing.T) {
server, err := NewServer(
newDatabase(),
roles,
iamMemberResolver(func(ctx context.Context) ([]string, error) {
return []string{user1}, nil
iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) {
return ctx, []string{user1}, nil
}),
ServerConfig{
ErrorHook: func(ctx context.Context, err error) {
Expand All @@ -492,8 +492,8 @@ func TestServer(t *testing.T) {
server, err := NewServer(
newDatabase(),
roles,
iamMemberResolver(func(ctx context.Context) ([]string, error) {
return []string{user1}, nil
iamMemberResolver(func(ctx context.Context) (context.Context, []string, error) {
return ctx, []string{user1}, nil
}),
ServerConfig{
ErrorHook: func(ctx context.Context, err error) {
Expand Down Expand Up @@ -528,8 +528,8 @@ func TestServer(t *testing.T) {
})
}

type iamMemberResolver func(context.Context) ([]string, error)
type iamMemberResolver func(context.Context) (context.Context, []string, error)

func (r iamMemberResolver) ResolveIAMMembers(ctx context.Context) ([]string, error) {
func (r iamMemberResolver) ResolveIAMMembers(ctx context.Context) (context.Context, []string, error) {
return r(ctx)
}

0 comments on commit 7cf25a1

Please sign in to comment.