diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d874a534d..dbbf3f751 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -12,7 +12,7 @@ jobs: container: python:${{ matrix.python }}-slim strategy: matrix: - python: ['3.6', '3.7', '3.8', '3.9', '3.10'] + python: ['3.8', '3.9', '3.10'] steps: - run: python3 --version - name: Check out code diff --git a/pyteal/__init__.pyi b/pyteal/__init__.pyi index 1591faeea..de3cf3e62 100644 --- a/pyteal/__init__.pyi +++ b/pyteal/__init__.pyi @@ -170,5 +170,6 @@ __all__ = [ "UnaryExpr", "While", "WideRatio", + "abi", "compileTeal", ] diff --git a/pyteal/ast/abi/__init__.py b/pyteal/ast/abi/__init__.py index df4a814f4..b922d0b06 100644 --- a/pyteal/ast/abi/__init__.py +++ b/pyteal/ast/abi/__init__.py @@ -1,20 +1,68 @@ -from .type import Type, ComputedType -from .bool import Bool -from .uint import Uint, Byte, Uint8, Uint16, Uint32, Uint64 -from .tuple import Tuple -from .array import StaticArray, DynamicArray +from .type import TypeSpec, BaseType, ComputedType +from .bool import BoolTypeSpec, Bool +from .uint import ( + UintTypeSpec, + Uint, + ByteTypeSpec, + Byte, + Uint8TypeSpec, + Uint8, + Uint16TypeSpec, + Uint16, + Uint32TypeSpec, + Uint32, + Uint64TypeSpec, + Uint64, +) +from .tuple import ( + TupleTypeSpec, + Tuple, + TupleElement, + Tuple0, + Tuple1, + Tuple2, + Tuple3, + Tuple4, + Tuple5, +) +from .array_base import ArrayTypeSpec, Array, ArrayElement +from .array_static import StaticArrayTypeSpec, StaticArray +from .array_dynamic import DynamicArrayTypeSpec, DynamicArray +from .util import type_spec_from_annotation __all__ = [ - "Type", + "TypeSpec", + "BaseType", "ComputedType", + "BoolTypeSpec", "Bool", + "UintTypeSpec", "Uint", + "ByteTypeSpec", "Byte", + "Uint8TypeSpec", "Uint8", + "Uint16TypeSpec", "Uint16", + "Uint32TypeSpec", "Uint32", + "Uint64TypeSpec", "Uint64", + "TupleTypeSpec", "Tuple", + "TupleElement", + "Tuple0", + "Tuple1", + "Tuple2", + "Tuple3", + "Tuple4", + "Tuple5", + "ArrayTypeSpec", + "Array", + "ArrayElement", + "StaticArrayTypeSpec", "StaticArray", + "DynamicArrayTypeSpec", "DynamicArray", + "type_spec_from_annotation", ] diff --git a/pyteal/ast/abi/array.py b/pyteal/ast/abi/array.py deleted file mode 100644 index d96e5e778..000000000 --- a/pyteal/ast/abi/array.py +++ /dev/null @@ -1,476 +0,0 @@ -from typing import ( - Union, - Sequence, - TypeVar, - Generic, - Optional, - cast, -) -from abc import abstractmethod - -from ...types import TealType, require_type -from ...errors import TealInputError -from ..expr import Expr -from ..seq import Seq -from ..int import Int -from ..if_ import If -from ..unaryexpr import Len -from ..binaryexpr import ExtractUint16 -from ..naryexpr import Concat - -from .type import Type, ComputedType, substringForDecoding -from .tuple import encodeTuple -from .bool import Bool, boolSequenceLength -from .uint import Uint16 - - -T = TypeVar("T", bound=Type) - - -class Array(Type, Generic[T]): - """The base class for both ABI static array and ABI dynamic array. - - This class contains - * both underlying array element ABI type and an optional array length - * a basic implementation for inherited ABI static and dynamic array types, including - * basic array elements setting method - * string representation for ABI array - * basic encoding and decoding of ABI array - * item retrieving by index (expression or integer) - """ - - def __init__(self, valueType: T, staticLength: Optional[int]) -> None: - """Creates a new ABI array. - - This function determines distinct storage format over byte string by inferring if the array - element type is static: - * if it is static, then the stride is the static byte length of the element. - * otherwise the stride is 2 bytes, which is the size of a Uint16. - - This function also stores the static array length of an instance of ABI array, if it exists. - - Args: - valueType: The ABI type of the array element. - staticLength (optional): An integer representing the static length of the array. - """ - super().__init__(TealType.bytes) - self._valueType = valueType - - self._has_offsets = valueType.is_dynamic() - if self._has_offsets: - self._stride = Uint16().byte_length_static() - else: - self._stride = valueType.byte_length_static() - self._static_length = staticLength - - def decode( - self, - encoded: Expr, - *, - startIndex: Expr = None, - endIndex: Expr = None, - length: Expr = None - ) -> Expr: - """Decode a substring of the passed in encoded byte string and set it as this type's value. - - Args: - encoded: An expression containing the bytes to decode. Must evaluate to TealType.bytes. - startIndex (optional): An expression containing the index to start decoding. Must - evaluate to TealType.uint64. Defaults to None. - endIndex (optional): An expression containing the index to stop decoding. Must evaluate - to TealType.uint64. Defaults to None. - length (optional): An expression containing the length of the substring to decode. Must - evaluate to TealType.uint64. Defaults to None. - - Returns: - An expression that partitions the needed parts from given byte strings and stores into - the scratch variable. - """ - extracted = substringForDecoding( - encoded, startIndex=startIndex, endIndex=endIndex, length=length - ) - return self.stored_value.store(extracted) - - def set(self, values: Sequence[T]) -> Expr: - """Set the ABI array with a sequence of ABI type variables. - - The function first type-check the argument `values` to make sure the sequence of ABI type - variables before storing them to the underlying ScratchVar. If any of the input element does - not match expected array element type, error would be raised about type-mismatch. - - If static length of array is not available, this function would - * infer the array length from the sequence element number. - * store the inferred array length in uint16 format. - * concatenate the encoded array length at the beginning of array encoding. - - Args: - values: The sequence of ABI type variables to store in ABI array. - - Returns: - A PyTeal expression that stores encoded sequence of ABI values in its internal - ScratchVar. - """ - for index, value in enumerate(values): - if not self._valueType.has_same_type_as(value): - raise TealInputError( - "Cannot assign type {} at index {} to {}".format( - value, index, self._valueType - ) - ) - - encoded = encodeTuple(values) - - if self._static_length is None: - length_tmp = Uint16() - length_prefix = Seq(length_tmp.set(len(values)), length_tmp.encode()) - encoded = Concat(length_prefix, encoded) - - return self.stored_value.store(encoded) - - def encode(self) -> Expr: - """Encode the ABI array to be a byte string. - - Returns: - A PyTeal expression that encodes this ABI array to a byte string. - """ - return self.stored_value.load() - - @abstractmethod - def length(self) -> Expr: - """Get the element number of this ABI array. - - Returns: - A PyTeal expression that represents the array length. - """ - pass - - def __getitem__(self, index: Union[int, Expr]) -> "ArrayElement[T]": - """Retrieve an ABI array element by an index (either a PyTeal expression or an integer). - - If the static array length is available and the argument index is integer, the function - checks if the index is in [0, static_length - 1]. - - Args: - index: either an integer or a PyTeal expression that evaluates to a uint64. - - Returns: - An ArrayElement that represents the ABI array element at the index. - """ - if type(index) is int: - if index < 0 or ( - self._static_length is not None and index >= self._static_length - ): - raise TealInputError("Index out of bounds: {}".format(index)) - index = Int(index) - return ArrayElement(self, cast(Expr, index)) - - def __str__(self) -> str: - """Get the string representation of ABI array type, for creating method signatures.""" - return self._valueType.__str__() + ( - "[]" if self._static_length is None else "[{}]".format(self._static_length) - ) - - -Array.__module__ = "pyteal" - -# until something like https://github.com/python/mypy/issues/3345 is added, we can't make the size of the array a generic parameter -class StaticArray(Array[T]): - """The class that represents ABI static array type. - - This class requires static length on initialization. - """ - - def __init__(self, valueType: T, length: int) -> None: - """Create a new static array. - - Checks if the static array length is non-negative, if it is negative throw error. - - Args: - valueType: The ABI type of the array element. - length: An integer representing the static length of the array. - """ - if length < 0: - raise TealInputError( - "Static array length cannot be negative. Got {}".format(length) - ) - super().__init__(valueType, length) - - def has_same_type_as(self, other: Type) -> bool: - """Check if this type is considered equal to the other ABI type, irrespective of their - values. - - For static array, this is determined by: - * the equivalance of static array type. - * the underlying array element type equivalance. - * the static array length equivalance. - - Args: - other: The ABI type to compare to. - - Returns: - True if and only if self and other can store the same ABI value. - """ - return ( - type(other) is StaticArray - and self._valueType.has_same_type_as(other._valueType) - and self.length_static() == other.length_static() - ) - - def new_instance(self) -> "StaticArray[T]": - """Create a new instance of this ABI type. - - The value of this type will not be applied to the new type. - """ - return StaticArray(self._valueType, self.length_static()) - - def is_dynamic(self) -> bool: - """Check if this ABI type is dynamic. - - Whether this ABI static array is dymamic is decided by its array elements' ABI type. - """ - return self._valueType.is_dynamic() - - def byte_length_static(self) -> int: - """Get the byte length of this ABI static array's encoding. - - Only valid when array elements' type is static. - """ - if self.is_dynamic(): - raise ValueError("Type is dynamic") - if type(self._valueType) is Bool: - return boolSequenceLength(self.length_static()) - return self.length_static() * self._valueType.byte_length_static() - - def set(self, values: Union[Sequence[T], "StaticArray[T]"]) -> Expr: - """Set the ABI static array with a sequence of ABI type variables, or another ABI static - array. - - This function determines if the argument `values` is an ABI static array: - * if so: - * checks whether `values` is same type as this ABI staic array. - * stores the encoding of `values`. - * if not: - * checks whether static array length matches sequence length. - * calls the inherited `set` function and stores `values`. - - Args: - values: either a sequence of ABI typed values, or an ABI static array. - - Returns: - A PyTeal expression that stores encoded `values` in its internal ScratchVar. - """ - if isinstance(values, Type): - if not self.has_same_type_as(values): - raise TealInputError("Cannot assign type {} to {}".format(values, self)) - return self.stored_value.store(cast(StaticArray[T], values).encode()) - - if self.length_static() != len(values): - raise TealInputError( - "Incorrect length for values. Expected {}, got {}".format( - self.length_static(), len(values) - ) - ) - return super().set(values) - - def length_static(self) -> int: - """Get the element number of this static ABI array. - - Returns: - A Python integer that represents the static array length. - """ - return cast(int, self._static_length) - - def length(self) -> Expr: - """Get the element number of this ABI static array. - - Returns: - A PyTeal expression that represents the static array length. - """ - return Int(self.length_static()) - - -StaticArray.__module__ = "pyteal" - - -class DynamicArray(Array[T]): - """The class that represents ABI dynamic array type.""" - - def __init__(self, valueType: T) -> None: - """Creates a new dynamic array. - - Args: - valueType: The ABI type of the array element. - """ - super().__init__(valueType, None) - - def has_same_type_as(self, other: Type) -> bool: - """Check if this type is considered equal to the other ABI type, irrespective of their - values. - - For dynamic array, this is determined by - * the equivalance of dynamic array type. - * the underlying array element type equivalence. - - Args: - other: The ABI type to compare to. - - Returns: - True if and only if self and other can store the same ABI value. - """ - return type(other) is DynamicArray and self._valueType.has_same_type_as( - other._valueType - ) - - def new_instance(self) -> "DynamicArray[T]": - """Create a new instance of this ABI type. - - The value of this type will not be applied to the new type. - """ - return DynamicArray(self._valueType) - - def is_dynamic(self) -> bool: - """Check if this ABI type is dynamic. - - An ABI dynamic array is always dynamic. - """ - return True - - def byte_length_static(self) -> int: - """Get the byte length of this ABI dynamic array's encoding. - - Always raise error for this method is only valid for static ABI types. - """ - raise ValueError("Type is dynamic") - - def set(self, values: Union[Sequence[T], "DynamicArray[T]"]) -> Expr: - """Set the ABI dynamic array with a sequence of ABI type variables, or another ABI dynamic - array. - - This function determines if the argument `values` is an ABI dynamic array: - * if so: - * checks whether `values` is same type as this ABI dynamic array. - * stores the encoding of `values`. - * if not: - * calls the inherited `set` function and stores `values`. - - Args: - values: either a sequence of ABI typed values, or an ABI dynamic array. - - Returns: - A PyTeal expression that stores encoded `values` in its internal ScratchVar. - """ - if isinstance(values, Type): - if not self.has_same_type_as(values): - raise TealInputError("Cannot assign type {} to {}".format(values, self)) - return self.stored_value.store(cast(DynamicArray[T], values).encode()) - return super().set(values) - - def length(self) -> Expr: - """Get the element number of this ABI dynamic array. - - The array length (element number) is encoded in the first 2 bytes of the byte encoding. - - Returns: - A PyTeal expression that represents the dynamic array length. - """ - output = Uint16() - return Seq( - output.decode(self.encode()), - output.get(), - ) - - -DynamicArray.__module__ = "pyteal" - - -class ArrayElement(ComputedType[T]): - """The class that represents an ABI array element. - - This class requires a reference to the array that the array element belongs to, and a PyTeal - expression (required to be TealType.uint64) which stands for array index. - """ - - def __init__(self, array: Array[T], index: Expr) -> None: - """Creates a new ArrayElement. - - Args: - array: The ABI array that the array element belongs to. - index: A PyTeal expression (required to be TealType.uint64) stands for array index. - """ - super().__init__(array._valueType) - require_type(index, TealType.uint64) - self.array = array - self.index = index - - def store_into(self, output: T) -> Expr: - """Partitions the byte string of the given ABI array and stores the byte string of array - element in the ABI value output. - - The function first checks if the output type matches with array element type, and throw - error if type-mismatch. - - Args: - output: An ABI typed value that the array element byte string stores into. - - Returns: - An expression that stores the byte string of the array element into value `output`. - """ - if not self.array._valueType.has_same_type_as(output): - raise TealInputError("Output type does not match value type") - - encodedArray = self.array.encode() - - # If the array element type is Bool, we compute the bit index - # (if array is dynamic we add 16 to bit index for dynamic array length uint16 prefix) - # and decode bit with given array encoding and the bit index for boolean bit. - if type(output) is Bool: - bitIndex = self.index - if self.array.is_dynamic(): - bitIndex = bitIndex + Int(Uint16().bits()) - return cast(Bool, output).decodeBit(encodedArray, bitIndex) - - # Compute the byteIndex (first byte indicating the element encoding) - # (If the array is dynamic, add 2 to byte index for dynamic array length uint16 prefix) - byteIndex = Int(self.array._stride) * self.index - if self.array._static_length is None: - byteIndex = byteIndex + Int(Uint16().byte_length_static()) - - arrayLength = self.array.length() - - # Handling case for array elements are dynamic: - # * `byteIndex` is pointing at the uint16 byte encoding indicating the beginning offset of - # the array element byte encoding. - # - # * `valueStart` is extracted from the uint16 bytes pointed by `byteIndex`. - # - # * If `index == arrayLength - 1` (last element in array), `valueEnd` is pointing at the - # end of the array byte encoding. - # - # * otherwise, `valueEnd` is inferred from `nextValueStart`, which is the beginning offset - # of the next array element byte encoding. - if self.array._valueType.is_dynamic(): - valueStart = ExtractUint16(encodedArray, byteIndex) - nextValueStart = ExtractUint16( - encodedArray, byteIndex + Int(Uint16().byte_length_static()) - ) - if self.array._static_length is None: - valueStart = valueStart + Int(Uint16().byte_length_static()) - nextValueStart = nextValueStart + Int(Uint16().byte_length_static()) - - valueEnd = ( - If(self.index + Int(1) == arrayLength) - .Then(Len(encodedArray)) - .Else(nextValueStart) - ) - - return output.decode(encodedArray, startIndex=valueStart, endIndex=valueEnd) - - # Handling case for array elements are static: - # since array._stride is element's static byte length - # we partition the substring for array element. - valueStart = byteIndex - valueLength = Int(self.array._stride) - return output.decode(encodedArray, startIndex=valueStart, length=valueLength) - - -ArrayElement.__module__ = "pyteal" diff --git a/pyteal/ast/abi/array_base.py b/pyteal/ast/abi/array_base.py new file mode 100644 index 000000000..676146d75 --- /dev/null +++ b/pyteal/ast/abi/array_base.py @@ -0,0 +1,281 @@ +from typing import ( + Union, + Sequence, + TypeVar, + Generic, + Final, + cast, +) +from abc import abstractmethod + +from ...types import TealType, require_type +from ...errors import TealInputError +from ..expr import Expr +from ..seq import Seq +from ..int import Int +from ..if_ import If +from ..unaryexpr import Len +from ..binaryexpr import ExtractUint16 +from ..naryexpr import Concat + +from .type import TypeSpec, BaseType, ComputedType +from .tuple import encodeTuple +from .bool import Bool, BoolTypeSpec +from .uint import Uint16, Uint16TypeSpec +from .util import substringForDecoding + +T = TypeVar("T", bound=BaseType) + + +class ArrayTypeSpec(TypeSpec, Generic[T]): + """The abstract base class for both static and dynamic array TypeSpecs.""" + + def __init__(self, value_type_spec: TypeSpec) -> None: + super().__init__() + self.value_spec: Final = value_type_spec + + def value_type_spec(self) -> TypeSpec: + """Get the TypeSpec of the value type this array can hold.""" + return self.value_spec + + def storage_type(self) -> TealType: + return TealType.bytes + + @abstractmethod + def is_length_dynamic(self) -> bool: + """Check if this array has a dynamic or static length.""" + pass + + def _stride(self) -> int: + """Get the "stride" of this array. + + The stride is defined as the byte length of each element in the array's encoded "head" + portion. + + If the underlying value type is static, then the stride is the static byte length of that + type. Otherwise, the stride is the static byte length of a Uint16 (2 bytes). + """ + if self.value_spec.is_dynamic(): + return Uint16TypeSpec().byte_length_static() + return self.value_spec.byte_length_static() + + +ArrayTypeSpec.__module__ = "pyteal" + + +class Array(BaseType, Generic[T]): + """The abstract base class for both ABI static and dynamic array instances. + + This class contains basic implementations of ABI array methods, including: + * basic array elements setting method + * basic encoding and decoding of ABI array + * item retrieving by index (expression or integer) + """ + + def __init__(self, spec: ArrayTypeSpec) -> None: + super().__init__(spec) + + def type_spec(self) -> ArrayTypeSpec[T]: + return cast(ArrayTypeSpec, super().type_spec()) + + def decode( + self, + encoded: Expr, + *, + startIndex: Expr = None, + endIndex: Expr = None, + length: Expr = None + ) -> Expr: + """Decode a substring of the passed in encoded byte string and set it as this type's value. + + Args: + encoded: An expression containing the bytes to decode. Must evaluate to TealType.bytes. + startIndex (optional): An expression containing the index to start decoding. Must + evaluate to TealType.uint64. Defaults to None. + endIndex (optional): An expression containing the index to stop decoding. Must evaluate + to TealType.uint64. Defaults to None. + length (optional): An expression containing the length of the substring to decode. Must + evaluate to TealType.uint64. Defaults to None. + + Returns: + An expression that partitions the needed parts from given byte strings and stores into + the scratch variable. + """ + extracted = substringForDecoding( + encoded, startIndex=startIndex, endIndex=endIndex, length=length + ) + return self.stored_value.store(extracted) + + def set(self, values: Sequence[T]) -> Expr: + """Set the ABI array with a sequence of ABI type variables. + + The function first type-check the argument `values` to make sure the sequence of ABI type + variables before storing them to the underlying ScratchVar. If any of the input element does + not match expected array element type, error would be raised about type-mismatch. + + If static length of array is not available, this function would + * infer the array length from the sequence element number. + * store the inferred array length in uint16 format. + * concatenate the encoded array length at the beginning of array encoding. + + Args: + values: The sequence of ABI type variables to store in ABI array. + + Returns: + A PyTeal expression that stores encoded sequence of ABI values in its internal + ScratchVar. + """ + for index, value in enumerate(values): + if self.type_spec().value_type_spec() != value.type_spec(): + raise TealInputError( + "Cannot assign type {} at index {} to {}".format( + value.type_spec(), + index, + self.type_spec().value_type_spec(), + ) + ) + + encoded = encodeTuple(values) + + if self.type_spec().is_length_dynamic(): + length_tmp = Uint16() + length_prefix = Seq(length_tmp.set(len(values)), length_tmp.encode()) + encoded = Concat(length_prefix, encoded) + + return self.stored_value.store(encoded) + + def encode(self) -> Expr: + """Encode the ABI array to be a byte string. + + Returns: + A PyTeal expression that encodes this ABI array to a byte string. + """ + return self.stored_value.load() + + @abstractmethod + def length(self) -> Expr: + """Get the element number of this ABI array. + + Returns: + A PyTeal expression that represents the array length. + """ + pass + + def __getitem__(self, index: Union[int, Expr]) -> "ArrayElement[T]": + """Retrieve an ABI array element by an index (either a PyTeal expression or an integer). + + If the argument index is integer, the function will raise an error if the index is negative. + + Args: + index: either an integer or a PyTeal expression that evaluates to a uint64. + + Returns: + An ArrayElement that represents the ABI array element at the index. + """ + if type(index) is int: + if index < 0: + raise TealInputError("Index out of bounds: {}".format(index)) + index = Int(index) + return ArrayElement(self, cast(Expr, index)) + + +Array.__module__ = "pyteal" + + +class ArrayElement(ComputedType[T]): + """The class that represents an ABI array element. + + This class requires a reference to the array that the array element belongs to, and a PyTeal + expression (required to be TealType.uint64) which contains the array index. + """ + + def __init__(self, array: Array[T], index: Expr) -> None: + """Creates a new ArrayElement. + + Args: + array: The ABI array that the array element belongs to. + index: A PyTeal expression (required to be TealType.uint64) stands for array index. + """ + super().__init__() + require_type(index, TealType.uint64) + self.array = array + self.index = index + + def produced_type_spec(self) -> TypeSpec: + return self.array.type_spec().value_type_spec() + + def store_into(self, output: T) -> Expr: + """Partitions the byte string of the given ABI array and stores the byte string of array + element in the ABI value output. + + The function first checks if the output type matches with array element type, and throw + error if type-mismatch. + + Args: + output: An ABI typed value that the array element byte string stores into. + + Returns: + An expression that stores the byte string of the array element into value `output`. + """ + if output.type_spec() != self.produced_type_spec(): + raise TealInputError("Output type does not match value type") + + encodedArray = self.array.encode() + arrayType = self.array.type_spec() + + # If the array element type is Bool, we compute the bit index + # (if array is dynamic we add 16 to bit index for dynamic array length uint16 prefix) + # and decode bit with given array encoding and the bit index for boolean bit. + if output.type_spec() == BoolTypeSpec(): + bitIndex = self.index + if arrayType.is_dynamic(): + bitIndex = bitIndex + Int(Uint16TypeSpec().bit_size()) + return cast(Bool, output).decodeBit(encodedArray, bitIndex) + + # Compute the byteIndex (first byte indicating the element encoding) + # (If the array is dynamic, add 2 to byte index for dynamic array length uint16 prefix) + byteIndex = Int(arrayType._stride()) * self.index + if arrayType.is_length_dynamic(): + byteIndex = byteIndex + Int(Uint16TypeSpec().byte_length_static()) + + arrayLength = self.array.length() + + # Handling case for array elements are dynamic: + # * `byteIndex` is pointing at the uint16 byte encoding indicating the beginning offset of + # the array element byte encoding. + # + # * `valueStart` is extracted from the uint16 bytes pointed by `byteIndex`. + # + # * If `index == arrayLength - 1` (last element in array), `valueEnd` is pointing at the + # end of the array byte encoding. + # + # * otherwise, `valueEnd` is inferred from `nextValueStart`, which is the beginning offset + # of the next array element byte encoding. + if arrayType.value_type_spec().is_dynamic(): + valueStart = ExtractUint16(encodedArray, byteIndex) + nextValueStart = ExtractUint16( + encodedArray, byteIndex + Int(Uint16TypeSpec().byte_length_static()) + ) + if arrayType.is_length_dynamic(): + valueStart = valueStart + Int(Uint16TypeSpec().byte_length_static()) + nextValueStart = nextValueStart + Int( + Uint16TypeSpec().byte_length_static() + ) + + valueEnd = ( + If(self.index + Int(1) == arrayLength) + .Then(Len(encodedArray)) + .Else(nextValueStart) + ) + + return output.decode(encodedArray, startIndex=valueStart, endIndex=valueEnd) + + # Handling case for array elements are static: + # since array._stride() is element's static byte length + # we partition the substring for array element. + valueStart = byteIndex + valueLength = Int(arrayType._stride()) + return output.decode(encodedArray, startIndex=valueStart, length=valueLength) + + +ArrayElement.__module__ = "pyteal" diff --git a/pyteal/ast/abi/array_base_test.py b/pyteal/ast/abi/array_base_test.py new file mode 100644 index 000000000..d1120e5e4 --- /dev/null +++ b/pyteal/ast/abi/array_base_test.py @@ -0,0 +1,158 @@ +from typing import List, cast +import pytest + +from ... import * + +options = CompileOptions(version=5) + +STATIC_TYPES: List[abi.TypeSpec] = [ + abi.BoolTypeSpec(), + abi.Uint8TypeSpec(), + abi.Uint16TypeSpec(), + abi.Uint32TypeSpec(), + abi.Uint64TypeSpec(), + abi.TupleTypeSpec(), + abi.TupleTypeSpec(abi.BoolTypeSpec(), abi.BoolTypeSpec(), abi.Uint64TypeSpec()), + abi.StaticArrayTypeSpec(abi.BoolTypeSpec(), 10), + abi.StaticArrayTypeSpec(abi.Uint8TypeSpec(), 10), + abi.StaticArrayTypeSpec(abi.Uint16TypeSpec(), 10), + abi.StaticArrayTypeSpec(abi.Uint32TypeSpec(), 10), + abi.StaticArrayTypeSpec(abi.Uint64TypeSpec(), 10), + abi.StaticArrayTypeSpec( + abi.TupleTypeSpec(abi.BoolTypeSpec(), abi.BoolTypeSpec(), abi.Uint64TypeSpec()), + 10, + ), +] + +DYNAMIC_TYPES: List[abi.TypeSpec] = [ + abi.DynamicArrayTypeSpec(abi.BoolTypeSpec()), + abi.DynamicArrayTypeSpec(abi.Uint8TypeSpec()), + abi.DynamicArrayTypeSpec(abi.Uint16TypeSpec()), + abi.DynamicArrayTypeSpec(abi.Uint32TypeSpec()), + abi.DynamicArrayTypeSpec(abi.Uint64TypeSpec()), + abi.DynamicArrayTypeSpec(abi.TupleTypeSpec()), + abi.DynamicArrayTypeSpec( + abi.TupleTypeSpec(abi.BoolTypeSpec(), abi.BoolTypeSpec(), abi.Uint64TypeSpec()) + ), + abi.DynamicArrayTypeSpec(abi.StaticArrayTypeSpec(abi.BoolTypeSpec(), 10)), + abi.DynamicArrayTypeSpec(abi.StaticArrayTypeSpec(abi.Uint8TypeSpec(), 10)), + abi.DynamicArrayTypeSpec(abi.StaticArrayTypeSpec(abi.Uint16TypeSpec(), 10)), + abi.DynamicArrayTypeSpec(abi.StaticArrayTypeSpec(abi.Uint32TypeSpec(), 10)), + abi.DynamicArrayTypeSpec(abi.StaticArrayTypeSpec(abi.Uint64TypeSpec(), 10)), + abi.DynamicArrayTypeSpec( + abi.StaticArrayTypeSpec( + abi.TupleTypeSpec( + abi.BoolTypeSpec(), abi.BoolTypeSpec(), abi.Uint64TypeSpec() + ), + 10, + ) + ), +] + + +def test_ArrayElement_init(): + dynamicArrayType = abi.DynamicArrayTypeSpec(abi.Uint64TypeSpec()) + array = dynamicArrayType.new_instance() + index = Int(6) + + element = abi.ArrayElement(array, index) + assert element.array is array + assert element.index is index + + with pytest.raises(TealTypeError): + abi.ArrayElement(array, Bytes("abc")) + + with pytest.raises(TealTypeError): + abi.ArrayElement(array, Assert(index)) + + +def test_ArrayElement_store_into(): + for elementType in STATIC_TYPES + DYNAMIC_TYPES: + staticArrayType = abi.StaticArrayTypeSpec(elementType, 100) + staticArray = staticArrayType.new_instance() + index = Int(9) + + element = abi.ArrayElement(staticArray, index) + output = elementType.new_instance() + expr = element.store_into(output) + + encoded = staticArray.encode() + stride = Int(staticArray.type_spec()._stride()) + expectedLength = staticArray.length() + if elementType == abi.BoolTypeSpec(): + expectedExpr = cast(abi.Bool, output).decodeBit(encoded, index) + elif not elementType.is_dynamic(): + expectedExpr = output.decode( + encoded, startIndex=stride * index, length=stride + ) + else: + expectedExpr = output.decode( + encoded, + startIndex=ExtractUint16(encoded, stride * index), + endIndex=If(index + Int(1) == expectedLength) + .Then(Len(encoded)) + .Else(ExtractUint16(encoded, stride * index + Int(2))), + ) + + expected, _ = expectedExpr.__teal__(options) + expected.addIncoming() + expected = TealBlock.NormalizeBlocks(expected) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = TealBlock.NormalizeBlocks(actual) + + with TealComponent.Context.ignoreExprEquality(): + assert actual == expected + + with pytest.raises(TealInputError): + element.store_into(abi.Tuple(elementType)) + + for elementType in STATIC_TYPES + DYNAMIC_TYPES: + dynamicArrayType = abi.DynamicArrayTypeSpec(elementType) + dynamicArray = dynamicArrayType.new_instance() + index = Int(9) + + element = abi.ArrayElement(dynamicArray, index) + output = elementType.new_instance() + expr = element.store_into(output) + + encoded = dynamicArray.encode() + stride = Int(dynamicArray.type_spec()._stride()) + expectedLength = dynamicArray.length() + if elementType == abi.BoolTypeSpec(): + expectedExpr = cast(abi.Bool, output).decodeBit(encoded, index + Int(16)) + elif not elementType.is_dynamic(): + expectedExpr = output.decode( + encoded, startIndex=stride * index + Int(2), length=stride + ) + else: + expectedExpr = output.decode( + encoded, + startIndex=ExtractUint16(encoded, stride * index + Int(2)) + Int(2), + endIndex=If(index + Int(1) == expectedLength) + .Then(Len(encoded)) + .Else( + ExtractUint16(encoded, stride * index + Int(2) + Int(2)) + Int(2) + ), + ) + + expected, _ = expectedExpr.__teal__(options) + expected.addIncoming() + expected = TealBlock.NormalizeBlocks(expected) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = TealBlock.NormalizeBlocks(actual) + + with TealComponent.Context.ignoreExprEquality(): + with TealComponent.Context.ignoreScratchSlotEquality(): + assert actual == expected + + assert TealBlock.MatchScratchSlotReferences( + TealBlock.GetReferencedScratchSlots(actual), + TealBlock.GetReferencedScratchSlots(expected), + ) + + with pytest.raises(TealInputError): + element.store_into(abi.Tuple(elementType)) diff --git a/pyteal/ast/abi/array_dynamic.py b/pyteal/ast/abi/array_dynamic.py new file mode 100644 index 000000000..8af860818 --- /dev/null +++ b/pyteal/ast/abi/array_dynamic.py @@ -0,0 +1,98 @@ +from typing import ( + Union, + Sequence, + TypeVar, + cast, +) + + +from ...errors import TealInputError +from ..expr import Expr +from ..seq import Seq + +from .type import TypeSpec, BaseType +from .uint import Uint16 +from .array_base import ArrayTypeSpec, Array + + +T = TypeVar("T", bound=BaseType) + + +class DynamicArrayTypeSpec(ArrayTypeSpec[T]): + def new_instance(self) -> "DynamicArray[T]": + return DynamicArray(self.value_type_spec()) + + def is_length_dynamic(self) -> bool: + return True + + def is_dynamic(self) -> bool: + return True + + def byte_length_static(self) -> int: + raise ValueError("Type is dynamic") + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, DynamicArrayTypeSpec) + and self.value_type_spec() == other.value_type_spec() + ) + + def __str__(self) -> str: + return "{}[]".format(self.value_type_spec()) + + +DynamicArrayTypeSpec.__module__ = "pyteal" + + +class DynamicArray(Array[T]): + """The class that represents ABI dynamic array type.""" + + def __init__(self, value_type_spec: TypeSpec) -> None: + super().__init__(DynamicArrayTypeSpec(value_type_spec)) + + def type_spec(self) -> DynamicArrayTypeSpec[T]: + return cast(DynamicArrayTypeSpec[T], super().type_spec()) + + def set(self, values: Union[Sequence[T], "DynamicArray[T]"]) -> Expr: + """Set the ABI dynamic array with a sequence of ABI type variables, or another ABI dynamic + array. + + This function determines if the argument `values` is an ABI dynamic array: + * if so: + * checks whether `values` is same type as this ABI dynamic array. + * stores the encoding of `values`. + * if not: + * calls the inherited `set` function and stores `values`. + + Args: + values: either a sequence of ABI typed values, or an ABI dynamic array. + + Returns: + A PyTeal expression that stores encoded `values` in its internal ScratchVar. + """ + if isinstance(values, BaseType): + if self.type_spec() != values.type_spec(): + raise TealInputError( + "Cannot assign type {} to {}".format( + values.type_spec(), self.type_spec() + ) + ) + return self.stored_value.store(values.encode()) + return super().set(values) + + def length(self) -> Expr: + """Get the element number of this ABI dynamic array. + + The array length (element number) is encoded in the first 2 bytes of the byte encoding. + + Returns: + A PyTeal expression that represents the dynamic array length. + """ + output = Uint16() + return Seq( + output.decode(self.encode()), + output.get(), + ) + + +DynamicArray.__module__ = "pyteal" diff --git a/pyteal/ast/abi/array_dynamic_test.py b/pyteal/ast/abi/array_dynamic_test.py new file mode 100644 index 000000000..7b8704fa4 --- /dev/null +++ b/pyteal/ast/abi/array_dynamic_test.py @@ -0,0 +1,235 @@ +from typing import List +import pytest + +from ... import * +from .util import substringForDecoding +from .tuple import encodeTuple +from .array_base_test import STATIC_TYPES, DYNAMIC_TYPES + +options = CompileOptions(version=5) + + +def test_DynamicArrayTypeSpec_init(): + for elementType in STATIC_TYPES: + dynamicArrayType = abi.DynamicArrayTypeSpec(elementType) + assert dynamicArrayType.value_type_spec() is elementType + assert dynamicArrayType.is_length_dynamic() + assert dynamicArrayType._stride() == elementType.byte_length_static() + + for elementType in DYNAMIC_TYPES: + dynamicArrayType = abi.DynamicArrayTypeSpec(elementType) + assert dynamicArrayType.value_type_spec() is elementType + assert dynamicArrayType.is_length_dynamic() + assert dynamicArrayType._stride() == 2 + + +def test_DynamicArrayTypeSpec_str(): + for elementType in STATIC_TYPES + DYNAMIC_TYPES: + dynamicArrayType = abi.DynamicArrayTypeSpec(elementType) + assert str(dynamicArrayType) == "{}[]".format(elementType) + + +def test_DynamicArrayTypeSpec_new_instance(): + for elementType in STATIC_TYPES + DYNAMIC_TYPES: + dynamicArrayType = abi.DynamicArrayTypeSpec(elementType) + instance = dynamicArrayType.new_instance() + assert isinstance(instance, abi.DynamicArray) + assert instance.type_spec() == dynamicArrayType + + +def test_DynamicArrayTypeSpec_eq(): + for elementType in STATIC_TYPES + DYNAMIC_TYPES: + dynamicArrayType = abi.DynamicArrayTypeSpec(elementType) + assert dynamicArrayType == dynamicArrayType + assert dynamicArrayType != abi.TupleTypeSpec(dynamicArrayType) + + +def test_DynamicArrayTypeSpec_is_dynamic(): + for elementType in STATIC_TYPES + DYNAMIC_TYPES: + dynamicArrayType = abi.DynamicArrayTypeSpec(elementType) + assert dynamicArrayType.is_dynamic() + + +def test_DynamicArrayTypeSpec_byte_length_static(): + for elementType in STATIC_TYPES + DYNAMIC_TYPES: + dynamicArrayType = abi.DynamicArrayTypeSpec(elementType) + with pytest.raises(ValueError): + dynamicArrayType.byte_length_static() + + +def test_DynamicArray_decode(): + encoded = Bytes("encoded") + dynamicArrayType = abi.DynamicArrayTypeSpec(abi.Uint64TypeSpec()) + for startIndex in (None, Int(1)): + for endIndex in (None, Int(2)): + for length in (None, Int(3)): + value = dynamicArrayType.new_instance() + + if endIndex is not None and length is not None: + with pytest.raises(TealInputError): + value.decode( + encoded, + startIndex=startIndex, + endIndex=endIndex, + length=length, + ) + continue + + expr = value.decode( + encoded, startIndex=startIndex, endIndex=endIndex, length=length + ) + assert expr.type_of() == TealType.none + assert not expr.has_return() + + expectedExpr = value.stored_value.store( + substringForDecoding( + encoded, startIndex=startIndex, endIndex=endIndex, length=length + ) + ) + expected, _ = expectedExpr.__teal__(options) + expected.addIncoming() + expected = TealBlock.NormalizeBlocks(expected) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = TealBlock.NormalizeBlocks(actual) + + with TealComponent.Context.ignoreExprEquality(): + assert actual == expected + + +def test_DynamicArray_set_values(): + valuesToSet: List[abi.Uint64] = [ + [], + [abi.Uint64()], + [abi.Uint64() for _ in range(10)], + ] + + dynamicArrayType = abi.DynamicArrayTypeSpec(abi.Uint64TypeSpec()) + for values in valuesToSet: + value = dynamicArrayType.new_instance() + expr = value.set(values) + assert expr.type_of() == TealType.none + assert not expr.has_return() + + length_tmp = abi.Uint16() + expectedExpr = value.stored_value.store( + Concat( + Seq(length_tmp.set(len(values)), length_tmp.encode()), + encodeTuple(values), + ) + ) + expected, _ = expectedExpr.__teal__(options) + expected.addIncoming() + expected = TealBlock.NormalizeBlocks(expected) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = TealBlock.NormalizeBlocks(actual) + + with TealComponent.Context.ignoreExprEquality(): + with TealComponent.Context.ignoreScratchSlotEquality(): + assert actual == expected + + assert TealBlock.MatchScratchSlotReferences( + TealBlock.GetReferencedScratchSlots(actual), + TealBlock.GetReferencedScratchSlots(expected), + ) + + +def test_DynamicArray_set_copy(): + dynamicArrayType = abi.DynamicArrayTypeSpec(abi.Uint64TypeSpec()) + value = dynamicArrayType.new_instance() + otherArray = dynamicArrayType.new_instance() + + with pytest.raises(TealInputError): + value.set(abi.DynamicArray(abi.DynamicArrayTypeSpec(abi.Uint8TypeSpec()))) + + with pytest.raises(TealInputError): + value.set(abi.Uint64()) + + expr = value.set(otherArray) + assert expr.type_of() == TealType.none + assert not expr.has_return() + + expected = TealSimpleBlock( + [ + TealOp(None, Op.load, otherArray.stored_value.slot), + TealOp(None, Op.store, value.stored_value.slot), + ] + ) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = TealBlock.NormalizeBlocks(actual) + + with TealComponent.Context.ignoreExprEquality(): + assert actual == expected + + +def test_DynamicArray_encode(): + dynamicArrayType = abi.DynamicArrayTypeSpec(abi.Uint64TypeSpec()) + value = dynamicArrayType.new_instance() + expr = value.encode() + assert expr.type_of() == TealType.bytes + assert not expr.has_return() + + expected = TealSimpleBlock([TealOp(None, Op.load, value.stored_value.slot)]) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = TealBlock.NormalizeBlocks(actual) + + with TealComponent.Context.ignoreExprEquality(): + assert actual == expected + + +def test_DynamicArray_length(): + dynamicArrayType = abi.DynamicArrayTypeSpec(abi.Uint64TypeSpec()) + value = dynamicArrayType.new_instance() + expr = value.length() + assert expr.type_of() == TealType.uint64 + assert not expr.has_return() + + length_tmp = abi.Uint16() + expectedExpr = Seq(length_tmp.decode(value.encode()), length_tmp.get()) + expected, _ = expectedExpr.__teal__(options) + expected.addIncoming() + expected = TealBlock.NormalizeBlocks(expected) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = TealBlock.NormalizeBlocks(actual) + + with TealComponent.Context.ignoreExprEquality(): + with TealComponent.Context.ignoreScratchSlotEquality(): + assert actual == expected + + assert TealBlock.MatchScratchSlotReferences( + TealBlock.GetReferencedScratchSlots(actual), + TealBlock.GetReferencedScratchSlots(expected), + ) + + +def test_DynamicArray_getitem(): + dynamicArrayType = abi.DynamicArrayTypeSpec(abi.Uint64TypeSpec()) + value = dynamicArrayType.new_instance() + + for index in (0, 1, 2, 3, 1000): + # dynamic indexes + indexExpr = Int(index) + element = value[indexExpr] + assert type(element) is abi.ArrayElement + assert element.array is value + assert element.index is indexExpr + + for index in (0, 1, 2, 3, 1000): + # static indexes + element = value[index] + assert type(element) is abi.ArrayElement + assert element.array is value + assert type(element.index) is Int + assert element.index.value == index + + with pytest.raises(TealInputError): + value[-1] diff --git a/pyteal/ast/abi/array_static.py b/pyteal/ast/abi/array_static.py new file mode 100644 index 000000000..5396a43ad --- /dev/null +++ b/pyteal/ast/abi/array_static.py @@ -0,0 +1,130 @@ +from typing import ( + Union, + Sequence, + TypeVar, + Generic, + Final, + cast, +) + +from ...errors import TealInputError +from ..expr import Expr +from ..int import Int + +from .type import TypeSpec, BaseType +from .bool import BoolTypeSpec, boolSequenceLength +from .array_base import ArrayTypeSpec, Array, ArrayElement + + +T = TypeVar("T", bound=BaseType) +N = TypeVar("N", bound=int) + + +class StaticArrayTypeSpec(ArrayTypeSpec[T], Generic[T, N]): + def __init__(self, value_type_spec: TypeSpec, array_length: int) -> None: + super().__init__(value_type_spec) + if not isinstance(array_length, int) or array_length < 0: + raise TypeError("Unsupported StaticArray length: {}".format(array_length)) + self.array_length: Final = array_length + + def new_instance(self) -> "StaticArray[T, N]": + return StaticArray(self.value_type_spec(), self.length_static()) + + def length_static(self) -> int: + """Get the size of this static array type. + + Returns: + A Python integer that represents the static array length. + """ + return self.array_length + + def is_length_dynamic(self) -> bool: + return False + + def is_dynamic(self) -> bool: + return self.value_type_spec().is_dynamic() + + def byte_length_static(self) -> int: + if self.is_dynamic(): + raise ValueError("Type is dynamic") + + value_type = self.value_type_spec() + length = self.length_static() + + if value_type == BoolTypeSpec(): + return boolSequenceLength(length) + return length * value_type.byte_length_static() + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, StaticArrayTypeSpec) + and self.value_type_spec() == other.value_type_spec() + and self.length_static() == other.length_static() + ) + + def __str__(self) -> str: + return "{}[{}]".format(self.value_type_spec(), self.length_static()) + + +StaticArrayTypeSpec.__module__ = "pyteal" + + +class StaticArray(Array[T], Generic[T, N]): + """The class that represents ABI static array type.""" + + def __init__(self, value_type_spec: TypeSpec, array_length: int) -> None: + super().__init__(StaticArrayTypeSpec(value_type_spec, array_length)) + + def type_spec(self) -> StaticArrayTypeSpec[T, N]: + return cast(StaticArrayTypeSpec[T, N], super().type_spec()) + + def set(self, values: Union[Sequence[T], "StaticArray[T, N]"]) -> Expr: + """Set the ABI static array with a sequence of ABI type variables, or another ABI static + array. + + This function determines if the argument `values` is an ABI static array: + * if so: + * checks whether `values` is same type as this ABI staic array. + * stores the encoding of `values`. + * if not: + * checks whether static array length matches sequence length. + * calls the inherited `set` function and stores `values`. + + Args: + values: either a sequence of ABI typed values, or an ABI static array. + + Returns: + A PyTeal expression that stores encoded `values` in its internal ScratchVar. + """ + if isinstance(values, BaseType): + if self.type_spec() != values.type_spec(): + raise TealInputError( + "Cannot assign type {} to {}".format( + values.type_spec(), self.type_spec() + ) + ) + return self.stored_value.store(values.encode()) + + if self.type_spec().length_static() != len(values): + raise TealInputError( + "Incorrect length for values. Expected {}, got {}".format( + self.type_spec().length_static(), len(values) + ) + ) + return super().set(values) + + def length(self) -> Expr: + """Get the element number of this ABI static array. + + Returns: + A PyTeal expression that represents the static array length. + """ + return Int(self.type_spec().length_static()) + + def __getitem__(self, index: Union[int, Expr]) -> "ArrayElement[T]": + if type(index) is int and index >= self.type_spec().length_static(): + raise TealInputError("Index out of bounds: {}".format(index)) + return super().__getitem__(index) + + +StaticArray.__module__ = "pyteal" diff --git a/pyteal/ast/abi/array_static_test.py b/pyteal/ast/abi/array_static_test.py new file mode 100644 index 000000000..0e7618c2f --- /dev/null +++ b/pyteal/ast/abi/array_static_test.py @@ -0,0 +1,265 @@ +import pytest + +from ... import * +from .util import substringForDecoding +from .tuple import encodeTuple +from .bool import boolSequenceLength +from .array_base_test import STATIC_TYPES, DYNAMIC_TYPES + +options = CompileOptions(version=5) + + +def test_StaticArrayTypeSpec_init(): + for elementType in STATIC_TYPES: + for length in range(256): + staticArrayType = abi.StaticArrayTypeSpec(elementType, length) + assert staticArrayType.value_type_spec() is elementType + assert not staticArrayType.is_length_dynamic() + assert staticArrayType._stride() == elementType.byte_length_static() + assert staticArrayType.length_static() == length + + with pytest.raises(TypeError): + abi.StaticArrayTypeSpec(elementType, -1) + + for elementType in DYNAMIC_TYPES: + for length in range(256): + staticArrayType = abi.StaticArrayTypeSpec(elementType, length) + assert staticArrayType.value_type_spec() is elementType + assert not staticArrayType.is_length_dynamic() + assert staticArrayType._stride() == 2 + assert staticArrayType.length_static() == length + + with pytest.raises(TypeError): + abi.StaticArrayTypeSpec(elementType, -1) + + +def test_StaticArrayTypeSpec_str(): + for elementType in STATIC_TYPES + DYNAMIC_TYPES: + for length in range(256): + staticArrayType = abi.StaticArrayTypeSpec(elementType, length) + assert str(staticArrayType) == "{}[{}]".format(elementType, length) + + +def test_StaticArrayTypeSpec_new_instance(): + for elementType in STATIC_TYPES + DYNAMIC_TYPES: + for length in range(256): + staticArrayType = abi.StaticArrayTypeSpec(elementType, length) + instance = staticArrayType.new_instance() + assert isinstance( + instance, + abi.StaticArray, + ) + assert instance.type_spec() == staticArrayType + + +def test_StaticArrayTypeSpec_eq(): + for elementType in STATIC_TYPES + DYNAMIC_TYPES: + for length in range(256): + staticArrayType = abi.StaticArrayTypeSpec(elementType, length) + assert staticArrayType == staticArrayType + assert staticArrayType != abi.StaticArrayTypeSpec(elementType, length + 1) + assert staticArrayType != abi.StaticArrayTypeSpec( + abi.TupleTypeSpec(elementType), length + ) + + +def test_StaticArrayTypeSpec_is_dynamic(): + for elementType in STATIC_TYPES: + for length in range(256): + staticArrayType = abi.StaticArrayTypeSpec(elementType, length) + assert not staticArrayType.is_dynamic() + + for elementType in DYNAMIC_TYPES: + for length in range(256): + staticArrayType = abi.StaticArrayTypeSpec(elementType, length) + assert staticArrayType.is_dynamic() + + +def test_StaticArrayTypeSpec_byte_length_static(): + for elementType in STATIC_TYPES: + for length in range(256): + staticArrayType = abi.StaticArrayTypeSpec(elementType, length) + actual = staticArrayType.byte_length_static() + + if elementType == abi.BoolTypeSpec(): + expected = boolSequenceLength(length) + else: + expected = elementType.byte_length_static() * length + + assert ( + actual == expected + ), "failed with element type {} and length {}".format(elementType, length) + + for elementType in DYNAMIC_TYPES: + for length in range(256): + staticArrayType = abi.StaticArrayTypeSpec(elementType, length) + with pytest.raises(ValueError): + staticArrayType.byte_length_static() + + +def test_StaticArray_decode(): + encoded = Bytes("encoded") + for startIndex in (None, Int(1)): + for endIndex in (None, Int(2)): + for length in (None, Int(3)): + value = abi.StaticArray(abi.Uint64TypeSpec(), 10) + + if endIndex is not None and length is not None: + with pytest.raises(TealInputError): + value.decode( + encoded, + startIndex=startIndex, + endIndex=endIndex, + length=length, + ) + continue + + expr = value.decode( + encoded, startIndex=startIndex, endIndex=endIndex, length=length + ) + assert expr.type_of() == TealType.none + assert not expr.has_return() + + expectedExpr = value.stored_value.store( + substringForDecoding( + encoded, startIndex=startIndex, endIndex=endIndex, length=length + ) + ) + expected, _ = expectedExpr.__teal__(options) + expected.addIncoming() + expected = TealBlock.NormalizeBlocks(expected) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = TealBlock.NormalizeBlocks(actual) + + with TealComponent.Context.ignoreExprEquality(): + assert actual == expected + + +def test_StaticArray_set_values(): + value = abi.StaticArray(abi.Uint64TypeSpec(), 10) + + with pytest.raises(TealInputError): + value.set([]) + + with pytest.raises(TealInputError): + value.set([abi.Uint64()] * 9) + + with pytest.raises(TealInputError): + value.set([abi.Uint64()] * 11) + + with pytest.raises(TealInputError): + value.set([abi.Uint16()] * 10) + + with pytest.raises(TealInputError): + value.set([abi.Uint64()] * 9 + [abi.Uint16()]) + + values = [abi.Uint64() for _ in range(10)] + expr = value.set(values) + assert expr.type_of() == TealType.none + assert not expr.has_return() + + expectedExpr = value.stored_value.store(encodeTuple(values)) + expected, _ = expectedExpr.__teal__(options) + expected.addIncoming() + expected = TealBlock.NormalizeBlocks(expected) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = TealBlock.NormalizeBlocks(actual) + + with TealComponent.Context.ignoreExprEquality(): + assert actual == expected + + +def test_StaticArray_set_copy(): + value = abi.StaticArray(abi.Uint64TypeSpec(), 10) + otherArray = abi.StaticArray(abi.Uint64TypeSpec(), 10) + + with pytest.raises(TealInputError): + value.set(abi.StaticArray(abi.Uint64TypeSpec(), 11)) + + with pytest.raises(TealInputError): + value.set(abi.StaticArray(abi.Uint8TypeSpec(), 10)) + + with pytest.raises(TealInputError): + value.set(abi.Uint64()) + + expr = value.set(otherArray) + assert expr.type_of() == TealType.none + assert not expr.has_return() + + expected = TealSimpleBlock( + [ + TealOp(None, Op.load, otherArray.stored_value.slot), + TealOp(None, Op.store, value.stored_value.slot), + ] + ) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = TealBlock.NormalizeBlocks(actual) + + with TealComponent.Context.ignoreExprEquality(): + assert actual == expected + + +def test_StaticArray_encode(): + value = abi.StaticArray(abi.Uint64TypeSpec(), 10) + expr = value.encode() + assert expr.type_of() == TealType.bytes + assert not expr.has_return() + + expected = TealSimpleBlock([TealOp(None, Op.load, value.stored_value.slot)]) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = TealBlock.NormalizeBlocks(actual) + + with TealComponent.Context.ignoreExprEquality(): + assert actual == expected + + +def test_StaticArray_length(): + for length in (0, 1, 2, 3, 1000): + value = abi.StaticArray(abi.Uint64TypeSpec(), length) + expr = value.length() + assert expr.type_of() == TealType.uint64 + assert not expr.has_return() + + expected = TealSimpleBlock([TealOp(None, Op.int, length)]) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = TealBlock.NormalizeBlocks(actual) + + with TealComponent.Context.ignoreExprEquality(): + assert actual == expected + + +def test_StaticArray_getitem(): + for length in (0, 1, 2, 3, 1000): + value = abi.StaticArray(abi.Uint64TypeSpec(), length) + + for index in range(length): + # dynamic indexes + indexExpr = Int(index) + element = value[indexExpr] + assert type(element) is abi.ArrayElement + assert element.array is value + assert element.index is indexExpr + + for index in range(length): + # static indexes + element = value[index] + assert type(element) is abi.ArrayElement + assert element.array is value + assert type(element.index) is Int + assert element.index.value == index + + with pytest.raises(TealInputError): + value[-1] + + with pytest.raises(TealInputError): + value[length] diff --git a/pyteal/ast/abi/array_test.py b/pyteal/ast/abi/array_test.py deleted file mode 100644 index 797103db9..000000000 --- a/pyteal/ast/abi/array_test.py +++ /dev/null @@ -1,646 +0,0 @@ -from typing import List, cast -import pytest - -from ... import * -from .type import substringForDecoding -from .tuple import encodeTuple -from .bool import boolSequenceLength -from .array import ArrayElement - -# this is not necessary but mypy complains if it's not included -from ... import CompileOptions - -options = CompileOptions(version=5) - -STATIC_TYPES: List[abi.Type] = [ - abi.Bool(), - abi.Uint8(), - abi.Uint16(), - abi.Uint32(), - abi.Uint64(), - abi.Tuple(), - abi.Tuple(abi.Bool(), abi.Bool(), abi.Uint64()), - abi.StaticArray(abi.Bool(), 10), - abi.StaticArray(abi.Uint8(), 10), - abi.StaticArray(abi.Uint16(), 10), - abi.StaticArray(abi.Uint32(), 10), - abi.StaticArray(abi.Uint64(), 10), - abi.StaticArray(abi.Tuple(abi.Bool(), abi.Bool(), abi.Uint64()), 10), -] - -DYNAMIC_TYPES: List[abi.Type] = [ - abi.DynamicArray(abi.Bool()), - abi.DynamicArray(abi.Uint8()), - abi.DynamicArray(abi.Uint16()), - abi.DynamicArray(abi.Uint32()), - abi.DynamicArray(abi.Uint64()), - abi.DynamicArray(abi.Tuple()), - abi.DynamicArray(abi.Tuple(abi.Bool(), abi.Bool(), abi.Uint64())), - abi.DynamicArray(abi.StaticArray(abi.Bool(), 10)), - abi.DynamicArray(abi.StaticArray(abi.Uint8(), 10)), - abi.DynamicArray(abi.StaticArray(abi.Uint16(), 10)), - abi.DynamicArray(abi.StaticArray(abi.Uint32(), 10)), - abi.DynamicArray(abi.StaticArray(abi.Uint64(), 10)), - abi.DynamicArray( - abi.StaticArray(abi.Tuple(abi.Bool(), abi.Bool(), abi.Uint64()), 10) - ), -] - - -def test_StaticArray_init(): - for elementType in STATIC_TYPES: - for length in range(256): - staticArrayType = abi.StaticArray(elementType, length) - assert staticArrayType._valueType is elementType - assert not staticArrayType._has_offsets - assert staticArrayType._stride == elementType.byte_length_static() - assert staticArrayType._static_length == length - - with pytest.raises(TealInputError): - abi.StaticArray(elementType, -1) - - for elementType in DYNAMIC_TYPES: - for length in range(256): - staticArrayType = abi.StaticArray(elementType, length) - assert staticArrayType._valueType is elementType - assert staticArrayType._has_offsets - assert staticArrayType._stride == 2 - assert staticArrayType._static_length == length - - with pytest.raises(TealInputError): - abi.StaticArray(elementType, -1) - - -def test_StaticArray_str(): - for elementType in STATIC_TYPES + DYNAMIC_TYPES: - for length in range(256): - staticArrayType = abi.StaticArray(elementType, length) - assert str(staticArrayType) == "{}[{}]".format(elementType, length) - - -def test_StaticArray_new_instance(): - for elementType in STATIC_TYPES + DYNAMIC_TYPES: - for length in range(256): - staticArrayType = abi.StaticArray(elementType, length) - newInstance = staticArrayType.new_instance() - assert type(newInstance) is abi.StaticArray - assert newInstance._valueType is elementType - assert newInstance._has_offsets == staticArrayType._has_offsets - assert newInstance._stride == staticArrayType._stride - assert newInstance._static_length == length - - -def test_StaticArray_has_same_type_as(): - for elementType in STATIC_TYPES + DYNAMIC_TYPES: - for length in range(256): - staticArrayType = abi.StaticArray(elementType, length) - assert staticArrayType.has_same_type_as(staticArrayType) - assert not staticArrayType.has_same_type_as( - abi.StaticArray(elementType, length + 1) - ) - assert not staticArrayType.has_same_type_as( - abi.StaticArray(abi.Tuple(elementType), length) - ) - - -def test_StaticArray_is_dynamic(): - for elementType in STATIC_TYPES: - for length in range(256): - staticArrayType = abi.StaticArray(elementType, length) - assert not staticArrayType.is_dynamic() - - for elementType in DYNAMIC_TYPES: - for length in range(256): - staticArrayType = abi.StaticArray(elementType, length) - assert staticArrayType.is_dynamic() - - -def test_StaticArray_byte_length_static(): - for elementType in STATIC_TYPES: - for length in range(256): - staticArrayType = abi.StaticArray(elementType, length) - actual = staticArrayType.byte_length_static() - - if type(elementType) is abi.Bool: - expected = boolSequenceLength(length) - else: - expected = elementType.byte_length_static() * length - - assert actual == expected - - for elementType in DYNAMIC_TYPES: - for length in range(256): - staticArrayType = abi.StaticArray(elementType, length) - with pytest.raises(ValueError): - staticArrayType.byte_length_static() - - -def test_StaticArray_decode(): - staticArrayType = abi.StaticArray(abi.Uint64(), 10) - for startIndex in (None, Int(1)): - for endIndex in (None, Int(2)): - for length in (None, Int(3)): - encoded = Bytes("encoded") - - if endIndex is not None and length is not None: - with pytest.raises(TealInputError): - staticArrayType.decode( - encoded, - startIndex=startIndex, - endIndex=endIndex, - length=length, - ) - continue - - expr = staticArrayType.decode( - encoded, startIndex=startIndex, endIndex=endIndex, length=length - ) - assert expr.type_of() == TealType.none - assert not expr.has_return() - - expectedExpr = staticArrayType.stored_value.store( - substringForDecoding( - encoded, startIndex=startIndex, endIndex=endIndex, length=length - ) - ) - expected, _ = expectedExpr.__teal__(options) - expected.addIncoming() - expected = TealBlock.NormalizeBlocks(expected) - - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = TealBlock.NormalizeBlocks(actual) - - with TealComponent.Context.ignoreExprEquality(): - assert actual == expected - - -def test_StaticArray_set_values(): - staticArrayType = abi.StaticArray(abi.Uint64(), 10) - - with pytest.raises(TealInputError): - staticArrayType.set([]) - - with pytest.raises(TealInputError): - staticArrayType.set([abi.Uint64()] * 9) - - with pytest.raises(TealInputError): - staticArrayType.set([abi.Uint64()] * 11) - - with pytest.raises(TealInputError): - staticArrayType.set([abi.Uint16()] * 10) - - with pytest.raises(TealInputError): - staticArrayType.set([abi.Uint64()] * 9 + [abi.Uint16()]) - - values = [abi.Uint64() for _ in range(10)] - expr = staticArrayType.set(values) - assert expr.type_of() == TealType.none - assert not expr.has_return() - - expectedExpr = staticArrayType.stored_value.store(encodeTuple(values)) - expected, _ = expectedExpr.__teal__(options) - expected.addIncoming() - expected = TealBlock.NormalizeBlocks(expected) - - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = TealBlock.NormalizeBlocks(actual) - - with TealComponent.Context.ignoreExprEquality(): - assert actual == expected - - -def test_StaticArray_set_copy(): - staticArrayType = abi.StaticArray(abi.Uint64(), 10) - otherArray = abi.StaticArray(abi.Uint64(), 10) - - with pytest.raises(TealInputError): - staticArrayType.set(abi.StaticArray(abi.Uint64(), 11)) - - with pytest.raises(TealInputError): - staticArrayType.set(abi.StaticArray(abi.Uint8(), 10)) - - with pytest.raises(TealInputError): - staticArrayType.set(abi.Uint64()) - - expr = staticArrayType.set(otherArray) - assert expr.type_of() == TealType.none - assert not expr.has_return() - - expected = TealSimpleBlock( - [ - TealOp(None, Op.load, otherArray.stored_value.slot), - TealOp(None, Op.store, staticArrayType.stored_value.slot), - ] - ) - - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = TealBlock.NormalizeBlocks(actual) - - with TealComponent.Context.ignoreExprEquality(): - assert actual == expected - - -def test_StaticArray_encode(): - staticArrayType = abi.StaticArray(abi.Uint64(), 10) - expr = staticArrayType.encode() - assert expr.type_of() == TealType.bytes - assert not expr.has_return() - - expected = TealSimpleBlock( - [TealOp(None, Op.load, staticArrayType.stored_value.slot)] - ) - - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = TealBlock.NormalizeBlocks(actual) - - with TealComponent.Context.ignoreExprEquality(): - assert actual == expected - - -def test_StaticArray_length_static(): - for length in (0, 1, 2, 3, 1000): - staticArrayType = abi.StaticArray(abi.Uint64(), length) - assert staticArrayType.length_static() == length - - -def test_StaticArray_length(): - for length in (0, 1, 2, 3, 1000): - staticArrayType = abi.StaticArray(abi.Uint64(), length) - expr = staticArrayType.length() - assert expr.type_of() == TealType.uint64 - assert not expr.has_return() - - expected = TealSimpleBlock([TealOp(None, Op.int, length)]) - - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = TealBlock.NormalizeBlocks(actual) - - with TealComponent.Context.ignoreExprEquality(): - assert actual == expected - - -def test_StaticArray_getitem(): - for length in (0, 1, 2, 3, 1000): - staticArrayType = abi.StaticArray(abi.Uint64(), length) - - for index in range(length): - # dynamic indexes - indexExpr = Int(index) - element = staticArrayType[indexExpr] - assert type(element) is ArrayElement - assert element.array is staticArrayType - assert element.index is indexExpr - - for index in range(length): - # static indexes - element = staticArrayType[index] - assert type(element) is ArrayElement - assert element.array is staticArrayType - assert type(element.index) is Int - assert element.index.value == index - - with pytest.raises(TealInputError): - staticArrayType[-1] - - with pytest.raises(TealInputError): - staticArrayType[length] - - -def test_DynamicArray_init(): - for elementType in STATIC_TYPES: - dynamicArrayType = abi.DynamicArray(elementType) - assert dynamicArrayType._valueType is elementType - assert not dynamicArrayType._has_offsets - assert dynamicArrayType._stride == elementType.byte_length_static() - assert dynamicArrayType._static_length is None - - for elementType in DYNAMIC_TYPES: - dynamicArrayType = abi.DynamicArray(elementType) - assert dynamicArrayType._valueType is elementType - assert dynamicArrayType._has_offsets - assert dynamicArrayType._stride == 2 - assert dynamicArrayType._static_length is None - - -def test_DynamicArray_str(): - for elementType in STATIC_TYPES + DYNAMIC_TYPES: - dynamicArrayType = abi.DynamicArray(elementType) - assert str(dynamicArrayType) == "{}[]".format(elementType) - - -def test_DynamicArray_new_instance(): - for elementType in STATIC_TYPES + DYNAMIC_TYPES: - dynamicArrayType = abi.DynamicArray(elementType) - newInstance = dynamicArrayType.new_instance() - assert type(newInstance) is abi.DynamicArray - assert newInstance._valueType is elementType - assert newInstance._has_offsets == dynamicArrayType._has_offsets - assert newInstance._stride == dynamicArrayType._stride - assert newInstance._static_length is None - - -def test_DynamicArray_has_same_type_as(): - for elementType in STATIC_TYPES + DYNAMIC_TYPES: - dynamicArrayType = abi.DynamicArray(elementType) - assert dynamicArrayType.has_same_type_as(dynamicArrayType) - assert not dynamicArrayType.has_same_type_as( - abi.DynamicArray(abi.Tuple(elementType)) - ) - - -def test_DynamicArray_is_dynamic(): - for elementType in STATIC_TYPES + DYNAMIC_TYPES: - dynamicArrayType = abi.DynamicArray(elementType) - assert dynamicArrayType.is_dynamic() - - -def test_DynamicArray_byte_length_static(): - for elementType in STATIC_TYPES + DYNAMIC_TYPES: - dynamicArrayType = abi.DynamicArray(elementType) - with pytest.raises(ValueError): - dynamicArrayType.byte_length_static() - - -def test_DynamicArray_decode(): - dynamicArrayType = abi.DynamicArray(abi.Uint64()) - for startIndex in (None, Int(1)): - for endIndex in (None, Int(2)): - for length in (None, Int(3)): - encoded = Bytes("encoded") - - if endIndex is not None and length is not None: - with pytest.raises(TealInputError): - dynamicArrayType.decode( - encoded, - startIndex=startIndex, - endIndex=endIndex, - length=length, - ) - continue - - expr = dynamicArrayType.decode( - encoded, startIndex=startIndex, endIndex=endIndex, length=length - ) - assert expr.type_of() == TealType.none - assert not expr.has_return() - - expectedExpr = dynamicArrayType.stored_value.store( - substringForDecoding( - encoded, startIndex=startIndex, endIndex=endIndex, length=length - ) - ) - expected, _ = expectedExpr.__teal__(options) - expected.addIncoming() - expected = TealBlock.NormalizeBlocks(expected) - - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = TealBlock.NormalizeBlocks(actual) - - with TealComponent.Context.ignoreExprEquality(): - assert actual == expected - - -def test_DynamicArray_set_values(): - dynamicArrayType = abi.DynamicArray(abi.Uint64()) - - valuesToSet: List[abi.Uint64] = [ - [], - [abi.Uint64()], - [abi.Uint64() for _ in range(10)], - ] - - for values in valuesToSet: - expr = dynamicArrayType.set(values) - assert expr.type_of() == TealType.none - assert not expr.has_return() - - length_tmp = abi.Uint16() - expectedExpr = dynamicArrayType.stored_value.store( - Concat( - Seq(length_tmp.set(len(values)), length_tmp.encode()), - encodeTuple(values), - ) - ) - expected, _ = expectedExpr.__teal__(options) - expected.addIncoming() - expected = TealBlock.NormalizeBlocks(expected) - - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = TealBlock.NormalizeBlocks(actual) - - with TealComponent.Context.ignoreExprEquality(): - with TealComponent.Context.ignoreScratchSlotEquality(): - assert actual == expected - - assert TealBlock.MatchScratchSlotReferences( - TealBlock.GetReferencedScratchSlots(actual), - TealBlock.GetReferencedScratchSlots(expected), - ) - - -def test_DynamicArray_set_copy(): - dynamicArrayType = abi.DynamicArray(abi.Uint64()) - otherArray = abi.DynamicArray(abi.Uint64()) - - with pytest.raises(TealInputError): - dynamicArrayType.set(abi.DynamicArray(abi.Uint8())) - - with pytest.raises(TealInputError): - dynamicArrayType.set(abi.Uint64()) - - expr = dynamicArrayType.set(otherArray) - assert expr.type_of() == TealType.none - assert not expr.has_return() - - expected = TealSimpleBlock( - [ - TealOp(None, Op.load, otherArray.stored_value.slot), - TealOp(None, Op.store, dynamicArrayType.stored_value.slot), - ] - ) - - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = TealBlock.NormalizeBlocks(actual) - - with TealComponent.Context.ignoreExprEquality(): - assert actual == expected - - -def test_DynamicArray_encode(): - dynamicArrayType = abi.DynamicArray(abi.Uint64()) - expr = dynamicArrayType.encode() - assert expr.type_of() == TealType.bytes - assert not expr.has_return() - - expected = TealSimpleBlock( - [TealOp(None, Op.load, dynamicArrayType.stored_value.slot)] - ) - - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = TealBlock.NormalizeBlocks(actual) - - with TealComponent.Context.ignoreExprEquality(): - assert actual == expected - - -def test_DynamicArray_length(): - dynamicArrayType = abi.DynamicArray(abi.Uint64()) - expr = dynamicArrayType.length() - assert expr.type_of() == TealType.uint64 - assert not expr.has_return() - - length_tmp = abi.Uint16() - expectedExpr = Seq(length_tmp.decode(dynamicArrayType.encode()), length_tmp.get()) - expected, _ = expectedExpr.__teal__(options) - expected.addIncoming() - expected = TealBlock.NormalizeBlocks(expected) - - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = TealBlock.NormalizeBlocks(actual) - - with TealComponent.Context.ignoreExprEquality(): - with TealComponent.Context.ignoreScratchSlotEquality(): - assert actual == expected - - assert TealBlock.MatchScratchSlotReferences( - TealBlock.GetReferencedScratchSlots(actual), - TealBlock.GetReferencedScratchSlots(expected), - ) - - -def test_DynamicArray_getitem(): - dynamicArrayType = abi.DynamicArray(abi.Uint64()) - - for index in (0, 1, 2, 3, 1000): - # dynamic indexes - indexExpr = Int(index) - element = dynamicArrayType[indexExpr] - assert type(element) is ArrayElement - assert element.array is dynamicArrayType - assert element.index is indexExpr - - for index in (0, 1, 2, 3, 1000): - # static indexes - element = dynamicArrayType[index] - assert type(element) is ArrayElement - assert element.array is dynamicArrayType - assert type(element.index) is Int - assert element.index.value == index - - with pytest.raises(TealInputError): - dynamicArrayType[-1] - - -def test_ArrayElement_init(): - array = abi.DynamicArray(abi.Uint64()) - index = Int(6) - - element = ArrayElement(array, index) - assert element.array is array - assert element.index is index - - with pytest.raises(TealTypeError): - ArrayElement(array, Bytes("abc")) - - with pytest.raises(TealTypeError): - ArrayElement(array, Assert(index)) - - -def test_ArrayElement_store_into(): - for elementType in STATIC_TYPES + DYNAMIC_TYPES: - staticArrayType = abi.StaticArray(elementType, 100) - index = Int(9) - - element = ArrayElement(staticArrayType, index) - output = elementType.new_instance() - expr = element.store_into(output) - - encoded = staticArrayType.encode() - stride = Int(staticArrayType._stride) - expectedLength = staticArrayType.length() - if type(elementType) is abi.Bool: - expectedExpr = cast(abi.Bool, output).decodeBit(encoded, index) - elif not elementType.is_dynamic(): - expectedExpr = output.decode( - encoded, startIndex=stride * index, length=stride - ) - else: - expectedExpr = output.decode( - encoded, - startIndex=ExtractUint16(encoded, stride * index), - endIndex=If(index + Int(1) == expectedLength) - .Then(Len(encoded)) - .Else(ExtractUint16(encoded, stride * index + Int(2))), - ) - - expected, _ = expectedExpr.__teal__(options) - expected.addIncoming() - expected = TealBlock.NormalizeBlocks(expected) - - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = TealBlock.NormalizeBlocks(actual) - - with TealComponent.Context.ignoreExprEquality(): - assert actual == expected - - with pytest.raises(TealInputError): - element.store_into(abi.Tuple(output)) - - for elementType in STATIC_TYPES + DYNAMIC_TYPES: - dynamicArrayType = abi.DynamicArray(elementType) - index = Int(9) - - element = ArrayElement(dynamicArrayType, index) - output = elementType.new_instance() - expr = element.store_into(output) - - encoded = dynamicArrayType.encode() - stride = Int(dynamicArrayType._stride) - expectedLength = dynamicArrayType.length() - if type(elementType) is abi.Bool: - expectedExpr = cast(abi.Bool, output).decodeBit(encoded, index + Int(16)) - elif not elementType.is_dynamic(): - expectedExpr = output.decode( - encoded, startIndex=stride * index + Int(2), length=stride - ) - else: - expectedExpr = output.decode( - encoded, - startIndex=ExtractUint16(encoded, stride * index + Int(2)) + Int(2), - endIndex=If(index + Int(1) == expectedLength) - .Then(Len(encoded)) - .Else( - ExtractUint16(encoded, stride * index + Int(2) + Int(2)) + Int(2) - ), - ) - - expected, _ = expectedExpr.__teal__(options) - expected.addIncoming() - expected = TealBlock.NormalizeBlocks(expected) - - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = TealBlock.NormalizeBlocks(actual) - - with TealComponent.Context.ignoreExprEquality(): - with TealComponent.Context.ignoreScratchSlotEquality(): - assert actual == expected - - assert TealBlock.MatchScratchSlotReferences( - TealBlock.GetReferencedScratchSlots(actual), - TealBlock.GetReferencedScratchSlots(expected), - ) - - with pytest.raises(TealInputError): - element.store_into(abi.Tuple(output)) diff --git a/pyteal/ast/abi/bool.py b/pyteal/ast/abi/bool.py index 8794e4370..3080662b2 100644 --- a/pyteal/ast/abi/bool.py +++ b/pyteal/ast/abi/bool.py @@ -1,6 +1,7 @@ -from typing import Union, cast, Sequence +from typing import TypeVar, Union, cast, Sequence, Callable from ...types import TealType +from ...errors import TealInputError from ..expr import Expr from ..seq import Seq from ..assert_ import Assert @@ -8,27 +9,38 @@ from ..bytes import Bytes from ..binaryexpr import GetBit from ..ternaryexpr import SetBit -from .type import Type +from .type import TypeSpec, BaseType from .uint import NUM_BITS_IN_BYTE -class Bool(Type): - def __init__(self) -> None: - super().__init__(TealType.uint64) - +class BoolTypeSpec(TypeSpec): def new_instance(self) -> "Bool": return Bool() - def has_same_type_as(self, other: Type) -> bool: - return type(other) is Bool - def is_dynamic(self) -> bool: + # Only accurate if this value is alone, since up to 8 consecutive bools will fit into a single byte return False def byte_length_static(self) -> int: - # Not completely accurate since up to 8 consecutive bools will fit into a single byte return 1 + def storage_type(self) -> TealType: + return TealType.uint64 + + def __eq__(self, other: object) -> bool: + return isinstance(other, BoolTypeSpec) + + def __str__(self) -> str: + return "bool" + + +BoolTypeSpec.__module__ = "pyteal" + + +class Bool(BaseType): + def __init__(self) -> None: + super().__init__(BoolTypeSpec()) + def get(self) -> Expr: return self.stored_value.load() @@ -38,15 +50,19 @@ def set(self, value: Union[bool, Expr, "Bool"]) -> Expr: value = Int(1 if value else 0) checked = True - if type(value) is Bool: + if isinstance(value, BaseType): + if value.type_spec() != self.type_spec(): + raise TealInputError( + "Cannot set type bool to {}".format(value.type_spec()) + ) value = value.get() checked = True if checked: - return self.stored_value.store(cast(Expr, value)) + return self.stored_value.store(value) return Seq( - self.stored_value.store(cast(Expr, value)), + self.stored_value.store(value), Assert(self.stored_value.load() < Int(2)), ) @@ -68,19 +84,16 @@ def decodeBit(self, encoded, bitIndex: Expr) -> Expr: def encode(self) -> Expr: return SetBit(Bytes(b"\x00"), Int(0), self.get()) - def __str__(self) -> str: - return "bool" - -def boolAwareStaticByteLength(types: Sequence[Type]) -> int: +def boolAwareStaticByteLength(types: Sequence[TypeSpec]) -> int: length = 0 ignoreNext = 0 for i, t in enumerate(types): if ignoreNext > 0: ignoreNext -= 1 continue - if type(t) is Bool: - numBools = consecutiveBoolNum(types, i) + if t == BoolTypeSpec(): + numBools = consecutiveBoolTypeSpecNum(types, i) ignoreNext = numBools - 1 length += boolSequenceLength(numBools) continue @@ -88,20 +101,50 @@ def boolAwareStaticByteLength(types: Sequence[Type]) -> int: return length -def consecutiveBoolNum(types: Sequence[Type], startIndex: int) -> int: - numConsecutiveBools = 0 - for t in types[startIndex:]: - if type(t) is not Bool: +T = TypeVar("T") + + +def consecutiveThingNum( + things: Sequence[T], startIndex: int, condition: Callable[[T], bool] +) -> int: + numConsecutiveThings = 0 + for t in things[startIndex:]: + if not condition(t): break - numConsecutiveBools += 1 - return numConsecutiveBools + numConsecutiveThings += 1 + return numConsecutiveThings + + +def consecutiveBoolTypeSpecNum(types: Sequence[TypeSpec], startIndex: int) -> int: + if len(types) != 0 and not isinstance(types[0], TypeSpec): + raise TypeError("Sequence of types expected") + return consecutiveThingNum(types, startIndex, lambda t: t == BoolTypeSpec()) + + +def consecutiveBoolInstanceNum(values: Sequence[BaseType], startIndex: int) -> int: + if len(values) != 0 and not isinstance(values[0], BaseType): + raise TypeError( + "Sequence of types expected, but got {}".format(type(values[0])) + ) + return consecutiveThingNum( + values, startIndex, lambda t: t.type_spec() == BoolTypeSpec() + ) def boolSequenceLength(num_bools: int) -> int: + """Get the length in bytes of an encoding of `num_bools` consecutive booleans values.""" return (num_bools + NUM_BITS_IN_BYTE - 1) // NUM_BITS_IN_BYTE def encodeBoolSequence(values: Sequence[Bool]) -> Expr: + """Encoding a sequences of boolean values into a byte string. + + Args: + values: The values to encode. Each must be an instance of Bool. + + Returns: + An expression which creates an encoded byte string with the input boolean values. + """ length = boolSequenceLength(len(values)) expr: Expr = Bytes(b"\x00" * length) diff --git a/pyteal/ast/abi/bool_test.py b/pyteal/ast/abi/bool_test.py index 9883b2dad..222d53c59 100644 --- a/pyteal/ast/abi/bool_test.py +++ b/pyteal/ast/abi/bool_test.py @@ -1,10 +1,11 @@ -from typing import NamedTuple, List +from typing import NamedTuple, List, Type import pytest from ... import * from .bool import ( boolAwareStaticByteLength, - consecutiveBoolNum, + consecutiveBoolInstanceNum, + consecutiveBoolTypeSpecNum, boolSequenceLength, encodeBoolSequence, ) @@ -15,45 +16,45 @@ options = CompileOptions(version=5) -def test_Bool_str(): - boolType = abi.Bool() - assert str(boolType) == "bool" +def test_BoolTypeSpec_str(): + assert str(abi.BoolTypeSpec()) == "bool" -def test_Bool_is_dynamic(): - boolType = abi.Bool() - assert not boolType.is_dynamic() +def test_BoolTypeSpec_is_dynamic(): + assert not abi.BoolTypeSpec().is_dynamic() -def test_Bool_has_same_type_as(): - boolType = abi.Bool() - assert boolType.has_same_type_as(abi.Bool()) +def test_BoolTypeSpec_byte_length_static(): + assert abi.BoolTypeSpec().byte_length_static() == 1 - for otherType in ( - abi.Byte(), - abi.Uint64(), - abi.StaticArray(boolType, 1), - abi.DynamicArray(boolType), - ): - assert not boolType.has_same_type_as(otherType) + +def test_BoolTypeSpec_new_instance(): + assert isinstance(abi.BoolTypeSpec().new_instance(), abi.Bool) -def test_Bool_new_instance(): - boolType = abi.Bool() - assert type(boolType.new_instance()) is abi.Bool +def test_BoolTypeSpec_eq(): + assert abi.BoolTypeSpec() == abi.BoolTypeSpec() + + for otherType in ( + abi.ByteTypeSpec, + abi.Uint64TypeSpec, + abi.StaticArrayTypeSpec(abi.BoolTypeSpec(), 1), + abi.DynamicArrayTypeSpec(abi.BoolTypeSpec()), + ): + assert abi.BoolTypeSpec() != otherType def test_Bool_set_static(): - boolType = abi.Bool() - for value in (True, False): - expr = boolType.set(value) + value = abi.Bool() + for value_to_set in (True, False): + expr = value.set(value_to_set) assert expr.type_of() == TealType.none assert not expr.has_return() expected = TealSimpleBlock( [ - TealOp(None, Op.int, 1 if value else 0), - TealOp(None, Op.store, boolType.stored_value.slot), + TealOp(None, Op.int, 1 if value_to_set else 0), + TealOp(None, Op.store, value.stored_value.slot), ] ) @@ -66,8 +67,8 @@ def test_Bool_set_static(): def test_Bool_set_expr(): - boolType = abi.Bool() - expr = boolType.set(Int(0).Or(Int(1))) + value = abi.Bool() + expr = value.set(Int(0).Or(Int(1))) assert expr.type_of() == TealType.none assert not expr.has_return() @@ -76,8 +77,8 @@ def test_Bool_set_expr(): TealOp(None, Op.int, 0), TealOp(None, Op.int, 1), TealOp(None, Op.logic_or), - TealOp(None, Op.store, boolType.stored_value.slot), - TealOp(None, Op.load, boolType.stored_value.slot), + TealOp(None, Op.store, value.stored_value.slot), + TealOp(None, Op.load, value.stored_value.slot), TealOp(None, Op.int, 2), TealOp(None, Op.lt), TealOp(None, Op.assert_), @@ -94,15 +95,15 @@ def test_Bool_set_expr(): def test_Bool_set_copy(): other = abi.Bool() - boolType = abi.Bool() - expr = boolType.set(other) + value = abi.Bool() + expr = value.set(other) assert expr.type_of() == TealType.none assert not expr.has_return() expected = TealSimpleBlock( [ TealOp(None, Op.load, other.stored_value.slot), - TealOp(None, Op.store, boolType.stored_value.slot), + TealOp(None, Op.store, value.stored_value.slot), ] ) @@ -113,14 +114,17 @@ def test_Bool_set_copy(): with TealComponent.Context.ignoreExprEquality(): assert actual == expected + with pytest.raises(TealInputError): + value.set(abi.Uint16()) + def test_Bool_get(): - boolType = abi.Bool() - expr = boolType.get() + value = abi.Bool() + expr = value.get() assert expr.type_of() == TealType.uint64 assert not expr.has_return() - expected = TealSimpleBlock([TealOp(expr, Op.load, boolType.stored_value.slot)]) + expected = TealSimpleBlock([TealOp(expr, Op.load, value.stored_value.slot)]) actual, _ = expr.__teal__(options) @@ -128,12 +132,12 @@ def test_Bool_get(): def test_Bool_decode(): - boolType = abi.Bool() + value = abi.Bool() encoded = Bytes("encoded") for startIndex in (None, Int(1)): for endIndex in (None, Int(2)): for length in (None, Int(3)): - expr = boolType.decode( + expr = value.decode( encoded, startIndex=startIndex, endIndex=endIndex, length=length ) assert expr.type_of() == TealType.none @@ -146,7 +150,7 @@ def test_Bool_decode(): TealOp(None, Op.int, 8), TealOp(None, Op.mul), TealOp(None, Op.getbit), - TealOp(None, Op.store, boolType.stored_value.slot), + TealOp(None, Op.store, value.stored_value.slot), ] ) @@ -159,10 +163,10 @@ def test_Bool_decode(): def test_Bool_decodeBit(): - boolType = abi.Bool() + value = abi.Bool() bitIndex = Int(17) encoded = Bytes("encoded") - expr = boolType.decodeBit(encoded, bitIndex) + expr = value.decodeBit(encoded, bitIndex) assert expr.type_of() == TealType.none assert not expr.has_return() @@ -171,7 +175,7 @@ def test_Bool_decodeBit(): TealOp(None, Op.byte, '"encoded"'), TealOp(None, Op.int, 17), TealOp(None, Op.getbit), - TealOp(None, Op.store, boolType.stored_value.slot), + TealOp(None, Op.store, value.stored_value.slot), ] ) @@ -184,8 +188,8 @@ def test_Bool_decodeBit(): def test_Bool_encode(): - boolType = abi.Bool() - expr = boolType.encode() + value = abi.Bool() + expr = value.encode() assert expr.type_of() == TealType.bytes assert not expr.has_return() @@ -193,7 +197,7 @@ def test_Bool_encode(): [ TealOp(None, Op.byte, "0x00"), TealOp(None, Op.int, 0), - TealOp(None, Op.load, boolType.stored_value.slot), + TealOp(None, Op.load, value.stored_value.slot), TealOp(None, Op.setbit), ] ) @@ -208,25 +212,35 @@ def test_Bool_encode(): def test_boolAwareStaticByteLength(): class ByteLengthTest(NamedTuple): - types: List[abi.Type] + types: List[abi.TypeSpec] expectedLength: int tests: List[ByteLengthTest] = [ ByteLengthTest(types=[], expectedLength=0), - ByteLengthTest(types=[abi.Uint64()], expectedLength=8), - ByteLengthTest(types=[abi.Bool()], expectedLength=1), - ByteLengthTest(types=[abi.Bool()] * 8, expectedLength=1), - ByteLengthTest(types=[abi.Bool()] * 9, expectedLength=2), - ByteLengthTest(types=[abi.Bool()] * 16, expectedLength=2), - ByteLengthTest(types=[abi.Bool()] * 17, expectedLength=3), - ByteLengthTest(types=[abi.Bool()] * 100, expectedLength=13), - ByteLengthTest(types=[abi.Bool(), abi.Byte(), abi.Bool()], expectedLength=3), + ByteLengthTest(types=[abi.Uint64TypeSpec()], expectedLength=8), + ByteLengthTest(types=[abi.BoolTypeSpec()], expectedLength=1), + ByteLengthTest(types=[abi.BoolTypeSpec()] * 8, expectedLength=1), + ByteLengthTest(types=[abi.BoolTypeSpec()] * 9, expectedLength=2), + ByteLengthTest(types=[abi.BoolTypeSpec()] * 16, expectedLength=2), + ByteLengthTest(types=[abi.BoolTypeSpec()] * 17, expectedLength=3), + ByteLengthTest(types=[abi.BoolTypeSpec()] * 100, expectedLength=13), ByteLengthTest( - types=[abi.Bool(), abi.Bool(), abi.Byte(), abi.Bool(), abi.Bool()], + types=[abi.BoolTypeSpec(), abi.ByteTypeSpec(), abi.BoolTypeSpec()], expectedLength=3, ), ByteLengthTest( - types=[abi.Bool()] * 16 + [abi.Byte(), abi.Bool(), abi.Bool()], + types=[ + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.ByteTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + ], + expectedLength=3, + ), + ByteLengthTest( + types=[abi.BoolTypeSpec()] * 16 + + [abi.ByteTypeSpec(), abi.BoolTypeSpec(), abi.BoolTypeSpec()], expectedLength=4, ), ] @@ -236,40 +250,88 @@ class ByteLengthTest(NamedTuple): assert actual == test.expectedLength, "Test at index {} failed".format(i) -def test_consecutiveBoolNum(): +def test_consecutiveBool(): class ConsecutiveTest(NamedTuple): - types: List[abi.Type] + types: List[abi.TypeSpec] start: int expected: int tests: List[ConsecutiveTest] = [ ConsecutiveTest(types=[], start=0, expected=0), - ConsecutiveTest(types=[abi.Uint16()], start=0, expected=0), - ConsecutiveTest(types=[abi.Bool()], start=0, expected=1), - ConsecutiveTest(types=[abi.Bool()], start=1, expected=0), - ConsecutiveTest(types=[abi.Bool(), abi.Bool()], start=0, expected=2), - ConsecutiveTest(types=[abi.Bool(), abi.Bool()], start=1, expected=1), - ConsecutiveTest(types=[abi.Bool(), abi.Bool()], start=2, expected=0), - ConsecutiveTest(types=[abi.Bool() for _ in range(10)], start=0, expected=10), + ConsecutiveTest(types=[abi.Uint16TypeSpec()], start=0, expected=0), + ConsecutiveTest(types=[abi.BoolTypeSpec()], start=0, expected=1), + ConsecutiveTest(types=[abi.BoolTypeSpec()], start=1, expected=0), ConsecutiveTest( - types=[abi.Bool(), abi.Bool(), abi.Byte(), abi.Bool()], start=0, expected=2 + types=[abi.BoolTypeSpec(), abi.BoolTypeSpec()], start=0, expected=2 ), ConsecutiveTest( - types=[abi.Bool(), abi.Bool(), abi.Byte(), abi.Bool()], start=2, expected=0 + types=[abi.BoolTypeSpec(), abi.BoolTypeSpec()], start=1, expected=1 ), ConsecutiveTest( - types=[abi.Bool(), abi.Bool(), abi.Byte(), abi.Bool()], start=3, expected=1 + types=[abi.BoolTypeSpec(), abi.BoolTypeSpec()], start=2, expected=0 ), ConsecutiveTest( - types=[abi.Byte(), abi.Bool(), abi.Bool(), abi.Byte()], start=0, expected=0 + types=[abi.BoolTypeSpec() for _ in range(10)], start=0, expected=10 ), ConsecutiveTest( - types=[abi.Byte(), abi.Bool(), abi.Bool(), abi.Byte()], start=1, expected=2 + types=[ + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.ByteTypeSpec(), + abi.BoolTypeSpec(), + ], + start=0, + expected=2, + ), + ConsecutiveTest( + types=[ + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.ByteTypeSpec(), + abi.BoolTypeSpec(), + ], + start=2, + expected=0, + ), + ConsecutiveTest( + types=[ + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.ByteTypeSpec(), + abi.BoolTypeSpec(), + ], + start=3, + expected=1, + ), + ConsecutiveTest( + types=[ + abi.ByteTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.ByteTypeSpec(), + ], + start=0, + expected=0, + ), + ConsecutiveTest( + types=[ + abi.ByteTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.ByteTypeSpec(), + ], + start=1, + expected=2, ), ] for i, test in enumerate(tests): - actual = consecutiveBoolNum(test.types, test.start) + actual = consecutiveBoolTypeSpecNum(test.types, test.start) + assert actual == test.expected, "Test at index {} failed".format(i) + + actual = consecutiveBoolInstanceNum( + [t.new_instance() for t in test.types], test.start + ) assert actual == test.expected, "Test at index {} failed".format(i) diff --git a/pyteal/ast/abi/tuple.py b/pyteal/ast/abi/tuple.py index 404b58cbb..c10ecfea5 100644 --- a/pyteal/ast/abi/tuple.py +++ b/pyteal/ast/abi/tuple.py @@ -1,9 +1,4 @@ -from typing import ( - List, - Sequence, - Dict, - cast, -) +from typing import List, Sequence, Dict, Generic, TypeVar, cast from ...types import TealType from ...errors import TealInputError @@ -16,18 +11,21 @@ from ..naryexpr import Concat from ..scratchvar import ScratchVar -from .type import Type, ComputedType, substringForDecoding +from .type import TypeSpec, BaseType, ComputedType from .bool import ( Bool, - consecutiveBoolNum, + BoolTypeSpec, + consecutiveBoolInstanceNum, + consecutiveBoolTypeSpecNum, boolSequenceLength, encodeBoolSequence, boolAwareStaticByteLength, ) from .uint import NUM_BITS_IN_BYTE, Uint16 +from .util import substringForDecoding -def encodeTuple(values: Sequence[Type]) -> Expr: +def encodeTuple(values: Sequence[BaseType]) -> Expr: heads: List[Expr] = [] head_length_static: int = 0 @@ -38,8 +36,10 @@ def encodeTuple(values: Sequence[Type]) -> Expr: ignoreNext -= 1 continue - if type(elem) is Bool: - numBools = consecutiveBoolNum(values, i) + elemType = elem.type_spec() + + if elemType == BoolTypeSpec(): + numBools = consecutiveBoolInstanceNum(values, i) ignoreNext = numBools - 1 head_length_static += boolSequenceLength(numBools) heads.append( @@ -47,13 +47,13 @@ def encodeTuple(values: Sequence[Type]) -> Expr: ) continue - if elem.is_dynamic(): + if elemType.is_dynamic(): head_length_static += 2 dynamicValueIndexToHeadIndex[i] = len(heads) heads.append(Seq()) # a placeholder continue - head_length_static += elem.byte_length_static() + head_length_static += elemType.byte_length_static() heads.append(elem.encode()) tail_offset = Uint16() @@ -63,7 +63,7 @@ def encodeTuple(values: Sequence[Type]) -> Expr: firstDynamicTail = True for i, elem in enumerate(values): - if elem.is_dynamic(): + if elem.type_spec().is_dynamic(): if firstDynamicTail: firstDynamicTail = False updateVars = Seq( @@ -77,7 +77,7 @@ def encodeTuple(values: Sequence[Type]) -> Expr: ) notLastDynamicValue = any( - [nextValue.is_dynamic() for nextValue in values[i + 1 :]] + [nextValue.type_spec().is_dynamic() for nextValue in values[i + 1 :]] ) if notLastDynamicValue: updateAccumulator = tail_offset_accumulator.set( @@ -104,7 +104,7 @@ def encodeTuple(values: Sequence[Type]) -> Expr: def indexTuple( - valueTypes: Sequence[Type], encoded: Expr, index: int, output: Type + valueTypes: Sequence[TypeSpec], encoded: Expr, index: int, output: BaseType ) -> Expr: if not (0 <= index < len(valueTypes)): raise ValueError("Index outside of range") @@ -118,9 +118,9 @@ def indexTuple( ignoreNext -= 1 continue - if type(typeBefore) is Bool: + if typeBefore == BoolTypeSpec(): lastBoolStart = offset - lastBoolLength = consecutiveBoolNum(valueTypes, i) + lastBoolLength = consecutiveBoolTypeSpecNum(valueTypes, i) offset += boolSequenceLength(lastBoolLength) ignoreNext = lastBoolLength - 1 continue @@ -132,7 +132,7 @@ def indexTuple( offset += typeBefore.byte_length_static() valueType = valueTypes[index] - if not valueType.has_same_type_as(output): + if output.type_spec() != valueType: raise TypeError("Output type does not match value type") if type(output) is Bool: @@ -154,8 +154,8 @@ def indexTuple( ignoreNext -= 1 continue - if type(typeAfter) is Bool: - boolLength = consecutiveBoolNum(valueTypes, i) + if type(typeAfter) is BoolTypeSpec: + boolLength = consecutiveBoolTypeSpecNum(valueTypes, i) nextDynamicValueOffset += boolSequenceLength(boolLength) ignoreNext = boolLength - 1 continue @@ -196,31 +196,52 @@ def indexTuple( return output.decode(encoded, startIndex=startIndex, length=length) -class Tuple(Type): - def __init__(self, *valueTypes: Type) -> None: - super().__init__(TealType.bytes) - self.valueTypes = list(valueTypes) +class TupleTypeSpec(TypeSpec): + def __init__(self, *value_type_specs: TypeSpec) -> None: + super().__init__() + self.value_specs = list(value_type_specs) - def has_same_type_as(self, other: Type) -> bool: - return ( - type(other) is Tuple - and len(self.valueTypes) == len(other.valueTypes) - and all( - self.valueTypes[i].has_same_type_as(other.valueTypes[i]) - for i in range(len(self.valueTypes)) - ) - ) + def value_type_specs(self) -> List[TypeSpec]: + """Get the TypeSpecs for the values of this tuple.""" + return self.value_specs + + def length_static(self) -> int: + """Get the number of values this tuple holds.""" + return len(self.value_specs) def new_instance(self) -> "Tuple": - return Tuple(*self.valueTypes) + return Tuple(*self.value_specs) def is_dynamic(self) -> bool: - return any(valueType.is_dynamic() for valueType in self.valueTypes) + return any(type_spec.is_dynamic() for type_spec in self.value_type_specs()) def byte_length_static(self) -> int: if self.is_dynamic(): raise ValueError("Type is dynamic") - return boolAwareStaticByteLength(self.valueTypes) + return boolAwareStaticByteLength(self.value_type_specs()) + + def storage_type(self) -> TealType: + return TealType.bytes + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, TupleTypeSpec) + and self.value_type_specs() == other.value_type_specs() + ) + + def __str__(self) -> str: + return "({})".format(",".join(map(str, self.value_type_specs()))) + + +TupleTypeSpec.__module__ = "pyteal" + + +class Tuple(BaseType): + def __init__(self, *value_type_specs: TypeSpec) -> None: + super().__init__(TupleTypeSpec(*value_type_specs)) + + def type_spec(self) -> TupleTypeSpec: + return cast(TupleTypeSpec, super().type_spec()) def decode( self, @@ -235,51 +256,153 @@ def decode( ) return self.stored_value.store(extracted) - def set(self, *values: Type) -> Expr: - if len(self.valueTypes) != len(values): + def set(self, *values: BaseType) -> Expr: + myTypes = self.type_spec().value_type_specs() + if len(myTypes) != len(values): raise TealInputError( "Incorrect length for values. Expected {}, got {}".format( - len(self.valueTypes), len(values) + len(myTypes), len(values) ) ) - if not all( - self.valueTypes[i].has_same_type_as(values[i]) - for i in range(len(self.valueTypes)) - ): + if not all(myTypes[i] == values[i].type_spec() for i in range(len(myTypes))): raise TealInputError("Input values do not match type") return self.stored_value.store(encodeTuple(values)) def encode(self) -> Expr: return self.stored_value.load() - def length_static(self) -> int: - return len(self.valueTypes) - def length(self) -> Expr: - return Int(self.length_static()) + """Get the number of values this tuple holds as an Expr.""" + return Int(self.type_spec().length_static()) def __getitem__(self, index: int) -> "TupleElement": - if not (0 <= index < self.length_static()): + if not (0 <= index < self.type_spec().length_static()): raise TealInputError("Index out of bounds") return TupleElement(self, index) - def __str__(self) -> str: - return "({})".format(",".join(map(lambda x: x.__str__(), self.valueTypes))) - Tuple.__module__ = "pyteal" -class TupleElement(ComputedType[Type]): +class TupleElement(ComputedType[BaseType]): + """Represents the extraction of a specific element from a Tuple.""" + def __init__(self, tuple: Tuple, index: int) -> None: - super().__init__(tuple.valueTypes[index]) + super().__init__() self.tuple = tuple self.index = index - def store_into(self, output: Type) -> Expr: + def produced_type_spec(self) -> TypeSpec: + return self.tuple.type_spec().value_type_specs()[self.index] + + def store_into(self, output: BaseType) -> Expr: return indexTuple( - self.tuple.valueTypes, self.tuple.encode(), self.index, output + self.tuple.type_spec().value_type_specs(), + self.tuple.encode(), + self.index, + output, ) TupleElement.__module__ = "pyteal" + +# Until Python 3.11 is released with support for PEP 646 -- Variadic Generics, it's not possible for +# the Tuple class to take an arbitrary number of template parameters. As a workaround, we define the +# following classes for specifically sized Tuples. If needed, more classes can be added for larger +# sizes. + + +class Tuple0(Tuple): + """A Tuple with 0 values.""" + + def __init__(self) -> None: + super().__init__() + + +Tuple0.__module__ = "pyteal" + +T1 = TypeVar("T1", bound=BaseType) + + +class Tuple1(Tuple, Generic[T1]): + """A Tuple with 1 value.""" + + def __init__(self, value1_type_spec: TypeSpec) -> None: + super().__init__(value1_type_spec) + + +Tuple1.__module__ = "pyteal" + +T2 = TypeVar("T2", bound=BaseType) + + +class Tuple2(Tuple, Generic[T1, T2]): + """A Tuple with 2 values.""" + + def __init__(self, value1_type_spec: TypeSpec, value2_type_spec: TypeSpec) -> None: + super().__init__(value1_type_spec, value2_type_spec) + + +Tuple2.__module__ = "pyteal" + +T3 = TypeVar("T3", bound=BaseType) + + +class Tuple3(Tuple, Generic[T1, T2, T3]): + """A Tuple with 3 values.""" + + def __init__( + self, + value1_type_spec: TypeSpec, + value2_type_spec: TypeSpec, + value3_type_spec: TypeSpec, + ) -> None: + super().__init__(value1_type_spec, value2_type_spec, value3_type_spec) + + +Tuple3.__module__ = "pyteal" + +T4 = TypeVar("T4", bound=BaseType) + + +class Tuple4(Tuple, Generic[T1, T2, T3, T4]): + """A Tuple with 4 values.""" + + def __init__( + self, + value1_type_spec: TypeSpec, + value2_type_spec: TypeSpec, + value3_type_spec: TypeSpec, + value4_type_spec: TypeSpec, + ) -> None: + super().__init__( + value1_type_spec, value2_type_spec, value3_type_spec, value4_type_spec + ) + + +Tuple4.__module__ = "pyteal" + +T5 = TypeVar("T5", bound=BaseType) + + +class Tuple5(Tuple, Generic[T1, T2, T3, T4, T5]): + """A Tuple with 5 values.""" + + def __init__( + self, + value1_type_spec: TypeSpec, + value2_type_spec: TypeSpec, + value3_type_spec: TypeSpec, + value4_type_spec: TypeSpec, + value5_type_spec: TypeSpec, + ) -> None: + super().__init__( + value1_type_spec, + value2_type_spec, + value3_type_spec, + value4_type_spec, + value5_type_spec, + ) + + +Tuple5.__module__ = "pyteal" diff --git a/pyteal/ast/abi/tuple_test.py b/pyteal/ast/abi/tuple_test.py index 5af924b8c..01011842e 100644 --- a/pyteal/ast/abi/tuple_test.py +++ b/pyteal/ast/abi/tuple_test.py @@ -4,17 +4,14 @@ from ... import * from .tuple import encodeTuple, indexTuple, TupleElement from .bool import encodeBoolSequence -from .type import substringForDecoding - -# this is not necessary but mypy complains if it's not included -from ... import CompileOptions +from .util import substringForDecoding options = CompileOptions(version=5) def test_encodeTuple(): class EncodeTest(NamedTuple): - types: List[abi.Type] + types: List[abi.BaseType] expected: Expr # variables used to construct the tests @@ -24,10 +21,10 @@ class EncodeTest(NamedTuple): uint16_b = abi.Uint16() bool_a = abi.Bool() bool_b = abi.Bool() - tuple_a = abi.Tuple(abi.Bool(), abi.Bool()) - dynamic_array_a = abi.DynamicArray(abi.Uint64()) - dynamic_array_b = abi.DynamicArray(abi.Uint16()) - dynamic_array_c = abi.DynamicArray(abi.Bool()) + tuple_a = abi.Tuple(abi.BoolTypeSpec(), abi.BoolTypeSpec()) + dynamic_array_a = abi.DynamicArray(abi.Uint64TypeSpec()) + dynamic_array_b = abi.DynamicArray(abi.Uint16TypeSpec()) + dynamic_array_c = abi.DynamicArray(abi.BoolTypeSpec()) tail_holder = ScratchVar() encoded_tail = ScratchVar() @@ -199,7 +196,7 @@ class EncodeTest(NamedTuple): actual.addIncoming() actual = TealBlock.NormalizeBlocks(actual) - if any(t.is_dynamic() for t in test.types): + if any(t.type_spec().is_dynamic() for t in test.types): with TealComponent.Context.ignoreExprEquality(): with TealComponent.Context.ignoreScratchSlotEquality(): assert actual == expected, "Test at index {} failed".format(i) @@ -216,159 +213,157 @@ class EncodeTest(NamedTuple): def test_indexTuple(): class IndexTest(NamedTuple): - types: List[abi.Type] + types: List[abi.TypeSpec] typeIndex: int - expected: Callable[[abi.Type], Expr] + expected: Callable[[abi.BaseType], Expr] # variables used to construct the tests - uint64_a = abi.Uint64() - uint64_b = abi.Uint64() - byte_a = abi.Byte() - bool_a = abi.Bool() - bool_b = abi.Bool() - tuple_a = abi.Tuple(abi.Bool(), abi.Bool()) - dynamic_array_a = abi.DynamicArray(abi.Uint64()) - dynamic_array_b = abi.DynamicArray(abi.Uint16()) + uint64_t = abi.Uint64TypeSpec() + byte_t = abi.ByteTypeSpec() + bool_t = abi.BoolTypeSpec() + tuple_t = abi.TupleTypeSpec(abi.BoolTypeSpec(), abi.BoolTypeSpec()) + dynamic_array_t1 = abi.DynamicArrayTypeSpec(abi.Uint64TypeSpec()) + dynamic_array_t2 = abi.DynamicArrayTypeSpec(abi.Uint16TypeSpec()) encoded = Bytes("encoded") tests: List[IndexTest] = [ IndexTest( - types=[uint64_a], + types=[uint64_t], typeIndex=0, expected=lambda output: output.decode(encoded), ), IndexTest( - types=[uint64_a, uint64_b], + types=[uint64_t, uint64_t], typeIndex=0, expected=lambda output: output.decode(encoded, length=Int(8)), ), IndexTest( - types=[uint64_a, uint64_b], + types=[uint64_t, uint64_t], typeIndex=1, expected=lambda output: output.decode(encoded, startIndex=Int(8)), ), IndexTest( - types=[uint64_a, byte_a, uint64_b], + types=[uint64_t, byte_t, uint64_t], typeIndex=1, expected=lambda output: output.decode( encoded, startIndex=Int(8), length=Int(1) ), ), IndexTest( - types=[uint64_a, byte_a, uint64_b], + types=[uint64_t, byte_t, uint64_t], typeIndex=2, expected=lambda output: output.decode( encoded, startIndex=Int(9), length=Int(8) ), ), IndexTest( - types=[bool_a], + types=[bool_t], typeIndex=0, expected=lambda output: output.decodeBit(encoded, Int(0)), ), IndexTest( - types=[bool_a, bool_b], + types=[bool_t, bool_t], typeIndex=0, expected=lambda output: output.decodeBit(encoded, Int(0)), ), IndexTest( - types=[bool_a, bool_b], + types=[bool_t, bool_t], typeIndex=1, expected=lambda output: output.decodeBit(encoded, Int(1)), ), IndexTest( - types=[uint64_a, bool_a], + types=[uint64_t, bool_t], typeIndex=1, expected=lambda output: output.decodeBit(encoded, Int(8 * 8)), ), IndexTest( - types=[uint64_a, bool_a, bool_b], + types=[uint64_t, bool_t, bool_t], typeIndex=1, expected=lambda output: output.decodeBit(encoded, Int(8 * 8)), ), IndexTest( - types=[uint64_a, bool_a, bool_b], + types=[uint64_t, bool_t, bool_t], typeIndex=2, expected=lambda output: output.decodeBit(encoded, Int(8 * 8 + 1)), ), IndexTest( - types=[bool_a, uint64_a], + types=[bool_t, uint64_t], typeIndex=0, expected=lambda output: output.decodeBit(encoded, Int(0)), ), IndexTest( - types=[bool_a, uint64_a], + types=[bool_t, uint64_t], typeIndex=1, expected=lambda output: output.decode(encoded, startIndex=Int(1)), ), IndexTest( - types=[bool_a, bool_b, uint64_a], + types=[bool_t, bool_t, uint64_t], typeIndex=0, expected=lambda output: output.decodeBit(encoded, Int(0)), ), IndexTest( - types=[bool_a, bool_b, uint64_a], + types=[bool_t, bool_t, uint64_t], typeIndex=1, expected=lambda output: output.decodeBit(encoded, Int(1)), ), IndexTest( - types=[bool_a, bool_b, uint64_a], + types=[bool_t, bool_t, uint64_t], typeIndex=2, expected=lambda output: output.decode(encoded, startIndex=Int(1)), ), IndexTest( - types=[tuple_a], typeIndex=0, expected=lambda output: output.decode(encoded) + types=[tuple_t], typeIndex=0, expected=lambda output: output.decode(encoded) ), IndexTest( - types=[byte_a, tuple_a], + types=[byte_t, tuple_t], typeIndex=1, expected=lambda output: output.decode(encoded, startIndex=Int(1)), ), IndexTest( - types=[tuple_a, byte_a], + types=[tuple_t, byte_t], typeIndex=0, expected=lambda output: output.decode( - encoded, startIndex=Int(0), length=Int(tuple_a.byte_length_static()) + encoded, startIndex=Int(0), length=Int(tuple_t.byte_length_static()) ), ), IndexTest( - types=[byte_a, tuple_a, byte_a], + types=[byte_t, tuple_t, byte_t], typeIndex=1, expected=lambda output: output.decode( - encoded, startIndex=Int(1), length=Int(tuple_a.byte_length_static()) + encoded, startIndex=Int(1), length=Int(tuple_t.byte_length_static()) ), ), IndexTest( - types=[dynamic_array_a], + types=[dynamic_array_t1], typeIndex=0, expected=lambda output: output.decode( encoded, startIndex=ExtractUint16(encoded, Int(0)) ), ), IndexTest( - types=[byte_a, dynamic_array_a], + types=[byte_t, dynamic_array_t1], typeIndex=1, expected=lambda output: output.decode( encoded, startIndex=ExtractUint16(encoded, Int(1)) ), ), IndexTest( - types=[dynamic_array_a, byte_a], + types=[dynamic_array_t1, byte_t], typeIndex=0, expected=lambda output: output.decode( encoded, startIndex=ExtractUint16(encoded, Int(0)) ), ), IndexTest( - types=[byte_a, dynamic_array_a, byte_a], + types=[byte_t, dynamic_array_t1, byte_t], typeIndex=1, expected=lambda output: output.decode( encoded, startIndex=ExtractUint16(encoded, Int(1)) ), ), IndexTest( - types=[byte_a, dynamic_array_a, byte_a, dynamic_array_b], + types=[byte_t, dynamic_array_t1, byte_t, dynamic_array_t2], typeIndex=1, expected=lambda output: output.decode( encoded, @@ -377,14 +372,14 @@ class IndexTest(NamedTuple): ), ), IndexTest( - types=[byte_a, dynamic_array_a, byte_a, dynamic_array_b], + types=[byte_t, dynamic_array_t1, byte_t, dynamic_array_t2], typeIndex=3, expected=lambda output: output.decode( encoded, startIndex=ExtractUint16(encoded, Int(4)) ), ), IndexTest( - types=[byte_a, dynamic_array_a, tuple_a, dynamic_array_b], + types=[byte_t, dynamic_array_t1, tuple_t, dynamic_array_t2], typeIndex=1, expected=lambda output: output.decode( encoded, @@ -393,14 +388,14 @@ class IndexTest(NamedTuple): ), ), IndexTest( - types=[byte_a, dynamic_array_a, tuple_a, dynamic_array_b], + types=[byte_t, dynamic_array_t1, tuple_t, dynamic_array_t2], typeIndex=3, expected=lambda output: output.decode( encoded, startIndex=ExtractUint16(encoded, Int(4)) ), ), IndexTest( - types=[byte_a, dynamic_array_a, bool_a, bool_b, dynamic_array_b], + types=[byte_t, dynamic_array_t2, bool_t, bool_t, dynamic_array_t2], typeIndex=1, expected=lambda output: output.decode( encoded, @@ -409,7 +404,7 @@ class IndexTest(NamedTuple): ), ), IndexTest( - types=[byte_a, dynamic_array_a, bool_a, bool_b, dynamic_array_b], + types=[byte_t, dynamic_array_t1, bool_t, bool_t, dynamic_array_t2], typeIndex=4, expected=lambda output: output.decode( encoded, startIndex=ExtractUint16(encoded, Int(4)) @@ -440,105 +435,165 @@ class IndexTest(NamedTuple): with pytest.raises(ValueError): indexTuple(test.types, encoded, -1, output) - otherType = abi.Uint64 - if output.has_same_type_as(otherType): - otherType = abi.Uint16 + otherType = abi.Uint64() + if output.type_spec() == otherType.type_spec(): + otherType = abi.Uint16() + with pytest.raises(TypeError): - indexTuple(test.types, encoded, test.types, otherType) + indexTuple(test.types, encoded, test.typeIndex, otherType) + +def test_TupleTypeSpec_eq(): + tupleA = abi.TupleTypeSpec( + abi.Uint64TypeSpec(), abi.Uint32TypeSpec(), abi.BoolTypeSpec() + ) + tupleB = abi.TupleTypeSpec( + abi.Uint64TypeSpec(), abi.Uint32TypeSpec(), abi.BoolTypeSpec() + ) + tupleC = abi.TupleTypeSpec( + abi.BoolTypeSpec(), abi.Uint64TypeSpec(), abi.Uint32TypeSpec() + ) + assert tupleA == tupleA + assert tupleA == tupleB + assert tupleA != tupleC + + +def test_TupleTypeSpec_value_type_specs(): + assert abi.TupleTypeSpec( + abi.Uint64TypeSpec(), abi.Uint32TypeSpec(), abi.BoolTypeSpec() + ).value_type_specs() == [ + abi.Uint64TypeSpec(), + abi.Uint32TypeSpec(), + abi.BoolTypeSpec(), + ] + + +def test_TupleTypeSpec_length_static(): + tests: List[List[abi.TypeSpec]] = [ + [], + [abi.Uint64TypeSpec()], + [ + abi.TupleTypeSpec(abi.Uint64TypeSpec(), abi.Uint64TypeSpec()), + abi.Uint64TypeSpec(), + ], + [abi.BoolTypeSpec()] * 8, + ] -def test_Tuple_has_same_type_as(): - tupleA = abi.Tuple(abi.Uint64(), abi.Uint32(), abi.Bool()) - tupleB = abi.Tuple(abi.Uint64(), abi.Uint32(), abi.Bool()) - tupleC = abi.Tuple(abi.Bool(), abi.Uint64(), abi.Uint32()) - assert tupleA.has_same_type_as(tupleA) - assert tupleA.has_same_type_as(tupleB) - assert not tupleA.has_same_type_as(tupleC) + for i, test in enumerate(tests): + actual = abi.TupleTypeSpec(*test).length_static() + expected = len(test) + assert actual == expected, "Test at index {} failed".format(i) -def test_Tuple_new_instance(): - tupleA = abi.Tuple(abi.Uint64(), abi.Uint32(), abi.Bool()) - newTuple = tupleA.new_instance() - assert type(newTuple) is abi.Tuple - assert newTuple.valueTypes == tupleA.valueTypes +def test_TupleTypeSpec_new_instance(): + assert isinstance( + abi.TupleTypeSpec( + abi.Uint64TypeSpec(), abi.Uint32TypeSpec(), abi.BoolTypeSpec() + ).new_instance(), + abi.Tuple, + ) -def test_Tuple_is_dynamic(): - assert not abi.Tuple().is_dynamic() - assert not abi.Tuple(abi.Uint64(), abi.Uint32(), abi.Bool()).is_dynamic() - assert abi.Tuple(abi.Uint16(), abi.DynamicArray(abi.Uint8())).is_dynamic() +def test_TupleTypeSpec_is_dynamic(): + assert not abi.TupleTypeSpec().is_dynamic() + assert not abi.TupleTypeSpec( + abi.Uint64TypeSpec(), abi.Uint32TypeSpec(), abi.BoolTypeSpec() + ).is_dynamic() + assert abi.TupleTypeSpec( + abi.Uint16TypeSpec(), abi.DynamicArrayTypeSpec(abi.Uint8TypeSpec()) + ).is_dynamic() -def test_tuple_str(): - assert str(abi.Tuple()) == "()" - assert str(abi.Tuple(abi.Tuple())) == "(())" - assert str(abi.Tuple(abi.Tuple(), abi.Tuple())) == "((),())" +def test_TupleTypeSpec_str(): + assert str(abi.TupleTypeSpec()) == "()" + assert str(abi.TupleTypeSpec(abi.TupleTypeSpec())) == "(())" + assert str(abi.TupleTypeSpec(abi.TupleTypeSpec(), abi.TupleTypeSpec())) == "((),())" assert ( - str(abi.Tuple(abi.Uint64(), abi.Uint32(), abi.Bool())) == "(uint64,uint32,bool)" + str( + abi.TupleTypeSpec( + abi.Uint64TypeSpec(), abi.Uint32TypeSpec(), abi.BoolTypeSpec() + ) + ) + == "(uint64,uint32,bool)" ) assert ( - str(abi.Tuple(abi.Bool(), abi.Uint64(), abi.Uint32())) == "(bool,uint64,uint32)" + str( + abi.TupleTypeSpec( + abi.BoolTypeSpec(), abi.Uint64TypeSpec(), abi.Uint32TypeSpec() + ) + ) + == "(bool,uint64,uint32)" ) assert ( - str(abi.Tuple(abi.Uint16(), abi.DynamicArray(abi.Uint8()))) + str( + abi.TupleTypeSpec( + abi.Uint16TypeSpec(), abi.DynamicArrayTypeSpec(abi.Uint8TypeSpec()) + ) + ) == "(uint16,uint8[])" ) -def test_Tuple_byte_length_static(): - assert abi.Tuple().byte_length_static() == 0 - assert abi.Tuple(abi.Tuple()).byte_length_static() == 0 - assert abi.Tuple(abi.Tuple(), abi.Tuple()).byte_length_static() == 0 +def test_TupleTypeSpec_byte_length_static(): + assert abi.TupleTypeSpec().byte_length_static() == 0 + assert abi.TupleTypeSpec(abi.TupleTypeSpec()).byte_length_static() == 0 + assert ( + abi.TupleTypeSpec(abi.TupleTypeSpec(), abi.TupleTypeSpec()).byte_length_static() + == 0 + ) assert ( - abi.Tuple(abi.Uint64(), abi.Uint32(), abi.Bool()).byte_length_static() + abi.TupleTypeSpec( + abi.Uint64TypeSpec(), abi.Uint32TypeSpec(), abi.BoolTypeSpec() + ).byte_length_static() == 8 + 4 + 1 ) assert ( - abi.Tuple( - abi.Uint64(), - abi.Uint32(), - abi.Bool(), - abi.Bool(), - abi.Bool(), - abi.Bool(), - abi.Bool(), - abi.Bool(), - abi.Bool(), - abi.Bool(), + abi.TupleTypeSpec( + abi.Uint64TypeSpec(), + abi.Uint32TypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), ).byte_length_static() == 8 + 4 + 1 ) assert ( - abi.Tuple( - abi.Uint64(), - abi.Uint32(), - abi.Bool(), - abi.Bool(), - abi.Bool(), - abi.Bool(), - abi.Bool(), - abi.Bool(), - abi.Bool(), - abi.Bool(), - abi.Bool(), + abi.TupleTypeSpec( + abi.Uint64TypeSpec(), + abi.Uint32TypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), + abi.BoolTypeSpec(), ).byte_length_static() == 8 + 4 + 2 ) with pytest.raises(ValueError): - abi.Tuple(abi.Uint16(), abi.DynamicArray(abi.Uint8())).byte_length_static() + abi.TupleTypeSpec( + abi.Uint16TypeSpec(), abi.DynamicArrayTypeSpec(abi.Uint8TypeSpec()) + ).byte_length_static() def test_Tuple_decode(): - tupleType = abi.Tuple() + encoded = Bytes("encoded") + tupleValue = abi.Tuple(abi.Uint64TypeSpec()) for startIndex in (None, Int(1)): for endIndex in (None, Int(2)): for length in (None, Int(3)): - encoded = Bytes("encoded") - if endIndex is not None and length is not None: with pytest.raises(TealInputError): - tupleType.decode( + tupleValue.decode( encoded, startIndex=startIndex, endIndex=endIndex, @@ -546,13 +601,13 @@ def test_Tuple_decode(): ) continue - expr = tupleType.decode( + expr = tupleValue.decode( encoded, startIndex=startIndex, endIndex=endIndex, length=length ) assert expr.type_of() == TealType.none assert not expr.has_return() - expectedExpr = tupleType.stored_value.store( + expectedExpr = tupleValue.stored_value.store( substringForDecoding( encoded, startIndex=startIndex, endIndex=endIndex, length=length ) @@ -570,31 +625,33 @@ def test_Tuple_decode(): def test_Tuple_set(): - tupleType = abi.Tuple(abi.Uint8(), abi.Uint16(), abi.Uint32()) + tupleValue = abi.Tuple( + abi.Uint8TypeSpec(), abi.Uint16TypeSpec(), abi.Uint32TypeSpec() + ) uint8 = abi.Uint8() uint16 = abi.Uint16() uint32 = abi.Uint32() with pytest.raises(TealInputError): - tupleType.set() + tupleValue.set() with pytest.raises(TealInputError): - tupleType.set(uint8, uint16) + tupleValue.set(uint8, uint16) with pytest.raises(TealInputError): - tupleType.set(uint8, uint16, uint32, uint32) + tupleValue.set(uint8, uint16, uint32, uint32) with pytest.raises(TealInputError): - tupleType.set(uint8, uint32, uint16) + tupleValue.set(uint8, uint32, uint16) with pytest.raises(TealInputError): - tupleType.set(uint8, uint16, uint16) + tupleValue.set(uint8, uint16, uint16) - expr = tupleType.set(uint8, uint16, uint32) + expr = tupleValue.set(uint8, uint16, uint32) assert expr.type_of() == TealType.none assert not expr.has_return() - expectedExpr = tupleType.stored_value.store(encodeTuple([uint8, uint16, uint32])) + expectedExpr = tupleValue.stored_value.store(encodeTuple([uint8, uint16, uint32])) expected, _ = expectedExpr.__teal__(options) expected.addIncoming() expected = TealBlock.NormalizeBlocks(expected) @@ -608,12 +665,12 @@ def test_Tuple_set(): def test_Tuple_encode(): - tupleType = abi.Tuple() - expr = tupleType.encode() + tupleValue = abi.Tuple(abi.Uint64TypeSpec()) + expr = tupleValue.encode() assert expr.type_of() == TealType.bytes assert not expr.has_return() - expected = TealSimpleBlock([TealOp(None, Op.load, tupleType.stored_value.slot)]) + expected = TealSimpleBlock([TealOp(None, Op.load, tupleValue.stored_value.slot)]) actual, _ = expr.__teal__(options) actual.addIncoming() @@ -623,32 +680,20 @@ def test_Tuple_encode(): assert actual == expected -def test_Tuple_length_static(): - tests: List[List[abi.Type]] = [ - [], - [abi.Uint64()], - [abi.Tuple(abi.Uint64(), abi.Uint64()), abi.Uint64()], - [abi.Bool()] * 8, - ] - - for i, test in enumerate(tests): - tupleType = abi.Tuple(*test) - actual = tupleType.length_static() - expected = len(test) - assert actual == expected, "Test at index {} failed".format(i) - - def test_Tuple_length(): - tests: List[List[abi.Type]] = [ + tests: List[List[abi.TypeSpec]] = [ [], - [abi.Uint64()], - [abi.Tuple(abi.Uint64(), abi.Uint64()), abi.Uint64()], - [abi.Bool()] * 8, + [abi.Uint64TypeSpec()], + [ + abi.TupleTypeSpec(abi.Uint64TypeSpec(), abi.Uint64TypeSpec()), + abi.Uint64TypeSpec(), + ], + [abi.BoolTypeSpec()] * 8, ] for i, test in enumerate(tests): - tupleType = abi.Tuple(*test) - expr = tupleType.length() + tupleValue = abi.Tuple(*test) + expr = tupleValue.length() assert expr.type_of() == TealType.uint64 assert not expr.has_return() @@ -664,49 +709,53 @@ def test_Tuple_length(): def test_Tuple_getitem(): - tests: List[List[abi.Type]] = [ + tests: List[List[abi.TypeSpec]] = [ [], - [abi.Uint64()], - [abi.Tuple(abi.Uint64(), abi.Uint64()), abi.Uint64()], - [abi.Bool()] * 8, + [abi.Uint64TypeSpec()], + [ + abi.TupleTypeSpec(abi.Uint64TypeSpec(), abi.Uint64TypeSpec()), + abi.Uint64TypeSpec(), + ], + [abi.BoolTypeSpec()] * 8, ] for i, test in enumerate(tests): - tupleType = abi.Tuple(*test) + tupleValue = abi.Tuple(*test) for j in range(len(test)): - element = tupleType[j] + element = tupleValue[j] assert type(element) is TupleElement, "Test at index {} failed".format(i) - assert element.tuple is tupleType, "Test at index {} failed".format(i) + assert element.tuple is tupleValue, "Test at index {} failed".format(i) assert element.index == j, "Test at index {} failed".format(i) with pytest.raises(TealInputError): - tupleType[-1] + tupleValue[-1] with pytest.raises(TealInputError): - tupleType[len(test)] + tupleValue[len(test)] def test_TupleElement_store_into(): - tests: List[List[abi.Type]] = [ + tests: List[List[abi.TypeSpec]] = [ [], - [abi.Uint64()], - [abi.Tuple(abi.Uint64(), abi.Uint64()), abi.Uint64()], - [abi.Bool()] * 8, + [abi.Uint64TypeSpec()], + [ + abi.TupleTypeSpec(abi.Uint64TypeSpec(), abi.Uint64TypeSpec()), + abi.Uint64TypeSpec(), + ], + [abi.BoolTypeSpec()] * 8, ] for i, test in enumerate(tests): - tupleType = abi.Tuple(*test) + tupleValue = abi.Tuple(*test) for j in range(len(test)): - element = TupleElement(tupleType, j) - output = tupleType.valueTypes[j] + element = TupleElement(tupleValue, j) + output = test[j].new_instance() expr = element.store_into(output) assert expr.type_of() == TealType.none assert not expr.has_return() - expectedExpr = indexTuple( - tupleType.valueTypes, tupleType.encode(), j, output - ) + expectedExpr = indexTuple(test, tupleValue.encode(), j, output) expected, _ = expectedExpr.__teal__(options) expected.addIncoming() expected = TealBlock.NormalizeBlocks(expected) diff --git a/pyteal/ast/abi/type.py b/pyteal/ast/abi/type.py index 842bd8c77..67706a4e5 100644 --- a/pyteal/ast/abi/type.py +++ b/pyteal/ast/abi/type.py @@ -1,71 +1,82 @@ -from typing import TypeVar, Generic, Callable +from typing import TypeVar, Generic, Callable, Final, cast from abc import ABC, abstractmethod from ...types import TealType -from ...errors import TealInputError from ..expr import Expr from ..scratchvar import ScratchVar from ..seq import Seq -from ..int import Int -from ..substring import Extract, Substring, Suffix -T = TypeVar("T", bound="Type") +class TypeSpec(ABC): + """TypeSpec represents a specification for an ABI type. -class Type(ABC): - """The abstract base class for all ABI types. - - This class contains both information about an ABI type, and a value that conforms to that type. - The value is contained in a unique ScratchVar that only the type has access to. As a result, the - value of an ABI type is mutable and can be efficiently referenced multiple times without needing - to recompute it. + Essentially this is a factory that can produce specific instances of ABI types. """ - def __init__(self, valueType: TealType) -> None: - """Create a new Type. - - Args: - valueType: The TealType (uint64 or bytes) that this ABI type will store in its internal - ScratchVar. - """ - super().__init__() - self.stored_value = ScratchVar(valueType) - @abstractmethod - def has_same_type_as(self, other: "Type") -> bool: - """Check if this type is considered equal to the other ABI type, irrespective of their - values. + def new_instance(self) -> "BaseType": + """Create a new instance of the specified type.""" + pass - Args: - other: The ABI type to compare to. + @abstractmethod + def is_dynamic(self) -> bool: + """Check if this ABI type is dynamic. - Returns: - True if and only if self and other can store the same ABI value. + If a type is dynamic, the length of its encoding depends on its value. Otherwise, the type + is considered static (not dynamic). """ pass @abstractmethod - def new_instance(self: T) -> T: - """Create a new instance of this ABI type. + def byte_length_static(self) -> int: + """Get the byte length of this ABI type's encoding. Only valid for static types.""" + pass - The value of this type will not be applied to the new type. - """ + @abstractmethod + def storage_type(self) -> TealType: + """Get the TealType that the underlying ScratchVar should hold for this type.""" pass @abstractmethod - def is_dynamic(self) -> bool: - """Check if this ABI type is dynamic. + def __eq__(self, other: object) -> bool: + """Check if this type is considered equal to another ABI type. - If a type is dynamic, the length of its encoding depends on its value. Otherwise, the type - is considered static (not dynamic). + Args: + other: The object to compare to. If this is not a TypeSpec, this method will never + return true. + + Returns: + True if and only if self and other represent the same ABI type. """ pass @abstractmethod - def byte_length_static(self) -> int: - """Get the byte length of this ABI type's encoding. Only valid for static types.""" + def __str__(self) -> str: + """Get the string representation of this ABI type, used for creating method signatures.""" pass + +TypeSpec.__module__ = "pyteal" + + +class BaseType(ABC): + """The abstract base class for all ABI type instances. + + The value of the type is contained in a unique ScratchVar that only this instance has access to. + As a result, the value of an ABI type is mutable and can be efficiently referenced multiple + times without needing to recompute it. + """ + + def __init__(self, spec: TypeSpec) -> None: + """Create a new BaseType.""" + super().__init__() + self._type_spec: Final = spec + self.stored_value: Final = ScratchVar(spec.storage_type()) + + def type_spec(self) -> TypeSpec: + """Get the TypeSpec for this ABI type instance.""" + return self._type_spec + @abstractmethod def encode(self) -> Expr: """Encode this ABI type to a byte string. @@ -117,62 +128,48 @@ def decode( """ pass - @abstractmethod - def __str__(self) -> str: - """Get the string representation of this ABI type, used for creating method signatures.""" - pass +BaseType.__module__ = "pyteal" -Type.__module__ = "pyteal" +T = TypeVar("T", bound=BaseType) class ComputedType(ABC, Generic[T]): - def __init__(self, producedType: T) -> None: - super().__init__() - self._producedType = producedType + """Represents an ABI Type whose value must be computed by an expression.""" @abstractmethod - def store_into(self, output: T) -> Expr: + def produced_type_spec(cls) -> TypeSpec: + """Get the ABI TypeSpec that this object produces.""" pass - def use(self, action: Callable[[T], Expr]) -> Expr: - newInstance = self._producedType.new_instance() - return Seq(self.store_into(newInstance), action(newInstance)) - - -ComputedType.__module__ = "pyteal" - + @abstractmethod + def store_into(self, output: T) -> Expr: + """Store the value of this computed type into an existing ABI type instance. -def substringForDecoding( - encoded: Expr, - *, - startIndex: Expr = None, - endIndex: Expr = None, - length: Expr = None -) -> Expr: - """A helper function for getting the substring to decode according to the rules of Type.decode.""" - if length is not None and endIndex is not None: - raise TealInputError("length and endIndex are mutually exclusive arguments") + Args: + output: The object where the computed value will be stored. This object must have the + same type as this class's produced type. - if startIndex is not None: - if length is not None: - # substring from startIndex to startIndex + length - return Extract(encoded, startIndex, length) + Returns: + An expression which stores the computed value represented by this class into the output + object. + """ + pass - if endIndex is not None: - # substring from startIndex to endIndex - return Substring(encoded, startIndex, endIndex) + def use(self, action: Callable[[T], Expr]) -> Expr: + """Use the computed value represented by this class in a function or lambda expression. - # substring from startIndex to end of string - return Suffix(encoded, startIndex) + Args: + action: A callable object that will receive an instance of this class's produced type + with the computed value. The callable object may use that value as it sees fit, but + it must return an Expr to be included in the program's AST. - if length is not None: - # substring from 0 to length - return Extract(encoded, Int(0), length) + Returns: + An expression which contains the returned expression from invoking action with the + computed value. + """ + newInstance = cast(T, self.produced_type_spec().new_instance()) + return Seq(self.store_into(newInstance), action(newInstance)) - if endIndex is not None: - # substring from 0 to endIndex - return Substring(encoded, Int(0), endIndex) - # the entire string - return encoded +ComputedType.__module__ = "pyteal" diff --git a/pyteal/ast/abi/type_test.py b/pyteal/ast/abi/type_test.py index 2d57107dd..4c46b3613 100644 --- a/pyteal/ast/abi/type_test.py +++ b/pyteal/ast/abi/type_test.py @@ -1,20 +1,18 @@ -from typing import NamedTuple, List, Optional, Union, Any, cast import pytest from ... import * -from .type import ComputedType, substringForDecoding - -# this is not necessary but mypy complains if it's not included -from ... import CompileOptions options = CompileOptions(version=5) -class DummyComputedType(ComputedType[abi.Uint64]): +class DummyComputedType(abi.ComputedType[abi.Uint64]): def __init__(self, value: int) -> None: - super().__init__(abi.Uint64()) + super().__init__() self._value = value + def produced_type_spec(self) -> abi.Uint64TypeSpec: + return abi.Uint64TypeSpec() + def store_into(self, output: abi.Uint64) -> Expr: return output.set(self._value) @@ -47,84 +45,3 @@ def test_ComputedType_use(): with TealComponent.Context.ignoreExprEquality(): assert actual == expected - - -def test_substringForDecoding(): - class SubstringTest(NamedTuple): - startIndex: Optional[Expr] - endIndex: Optional[Expr] - length: Optional[Expr] - expected: Union[Expr, Any] - - encoded = Bytes("encoded") - - tests: List[SubstringTest] = [ - SubstringTest(startIndex=None, endIndex=None, length=None, expected=encoded), - SubstringTest( - startIndex=None, - endIndex=None, - length=Int(4), - expected=Extract(encoded, Int(0), Int(4)), - ), - SubstringTest( - startIndex=None, - endIndex=Int(4), - length=None, - expected=Substring(encoded, Int(0), Int(4)), - ), - SubstringTest( - startIndex=None, endIndex=Int(4), length=Int(5), expected=TealInputError - ), - SubstringTest( - startIndex=Int(4), - endIndex=None, - length=None, - expected=Suffix(encoded, Int(4)), - ), - SubstringTest( - startIndex=Int(4), - endIndex=None, - length=Int(5), - expected=Extract(encoded, Int(4), Int(5)), - ), - SubstringTest( - startIndex=Int(4), - endIndex=Int(5), - length=None, - expected=Substring(encoded, Int(4), Int(5)), - ), - SubstringTest( - startIndex=Int(4), endIndex=Int(5), length=Int(6), expected=TealInputError - ), - ] - - for i, test in enumerate(tests): - if not isinstance(test.expected, Expr): - with pytest.raises(test.expected): - substringForDecoding( - encoded, - startIndex=test.startIndex, - endIndex=test.endIndex, - length=test.length, - ) - continue - - expr = substringForDecoding( - encoded, - startIndex=test.startIndex, - endIndex=test.endIndex, - length=test.length, - ) - assert expr.type_of() == TealType.bytes - assert not expr.has_return() - - expected, _ = cast(Expr, test.expected).__teal__(options) - expected.addIncoming() - expected = TealBlock.NormalizeBlocks(expected) - - actual, _ = expr.__teal__(options) - actual.addIncoming() - actual = TealBlock.NormalizeBlocks(actual) - - with TealComponent.Context.ignoreExprEquality(): - assert actual == expected, "Test at index {} failed".format(i) diff --git a/pyteal/ast/abi/uint.py b/pyteal/ast/abi/uint.py index cd5e190e2..dd12ed2ee 100644 --- a/pyteal/ast/abi/uint.py +++ b/pyteal/ast/abi/uint.py @@ -1,8 +1,15 @@ -from typing import Union, cast +from typing import ( + TypeVar, + Union, + Optional, + Final, + cast, +) from abc import abstractmethod from ...types import TealType from ...errors import TealInputError +from ..scratchvar import ScratchVar from ..expr import Expr from ..seq import Seq from ..assert_ import Assert @@ -12,95 +19,140 @@ from ..unaryexpr import Itob, Btoi from ..binaryexpr import GetByte, ExtractUint16, ExtractUint32, ExtractUint64 from ..ternaryexpr import SetByte -from .type import Type +from .type import TypeSpec, BaseType NUM_BITS_IN_BYTE = 8 +SUPPORTED_UINT_SIZES = (8, 16, 32, 64) -class Uint(Type): - def __init__(self, bit_size: int) -> None: - valueType = TealType.uint64 if bit_size <= 64 else TealType.bytes - super().__init__(valueType) - self.bit_size = bit_size - def has_same_type_as(self, other: Type) -> bool: - return isinstance(other, Uint) and self.bit_size == other.bit_size +def uint_storage_type(size: int) -> TealType: + if size <= 64: + return TealType.uint64 + return TealType.bytes - def is_dynamic(self) -> bool: - return False - def bits(self) -> int: - return self.bit_size +def uint_set(size: int, uintVar: ScratchVar, value: Union[int, Expr, "Uint"]) -> Expr: + if size > 64: + raise NotImplementedError( + "Uint operations have not yet been implemented for bit sizes larger than 64" + ) - def byte_length_static(self) -> int: - return self.bit_size // NUM_BITS_IN_BYTE + checked = False + if type(value) is int: + if value >= 2 ** size: + raise TealInputError("Value exceeds uint{} maximum: {}".format(size, value)) + value = Int(value) + checked = True + + if isinstance(value, Uint): + value = value.get() + checked = True + + if checked or size == 64: + return uintVar.store(cast(Expr, value)) + + return Seq( + uintVar.store(cast(Expr, value)), + Assert(uintVar.load() < Int(2 ** size)), + ) + + +def uint_decode( + size: int, + uintVar: ScratchVar, + encoded: Expr, + startIndex: Optional[Expr], + endIndex: Optional[Expr], + length: Optional[Expr], +) -> Expr: + if size > 64: + raise NotImplementedError( + "Uint operations have not yet been implemented for bit sizes larger than 64" + ) - def get(self) -> Expr: - return self.stored_value.load() + if size == 64: + if startIndex is None: + if endIndex is None and length is None: + return uintVar.store(Btoi(encoded)) + startIndex = Int(0) + return uintVar.store(ExtractUint64(encoded, startIndex)) - @abstractmethod - def set(self, value: Union[int, Expr]) -> Expr: - pass + if startIndex is None: + startIndex = Int(0) - def __str__(self) -> str: - return "uint{}".format(self.bit_size) + if size == 8: + return uintVar.store(GetByte(encoded, startIndex)) + if size == 16: + return uintVar.store(ExtractUint16(encoded, startIndex)) + if size == 32: + return uintVar.store(ExtractUint32(encoded, startIndex)) + raise ValueError("Unsupported uint size: {}".format(size)) -class Uint8(Uint): - def __init__( - self, - ) -> None: - super().__init__(8) - def new_instance(self) -> "Uint8": - return Uint8() +def uint_encode(size: int, uintVar: ScratchVar) -> Expr: + if size > 64: + raise NotImplementedError( + "Uint operations have not yet been implemented for bit sizes larger than 64" + ) - def set(self, value: Union[int, Expr, "Uint8", "Byte"]) -> Expr: - checked = False - if type(value) is int: - if value >= 2 ** self.bit_size: - raise TealInputError( - "Value exceeds {} maximum: {}".format( - self.__class__.__name__, value - ) - ) - value = Int(value) - checked = True + if size == 8: + return SetByte(Bytes(b"\x00"), Int(0), uintVar.load()) + if size == 16: + return Suffix(Itob(uintVar.load()), Int(6)) + if size == 32: + return Suffix(Itob(uintVar.load()), Int(4)) + if size == 64: + return Itob(uintVar.load()) - if type(value) is Uint8 or type(value) is Byte: - value = value.get() - checked = True + raise ValueError("Unsupported uint size: {}".format(size)) - if checked: - return self.stored_value.store(cast(Expr, value)) - return Seq( - self.stored_value.store(cast(Expr, value)), - Assert(self.stored_value.load() < Int(2 ** self.bit_size)), - ) +N = TypeVar("N", bound=int) - def decode( - self, - encoded: Expr, - *, - startIndex: Expr = None, - endIndex: Expr = None, - length: Expr = None - ) -> Expr: - if startIndex is None: - startIndex = Int(0) - return self.stored_value.store(GetByte(encoded, startIndex)) - def encode(self) -> Expr: - return SetByte(Bytes(b"\x00"), Int(0), self.get()) +class UintTypeSpec(TypeSpec): + def __init__(self, bit_size: int) -> None: + super().__init__() + if bit_size not in SUPPORTED_UINT_SIZES: + raise TypeError("Unsupported uint size: {}".format(bit_size)) + self.size: Final = bit_size + @abstractmethod + def new_instance(self) -> "Uint": + pass -Uint8.__module__ = "pyteal" + def bit_size(self) -> int: + """Get the bit size of this uint type""" + return self.size + + def is_dynamic(self) -> bool: + return False + + def byte_length_static(self) -> int: + return self.bit_size() // NUM_BITS_IN_BYTE + + def storage_type(self) -> TealType: + return uint_storage_type(self.bit_size()) + + def __eq__(self, other: object) -> bool: + # NOTE: by this implementation, ByteTypeSpec() != Uint8TypeSpec() + return ( + type(self) is type(other) + and self.bit_size() == cast(UintTypeSpec, other).bit_size() + ) + + def __str__(self) -> str: + return "uint{}".format(self.bit_size()) + + +UintTypeSpec.__module__ = "pyteal" -class Byte(Uint8): +class ByteTypeSpec(UintTypeSpec): def __init__(self) -> None: - super().__init__() + super().__init__(8) def new_instance(self) -> "Byte": return Byte() @@ -109,85 +161,75 @@ def __str__(self) -> str: return "byte" -Byte.__module__ = "pyteal" +ByteTypeSpec.__module__ = "pyteal" -class Uint16(Uint): - def __init__( - self, - ) -> None: +class Uint8TypeSpec(UintTypeSpec): + def __init__(self) -> None: + super().__init__(8) + + def new_instance(self) -> "Uint8": + return Uint8() + + +Uint8TypeSpec.__module__ = "pyteal" + + +class Uint16TypeSpec(UintTypeSpec): + def __init__(self) -> None: super().__init__(16) def new_instance(self) -> "Uint16": return Uint16() - def set(self, value: Union[int, Expr, "Uint16"]) -> Expr: - checked = False - if type(value) is int: - if value >= 2 ** self.bit_size: - raise TealInputError("Value exceeds Uint16 maximum: {}".format(value)) - value = Int(value) - checked = True - if type(value) is Uint16: - value = value.get() - checked = True +Uint16TypeSpec.__module__ = "pyteal" - if checked: - return self.stored_value.store(cast(Expr, value)) - return Seq( - self.stored_value.store(cast(Expr, value)), - Assert(self.stored_value.load() < Int(2 ** self.bit_size)), - ) +class Uint32TypeSpec(UintTypeSpec): + def __init__(self) -> None: + super().__init__(32) - def decode( - self, - encoded: Expr, - *, - startIndex: Expr = None, - endIndex: Expr = None, - length: Expr = None - ) -> Expr: - if startIndex is None: - startIndex = Int(0) - return self.stored_value.store(ExtractUint16(encoded, startIndex)) + def new_instance(self) -> "Uint32": + return Uint32() - def encode(self) -> Expr: - return Suffix(Itob(self.get()), Int(6)) +Uint32TypeSpec.__module__ = "pyteal" -Uint16.__module__ = "pyteal" +class Uint64TypeSpec(UintTypeSpec): + def __init__(self) -> None: + super().__init__(64) -class Uint32(Uint): - def __init__( - self, - ) -> None: - super().__init__(32) + def new_instance(self) -> "Uint64": + return Uint64() - def new_instance(self) -> "Uint32": - return Uint32() - def set(self, value: Union[int, Expr, "Uint32"]) -> Expr: - checked = False - if type(value) is int: - if value >= 2 ** self.bit_size: - raise TealInputError("Value exceeds Uint32 maximum: {}".format(value)) - value = Int(value) - checked = True +Uint32TypeSpec.__module__ = "pyteal" - if type(value) is Uint32: - value = value.get() - checked = True - if checked: - return self.stored_value.store(cast(Expr, value)) +class Uint(BaseType): + @abstractmethod + def __init__(self, spec: UintTypeSpec) -> None: + super().__init__(spec) - return Seq( - self.stored_value.store(cast(Expr, value)), - Assert(self.stored_value.load() < Int(2 ** self.bit_size)), - ) + def type_spec(self) -> UintTypeSpec: + return cast(UintTypeSpec, super().type_spec()) + + def get(self) -> Expr: + return self.stored_value.load() + + def set(self, value: Union[int, Expr, "Uint"]) -> Expr: + if isinstance(value, BaseType) and not ( + isinstance(value.type_spec(), UintTypeSpec) + and self.type_spec().bit_size() == value.type_spec().bit_size() + ): + raise TealInputError( + "Type {} is not assignable to type {}".format( + value.type_spec(), self.type_spec() + ) + ) + return uint_set(self.type_spec().bit_size(), self.stored_value, value) def decode( self, @@ -197,47 +239,57 @@ def decode( endIndex: Expr = None, length: Expr = None ) -> Expr: - if startIndex is None: - startIndex = Int(0) - return self.stored_value.store(ExtractUint32(encoded, startIndex)) + return uint_decode( + self.type_spec().bit_size(), + self.stored_value, + encoded, + startIndex, + endIndex, + length, + ) def encode(self) -> Expr: - return Suffix(Itob(self.get()), Int(4)) + return uint_encode(self.type_spec().bit_size(), self.stored_value) -Uint32.__module__ = "pyteal" +Uint.__module__ = "pyteal" -class Uint64(Uint): +class Byte(Uint): def __init__(self) -> None: - super().__init__(64) + super().__init__(ByteTypeSpec()) - def new_instance(self) -> "Uint64": - return Uint64() - def set(self, value: Union[int, Expr, "Uint64"]) -> Expr: - if type(value) is int: - value = Int(value) - if type(value) is Uint64: - value = value.get() - return self.stored_value.store(cast(Expr, value)) +Byte.__module__ = "pyteal" - def decode( - self, - encoded: Expr, - *, - startIndex: Expr = None, - endIndex: Expr = None, - length: Expr = None - ) -> Expr: - if startIndex is None: - if endIndex is None and length is None: - return self.stored_value.store(Btoi(encoded)) - startIndex = Int(0) - return self.stored_value.store(ExtractUint64(encoded, startIndex)) - def encode(self) -> Expr: - return Itob(self.get()) +class Uint8(Uint): + def __init__(self) -> None: + super().__init__(Uint8TypeSpec()) + + +Uint8.__module__ = "pyteal" + + +class Uint16(Uint): + def __init__(self) -> None: + super().__init__(Uint16TypeSpec()) + + +Uint16.__module__ = "pyteal" + + +class Uint32(Uint): + def __init__(self) -> None: + super().__init__(Uint32TypeSpec()) + + +Uint32.__module__ = "pyteal" + + +class Uint64(Uint): + def __init__(self) -> None: + super().__init__(Uint64TypeSpec()) Uint64.__module__ = "pyteal" diff --git a/pyteal/ast/abi/uint_test.py b/pyteal/ast/abi/uint_test.py index 1ea1f665d..e0e9d10c9 100644 --- a/pyteal/ast/abi/uint_test.py +++ b/pyteal/ast/abi/uint_test.py @@ -1,16 +1,16 @@ -from typing import NamedTuple, Callable, Union, Optional +from typing import List, Tuple, NamedTuple, Callable, Union, Optional import pytest -from ... import * +from pyteal.ast.abi.uint import UintTypeSpec -# this is not necessary but mypy complains if it's not included -from ... import CompileOptions +from ... import * options = CompileOptions(version=5) class UintTestData(NamedTuple): - uintType: abi.Uint + uintType: abi.UintTypeSpec + instanceType: type expectedBits: int maxValue: int checkUpperBound: bool @@ -28,7 +28,8 @@ def noneToInt0(value: Union[None, Expr]): testData = [ UintTestData( - uintType=abi.Uint8(), + uintType=abi.Uint8TypeSpec(), + instanceType=abi.Uint8, expectedBits=8, maxValue=2 ** 8 - 1, checkUpperBound=True, @@ -40,7 +41,8 @@ def noneToInt0(value: Union[None, Expr]): ), ), UintTestData( - uintType=abi.Uint16(), + uintType=abi.Uint16TypeSpec(), + instanceType=abi.Uint16, expectedBits=16, maxValue=2 ** 16 - 1, checkUpperBound=True, @@ -50,7 +52,8 @@ def noneToInt0(value: Union[None, Expr]): expectedEncoding=lambda uintType: Suffix(Itob(uintType.get()), Int(6)), ), UintTestData( - uintType=abi.Uint32(), + uintType=abi.Uint32TypeSpec(), + instanceType=abi.Uint32, expectedBits=32, maxValue=2 ** 32 - 1, checkUpperBound=True, @@ -60,7 +63,8 @@ def noneToInt0(value: Union[None, Expr]): expectedEncoding=lambda uintType: Suffix(Itob(uintType.get()), Int(4)), ), UintTestData( - uintType=abi.Uint64(), + uintType=abi.Uint64TypeSpec(), + instanceType=abi.Uint64, expectedBits=64, maxValue=2 ** 64 - 1, checkUpperBound=False, @@ -72,56 +76,68 @@ def noneToInt0(value: Union[None, Expr]): ] -def test_Uint_bits(): +def test_UintTypeSpec_bits(): for test in testData: - assert test.uintType.bits() == test.expectedBits + assert test.uintType.bit_size() == test.expectedBits assert test.uintType.byte_length_static() * 8 == test.expectedBits -def test_Uint_str(): +def test_UintTypeSpec_str(): for test in testData: assert str(test.uintType) == "uint{}".format(test.expectedBits) - assert str(abi.Byte()) == "byte" + assert str(abi.ByteTypeSpec()) == "byte" -def test_Uint_is_dynamic(): +def test_UintTypeSpec_is_dynamic(): for test in testData: assert not test.uintType.is_dynamic() + assert not abi.ByteTypeSpec().is_dynamic() -def test_Uint_has_same_type_as(): +def test_UintTypeSpec_eq(): for i, test in enumerate(testData): - assert test.uintType.has_same_type_as(test.uintType) + assert test.uintType == test.uintType for j, otherTest in enumerate(testData): if i == j: continue - assert not test.uintType.has_same_type_as(otherTest.uintType) + assert test.uintType != otherTest.uintType for otherType in ( - abi.Bool(), - abi.StaticArray(test.uintType, 1), - abi.DynamicArray(test.uintType), + abi.BoolTypeSpec(), + abi.StaticArrayTypeSpec(test.uintType, 1), + abi.DynamicArrayTypeSpec(test.uintType), ): - assert not test.uintType.has_same_type_as(otherType) + assert test.uintType != otherType + assert abi.ByteTypeSpec() != abi.Uint8TypeSpec() + assert abi.Uint8TypeSpec() != abi.ByteTypeSpec() -def test_Uint_new_instance(): + +def test_UintTypeSpec_storage_type(): for test in testData: - assert type(test.uintType.new_instance()) is type(test.uintType) + assert test.uintType.storage_type() == TealType.uint64 + assert abi.BoolTypeSpec().storage_type() == TealType.uint64 + + +def test_UintTypeSpec_new_instance(): + for test in testData: + assert isinstance(test.uintType.new_instance(), test.instanceType) + assert isinstance(abi.ByteTypeSpec().new_instance(), abi.Byte) def test_Uint_set_static(): for test in testData: - for value in (0, 1, 100, test.maxValue): - expr = test.uintType.set(value) + for value_to_set in (0, 1, 100, test.maxValue): + value = test.uintType.new_instance() + expr = value.set(value_to_set) assert expr.type_of() == TealType.none assert not expr.has_return() expected = TealSimpleBlock( [ - TealOp(None, Op.int, value), - TealOp(None, Op.store, test.uintType.stored_value.slot), + TealOp(None, Op.int, value_to_set), + TealOp(None, Op.store, value.stored_value.slot), ] ) @@ -133,22 +149,23 @@ def test_Uint_set_static(): assert actual == expected with pytest.raises(TealInputError): - test.uintType.set(test.maxValue + 1) + value.set(test.maxValue + 1) with pytest.raises(TealInputError): - test.uintType.set(-1) + value.set(-1) def test_Uint_set_expr(): for test in testData: - expr = test.uintType.set(Int(10) + Int(1)) + value = test.uintType.new_instance() + expr = value.set(Int(10) + Int(1)) assert expr.type_of() == TealType.none assert not expr.has_return() upperBoundCheck = [] if test.checkUpperBound: upperBoundCheck = [ - TealOp(None, Op.load, test.uintType.stored_value.slot), + TealOp(None, Op.load, value.stored_value.slot), TealOp(None, Op.int, test.maxValue + 1), TealOp(None, Op.lt), TealOp(None, Op.assert_), @@ -159,7 +176,7 @@ def test_Uint_set_expr(): TealOp(None, Op.int, 10), TealOp(None, Op.int, 1), TealOp(None, Op.add), - TealOp(None, Op.store, test.uintType.stored_value.slot), + TealOp(None, Op.store, value.stored_value.slot), ] + upperBoundCheck ) @@ -174,15 +191,16 @@ def test_Uint_set_expr(): def test_Uint_set_copy(): for test in testData: + value = test.uintType.new_instance() other = test.uintType.new_instance() - expr = test.uintType.set(other) + expr = value.set(other) assert expr.type_of() == TealType.none assert not expr.has_return() expected = TealSimpleBlock( [ TealOp(None, Op.load, other.stored_value.slot), - TealOp(None, Op.store, test.uintType.stored_value.slot), + TealOp(None, Op.store, value.stored_value.slot), ] ) @@ -193,16 +211,18 @@ def test_Uint_set_copy(): with TealComponent.Context.ignoreExprEquality(): assert actual == expected + with pytest.raises(TealInputError): + value.set(abi.Bool()) + def test_Uint_get(): for test in testData: - expr = test.uintType.get() + value = test.uintType.new_instance() + expr = value.get() assert expr.type_of() == TealType.uint64 assert not expr.has_return() - expected = TealSimpleBlock( - [TealOp(expr, Op.load, test.uintType.stored_value.slot)] - ) + expected = TealSimpleBlock([TealOp(expr, Op.load, value.stored_value.slot)]) actual, _ = expr.__teal__(options) @@ -210,18 +230,19 @@ def test_Uint_get(): def test_Uint_decode(): + encoded = Bytes("encoded") for test in testData: for startIndex in (None, Int(1)): for endIndex in (None, Int(2)): for length in (None, Int(3)): - encoded = Bytes("encoded") - expr = test.uintType.decode( + value = test.uintType.new_instance() + expr = value.decode( encoded, startIndex=startIndex, endIndex=endIndex, length=length ) assert expr.type_of() == TealType.none assert not expr.has_return() - expectedDecoding = test.uintType.stored_value.store( + expectedDecoding = value.stored_value.store( test.expectedDecoding(encoded, startIndex, endIndex, length) ) expected, _ = expectedDecoding.__teal__(options) @@ -238,11 +259,12 @@ def test_Uint_decode(): def test_Uint_encode(): for test in testData: - expr = test.uintType.encode() + value = test.uintType.new_instance() + expr = value.encode() assert expr.type_of() == TealType.bytes assert not expr.has_return() - expected, _ = test.expectedEncoding(test.uintType).__teal__(options) + expected, _ = test.expectedEncoding(value).__teal__(options) expected.addIncoming() expected = TealBlock.NormalizeBlocks(expected) @@ -254,20 +276,14 @@ def test_Uint_encode(): assert actual == expected -def test_ByteUint8_set_error(): - with pytest.raises(TealInputError) as uint8_err_msg: - abi.Uint8().set(256) - assert "Uint8" in uint8_err_msg.__str__() - - with pytest.raises(TealInputError) as byte_err_msg: - abi.Byte().set(256) - assert "Byte" in byte_err_msg.__str__() - - def test_ByteUint8_mutual_conversion(): - for type_a, type_b in [(abi.Uint8, abi.Byte), (abi.Byte, abi.Uint8)]: - type_b_instance = type_b() - other = type_a() + cases: List[Tuple[UintTypeSpec, UintTypeSpec]] = [ + (abi.Uint8TypeSpec(), abi.ByteTypeSpec()), + (abi.ByteTypeSpec(), abi.Uint8TypeSpec()), + ] + for type_a, type_b in cases: + type_b_instance = type_b.new_instance() + other = type_a.new_instance() expr = type_b_instance.set(other) assert expr.type_of() == TealType.none diff --git a/pyteal/ast/abi/util.py b/pyteal/ast/abi/util.py new file mode 100644 index 000000000..e9077ded6 --- /dev/null +++ b/pyteal/ast/abi/util.py @@ -0,0 +1,195 @@ +from typing import Any, Literal, get_origin, get_args + +from ...errors import TealInputError +from ..expr import Expr +from ..int import Int +from ..substring import Extract, Substring, Suffix +from .type import TypeSpec + + +def substringForDecoding( + encoded: Expr, + *, + startIndex: Expr = None, + endIndex: Expr = None, + length: Expr = None +) -> Expr: + """A helper function for getting the substring to decode according to the rules of BaseType.decode.""" + if length is not None and endIndex is not None: + raise TealInputError("length and endIndex are mutually exclusive arguments") + + if startIndex is not None: + if length is not None: + # substring from startIndex to startIndex + length + return Extract(encoded, startIndex, length) + + if endIndex is not None: + # substring from startIndex to endIndex + return Substring(encoded, startIndex, endIndex) + + # substring from startIndex to end of string + return Suffix(encoded, startIndex) + + if length is not None: + # substring from 0 to length + return Extract(encoded, Int(0), length) + + if endIndex is not None: + # substring from 0 to endIndex + return Substring(encoded, Int(0), endIndex) + + # the entire string + return encoded + + +def int_literal_from_annotation(annotation: Any) -> int: + """Extract an integer from a Literal type annotation. + + Args: + annotation: A Literal type annotation. E.g., `Literal[4]`. This must contain only a single + integer value. + + Returns: + The integer that the Literal represents. + """ + origin = get_origin(annotation) + args = get_args(annotation) + + if origin is not Literal: + raise TypeError("Expected literal for argument. Got: {}".format(origin)) + + if len(args) != 1 or type(args[0]) is not int: + raise TypeError( + "Expected single integer argument for Literal. Got: {}".format(args) + ) + + return args[0] + + +def type_spec_from_annotation(annotation: Any) -> TypeSpec: + """Convert an ABI type annotation into the corresponding TypeSpec. + + For example, calling this function with the input `abi.StaticArray[abi.Bool, Literal[5]]` would + return `abi.StaticArrayTypeSpec(abi.BoolTypeSpec(), 5)`. + + Args: + annotation. An annotation representing an ABI type instance. + + Raises: + TypeError: if the input annotation does not represent a valid ABI type instance or its + arguments are invalid. + + Returns: + The TypeSpec that corresponds to the input annotation. + """ + from .bool import BoolTypeSpec, Bool + from .uint import ( + ByteTypeSpec, + Byte, + Uint8TypeSpec, + Uint8, + Uint16TypeSpec, + Uint16, + Uint32TypeSpec, + Uint32, + Uint64TypeSpec, + Uint64, + ) + from .array_dynamic import DynamicArrayTypeSpec, DynamicArray + from .array_static import StaticArrayTypeSpec, StaticArray + from .tuple import ( + TupleTypeSpec, + Tuple, + Tuple0, + Tuple1, + Tuple2, + Tuple3, + Tuple4, + Tuple5, + ) + + origin = get_origin(annotation) + if origin is None: + origin = annotation + + args = get_args(annotation) + + if origin is Bool: + if len(args) != 0: + raise TypeError("Bool expects 0 type arguments. Got: {}".format(args)) + return BoolTypeSpec() + + if origin is Byte: + if len(args) != 0: + raise TypeError("Byte expects 0 type arguments. Got: {}".format(args)) + return ByteTypeSpec() + + if origin is Uint8: + if len(args) != 0: + raise TypeError("Uint8 expects 0 type arguments. Got: {}".format(args)) + return Uint8TypeSpec() + + if origin is Uint16: + if len(args) != 0: + raise TypeError("Uint16 expects 0 type arguments. Got: {}".format(args)) + return Uint16TypeSpec() + + if origin is Uint32: + if len(args) != 0: + raise TypeError("Uint32 expects 0 type arguments. Got: {}".format(args)) + return Uint32TypeSpec() + + if origin is Uint64: + if len(args) != 0: + raise TypeError("Uint64 expects 0 type arguments. Got: {}".format(args)) + return Uint64TypeSpec() + + if origin is DynamicArray: + if len(args) != 1: + raise TypeError( + "DynamicArray expects 1 type argument. Got: {}".format(args) + ) + value_type_spec = type_spec_from_annotation(args[0]) + return DynamicArrayTypeSpec(value_type_spec) + + if origin is StaticArray: + if len(args) != 2: + raise TypeError("StaticArray expects 1 type argument. Got: {}".format(args)) + value_type_spec = type_spec_from_annotation(args[0]) + array_length = int_literal_from_annotation(args[1]) + return StaticArrayTypeSpec(value_type_spec, array_length) + + if origin is Tuple: + return TupleTypeSpec(*(type_spec_from_annotation(arg) for arg in args)) + + if origin is Tuple0: + if len(args) != 0: + raise TypeError("Tuple0 expects 0 type arguments. Got: {}".format(args)) + return TupleTypeSpec() + + if origin is Tuple1: + if len(args) != 1: + raise TypeError("Tuple1 expects 1 type argument. Got: {}".format(args)) + return TupleTypeSpec(*(type_spec_from_annotation(arg) for arg in args)) + + if origin is Tuple2: + if len(args) != 2: + raise TypeError("Tuple2 expects 2 type arguments. Got: {}".format(args)) + return TupleTypeSpec(*(type_spec_from_annotation(arg) for arg in args)) + + if origin is Tuple3: + if len(args) != 3: + raise TypeError("Tuple3 expects 3 type arguments. Got: {}".format(args)) + return TupleTypeSpec(*(type_spec_from_annotation(arg) for arg in args)) + + if origin is Tuple4: + if len(args) != 4: + raise TypeError("Tuple4 expects 4 type arguments. Got: {}".format(args)) + return TupleTypeSpec(*(type_spec_from_annotation(arg) for arg in args)) + + if origin is Tuple5: + if len(args) != 5: + raise TypeError("Tuple5 expects 5 type arguments. Got: {}".format(args)) + return TupleTypeSpec(*(type_spec_from_annotation(arg) for arg in args)) + + raise TypeError("Unknown annotation origin: {}".format(origin)) diff --git a/pyteal/ast/abi/util_test.py b/pyteal/ast/abi/util_test.py new file mode 100644 index 000000000..b4026ee0a --- /dev/null +++ b/pyteal/ast/abi/util_test.py @@ -0,0 +1,298 @@ +from typing import NamedTuple, List, Literal, Optional, Union, Any, cast +from inspect import isabstract +import pytest + +from ... import * +from .util import ( + substringForDecoding, + int_literal_from_annotation, + type_spec_from_annotation, +) + +options = CompileOptions(version=5) + + +def test_substringForDecoding(): + class SubstringTest(NamedTuple): + startIndex: Optional[Expr] + endIndex: Optional[Expr] + length: Optional[Expr] + expected: Union[Expr, Any] + + encoded = Bytes("encoded") + + tests: List[SubstringTest] = [ + SubstringTest(startIndex=None, endIndex=None, length=None, expected=encoded), + SubstringTest( + startIndex=None, + endIndex=None, + length=Int(4), + expected=Extract(encoded, Int(0), Int(4)), + ), + SubstringTest( + startIndex=None, + endIndex=Int(4), + length=None, + expected=Substring(encoded, Int(0), Int(4)), + ), + SubstringTest( + startIndex=None, endIndex=Int(4), length=Int(5), expected=TealInputError + ), + SubstringTest( + startIndex=Int(4), + endIndex=None, + length=None, + expected=Suffix(encoded, Int(4)), + ), + SubstringTest( + startIndex=Int(4), + endIndex=None, + length=Int(5), + expected=Extract(encoded, Int(4), Int(5)), + ), + SubstringTest( + startIndex=Int(4), + endIndex=Int(5), + length=None, + expected=Substring(encoded, Int(4), Int(5)), + ), + SubstringTest( + startIndex=Int(4), endIndex=Int(5), length=Int(6), expected=TealInputError + ), + ] + + for i, test in enumerate(tests): + if not isinstance(test.expected, Expr): + with pytest.raises(test.expected): + substringForDecoding( + encoded, + startIndex=test.startIndex, + endIndex=test.endIndex, + length=test.length, + ) + continue + + expr = substringForDecoding( + encoded, + startIndex=test.startIndex, + endIndex=test.endIndex, + length=test.length, + ) + assert expr.type_of() == TealType.bytes + assert not expr.has_return() + + expected, _ = cast(Expr, test.expected).__teal__(options) + expected.addIncoming() + expected = TealBlock.NormalizeBlocks(expected) + + actual, _ = expr.__teal__(options) + actual.addIncoming() + actual = TealBlock.NormalizeBlocks(actual) + + with TealComponent.Context.ignoreExprEquality(): + assert actual == expected, "Test at index {} failed".format(i) + + +def test_int_literal_from_annotation(): + class IntAnnotationTest(NamedTuple): + annotation: Any + expected: Union[int, Any] + + tests: List[IntAnnotationTest] = [ + IntAnnotationTest(annotation=Literal[0], expected=0), + IntAnnotationTest(annotation=Literal[1], expected=1), + IntAnnotationTest(annotation=Literal[10], expected=10), + # In Python 3.8, Literal[True] == Litearl[1], so the below test fails. + # It's not crucial, so I've commented it out until we no longer support 3.8 + # IntAnnotationTest(annotation=Literal[True], expected=TypeError), + IntAnnotationTest(annotation=Literal["test"], expected=TypeError), + IntAnnotationTest(annotation=Literal[b"test"], expected=TypeError), + IntAnnotationTest(annotation=Literal[None], expected=TypeError), + IntAnnotationTest(annotation=Literal[0, 1], expected=TypeError), + IntAnnotationTest(annotation=Literal, expected=TypeError), + ] + + for i, test in enumerate(tests): + if type(test.expected) is not int: + with pytest.raises(test.expected): + int_literal_from_annotation(test.annotation) + continue + + actual = int_literal_from_annotation(test.annotation) + assert actual == test.expected, "Test at index {} failed".format(i) + + +def test_type_spec_from_annotation(): + class TypeAnnotationTest(NamedTuple): + annotation: Any + expected: Union[abi.TypeSpec, Any] + + tests: List[TypeAnnotationTest] = [ + TypeAnnotationTest(annotation=abi.Bool, expected=abi.BoolTypeSpec()), + TypeAnnotationTest(annotation=abi.Byte, expected=abi.ByteTypeSpec()), + TypeAnnotationTest(annotation=abi.Uint8, expected=abi.Uint8TypeSpec()), + TypeAnnotationTest(annotation=abi.Uint16, expected=abi.Uint16TypeSpec()), + TypeAnnotationTest(annotation=abi.Uint32, expected=abi.Uint32TypeSpec()), + TypeAnnotationTest(annotation=abi.Uint64, expected=abi.Uint64TypeSpec()), + TypeAnnotationTest( + annotation=abi.DynamicArray[abi.Uint32], + expected=abi.DynamicArrayTypeSpec(abi.Uint32TypeSpec()), + ), + TypeAnnotationTest( + annotation=abi.DynamicArray[abi.Uint64], + expected=abi.DynamicArrayTypeSpec(abi.Uint64TypeSpec()), + ), + TypeAnnotationTest( + annotation=abi.DynamicArray[abi.DynamicArray[abi.Uint32]], + expected=abi.DynamicArrayTypeSpec( + abi.DynamicArrayTypeSpec(abi.Uint32TypeSpec()) + ), + ), + TypeAnnotationTest( + annotation=abi.DynamicArray, + expected=TypeError, + ), + TypeAnnotationTest( + annotation=abi.StaticArray[abi.Uint32, Literal[0]], + expected=abi.StaticArrayTypeSpec(abi.Uint32TypeSpec(), 0), + ), + TypeAnnotationTest( + annotation=abi.StaticArray[abi.Uint32, Literal[10]], + expected=abi.StaticArrayTypeSpec(abi.Uint32TypeSpec(), 10), + ), + TypeAnnotationTest( + annotation=abi.StaticArray[abi.Bool, Literal[500]], + expected=abi.StaticArrayTypeSpec(abi.BoolTypeSpec(), 500), + ), + TypeAnnotationTest( + annotation=abi.StaticArray[abi.Bool, Literal[-1]], + expected=TypeError, + ), + TypeAnnotationTest( + annotation=abi.StaticArray[abi.Bool, int], + expected=TypeError, + ), + TypeAnnotationTest( + annotation=abi.StaticArray, + expected=TypeError, + ), + TypeAnnotationTest( + annotation=abi.StaticArray[ + abi.StaticArray[abi.Bool, Literal[500]], Literal[5] + ], + expected=abi.StaticArrayTypeSpec( + abi.StaticArrayTypeSpec(abi.BoolTypeSpec(), 500), 5 + ), + ), + TypeAnnotationTest( + annotation=abi.DynamicArray[abi.StaticArray[abi.Bool, Literal[500]]], + expected=abi.DynamicArrayTypeSpec( + abi.StaticArrayTypeSpec(abi.BoolTypeSpec(), 500) + ), + ), + TypeAnnotationTest(annotation=abi.Tuple, expected=abi.TupleTypeSpec()), + TypeAnnotationTest(annotation=abi.Tuple0, expected=abi.TupleTypeSpec()), + TypeAnnotationTest( + annotation=abi.Tuple1[abi.Uint32], + expected=abi.TupleTypeSpec(abi.Uint32TypeSpec()), + ), + TypeAnnotationTest( + annotation=abi.Tuple1, + expected=TypeError, + ), + TypeAnnotationTest( + annotation=abi.Tuple2[abi.Uint32, abi.Uint16], + expected=abi.TupleTypeSpec(abi.Uint32TypeSpec(), abi.Uint16TypeSpec()), + ), + TypeAnnotationTest( + annotation=abi.Tuple2, + expected=TypeError, + ), + TypeAnnotationTest( + annotation=abi.Tuple3[abi.Uint32, abi.Uint16, abi.Byte], + expected=abi.TupleTypeSpec( + abi.Uint32TypeSpec(), abi.Uint16TypeSpec(), abi.ByteTypeSpec() + ), + ), + TypeAnnotationTest( + annotation=abi.Tuple3, + expected=TypeError, + ), + TypeAnnotationTest( + annotation=abi.Tuple3[ + abi.Tuple1[abi.Uint32], + abi.StaticArray[abi.Bool, Literal[55]], + abi.Tuple2[abi.Uint32, abi.Uint16], + ], + expected=abi.TupleTypeSpec( + abi.TupleTypeSpec(abi.Uint32TypeSpec()), + abi.StaticArrayTypeSpec(abi.BoolTypeSpec(), 55), + abi.TupleTypeSpec(abi.Uint32TypeSpec(), abi.Uint16TypeSpec()), + ), + ), + TypeAnnotationTest( + annotation=abi.Tuple4[abi.Uint32, abi.Uint16, abi.Byte, abi.Bool], + expected=abi.TupleTypeSpec( + abi.Uint32TypeSpec(), + abi.Uint16TypeSpec(), + abi.ByteTypeSpec(), + abi.BoolTypeSpec(), + ), + ), + TypeAnnotationTest( + annotation=abi.Tuple4, + expected=TypeError, + ), + TypeAnnotationTest( + annotation=abi.Tuple5[ + abi.Uint32, abi.Uint16, abi.Byte, abi.Bool, abi.Tuple0 + ], + expected=abi.TupleTypeSpec( + abi.Uint32TypeSpec(), + abi.Uint16TypeSpec(), + abi.ByteTypeSpec(), + abi.BoolTypeSpec(), + abi.TupleTypeSpec(), + ), + ), + TypeAnnotationTest( + annotation=abi.Tuple5, + expected=TypeError, + ), + TypeAnnotationTest( + annotation=List[abi.Uint16], + expected=TypeError, + ), + ] + + for i, test in enumerate(tests): + if not isinstance(test.expected, abi.TypeSpec): + with pytest.raises(test.expected): + type_spec_from_annotation(test.annotation) + continue + + actual = type_spec_from_annotation(test.annotation) + assert actual == test.expected, "Test at index {} failed".format(i) + + +def test_type_spec_from_annotation_is_exhaustive(): + # This test is to make sure there are no new subclasses of BaseType that type_spec_from_annotation + # is not aware of. + + subclasses = abi.BaseType.__subclasses__() + while len(subclasses) > 0: + subclass = subclasses.pop() + subclasses += subclass.__subclasses__() + + if isabstract(subclass): + # abstract class type annotations should not be supported + with pytest.raises(TypeError, match=r"^Unknown annotation origin"): + type_spec_from_annotation(subclass) + continue + + try: + # if subclass is not generic, this will succeed + type_spec_from_annotation(subclass) + except TypeError as e: + # if subclass is generic, we should get an error that is NOT "Unknown annotation origin" + assert "Unknown annotation origin" not in str(e)