diff --git a/src/modules/ejector/ejector.py b/src/modules/ejector/ejector.py index 7fc0f9119..b7c3be5e6 100644 --- a/src/modules/ejector/ejector.py +++ b/src/modules/ejector/ejector.py @@ -306,10 +306,11 @@ def _get_churn_limit(self, blockstamp: ReferenceBlockStamp) -> int: return churn_limit def _get_total_active_validators(self, blockstamp: ReferenceBlockStamp) -> int: - total_active_validators = len([ - is_active_validator(val, blockstamp.ref_epoch) - for val in self.w3.cc.get_validators(blockstamp) - ]) + total_active_validators = reduce( + lambda total, validator: total + int(is_active_validator(validator, blockstamp.ref_epoch)), + self.w3.cc.get_validators(blockstamp), + 0, + ) logger.info({'msg': 'Calculate total active validators.', 'value': total_active_validators}) return total_active_validators diff --git a/tests/factory/no_registry.py b/tests/factory/no_registry.py index 6e23eb701..94775a80c 100644 --- a/tests/factory/no_registry.py +++ b/tests/factory/no_registry.py @@ -53,6 +53,33 @@ def build_with_activation_epoch_bound(cls, max_value: int, **kwargs: Any): validator=ValidatorStateFactory.build(activation_epoch=str(faker.pyint(max_value=max_value - 1))), **kwargs ) + @classmethod + def build_not_active_vals(cls, epoch, **kwargs: Any): + return cls.build( + validator=ValidatorStateFactory.build( + activation_epoch=str(faker.pyint(min_value=epoch, max_value=FAR_FUTURE_EPOCH)), + exit_epoch=str(faker.pyint(min_value=FAR_FUTURE_EPOCH, max_value=FAR_FUTURE_EPOCH)), + ), **kwargs + ) + + @classmethod + def build_active_vals(cls, epoch, **kwargs: Any): + return cls.build( + validator=ValidatorStateFactory.build( + activation_epoch=str(faker.pyint(min_value=0, max_value=epoch - 1)), + exit_epoch=str(faker.pyint(min_value=epoch + 1, max_value=FAR_FUTURE_EPOCH)), + ), **kwargs + ) + + @classmethod + def build_exit_vals(cls, epoch, **kwargs: Any): + return cls.build( + validator=ValidatorStateFactory.build( + activation_epoch=str(faker.pyint(min_value=0, max_value=epoch - 1)), + exit_epoch=str(faker.pyint(min_value=0, max_value=epoch)), + ), **kwargs + ) + class NodeOperatorFactory(Web3Factory): __model__ = NodeOperator diff --git a/tests/modules/ejector/test_ejector.py b/tests/modules/ejector/test_ejector.py index 0b80ba7e9..d63f5693a 100644 --- a/tests/modules/ejector/test_ejector.py +++ b/tests/modules/ejector/test_ejector.py @@ -234,6 +234,19 @@ def test_get_predicted_withdrawable_epoch(ejector: Ejector) -> None: assert result == 3809, "Unexpected predicted withdrawable epoch" +@pytest.mark.unit +def test_get_total_active_validators(ejector: Ejector) -> None: + ref_blockstamp = ReferenceBlockStampFactory.build(ref_epoch=3546) + ejector.w3 = Mock() + ejector.w3.cc.get_validators = Mock(return_value=[ + *[LidoValidatorFactory.build_not_active_vals(ref_blockstamp.ref_epoch) for _ in range(100)], + *[LidoValidatorFactory.build_active_vals(ref_blockstamp.ref_epoch) for _ in range(100)], + *[LidoValidatorFactory.build_exit_vals(ref_blockstamp.ref_epoch) for _ in range(100)], + ]) + + assert ejector._get_total_active_validators(ref_blockstamp) == 100 + + @pytest.mark.unit @pytest.mark.usefixtures("consensus_client", "lido_validators") def test_get_withdrawable_lido_validators(