Skip to content

Commit

Permalink
Merge pull request #58 from jannikluhn/signed-root
Browse files Browse the repository at this point in the history
Root and SignedRoot
  • Loading branch information
jannikluhn authored Apr 23, 2019
2 parents 90d1e5b + b40aa7d commit 97cc1cf
Show file tree
Hide file tree
Showing 16 changed files with 150 additions and 14 deletions.
1 change: 1 addition & 0 deletions ssz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Container,
List,
Serializable,
SignedSerializable,
UInt,
Vector,
boolean,
Expand Down
2 changes: 2 additions & 0 deletions ssz/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@

SIZE_PREFIX_SIZE = 4 # named BYTES_PER_LENGTH_PREFIX in the spec
MAX_CONTENT_SIZE = 2 ** (SIZE_PREFIX_SIZE * 8) - 1

SIGNATURE_FIELD_NAME = "signature"
3 changes: 3 additions & 0 deletions ssz/sedes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
from .serializable import ( # noqa: F401
Serializable,
)
from .signed_serializable import ( # noqa: F401
SignedSerializable,
)
from .uint import ( # noqa: F401
UInt,
uint8,
Expand Down
3 changes: 0 additions & 3 deletions ssz/sedes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@


class BaseSedes(ABC, Generic[TSerializable, TDeserialized]):

#
# Size
#
Expand Down Expand Up @@ -94,7 +93,6 @@ def hash_tree_root(self, value: TSerializable) -> bytes:


class BasicSedes(BaseSedes[TSerializable, TDeserialized]):

def __init__(self, size: int):
if size <= 0:
raise ValueError("Length must be greater than 0")
Expand Down Expand Up @@ -144,7 +142,6 @@ def hash_tree_root(self, value: TSerializable) -> bytes:


class CompositeSedes(BaseSedes[TSerializable, TDeserialized]):

#
# Serialization
#
Expand Down
1 change: 0 additions & 1 deletion ssz/sedes/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class Boolean(BasicSedes[bool, bool]):

def __init__(self) -> None:
super().__init__(size=1)

Expand Down
1 change: 0 additions & 1 deletion ssz/sedes/byte.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class Byte(BasicSedes[bytes, bytes]):

def __init__(self) -> None:
super().__init__(1)

Expand Down
1 change: 0 additions & 1 deletion ssz/sedes/byte_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


class ByteList(CompositeSedes[BytesOrByteArray, bytes]):

is_static_sized = False

def get_static_size(self):
Expand Down
1 change: 0 additions & 1 deletion ssz/sedes/byte_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@


class ByteVector(CompositeSedes[BytesOrByteArray, bytes]):

def __init__(self, size: int) -> None:
self.size = size

Expand Down
1 change: 0 additions & 1 deletion ssz/sedes/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@


class Container(CompositeSedes[TAnyTypedDict, Dict[str, Any]]):

def __init__(self, fields: Sequence[Tuple[str, BaseSedes[Any, Any]]]) -> None:
self.fields = fields
self.field_names = tuple(field_name for field_name, _ in self.fields)
Expand Down
1 change: 0 additions & 1 deletion ssz/sedes/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@


class List(CompositeSedes[Iterable[TSerializable], Tuple[TDeserialized, ...]]):

def __init__(self,
element_sedes: BaseSedes[TSerializable, TDeserialized] = None,
empty: bool = False) -> None:
Expand Down
8 changes: 5 additions & 3 deletions ssz/sedes/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
merge,
)

import ssz
from ssz.sedes.base import (
BaseSedes,
)
Expand All @@ -37,7 +38,6 @@


class Meta(NamedTuple):

has_fields: bool
fields: Optional[Tuple[Tuple[str, BaseSedes]]]
container_sedes: Optional[Container]
Expand Down Expand Up @@ -85,7 +85,6 @@ def merge_args_to_kwargs(args, kwargs, arg_names):


class BaseSerializable(collections.Sequence):

_cached_ssz = None

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -176,6 +175,10 @@ def __copy__(self):
def __deepcopy__(self, *args):
return self.copy()

@property
def root(self):
return ssz.hash_tree_root(self)


def make_immutable(value):
if isinstance(value, list):
Expand Down Expand Up @@ -247,7 +250,6 @@ def _get_class_namespace(cls):


class MetaSerializable(abc.ABCMeta):

