Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
hwwhww committed Apr 30, 2020
1 parent 8c21735 commit b1b739d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 43 deletions.
56 changes: 29 additions & 27 deletions specs/phase1/shard-fork-choice.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ This document is the shard chain fork choice spec for part of Ethereum 2.0 Phase

```python
@dataclass
class Store(object):
class Store:

@dataclass
class ShardStore:
blocks: Dict[Root, ShardBlock] = field(default_factory=dict)
block_states: Dict[Root, ShardState] = field(default_factory=dict)

time: uint64
genesis_time: uint64
justified_checkpoint: Checkpoint
Expand All @@ -46,17 +52,13 @@ class Store(object):
checkpoint_states: Dict[Checkpoint, BeaconState] = field(default_factory=dict)
latest_messages: Dict[ValidatorIndex, LatestMessage] = field(default_factory=dict)
# shard chain
shard_init_slots: Dict[Shard, Slot] = field(default_factory=dict)
shard_blocks: Dict[Shard, Dict[Root, ShardBlock]] = field(default_factory=dict)
shard_block_states: Dict[Shard, Dict[Root, ShardState]] = field(default_factory=dict)
shards: Dict[Shard, ShardStore] = field(default_factory=dict) # noqa: F821
```

#### Updated `get_forkchoice_store`

```python
def get_forkchoice_store(anchor_state: BeaconState,
shard_init_slots: Dict[Shard, Slot],
anchor_state_shard_blocks: Dict[Shard, Dict[Root, ShardBlock]]) -> Store:
def get_forkchoice_store(anchor_state: BeaconState) -> Store:
shard_count = len(anchor_state.shard_states)
anchor_block_header = anchor_state.latest_block_header.copy()
if anchor_block_header.state_root == Bytes32():
Expand All @@ -65,6 +67,14 @@ def get_forkchoice_store(anchor_state: BeaconState,
anchor_epoch = get_current_epoch(anchor_state)
justified_checkpoint = Checkpoint(epoch=anchor_epoch, root=anchor_root)
finalized_checkpoint = Checkpoint(epoch=anchor_epoch, root=anchor_root)

shard_stores = {}
for shard in map(Shard, range(shard_count)):
shard_stores[shard] = Store.ShardStore(
blocks={anchor_state.shard_states[shard].latest_block_root: ShardBlock(slot=anchor_state.slot)},
block_states={anchor_state.shard_states[shard].latest_block_root: anchor_state.copy().shard_states[shard]},
)

return Store(
time=anchor_state.genesis_time + SECONDS_PER_SLOT * anchor_state.slot,
genesis_time=anchor_state.genesis_time,
Expand All @@ -75,14 +85,7 @@ def get_forkchoice_store(anchor_state: BeaconState,
block_states={anchor_root: anchor_state.copy()},
checkpoint_states={justified_checkpoint: anchor_state.copy()},
# shard chain
shard_init_slots=shard_init_slots,
shard_blocks=anchor_state_shard_blocks,
shard_block_states={
shard: {
anchor_state.shard_states[shard].latest_block_root: anchor_state.copy().shard_states[shard]
}
for shard in map(Shard, range(shard_count))
},
shards=shard_stores,
)
```

Expand All @@ -96,7 +99,7 @@ def get_shard_latest_attesting_balance(store: Store, shard: Shard, root: Root) -
state.validators[i].effective_balance for i in active_indices
if (
i in store.latest_messages and get_shard_ancestor(
store, shard, store.latest_messages[i].root, store.shard_blocks[shard][root].slot
store, shard, store.latest_messages[i].root, store.shards[shard].blocks[root].slot
) == root
)
))
Expand Down Expand Up @@ -127,7 +130,7 @@ def get_shard_head(store: Store, shard: Shard) -> Root:

