diff --git a/tests/test_agreements_pool.py b/tests/test_agreements_pool.py index 69070853a..340b1cabb 100644 --- a/tests/test_agreements_pool.py +++ b/tests/test_agreements_pool.py @@ -28,6 +28,12 @@ async def create_agreement(): return create_agreement +def get_agreements_pool() -> agreements_pool.AgreementsPool: + return agreements_pool.AgreementsPool( + lambda _event, **kwargs: None, lambda _offer: None, mock.Mock() + ) + + @pytest.mark.asyncio async def test_use_agreement_chooses_max_score(): """Test that a proposal with the largest score is chosen in AgreementsPool.use_agreement().""" @@ -40,7 +46,7 @@ async def test_use_agreement_chooses_max_score(): mock_score = random.random() proposals[n] = (mock_score, mock_proposal) - pool = agreements_pool.AgreementsPool(lambda _event, **kwargs: None, lambda _offer: None) + pool = get_agreements_pool() for score, proposal in proposals.values(): await pool.add_proposal(score, proposal) @@ -76,7 +82,7 @@ async def test_use_agreement_shuffles_proposals(): mock_score = 42.0 if n != 0 else 41.0 proposals.append((mock_score, mock_proposal)) - pool = agreements_pool.AgreementsPool(lambda _event, **kwargs: None, lambda _offer: None) + pool = get_agreements_pool() for score, proposal in proposals: await pool.add_proposal(score, proposal) @@ -95,7 +101,7 @@ def use_agreement_cb(agreement): async def test_use_agreement_no_proposals(): """Test that `AgreementPool.use_agreement()` returns `None` when there are no proposals.""" - pool = agreements_pool.AgreementsPool(lambda _event, **kwargs: None, lambda _offer: None) + pool = get_agreements_pool() def use_agreement_cb(_agreement): assert False, "use_agreement callback called" @@ -120,7 +126,7 @@ async def test_terminate_agreement(multi_activity, simulate_race, event_emitted) events = [] pool = agreements_pool.AgreementsPool( - lambda event, **kwargs: events.append(event), lambda _offer: None # noqa + lambda event, **kwargs: events.append(event), lambda _offer: None, mock.Mock() # noqa ) agreement: BufferedAgreement = BufferedAgreementFactory(has_multi_activity=multi_activity) pool._agreements[agreement.agreement.id] = agreement