diff --git a/services/horizon/internal/db2/history/liquidity_pools.go b/services/horizon/internal/db2/history/liquidity_pools.go index 94ad9bdbda..0c79dc890c 100644 --- a/services/horizon/internal/db2/history/liquidity_pools.go +++ b/services/horizon/internal/db2/history/liquidity_pools.go @@ -9,6 +9,7 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/guregu/null" + "github.com/jmoiron/sqlx" "github.com/stellar/go/services/horizon/internal/db2" "github.com/stellar/go/support/errors" "github.com/stellar/go/xdr" @@ -85,7 +86,7 @@ func (lpar *LiquidityPoolAssetReserve) UnmarshalJSON(data []byte) error { type QLiquidityPools interface { UpsertLiquidityPools(ctx context.Context, lps []LiquidityPool) error GetLiquidityPoolsByID(ctx context.Context, poolIDs []string) ([]LiquidityPool, error) - GetAllLiquidityPools(ctx context.Context) ([]LiquidityPool, error) + StreamAllLiquidityPools(ctx context.Context, callback func(LiquidityPool) error) error CountLiquidityPools(ctx context.Context) (int, error) FindLiquidityPoolByID(ctx context.Context, liquidityPoolID string) (LiquidityPool, error) GetUpdatedLiquidityPools(ctx context.Context, newerThanSequence uint32) ([]LiquidityPool, error) @@ -186,13 +187,27 @@ func (q *Q) GetLiquidityPools(ctx context.Context, query LiquidityPoolsQuery) ([ return results, nil } -func (q *Q) GetAllLiquidityPools(ctx context.Context) ([]LiquidityPool, error) { - var results []LiquidityPool - if err := q.Select(ctx, &results, selectLiquidityPools.Where("deleted = ?", false)); err != nil { - return nil, errors.Wrap(err, "could not run select query") +func (q *Q) StreamAllLiquidityPools(ctx context.Context, callback func(LiquidityPool) error) error { + var rows *sqlx.Rows + var err error + + if rows, err = q.Query(ctx, selectLiquidityPools.Where("deleted = ?", false)); err != nil { + return errors.Wrap(err, "could not run all liquidity pools select query") } - return results, nil + defer rows.Close() + liquidityPool := LiquidityPool{} + + for rows.Next() { + if err = rows.StructScan(&liquidityPool); err != nil { + return errors.Wrap(err, "could not scan row into liquidity pool struct") + } + if err = callback(liquidityPool); err != nil { + return err + } + } + + return rows.Err() } // GetUpdatedLiquidityPools returns all liquidity pools created, updated, or deleted after the given ledger sequence. diff --git a/services/horizon/internal/db2/history/liquidity_pools_test.go b/services/horizon/internal/db2/history/liquidity_pools_test.go index e7bed878c1..eb95e35a3e 100644 --- a/services/horizon/internal/db2/history/liquidity_pools_test.go +++ b/services/horizon/internal/db2/history/liquidity_pools_test.go @@ -111,7 +111,11 @@ func TestFindLiquidityPoolsByAssets(t *testing.T) { tt.Assert.Len(lps, 1) pool := lps[0] - lps, err = q.GetAllLiquidityPools(tt.Ctx) + lps = nil + err = q.StreamAllLiquidityPools(tt.Ctx, func(liqudityPool LiquidityPool) error { + lps = append(lps, liqudityPool) + return nil + }) tt.Assert.NoError(err) tt.Assert.Len(lps, 1) tt.Assert.Equal(pool, lps[0]) @@ -205,7 +209,12 @@ func TestLiquidityPoolCompaction(t *testing.T) { tt.Assert.NoError(err) tt.Assert.Len(lps, 0) - lps, err = q.GetAllLiquidityPools(tt.Ctx) + lps = nil + err = q.StreamAllLiquidityPools(tt.Ctx, func(liqudityPool LiquidityPool) error { + lps = append(lps, liqudityPool) + return nil + }) + tt.Assert.NoError(err) tt.Assert.Len(lps, 0) diff --git a/services/horizon/internal/db2/history/mock_q_liquidity_pools.go b/services/horizon/internal/db2/history/mock_q_liquidity_pools.go index e42122865d..7b64b24126 100644 --- a/services/horizon/internal/db2/history/mock_q_liquidity_pools.go +++ b/services/horizon/internal/db2/history/mock_q_liquidity_pools.go @@ -31,9 +31,9 @@ func (m *MockQLiquidityPools) FindLiquidityPoolByID(ctx context.Context, liquidi return a.Get(0).(LiquidityPool), a.Error(1) } -func (m *MockQLiquidityPools) GetAllLiquidityPools(ctx context.Context) ([]LiquidityPool, error) { - a := m.Called(ctx) - return a.Get(0).([]LiquidityPool), a.Error(1) +func (m *MockQLiquidityPools) StreamAllLiquidityPools(ctx context.Context, callback func(LiquidityPool) error) error { + a := m.Called(ctx, callback) + return a.Error(0) } func (m *MockQLiquidityPools) GetUpdatedLiquidityPools(ctx context.Context, sequence uint32) ([]LiquidityPool, error) { diff --git a/services/horizon/internal/db2/history/mock_q_offers.go b/services/horizon/internal/db2/history/mock_q_offers.go index a2ee8efdd0..0c4bc5e9bb 100644 --- a/services/horizon/internal/db2/history/mock_q_offers.go +++ b/services/horizon/internal/db2/history/mock_q_offers.go @@ -11,9 +11,9 @@ type MockQOffers struct { mock.Mock } -func (m *MockQOffers) GetAllOffers(ctx context.Context) ([]Offer, error) { - a := m.Called(ctx) - return a.Get(0).([]Offer), a.Error(1) +func (m *MockQOffers) StreamAllOffers(ctx context.Context, callback func(Offer) error) error { + a := m.Called(ctx, callback) + return a.Error(0) } func (m *MockQOffers) GetOffersByIDs(ctx context.Context, ids []int64) ([]Offer, error) { diff --git a/services/horizon/internal/db2/history/offers.go b/services/horizon/internal/db2/history/offers.go index 0338d4474d..1d10b1bcde 100644 --- a/services/horizon/internal/db2/history/offers.go +++ b/services/horizon/internal/db2/history/offers.go @@ -4,13 +4,14 @@ import ( "context" sq "github.com/Masterminds/squirrel" + "github.com/jmoiron/sqlx" "github.com/stellar/go/support/errors" ) // QOffers defines offer related queries. type QOffers interface { - GetAllOffers(ctx context.Context) ([]Offer, error) + StreamAllOffers(ctx context.Context, callback func(Offer) error) error GetOffersByIDs(ctx context.Context, ids []int64) ([]Offer, error) CountOffers(ctx context.Context) (int, error) GetUpdatedOffers(ctx context.Context, newerThanSequence uint32) ([]Offer, error) @@ -80,11 +81,30 @@ func (q *Q) GetOffers(ctx context.Context, query OffersQuery) ([]Offer, error) { return offers, nil } -// GetAllOffers loads all non deleted offers -func (q *Q) GetAllOffers(ctx context.Context) ([]Offer, error) { - var offers []Offer - err := q.Select(ctx, &offers, selectOffers.Where("deleted = ?", false)) - return offers, err +// StreamAllOffers loads all non deleted offers +func (q *Q) StreamAllOffers(ctx context.Context, callback func(Offer) error) error { + var rows *sqlx.Rows + var err error + + if rows, err = q.Query(ctx, selectOffers.Where("deleted = ?", false)); err != nil { + return errors.Wrap(err, "could not run all offers select query") + } + + defer rows.Close() + offer := Offer{} + + for rows.Next() { + if err = rows.StructScan(&offer); err != nil { + return errors.Wrap(err, "could not scan row into offer struct") + } + + if err = callback(offer); err != nil { + return err + } + } + + return rows.Err() + } // GetUpdatedOffers returns all offers created, updated, or deleted after the given ledger sequence. diff --git a/services/horizon/internal/db2/history/offers_test.go b/services/horizon/internal/db2/history/offers_test.go index 0be7559515..85951bfa1c 100644 --- a/services/horizon/internal/db2/history/offers_test.go +++ b/services/horizon/internal/db2/history/offers_test.go @@ -96,7 +96,12 @@ func TestQueryEmptyOffers(t *testing.T) { test.ResetHorizonDB(t, tt.HorizonDB) q := &Q{tt.HorizonSession()} - offers, err := q.GetAllOffers(tt.Ctx) + var offers []Offer + err := q.StreamAllOffers(tt.Ctx, func(offer Offer) error { + offers = append(offers, offer) + return nil + }) + tt.Assert.NoError(err) tt.Assert.Len(offers, 0) @@ -127,7 +132,11 @@ func TestInsertOffers(t *testing.T) { err = insertOffer(tt, q, twoEurOffer) tt.Assert.NoError(err) - offers, err := q.GetAllOffers(tt.Ctx) + var offers []Offer + err = q.StreamAllOffers(tt.Ctx, func(offer Offer) error { + offers = append(offers, offer) + return nil + }) tt.Assert.NoError(err) tt.Assert.Len(offers, 2) @@ -154,7 +163,11 @@ func TestInsertOffers(t *testing.T) { tt.Assert.NoError(err) tt.Assert.Equal(2, afterCompactionCount) - afterCompactionOffers, err := q.GetAllOffers(tt.Ctx) + var afterCompactionOffers []Offer + err = q.StreamAllOffers(tt.Ctx, func(offer Offer) error { + afterCompactionOffers = append(afterCompactionOffers, offer) + return nil + }) tt.Assert.NoError(err) tt.Assert.Len(afterCompactionOffers, 2) } @@ -168,7 +181,11 @@ func TestUpdateOffer(t *testing.T) { err := insertOffer(tt, q, eurOffer) tt.Assert.NoError(err) - offers, err := q.GetAllOffers(tt.Ctx) + var offers []Offer + err = q.StreamAllOffers(tt.Ctx, func(offer Offer) error { + offers = append(offers, offer) + return nil + }) tt.Assert.NoError(err) tt.Assert.Len(offers, 1) @@ -192,7 +209,11 @@ func TestUpdateOffer(t *testing.T) { err = q.UpsertOffers(tt.Ctx, []Offer{modifiedEurOffer}) tt.Assert.NoError(err) - offers, err = q.GetAllOffers(tt.Ctx) + offers = nil + err = q.StreamAllOffers(tt.Ctx, func(offer Offer) error { + offers = append(offers, offer) + return nil + }) tt.Assert.NoError(err) tt.Assert.Len(offers, 1) @@ -215,7 +236,11 @@ func TestRemoveOffer(t *testing.T) { err := insertOffer(tt, q, eurOffer) tt.Assert.NoError(err) - offers, err := q.GetAllOffers(tt.Ctx) + var offers []Offer + err = q.StreamAllOffers(tt.Ctx, func(offer Offer) error { + offers = append(offers, offer) + return nil + }) tt.Assert.NoError(err) tt.Assert.Len(offers, 1) tt.Assert.Equal(offers[0], eurOffer) @@ -229,7 +254,11 @@ func TestRemoveOffer(t *testing.T) { expectedUpdates[0].LastModifiedLedger = 1236 expectedUpdates[0].Deleted = true - offers, err = q.GetAllOffers(tt.Ctx) + offers = nil + err = q.StreamAllOffers(tt.Ctx, func(offer Offer) error { + offers = append(offers, offer) + return nil + }) tt.Assert.NoError(err) tt.Assert.Len(offers, 0) diff --git a/services/horizon/internal/ingest/main_test.go b/services/horizon/internal/ingest/main_test.go index 375b09558e..63c0f19f87 100644 --- a/services/horizon/internal/ingest/main_test.go +++ b/services/horizon/internal/ingest/main_test.go @@ -323,9 +323,9 @@ func (m *mockDBQ) GetExpStateInvalid(ctx context.Context) (bool, error) { return args.Get(0).(bool), args.Error(1) } -func (m *mockDBQ) GetAllOffers(ctx context.Context) ([]history.Offer, error) { - args := m.Called(ctx) - return args.Get(0).([]history.Offer), args.Error(1) +func (m *mockDBQ) StreamAllOffers(ctx context.Context, callback func(history.Offer) error) error { + a := m.Called(ctx, callback) + return a.Error(0) } func (m *mockDBQ) GetLatestHistoryLedger(ctx context.Context) (uint32, error) { diff --git a/services/horizon/internal/ingest/orderbook.go b/services/horizon/internal/ingest/orderbook.go index 2700a42759..9577cf5f6d 100644 --- a/services/horizon/internal/ingest/orderbook.go +++ b/services/horizon/internal/ingest/orderbook.go @@ -132,26 +132,26 @@ func (o *OrderBookStream) update(ctx context.Context, status ingestionStatus) (b defer o.graph.Discard() - offers, err := o.historyQ.GetAllOffers(ctx) - if err != nil { - return true, errors.Wrap(err, "Error from GetAllOffers") - } + err := o.historyQ.StreamAllOffers(ctx, func(offer history.Offer) error { + o.graph.AddOffers(offerToXDR(offer)) + return nil + }) - liquidityPools, err := o.historyQ.GetAllLiquidityPools(ctx) if err != nil { - return true, errors.Wrap(err, "Error from GetAllLiquidityPools") + return true, errors.Wrap(err, "Error loading offers into orderbook") } - for _, offer := range offers { - o.graph.AddOffers(offerToXDR(offer)) - } - - for _, liquidityPool := range liquidityPools { - liquidityPoolXDR, err := liquidityPoolToXDR(liquidityPool) - if err != nil { - return true, errors.Wrap(err, "Invalid liquidity pool row") + err = o.historyQ.StreamAllLiquidityPools(ctx, func(liquidityPool history.LiquidityPool) error { + if liquidityPoolXDR, liquidityPoolErr := liquidityPoolToXDR(liquidityPool); liquidityPoolErr != nil { + return errors.Wrapf(liquidityPoolErr, "Invalid liquidity pool row %v, unable to marshal to xdr", liquidityPool) + } else { + o.graph.AddLiquidityPools(liquidityPoolXDR) + return nil } - o.graph.AddLiquidityPools(liquidityPoolXDR) + }) + + if err != nil { + return true, errors.Wrap(err, "Error loading liquidity pools into orderbook") } if err := o.graph.Apply(status.LastIngestedLedger); err != nil { @@ -209,9 +209,14 @@ func (o *OrderBookStream) update(ctx context.Context, status ingestionStatus) (b } func (o *OrderBookStream) verifyAllOffers(ctx context.Context, offers []xdr.OfferEntry) (bool, error) { - ingestionOffers, err := o.historyQ.GetAllOffers(ctx) + var ingestionOffers []history.Offer + err := o.historyQ.StreamAllOffers(ctx, func(offer history.Offer) error { + ingestionOffers = append(ingestionOffers, offer) + return nil + }) + if err != nil { - return false, errors.Wrap(err, "Error from GetAllOffers") + return false, errors.Wrap(err, "Error loading all offers for orderbook verification") } mismatch := len(offers) != len(ingestionOffers) @@ -253,9 +258,15 @@ func (o *OrderBookStream) verifyAllOffers(ctx context.Context, offers []xdr.Offe } func (o *OrderBookStream) verifyAllLiquidityPools(ctx context.Context, liquidityPools []xdr.LiquidityPoolEntry) (bool, error) { - ingestionLiquidityPools, err := o.historyQ.GetAllLiquidityPools(ctx) + var ingestionLiquidityPools []history.LiquidityPool + + err := o.historyQ.StreamAllLiquidityPools(ctx, func(liquidityPool history.LiquidityPool) error { + ingestionLiquidityPools = append(ingestionLiquidityPools, liquidityPool) + return nil + }) + if err != nil { - return false, errors.Wrap(err, "Error from GetAllLiquidityPools") + return false, errors.Wrap(err, "Error loading all liquidity pools for orderbook verification") } mismatch := len(liquidityPools) != len(ingestionLiquidityPools) @@ -322,7 +333,7 @@ func (o *OrderBookStream) Update(ctx context.Context) error { } // add 15 minute jitter so that not all horizon nodes are calling - // historyQ.GetAllOffers at the same time + // historyQ.StreamAllOffers at the same time jitter := time.Duration(rand.Int63n(int64(15 * time.Minute))) requiresVerification := o.lastLedger > 0 && time.Since(o.lastVerification) >= verificationFrequency+jitter diff --git a/services/horizon/internal/ingest/orderbook_test.go b/services/horizon/internal/ingest/orderbook_test.go index 161bf44bd5..870a3490ab 100644 --- a/services/horizon/internal/ingest/orderbook_test.go +++ b/services/horizon/internal/ingest/orderbook_test.go @@ -11,6 +11,7 @@ import ( "github.com/stellar/go/services/horizon/internal/ingest/processors" "github.com/stellar/go/xdr" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) @@ -238,7 +239,7 @@ func (t *UpdateOrderBookStreamTestSuite) TearDownTest() { t.graph.AssertExpectations(t.T()) } -func (t *UpdateOrderBookStreamTestSuite) TestGetAllOffersError() { +func (t *UpdateOrderBookStreamTestSuite) TestStreamAllOffersError() { status := ingestionStatus{ HistoryConsistentWithState: true, StateInvalid: false, @@ -248,13 +249,13 @@ func (t *UpdateOrderBookStreamTestSuite) TestGetAllOffersError() { } t.graph.On("Clear").Return().Once() t.graph.On("Discard").Return().Once() - t.historyQ.On("GetAllOffers", t.ctx). - Return([]history.Offer{}, fmt.Errorf("offers error")). + t.historyQ.On("StreamAllOffers", t.ctx, mock.Anything). + Return(fmt.Errorf("offers error")). Once() t.stream.lastLedger = 300 _, err := t.stream.update(t.ctx, status) - t.Assert().EqualError(err, "Error from GetAllOffers: offers error") + t.Assert().EqualError(err, "Error loading offers into orderbook: offers error") t.Assert().Equal(uint32(0), t.stream.lastLedger) } @@ -280,12 +281,17 @@ func (t *UpdateOrderBookStreamTestSuite) TestResetApplyError() { SellerId: xdr.MustAddress(sellerID), OfferId: 20, }} - t.historyQ.On("GetAllOffers", t.ctx). - Return([]history.Offer{offer, otherOffer}, nil). + t.historyQ.On("StreamAllOffers", t.ctx, mock.Anything). + Return(nil). + Run(func(args mock.Arguments) { + callback := args.Get(1).(func(offer history.Offer) error) + callback(offer) + callback(otherOffer) + }). Once() - t.historyQ.MockQLiquidityPools.On("GetAllLiquidityPools", t.ctx). - Return([]history.LiquidityPool{}, nil). + t.historyQ.MockQLiquidityPools.On("StreamAllLiquidityPools", t.ctx, mock.Anything). + Return(nil). Once() t.graph.On("AddOffers", offerEntry).Return().Once() @@ -317,12 +323,18 @@ func (t *UpdateOrderBookStreamTestSuite) mockReset(status ingestionStatus) { OfferId: 20, }} offers := []history.Offer{offer, otherOffer} - t.historyQ.On("GetAllOffers", t.ctx). - Return(offers, nil). + t.historyQ.On("StreamAllOffers", t.ctx, mock.Anything). + Return(nil). + Run(func(args mock.Arguments) { + callback := args.Get(1).(func(offer history.Offer) error) + for idx := range offers { + callback(offers[idx]) + } + }). Once() - t.historyQ.MockQLiquidityPools.On("GetAllLiquidityPools", t.ctx). - Return([]history.LiquidityPool{}, nil). + t.historyQ.MockQLiquidityPools.On("StreamAllLiquidityPools", t.ctx, mock.Anything). + Return(nil). Once() t.graph.On("AddOffers", offerEntry).Return().Once() @@ -636,19 +648,18 @@ func (t *VerifyOffersStreamTestSuite) TearDownTest() { t.graph.AssertExpectations(t.T()) } -func (t *VerifyOffersStreamTestSuite) TestGetAllOffersError() { - t.historyQ.On("GetAllOffers", t.ctx). - Return([]history.Offer{}, fmt.Errorf("offers error")). +func (t *VerifyOffersStreamTestSuite) TestStreamAllOffersError() { + t.historyQ.On("StreamAllOffers", t.ctx, mock.Anything). + Return(fmt.Errorf("offers error")). Once() offersOk, err := t.stream.verifyAllOffers(t.ctx, t.graph.Offers()) - t.Assert().EqualError(err, "Error from GetAllOffers: offers error") + t.Assert().EqualError(err, "Error loading all offers for orderbook verification: offers error") t.Assert().False(offersOk) } func (t *VerifyOffersStreamTestSuite) TestEmptyDBOffers() { - var offers []history.Offer - t.historyQ.On("GetAllOffers", t.ctx).Return(offers, nil).Once() + t.historyQ.On("StreamAllOffers", t.ctx, mock.Anything).Return(nil).Once() offersOk, err := t.stream.verifyAllOffers(t.ctx, t.graph.Offers()) t.Assert().NoError(err) @@ -671,7 +682,15 @@ func (t *VerifyOffersStreamTestSuite) TestLengthMismatch() { LastModifiedLedger: 1, }, } - t.historyQ.On("GetAllOffers", t.ctx).Return(offers, nil).Once() + t.historyQ.On("StreamAllOffers", t.ctx, mock.Anything). + Return(nil). + Run(func(args mock.Arguments) { + callback := args.Get(1).(func(offer history.Offer) error) + for idx := range offers { + callback(offers[idx]) + } + }). + Once() offersOk, err := t.stream.verifyAllOffers(t.ctx, t.graph.Offers()) t.Assert().NoError(err) @@ -707,7 +726,16 @@ func (t *VerifyOffersStreamTestSuite) TestContentMismatch() { LastModifiedLedger: 1, }, } - t.historyQ.On("GetAllOffers", t.ctx).Return(offers, nil).Once() + + t.historyQ.On("StreamAllOffers", t.ctx, mock.Anything). + Return(nil). + Run(func(args mock.Arguments) { + callback := args.Get(1).(func(offer history.Offer) error) + for idx := range offers { + callback(offers[idx]) + } + }). + Once() t.stream.lastLedger = 300 offersOk, err := t.stream.verifyAllOffers(t.ctx, t.graph.Offers()) @@ -744,7 +772,15 @@ func (t *VerifyOffersStreamTestSuite) TestSuccess() { LastModifiedLedger: 1, }, } - t.historyQ.On("GetAllOffers", t.ctx).Return(offers, nil).Once() + t.historyQ.On("StreamAllOffers", t.ctx, mock.Anything). + Return(nil). + Run(func(args mock.Arguments) { + callback := args.Get(1).(func(offer history.Offer) error) + for idx := range offers { + callback(offers[idx]) + } + }). + Once() offersOk, err := t.stream.verifyAllOffers(t.ctx, t.graph.Offers()) t.Assert().NoError(err) @@ -812,19 +848,19 @@ func (t *VerifyLiquidityPoolsStreamTestSuite) TearDownTest() { t.graph.AssertExpectations(t.T()) } -func (t *VerifyLiquidityPoolsStreamTestSuite) TestGetAllLiquidityPoolsError() { - t.historyQ.MockQLiquidityPools.On("GetAllLiquidityPools", t.ctx). - Return([]history.LiquidityPool{}, fmt.Errorf("liquidity pools error")). +func (t *VerifyLiquidityPoolsStreamTestSuite) TestStreamAllLiquidityPoolsError() { + t.historyQ.MockQLiquidityPools.On("StreamAllLiquidityPools", t.ctx, mock.Anything). + Return(fmt.Errorf("liquidity pools error")). Once() liquidityPoolsOk, err := t.stream.verifyAllLiquidityPools(t.ctx, t.graph.LiquidityPools()) - t.Assert().EqualError(err, "Error from GetAllLiquidityPools: liquidity pools error") + t.Assert().EqualError(err, "Error loading all liquidity pools for orderbook verification: liquidity pools error") t.Assert().False(liquidityPoolsOk) } func (t *VerifyLiquidityPoolsStreamTestSuite) TestEmptyDBOffers() { - t.historyQ.MockQLiquidityPools.On("GetAllLiquidityPools", t.ctx). - Return([]history.LiquidityPool{}, nil). + t.historyQ.MockQLiquidityPools.On("StreamAllLiquidityPools", t.ctx, mock.Anything). + Return(nil). Once() liquidityPoolsOk, err := t.stream.verifyAllLiquidityPools(t.ctx, t.graph.LiquidityPools()) @@ -855,8 +891,14 @@ func (t *VerifyLiquidityPoolsStreamTestSuite) TestLengthMismatch() { }, } - t.historyQ.MockQLiquidityPools.On("GetAllLiquidityPools", t.ctx). - Return(liquidityPools, nil). + t.historyQ.MockQLiquidityPools.On("StreamAllLiquidityPools", t.ctx, mock.Anything). + Return(nil). + Run(func(args mock.Arguments) { + callback := args.Get(1).(func(offer history.LiquidityPool) error) + for idx := range liquidityPools { + callback(liquidityPools[idx]) + } + }). Once() liquidityPoolsOk, err := t.stream.verifyAllLiquidityPools(t.ctx, t.graph.LiquidityPools()) @@ -905,8 +947,14 @@ func (t *VerifyLiquidityPoolsStreamTestSuite) TestContentMismatch() { Deleted: false, }, } - t.historyQ.MockQLiquidityPools.On("GetAllLiquidityPools", t.ctx). - Return(liquidityPools, nil). + t.historyQ.MockQLiquidityPools.On("StreamAllLiquidityPools", t.ctx, mock.Anything). + Return(nil). + Run(func(args mock.Arguments) { + callback := args.Get(1).(func(offer history.LiquidityPool) error) + for idx := range liquidityPools { + callback(liquidityPools[idx]) + } + }). Once() liquidityPoolsOk, err := t.stream.verifyAllLiquidityPools(t.ctx, t.graph.LiquidityPools()) @@ -955,8 +1003,14 @@ func (t *VerifyLiquidityPoolsStreamTestSuite) TestSuccess() { Deleted: false, }, } - t.historyQ.MockQLiquidityPools.On("GetAllLiquidityPools", t.ctx). - Return(liquidityPools, nil). + t.historyQ.MockQLiquidityPools.On("StreamAllLiquidityPools", t.ctx, mock.Anything). + Return(nil). + Run(func(args mock.Arguments) { + callback := args.Get(1).(func(history.LiquidityPool) error) + for idx := range liquidityPools { + callback(liquidityPools[idx]) + } + }). Once() offersOk, err := t.stream.verifyAllLiquidityPools(t.ctx, t.graph.LiquidityPools()) diff --git a/support/db/main.go b/support/db/main.go index ed6e65285d..5a316899c5 100644 --- a/support/db/main.go +++ b/support/db/main.go @@ -127,6 +127,8 @@ type SessionInterface interface { GetRaw(ctx context.Context, dest interface{}, query string, args ...interface{}) error Select(ctx context.Context, dest interface{}, query squirrel.Sqlizer) error SelectRaw(ctx context.Context, dest interface{}, query string, args ...interface{}) error + Query(ctx context.Context, query squirrel.Sqlizer) (*sqlx.Rows, error) + QueryRaw(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) GetTable(name string) *Table Exec(ctx context.Context, query squirrel.Sqlizer) (sql.Result, error) ExecRaw(ctx context.Context, query string, args ...interface{}) (sql.Result, error) diff --git a/support/db/mock_session.go b/support/db/mock_session.go index 570afeaca3..9c3c4e7861 100644 --- a/support/db/mock_session.go +++ b/support/db/mock_session.go @@ -72,6 +72,16 @@ func (m *MockSession) GetRaw(ctx context.Context, dest interface{}, query string return argss.Error(0) } +func (m *MockSession) Query(ctx context.Context, query squirrel.Sqlizer) (*sqlx.Rows, error) { + args := m.Called(ctx, query) + return args.Get(0).(*sqlx.Rows), args.Error(1) +} + +func (m *MockSession) QueryRaw(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { + argss := m.Called(ctx, query, args) + return argss.Get(0).(*sqlx.Rows), argss.Error(1) +} + func (m *MockSession) Select(ctx context.Context, dest interface{}, query squirrel.Sqlizer) error { argss := m.Called(ctx, dest, query) return argss.Error(0)