Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work-in-progress on SSZ partials #1184

Closed
wants to merge 13 commits into from
104 changes: 104 additions & 0 deletions test_libs/pyspec/eth2spec/test_ssz_partials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from utils.ssz.ssz_typing import *
from utils.ssz.ssz_impl import *
from utils.ssz.ssz_partials import *
import os, random

class Person(Container):
is_male: bool
age: uint64
name: Bytes[32]

class City(Container):
coords: Vector[uint64, 2]
people: List[Person, 20]

people = List[Person, 20](
Person(is_male=True, age=uint64(45), name=Bytes[32](b"Alex")),
Person(is_male=True, age=uint64(47), name=Bytes[32](b"Bob")),
Person(is_male=True, age=uint64(49), name=Bytes[32](b"Carl")),
Person(is_male=True, age=uint64(51), name=Bytes[32](b"Danny")),
Person(is_male=True, age=uint64(53), name=Bytes[32](b"Evan")),
Person(is_male=False, age=uint64(55), name=Bytes[32](b"Fae")),
Person(is_male=False, age=uint64(57), name=Bytes[32](b"Ginny")),
Person(is_male=False, age=uint64(59), name=Bytes[32](b"Heather")),
Person(is_male=False, age=uint64(61), name=Bytes[32](b"Ingrid")),
Person(is_male=False, age=uint64(63), name=Bytes[32](b"Kane")),
)

city = City(coords=Vector[uint64, 2](uint64(45), uint64(90)), people=people)

paths = [
["coords", 0],
["people", 4, "name", 1],
["people", 8, "is_male"],
["people", 9],
["people", 7],
["people", 1],
]

x = ssz_full(city)
full = ssz_partial(City, ssz_full(city))
print(full.objects.keys())
for path in paths:
print(path, list(full.access_partial(path).objects.keys()))
# print(path, get_nodes_along_path(full, path, typ=City).keys())
p = merge(*[full.access_partial(path) for path in paths])
# p = SSZPartial(infer_type(city), branch2)
assert p.coords[0] == city.coords[0] == extract_value_at_path(p.objects, City, ['coords', 0])
assert p.coords[1] == city.coords[1]
assert len(p.coords) == len(city.coords)
assert p.coords.hash_tree_root() == hash_tree_root(city.coords)
assert p.people[4].name[1] == city.people[4].name[1] == extract_value_at_path(p.objects, City, ['people', 4, 'name', 1])
assert len(p.people[4].name) == len(city.people[4].name) == 4
assert p.people[8].is_male == city.people[8].is_male
assert p.people[7].is_male == city.people[7].is_male
assert p.people[7].age == city.people[7].age
assert p.people[7].name[0] == city.people[7].name[0]
assert str(p.people[7].name) == str(city.people[7].name)
assert str(p.people[1]) == str(city.people[1]), (str(p.people[1]), str(city.people[1]))
assert p.people[1].name.hash_tree_root() == hash_tree_root(city.people[1].name)
assert p.people[1].hash_tree_root() == hash_tree_root(city.people[1])
assert p.coords.hash_tree_root() == hash_tree_root(city.coords)
assert p.people.hash_tree_root() == hash_tree_root(city.people), (p.people.hash_tree_root(), hash_tree_root(city.people))
assert p.hash_tree_root() == hash_tree_root(city)
print(hash_tree_root(city))
print("Reading tests passed")
p.coords[0] = 65
assert p.coords[0] == 65
assert p.coords.hash_tree_root() == hash_tree_root(Vector[uint64, 2](uint64(65), uint64(90)))
p.people[7].name[0] = byte('F')
assert p.people[7].name[0] == ord('F')
assert p.people[7].name == Bytes[32](b"Feather")
p.people[9].is_male = False
assert p.people[9].is_male is False
p.people[1].name = Bytes[32](b"Ashley")
assert p.people[1].name.full_value() == Bytes[32](b"Ashley")
p.people[1].age += 100
assert p.people[1].hash_tree_root() == hash_tree_root(Person(is_male=True, age=uint64(147), name=Bytes[32](b"Ashley")))
print("Writing tests passed")
p = merge(*[full.access_partial(path) for path in paths])
object_keys = sorted(list(p.objects.keys()))[::-1]
print(object_keys)
pre_hash_root = p.hash_tree_root()
for i in range(10):
p.people.append(Person(is_male=False, age=uint64(i), name=Bytes[32](b"z" * i)))
city.people.append(Person(is_male=False, age=uint64(i), name=Bytes[32](b"z" * i)))
p.people[7].name.append(byte('!'))
city.people[7].name.append(byte('!'))
assert p.hash_tree_root() == city.hash_tree_root()
print(i)
for i in range(10):
p.people.pop()
city.people.pop()
p.people[7].name.pop()
city.people[7].name.pop()
assert p.hash_tree_root() == city.hash_tree_root()
print(i)
assert p.hash_tree_root() == pre_hash_root
print("Append and pop tests passed")
encoded = p.encode()
print(encoded)
print(serialize(encoded))
assert encoded.to_ssz(City).hash_tree_root() == p.hash_tree_root()
# print('extras', list([k for k in p.objects if k not in encoded.to_ssz(City).objects]))
print("Encoded partial tests passed")
37 changes: 29 additions & 8 deletions test_libs/pyspec/eth2spec/utils/merkle_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,32 @@ def next_power_of_two(v: int) -> int:
return 1 << (v - 1).bit_length()


