diff --git a/spec_pythonizer/sanity_check.py b/spec_pythonizer/sanity_check.py index df84719e..8371883d 100644 --- a/spec_pythonizer/sanity_check.py +++ b/spec_pythonizer/sanity_check.py @@ -1,5 +1,7 @@ from copy import deepcopy +import time +import spec from spec import ( # constants BLS_WITHDRAWAL_PREFIX_BYTE, @@ -15,6 +17,9 @@ SLOTS_PER_HISTORICAL_ROOT, ZERO_HASH, # SSZ + Bytes32, + List, + Epoch, Attestation, AttestationData, BeaconBlockHeader, @@ -25,6 +30,7 @@ Transfer, ProposerSlashing, Validator, + ValidatorIndex, VoluntaryExit, # functions int_to_bytes48, @@ -40,6 +46,8 @@ state_transition, cache_state, verify_merkle_branch, + hash_tree_root, + hash ) from utils.merkle_minimal import ( calc_merkle_tree_from_leaves, @@ -49,9 +57,18 @@ from hashlib import sha256 +def timeit(method): + def timed(*args, **kw): + ts = time.time() + result = method(*args, **kw) + te = time.time() -def hash(x): return sha256(x).digest() + print('%r %2.2f ms' % \ + (method.__name__, (te - ts) * 1000)) + return result + + return timed pubkeys = [int_to_bytes48(i) for i in range(10000)] all_deposit_data_leaves = list() @@ -173,6 +190,7 @@ def build_attestation_data(state, slot, shard): ) +@timeit def test_slot_transition(state): test_state = deepcopy(state) cache_state(test_state) @@ -182,6 +200,7 @@ def test_slot_transition(state): return test_state +@timeit def test_empty_block_transition(state): test_state = deepcopy(state) @@ -194,6 +213,7 @@ def test_empty_block_transition(state): return [block], test_state +@timeit def test_skipped_slots(state): test_state = deepcopy(state) block = construct_empty_block_for_next_slot(test_state) @@ -206,6 +226,7 @@ def test_skipped_slots(state): assert get_block_root(test_state, slot) == block.previous_block_root +@timeit def test_empty_epoch_transition(state): test_state = deepcopy(state) block = construct_empty_block_for_next_slot(test_state) @@ -218,6 +239,7 @@ def test_empty_epoch_transition(state): assert get_block_root(test_state, slot) == block.previous_block_root +@timeit def test_empty_epoch_transition_not_finalizing(state): test_state = deepcopy(state) block = construct_empty_block_for_next_slot(test_state) @@ -229,6 +251,7 @@ def test_empty_epoch_transition_not_finalizing(state): assert test_state.finalized_epoch < get_current_epoch(test_state) - 4 +@timeit def test_proposer_slashing(state): test_state = deepcopy(state) current_epoch = get_current_epoch(test_state) @@ -269,6 +292,7 @@ def test_proposer_slashing(state): assert test_state.validator_balances[validator_index] < state.validator_balances[validator_index] +@timeit def test_deposit_in_block(state): test_state = deepcopy(state) test_deposit_data_leaves = deepcopy(all_deposit_data_leaves) @@ -309,6 +333,7 @@ def test_deposit_in_block(state): assert test_state.validator_registry[index].pubkey == pubkeys[index] +@timeit def test_attestation(state): test_state = deepcopy(state) current_epoch = get_current_epoch(test_state) @@ -353,6 +378,7 @@ def test_attestation(state): assert test_state.previous_epoch_attestations == pre_current_epoch_attestations +@timeit def test_voluntary_exit(state): test_state = deepcopy(state) current_epoch = get_current_epoch(test_state) @@ -391,6 +417,7 @@ def test_voluntary_exit(state): assert test_state.validator_registry[validator_index].exit_epoch < FAR_FUTURE_EPOCH +@timeit def test_transfer(state): test_state = deepcopy(state) current_epoch = get_current_epoch(test_state) @@ -429,6 +456,7 @@ def test_transfer(state): assert recipient_balance == pre_transfer_recipient_balance + amount +@timeit def test_ejection(state): test_state = deepcopy(state) @@ -451,6 +479,7 @@ def test_ejection(state): assert test_state.validator_registry[validator_index].exit_epoch < FAR_FUTURE_EPOCH +@timeit def test_historical_batch(state): test_state = deepcopy(state) @@ -464,39 +493,70 @@ def test_historical_batch(state): assert len(test_state.historical_roots) == len(state.historical_roots) + 1 +@timeit def sanity_tests(): print("Buidling state with 100 validators...") genesis_state = create_genesis_state(num_validators=100) print("done!") print() - print("Running some sanity check tests...") + print("Running some sanity check tests...\n") test_slot_transition(genesis_state) - print("Passed slot transition test") + print("Passed slot transition test\n") test_empty_block_transition(genesis_state) - print("Passed empty block transition test") + print("Passed empty block transition test\n") test_skipped_slots(genesis_state) - print("Passed skipped slot test") + print("Passed skipped slot test\n") test_empty_epoch_transition(genesis_state) - print("Passed empty epoch transition test") + print("Passed empty epoch transition test\n") test_empty_epoch_transition_not_finalizing(genesis_state) - print("Passed non-finalizing epoch test") + print("Passed non-finalizing epoch test\n") test_proposer_slashing(genesis_state) - print("Passed proposer slashing test") + print("Passed proposer slashing test\n") test_attestation(genesis_state) - print("Passed attestation test") + print("Passed attestation test\n") test_deposit_in_block(genesis_state) - print("Passed deposit test") + print("Passed deposit test\n") test_voluntary_exit(genesis_state) - print("Passed voluntary exit test") + print("Passed voluntary exit test\n") test_transfer(genesis_state) - print("Passed transfer test") + print("Passed transfer test\n") test_ejection(genesis_state) - print("Passed ejection test") + print("Passed ejection test\n") test_historical_batch(genesis_state) - print("Passed historical batch test") + print("Passed historical batch test\n") print("done!") +# Monkey patch validator shuffling cache +_get_shuffling = spec.get_shuffling +shuffling_cache = {} +def get_shuffling(seed: Bytes32, + validators: List[Validator], + epoch: Epoch) -> List[List[ValidatorIndex]]: + + param_hash = (seed, hash_tree_root(validators, [Validator]), epoch) + + if param_hash in shuffling_cache: + #print("Cache hit, epoch={0}".format(epoch)) + return shuffling_cache[param_hash] + else: + #print("Cache miss, epoch={0}".format(epoch)) + ret = _get_shuffling(seed, validators, epoch) + shuffling_cache[param_hash] = ret + return ret + +spec.get_shuffling = get_shuffling + +hash_cache = {} +def hash(x): + if x in hash_cache: + return hash_cache[x] + else: + ret = sha256(x).digest() + hash_cache[x] = ret + return ret + +spec.hash = hash if __name__ == "__main__": sanity_tests()