Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix homogeneous sequence boundary #117

Merged
merged 3 commits into from
Jun 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions ssz/sedes/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,6 @@ def get_key(self, value: Any) -> str:
class HomogeneousProperCompositeSedes(
ProperCompositeSedes[TSerializable, TDeserialized]
):
def __init__(self, element_sedes: TSedes, max_length: int) -> None:
self.element_sedes = element_sedes
if max_length < 1:
raise ValueError(
f"(Maximum) length of homogenous composites must be at least 1, got {max_length}"
)
self.max_length = max_length

def get_sedes_id(self) -> str:
sedes_name = self.__class__.__name__
return f"{sedes_name}({self.element_sedes.get_sedes_id()},{self.max_length})"
Expand Down
2 changes: 1 addition & 1 deletion ssz/sedes/bitlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class Bitlist(BitfieldCompositeSedes[BytesOrByteArray, bytes]):
def __init__(self, max_bit_count: int) -> None:
if max_bit_count < 0:
raise TypeError("Max bit count cannot be negative")
raise ValueError(f"Max bit count cannot be negative, got {max_bit_count}")
self.max_bit_count = max_bit_count

#
Expand Down
6 changes: 4 additions & 2 deletions ssz/sedes/bitvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

class Bitvector(BitfieldCompositeSedes[BytesOrByteArray, bytes]):
def __init__(self, bit_count: int) -> None:
if bit_count <= 0:
raise TypeError("Bit count cannot be zero or negative")
if bit_count < 1:
raise ValueError(
f"Bitvector must have a size of 1 or greater, got {bit_count}"
)
self.bit_count = bit_count

#
Expand Down
10 changes: 9 additions & 1 deletion ssz/sedes/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ssz.exceptions import DeserializationError
from ssz.hashable_list import HashableList
from ssz.hashable_structure import BaseHashableStructure
from ssz.sedes.base import BaseSedes
from ssz.sedes.base import BaseSedes, TSedes
from ssz.sedes.basic import BasicSedes, HomogeneousProperCompositeSedes
from ssz.typing import CacheObj, TDeserialized, TSerializable
from ssz.utils import (
Expand All @@ -30,6 +30,14 @@
class List(
HomogeneousProperCompositeSedes[Sequence[TSerializable], Tuple[TDeserialized, ...]]
):
def __init__(self, element_sedes: TSedes, max_length: int) -> None:
if max_length < 0:
raise ValueError(
f"Lists must have a maximum length of 0 or greater, got {max_length}"
)
self.element_sedes = element_sedes
self.max_length = max_length

#
# Size
#
Expand Down
5 changes: 3 additions & 2 deletions ssz/sedes/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ class Vector(
]
):
def __init__(self, element_sedes: TSedes, length: int) -> None:
if length <= 0:
if length < 1:
Copy link

@saltiniroberto saltiniroberto Jun 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change here? Given that we are dealing with int, <= 0 is equivalent to < 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's equivalent. I wanted to make it more clear and follow the error message more tightly.

raise ValueError(f"Vectors must have a size of 1 or greater, got {length}")
super().__init__(element_sedes, max_length=length)
self.element_sedes = element_sedes
self.max_length = length

@property
def length(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion tests/sedes/test_bitvector_instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@


def test_bitvector_instantiation_bound():
with pytest.raises(TypeError):
with pytest.raises(ValueError):
bit_count = 0
Bitvector(bit_count)
48 changes: 47 additions & 1 deletion tests/sedes/test_composite_sedes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@
from ssz.exceptions import DeserializationError
from ssz.hashable_list import HashableList
from ssz.hashable_vector import HashableVector
from ssz.sedes import Container, List, UInt, Vector, bytes32, uint8, uint256
from ssz.sedes import (
Bitlist,
Bitvector,
Container,
List,
UInt,
Vector,
bytes32,
uint8,
uint256,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -164,3 +174,39 @@ def test_eq(sedes1, sedes2):
def test_neq(sedes1, sedes2):
assert sedes1 != sedes2
assert hash(sedes1) != hash(sedes2)


@pytest.mark.parametrize(
("sedes_type", "element_type", "length", "is_valid"),
(
(List, uint8, 0, True),
(List, uint8, -1, False),
(Vector, uint8, 1, True),
(Vector, uint8, 0, False),
),
)
def test_homogeneous_sequence_length_boundary(
sedes_type, element_type, length, is_valid
):
if is_valid:
sedes_type(element_type, length)
else:
with pytest.raises(ValueError):
sedes_type(element_type, length)


@pytest.mark.parametrize(
("sedes_type", "length", "is_valid"),
(
(Bitlist, 0, True),
(Bitlist, -1, False),
(Bitvector, 1, True),
(Bitvector, 0, False),
),
)
def test_bitfield_length_boundary(sedes_type, length, is_valid):
if is_valid:
sedes_type(length)
else:
with pytest.raises(ValueError):
sedes_type(length)