def __new__(mcls, name, bases, namespace):
fields_attr_name = "fields"
declares_fields = fields_attr_name in namespace
Expand Down
68 changes: 68 additions & 0 deletions ssz/sedes/signed_serializable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import (
NamedTuple,
Optional,
Tuple,
)

import ssz
from ssz.constants import (
SIGNATURE_FIELD_NAME,
)
from ssz.sedes.base import (
BaseSedes,
)
from ssz.sedes.container import (
Container,
)
from ssz.sedes.serializable import (
BaseSerializable,
MetaSerializable,
)


class SignedMeta(NamedTuple):
has_fields: bool
fields: Optional[Tuple[Tuple[str, BaseSedes]]]
container_sedes: Optional[Container]
signed_container_sedes: Optional[Container]
field_names: Optional[Tuple[str, ...]]
field_attrs: Optional[Tuple[str, ...]]


class MetaSignedSerializable(MetaSerializable):
def __new__(mcls, name, bases, namespace):
cls = super().__new__(mcls, name, bases, namespace)

if cls._meta.has_fields:
if len(cls._meta.fields) < 2:
raise TypeError(f"Signed serializables need to have at least two fields")
if cls._meta.field_names[-1] != SIGNATURE_FIELD_NAME:
raise TypeError(
f"Last field of signed serializable must be {SIGNATURE_FIELD_NAME}, but is "
f"{cls._meta.field_names[-1]}"
)

signed_container_sedes = Container(cls._meta.fields[:-1])
else:
signed_container_sedes = None

meta = SignedMeta(
has_fields=cls._meta.has_fields,
fields=cls._meta.fields,
container_sedes=cls._meta.container_sedes,
signed_container_sedes=signed_container_sedes,
field_names=cls._meta.field_names,
field_attrs=cls._meta.field_attrs,
)
cls._meta = meta

return cls


BaseSedes.register(MetaSignedSerializable)


class SignedSerializable(BaseSerializable, metaclass=MetaSignedSerializable):
@property
def signing_root(self):
return ssz.hash_tree_root(self, self._meta.signed_container_sedes)
1 change: 0 additions & 1 deletion ssz/sedes/uint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class UInt(BasicSedes[int, int]):

def __init__(self, num_bits: int) -> None:
if num_bits % 8 != 0:
raise ValueError(
Expand Down
1 change: 0 additions & 1 deletion ssz/sedes/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@


class Vector(CompositeSedes[Sequence[TSerializableElement], Tuple[TDeserializedElement, ...]]):

def __init__(self,
element_sedes: BaseSedes[TSerializableElement, TDeserializedElement],
length: int) -> None:
Expand Down
11 changes: 11 additions & 0 deletions tests/misc/test_serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,14 @@ class TestC(ssz.Serializable):
assert test_b1 == test_a1
assert test_c1 == test_a1
assert test_c2 != test_a1


def test_root():
class Test(ssz.Serializable):
fields = (
("field1", uint8),
("field2", uint8),
)

test = Test(1, 2)
assert test.root == ssz.hash_tree_root(test, Test)
60 changes: 60 additions & 0 deletions tests/misc/test_signed_serializable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pytest

import ssz
from ssz.sedes import (
byte_list,
uint8,
)


def test_field_number_check():
with pytest.raises(TypeError):
class TestA(ssz.SignedSerializable):
fields = ()

with pytest.raises(TypeError):
class TestB(ssz.SignedSerializable):
fields = (
("signature", byte_list),
)

class TestC(ssz.SignedSerializable):
fields = (
("field1", uint8),
("signature", byte_list),
)


def test_field_name_check():
with pytest.raises(TypeError):
class TestA(ssz.SignedSerializable):
fields = (
("field1", uint8),
("field2", byte_list),
)

with pytest.raises(TypeError):
class TestB(ssz.SignedSerializable):
fields = (
("signature", uint8),
("field1", byte_list),
)


def test_signing_root():
class Signed(ssz.SignedSerializable):
fields = (
("field1", uint8),
("field2", byte_list),
("signature", byte_list),
)

class Unsigned(ssz.Serializable):
fields = (
("field1", uint8),
("field2", byte_list),
)

signed = Signed(123, b"\xaa", b"\x00")
unsigned = Unsigned(123, b"\xaa")
assert signed.signing_root == unsigned.root

0 comments on commit 97cc1cf

Please sign in to comment.