Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shard fork choice #1970

Merged
merged 10 commits into from
Jul 29, 2020
54 changes: 36 additions & 18 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum, auto
from setuptools import setup, find_packages, Command
from setuptools.command.build_py import build_py
from distutils import dir_util
Expand All @@ -14,6 +15,13 @@ class SpecObject(NamedTuple):
custom_types: Dict[str, str]
constants: Dict[str, str]
ssz_objects: Dict[str, str]
dataclasses: Dict[str, str]


class CodeBlockType(Enum):
SSZ = auto()
DATACLASS = auto()
FUNCTION = auto()


def get_spec(file_name: str) -> SpecObject:
Expand All @@ -28,8 +36,9 @@ def get_spec(file_name: str) -> SpecObject:
functions: Dict[str, str] = {}
constants: Dict[str, str] = {}
ssz_objects: Dict[str, str] = {}
dataclasses: Dict[str, str] = {}
function_matcher = re.compile(FUNCTION_REGEX)
is_ssz = False
block_type = CodeBlockType.FUNCTION
custom_types: Dict[str, str] = {}
for linenum, line in enumerate(open(file_name).readlines()):
line = line.rstrip()
Expand All @@ -43,20 +52,26 @@ def get_spec(file_name: str) -> SpecObject:
else:
# Handle function definitions & ssz_objects
if pulling_from is not None:
# SSZ Object
if len(line) > 18 and line[:6] == 'class ' and line[-12:] == '(Container):':
name = line[6:-12]
# Check consistency with markdown header
assert name == current_name
is_ssz = True
# function definition
block_type = CodeBlockType.SSZ
elif line[:10] == '@dataclass':
block_type = CodeBlockType.DATACLASS
elif function_matcher.match(line) is not None:
current_name = function_matcher.match(line).group(0)
is_ssz = False
if is_ssz:
block_type = CodeBlockType.FUNCTION

if block_type == CodeBlockType.SSZ:
ssz_objects[current_name] = ssz_objects.get(current_name, '') + line + '\n'
else:
elif block_type == CodeBlockType.DATACLASS:
dataclasses[current_name] = dataclasses.get(current_name, '') + line + '\n'
elif block_type == CodeBlockType.FUNCTION:
functions[current_name] = functions.get(current_name, '') + line + '\n'
else:
pass

# Handle constant and custom types table entries
elif pulling_from is None and len(line) > 0 and line[0] == '|':
row = line[1:].split('|')
Expand All @@ -75,7 +90,7 @@ def get_spec(file_name: str) -> SpecObject:
constants[row[0]] = row[1].replace('**TBD**', '2**32')
elif row[1].startswith('uint') or row[1].startswith('Bytes'):
custom_types[row[0]] = row[1]
return SpecObject(functions, custom_types, constants, ssz_objects)
return SpecObject(functions, custom_types, constants, ssz_objects, dataclasses)


CONFIG_LOADER = '''
Expand Down Expand Up @@ -220,7 +235,7 @@ def wrapper(*args, **kw): # type: ignore
_get_start_shard, lru_size=SLOTS_PER_EPOCH * 3)'''


