Skip to content

Commit

Permalink
Pass mocks back to test API
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas committed Jan 6, 2025
1 parent b4f6507 commit 0ef8fb1
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 30 deletions.
89 changes: 85 additions & 4 deletions pkg/api/message/publish_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (
"testing"
"time"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/mlsvalidate"
apiv1 "github.com/xmtp/xmtpd/pkg/proto/mls/api/v1"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
apiTestUtils "github.com/xmtp/xmtpd/pkg/testutils/api"
Expand All @@ -16,7 +19,7 @@ import (
)

func TestPublishEnvelope(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, db, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()

payerEnvelope := envelopeTestUtils.CreatePayerEnvelope(t)
Expand Down Expand Up @@ -64,7 +67,7 @@ func TestPublishEnvelope(t *testing.T) {
}

func TestUnmarshalErrorOnPublish(t *testing.T) {
api, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, _, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()

envelope := envelopeTestUtils.CreatePayerEnvelope(t)
Expand All @@ -79,7 +82,7 @@ func TestUnmarshalErrorOnPublish(t *testing.T) {
}

func TestMismatchingOriginatorOnPublish(t *testing.T) {
api, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, _, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()

clientEnv := envelopeTestUtils.CreateClientEnvelope()
Expand All @@ -96,7 +99,7 @@ func TestMismatchingOriginatorOnPublish(t *testing.T) {
}

func TestMissingTopicOnPublish(t *testing.T) {
api, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, _, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()

clientEnv := envelopeTestUtils.CreateClientEnvelope()
Expand All @@ -111,3 +114,81 @@ func TestMissingTopicOnPublish(t *testing.T) {
)
require.ErrorContains(t, err, "topic")
}

func TestKeyPackageValidationSuccess(t *testing.T) {
api, _, apiMocks, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()

clientEnv := envelopeTestUtils.CreateClientEnvelope(&envelopes.AuthenticatedData{
TargetTopic: topic.NewTopic(topic.TOPIC_KIND_KEY_PACKAGES_V1, []byte{1, 2, 3}).Bytes(),
TargetOriginator: 100,
LastSeen: &envelopes.VectorClock{},
})
clientEnv.Payload = &envelopes.ClientEnvelope_UploadKeyPackage{
UploadKeyPackage: &apiv1.UploadKeyPackageRequest{
KeyPackage: &apiv1.KeyPackageUpload{
KeyPackageTlsSerialized: []byte{1, 2, 3},
},
},
}

apiMocks.MockValidationService.EXPECT().
ValidateKeyPackages(mock.Anything, mock.Anything).
Return(
[]mlsvalidate.KeyPackageValidationResult{
{
IsOk: true,
},
},
nil,
)

_, err := api.PublishPayerEnvelopes(
context.Background(),
&message_api.PublishPayerEnvelopesRequest{
PayerEnvelopes: []*envelopes.PayerEnvelope{
envelopeTestUtils.CreatePayerEnvelope(t, clientEnv),
},
},
)
require.Nil(t, err)
}

func TestKeyPackageValidationFail(t *testing.T) {
api, _, apiMocks, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()

clientEnv := envelopeTestUtils.CreateClientEnvelope(&envelopes.AuthenticatedData{
TargetTopic: topic.NewTopic(topic.TOPIC_KIND_KEY_PACKAGES_V1, []byte{1, 2, 3}).Bytes(),
TargetOriginator: 100,
LastSeen: &envelopes.VectorClock{},
})
clientEnv.Payload = &envelopes.ClientEnvelope_UploadKeyPackage{
UploadKeyPackage: &apiv1.UploadKeyPackageRequest{
KeyPackage: &apiv1.KeyPackageUpload{
KeyPackageTlsSerialized: []byte{1, 2, 3},
},
},
}

apiMocks.MockValidationService.EXPECT().
ValidateKeyPackages(mock.Anything, mock.Anything).
Return(
[]mlsvalidate.KeyPackageValidationResult{
{
IsOk: false,
},
},
nil,
)

_, err := api.PublishPayerEnvelopes(
context.Background(),
&message_api.PublishPayerEnvelopesRequest{
PayerEnvelopes: []*envelopes.PayerEnvelope{
envelopeTestUtils.CreatePayerEnvelope(t, clientEnv),
},
},
)
require.Error(t, err)
}
18 changes: 10 additions & 8 deletions pkg/api/message/subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ var (
)
var allRows []queries.InsertGatewayEnvelopeParams

func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func()) {
func setupTest(
t *testing.T,
) (message_api.ReplicationApiClient, *sql.DB, testUtilsApi.ApiServerMocks, func()) {
allRows = []queries.InsertGatewayEnvelopeParams{
// Initial rows
{
Expand Down Expand Up @@ -115,7 +117,7 @@ func validateUpdates(
}

func TestSubscribeEnvelopesAll(t *testing.T) {
client, db, cleanup := setupTest(t)
client, db, _, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, db)

Expand All @@ -136,7 +138,7 @@ func TestSubscribeEnvelopesAll(t *testing.T) {
}

func TestSubscribeEnvelopesByTopic(t *testing.T) {
client, store, cleanup := setupTest(t)
client, store, _, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, store)

Expand All @@ -158,7 +160,7 @@ func TestSubscribeEnvelopesByTopic(t *testing.T) {
}

func TestSubscribeEnvelopesByOriginator(t *testing.T) {
client, db, cleanup := setupTest(t)
client, db, _, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, db)

Expand All @@ -180,7 +182,7 @@ func TestSubscribeEnvelopesByOriginator(t *testing.T) {
}

func TestSimultaneousSubscriptions(t *testing.T) {
client, store, cleanup := setupTest(t)
client, store, _, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, store)

Expand Down Expand Up @@ -223,7 +225,7 @@ func TestSimultaneousSubscriptions(t *testing.T) {
}

func TestSubscribeEnvelopesFromCursor(t *testing.T) {
client, store, cleanup := setupTest(t)
client, store, _, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, store)

Expand All @@ -245,7 +247,7 @@ func TestSubscribeEnvelopesFromCursor(t *testing.T) {
}

func TestSubscribeEnvelopesFromEmptyCursor(t *testing.T) {
client, store, cleanup := setupTest(t)
client, store, _, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, store)

Expand All @@ -267,7 +269,7 @@ func TestSubscribeEnvelopesFromEmptyCursor(t *testing.T) {
}

func TestSubscribeEnvelopesInvalidRequest(t *testing.T) {
client, _, cleanup := setupTest(t)
client, _, _, cleanup := setupTest(t)
defer cleanup()

stream, err := client.SubscribeEnvelopes(
Expand Down
4 changes: 2 additions & 2 deletions pkg/api/payer/clientManager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ func formatAddress(addr string) string {
}

func TestClientManager(t *testing.T) {
server1, _, cleanup1 := apiTestUtils.NewTestAPIServer(t)
server1, _, _, cleanup1 := apiTestUtils.NewTestAPIServer(t)
defer cleanup1()
server2, _, cleanup2 := apiTestUtils.NewTestAPIServer(t)
server2, _, _, cleanup2 := apiTestUtils.NewTestAPIServer(t)
defer cleanup2()

nodeRegistry := registry.NewFixedNodeRegistry([]registry.Node{
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/payer/publish_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestPublishIdentityUpdate(t *testing.T) {
}

func TestPublishToNodes(t *testing.T) {
originatorServer, _, originatorCleanup := apiTestUtils.NewTestAPIServer(t)
originatorServer, _, _, originatorCleanup := apiTestUtils.NewTestAPIServer(t)
defer originatorCleanup()

ctx := context.Background()
Expand Down
20 changes: 10 additions & 10 deletions pkg/api/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
}

func TestQueryAllEnvelopes(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, db, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, db)

Expand All @@ -91,7 +91,7 @@ func TestQueryAllEnvelopes(t *testing.T) {
}

func TestQueryPagedEnvelopes(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, db, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, db)

Expand All @@ -107,7 +107,7 @@ func TestQueryPagedEnvelopes(t *testing.T) {
}

func TestQueryEnvelopesByOriginator(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, db, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, db)

Expand All @@ -126,7 +126,7 @@ func TestQueryEnvelopesByOriginator(t *testing.T) {
}

func TestQueryEnvelopesByTopic(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, store, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, store)

Expand All @@ -145,7 +145,7 @@ func TestQueryEnvelopesByTopic(t *testing.T) {
}

func TestQueryEnvelopesFromLastSeen(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, db, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, db)

Expand All @@ -163,7 +163,7 @@ func TestQueryEnvelopesFromLastSeen(t *testing.T) {
}

func TestQueryTopicFromLastSeen(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, store, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, store)

Expand All @@ -184,7 +184,7 @@ func TestQueryTopicFromLastSeen(t *testing.T) {
}

func TestQueryMultipleTopicsFromLastSeen(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, store, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, store)

Expand All @@ -205,7 +205,7 @@ func TestQueryMultipleTopicsFromLastSeen(t *testing.T) {
}

func TestQueryMultipleOriginatorsFromLastSeen(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, store, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, store)

Expand All @@ -226,7 +226,7 @@ func TestQueryMultipleOriginatorsFromLastSeen(t *testing.T) {
}

func TestQueryEnvelopesWithEmptyResult(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, store, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, store)

Expand All @@ -244,7 +244,7 @@ func TestQueryEnvelopesWithEmptyResult(t *testing.T) {
}

func TestInvalidQuery(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
api, store, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
_ = setupQueryTest(t, store)

Expand Down
24 changes: 19 additions & 5 deletions pkg/testutils/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ func NewPayerAPIClient(
}
}

func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, func()) {
type ApiServerMocks struct {
MockRegistry *mocks.MockNodeRegistry
MockValidationService *mlsvalidateMocks.MockMLSValidationService
MockMessagePublisher *blockchain.MockIBlockchainPublisher
}

func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, ApiServerMocks, func()) {
ctx, cancel := context.WithCancel(context.Background())
log := testutils.NewLog(t)
db, _, dbCleanup := testutils.NewDB(t, ctx)
Expand Down Expand Up @@ -117,17 +123,25 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, func()) {
)
require.NoError(t, err)

return svr, db, func() {
allMocks := ApiServerMocks{
MockRegistry: mockRegistry,
MockValidationService: mockValidationService,
MockMessagePublisher: mockMessagePublisher,
}

return svr, db, allMocks, func() {
cancel()
svr.Close()
dbCleanup()
}
}

func NewTestReplicationAPIClient(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func()) {
svc, db, svcCleanup := NewTestAPIServer(t)
func NewTestReplicationAPIClient(
t *testing.T,
) (message_api.ReplicationApiClient, *sql.DB, ApiServerMocks, func()) {
svc, db, allMocks, svcCleanup := NewTestAPIServer(t)
client, clientCleanup := NewReplicationAPIClient(t, context.Background(), svc.Addr().String())
return client, db, func() {
return client, db, allMocks, func() {
clientCleanup()
svcCleanup()
}
Expand Down

0 comments on commit 0ef8fb1

Please sign in to comment.