Skip to content

Commit

Permalink
Add key package validation
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas committed Jan 6, 2025
1 parent ab9d996 commit b4f6507
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 29 deletions.
65 changes: 52 additions & 13 deletions pkg/api/message/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (
"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/envelopes"
"github.com/xmtp/xmtpd/pkg/mlsvalidate"
envelopesProto "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/registrant"
"github.com/xmtp/xmtpd/pkg/topic"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
Expand All @@ -31,20 +33,21 @@ const (
type Service struct {
message_api.UnimplementedReplicationApiServer

ctx context.Context
log *zap.Logger
registrant *registrant.Registrant
store *sql.DB
publishWorker *publishWorker
subscribeWorker *subscribeWorker
ctx context.Context
log *zap.Logger
registrant *registrant.Registrant
store *sql.DB
publishWorker *publishWorker
subscribeWorker *subscribeWorker
validationService mlsvalidate.MLSValidationService
}

func NewReplicationApiService(
ctx context.Context,
log *zap.Logger,
registrant *registrant.Registrant,
store *sql.DB,

validationService mlsvalidate.MLSValidationService,
) (*Service, error) {
publishWorker, err := startPublishWorker(ctx, log, registrant, store)
if err != nil {
Expand All @@ -56,12 +59,13 @@ func NewReplicationApiService(
}

return &Service{
ctx: ctx,
log: log,
registrant: registrant,
store: store,
publishWorker: publishWorker,
subscribeWorker: subscribeWorker,
ctx: ctx,
log: log,
registrant: registrant,
store: store,
publishWorker: publishWorker,
subscribeWorker: subscribeWorker,
validationService: validationService,
}, nil
}

Expand Down Expand Up @@ -323,6 +327,13 @@ func (s *Service) PublishPayerEnvelopes(
}

targetTopic := payerEnv.ClientEnvelope.TargetTopic()
topicKind := targetTopic.Kind()

if topicKind == topic.TOPIC_KIND_KEY_PACKAGES_V1 {
if err = s.validateKeyPackage(ctx, &payerEnv.ClientEnvelope); err != nil {
return nil, err
}
}

stagedEnv, err := queries.New(s.store).
InsertStagedOriginatorEnvelope(ctx, queries.InsertStagedOriginatorEnvelopeParams{
Expand Down Expand Up @@ -397,6 +408,34 @@ func (s *Service) validatePayerEnvelope(
return payerEnv, nil
}

func (s *Service) validateKeyPackage(
ctx context.Context,
clientEnv *envelopes.ClientEnvelope,
) error {
payload, ok := clientEnv.Payload().(*envelopesProto.ClientEnvelope_UploadKeyPackage)
if !ok {
return status.Errorf(codes.InvalidArgument, "invalid payload type")
}

validationResult, err := s.validationService.ValidateKeyPackages(
ctx,
[][]byte{payload.UploadKeyPackage.KeyPackage.KeyPackageTlsSerialized},
)
if err != nil {
return status.Errorf(codes.Internal, "could not validate key package: %v", err)
}

if len(validationResult) == 0 {
return status.Errorf(codes.Internal, "no validation results")
}

if !validationResult[0].IsOk {
return status.Errorf(codes.InvalidArgument, "key package validation failed")
}

return nil
}

func (s *Service) validateClientInfo(clientEnv *envelopes.ClientEnvelope) error {
aad := clientEnv.Aad()
if aad.GetTargetOriginator() != s.registrant.NodeID() {
Expand Down
1 change: 1 addition & 0 deletions pkg/mlsvalidate/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
)

type KeyPackageValidationResult struct {
IsOk bool
InstallationKey []byte
Credential *identity_proto.MlsCredential
Expiration uint64
Expand Down
19 changes: 13 additions & 6 deletions pkg/mlsvalidate/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,19 @@ func (s *MLSValidationServiceImpl) ValidateKeyPackages(
out := make([]KeyPackageValidationResult, len(response.Responses))
for i, response := range response.Responses {
if !response.IsOk {
return nil, fmt.Errorf("validation failed with error %s", response.ErrorMessage)
}
out[i] = KeyPackageValidationResult{
InstallationKey: response.InstallationPublicKey,
Credential: nil,
Expiration: response.Expiration,
out[i] = KeyPackageValidationResult{
IsOk: false,
InstallationKey: nil,
Credential: nil,
Expiration: 0,
}
} else {
out[i] = KeyPackageValidationResult{
IsOk: true,
InstallationKey: response.InstallationPublicKey,
Credential: nil,
Expiration: response.Expiration,
}
}
}
return out, nil
Expand Down
33 changes: 23 additions & 10 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ type ReplicationServer struct {
apiServer *api.ApiServer
syncServer *sync.SyncServer

ctx context.Context
cancel context.CancelFunc
log *zap.Logger
registrant *registrant.Registrant
nodeRegistry registry.NodeRegistry
indx *indexer.Indexer
options config.ServerOptions
metrics *metrics.Server
ctx context.Context
cancel context.CancelFunc
log *zap.Logger
registrant *registrant.Registrant
nodeRegistry registry.NodeRegistry
indx *indexer.Indexer
options config.ServerOptions
metrics *metrics.Server
validationService mlsvalidate.MLSValidationService
}

func NewReplicationServer(
Expand Down Expand Up @@ -98,7 +99,7 @@ func NewReplicationServer(
}

if options.Indexer.Enable {
validationService, err := mlsvalidate.NewMlsValidationService(
s.validationService, err = mlsvalidate.NewMlsValidationService(
ctx,
log,
options.MlsValidation,
Expand All @@ -111,7 +112,7 @@ func NewReplicationServer(
err = s.indx.StartIndexer(
writerDB,
options.Contracts,
validationService,
s.validationService,
)

if err != nil {
Expand Down Expand Up @@ -168,11 +169,23 @@ func startAPIServer(

serviceRegistrationFunc := func(grpcServer *grpc.Server) error {
if options.Replication.Enable {
if s.validationService == nil {
s.validationService, err = mlsvalidate.NewMlsValidationService(
ctx,
log,
options.MlsValidation,
)
if err != nil {
return err
}
}

replicationService, err := message.NewReplicationApiService(
ctx,
log,
s.registrant,
writerDB,
s.validationService,
)
if err != nil {
return err
Expand Down
3 changes: 3 additions & 0 deletions pkg/testutils/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/xmtp/xmtpd/pkg/authn"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/mocks/blockchain"
mlsvalidateMocks "github.com/xmtp/xmtpd/pkg/mocks/mlsvalidate"
mocks "github.com/xmtp/xmtpd/pkg/mocks/registry"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/payer_api"
Expand Down Expand Up @@ -78,6 +79,7 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, func()) {
registrant, err := registrant.NewRegistrant(ctx, log, queries.New(db), mockRegistry, privKeyStr)
require.NoError(t, err)
mockMessagePublisher := blockchain.NewMockIBlockchainPublisher(t)
mockValidationService := mlsvalidateMocks.NewMockMLSValidationService(t)

jwtVerifier := authn.NewRegistryVerifier(mockRegistry, registrant.NodeID())

Expand All @@ -87,6 +89,7 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, func()) {
log,
registrant,
db,
mockValidationService,
)
require.NoError(t, err)
message_api.RegisterReplicationApiServer(grpcServer, replicationService)
Expand Down

0 comments on commit b4f6507

Please sign in to comment.