def merkleize_chunks(chunks):
tree = chunks[::]
margin = next_power_of_two(len(chunks)) - len(chunks)
tree.extend([ZERO_BYTES32] * margin)
tree = [ZERO_BYTES32] * len(tree) + tree
for i in range(len(tree) // 2 - 1, 0, -1):
tree[i] = hash(tree[i * 2] + tree[i * 2 + 1])
return tree[1]
def merkleize_chunks(chunks, pad_to: int = None):
count = len(chunks)
depth = max(count - 1, 0).bit_length()
max_depth = max(depth, (pad_to - 1).bit_length())
tmp = [None for _ in range(max_depth + 1)]

def merge(h, i):
j = 0
while True:
if i & (1 << j) == 0:
if i == count and j < depth:
h = hash(h + zerohashes[j])
else:
break
else:
h = hash(tmp[j] + h)
j += 1
tmp[j] = h

for i in range(count):
merge(chunks[i], i)

if count < (1 << (count - 1).bit_length()):
merge(zerohashes[0], count)

for j in range(depth, max_depth):
tmp[j + 1] = hash(tmp[j] + zerohashes[j])

return tmp[max_depth]
81 changes: 52 additions & 29 deletions test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from ..merkle_minimal import merkleize_chunks, hash
from eth2spec.utils.ssz.ssz_typing import (
is_uint_type, is_bool_type, is_container_type,
is_list_kind, is_vector_kind,
read_vector_elem_type, read_elem_type,
uint_byte_size,
infer_input_type,
get_zero_value,
from ..merkle_minimal import merkleize_chunks, ZERO_BYTES32
from ..hash_function import hash
from .ssz_typing import (
get_zero_value, Container, List, Vector, Bytes, BytesN, uint, infer_input_type
)

# SSZ Serialization
Expand All @@ -15,13 +11,13 @@


def is_basic_type(typ):
return is_uint_type(typ) or is_bool_type(typ)

return typ == bool or issubclass(typ, uint)

@infer_input_type
def serialize_basic(value, typ):
if is_uint_type(typ):
return value.to_bytes(uint_byte_size(typ), 'little')
elif is_bool_type(typ):
if issubclass(typ, uint):
return value.to_bytes(typ.byte_len, 'little')
elif issubclass(typ, bool):
if value:
return b'\x01'
else:
Expand All @@ -31,22 +27,34 @@ def serialize_basic(value, typ):


def deserialize_basic(value, typ):
if is_uint_type(typ):
if issubclass(typ, uint):
return typ(int.from_bytes(value, 'little'))
elif is_bool_type(typ):
elif issubclass(typ, bool):
assert value in (b'\x00', b'\x01')
return True if value == b'\x01' else False
else:
raise Exception("Type not supported: {}".format(typ))


def is_list_kind(typ):
return issubclass(typ, (List, Bytes))


def is_vector_kind(typ):
return issubclass(typ, (Vector, BytesN))


def is_container_type(typ):
return issubclass(typ, Container)


def is_fixed_size(typ):
if is_basic_type(typ):
return True
elif is_list_kind(typ):
return False
elif is_vector_kind(typ):
return is_fixed_size(read_vector_elem_type(typ))
return is_fixed_size(typ.elem_type)
elif is_container_type(typ):
return all(is_fixed_size(t) for t in typ.get_field_types())
else:
Expand All @@ -58,11 +66,11 @@ def is_empty(obj):


@infer_input_type
def serialize(obj, typ=None):
def serialize(obj, typ):
if is_basic_type(typ):
return serialize_basic(obj, typ)
elif is_list_kind(typ) or is_vector_kind(typ):
return encode_series(obj, [read_elem_type(typ)] * len(obj))
return encode_series(obj, [typ.elem_type] * len(obj))
elif is_container_type(typ):
return encode_series(obj.get_field_values(), typ.get_field_types())
else:
Expand Down Expand Up @@ -126,36 +134,51 @@ def mix_in_length(root, length):
def is_bottom_layer_kind(typ):
return (
is_basic_type(typ) or
(is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(read_elem_type(typ))
(is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(typ.elem_type)
)


@infer_input_type
def get_typed_values(obj, typ=None):
def get_typed_values(obj, typ):
if is_container_type(typ):
return obj.get_typed_values()
elif is_list_kind(typ) or is_vector_kind(typ):
elem_type = read_elem_type(typ)
return list(zip(obj, [elem_type] * len(obj)))
return list(zip(obj, [typ.elem_type] * len(obj)))
else:
raise Exception("Invalid type")
raise Exception("Invalid type", obj, typ)


def item_length(typ):
if typ == bool:
return 1
elif issubclass(typ, uint):
return typ.byte_len
else:
return 32


def get_chunk_count(typ):
if is_basic_type(typ):
return 1
elif is_list_kind(typ) or is_vector_kind(typ):
return (typ.length * item_length(typ.elem_type) + 31) // 32
else:
return len(typ.get_fields())


@infer_input_type
def hash_tree_root(obj, typ=None):
def hash_tree_root(obj, typ):
if is_bottom_layer_kind(typ):
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, read_elem_type(typ))
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, typ.elem_type)
leaves = chunkify(data)
else:
fields = get_typed_values(obj, typ=typ)
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields]
if is_list_kind(typ):
return mix_in_length(merkleize_chunks(leaves), len(obj))
return mix_in_length(merkleize_chunks(leaves, pad_to=get_chunk_count(typ)), len(obj))
else:
return merkleize_chunks(leaves)
return merkleize_chunks(leaves, pad_to=get_chunk_count(typ))


@infer_input_type
def signing_root(obj, typ):
assert is_container_type(typ)
# ignore last field
Expand Down
Loading