def objects_to_spec(spec_object: SpecObject, imports: str, fork: str) -> str:
def objects_to_spec(spec_object: SpecObject, imports: str, fork: str, ordered_class_objects: Dict[str, str]) -> str:
"""
Given all the objects that constitute a spec, combine them into a single pyfile.
"""
Expand All @@ -240,15 +255,15 @@ def objects_to_spec(spec_object: SpecObject, imports: str, fork: str) -> str:
if k == "BLS12_381_Q":
spec_object.constants[k] += " # noqa: E501"
constants_spec = '\n'.join(map(lambda x: '%s = %s' % (x, spec_object.constants[x]), spec_object.constants))
ssz_objects_instantiation_spec = '\n\n'.join(spec_object.ssz_objects.values())
ordered_class_objects_spec = '\n\n'.join(ordered_class_objects.values())
spec = (
imports
+ '\n\n' + f"fork = \'{fork}\'\n"
+ '\n\n' + new_type_definitions
+ '\n' + SUNDRY_CONSTANTS_FUNCTIONS
+ '\n\n' + constants_spec
+ '\n\n' + CONFIG_LOADER
+ '\n\n' + ssz_objects_instantiation_spec
+ '\n\n' + ordered_class_objects_spec
+ '\n\n' + functions_spec
+ '\n' + PHASE0_SUNDRY_FUNCTIONS
)
Expand All @@ -274,11 +289,12 @@ def combine_constants(old_constants: Dict[str, str], new_constants: Dict[str, st
'bit', 'boolean', 'Vector', 'List', 'Container', 'BLSPubkey', 'BLSSignature',
'Bytes1', 'Bytes4', 'Bytes32', 'Bytes48', 'Bytes96', 'Bitlist', 'Bitvector',
'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256',
'bytes', 'byte', 'ByteList', 'ByteVector'
'bytes', 'byte', 'ByteList', 'ByteVector',
'Dict', 'dict', 'field',
]


def dependency_order_ssz_objects(objects: Dict[str, str], custom_types: Dict[str, str]) -> None:
def dependency_order_class_objects(objects: Dict[str, str], custom_types: Dict[str, str]) -> None:
"""
Determines which SSZ Object is dependent on which other and orders them appropriately
"""
Expand Down Expand Up @@ -315,13 +331,14 @@ def combine_spec_objects(spec0: SpecObject, spec1: SpecObject) -> SpecObject:
"""
Takes in two spec variants (as tuples of their objects) and combines them using the appropriate combiner function.
"""
functions0, custom_types0, constants0, ssz_objects0 = spec0
functions1, custom_types1, constants1, ssz_objects1 = spec1
functions0, custom_types0, constants0, ssz_objects0, dataclasses0 = spec0
functions1, custom_types1, constants1, ssz_objects1, dataclasses1 = spec1
functions = combine_functions(functions0, functions1)
custom_types = combine_constants(custom_types0, custom_types1)
constants = combine_constants(constants0, constants1)
ssz_objects = combine_ssz_objects(ssz_objects0, ssz_objects1, custom_types)
return SpecObject(functions, custom_types, constants, ssz_objects)
dataclasses = combine_functions(dataclasses0, dataclasses1)
return SpecObject(functions, custom_types, constants, ssz_objects, dataclasses)


fork_imports = {
Expand All @@ -337,9 +354,10 @@ def build_spec(fork: str, source_files: List[str]) -> str:
for value in all_specs[1:]:
spec_object = combine_spec_objects(spec_object, value)

dependency_order_ssz_objects(spec_object.ssz_objects, spec_object.custom_types)
class_objects = {**spec_object.ssz_objects, **spec_object.dataclasses}
dependency_order_class_objects(class_objects, spec_object.custom_types)

return objects_to_spec(spec_object, fork_imports[fork], fork)
return objects_to_spec(spec_object, fork_imports[fork], fork, class_objects)


class PySpecCommand(Command):
Expand Down
80 changes: 71 additions & 9 deletions specs/phase1/fork-choice.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@


- [Introduction](#introduction)
- [Helpers](#helpers)
- [Extended `LatestMessage`](#extended-latestmessage)
- [Updated data structures](#updated-data-structures)
- [Extended `Store`](#extended-store)
- [New data structures](#new-data-structures)
- [`ShardLatestMessage`](#shardlatestmessage)
- [`ShardStore`](#shardstore)
- [Updated helpers](#updated-helpers)
- [Updated `get_forkchoice_store`](#updated-get_forkchoice_store)
- [Updated `update_latest_messages`](#updated-update_latest_messages)

<!-- END doctoc generated TOC please keep comment here to allow auto update -->
Expand All @@ -20,17 +25,74 @@

This document is the beacon chain fork choice spec for part of Ethereum 2.0 Phase 1.

### Helpers
### Updated data structures

#### Extended `LatestMessage`
#### Extended `Store`

```python
@dataclass
class Store(object):
time: uint64
genesis_time: uint64
justified_checkpoint: Checkpoint
finalized_checkpoint: Checkpoint
best_justified_checkpoint: Checkpoint
blocks: Dict[Root, BeaconBlock] = field(default_factory=dict)
block_states: Dict[Root, BeaconState] = field(default_factory=dict)
checkpoint_states: Dict[Checkpoint, BeaconState] = field(default_factory=dict)
latest_messages: Dict[ValidatorIndex, LatestMessage] = field(default_factory=dict)
shard_stores: Dict[Shard, ShardStore] = field(default_factory=dict)
```

### New data structures

#### `ShardLatestMessage`

```python
@dataclass(eq=True, frozen=True)
class LatestMessage(object):
class ShardLatestMessage(object):
epoch: Epoch
root: Root
```

#### `ShardStore`

```python
@dataclass
class ShardStore:
shard: Shard
shard_root: Root
signed_blocks: Dict[Root, SignedShardBlock] = field(default_factory=dict)
block_states: Dict[Root, ShardState] = field(default_factory=dict)
latest_messages: Dict[ValidatorIndex, ShardLatestMessage] = field(default_factory=dict)
```

### Updated helpers

#### Updated `get_forkchoice_store`

```python
def get_forkchoice_store(anchor_state: BeaconState) -> Store:
anchor_block_header = anchor_state.latest_block_header.copy()
if anchor_block_header.state_root == Bytes32():
anchor_block_header.state_root = hash_tree_root(anchor_state)
anchor_root = hash_tree_root(anchor_block_header)
anchor_epoch = get_current_epoch(anchor_state)
justified_checkpoint = Checkpoint(epoch=anchor_epoch, root=anchor_root)
finalized_checkpoint = Checkpoint(epoch=anchor_epoch, root=anchor_root)
return Store(
time=anchor_state.genesis_time + SECONDS_PER_SLOT * anchor_state.slot,
genesis_time=anchor_state.genesis_time,
justified_checkpoint=justified_checkpoint,
finalized_checkpoint=finalized_checkpoint,
best_justified_checkpoint=justified_checkpoint,
blocks={anchor_root: anchor_block_header},
block_states={anchor_root: anchor_state.copy()},
checkpoint_states={justified_checkpoint: anchor_state.copy()},
shard_stores={
Shard(shard): get_forkchoice_shard_store(anchor_state, Shard(shard))
for shard in range(get_active_shard_count(anchor_state))
}
)
```

#### Updated `update_latest_messages`
Expand All @@ -43,7 +105,7 @@ def update_latest_messages(store: Store, attesting_indices: Sequence[ValidatorIn
shard = attestation.data.shard
for i in attesting_indices:
if i not in store.latest_messages or target.epoch > store.latest_messages[i].epoch:
store.latest_messages[i] = LatestMessage(
epoch=target.epoch, root=beacon_block_root, shard=shard, shard_root=attestation.data.shard_head_root
)
store.latest_messages[i] = LatestMessage(epoch=target.epoch, root=beacon_block_root)
shard_latest_message = ShardLatestMessage(epoch=target.epoch, root=attestation.data.shard_head_root)
store.shard_stores[shard].latest_messages[i] = shard_latest_message
```
52 changes: 23 additions & 29 deletions specs/phase1/shard-fork-choice.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
- [Introduction](#introduction)
- [Fork choice](#fork-choice)
- [Helpers](#helpers)
- [`ShardStore`](#shardstore)
- [`get_forkchoice_shard_store`](#get_forkchoice_shard_store)
- [`get_shard_latest_attesting_balance`](#get_shard_latest_attesting_balance)
- [`get_shard_head`](#get_shard_head)
Expand All @@ -30,16 +29,6 @@ This document is the shard chain fork choice spec for part of Ethereum 2.0 Phase

### Helpers

#### `ShardStore`

```python
@dataclass
class ShardStore:
shard: Shard
signed_blocks: Dict[Root, SignedShardBlock] = field(default_factory=dict)
block_states: Dict[Root, ShardState] = field(default_factory=dict)
```

#### `get_forkchoice_shard_store`

```python
Expand All @@ -58,18 +47,21 @@ def get_forkchoice_shard_store(anchor_state: BeaconState, shard: Shard) -> Shard
#### `get_shard_latest_attesting_balance`

```python
def get_shard_latest_attesting_balance(store: Store, shard_store: ShardStore, root: Root) -> Gwei:
def get_shard_latest_attesting_balance(store: Store, shard: Shard, root: Root) -> Gwei:
shard_store = store.shard_stores[shard]
state = store.checkpoint_states[store.justified_checkpoint]
active_indices = get_active_validator_indices(state, get_current_epoch(state))
return Gwei(sum(
state.validators[i].effective_balance for i in active_indices
if (
i in store.latest_messages
i in shard_store.latest_messages
# TODO: check the latest message logic: currently, validator's previous vote of another shard
# would be ignored once their newer vote is accepted. Check if it makes sense.
and store.latest_messages[i].shard == shard_store.shard
and get_shard_ancestor(
store, shard_store, store.latest_messages[i].shard_root, shard_store.signed_blocks[root].message.slot
store,
shard,
shard_store.latest_messages[i].root,
shard_store.signed_blocks[root].message.slot,
) == root
)
))
Expand All @@ -78,10 +70,14 @@ def get_shard_latest_attesting_balance(store: Store, shard_store: ShardStore, ro
#### `get_shard_head`

```python
def get_shard_head(store: Store, shard_store: ShardStore) -> Root:
def get_shard_head(store: Store, shard: Shard) -> Root:
# Execute the LMD-GHOST fork choice
"""
Execute the LMD-GHOST fork choice.
"""
shard_store = store.shard_stores[shard]
beacon_head_root = get_head(store)
shard_head_state = store.block_states[beacon_head_root].shard_states[shard_store.shard]
shard_head_state = store.block_states[beacon_head_root].shard_states[shard]
shard_head_root = shard_head_state.latest_block_root
shard_blocks = {
root: signed_shard_block.message for root, signed_shard_block in shard_store.signed_blocks.items()
Expand All @@ -97,17 +93,18 @@ def get_shard_head(store: Store, shard_store: ShardStore) -> Root:
return shard_head_root
# Sort by latest attesting balance with ties broken lexicographically
shard_head_root = max(
children, key=lambda root: (get_shard_latest_attesting_balance(store, shard_store, root), root)
children, key=lambda root: (get_shard_latest_attesting_balance(store, shard, root), root)
)
```

#### `get_shard_ancestor`

```python
def get_shard_ancestor(store: Store, shard_store: ShardStore, root: Root, slot: Slot) -> Root:
def get_shard_ancestor(store: Store, shard: Shard, root: Root, slot: Slot) -> Root:
shard_store = store.shard_stores[shard]
block = shard_store.signed_blocks[root].message
if block.slot > slot:
return get_shard_ancestor(store, shard_store, block.shard_parent_root, slot)
return get_shard_ancestor(store, shard, block.shard_parent_root, slot)
elif block.slot == slot:
return root
else:
Expand All @@ -118,17 +115,17 @@ def get_shard_ancestor(store: Store, shard_store: ShardStore, root: Root, slot:
#### `get_pending_shard_blocks`

```python
def get_pending_shard_blocks(store: Store, shard_store: ShardStore) -> Sequence[SignedShardBlock]:
def get_pending_shard_blocks(store: Store, shard: Shard) -> Sequence[SignedShardBlock]:
"""
Return the canonical shard block branch that has not yet been crosslinked.
"""
shard = shard_store.shard
shard_store = store.shard_stores[shard]

beacon_head_root = get_head(store)
beacon_head_state = store.block_states[beacon_head_root]
latest_shard_block_root = beacon_head_state.shard_states[shard].latest_block_root

shard_head_root = get_shard_head(store, shard_store)
shard_head_root = get_shard_head(store, shard)
root = shard_head_root
signed_shard_blocks = []
while root != latest_shard_block_root:
Expand All @@ -145,13 +142,10 @@ def get_pending_shard_blocks(store: Store, shard_store: ShardStore) -> Sequence[
#### `on_shard_block`

```python
def on_shard_block(store: Store, shard_store: ShardStore, signed_shard_block: SignedShardBlock) -> None:
def on_shard_block(store: Store, signed_shard_block: SignedShardBlock) -> None:
shard_block = signed_shard_block.message
shard = shard_store.shard

# Check shard
# TODO: check it in networking spec
assert shard_block.shard == shard
shard = shard_block.shard
shard_store = store.shard_stores[shard]

# Check shard parent exists
assert shard_block.shard_parent_root in shard_store.block_states
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ def run_on_attestation(spec, state, store, attestation, valid=True):
latest_message = spec.LatestMessage(
epoch=attestation.data.target.epoch,
root=attestation.data.beacon_block_root,
shard=attestation.data.shard,
shard_root=attestation.data.shard_head_root,
)
shard_latest_message = spec.ShardLatestMessage(
epoch=attestation.data.target.epoch,
root=attestation.data.shard_head_root,
)
assert store.shard_stores[attestation.data.shard].latest_messages[sample_index] == shard_latest_message

assert (
store.latest_messages[sample_index] == latest_message
Expand Down
Loading