diff --git a/pkg/api/message/service.go b/pkg/api/message/service.go index 6ac125c1..be6cc636 100644 --- a/pkg/api/message/service.go +++ b/pkg/api/message/service.go @@ -9,9 +9,11 @@ import ( "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/envelopes" + "github.com/xmtp/xmtpd/pkg/mlsvalidate" envelopesProto "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" "github.com/xmtp/xmtpd/pkg/registrant" + "github.com/xmtp/xmtpd/pkg/topic" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -31,12 +33,13 @@ const ( type Service struct { message_api.UnimplementedReplicationApiServer - ctx context.Context - log *zap.Logger - registrant *registrant.Registrant - store *sql.DB - publishWorker *publishWorker - subscribeWorker *subscribeWorker + ctx context.Context + log *zap.Logger + registrant *registrant.Registrant + store *sql.DB + publishWorker *publishWorker + subscribeWorker *subscribeWorker + validationService mlsvalidate.MLSValidationService } func NewReplicationApiService( @@ -44,7 +47,7 @@ func NewReplicationApiService( log *zap.Logger, registrant *registrant.Registrant, store *sql.DB, - + validationService mlsvalidate.MLSValidationService, ) (*Service, error) { publishWorker, err := startPublishWorker(ctx, log, registrant, store) if err != nil { @@ -56,12 +59,13 @@ func NewReplicationApiService( } return &Service{ - ctx: ctx, - log: log, - registrant: registrant, - store: store, - publishWorker: publishWorker, - subscribeWorker: subscribeWorker, + ctx: ctx, + log: log, + registrant: registrant, + store: store, + publishWorker: publishWorker, + subscribeWorker: subscribeWorker, + validationService: validationService, }, nil } @@ -323,6 +327,13 @@ func (s *Service) PublishPayerEnvelopes( } targetTopic := payerEnv.ClientEnvelope.TargetTopic() + topicKind := targetTopic.Kind() + + if topicKind == topic.TOPIC_KIND_KEY_PACKAGES_V1 { + if err = s.validateKeyPackage(ctx, &payerEnv.ClientEnvelope); err != nil { + return nil, err + } + } stagedEnv, err := queries.New(s.store). InsertStagedOriginatorEnvelope(ctx, queries.InsertStagedOriginatorEnvelopeParams{ @@ -397,6 +408,34 @@ func (s *Service) validatePayerEnvelope( return payerEnv, nil } +func (s *Service) validateKeyPackage( + ctx context.Context, + clientEnv *envelopes.ClientEnvelope, +) error { + payload, ok := clientEnv.Payload().(*envelopesProto.ClientEnvelope_UploadKeyPackage) + if !ok { + return status.Errorf(codes.InvalidArgument, "invalid payload type") + } + + validationResult, err := s.validationService.ValidateKeyPackages( + ctx, + [][]byte{payload.UploadKeyPackage.KeyPackage.KeyPackageTlsSerialized}, + ) + if err != nil { + return status.Errorf(codes.Internal, "could not validate key package: %v", err) + } + + if len(validationResult) == 0 { + return status.Errorf(codes.Internal, "no validation results") + } + + if !validationResult[0].IsOk { + return status.Errorf(codes.InvalidArgument, "key package validation failed") + } + + return nil +} + func (s *Service) validateClientInfo(clientEnv *envelopes.ClientEnvelope) error { aad := clientEnv.Aad() if aad.GetTargetOriginator() != s.registrant.NodeID() { diff --git a/pkg/mlsvalidate/interface.go b/pkg/mlsvalidate/interface.go index d6c7447f..2e65cda2 100644 --- a/pkg/mlsvalidate/interface.go +++ b/pkg/mlsvalidate/interface.go @@ -10,6 +10,7 @@ import ( ) type KeyPackageValidationResult struct { + IsOk bool InstallationKey []byte Credential *identity_proto.MlsCredential Expiration uint64 diff --git a/pkg/mlsvalidate/service.go b/pkg/mlsvalidate/service.go index 009ee954..f45162c2 100644 --- a/pkg/mlsvalidate/service.go +++ b/pkg/mlsvalidate/service.go @@ -118,12 +118,19 @@ func (s *MLSValidationServiceImpl) ValidateKeyPackages( out := make([]KeyPackageValidationResult, len(response.Responses)) for i, response := range response.Responses { if !response.IsOk { - return nil, fmt.Errorf("validation failed with error %s", response.ErrorMessage) - } - out[i] = KeyPackageValidationResult{ - InstallationKey: response.InstallationPublicKey, - Credential: nil, - Expiration: response.Expiration, + out[i] = KeyPackageValidationResult{ + IsOk: false, + InstallationKey: nil, + Credential: nil, + Expiration: 0, + } + } else { + out[i] = KeyPackageValidationResult{ + IsOk: true, + InstallationKey: response.InstallationPublicKey, + Credential: nil, + Expiration: response.Expiration, + } } } return out, nil diff --git a/pkg/server/server.go b/pkg/server/server.go index b33ae0a0..f3bd647b 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -36,14 +36,15 @@ type ReplicationServer struct { apiServer *api.ApiServer syncServer *sync.SyncServer - ctx context.Context - cancel context.CancelFunc - log *zap.Logger - registrant *registrant.Registrant - nodeRegistry registry.NodeRegistry - indx *indexer.Indexer - options config.ServerOptions - metrics *metrics.Server + ctx context.Context + cancel context.CancelFunc + log *zap.Logger + registrant *registrant.Registrant + nodeRegistry registry.NodeRegistry + indx *indexer.Indexer + options config.ServerOptions + metrics *metrics.Server + validationService mlsvalidate.MLSValidationService } func NewReplicationServer( @@ -98,7 +99,7 @@ func NewReplicationServer( } if options.Indexer.Enable { - validationService, err := mlsvalidate.NewMlsValidationService( + s.validationService, err = mlsvalidate.NewMlsValidationService( ctx, log, options.MlsValidation, @@ -111,7 +112,7 @@ func NewReplicationServer( err = s.indx.StartIndexer( writerDB, options.Contracts, - validationService, + s.validationService, ) if err != nil { @@ -168,11 +169,23 @@ func startAPIServer( serviceRegistrationFunc := func(grpcServer *grpc.Server) error { if options.Replication.Enable { + if s.validationService == nil { + s.validationService, err = mlsvalidate.NewMlsValidationService( + ctx, + log, + options.MlsValidation, + ) + if err != nil { + return err + } + } + replicationService, err := message.NewReplicationApiService( ctx, log, s.registrant, writerDB, + s.validationService, ) if err != nil { return err diff --git a/pkg/testutils/api/api.go b/pkg/testutils/api/api.go index 9931abfb..3f6d7347 100644 --- a/pkg/testutils/api/api.go +++ b/pkg/testutils/api/api.go @@ -14,6 +14,7 @@ import ( "github.com/xmtp/xmtpd/pkg/authn" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/mocks/blockchain" + mlsvalidateMocks "github.com/xmtp/xmtpd/pkg/mocks/mlsvalidate" mocks "github.com/xmtp/xmtpd/pkg/mocks/registry" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/payer_api" @@ -78,6 +79,7 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, func()) { registrant, err := registrant.NewRegistrant(ctx, log, queries.New(db), mockRegistry, privKeyStr) require.NoError(t, err) mockMessagePublisher := blockchain.NewMockIBlockchainPublisher(t) + mockValidationService := mlsvalidateMocks.NewMockMLSValidationService(t) jwtVerifier := authn.NewRegistryVerifier(mockRegistry, registrant.NodeID()) @@ -87,6 +89,7 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, func()) { log, registrant, db, + mockValidationService, ) require.NoError(t, err) message_api.RegisterReplicationApiServer(grpcServer, replicationService)