Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JWT verification #317

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ outpkg: "{{.PackageName}}"
dir: "pkg/mocks/{{.PackageName}}"
filename: "mock_{{.InterfaceName}}.go"
packages:
github.com/xmtp/xmtpd/pkg/authn:
interfaces:
JWTVerifier:
github.com/xmtp/xmtpd/pkg/mlsvalidate:
interfaces:
MLSValidationService:
Expand Down
2 changes: 2 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ const (
NODE_AUTHORIZATION_HEADER_NAME = "node-authorization"
MAX_BLOCKCHAIN_ORIGINATOR_ID = 100
)

type VerifiedNodeRequestCtxKey struct{}
117 changes: 117 additions & 0 deletions pkg/interceptors/server/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package server

import (
"context"

"github.com/xmtp/xmtpd/pkg/authn"
"github.com/xmtp/xmtpd/pkg/constants"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

// wrappedServerStream allows us to modify the context of the stream
type wrappedServerStream struct {
grpc.ServerStream
ctx context.Context
}

func (w *wrappedServerStream) Context() context.Context {
return w.ctx
}

// AuthInterceptor validates JWT tokens from other nodes
type AuthInterceptor struct {
verifier authn.JWTVerifier
logger *zap.Logger
}

func NewAuthInterceptor(verifier authn.JWTVerifier, logger *zap.Logger) *AuthInterceptor {
return &AuthInterceptor{
verifier: verifier,
logger: logger,
}
}

// extractToken gets the JWT token from the request metadata
func extractToken(ctx context.Context) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", status.Error(codes.Unauthenticated, "missing metadata")
}

values := md.Get(constants.NODE_AUTHORIZATION_HEADER_NAME)
if len(values) == 0 {
return "", status.Error(codes.Unauthenticated, "missing auth token")
}

if len(values) > 1 {
return "", status.Error(codes.Unauthenticated, "multiple auth tokens provided")
}

return values[0], nil
}

// Unary returns a grpc.UnaryServerInterceptor that validates JWT tokens
func (i *AuthInterceptor) Unary() grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
token, err := extractToken(ctx)
if err != nil {
i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err))
return handler(ctx, req)
neekolas marked this conversation as resolved.
Show resolved Hide resolved
}
neekolas marked this conversation as resolved.
Show resolved Hide resolved

if err := i.verifier.Verify(token); err != nil {
return nil, status.Errorf(
codes.Unauthenticated,
"invalid auth token: %v",
err,
)
}
ctx = context.WithValue(ctx, constants.VerifiedNodeRequestCtxKey{}, true)

return handler(ctx, req)
}
}

// Stream returns a grpc.StreamServerInterceptor that validates JWT tokens
func (i *AuthInterceptor) Stream() grpc.StreamServerInterceptor {
return func(
srv interface{},
stream grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
token, err := extractToken(stream.Context())
if err != nil {
i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err))
return handler(srv, stream)
}

if err := i.verifier.Verify(token); err != nil {
return status.Errorf(
codes.Unauthenticated,
"invalid auth token: %v",
err,
)
}

stream = &wrappedServerStream{
ServerStream: stream,
ctx: context.WithValue(
stream.Context(),
constants.VerifiedNodeRequestCtxKey{},
true,
),
}

return handler(srv, stream)
}
}
206 changes: 206 additions & 0 deletions pkg/interceptors/server/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package server

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/constants"
"github.com/xmtp/xmtpd/pkg/mocks/authn"
"go.uber.org/zap/zaptest"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

func TestUnaryInterceptor(t *testing.T) {
mockVerifier := authn.NewMockJWTVerifier(t)
logger := zaptest.NewLogger(t)
interceptor := NewAuthInterceptor(mockVerifier, logger)

tests := []struct {
name string
setupContext func() context.Context
setupVerifier func()
wantError error
wantVerifiedNode bool
}{
{
name: "valid token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{
constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token",
})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {
mockVerifier.EXPECT().Verify("valid_token").Return(nil)
},
wantError: nil,
wantVerifiedNode: true,
},
{
name: "missing metadata",
setupContext: func() context.Context {
return context.Background()
},
setupVerifier: func() {},
wantError: nil,
wantVerifiedNode: false,
},
{
name: "missing token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {},
wantError: nil,
wantVerifiedNode: false,
},
{
name: "invalid token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{
constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token",
})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {
mockVerifier.EXPECT().
Verify("invalid_token").
Return(errors.New("invalid signature"))
},
wantError: status.Error(
codes.Unauthenticated,
"invalid auth token: invalid signature",
),
wantVerifiedNode: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupVerifier()

ctx := tt.setupContext()
var handlerCtx context.Context
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCtx = ctx
return "ok", nil
}

_, err := interceptor.Unary()(ctx, nil, &grpc.UnaryServerInfo{}, handler)

if tt.wantError != nil {
require.Error(t, err)
require.Equal(t, tt.wantError.Error(), err.Error())
} else {
require.NoError(t, err)
isVerified, hasContextValue := handlerCtx.Value(constants.VerifiedNodeRequestCtxKey{}).(bool)
if tt.wantVerifiedNode {
require.True(t, isVerified)
} else {
require.False(t, hasContextValue)
}
}
})
}
}

func TestStreamInterceptor(t *testing.T) {
mockVerifier := authn.NewMockJWTVerifier(t)
logger := zaptest.NewLogger(t)
interceptor := NewAuthInterceptor(mockVerifier, logger)

tests := []struct {
name string
setupContext func() context.Context
setupVerifier func()
wantError error
wantVerifiedNode bool
}{
{
name: "valid token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{
constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token",
})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {
mockVerifier.EXPECT().Verify("valid_token").Return(nil)
},
wantError: nil,
wantVerifiedNode: true,
},
{
name: "missing metadata",
setupContext: func() context.Context {
return context.Background()
},
setupVerifier: func() {},
wantError: nil,
wantVerifiedNode: false,
},
{
name: "invalid token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{
constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token",
})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {
mockVerifier.EXPECT().
Verify("invalid_token").
Return(errors.New("invalid signature"))
},
wantError: status.Error(
codes.Unauthenticated,
"invalid auth token: invalid signature",
),
wantVerifiedNode: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupVerifier()

ctx := tt.setupContext()
var handlerStream grpc.ServerStream
stream := &mockServerStreamWithCtx{ctx: ctx}
handler := func(srv interface{}, stream grpc.ServerStream) error {
handlerStream = stream
return nil
}

err := interceptor.Stream()(nil, stream, &grpc.StreamServerInfo{}, handler)

if tt.wantError != nil {
require.Error(t, err)
require.Equal(t, tt.wantError.Error(), err.Error())
} else {
require.NoError(t, err)
isVerified, hasContextValue := handlerStream.Context().Value(constants.VerifiedNodeRequestCtxKey{}).(bool)
if tt.wantVerifiedNode {
require.True(t, isVerified)
} else {
require.False(t, hasContextValue)
}
}
})
}
}

type mockServerStreamWithCtx struct {
grpc.ServerStream
ctx context.Context
}

func (s *mockServerStreamWithCtx) Context() context.Context {
return s.ctx
}
Loading
Loading