diff --git a/pkg/api/server.go b/pkg/api/server.go index 9bfc8396..31b6d246 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -2,12 +2,14 @@ package api import ( "context" - "github.com/xmtp/xmtpd/pkg/interceptors/server" "net" "strings" "sync" "time" + "github.com/xmtp/xmtpd/pkg/authn" + "github.com/xmtp/xmtpd/pkg/interceptors/server" + "google.golang.org/grpc/reflection" prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" @@ -41,6 +43,7 @@ func NewAPIServer( listenAddress string, enableReflection bool, registrationFunc RegistrationFunc, + jwtVerifier authn.JWTVerifier, ) (*ApiServer, error) { grpcListener, err := net.Listen("tcp", listenAddress) @@ -67,8 +70,18 @@ func NewAPIServer( return nil, err } - unary := []grpc.UnaryServerInterceptor{prometheus.UnaryServerInterceptor} - stream := []grpc.StreamServerInterceptor{prometheus.StreamServerInterceptor} + unary := []grpc.UnaryServerInterceptor{ + prometheus.UnaryServerInterceptor, + } + stream := []grpc.StreamServerInterceptor{ + prometheus.StreamServerInterceptor, + } + + if jwtVerifier != nil { + interceptor := server.NewAuthInterceptor(jwtVerifier, log) + unary = append(unary, interceptor.Unary()) + stream = append(stream, interceptor.Stream()) + } options := []grpc.ServerOption{ grpc.ChainUnaryInterceptor(unary...), diff --git a/pkg/server/server.go b/pkg/server/server.go index b38749cb..b33ae0a0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -3,12 +3,14 @@ package server import ( "context" "database/sql" - "github.com/xmtp/xmtpd/pkg/mlsvalidate" "net" "os" "os/signal" "syscall" + "github.com/xmtp/xmtpd/pkg/authn" + "github.com/xmtp/xmtpd/pkg/mlsvalidate" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" "github.com/xmtp/xmtpd/pkg/api/message" @@ -203,12 +205,19 @@ func startAPIServer( return nil } + var jwtVerifier authn.JWTVerifier + + if s.nodeRegistry != nil && s.registrant != nil { + jwtVerifier = authn.NewRegistryVerifier(s.nodeRegistry, s.registrant.NodeID()) + } + s.apiServer, err = api.NewAPIServer( s.ctx, log, listenAddress, options.Reflection.Enable, serviceRegistrationFunc, + jwtVerifier, ) if err != nil { return err diff --git a/pkg/testutils/api/api.go b/pkg/testutils/api/api.go index 35d4cc28..9931abfb 100644 --- a/pkg/testutils/api/api.go +++ b/pkg/testutils/api/api.go @@ -11,6 +11,7 @@ import ( "github.com/xmtp/xmtpd/pkg/api" "github.com/xmtp/xmtpd/pkg/api/message" "github.com/xmtp/xmtpd/pkg/api/payer" + "github.com/xmtp/xmtpd/pkg/authn" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/mocks/blockchain" mocks "github.com/xmtp/xmtpd/pkg/mocks/registry" @@ -78,6 +79,8 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, func()) { require.NoError(t, err) mockMessagePublisher := blockchain.NewMockIBlockchainPublisher(t) + jwtVerifier := authn.NewRegistryVerifier(mockRegistry, registrant.NodeID()) + serviceRegistrationFunc := func(grpcServer *grpc.Server) error { replicationService, err := message.NewReplicationApiService( ctx, @@ -107,6 +110,7 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, func()) { "localhost:0", /*listenAddress*/ true, /*enableReflection*/ serviceRegistrationFunc, + jwtVerifier, ) require.NoError(t, err)