From 180f6d0823782a19619feef7e8f762ae809dfa26 Mon Sep 17 00:00:00 2001 From: Hsiao-Wei Wang Date: Thu, 18 Jun 2020 13:31:08 +0800 Subject: [PATCH 1/3] Fix homogeneous sequence boundary --- ssz/sedes/basic.py | 8 ----- ssz/sedes/bitlist.py | 2 +- ssz/sedes/bitvector.py | 4 +-- ssz/sedes/list.py | 8 ++++- ssz/sedes/vector.py | 5 +-- tests/sedes/test_bitvector_instantiation.py | 2 +- tests/sedes/test_composite_sedes.py | 36 ++++++++++++++++++++- 7 files changed, 49 insertions(+), 16 deletions(-) diff --git a/ssz/sedes/basic.py b/ssz/sedes/basic.py index f4928ab2..7ff67bc7 100644 --- a/ssz/sedes/basic.py +++ b/ssz/sedes/basic.py @@ -172,14 +172,6 @@ def get_key(self, value: Any) -> str: class HomogeneousProperCompositeSedes( ProperCompositeSedes[TSerializable, TDeserialized] ): - def __init__(self, element_sedes: TSedes, max_length: int) -> None: - self.element_sedes = element_sedes - if max_length < 1: - raise ValueError( - f"(Maximum) length of homogenous composites must be at least 1, got {max_length}" - ) - self.max_length = max_length - def get_sedes_id(self) -> str: sedes_name = self.__class__.__name__ return f"{sedes_name}({self.element_sedes.get_sedes_id()},{self.max_length})" diff --git a/ssz/sedes/bitlist.py b/ssz/sedes/bitlist.py index 3de87435..7e5aca75 100644 --- a/ssz/sedes/bitlist.py +++ b/ssz/sedes/bitlist.py @@ -20,7 +20,7 @@ class Bitlist(BitfieldCompositeSedes[BytesOrByteArray, bytes]): def __init__(self, max_bit_count: int) -> None: if max_bit_count < 0: - raise TypeError("Max bit count cannot be negative") + raise ValueError(f"Max bit count cannot be negative, got {max_bit_count}") self.max_bit_count = max_bit_count # diff --git a/ssz/sedes/bitvector.py b/ssz/sedes/bitvector.py index 9020cc4e..561151c8 100644 --- a/ssz/sedes/bitvector.py +++ b/ssz/sedes/bitvector.py @@ -18,8 +18,8 @@ class Bitvector(BitfieldCompositeSedes[BytesOrByteArray, bytes]): def __init__(self, bit_count: int) -> None: - if bit_count <= 0: - raise TypeError("Bit count cannot be zero or negative") + if bit_count < 1: + raise ValueError(f"Bitvector must have a size of 1 or greater, got {bit_count}") self.bit_count = bit_count # diff --git a/ssz/sedes/list.py b/ssz/sedes/list.py index 39d1436e..d51821b0 100644 --- a/ssz/sedes/list.py +++ b/ssz/sedes/list.py @@ -12,7 +12,7 @@ from ssz.exceptions import DeserializationError from ssz.hashable_list import HashableList from ssz.hashable_structure import BaseHashableStructure -from ssz.sedes.base import BaseSedes +from ssz.sedes.base import BaseSedes, TSedes from ssz.sedes.basic import BasicSedes, HomogeneousProperCompositeSedes from ssz.typing import CacheObj, TDeserialized, TSerializable from ssz.utils import ( @@ -30,6 +30,12 @@ class List( HomogeneousProperCompositeSedes[Sequence[TSerializable], Tuple[TDeserialized, ...]] ): + def __init__(self, element_sedes: TSedes, max_length: int) -> None: + if max_length < 0: + raise ValueError(f"Lists must have a size of 0 or greater, got {max_length}") + self.element_sedes = element_sedes + self.max_length = max_length + # # Size # diff --git a/ssz/sedes/vector.py b/ssz/sedes/vector.py index cc1da7f5..1052d32b 100644 --- a/ssz/sedes/vector.py +++ b/ssz/sedes/vector.py @@ -28,9 +28,10 @@ class Vector( ] ): def __init__(self, element_sedes: TSedes, length: int) -> None: - if length <= 0: + if length < 1: raise ValueError(f"Vectors must have a size of 1 or greater, got {length}") - super().__init__(element_sedes, max_length=length) + self.element_sedes = element_sedes + self.max_length = length @property def length(self) -> int: diff --git a/tests/sedes/test_bitvector_instantiation.py b/tests/sedes/test_bitvector_instantiation.py index 47d5d20f..e2a06902 100644 --- a/tests/sedes/test_bitvector_instantiation.py +++ b/tests/sedes/test_bitvector_instantiation.py @@ -4,6 +4,6 @@ def test_bitvector_instantiation_bound(): - with pytest.raises(TypeError): + with pytest.raises(ValueError): bit_count = 0 Bitvector(bit_count) diff --git a/tests/sedes/test_composite_sedes.py b/tests/sedes/test_composite_sedes.py index e86ef00f..4d882e9e 100644 --- a/tests/sedes/test_composite_sedes.py +++ b/tests/sedes/test_composite_sedes.py @@ -7,7 +7,7 @@ from ssz.exceptions import DeserializationError from ssz.hashable_list import HashableList from ssz.hashable_vector import HashableVector -from ssz.sedes import Container, List, UInt, Vector, bytes32, uint8, uint256 +from ssz.sedes import Bitlist, Bitvector, Container, List, UInt, Vector, bytes32, uint8, uint256 @pytest.mark.parametrize( @@ -164,3 +164,37 @@ def test_eq(sedes1, sedes2): def test_neq(sedes1, sedes2): assert sedes1 != sedes2 assert hash(sedes1) != hash(sedes2) + + +@pytest.mark.parametrize( + ("sedes_type", "element_type", "length", "is_valid"), + ( + (List, uint8, 0, True), + (List, uint8, -1, False), + (Vector, uint8, 1, True), + (Vector, uint8, 0, False), + ), +) +def test_homogeneous_sequence_length_boundary(sedes_type, element_type, length, is_valid): + if is_valid: + sedes_type(element_type, length) + else: + with pytest.raises(ValueError): + sedes_type(element_type, length) + + +@pytest.mark.parametrize( + ("sedes_type", "length", "is_valid"), + ( + (Bitlist, 0, True), + (Bitlist, -1, False), + (Bitvector, 1, True), + (Bitvector, 0, False), + ), +) +def test_bitfield_length_boundary(sedes_type, length, is_valid): + if is_valid: + sedes_type(length) + else: + with pytest.raises(ValueError): + sedes_type(length) From a5d797f0531e57300e0ce29ea10f177353d8c2db Mon Sep 17 00:00:00 2001 From: Hsiao-Wei Wang Date: Thu, 18 Jun 2020 15:29:10 +0800 Subject: [PATCH 2/3] Update error message wording --- ssz/sedes/list.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ssz/sedes/list.py b/ssz/sedes/list.py index d51821b0..01c3192a 100644 --- a/ssz/sedes/list.py +++ b/ssz/sedes/list.py @@ -32,7 +32,7 @@ class List( ): def __init__(self, element_sedes: TSedes, max_length: int) -> None: if max_length < 0: - raise ValueError(f"Lists must have a size of 0 or greater, got {max_length}") + raise ValueError(f"Lists must have a maximum length of 0 or greater, got {max_length}") self.element_sedes = element_sedes self.max_length = max_length From d5a922a413abf9fd10bc7016c55090d81fed95a5 Mon Sep 17 00:00:00 2001 From: Hsiao-Wei Wang Date: Thu, 18 Jun 2020 15:34:26 +0800 Subject: [PATCH 3/3] Fix linter error --- ssz/sedes/bitvector.py | 4 +++- ssz/sedes/list.py | 4 +++- tests/sedes/test_composite_sedes.py | 16 ++++++++++++++-- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/ssz/sedes/bitvector.py b/ssz/sedes/bitvector.py index 561151c8..8960f8d3 100644 --- a/ssz/sedes/bitvector.py +++ b/ssz/sedes/bitvector.py @@ -19,7 +19,9 @@ class Bitvector(BitfieldCompositeSedes[BytesOrByteArray, bytes]): def __init__(self, bit_count: int) -> None: if bit_count < 1: - raise ValueError(f"Bitvector must have a size of 1 or greater, got {bit_count}") + raise ValueError( + f"Bitvector must have a size of 1 or greater, got {bit_count}" + ) self.bit_count = bit_count # diff --git a/ssz/sedes/list.py b/ssz/sedes/list.py index 01c3192a..b3e7af78 100644 --- a/ssz/sedes/list.py +++ b/ssz/sedes/list.py @@ -32,7 +32,9 @@ class List( ): def __init__(self, element_sedes: TSedes, max_length: int) -> None: if max_length < 0: - raise ValueError(f"Lists must have a maximum length of 0 or greater, got {max_length}") + raise ValueError( + f"Lists must have a maximum length of 0 or greater, got {max_length}" + ) self.element_sedes = element_sedes self.max_length = max_length diff --git a/tests/sedes/test_composite_sedes.py b/tests/sedes/test_composite_sedes.py index 4d882e9e..52e07411 100644 --- a/tests/sedes/test_composite_sedes.py +++ b/tests/sedes/test_composite_sedes.py @@ -7,7 +7,17 @@ from ssz.exceptions import DeserializationError from ssz.hashable_list import HashableList from ssz.hashable_vector import HashableVector -from ssz.sedes import Bitlist, Bitvector, Container, List, UInt, Vector, bytes32, uint8, uint256 +from ssz.sedes import ( + Bitlist, + Bitvector, + Container, + List, + UInt, + Vector, + bytes32, + uint8, + uint256, +) @pytest.mark.parametrize( @@ -175,7 +185,9 @@ def test_neq(sedes1, sedes2): (Vector, uint8, 0, False), ), ) -def test_homogeneous_sequence_length_boundary(sedes_type, element_type, length, is_valid): +def test_homogeneous_sequence_length_boundary( + sedes_type, element_type, length, is_valid +): if is_valid: sedes_type(element_type, length) else: