Skip to content
This repository has been archived by the owner on Jul 1, 2021. It is now read-only.

Add get_permuted_index and new swap-or-not shuffle #363

Merged
merged 4 commits into from
Mar 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 79 additions & 50 deletions eth2/beacon/_utils/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,78 +11,107 @@
)
from eth_utils import (
to_tuple,
ValidationError,
)

from eth2.beacon._utils.hash import (
hash_eth2,
)
from eth2.beacon.constants import (
RAND_BYTES,
RAND_MAX,
POWER_OF_TWO_NUMBERS,
MAX_LIST_SIZE,
)


TItem = TypeVar('TItem')


def get_permuted_index(index: int,
list_size: int,
seed: Hash32,
shuffle_round_count: int) -> int:
"""
Return `p(index)` in a pseudorandom permutation `p` of `0...list_size-1`
with ``seed`` as entropy.

Utilizes 'swap or not' shuffling found in
https://link.springer.com/content/pdf/10.1007%2F978-3-642-32009-5_1.pdf
See the 'generalized domain' algorithm on page 3.
"""
if index >= list_size:
raise ValidationError(
f"The given `index` ({index}) should be less than `list_size` ({list_size}"
)

if list_size > MAX_LIST_SIZE:
raise ValidationError(
f"The given `list_size` ({list_size}) should be equal to or less than "
f"`MAX_LIST_SIZE` ({MAX_LIST_SIZE}"
)

new_index = index
for round in range(shuffle_round_count):
pivot = int.from_bytes(
hash_eth2(seed + round.to_bytes(1, 'little'))[0:8],
'little',
) % list_size

flip = (pivot - new_index) % list_size
hash_pos = max(new_index, flip)
h = hash_eth2(seed + round.to_bytes(1, 'little') + (hash_pos // 256).to_bytes(4, 'little'))
byte = h[(hash_pos % 256) // 8]
bit = (byte >> (hash_pos % 8)) % 2
new_index = flip if bit else new_index

return new_index


@to_tuple
def shuffle(values: Sequence[TItem],
seed: Hash32) -> Iterable[TItem]:
seed: Hash32,
shuffle_round_count: int=90) -> Iterable[TItem]:
"""
Return the shuffled ``values`` with ``seed`` as entropy.
Mainly for shuffling active validators in-protocol.
Return shuffled indices in a pseudorandom permutation `0...list_size-1` with
``seed`` as entropy.

Spec: https://github.com/ethereum/eth2.0-specs/blob/70cef14a08de70e7bd0455d75cf380eb69694bfb/specs/core/0_beacon-chain.md#helper-functions # noqa: E501
Utilizes 'swap or not' shuffling found in
https://link.springer.com/content/pdf/10.1007%2F978-3-642-32009-5_1.pdf
See the 'generalized domain' algorithm on page 3.
"""
values_count = len(values)

# The range of the RNG places an upper-bound on the size of the list that
# may be shuffled. It is a logic error to supply an oversized list.
if values_count >= RAND_MAX:
raise ValueError(
"values_count (%s) should less than RAND_MAX (%s)." %
(values_count, RAND_MAX)
list_size = len(values)

if list_size > MAX_LIST_SIZE:
raise ValidationError(
f"The `list_size` ({list_size}) should be equal to or less than "
f"`MAX_LIST_SIZE` ({MAX_LIST_SIZE}"
)

indices = list(range(list_size))
for round in range(shuffle_round_count):
hash_bytes = b''.join(
[
hash_eth2(seed + round.to_bytes(1, 'little') + i.to_bytes(4, 'little'))
for i in range((list_size + 255) // 256)
]
)

output = [x for x in values]
source = seed
index = 0
while index < values_count - 1:
# Re-hash the `source` to obtain a new pattern of bytes.
source = hash_eth2(source)

# Iterate through the `source` bytes in 3-byte chunks.
for position in range(0, 32 - (32 % RAND_BYTES), RAND_BYTES):
# Determine the number of indices remaining in `values` and exit
# once the last index is reached.
remaining = values_count - index
if remaining == 1:
break

# Read 3-bytes of `source` as a 24-bit little-endian integer.
sample_from_source = int.from_bytes(
source[position:position + RAND_BYTES], 'little'
)

# Sample values greater than or equal to `sample_max` will cause
# modulo bias when mapped into the `remaining` range.
sample_max = RAND_MAX - RAND_MAX % remaining

# Perform a swap if the consumed entropy will not cause modulo bias.
if sample_from_source < sample_max:
# Select a replacement index for the current index.
replacement_position = (sample_from_source % remaining) + index
# Swap the current index with the replacement index.
(output[index], output[replacement_position]) = (
output[replacement_position],
output[index]
)
index += 1
pivot = int.from_bytes(
hash_eth2(seed + round.to_bytes(1, 'little'))[:8],
'little',
) % list_size
for i in range(list_size):
flip = (pivot - indices[i]) % list_size
hash_position = indices[i] if indices[i] > flip else flip
byte = hash_bytes[hash_position // 8]
mask = POWER_OF_TWO_NUMBERS[hash_position % 8]
if byte & mask:
indices[i] = flip
else:
# The sample causes modulo bias. A new sample should be read.
# not swap
pass

return output
for i in indices:
yield values[i]


def split(values: Sequence[TItem], split_count: int) -> Tuple[Iterable[TItem], ...]:
Expand Down
20 changes: 12 additions & 8 deletions eth2/beacon/committee_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def get_shuffling(*,
seed: Hash32,
validators: Sequence['ValidatorRecord'],
epoch: Epoch,
slots_per_epoch: int,
target_committee_size: int,
shard_count: int) -> Tuple[Iterable[ValidatorIndex], ...]:
committee_config: CommitteeConfig) -> Tuple[Iterable[ValidatorIndex], ...]:
"""
Shuffle ``validators`` into crosslink committees seeded by ``seed`` and ``epoch``.
Return a list of ``committee_per_epoch`` committees where each
Expand All @@ -86,6 +84,11 @@ def get_shuffling(*,
of ``validators`` forever in phase 0, and until the ~1 year deletion delay in phase 2
and in the future.
"""
slots_per_epoch = committee_config.SLOTS_PER_EPOCH
target_committee_size = committee_config.TARGET_COMMITTEE_SIZE
shard_count = committee_config.SHARD_COUNT
shuffle_round_count = committee_config.SHUFFLE_ROUND_COUNT

active_validator_indices = get_active_validator_indices(validators, epoch)

committees_per_epoch = get_epoch_committee_count(
Expand All @@ -96,7 +99,11 @@ def get_shuffling(*,
)

# Shuffle
shuffled_active_validator_indices = shuffle(active_validator_indices, seed)
shuffled_active_validator_indices = shuffle(
active_validator_indices,
seed,
shuffle_round_count=shuffle_round_count,
)

# Split the shuffled list into committees_per_epoch pieces
return tuple(
Expand Down Expand Up @@ -282,7 +289,6 @@ def get_crosslink_committees_at_slot(
genesis_epoch = committee_config.GENESIS_EPOCH
shard_count = committee_config.SHARD_COUNT
slots_per_epoch = committee_config.SLOTS_PER_EPOCH
target_committee_size = committee_config.TARGET_COMMITTEE_SIZE

epoch = slot_to_epoch(slot, slots_per_epoch)
current_epoch = state.current_epoch(slots_per_epoch)
Expand Down Expand Up @@ -324,9 +330,7 @@ def get_crosslink_committees_at_slot(
seed=shuffling_context.seed,
validators=state.validator_registry,
epoch=shuffling_context.shuffling_epoch,
slots_per_epoch=slots_per_epoch,
target_committee_size=target_committee_size,
shard_count=shard_count,
committee_config=committee_config,
)
offset = slot % slots_per_epoch
committees_per_slot = shuffling_context.committees_per_epoch // slots_per_epoch
Expand Down
4 changes: 4 additions & 0 deletions eth2/beacon/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
('MAX_BALANCE_CHURN_QUOTIENT', int),
('BEACON_CHAIN_SHARD_NUMBER', Shard),
('MAX_INDICES_PER_SLASHABLE_VOTE', int),
('MAX_EXIT_DEQUEUES_PER_EPOCH', int),
('SHUFFLE_ROUND_COUNT', int),
# State list lengths
('LATEST_BLOCK_ROOTS_LENGTH', int),
('LATEST_ACTIVE_INDEX_ROOTS_LENGTH', int),
('LATEST_RANDAO_MIXES_LENGTH', int),
Expand Down Expand Up @@ -73,6 +76,7 @@ def __init__(self, config: BeaconConfig):
self.SHARD_COUNT = config.SHARD_COUNT
self.SLOTS_PER_EPOCH = config.SLOTS_PER_EPOCH
self.TARGET_COMMITTEE_SIZE = config.TARGET_COMMITTEE_SIZE
self.SHUFFLE_ROUND_COUNT = config.SHUFFLE_ROUND_COUNT

# For seed
self.MIN_SEED_LOOKAHEAD = config.MIN_SEED_LOOKAHEAD
Expand Down
20 changes: 7 additions & 13 deletions eth2/beacon/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,15 @@
)


#
# shuffle function
#

# The size of 3 bytes in integer
# sample_range = 2 ** (3 * 8) = 2 ** 24 = 16777216
# sample_range = 16777216

# Entropy is consumed from the seed in 3-byte (24 bit) chunks.
RAND_BYTES = 3
# The highest possible result of the RNG.
RAND_MAX = 2 ** (RAND_BYTES * 8) - 1

EMPTY_SIGNATURE = BLSSignature(b'\x00' * 96)
GWEI_PER_ETH = 10**9
FAR_FUTURE_EPOCH = Epoch(2**64 - 1)

GENESIS_PARENT_ROOT = ZERO_HASH32

#
# shuffle function
#

POWER_OF_TWO_NUMBERS = [1, 2, 4, 8, 16, 32, 64, 128]
MAX_LIST_SIZE = 2**40
3 changes: 3 additions & 0 deletions eth2/beacon/state_machines/forks/serenity/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
MAX_BALANCE_CHURN_QUOTIENT=2**5, # (= 32)
BEACON_CHAIN_SHARD_NUMBER=Shard(2**64 - 1),
MAX_INDICES_PER_SLASHABLE_VOTE=2**12, # (= 4,096) votes
MAX_EXIT_DEQUEUES_PER_EPOCH=2**2, # (= 4)
SHUFFLE_ROUND_COUNT=90,
# State list lengths
LATEST_BLOCK_ROOTS_LENGTH=2**13, # (= 8,192) slots
LATEST_ACTIVE_INDEX_ROOTS_LENGTH=2**13, # (= 8,192) epochs
LATEST_RANDAO_MIXES_LENGTH=2**13, # (= 8,192) epochs
Expand Down
38 changes: 30 additions & 8 deletions tests/eth2/beacon/_utils/test_random.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,49 @@
import pytest

from eth_utils import (
ValidationError,
)

from eth2.beacon._utils.random import (
get_permuted_index,
shuffle,
)


def slow_shuffle(items, seed, shuffle_round_count):
length = len(items)
return tuple(
[
items[get_permuted_index(i, length, seed, shuffle_round_count)]
for i in range(length)
]
)


@pytest.mark.parametrize(
(
'values,seed,expect'
'values',
'seed',
'shuffle_round_count',
),
[
(
tuple(range(12)),
b'\x23' * 32,
(8, 3, 9, 0, 1, 11, 2, 4, 6, 7, 10, 5),
90,
),
(
tuple(range(2**6))[10:],
b'\x67' * 32,
20,
),
],
)
def test_shuffle_consistent(values, seed, expect):
assert shuffle(values, seed) == expect
def test_shuffle_consistent(values, seed, shuffle_round_count):
expect = slow_shuffle(values, seed, shuffle_round_count)
assert shuffle(values, seed, shuffle_round_count) == expect


def test_shuffle_out_of_bound():
values = [i for i in range(2**24 + 1)]
with pytest.raises(ValueError):
shuffle(values, b'hello')
def test_get_permuted_index_invalid(shuffle_round_count):
with pytest.raises(ValidationError):
get_permuted_index(2, 2, b'\x12' * 32, shuffle_round_count)
14 changes: 14 additions & 0 deletions tests/eth2/beacon/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,16 @@ def max_indices_per_slashable_vote():
return SERENITY_CONFIG.MAX_INDICES_PER_SLASHABLE_VOTE


@pytest.fixture
def max_exit_dequeues_per_epoch():
return SERENITY_CONFIG.MAX_EXIT_DEQUEUES_PER_EPOCH


@pytest.fixture
def shuffle_round_count():
return SERENITY_CONFIG.SHUFFLE_ROUND_COUNT


@pytest.fixture
def latest_block_roots_length():
return SERENITY_CONFIG.LATEST_BLOCK_ROOTS_LENGTH
Expand Down Expand Up @@ -687,6 +697,8 @@ def config(
max_balance_churn_quotient,
beacon_chain_shard_number,
max_indices_per_slashable_vote,
max_exit_dequeues_per_epoch,
shuffle_round_count,
latest_block_roots_length,
latest_active_index_roots_length,
latest_randao_mixes_length,
Expand Down Expand Up @@ -726,6 +738,8 @@ def config(
MAX_BALANCE_CHURN_QUOTIENT=max_balance_churn_quotient,
BEACON_CHAIN_SHARD_NUMBER=beacon_chain_shard_number,
MAX_INDICES_PER_SLASHABLE_VOTE=max_indices_per_slashable_vote,
MAX_EXIT_DEQUEUES_PER_EPOCH=max_exit_dequeues_per_epoch,
SHUFFLE_ROUND_COUNT=shuffle_round_count,
LATEST_BLOCK_ROOTS_LENGTH=latest_block_roots_length,
LATEST_ACTIVE_INDEX_ROOTS_LENGTH=latest_active_index_roots_length,
LATEST_RANDAO_MIXES_LENGTH=latest_randao_mixes_length,
Expand Down
Loading