Skip to content

Commit

Permalink
Add monkey patch for validator shuffling cache
Browse files Browse the repository at this point in the history
  • Loading branch information
dankrad committed Mar 12, 2019
1 parent 812d961 commit f17fcbf
Showing 1 changed file with 74 additions and 14 deletions.
88 changes: 74 additions & 14 deletions spec_pythonizer/sanity_check.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from copy import deepcopy
import time

import spec
from spec import (
# constants
BLS_WITHDRAWAL_PREFIX_BYTE,
Expand All @@ -15,6 +17,9 @@
SLOTS_PER_HISTORICAL_ROOT,
ZERO_HASH,
# SSZ
Bytes32,
List,
Epoch,
Attestation,
AttestationData,
BeaconBlockHeader,
Expand All @@ -25,6 +30,7 @@
Transfer,
ProposerSlashing,
Validator,
ValidatorIndex,
VoluntaryExit,
# functions
int_to_bytes48,
Expand All @@ -40,6 +46,8 @@
state_transition,
cache_state,
verify_merkle_branch,
hash_tree_root,
hash
)
from utils.merkle_minimal import (
calc_merkle_tree_from_leaves,
Expand All @@ -49,9 +57,18 @@

from hashlib import sha256

def timeit(method):
def timed(*args, **kw):
ts = time.time()
result = method(*args, **kw)
te = time.time()

def hash(x): return sha256(x).digest()
print('%r %2.2f ms' % \
(method.__name__, (te - ts) * 1000))

return result

return timed

pubkeys = [int_to_bytes48(i) for i in range(10000)]
all_deposit_data_leaves = list()
Expand Down Expand Up @@ -173,6 +190,7 @@ def build_attestation_data(state, slot, shard):
)


@timeit
def test_slot_transition(state):
test_state = deepcopy(state)
cache_state(test_state)
Expand All @@ -182,6 +200,7 @@ def test_slot_transition(state):
return test_state


@timeit
def test_empty_block_transition(state):
test_state = deepcopy(state)

Expand All @@ -194,6 +213,7 @@ def test_empty_block_transition(state):
return [block], test_state


@timeit
def test_skipped_slots(state):
test_state = deepcopy(state)
block = construct_empty_block_for_next_slot(test_state)
Expand All @@ -206,6 +226,7 @@ def test_skipped_slots(state):
assert get_block_root(test_state, slot) == block.previous_block_root


@timeit
def test_empty_epoch_transition(state):
test_state = deepcopy(state)
block = construct_empty_block_for_next_slot(test_state)
Expand All @@ -218,6 +239,7 @@ def test_empty_epoch_transition(state):
assert get_block_root(test_state, slot) == block.previous_block_root


@timeit
def test_empty_epoch_transition_not_finalizing(state):
test_state = deepcopy(state)
block = construct_empty_block_for_next_slot(test_state)
Expand All @@ -229,6 +251,7 @@ def test_empty_epoch_transition_not_finalizing(state):
assert test_state.finalized_epoch < get_current_epoch(test_state) - 4


@timeit
def test_proposer_slashing(state):
test_state = deepcopy(state)
current_epoch = get_current_epoch(test_state)
Expand Down Expand Up @@ -269,6 +292,7 @@ def test_proposer_slashing(state):
assert test_state.validator_balances[validator_index] < state.validator_balances[validator_index]


@timeit
def test_deposit_in_block(state):
test_state = deepcopy(state)
test_deposit_data_leaves = deepcopy(all_deposit_data_leaves)
Expand Down Expand Up @@ -309,6 +333,7 @@ def test_deposit_in_block(state):
assert test_state.validator_registry[index].pubkey == pubkeys[index]


@timeit
def test_attestation(state):
test_state = deepcopy(state)
current_epoch = get_current_epoch(test_state)
Expand Down Expand Up @@ -353,6 +378,7 @@ def test_attestation(state):
assert test_state.previous_epoch_attestations == pre_current_epoch_attestations


@timeit
def test_voluntary_exit(state):
test_state = deepcopy(state)
current_epoch = get_current_epoch(test_state)
Expand Down Expand Up @@ -391,6 +417,7 @@ def test_voluntary_exit(state):
assert test_state.validator_registry[validator_index].exit_epoch < FAR_FUTURE_EPOCH


@timeit
def test_transfer(state):
test_state = deepcopy(state)
current_epoch = get_current_epoch(test_state)
Expand Down Expand Up @@ -429,6 +456,7 @@ def test_transfer(state):
assert recipient_balance == pre_transfer_recipient_balance + amount


@timeit
def test_ejection(state):
test_state = deepcopy(state)

Expand All @@ -451,6 +479,7 @@ def test_ejection(state):
assert test_state.validator_registry[validator_index].exit_epoch < FAR_FUTURE_EPOCH


@timeit
def test_historical_batch(state):
test_state = deepcopy(state)

Expand All @@ -464,39 +493,70 @@ def test_historical_batch(state):
assert len(test_state.historical_roots) == len(state.historical_roots) + 1


@timeit
def sanity_tests():
print("Buidling state with 100 validators...")
genesis_state = create_genesis_state(num_validators=100)
print("done!")
print()

print("Running some sanity check tests...")
print("Running some sanity check tests...\n")
test_slot_transition(genesis_state)
print("Passed slot transition test")
print("Passed slot transition test\n")
test_empty_block_transition(genesis_state)
print("Passed empty block transition test")
print("Passed empty block transition test\n")
test_skipped_slots(genesis_state)
print("Passed skipped slot test")
print("Passed skipped slot test\n")
test_empty_epoch_transition(genesis_state)
print("Passed empty epoch transition test")
print("Passed empty epoch transition test\n")
test_empty_epoch_transition_not_finalizing(genesis_state)
print("Passed non-finalizing epoch test")
print("Passed non-finalizing epoch test\n")
test_proposer_slashing(genesis_state)
print("Passed proposer slashing test")
print("Passed proposer slashing test\n")
test_attestation(genesis_state)
print("Passed attestation test")
print("Passed attestation test\n")
test_deposit_in_block(genesis_state)
print("Passed deposit test")
print("Passed deposit test\n")
test_voluntary_exit(genesis_state)
print("Passed voluntary exit test")
print("Passed voluntary exit test\n")
test_transfer(genesis_state)
print("Passed transfer test")
print("Passed transfer test\n")
test_ejection(genesis_state)
print("Passed ejection test")
print("Passed ejection test\n")
test_historical_batch(genesis_state)
print("Passed historical batch test")
print("Passed historical batch test\n")
print("done!")

# Monkey patch validator shuffling cache
_get_shuffling = spec.get_shuffling
shuffling_cache = {}
def get_shuffling(seed: Bytes32,
validators: List[Validator],
epoch: Epoch) -> List[List[ValidatorIndex]]:

param_hash = (seed, hash_tree_root(validators, [Validator]), epoch)

if param_hash in shuffling_cache:
#print("Cache hit, epoch={0}".format(epoch))
return shuffling_cache[param_hash]
else:
#print("Cache miss, epoch={0}".format(epoch))
ret = _get_shuffling(seed, validators, epoch)
shuffling_cache[param_hash] = ret
return ret

spec.get_shuffling = get_shuffling

hash_cache = {}
def hash(x):
if x in hash_cache:
return hash_cache[x]
else:
ret = sha256(x).digest()
hash_cache[x] = ret
return ret

spec.hash = hash

if __name__ == "__main__":
sanity_tests()

0 comments on commit f17fcbf

Please sign in to comment.