Skip to content

Commit

Permalink
Merge pull request #117 from hwwhww/fix-length-boundary
Browse files Browse the repository at this point in the history
Fix homogeneous sequence boundary
  • Loading branch information
hwwhww authored Jun 18, 2020
2 parents 7e9c107 + d5a922a commit 36f3406
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 16 deletions.
8 changes: 0 additions & 8 deletions ssz/sedes/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,6 @@ def get_key(self, value: Any) -> str:
class HomogeneousProperCompositeSedes(
ProperCompositeSedes[TSerializable, TDeserialized]
):
def __init__(self, element_sedes: TSedes, max_length: int) -> None:
self.element_sedes = element_sedes
if max_length < 1:
raise ValueError(
f"(Maximum) length of homogenous composites must be at least 1, got {max_length}"
)
self.max_length = max_length

def get_sedes_id(self) -> str:
sedes_name = self.__class__.__name__
return f"{sedes_name}({self.element_sedes.get_sedes_id()},{self.max_length})"
Expand Down
2 changes: 1 addition & 1 deletion ssz/sedes/bitlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class Bitlist(BitfieldCompositeSedes[BytesOrByteArray, bytes]):
def __init__(self, max_bit_count: int) -> None:
if max_bit_count < 0:
raise TypeError("Max bit count cannot be negative")
raise ValueError(f"Max bit count cannot be negative, got {max_bit_count}")
self.max_bit_count = max_bit_count

#
Expand Down
6 changes: 4 additions & 2 deletions ssz/sedes/bitvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

class Bitvector(BitfieldCompositeSedes[BytesOrByteArray, bytes]):
def __init__(self, bit_count: int) -> None:
if bit_count <= 0:
raise TypeError("Bit count cannot be zero or negative")
if bit_count < 1:
raise ValueError(
f"Bitvector must have a size of 1 or greater, got {bit_count}"
)
self.bit_count = bit_count

#
Expand Down
10 changes: 9 additions & 1 deletion ssz/sedes/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ssz.exceptions import DeserializationError
from ssz.hashable_list import HashableList
from ssz.hashable_structure import BaseHashableStructure
from ssz.sedes.base import BaseSedes
from ssz.sedes.base import BaseSedes, TSedes
from ssz.sedes.basic import BasicSedes, HomogeneousProperCompositeSedes
from ssz.typing import CacheObj, TDeserialized, TSerializable
from ssz.utils import (
Expand All @@ -30,6 +30,14 @@
class List(
HomogeneousProperCompositeSedes[Sequence[TSerializable], Tuple[TDeserialized, ...]]
):
def __init__(self, element_sedes: TSedes, max_length: int) -> None:
if max_length < 0:
raise ValueError(
f"Lists must have a maximum length of 0 or greater, got {max_length}"
)
self.element_sedes = element_sedes
self.max_length = max_length

#
# Size
#
Expand Down
5 changes: 3 additions & 2 deletions ssz/sedes/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ class Vector(
]
):
def __init__(self, element_sedes: TSedes, length: int) -> None:
if length <= 0:
if length < 1:
raise ValueError(f"Vectors must have a size of 1 or greater, got {length}")
super().__init__(element_sedes, max_length=length)
self.element_sedes = element_sedes
self.max_length = length

@property
def length(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion tests/sedes/test_bitvector_instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@


def test_bitvector_instantiation_bound():
with pytest.raises(TypeError):
with pytest.raises(ValueError):
bit_count = 0
Bitvector(bit_count)
48 changes: 47 additions & 1 deletion tests/sedes/test_composite_sedes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@
from ssz.exceptions import DeserializationError
from ssz.hashable_list import HashableList
from ssz.hashable_vector import HashableVector
from ssz.sedes import Container, List, UInt, Vector, bytes32, uint8, uint256
from ssz.sedes import (
Bitlist,
Bitvector,
Container,
List,
UInt,
Vector,
bytes32,
uint8,
uint256,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -164,3 +174,39 @@ def test_eq(sedes1, sedes2):
def test_neq(sedes1, sedes2):
assert sedes1 != sedes2
assert hash(sedes1) != hash(sedes2)


@pytest.mark.parametrize(
("sedes_type", "element_type", "length", "is_valid"),
(
(List, uint8, 0, True),
(List, uint8, -1, False),
(Vector, uint8, 1, True),
(Vector, uint8, 0, False),
),
)
def test_homogeneous_sequence_length_boundary(
sedes_type, element_type, length, is_valid
):
if is_valid:
sedes_type(element_type, length)
else:
with pytest.raises(ValueError):
sedes_type(element_type, length)


@pytest.mark.parametrize(
("sedes_type", "length", "is_valid"),
(
(Bitlist, 0, True),
(Bitlist, -1, False),
(Bitvector, 1, True),
(Bitvector, 0, False),
),
)
def test_bitfield_length_boundary(sedes_type, length, is_valid):
if is_valid:
sedes_type(length)
else:
with pytest.raises(ValueError):
sedes_type(length)

0 comments on commit 36f3406

Please sign in to comment.