-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add iamauthz with auth-checking middleware
As a belt-and-suspenders approach to ensure that authorization has been performed on all RPC methods in a service, this middleware requires authorization implementations to call `iamauthz.Authorize(ctx)` to mark the request as authorized. Requests that return without being marked as authorized are stopped by the middleware, instead returning INTERNAL (signaling that the implementer of the RPC method has not correctly implemented authorization).
- Loading branch information
Showing
6 changed files
with
232 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
// Package iamauthz provides primitives for performing IAM request authorization. | ||
package iamauthz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
package iamauthz | ||
|
||
import ( | ||
"context" | ||
"sync" | ||
|
||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/status" | ||
) | ||
|
||
// Authorize marks the current request as processed by an authorization check. | ||
// WithAuthorization must have been called on the context for the call to be effective. | ||
// | ||
// Authorize should be called at the start of an authorization check, to ensure that any errors resulting from the | ||
// authorization check itself are forwarded to the caller. | ||
func Authorize(ctx context.Context) { | ||
if value, ok := ctx.Value(contextKey{}).(*contextValue); ok { | ||
value.mu.Lock() | ||
value.authorized = true | ||
value.mu.Unlock() | ||
} | ||
} | ||
|
||
// RequireUnaryAuthorization is a grpc.UnaryServerInterceptor that requires authorization | ||
// to be performed on all incoming requests. | ||
// | ||
// To mark the request as processed by authorization checks, the method implementing authorization should call | ||
// Authorize on the request context as soon as authorization starts. | ||
func RequireUnaryAuthorization( | ||
ctx context.Context, | ||
req interface{}, | ||
_ *grpc.UnaryServerInfo, | ||
handler grpc.UnaryHandler, | ||
) (interface{}, error) { | ||
ctx = WithAuthorization(ctx) | ||
resp, err := handler(ctx, req) | ||
if code := status.Code(err); code == codes.Unauthenticated || code == codes.PermissionDenied { | ||
return nil, err | ||
} | ||
value := ctx.Value(contextKey{}).(*contextValue) | ||
value.mu.Lock() | ||
authorized := value.authorized | ||
value.mu.Unlock() | ||
if !authorized { | ||
return nil, status.Error(codes.Internal, "server did not perform authorization") | ||
} | ||
return resp, err | ||
} | ||
|
||
var _ grpc.UnaryServerInterceptor = RequireUnaryAuthorization | ||
|
||
// RequireStreamAuthorization is a grpc.StreamServerInterceptor that aborts all incoming streams, pending implementation | ||
// of stream support in this package. | ||
func RequireStreamAuthorization( | ||
_ interface{}, | ||
_ grpc.ServerStream, | ||
_ *grpc.StreamServerInfo, | ||
_ grpc.StreamHandler, | ||
) error { | ||
return status.Error(codes.Internal, "server has not implemented stream authorization") | ||
} | ||
|
||
var _ grpc.StreamServerInterceptor = RequireStreamAuthorization | ||
|
||
// WithAuthorization adds authorization to the current request context. | ||
func WithAuthorization(ctx context.Context) context.Context { | ||
if _, ok := ctx.Value(contextKey{}).(*contextValue); ok { | ||
return ctx | ||
} | ||
return context.WithValue(ctx, contextKey{}, &contextValue{}) | ||
} | ||
|
||
type contextKey struct{} | ||
|
||
type contextValue struct { | ||
mu sync.Mutex | ||
authorized bool | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
package iamauthz | ||
|
||
import ( | ||
"context" | ||
"net" | ||
"testing" | ||
|
||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/status" | ||
|
||
"gotest.tools/v3/assert" | ||
|
||
"google.golang.org/grpc" | ||
healthpb "google.golang.org/grpc/health/grpc_health_v1" | ||
) | ||
|
||
func TestRequireUnaryAuthorization(t *testing.T) { | ||
t.Run("authorized", func(t *testing.T) { | ||
lis, err := net.Listen("tcp", "localhost:0") | ||
assert.NilError(t, err) | ||
grpcServer := grpc.NewServer(grpc.UnaryInterceptor(RequireUnaryAuthorization)) | ||
healthpb.RegisterHealthServer(grpcServer, &authorizedHealthServer{}) | ||
errChan := make(chan error) | ||
go func() { | ||
if err := grpcServer.Serve(lis); err != nil && err != grpc.ErrServerStopped { | ||
errChan <- err | ||
return | ||
} | ||
errChan <- nil | ||
}() | ||
t.Cleanup(func() { | ||
assert.NilError(t, <-errChan) | ||
}) | ||
t.Cleanup(func() { | ||
grpcServer.GracefulStop() | ||
}) | ||
ctx := withTestDeadline(context.Background(), t) | ||
conn, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock()) | ||
assert.NilError(t, err) | ||
client := healthpb.NewHealthClient(conn) | ||
response, err := client.Check(ctx, &healthpb.HealthCheckRequest{}) | ||
assert.NilError(t, err) | ||
assert.Equal(t, healthpb.HealthCheckResponse_SERVING, response.GetStatus()) | ||
}) | ||
|
||
t.Run("not authorized", func(t *testing.T) { | ||
lis, err := net.Listen("tcp", "localhost:0") | ||
assert.NilError(t, err) | ||
grpcServer := grpc.NewServer(grpc.UnaryInterceptor(RequireUnaryAuthorization)) | ||
healthpb.RegisterHealthServer(grpcServer, &healthServer{}) | ||
errChan := make(chan error) | ||
go func() { | ||
if err := grpcServer.Serve(lis); err != nil && err != grpc.ErrServerStopped { | ||
errChan <- err | ||
return | ||
} | ||
errChan <- nil | ||
}() | ||
t.Cleanup(func() { | ||
assert.NilError(t, <-errChan) | ||
}) | ||
t.Cleanup(func() { | ||
grpcServer.GracefulStop() | ||
}) | ||
ctx := withTestDeadline(context.Background(), t) | ||
conn, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock()) | ||
assert.NilError(t, err) | ||
client := healthpb.NewHealthClient(conn) | ||
response, err := client.Check(ctx, &healthpb.HealthCheckRequest{}) | ||
assert.Assert(t, response == nil) | ||
assert.Equal(t, codes.Internal, status.Code(err)) | ||
}) | ||
} | ||
|
||
type authorizedHealthServer struct { | ||
healthServer | ||
} | ||
|
||
func (s *authorizedHealthServer) Check( | ||
ctx context.Context, | ||
request *healthpb.HealthCheckRequest, | ||
) (*healthpb.HealthCheckResponse, error) { | ||
Authorize(ctx) | ||
return s.healthServer.Check(ctx, request) | ||
} | ||
|
||
func (s *authorizedHealthServer) Watch( | ||
request *healthpb.HealthCheckRequest, | ||
server healthpb.Health_WatchServer, | ||
) error { | ||
Authorize(server.Context()) | ||
return s.healthServer.Watch(request, server) | ||
} | ||
|
||
type healthServer struct{} | ||
|
||
func (s *healthServer) Check( | ||
_ context.Context, | ||
_ *healthpb.HealthCheckRequest, | ||
) (*healthpb.HealthCheckResponse, error) { | ||
return &healthpb.HealthCheckResponse{Status: healthpb.HealthCheckResponse_SERVING}, nil | ||
} | ||
|
||
func (s *healthServer) Watch( | ||
_ *healthpb.HealthCheckRequest, | ||
_ healthpb.Health_WatchServer, | ||
) error { | ||
return nil | ||
} | ||
|
||
func withTestDeadline(ctx context.Context, t *testing.T) context.Context { | ||
deadline, ok := t.Deadline() | ||
if !ok { | ||
return ctx | ||
} | ||
ctx, cancel := context.WithDeadline(ctx, deadline) | ||
t.Cleanup(cancel) | ||
return ctx | ||
} |
Oops, something went wrong.