Skip to content

Commit

Permalink
feat: add iamauthz with auth-checking middleware
Browse files Browse the repository at this point in the history
As a belt-and-suspenders approach to ensure that authorization has been
performed on all RPC methods in a service, this middleware requires
authorization implementations to call `iamauthz.Authorize(ctx)` to mark
the request as authorized.

Requests that return without being marked as authorized are stopped by
the middleware, instead returning INTERNAL (signaling that the
implementer of the RPC method has not correctly implemented
authorization).
  • Loading branch information
odsod committed May 12, 2021
1 parent b71b4b8 commit ea6a6ed
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 2 deletions.
9 changes: 8 additions & 1 deletion cmd/iamctl/internal/examplecmd/exampleservercmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"

"cloud.google.com/go/spanner"
"go.einride.tech/iam/iamauthz"
"go.einride.tech/iam/iamexample"
"go.einride.tech/iam/iamexample/iamexampledata"
"go.einride.tech/iam/iammember"
Expand Down Expand Up @@ -81,7 +82,13 @@ func (g googleUserInfoMemberResolver) ResolveIAMMembersFromGoogleUserInfo(
}

func runServer(ctx context.Context, server iamexamplev1.FreightServiceServer, address string) error {
grpcServer := grpc.NewServer(grpc.UnaryInterceptor(logUnary))
grpcServer := grpc.NewServer(
grpc.ChainUnaryInterceptor(
logUnary,
iamauthz.RequireUnaryAuthorization,
),
grpc.StreamInterceptor(iamauthz.RequireStreamAuthorization),
)
iam.RegisterIAMPolicyServer(grpcServer, server)
if adminServer, ok := server.(admin.IAMServer); ok {
admin.RegisterIAMServer(grpcServer, adminServer)
Expand Down
2 changes: 2 additions & 0 deletions iamauthz/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Package iamauthz provides primitives for performing IAM request authorization.
package iamauthz
79 changes: 79 additions & 0 deletions iamauthz/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package iamauthz

import (
"context"
"sync"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

// Authorize marks the current request as processed by an authorization check.
// WithAuthorization must have been called on the context for the call to be effective.
//
// Authorize should be called at the start of an authorization check, to ensure that any errors resulting from the
// authorization check itself are forwarded to the caller.
func Authorize(ctx context.Context) {
if value, ok := ctx.Value(contextKey{}).(*contextValue); ok {
value.mu.Lock()
value.authorized = true
value.mu.Unlock()
}
}

// RequireUnaryAuthorization is a grpc.UnaryServerInterceptor that requires authorization
// to be performed on all incoming requests.
//
// To mark the request as processed by authorization checks, the method implementing authorization should call
// Authorize on the request context as soon as authorization starts.
func RequireUnaryAuthorization(
ctx context.Context,
req interface{},
_ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
ctx = WithAuthorization(ctx)
resp, err := handler(ctx, req)
if code := status.Code(err); code == codes.Unauthenticated || code == codes.PermissionDenied {
return nil, err
}
value := ctx.Value(contextKey{}).(*contextValue)
value.mu.Lock()
authorized := value.authorized
value.mu.Unlock()
if !authorized {
return nil, status.Error(codes.Internal, "server did not perform authorization")
}
return resp, err
}

var _ grpc.UnaryServerInterceptor = RequireUnaryAuthorization

// RequireStreamAuthorization is a grpc.StreamServerInterceptor that aborts all incoming streams, pending implementation
// of stream support in this package.
func RequireStreamAuthorization(
_ interface{},
_ grpc.ServerStream,
_ *grpc.StreamServerInfo,
_ grpc.StreamHandler,
) error {
return status.Error(codes.Internal, "server has not implemented stream authorization")
}

var _ grpc.StreamServerInterceptor = RequireStreamAuthorization

// WithAuthorization adds authorization to the current request context.
func WithAuthorization(ctx context.Context) context.Context {
if _, ok := ctx.Value(contextKey{}).(*contextValue); ok {
return ctx
}
return context.WithValue(ctx, contextKey{}, &contextValue{})
}

type contextKey struct{}

type contextValue struct {
mu sync.Mutex
authorized bool
}
119 changes: 119 additions & 0 deletions iamauthz/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package iamauthz

import (
"context"
"net"
"testing"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"gotest.tools/v3/assert"

"google.golang.org/grpc"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
)

func TestRequireUnaryAuthorization(t *testing.T) {
t.Run("authorized", func(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
assert.NilError(t, err)
grpcServer := grpc.NewServer(grpc.UnaryInterceptor(RequireUnaryAuthorization))
healthpb.RegisterHealthServer(grpcServer, &authorizedHealthServer{})
errChan := make(chan error)
go func() {
if err := grpcServer.Serve(lis); err != nil && err != grpc.ErrServerStopped {
errChan <- err
return
}
errChan <- nil
}()
t.Cleanup(func() {
assert.NilError(t, <-errChan)
})
t.Cleanup(func() {
grpcServer.GracefulStop()
})
ctx := withTestDeadline(context.Background(), t)
conn, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock())
assert.NilError(t, err)
client := healthpb.NewHealthClient(conn)
response, err := client.Check(ctx, &healthpb.HealthCheckRequest{})
assert.NilError(t, err)
assert.Equal(t, healthpb.HealthCheckResponse_SERVING, response.GetStatus())
})

t.Run("not authorized", func(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
assert.NilError(t, err)
grpcServer := grpc.NewServer(grpc.UnaryInterceptor(RequireUnaryAuthorization))
healthpb.RegisterHealthServer(grpcServer, &healthServer{})
errChan := make(chan error)
go func() {
if err := grpcServer.Serve(lis); err != nil && err != grpc.ErrServerStopped {
errChan <- err
return
}
errChan <- nil
}()
t.Cleanup(func() {
assert.NilError(t, <-errChan)
})
t.Cleanup(func() {
grpcServer.GracefulStop()
})
ctx := withTestDeadline(context.Background(), t)
conn, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock())
assert.NilError(t, err)
client := healthpb.NewHealthClient(conn)
response, err := client.Check(ctx, &healthpb.HealthCheckRequest{})
assert.Assert(t, response == nil)
assert.Equal(t, codes.Internal, status.Code(err))
})
}

type authorizedHealthServer struct {
healthServer
}

func (s *authorizedHealthServer) Check(
ctx context.Context,
request *healthpb.HealthCheckRequest,
) (*healthpb.HealthCheckResponse, error) {
Authorize(ctx)
return s.healthServer.Check(ctx, request)
}

func (s *authorizedHealthServer) Watch(
request *healthpb.HealthCheckRequest,
server healthpb.Health_WatchServer,
) error {
Authorize(server.Context())
return s.healthServer.Watch(request, server)
}

type healthServer struct{}

func (s *healthServer) Check(
_ context.Context,
_ *healthpb.HealthCheckRequest,
) (*healthpb.HealthCheckResponse, error) {
return &healthpb.HealthCheckResponse{Status: healthpb.HealthCheckResponse_SERVING}, nil
}

func (s *healthServer) Watch(
_ *healthpb.HealthCheckRequest,
_ healthpb.Health_WatchServer,
) error {
return nil
}

func withTestDeadline(ctx context.Context, t *testing.T) context.Context {
deadline, ok := t.Deadline()
if !ok {
return ctx
}
ctx, cancel := context.WithDeadline(ctx, deadline)
t.Cleanup(cancel)
return ctx
}
Loading

0 comments on commit ea6a6ed

Please sign in to comment.