diff --git a/ssz/__init__.py b/ssz/__init__.py index 35352e9b..765095ac 100644 --- a/ssz/__init__.py +++ b/ssz/__init__.py @@ -21,6 +21,7 @@ Container, List, Serializable, + SignedSerializable, UInt, Vector, boolean, diff --git a/ssz/constants.py b/ssz/constants.py index 64246386..4a005e3d 100644 --- a/ssz/constants.py +++ b/ssz/constants.py @@ -7,3 +7,5 @@ SIZE_PREFIX_SIZE = 4 # named BYTES_PER_LENGTH_PREFIX in the spec MAX_CONTENT_SIZE = 2 ** (SIZE_PREFIX_SIZE * 8) - 1 + +SIGNATURE_FIELD_NAME = "signature" diff --git a/ssz/sedes/__init__.py b/ssz/sedes/__init__.py index fb916277..ccc995a4 100644 --- a/ssz/sedes/__init__.py +++ b/ssz/sedes/__init__.py @@ -37,6 +37,9 @@ from .serializable import ( # noqa: F401 Serializable, ) +from .signed_serializable import ( # noqa: F401 + SignedSerializable, +) from .uint import ( # noqa: F401 UInt, uint8, diff --git a/ssz/sedes/base.py b/ssz/sedes/base.py index 9f6795eb..a8239e81 100644 --- a/ssz/sedes/base.py +++ b/ssz/sedes/base.py @@ -27,7 +27,6 @@ class BaseSedes(ABC, Generic[TSerializable, TDeserialized]): - # # Size # @@ -94,7 +93,6 @@ def hash_tree_root(self, value: TSerializable) -> bytes: class BasicSedes(BaseSedes[TSerializable, TDeserialized]): - def __init__(self, size: int): if size <= 0: raise ValueError("Length must be greater than 0") @@ -144,7 +142,6 @@ def hash_tree_root(self, value: TSerializable) -> bytes: class CompositeSedes(BaseSedes[TSerializable, TDeserialized]): - # # Serialization # diff --git a/ssz/sedes/boolean.py b/ssz/sedes/boolean.py index 02294007..870cf38c 100644 --- a/ssz/sedes/boolean.py +++ b/ssz/sedes/boolean.py @@ -11,7 +11,6 @@ class Boolean(BasicSedes[bool, bool]): - def __init__(self) -> None: super().__init__(size=1) diff --git a/ssz/sedes/byte.py b/ssz/sedes/byte.py index 79fc4e73..7e90aa2b 100644 --- a/ssz/sedes/byte.py +++ b/ssz/sedes/byte.py @@ -4,7 +4,6 @@ class Byte(BasicSedes[bytes, bytes]): - def __init__(self) -> None: super().__init__(1) diff --git a/ssz/sedes/byte_list.py b/ssz/sedes/byte_list.py index 5214299d..411460b1 100644 --- a/ssz/sedes/byte_list.py +++ b/ssz/sedes/byte_list.py @@ -15,7 +15,6 @@ class ByteList(CompositeSedes[BytesOrByteArray, bytes]): - is_static_sized = False def get_static_size(self): diff --git a/ssz/sedes/byte_vector.py b/ssz/sedes/byte_vector.py index beb62b8c..e6a272e3 100644 --- a/ssz/sedes/byte_vector.py +++ b/ssz/sedes/byte_vector.py @@ -17,7 +17,6 @@ class ByteVector(CompositeSedes[BytesOrByteArray, bytes]): - def __init__(self, size: int) -> None: self.size = size diff --git a/ssz/sedes/container.py b/ssz/sedes/container.py index a86e7071..936ee9db 100644 --- a/ssz/sedes/container.py +++ b/ssz/sedes/container.py @@ -32,7 +32,6 @@ class Container(CompositeSedes[TAnyTypedDict, Dict[str, Any]]): - def __init__(self, fields: Sequence[Tuple[str, BaseSedes[Any, Any]]]) -> None: self.fields = fields self.field_names = tuple(field_name for field_name, _ in self.fields) diff --git a/ssz/sedes/list.py b/ssz/sedes/list.py index 46be2e00..5ea8bfcc 100644 --- a/ssz/sedes/list.py +++ b/ssz/sedes/list.py @@ -32,7 +32,6 @@ class List(CompositeSedes[Iterable[TSerializable], Tuple[TDeserialized, ...]]): - def __init__(self, element_sedes: BaseSedes[TSerializable, TDeserialized] = None, empty: bool = False) -> None: diff --git a/ssz/sedes/serializable.py b/ssz/sedes/serializable.py index 1c43ed24..c28c75d8 100644 --- a/ssz/sedes/serializable.py +++ b/ssz/sedes/serializable.py @@ -23,6 +23,7 @@ merge, ) +import ssz from ssz.sedes.base import ( BaseSedes, ) @@ -37,7 +38,6 @@ class Meta(NamedTuple): - has_fields: bool fields: Optional[Tuple[Tuple[str, BaseSedes]]] container_sedes: Optional[Container] @@ -85,7 +85,6 @@ def merge_args_to_kwargs(args, kwargs, arg_names): class BaseSerializable(collections.Sequence): - _cached_ssz = None def __init__(self, *args, **kwargs): @@ -176,6 +175,10 @@ def __copy__(self): def __deepcopy__(self, *args): return self.copy() + @property + def root(self): + return ssz.hash_tree_root(self) + def make_immutable(value): if isinstance(value, list): @@ -247,7 +250,6 @@ def _get_class_namespace(cls): class MetaSerializable(abc.ABCMeta): - def __new__(mcls, name, bases, namespace): fields_attr_name = "fields" declares_fields = fields_attr_name in namespace diff --git a/ssz/sedes/signed_serializable.py b/ssz/sedes/signed_serializable.py new file mode 100644 index 00000000..1ba8c874 --- /dev/null +++ b/ssz/sedes/signed_serializable.py @@ -0,0 +1,68 @@ +from typing import ( + NamedTuple, + Optional, + Tuple, +) + +import ssz +from ssz.constants import ( + SIGNATURE_FIELD_NAME, +) +from ssz.sedes.base import ( + BaseSedes, +) +from ssz.sedes.container import ( + Container, +) +from ssz.sedes.serializable import ( + BaseSerializable, + MetaSerializable, +) + + +class SignedMeta(NamedTuple): + has_fields: bool + fields: Optional[Tuple[Tuple[str, BaseSedes]]] + container_sedes: Optional[Container] + signed_container_sedes: Optional[Container] + field_names: Optional[Tuple[str, ...]] + field_attrs: Optional[Tuple[str, ...]] + + +class MetaSignedSerializable(MetaSerializable): + def __new__(mcls, name, bases, namespace): + cls = super().__new__(mcls, name, bases, namespace) + + if cls._meta.has_fields: + if len(cls._meta.fields) < 2: + raise TypeError(f"Signed serializables need to have at least two fields") + if cls._meta.field_names[-1] != SIGNATURE_FIELD_NAME: + raise TypeError( + f"Last field of signed serializable must be {SIGNATURE_FIELD_NAME}, but is " + f"{cls._meta.field_names[-1]}" + ) + + signed_container_sedes = Container(cls._meta.fields[:-1]) + else: + signed_container_sedes = None + + meta = SignedMeta( + has_fields=cls._meta.has_fields, + fields=cls._meta.fields, + container_sedes=cls._meta.container_sedes, + signed_container_sedes=signed_container_sedes, + field_names=cls._meta.field_names, + field_attrs=cls._meta.field_attrs, + ) + cls._meta = meta + + return cls + + +BaseSedes.register(MetaSignedSerializable) + + +class SignedSerializable(BaseSerializable, metaclass=MetaSignedSerializable): + @property + def signing_root(self): + return ssz.hash_tree_root(self, self._meta.signed_container_sedes) diff --git a/ssz/sedes/uint.py b/ssz/sedes/uint.py index a46eda57..05a69702 100644 --- a/ssz/sedes/uint.py +++ b/ssz/sedes/uint.py @@ -7,7 +7,6 @@ class UInt(BasicSedes[int, int]): - def __init__(self, num_bits: int) -> None: if num_bits % 8 != 0: raise ValueError( diff --git a/ssz/sedes/vector.py b/ssz/sedes/vector.py index 9431a93f..cc902730 100644 --- a/ssz/sedes/vector.py +++ b/ssz/sedes/vector.py @@ -28,7 +28,6 @@ class Vector(CompositeSedes[Sequence[TSerializableElement], Tuple[TDeserializedElement, ...]]): - def __init__(self, element_sedes: BaseSedes[TSerializableElement, TDeserializedElement], length: int) -> None: diff --git a/tests/misc/test_serializable.py b/tests/misc/test_serializable.py index 1383d369..b184ca95 100644 --- a/tests/misc/test_serializable.py +++ b/tests/misc/test_serializable.py @@ -140,3 +140,14 @@ class TestC(ssz.Serializable): assert test_b1 == test_a1 assert test_c1 == test_a1 assert test_c2 != test_a1 + + +def test_root(): + class Test(ssz.Serializable): + fields = ( + ("field1", uint8), + ("field2", uint8), + ) + + test = Test(1, 2) + assert test.root == ssz.hash_tree_root(test, Test) diff --git a/tests/misc/test_signed_serializable.py b/tests/misc/test_signed_serializable.py new file mode 100644 index 00000000..9333d7a9 --- /dev/null +++ b/tests/misc/test_signed_serializable.py @@ -0,0 +1,60 @@ +import pytest + +import ssz +from ssz.sedes import ( + byte_list, + uint8, +) + + +def test_field_number_check(): + with pytest.raises(TypeError): + class TestA(ssz.SignedSerializable): + fields = () + + with pytest.raises(TypeError): + class TestB(ssz.SignedSerializable): + fields = ( + ("signature", byte_list), + ) + + class TestC(ssz.SignedSerializable): + fields = ( + ("field1", uint8), + ("signature", byte_list), + ) + + +def test_field_name_check(): + with pytest.raises(TypeError): + class TestA(ssz.SignedSerializable): + fields = ( + ("field1", uint8), + ("field2", byte_list), + ) + + with pytest.raises(TypeError): + class TestB(ssz.SignedSerializable): + fields = ( + ("signature", uint8), + ("field1", byte_list), + ) + + +def test_signing_root(): + class Signed(ssz.SignedSerializable): + fields = ( + ("field1", uint8), + ("field2", byte_list), + ("signature", byte_list), + ) + + class Unsigned(ssz.Serializable): + fields = ( + ("field1", uint8), + ("field2", byte_list), + ) + + signed = Signed(123, b"\xaa", b"\x00") + unsigned = Unsigned(123, b"\xaa") + assert signed.signing_root == unsigned.root