```python
def get_shard_ancestor(store: Store, shard: Shard, root: Root, slot: Slot) -> Root:
block = store.shard_blocks[shard][root]
block = store.shards[shard].blocks[root]
if block.slot > slot:
return get_shard_ancestor(store, shard, block.shard_parent_root, slot)
elif block.slot == slot:
Expand All @@ -141,13 +144,11 @@ def get_shard_ancestor(store: Store, shard: Shard, root: Root, slot: Slot) -> Ro

```python
def filter_shard_block_tree(store: Store, shard: Shard, block_root: Root, blocks: Dict[Root, ShardBlock]) -> bool:
block = store.shard_blocks[shard][block_root]
shard_store = store.shards[shard]
block = shard_store.blocks[block_root]
children = [
root for root in store.shard_blocks[shard].keys()
if (
store.shard_blocks[shard][root].shard_parent_root == block_root
and store.shard_blocks[shard][root].slot != store.shard_init_slots[shard]
)
root for root in shard_store.blocks.keys()
if shard_store.blocks[root].shard_parent_root == block_root
]

if any(children):
Expand Down Expand Up @@ -178,10 +179,11 @@ def get_filtered_shard_block_tree(store: Store, shard: Shard) -> Dict[Root, Shar
```python
def on_shard_block(store: Store, shard: Shard, signed_shard_block: SignedShardBlock) -> None:
shard_block = signed_shard_block.message
shard_store = store.shards[shard]

# 1. Check shard parent exists
assert shard_block.shard_parent_root in store.shard_block_states[shard]
pre_shard_state = store.shard_block_states[shard][shard_block.shard_parent_root]
assert shard_block.shard_parent_root in shard_store.block_states
pre_shard_state = shard_store.block_states[shard_block.shard_parent_root]

# 2. Check beacon parent exists
assert shard_block.beacon_parent_root in store.block_states
Expand All @@ -198,12 +200,12 @@ def on_shard_block(store: Store, shard: Shard, signed_shard_block: SignedShardBl
)

# Add new block to the store
store.shard_blocks[shard][hash_tree_root(shard_block)] = shard_block
shard_store.blocks[hash_tree_root(shard_block)] = shard_block

# Check the block is valid and compute the post-state
verify_shard_block_message(beacon_state, pre_shard_state, shard_block, shard_block.slot, shard)
verify_shard_block_signature(beacon_state, signed_shard_block)
post_state = get_post_shard_state(beacon_state, pre_shard_state, shard_block)
# Add new state for this block to the store
store.shard_block_states[shard][hash_tree_root(shard_block)] = post_state
shard_store.block_states[hash_tree_root(shard_block)] = post_state
```
18 changes: 2 additions & 16 deletions tests/core/pyspec/eth2spec/test/fork_choice/test_on_shard_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def run_on_shard_block(spec, store, shard, signed_block, valid=True):
assert False

spec.on_shard_block(store, shard, signed_block)
assert store.shard_blocks[shard][hash_tree_root(signed_block.message)] == signed_block.message
assert store.shards[shard].blocks[hash_tree_root(signed_block.message)] == signed_block.message


def run_apply_shard_and_beacon(spec, state, store, shard, committee_index):
Expand Down Expand Up @@ -72,21 +72,7 @@ def test_basic(spec, state):
next_slot(spec, state)

# Initialization
shard_count = len(state.shard_states)
# Genesis shard blocks
anchor_shard_blocks = {
shard: {
state.shard_states[shard].latest_block_root: spec.ShardBlock(
slot=state.slot,
)
}
for shard in map(spec.Shard, range(shard_count))
}
shard_init_slots = {
shard: state.slot
for shard in map(spec.Shard, range(shard_count))
}
store = spec.get_forkchoice_store(state, shard_init_slots, anchor_shard_blocks)
store = spec.get_forkchoice_store(state)
anchor_root = get_anchor_root(spec, state)
assert spec.get_head(store) == anchor_root

Expand Down

0 comments on commit b1b739d

Please sign in to comment.