diff --git a/test_libs/pyspec/eth2spec/test_ssz_partials.py b/test_libs/pyspec/eth2spec/test_ssz_partials.py index 34be754950..3858423b28 100644 --- a/test_libs/pyspec/eth2spec/test_ssz_partials.py +++ b/test_libs/pyspec/eth2spec/test_ssz_partials.py @@ -1,17 +1,24 @@ -from utils.ssz.ssz_typing import * -from utils.ssz.ssz_impl import * -from utils.ssz.ssz_partials import * -import os, random +from eth2spec.utils.ssz.ssz_impl import serialize, hash_tree_root +from eth2spec.utils.ssz.ssz_typing import ( + Bit, Bytes, uint64, Container, Vector, List +) + +from eth2spec.utils.ssz.ssz_partials import ( + ssz_full, ssz_partial, extract_value_at_path, merge +) + class Person(Container): - is_male: bool + is_male: Bit age: uint64 name: Bytes[32] + class City(Container): coords: Vector[uint64, 2] people: List[Person, 20] + people = List[Person, 20]( Person(is_male=True, age=uint64(45), name=Bytes[32](b"Alex")), Person(is_male=True, age=uint64(47), name=Bytes[32](b"Bob")), @@ -59,7 +66,8 @@ class City(Container): assert p.people[1].name.hash_tree_root() == hash_tree_root(city.people[1].name) assert p.people[1].hash_tree_root() == hash_tree_root(city.people[1]) assert p.coords.hash_tree_root() == hash_tree_root(city.coords) -assert p.people.hash_tree_root() == hash_tree_root(city.people), (p.people.hash_tree_root(), hash_tree_root(city.people)) +assert p.people.hash_tree_root() == hash_tree_root(city.people), ( + p.people.hash_tree_root(), hash_tree_root(city.people)) assert p.hash_tree_root() == hash_tree_root(city) print(hash_tree_root(city)) print("Reading tests passed") diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_partials.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_partials.py index d5d4404ccb..79dfca68e5 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_partials.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_partials.py @@ -1,17 +1,12 @@ from ..merkle_minimal import hash, next_power_of_two from .ssz_typing import ( - get_zero_value, Container, List, Vector, Bytes, BytesN, uint, uint64, infer_input_type + Container, List, Vector, Bytes, BytesN, uint64, byte, BasicValue, SSZValue, coerce_type_maybe ) from .ssz_impl import ( chunkify, deserialize_basic, - get_typed_values, - is_basic_type, is_bottom_layer_kind, - is_list_kind, - is_vector_kind, - is_container_type, item_length, get_chunk_count, pack, @@ -23,7 +18,7 @@ def is_generalized_index_child(c, a): - return False if c < a else True if c == a else is_generalized_index_child(c//2, a) + return False if c < a else True if c == a else is_generalized_index_child(c // 2, a) def empty_sisters(starting_position, list_size, max_list_size): @@ -49,12 +44,13 @@ def empty_sisters(starting_position, list_size, max_list_size): return o -def ssz_leaves(obj, typ=None, root=1): +def ssz_leaves(obj: SSZValue, root=1): """ Converts an object into a {generalized_index: chunk} map. Leaves only; does not fill intermediate chunks or compute the root. """ - if is_list_kind(typ): + typ = obj.type() + if isinstance(obj, (List, Bytes)): o = {root * 2 + 1: len(obj).to_bytes(32, 'little')} base = root * 2 else: @@ -62,18 +58,16 @@ def ssz_leaves(obj, typ=None, root=1): base = root if is_bottom_layer_kind(typ): starting_index = base * next_power_of_two(get_chunk_count(typ)) - data = chunkify(serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, typ.elem_type)) + data = chunkify(serialize_basic(obj) if isinstance(obj, BasicValue) else pack(obj)) return { **o, **{starting_index + i: data[i] for i in range(len(data))}, **empty_sisters(starting_index, len(data), get_chunk_count(typ)) } else: - fields = get_typed_values(obj, typ=typ) - starting_index = base * next_power_of_two(get_chunk_count(typ)) - for i, (elem, elem_type) in enumerate(fields): - o = {**o, **ssz_leaves(elem, typ=elem_type, root=starting_index + i)} - return {**o, **empty_sisters(starting_index, len(fields), get_chunk_count(typ))} + for i, elem in enumerate(obj): + o = {**o, **ssz_leaves(elem, root=starting_index + i)} + return {**o, **empty_sisters(starting_index, len(obj), get_chunk_count(typ))} def fill(objects): @@ -93,9 +87,8 @@ def fill(objects): return objects -@infer_input_type -def ssz_full(obj, typ=None): - return fill(ssz_leaves(obj, typ=typ)) +def ssz_full(obj): + return fill(ssz_leaves(obj)) def get_item_position(typ, index): @@ -104,11 +97,11 @@ def get_item_position(typ, index): represented, (ii) the starting byte position, (iii) the ending byte position. For example for a 6-item list of uint64 values, index=2 will return (0, 16, 24), index=5 will return (1, 8, 16) """ - if is_list_kind(typ) or is_vector_kind(typ): + if issubclass(typ, (List, Bytes)) or issubclass(typ, (Vector, BytesN)): start = index * item_length(typ.elem_type) return start // 32, start % 32, start % 32 + item_length(typ.elem_type) - elif is_container_type(typ): - return typ.get_field_names().index(index), 0, item_length(get_elem_type(typ, index)) + elif issubclass(typ, Container): + return list(typ.get_fields().keys()).index(index), 0, item_length(get_elem_type(typ, index)) else: raise Exception("Only lists/vectors/containers supported") @@ -118,7 +111,7 @@ def get_elem_type(typ, index): Returns the type of the element of an object of the given type with the given index or member variable name (eg. `7` for `x[7]`, `"foo"` for `x.foo`) """ - return typ.get_fields_dict()[index] if is_container_type(typ) else typ.elem_type + return typ.get_fields()[index] if issubclass(typ, Container) else typ.elem_type def get_generalized_index(typ, root, path): @@ -127,12 +120,12 @@ def get_generalized_index(typ, root, path): `len(x[12].bar)`) into the generalized index representing its position in the Merkle tree. """ for p in path: - assert not is_basic_type(typ) # If we descend to a basic type, the path cannot continue further + assert not issubclass(typ, BasicValue) # If we descend to a basic type, the path cannot continue further if p == '__len__': - typ, root = uint64, root * 2 + 1 if is_list_kind(typ) else None + typ, root = uint64, root * 2 + 1 if issubclass(typ, list) else None else: pos, _, _ = get_item_position(typ, p) - root = root * (2 if is_list_kind(typ) else 1) * next_power_of_two(get_chunk_count(typ)) + pos + root = root * (2 if issubclass(typ, (List, Bytes)) else 1) * next_power_of_two(get_chunk_count(typ)) + pos typ = get_elem_type(typ, p) return root @@ -146,10 +139,10 @@ def extract_value_at_path(chunks, typ, path): for p in path: if p == '__len__': return deserialize_basic(chunks[root * 2 + 1][:8], uint64) - if is_list_kind(typ): + if issubclass(typ, (List, Bytes)): assert 0 <= p < deserialize_basic(chunks[root * 2 + 1][:8], uint64) pos, start, end = get_item_position(typ, p) - root = root * (2 if is_list_kind(typ) else 1) * next_power_of_two(get_chunk_count(typ)) + pos + root = root * (2 if issubclass(typ, (List, Bytes)) else 1) * next_power_of_two(get_chunk_count(typ)) + pos typ = get_elem_type(typ, p) return deserialize_basic(chunks[root][start: end], typ) @@ -159,12 +152,12 @@ def get_generalized_index_correspondence(typ, root=1, path=None): Prints the path corresponding to every leaf-level generalized index in an SSZ object """ path = path or [] - if is_basic_type(typ): + if issubclass(typ, BasicValue): return {root: path} - if is_list_kind(typ) or is_vector_kind(typ): + if issubclass(typ, (List, Vector, Bytes, BytesN)): fields = list(range(typ.length)) else: - fields = typ.get_field_names() + fields = list(typ.get_fields().keys()) o = {} for f in fields: o = { @@ -176,7 +169,7 @@ def get_generalized_index_correspondence(typ, root=1, path=None): ) } o[root] = path - if is_list_kind(typ): + if issubclass(typ, (List, Bytes)): o[root * 2 + 1] = path + ['__len__'] return o @@ -191,6 +184,7 @@ def get_branch_indices(tree_index): o.append((o[-1] // 2) ^ 1) return o[:-1] + def expand_indices(indices): """ Get the generalized indices of all chunks in the tree needed to prove the chunks with the given @@ -199,7 +193,7 @@ def expand_indices(indices): branches = set() for index in indices: branches = branches.union(set(get_branch_indices(index) + [index])) - return sorted(list([x for x in branches if x*2 not in branches or x*2+1 not in branches]))[::-1] + return sorted(list([x for x in branches if x * 2 not in branches or x * 2 + 1 not in branches]))[::-1] def merge(*args): @@ -215,14 +209,14 @@ class OutOfRangeException(Exception): class SSZPartial(): def __init__(self, typ, objects, root=1): - assert not is_basic_type(typ) + assert not issubclass(typ, BasicValue) self.objects = objects self.typ = typ self.root = root def getter(self, index): - base = self.root * 2 if is_list_kind(self.typ) else self.root - if is_basic_type(get_elem_type(self.typ, index)): + base = self.root * 2 if issubclass(self.typ, (List, Bytes)) else self.root + if issubclass(get_elem_type(self.typ, index), BasicValue): pos, start, end = get_item_position(self.typ, index) tree_index = base * next_power_of_two(get_chunk_count(self.typ)) + pos if tree_index not in self.objects: @@ -230,7 +224,7 @@ def getter(self, index): else: return deserialize_basic( self.objects[tree_index][start:end], - self.typ if is_basic_type(self.typ) else get_elem_type(self.typ, index) + self.typ if issubclass(self.typ, BasicValue) else get_elem_type(self.typ, index) ) else: return ssz_partial( @@ -250,27 +244,28 @@ def renew_branch(self, tree_index): tree_index //= 2 def setter(self, index, value, renew=True): - base = self.root * 2 if is_list_kind(self.typ) else self.root + base = self.root * 2 if issubclass(self.typ, (List, Bytes)) else self.root elem_type = get_elem_type(self.typ, index) - if is_basic_type(elem_type): + value = coerce_type_maybe(value, elem_type, strict=True) + if issubclass(elem_type, BasicValue): pos, start, end = get_item_position(self.typ, index) tree_index = base * next_power_of_two(get_chunk_count(self.typ)) + pos if tree_index not in self.objects: raise OutOfRangeException("Do not have required data") else: chunk = self.objects[tree_index] - self.objects[tree_index] = chunk[:start] + serialize_basic(value, elem_type) + chunk[end:] + self.objects[tree_index] = chunk[:start] + serialize_basic(value) + chunk[end:] assert len(self.objects[tree_index]) == 32 else: tree_index = get_generalized_index(self.typ, self.root, [index]) self.clear_subtree(tree_index) - for k, v in fill(ssz_leaves(value, elem_type, tree_index)).items(): + for k, v in fill(ssz_leaves(value, tree_index)).items(): self.objects[k] = v if renew: self.renew_branch(tree_index) def append_or_pop(self, appending, value): - assert is_list_kind(self.typ) + assert issubclass(self.typ, (List, Bytes)) old_length = len(self) new_length = old_length + (1 if appending else -1) if new_length < 0: @@ -286,17 +281,17 @@ def append_or_pop(self, appending, value): for k, v in empty_sisters(start_pos, old_chunk_count, get_chunk_count(self.typ)).items(): del self.objects[k] if not appending: - if is_basic_type(elem_type): - self.setter(old_length-1, get_zero_value(elem_type), renew=False) - else: - self.clear_subtree(get_generalized_index(self.typ, self.root, [old_length-1])) + if issubclass(elem_type, BasicValue): + self.setter(old_length - 1, elem_type.default(), renew=False) + else: + self.clear_subtree(get_generalized_index(self.typ, self.root, [old_length - 1])) else: - self.setter(new_length-1, value, renew=False) + self.setter(new_length - 1, value, renew=False) if new_chunk_count != old_chunk_count: for k, v in empty_sisters(start_pos, new_chunk_count, get_chunk_count(self.typ)).items(): self.clear_subtree(k) self.objects[k] = v - self.renew_branch(get_generalized_index(self.typ, self.root, [new_length-1])) + self.renew_branch(get_generalized_index(self.typ, self.root, [new_length - 1])) def append(self, value): return self.append_or_pop(True, value) @@ -311,7 +306,8 @@ def access_partial(self, path): gindex = get_generalized_index(self.typ, self.root, path) branch_keys = get_branch_indices(gindex) item_keys = [k for k in self.objects if is_generalized_index_child(k, gindex)] - return ssz_partial(self.typ, {i: self.objects[i] for i in self.objects if i in branch_keys+item_keys}, self.root) + return ssz_partial(self.typ, {i: self.objects[i] for i in self.objects if i in branch_keys + item_keys}, + self.root) def __getitem__(self, index): return self.getter(index) @@ -320,28 +316,28 @@ def __iter__(self): return (self[i] for i in range(len(self))) def __len__(self): - if is_list_kind(self.typ): + if issubclass(self.typ, (List, Bytes)): if self.root * 2 + 1 not in self.objects: raise OutOfRangeException("Do not have required data: {}".format(self.root * 2 + 1)) return int.from_bytes(self.objects[self.root * 2 + 1], 'little') - elif is_vector_kind(self.typ): + elif issubclass(self.typ, (Vector, BytesN)): return self.typ.length - elif is_container_type(self.typ): + elif issubclass(self.typ, Container): return len(self.typ.get_fields()) else: raise Exception("Unsupported type: {}".format(self.typ)) def full_value(self): - if issubclass(self.typ, Bytes) or issubclass(self.typ, BytesN): + if issubclass(self.typ, (Bytes, BytesN)): return self.typ(bytes([self.getter(i) for i in range(len(self))])) - elif is_list_kind(self.typ) or is_vector_kind(self.typ): + elif issubclass(self.typ, (List, Vector)): return self.typ(*(self[i] for i in range(len(self)))) - elif is_container_type(self.typ): + elif issubclass(self.typ, Container): def full_value(x): return x.full_value() if hasattr(x, 'full_value') else x - return self.typ(**{field: full_value(self.getter(field)) for field in self.typ.get_field_names()}) - elif is_basic_type(self.typ): + return self.typ(**{field: full_value(self.getter(field)) for field in self.typ.get_fields().keys()}) + elif issubclass(self.typ, BasicValue): return self.getter(0) else: raise Exception("Unsupported type: {}".format(self.typ)) @@ -368,18 +364,18 @@ def __eq__(self, other): return self.full_value() == other def minimal_indices(self): - if is_bottom_layer_kind(self.typ) and is_basic_type(get_elem_type(self.typ, None)): - if is_list_kind(self.typ) and self.root*2+1 not in self.objects: + if is_bottom_layer_kind(self.typ) and issubclass(get_elem_type(self.typ, None), BasicValue): + if issubclass(self.typ, (List, Bytes)) and self.root * 2 + 1 not in self.objects: return [] o = list(range( get_generalized_index(self.typ, self.root, [0]), - get_generalized_index(self.typ, self.root, [len(self)-1]) + 1 + get_generalized_index(self.typ, self.root, [len(self) - 1]) + 1 )) return [x for x in o if x in self.objects] - elif is_container_type(self.typ): + elif issubclass(self.typ, Container): o = [] - for field, elem_type in self.typ.get_fields(): - if is_basic_type(elem_type): + for field, elem_type in self.typ.get_fields().items(): + if issubclass(elem_type, BasicValue): gindex = get_generalized_index(self.typ, self.root, [field]) if gindex in self.objects: o.append(gindex) @@ -396,7 +392,7 @@ def encode(self): indices = self.minimal_indices() chunks = [self.objects[o] for o in expand_indices(indices)] return EncodedPartial(indices=indices, chunks=chunks) - + def ssz_partial(typ, objects, root=1): """ @@ -408,9 +404,9 @@ def ssz_partial(typ, objects, root=1): class Partial(SSZPartial, ssz_type): pass - if is_container_type(typ): + if issubclass(typ, Container): Partial.__annotations__ = typ.__annotations__ - for field in typ.get_field_names(): + for field in typ.get_fields().keys(): setattr(Partial, field, property( (lambda f: (lambda self: self.getter(f)))(field), (lambda f: (lambda self, v: self.setter(f, v)))(field) @@ -418,16 +414,17 @@ class Partial(SSZPartial, ssz_type): o = Partial(typ, objects, root) return o + class EncodedPartial(Container): - indices: List[uint64, 2**32] - chunks: List[BytesN[32], 2**32] + indices: List[uint64, 2 ** 32] + chunks: List[BytesN[32], 2 ** 32] def to_ssz(self, typ): """ Convert to an SSZ partial representing the given type. """ expanded_indices = expand_indices(self.indices) - o = ssz_partial(typ, fill({e:c for e,c in zip(expanded_indices, self.chunks)})) + o = ssz_partial(typ, fill({e: bytes(c) for e, c in zip(expanded_indices, self.chunks)})) for k, v in o.objects.items(): if k > 1: assert hash(o.objects[k & -2] + o.objects[k | 1]) == o.objects[k // 2]