diff --git a/neo3/__init__.py b/neo3/__init__.py index cf558d90..d2e688c4 100644 --- a/neo3/__init__.py +++ b/neo3/__init__.py @@ -85,6 +85,9 @@ class Settings(IndexableNamespace): } } }, + 'policy': { + 'max_tx_per_block': 512 + }, 'native_contract_activation': {} } diff --git a/neo3/blockchain.py b/neo3/blockchain.py index bed6693f..091bf746 100644 --- a/neo3/blockchain.py +++ b/neo3/blockchain.py @@ -36,20 +36,19 @@ def height(self): @staticmethod def _create_genesis_block() -> payloads.Block: - b = payloads.Block( + h = payloads.Header( version=0, prev_hash=types.UInt256.zero(), timestamp=int(datetime(2016, 7, 15, 15, 8, 21, 0, timezone.utc).timestamp() * 1000), index=0, + primary_index=0, next_consensus=contracts.Contract.get_consensus_address(settings.standby_validators), witness=payloads.Witness( invocation_script=b'', verification_script=b'\x11' # (OpCode.PUSH1) ), - consensus_data=payloads.ConsensusData(primary_index=0, nonce=2083236893), - transactions=[] ) - return b + return payloads.Block(header=h, transactions=[]) def persist(self, block: payloads.Block): with self.backend.get_snapshotview() as snapshot: @@ -77,6 +76,7 @@ def persist(self, block: payloads.Block): cloned_snapshot.commit() else: cloned_snapshot = snapshot.clone() + engine = contracts.ApplicationEngine(contracts.TriggerType.POST_PERSIST, None, snapshot, 0, True) # type: ignore engine.load_script(vm.Script(self.native_postpersist_script)) @@ -90,6 +90,8 @@ def persist(self, block: payloads.Block): Therefore we wait with persisting the block until here """ snapshot.blocks.put(block) + snapshot.best_block_height = block.index + snapshot.commit() self._current_snapshot = snapshot msgrouter.on_block_persisted(block) diff --git a/neo3/contracts/__init__.py b/neo3/contracts/__init__.py index 5aa1c01d..100a1db9 100644 --- a/neo3/contracts/__init__.py +++ b/neo3/contracts/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations import hashlib +import typing from .callflags import CallFlags from .contracttypes import (TriggerType) from .descriptor import (ContractPermissionDescriptor) @@ -30,13 +31,21 @@ NonFungibleToken, NFTState, NameService, - LedgerContract) + LedgerContract, + CryptoContract, + StdLibContract) def syscall_name_to_int(name: str) -> int: return int.from_bytes(hashlib.sha256(name.encode()).digest()[:4], 'little', signed=False) +def validate_type(obj: object, type_: typing.Type): + if type(obj) != type_: + raise ValueError(f"Expected type '{type_}' , got '{type(obj)}' instead") + return obj + + __all__ = ['ContractParameterType', 'TriggerType', 'ContractMethodDescriptor', @@ -58,4 +67,6 @@ def syscall_name_to_int(name: str) -> int: 'LedgerContract', 'NEF', 'MethodToken', - 'FindOptions'] + 'FindOptions', + 'CryptoContract', + 'StdLibContract'] diff --git a/neo3/contracts/abi.py b/neo3/contracts/abi.py index 7c5101bf..8b956833 100644 --- a/neo3/contracts/abi.py +++ b/neo3/contracts/abi.py @@ -2,7 +2,7 @@ import enum from typing import List, Optional, Type, Union, cast from enum import IntEnum -from neo3.core import types, IJson, IInteroperable, serialization +from neo3.core import types, IJson, IInteroperable, serialization, cryptography from neo3 import contracts, vm @@ -18,13 +18,13 @@ class ContractParameterType(IntEnum): SIGNATURE = 0x17, ARRAY = 0x20, MAP = 0x22, - INTEROP_INTERFACE = 0x30, + INTEROPINTERFACE = 0x30, VOID = 0xff def PascalCase(self) -> str: if self == ContractParameterType.BYTEARRAY: return "ByteArray" - elif self == ContractParameterType.INTEROP_INTERFACE: + elif self == ContractParameterType.INTEROPINTERFACE: return "InteropInterface" elif self == ContractParameterType.PUBLICKEY: return "PublicKey" @@ -41,6 +41,8 @@ def from_type(cls, class_type: Optional[Type[object]]) -> ContractParameterType: return ContractParameterType.INTEGER elif class_type in [bytes, bytearray, vm.BufferStackItem, vm.ByteStringStackItem]: return ContractParameterType.BYTEARRAY + elif class_type == cryptography.ECPoint: + return ContractParameterType.PUBLICKEY elif hasattr(class_type, '__origin__'): if class_type.__origin__ == list: # type: ignore return ContractParameterType.ARRAY @@ -64,10 +66,12 @@ def from_type(cls, class_type: Optional[Type[object]]) -> ContractParameterType: return ContractParameterType.ARRAY elif issubclass(class_type, IInteroperable): return ContractParameterType.ARRAY + elif class_type == vm.StackItem: + return ContractParameterType.ANY elif issubclass(class_type, enum.Enum): return ContractParameterType.INTEGER else: - return ContractParameterType.ANY + return ContractParameterType.INTEROPINTERFACE class ContractParameterDefinition(IJson): @@ -112,8 +116,8 @@ def from_json(cls, json: dict) -> ContractParameterDefinition: ValueError: if the type is VOID. """ c = cls( - name=json['name'], - type=contracts.ContractParameterType[json['type'].upper()] + name=contracts.validate_type(json['name'], str), + type=contracts.ContractParameterType[contracts.validate_type(json['type'], str).upper()] ) if c.name is None or len(c.name) == 0: raise ValueError("Format error - invalid 'name'") @@ -169,7 +173,7 @@ def from_json(cls, json: dict) -> ContractEventDescriptor: ValueError: if the 'name' property has an incorrect format """ c = cls( - name=json['name'], + name=contracts.validate_type(json['name'], str), parameters=list(map(lambda p: ContractParameterDefinition.from_json(p), json['parameters'])) ) if c.name is None or len(c.name) == 0: @@ -241,11 +245,11 @@ def from_json(cls, json: dict) -> ContractMethodDescriptor: ValueError: if the offset is negative. """ c = cls( - name=json['name'], - offset=json['offset'], + name=contracts.validate_type(json['name'], str), + offset=contracts.validate_type(json['offset'], int), parameters=list(map(lambda p: contracts.ContractParameterDefinition.from_json(p), json['parameters'])), - return_type=contracts.ContractParameterType[json['returntype'].upper()], - safe=json['safe'] + return_type=contracts.ContractParameterType[contracts.validate_type(json['returntype'], str).upper()], + safe=contracts.validate_type(json['safe'], bool) ) if c.name is None or len(c.name) == 0: raise ValueError("Format error - invalid 'name'") diff --git a/neo3/contracts/applicationengine.py b/neo3/contracts/applicationengine.py index 04bdffef..0789ac43 100644 --- a/neo3/contracts/applicationengine.py +++ b/neo3/contracts/applicationengine.py @@ -48,10 +48,45 @@ def __init__(self, self.exec_fee_factor = contracts.PolicyContract().get_exec_fee_factor(snapshot) self.STORAGE_PRICE = contracts.PolicyContract().get_storage_price(snapshot) - self._context_state: Dict[vm.ExecutionContext, contracts.ContractState] = {} from neo3.contracts import interop self.interop = interop + @property + def current_scripthash(self) -> types.UInt160: + """ + Get the script hash of the current executing smart contract + + Note: a smart contract can call other smart contracts. + """ + if len(self.current_context.scripthash_bytes) == 0: + return to_script_hash(self.current_context.script._value) + return types.UInt160(self.current_context.scripthash_bytes) + + @property + def calling_scripthash(self) -> types.UInt160: + """ + Get the script hash of the smart contract that called the current executing smart contract. + + Note: a smart contract can call other smart contracts. + + Raises: + ValueError: if the current executing contract has not been called by another contract. + """ + if len(self.current_context.calling_scripthash_bytes) == 0: + raise ValueError("Cannot retrieve calling script_hash - current context has not yet been called") + return types.UInt160(self.current_context.calling_scripthash_bytes) + + @property + def entry_scripthash(self) -> types.UInt160: + """ + Get the script hash of the first smart contract loaded into the engine + + Note: a smart contract can call other smart contracts. + """ + if len(self.entry_context.scripthash_bytes) == 0: + return to_script_hash(self.entry_context.script._value) + return types.UInt160(self.entry_context.scripthash_bytes) + def checkwitness(self, hash_: types.UInt160) -> bool: """ Check if the hash is a valid witness for the engines script_container @@ -93,9 +128,7 @@ def checkwitness(self, hash_: types.UInt160) -> bool: return True if payloads.WitnessScope.CUSTOM_GROUPS in signer.scope: - if contracts.CallFlags.READ_STATES not in \ - contracts.CallFlags(self.current_context.call_flags): - raise ValueError("Context requires callflags ALLOW_STATES") + self._validate_callflags(contracts.CallFlags.READ_STATES) contract = contracts.ManagementContract().get_contract(self.snapshot, self.calling_scripthash) if contract is None: @@ -105,116 +138,22 @@ def checkwitness(self, hash_: types.UInt160) -> bool: return True return False - if contracts.CallFlags.READ_STATES not in \ - contracts.CallFlags(self.current_context.call_flags): - raise ValueError("Context requires callflags ALLOW_STATES") + self._validate_callflags(contracts.CallFlags.READ_STATES) # for other IVerifiable types like Block hashes_for_verifying = self.script_container.get_script_hashes_for_verifying(self.snapshot) return hash_ in hashes_for_verifying - def _stackitem_to_native(self, stack_item: vm.StackItem, target_type: Type[object]): - # checks for type annotations like `List[bytes]` (similar to byte[][] in C#) - if hasattr(target_type, '__origin__') and target_type.__origin__ == list: # type: ignore - element_type = target_type.__args__[0] # type: ignore - array = [] - if isinstance(stack_item, vm.ArrayStackItem): - for e in stack_item: - array.append(self._convert(e, element_type)) - else: - count = stack_item.to_biginteger() - if count > self.MAX_STACK_SIZE: - raise ValueError - - # mypy bug: https://github.com/python/mypy/issues/9755 - for e in range(count): # type: ignore - array.append(self._convert(self.pop(), element_type)) - return array - else: - try: - return self._convert(stack_item, target_type) - except ValueError: - if isinstance(stack_item, vm.InteropStackItem): - return stack_item.get_object() - else: - raise - - def _convert(self, stack_item: vm.StackItem, class_type: Type[object]) -> object: - """ - convert VM type to native - """ - if class_type in [vm.StackItem, vm.PointerStackItem, vm.ArrayStackItem, vm.InteropStackItem]: - return stack_item - elif class_type == int: - return int(stack_item.to_biginteger()) - elif class_type == vm.BigInteger: - return stack_item.to_biginteger() - # mypy bug? https://github.com/python/mypy/issues/9756 - elif class_type in [bytes, bytearray]: # type: ignore - return stack_item.to_array() - elif class_type == bool: - return stack_item.to_boolean() - elif class_type == types.UInt160: - return types.UInt160(data=stack_item.to_array()) - elif class_type == types.UInt256: - return types.UInt256(data=stack_item.to_array()) - elif class_type == str: - if stack_item == vm.NullStackItem(): - return "" - return stack_item.to_array().decode() - elif class_type == cryptography.ECPoint: - return cryptography.ECPoint.deserialize_from_bytes(stack_item.to_array()) - elif issubclass(class_type, enum.Enum): - if stack_item.get_type() == vm.StackItemType.INTEGER: - stack_item = cast(vm.IntegerStackItem, stack_item) - # mypy seems to have trouble understanding types that support __int__ - return class_type(int(stack_item)) # type: ignore - elif stack_item.get_type() == vm.StackItemType.BYTESTRING: - stack_item = cast(vm.ByteStringStackItem, stack_item) - return class_type(int(stack_item.to_biginteger())) # type: ignore - raise ValueError(f"Unknown class type, don't know how to convert: {class_type}") - - def _native_to_stackitem(self, value, native_type) -> vm.StackItem: - """ - Convert native type to VM type - - Note: order of checking matters. - e.g. a Transaction should be treated as IInteropable, while its also ISerializable - """ - if isinstance(value, vm.StackItem): - return value - elif value is None: - return vm.NullStackItem() - elif native_type in [int, vm.BigInteger]: - return vm.IntegerStackItem(value) - elif issubclass(native_type, IInteroperable): - value_ = cast(IInteroperable, value) - return value_.to_stack_item(self.reference_counter) - elif issubclass(native_type, serialization.ISerializable): - serializable_value = cast(serialization.ISerializable, value) - return vm.ByteStringStackItem(serializable_value.to_array()) - # mypy bug? https://github.com/python/mypy/issues/9756 - elif native_type in [bytes, bytearray]: # type: ignore - return vm.ByteStringStackItem(value) - elif native_type == str: - return vm.ByteStringStackItem(bytes(value, 'utf-8')) - elif native_type == bool: - return vm.BooleanStackItem(value) - elif issubclass(native_type, (enum.IntFlag, enum.IntEnum)): - return self._native_to_stackitem(value.value, int) - elif hasattr(native_type, '__origin__') and native_type.__origin__ == Union: # type: ignore - # handle typing.Optional[type], Optional is an alias for Union[x, None] - # only support specifying 1 type - if len(native_type.__args__) != 2: - raise ValueError(f"Don't know how to convert native type {native_type} to stackitem") - for i in native_type.__args__: - if i is None: - continue - return self._native_to_stackitem(value, native_type) - else: - raise ValueError # shouldn't be possible, but silences mypy - else: - return vm.StackItem.from_interface(value) + def call_from_native(self, + calling_scripthash: types.UInt160, + hash_: types.UInt160, + method: str, + args: List[vm.StackItem]) -> None: + ctx = self.current_context + self._contract_call_internal(hash_, method, contracts.CallFlags.ALL, False, args) + self.current_context.calling_scripthash_bytes = calling_scripthash.to_array() + while self.current_context != ctx: + self.step_out() def on_syscall(self, method_id: int) -> Any: """ @@ -235,9 +174,7 @@ def on_syscall(self, method_id: int) -> Any: if descriptor is None: raise KeyError(f"Requested interop {method_id} is not valid") - if descriptor.required_call_flags not in contracts.CallFlags(self.current_context.call_flags): - raise ValueError(f"Cannot call {descriptor.method} with {self.current_context.call_flags}") - + self._validate_callflags(descriptor.required_call_flags) self.add_gas(descriptor.price * self.exec_fee_factor) parameters = [] @@ -269,42 +206,6 @@ def invoke_syscall_by_name(self, method: str) -> Any: """ return self.on_syscall(contracts.syscall_name_to_int(method)) - @property - def current_scripthash(self) -> types.UInt160: - """ - Get the script hash of the current executing smart contract - - Note: a smart contract can call other smart contracts. - """ - if len(self.current_context.scripthash_bytes) == 0: - return to_script_hash(self.current_context.script._value) - return types.UInt160(self.current_context.scripthash_bytes) - - @property - def calling_scripthash(self) -> types.UInt160: - """ - Get the script hash of the smart contract that called the current executing smart contract. - - Note: a smart contract can call other smart contracts. - - Raises: - ValueError: if the current executing contract has not been called by another contract. - """ - if len(self.current_context.calling_scripthash_bytes) == 0: - raise ValueError("Cannot retrieve calling script_hash - current context has not yet been called") - return types.UInt160(self.current_context.calling_scripthash_bytes) - - @property - def entry_scripthash(self) -> types.UInt160: - """ - Get the script hash of the first smart contract loaded into the engine - - Note: a smart contract can call other smart contracts. - """ - if len(self.entry_context.scripthash_bytes) == 0: - return to_script_hash(self.entry_context.script._value) - return types.UInt160(self.entry_context.scripthash_bytes) - def get_invocation_counter(self) -> int: """ Get the number of times the current contract has been called during this execute() run. @@ -320,35 +221,14 @@ def get_invocation_counter(self) -> int: counter = 1 return counter - def load_script_with_callflags(self, - script: vm.Script, - call_flags: contracts.CallFlags, - initial_position: int = 0, - rvcount: int = -1, - contract_state: Optional[contracts.ContractState] = None): - context = super(ApplicationEngine, self).load_script(script, rvcount, initial_position) - context.call_flags = int(call_flags) - if contract_state is not None: - self._context_state.update({context: contract_state}) - return context - - def call_from_native(self, - calling_scripthash: types.UInt160, - hash_: types.UInt160, - method: str, - args: List[vm.StackItem]) -> None: - ctx = self.current_context - self._contract_call_internal(hash_, method, contracts.CallFlags.ALL, False, args) - self.current_context.calling_scripthash_bytes = calling_scripthash.to_array() - while self.current_context != ctx: - self.step_out() + def load_context(self, context: vm.ExecutionContext) -> None: + if len(context.scripthash_bytes) == 0: + context.scripthash_bytes = to_script_hash(context.script._value).to_array() + contract_hash = types.UInt160(data=context.scripthash_bytes) + counter = self._invocation_counter.get(contract_hash, 0) + self._invocation_counter.update({contract_hash: counter + 1}) - def step_out(self) -> None: - c = len(self.invocation_stack) - while self.state != vm.VMState.HALT and self.state != vm.VMState.FAULT and len(self.invocation_stack) >= c: - self._execute_next() - if self.state == vm.VMState.FAULT: - raise ValueError(f"Call from native contract failed: {self.exception_message}") + super(ApplicationEngine, self).load_context(context) def load_contract(self, contract: contracts.ContractState, @@ -359,11 +239,11 @@ def load_contract(self, context = self.load_script_with_callflags(vm.Script(contract.script), flags, method_descriptor.offset, - rvcount, - contract) + rvcount) # configure state context.call_flags = int(flags) context.scripthash_bytes = contract.hash.to_array() + context.nef_bytes = contract.nef.to_array() init = contract.manifest.abi.get_method("_initialize", 0) if init is not None: @@ -371,13 +251,14 @@ def load_contract(self, return context def load_token(self, token_id: int) -> vm.ExecutionContext: - contract = self._context_state.get(self.current_context, None) - if contract is None: - raise ValueError("Current context has no contract state") - if token_id >= len(contract.nef.tokens): + self._validate_callflags(contracts.CallFlags.READ_STATES | contracts.CallFlags.ALLOW_CALL) + if len(self.current_context.nef_bytes) == 0: + raise ValueError("Current context has no NEF state") + nef = contracts.NEF.deserialize_from_bytes(self.current_context.nef_bytes) + if token_id >= len(nef.tokens): raise ValueError("token_id exceeds available tokens") - token = contract.nef.tokens[token_id] + token = nef.tokens[token_id] if token.parameters_count > len(self.current_context.evaluation_stack): raise ValueError("Token count exceeds available paremeters on evaluation stack") args: List[vm.StackItem] = [] @@ -385,14 +266,21 @@ def load_token(self, token_id: int) -> vm.ExecutionContext: args.append(self.pop()) return self._contract_call_internal(token.hash, token.method, token.call_flags, token.has_return_value, args) - def call_native(self, name: str) -> None: - contract = contracts.ManagementContract().get_contract_by_name(name) - if contract is None or contract.active_block_index > self.snapshot.persisting_block.index: - raise ValueError - contract.invoke(self) + def load_script_with_callflags(self, + script: vm.Script, + call_flags: contracts.CallFlags, + initial_position: int = 0, + rvcount: int = -1): + context = super(ApplicationEngine, self).load_script(script, rvcount, initial_position) + context.call_flags = int(call_flags) + return context - def context_unloaded(self, context: vm.ExecutionContext) -> None: - self._context_state.pop(context, None) + def step_out(self) -> None: + c = len(self.invocation_stack) + while self.state != vm.VMState.HALT and self.state != vm.VMState.FAULT and len(self.invocation_stack) >= c: + self._execute_next() + if self.state == vm.VMState.FAULT: + raise ValueError(f"Call from native contract failed: {self.exception_message}") def _contract_call_internal(self, contract_hash: types.UInt160, @@ -410,15 +298,6 @@ def _contract_call_internal(self, f"target contract") return self._contract_call_internal2(target_contract, method_descriptor, flags, has_return_value, args) - def load_context(self, context: vm.ExecutionContext) -> None: - if len(context.scripthash_bytes) == 0: - context.scripthash_bytes = to_script_hash(context.script._value).to_array() - contract_hash = types.UInt160(data=context.scripthash_bytes) - counter = self._invocation_counter.get(contract_hash, 0) - self._invocation_counter.update({contract_hash: counter + 1}) - - super(ApplicationEngine, self).load_context(context) - def _contract_call_internal2(self, target_contract: contracts.ContractState, method_descriptor: contracts.ContractMethodDescriptor, @@ -426,7 +305,7 @@ def _contract_call_internal2(self, has_return_value: bool, args: List[vm.StackItem]): if method_descriptor.safe: - flags &= ~contracts.CallFlags.WRITE_STATES + flags &= ~(contracts.CallFlags.WRITE_STATES | contracts.CallFlags.ALLOW_NOTIFY) else: current_contract = contracts.ManagementContract().get_contract(self.snapshot, self.current_scripthash) if current_contract and not current_contract.can_call(target_contract, method_descriptor.name): @@ -460,6 +339,111 @@ def _contract_call_internal2(self, for item in reversed(args): context_new.evaluation_stack.push(item) - if contracts.NativeContract.is_native(target_contract.hash): - context_new.evaluation_stack.push(vm.ByteStringStackItem(method_descriptor.name.encode('utf-8'))) return context_new + + def _stackitem_to_native(self, stack_item: vm.StackItem, target_type: Type[object]): + # checks for type annotations like `List[bytes]` (similar to byte[][] in C#) + if hasattr(target_type, '__origin__') and target_type.__origin__ == list: # type: ignore + element_type = target_type.__args__[0] # type: ignore + array = [] + if isinstance(stack_item, vm.ArrayStackItem): + for e in stack_item: + array.append(self._convert(e, element_type)) + else: + count = stack_item.to_biginteger() + if count > self.MAX_STACK_SIZE: + raise ValueError + + # mypy bug: https://github.com/python/mypy/issues/9755 + for e in range(count): # type: ignore + array.append(self._convert(self.pop(), element_type)) + return array + else: + try: + return self._convert(stack_item, target_type) + except ValueError: + if isinstance(stack_item, vm.InteropStackItem): + return stack_item.get_object() + else: + raise + + def _validate_callflags(self, callflags: contracts.CallFlags) -> None: + if callflags not in contracts.CallFlags(self.current_context.call_flags): + raise ValueError(f"Context requires callflags {callflags}") + + def _convert(self, stack_item: vm.StackItem, class_type: Type[object]) -> object: + """ + convert VM type to native + """ + if class_type in [vm.StackItem, vm.PointerStackItem, vm.ArrayStackItem, vm.InteropStackItem]: + return stack_item + elif class_type == int: + return int(stack_item.to_biginteger()) + elif class_type == vm.BigInteger: + return stack_item.to_biginteger() + # mypy bug? https://github.com/python/mypy/issues/9756 + elif class_type in [bytes, bytearray]: # type: ignore + return stack_item.to_array() + elif class_type == bool: + return stack_item.to_boolean() + elif class_type == types.UInt160: + return types.UInt160(data=stack_item.to_array()) + elif class_type == types.UInt256: + return types.UInt256(data=stack_item.to_array()) + elif class_type == str: + if stack_item == vm.NullStackItem(): + return "" + return stack_item.to_array().decode() + elif class_type == cryptography.ECPoint: + return cryptography.ECPoint.deserialize_from_bytes(stack_item.to_array()) + elif issubclass(class_type, enum.Enum): + if stack_item.get_type() == vm.StackItemType.INTEGER: + stack_item = cast(vm.IntegerStackItem, stack_item) + # mypy seems to have trouble understanding types that support __int__ + return class_type(int(stack_item)) # type: ignore + elif stack_item.get_type() == vm.StackItemType.BYTESTRING: + stack_item = cast(vm.ByteStringStackItem, stack_item) + return class_type(int(stack_item.to_biginteger())) # type: ignore + raise ValueError(f"Unknown class type, don't know how to convert: {class_type}") + + def _native_to_stackitem(self, value, native_type) -> vm.StackItem: + """ + Convert native type to VM type + + Note: order of checking matters. + e.g. a Transaction should be treated as IInteropable, while its also ISerializable + """ + if isinstance(value, vm.StackItem): + return value + elif value is None: + return vm.NullStackItem() + elif native_type in [int, vm.BigInteger]: + return vm.IntegerStackItem(value) + elif issubclass(native_type, IInteroperable): + value_ = cast(IInteroperable, value) + return value_.to_stack_item(self.reference_counter) + elif issubclass(native_type, serialization.ISerializable): + serializable_value = cast(serialization.ISerializable, value) + return vm.ByteStringStackItem(serializable_value.to_array()) + # mypy bug? https://github.com/python/mypy/issues/9756 + elif native_type in [bytes, bytearray]: # type: ignore + return vm.ByteStringStackItem(value) + elif native_type == str: + return vm.ByteStringStackItem(bytes(value, 'utf-8')) + elif native_type == bool: + return vm.BooleanStackItem(value) + elif issubclass(native_type, (enum.IntFlag, enum.IntEnum)): + return self._native_to_stackitem(value.value, int) + elif hasattr(native_type, '__origin__') and native_type.__origin__ == Union: # type: ignore + # handle typing.Optional[type], Optional is an alias for Union[x, None] + # only support specifying 1 type + if len(native_type.__args__) != 2: + raise ValueError(f"Don't know how to convert native type {native_type} to stackitem") + for i in native_type.__args__: + if i is None: + continue + return self._native_to_stackitem(value, native_type) + else: + raise ValueError # shouldn't be possible, but silences mypy + else: + return vm.StackItem.from_interface(value) diff --git a/neo3/contracts/binaryserializer.py b/neo3/contracts/binaryserializer.py index cb3e07a9..8e10f935 100644 --- a/neo3/contracts/binaryserializer.py +++ b/neo3/contracts/binaryserializer.py @@ -65,7 +65,6 @@ def serialize(stack_item: vm.StackItem, max_size: int) -> bytes: @staticmethod def deserialize(data: bytes, max_size: int, - max_item_size: int, reference_counter: vm.ReferenceCounter) -> vm.StackItem: """ Deserialize data into a stack item. @@ -73,7 +72,6 @@ def deserialize(data: bytes, Args: data: byte array of a serialized stack item. max_size: data reading limit for Array, Struct and Map types. - max_item_size: data reading limit for ByteString or Buffer types. reference_counter: a valid reference counter instance. Get's passed into reference stack items. """ if len(data) == 0: @@ -96,9 +94,9 @@ def deserialize(data: bytes, )) ) elif item_type == vm.StackItemType.BYTESTRING: - deserialized.append(vm.ByteStringStackItem(reader.read_var_bytes(max_item_size))) + deserialized.append(vm.ByteStringStackItem(reader.read_var_bytes(len(data)))) elif item_type == vm.StackItemType.BUFFER: - deserialized.append(vm.BufferStackItem(reader.read_var_bytes(max_item_size))) + deserialized.append(vm.BufferStackItem(reader.read_var_bytes(len(data)))) elif item_type in [vm.StackItemType.ARRAY, vm.StackItemType.STRUCT]: count = reader.read_var_int(max_size) deserialized.append(PlaceHolder(item_type, count)) diff --git a/neo3/contracts/contract.py b/neo3/contracts/contract.py index 34163a0d..174bc7d5 100644 --- a/neo3/contracts/contract.py +++ b/neo3/contracts/contract.py @@ -66,8 +66,7 @@ def create_multisig_redeemscript(m: int, public_keys: List[cryptography.ECPoint] sb.emit_push(key.encode_point(True)) sb.emit_push(len(public_keys)) - sb.emit(vm.OpCode.PUSHNULL) - sb.emit_syscall(contracts.syscall_name_to_int("Neo.Crypto.CheckMultisigWithECDsaSecp256r1")) + sb.emit_syscall(contracts.syscall_name_to_int("Neo.Crypto.CheckMultisig")) return sb.to_array() @classmethod @@ -95,8 +94,7 @@ def create_signature_redeemscript(public_key: cryptography.ECPoint) -> bytes: """ sb = vm.ScriptBuilder() sb.emit_push(public_key.encode_point(True)) - sb.emit(vm.OpCode.PUSHNULL) - sb.emit_syscall(contracts.syscall_name_to_int("Neo.Crypto.VerifyWithECDsaSecp256r1")) + sb.emit_syscall(contracts.syscall_name_to_int("Neo.Crypto.CheckSig")) return sb.to_array() @staticmethod @@ -107,15 +105,14 @@ def is_signature_contract(script: bytes) -> bool: Args: script: contract script. """ - if len(script) != 41: + if len(script) != 40: return False if (script[0] != vm.OpCode.PUSHDATA1 or script[1] != 33 - or script[35] != vm.OpCode.PUSHNULL - or script[36] != vm.OpCode.SYSCALL - or int.from_bytes(script[37:41], 'little') != contracts.syscall_name_to_int( - "Neo.Crypto.VerifyWithECDsaSecp256r1")): + or script[35] != vm.OpCode.SYSCALL + or int.from_bytes(script[36:40], 'little') != contracts.syscall_name_to_int( + "Neo.Crypto.CheckSig")): return False return True @@ -129,7 +126,7 @@ def is_multisig_contract(script: bytes) -> bool: """ len_script = len(script) - if len_script < 43: + if len_script < 42: return False # read signature length, which is encoded as variable_length @@ -180,17 +177,15 @@ def is_multisig_contract(script: bytes) -> bool: else: return False - if len_script != i + 6: + if len_script != i + 5: return False - if script[i] != int(vm.OpCode.PUSHNULL): + if script[i] != int(vm.OpCode.SYSCALL): return False - if script[i + 1] != int(vm.OpCode.SYSCALL): - return False - i += 2 + i += 1 syscall_num = int.from_bytes(script[i:i + 4], 'little') - if syscall_num != contracts.syscall_name_to_int("Neo.Crypto.CheckMultisigWithECDsaSecp256r1"): + if syscall_num != contracts.syscall_name_to_int("Neo.Crypto.CheckMultisig"): return False return True diff --git a/neo3/contracts/descriptor.py b/neo3/contracts/descriptor.py index 61cd1dc6..348c7e59 100644 --- a/neo3/contracts/descriptor.py +++ b/neo3/contracts/descriptor.py @@ -53,7 +53,7 @@ def to_json(self) -> dict: """ # NEO C# deviates here. They return a string if self.contract_hash: - val = str(self.contract_hash) + val = "0x" + str(self.contract_hash) elif self.group: val = str(self.group) else: @@ -77,8 +77,8 @@ def from_json(cls, json: dict) -> ContractPermissionDescriptor: if value is None: raise ValueError(f"Invalid JSON - Cannot deduce permission type from None") - if len(value) == 40: - return cls(contract_hash=types.UInt160.from_string(value)) + if len(value) == 42: + return cls(contract_hash=types.UInt160.from_string(value[2:])) if len(value) == 66: ecpoint = cryptography.ECPoint.deserialize_from_bytes(binascii.unhexlify(value)) return cls(group=ecpoint) diff --git a/neo3/contracts/interop/__init__.py b/neo3/contracts/interop/__init__.py index 07600393..7c59431a 100644 --- a/neo3/contracts/interop/__init__.py +++ b/neo3/contracts/interop/__init__.py @@ -5,10 +5,8 @@ from .decorator import register # the __name__ imports are just to trigger module loading, # which in turn executes the decorators to register the SYSCALLS -from .binary import __name__ from .contract import __name__ from .crypto import __name__ -from .json import __name__ from .enumerator import IIterator, StorageIterator, ArrayWrapper, ByteArrayWrapper from .runtime import __name__ -from .storage import _storage_put_internal, MAX_STORAGE_VALUE_SIZE, MAX_STORAGE_KEY_SIZE +from .storage import storage_put, MAX_STORAGE_VALUE_SIZE, MAX_STORAGE_KEY_SIZE diff --git a/neo3/contracts/interop/binary.py b/neo3/contracts/interop/binary.py deleted file mode 100644 index d99e0f22..00000000 --- a/neo3/contracts/interop/binary.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import annotations -import base64 -import base58 # type: ignore -from neo3 import vm, contracts -from neo3.contracts.interop import register - - -@register("System.Binary.Serialize", 1 << 12, contracts.CallFlags.NONE) -def binary_serialize(engine: contracts.ApplicationEngine, stack_item: vm.StackItem) -> bytes: - return contracts.BinarySerializer.serialize(stack_item, engine.MAX_ITEM_SIZE) - - -@register("System.Binary.Deserialize", 1 << 14, contracts.CallFlags.NONE) -def binary_derialize(engine: contracts.ApplicationEngine, data: bytes) -> vm.StackItem: - return contracts.BinarySerializer.deserialize(data, - engine.MAX_STACK_SIZE, - engine.MAX_ITEM_SIZE, - engine.reference_counter) - - -@register("System.Binary.Base64Encode", 1 << 12, contracts.CallFlags.NONE) -def base64_encode(engine: contracts.ApplicationEngine, data: bytes) -> str: - return base64.b64encode(data).decode() - - -@register("System.Binary.Base64Decode", 1 << 12, contracts.CallFlags.NONE) -def base64_decode(engine: contracts.ApplicationEngine, data: bytes) -> bytes: - return base64.b64decode(data) - - -@register("System.Binary.Base58Encode", 1 << 12, contracts.CallFlags.NONE) -def base58_encode(engine: contracts.ApplicationEngine, data: bytes) -> str: - return base58.b58encode(data).decode() - - -@register("System.Binary.Base58Decode", 1 << 12, contracts.CallFlags.NONE) -def base58_decode(engine: contracts.ApplicationEngine, data: bytes) -> bytes: - return base58.b58decode(data) - - -@register("System.Binary.Itoa", 1 << 12, contracts.CallFlags.NONE) -def do_itoa(engine: contracts.ApplicationEngine, value: vm.BigInteger, base: int) -> str: - if base == 10: - return str(value) - elif base == 16: - return hex(int(value))[2:] - else: - raise ValueError("Invalid base specified") - - -@register("System.Binary.Atoi", 1 << 12, contracts.CallFlags.NONE) -def do_atoi(engine: contracts.ApplicationEngine, value: str, base: int) -> int: - if base != 10 and base != 16: - raise ValueError("Invalid base specified") - else: - return int(value, base) diff --git a/neo3/contracts/interop/contract.py b/neo3/contracts/interop/contract.py index ffd3d5e6..65517566 100644 --- a/neo3/contracts/interop/contract.py +++ b/neo3/contracts/interop/contract.py @@ -7,7 +7,7 @@ from neo3.contracts.interop import register -@register("System.Contract.Call", 1 << 15, contracts.CallFlags.ALLOW_CALL) +@register("System.Contract.Call", 1 << 15, contracts.CallFlags.READ_STATES | contracts.CallFlags.ALLOW_CALL) def contract_call(engine: contracts.ApplicationEngine, contract_hash: types.UInt160, method: str, @@ -31,21 +31,6 @@ def contract_call(engine: contracts.ApplicationEngine, engine._contract_call_internal2(target_contract, method_descriptor, call_flags, has_return_value, list(args)) -@register("System.Contract.IsStandard", 1 << 10, contracts.CallFlags.READ_STATES) -def contract_is_standard(engine: contracts.ApplicationEngine, hash_: types.UInt160) -> bool: - contract = contracts.ManagementContract().get_contract(engine.snapshot, hash_) - if contract: - return (contracts.Contract.is_signature_contract(contract.script) - or contracts.Contract.is_multisig_contract(contract.script)) - - if isinstance(engine.script_container, payloads.Transaction): - for witness in engine.script_container.witnesses: - if witness.script_hash() == hash_: - return contracts.Contract.is_signature_contract(witness.verification_script) - - return False - - @register("System.Contract.GetCallFlags", 1 << 10, contracts.CallFlags.NONE) def get_callflags(engine: contracts.ApplicationEngine) -> contracts.CallFlags: return contracts.CallFlags(engine.current_context.call_flags) @@ -57,7 +42,14 @@ def contract_create_standard_account(engine: contracts.ApplicationEngine, return to_script_hash(contracts.Contract.create_signature_redeemscript(public_key)) -@register("System.Contract.NativeOnPersist", 0, contracts.CallFlags.WRITE_STATES) +@register("System.Contract.CreateMultisigAccount", 1 << 8, contracts.CallFlags.NONE) +def contract_create_multisigaccount(engine: contracts.ApplicationEngine, + m: int, + public_keys: List[cryptography.ECPoint]) -> types.UInt160: + return to_script_hash(contracts.Contract.create_multisig_redeemscript(m, public_keys)) + + +@register("System.Contract.NativeOnPersist", 0, contracts.CallFlags.STATES) def native_on_persist(engine: contracts.ApplicationEngine) -> None: if engine.trigger != contracts.TriggerType.ON_PERSIST: raise SystemError() @@ -70,7 +62,7 @@ def native_on_persist(engine: contracts.ApplicationEngine) -> None: contract.on_persist(engine) -@register("System.Contract.NativePostPersist", 0, contracts.CallFlags.WRITE_STATES) +@register("System.Contract.NativePostPersist", 0, contracts.CallFlags.STATES) def native_post_persist(engine: contracts.ApplicationEngine) -> None: if engine.trigger != contracts.TriggerType.POST_PERSIST: raise SystemError() @@ -80,11 +72,10 @@ def native_post_persist(engine: contracts.ApplicationEngine) -> None: @register("System.Contract.CallNative", 0, contracts.CallFlags.NONE) -def call_native(engine: contracts.ApplicationEngine, contract_id: int) -> None: - contract = contracts.NativeContract.get_contract_by_id(contract_id) +def call_native(engine: contracts.ApplicationEngine, version: int) -> None: + contract = contracts.NativeContract.get_contract_by_hash(engine.current_scripthash) if contract is None: - raise ValueError(f"Can't find native contract with id {contract_id}") - - if contract.active_block_index > engine.snapshot.best_block_height: - raise ValueError(f"Native contract is not active until blockheight {contract.active_block_index}") - contract.invoke(engine) + raise ValueError(f"It is not allowed to use \"System.Contract.CallNative\" directly") + if contract.active_block_index > engine.snapshot.persisting_block.index: + raise ValueError(f"The native contract {contract.service_name()} is not active") + contract.invoke(engine, version) diff --git a/neo3/contracts/interop/crypto.py b/neo3/contracts/interop/crypto.py index 525614bc..7c083003 100644 --- a/neo3/contracts/interop/crypto.py +++ b/neo3/contracts/interop/crypto.py @@ -1,61 +1,27 @@ from __future__ import annotations -import hashlib from neo3 import vm, contracts, settings -from neo3.network import payloads from neo3.core import cryptography from neo3.contracts.interop import register -from typing import cast, List +from typing import List -def stackitem_to_hash_data(engine: contracts.ApplicationEngine, stack_item: vm.StackItem) -> bytes: - if isinstance(stack_item, vm.InteropStackItem): - item = stack_item.get_object() - if not issubclass(type(item), payloads.IVerifiable): - raise ValueError("Invalid type") - item = cast(payloads.IVerifiable, item) - value = item.get_hash_data(settings.network.magic) - elif isinstance(stack_item, vm.NullStackItem): - value = engine.script_container.get_hash_data(settings.network.magic) - else: - value = stack_item.to_array() - return value +CHECKSIG_PRICE = 1 << 15 -@register("Neo.Crypto.RIPEMD160", 1 << 15, contracts.CallFlags.NONE) -def do_ripemd160(engine: contracts.ApplicationEngine, stack_item: vm.StackItem) -> bytes: - value = stackitem_to_hash_data(engine, stack_item) - return hashlib.new('ripemd160', value).digest() - - -@register("Neo.Crypto.SHA256", 1 << 15, contracts.CallFlags.NONE) -def do_sha256(engine: contracts.ApplicationEngine, stack_item: vm.StackItem) -> bytes: - value = stackitem_to_hash_data(engine, stack_item) - return hashlib.sha256(value).digest() - - -@register("Neo.Crypto.VerifyWithECDsaSecp256r1", 1 << 15, contracts.CallFlags.NONE) +@register("Neo.Crypto.CheckSig", CHECKSIG_PRICE, contracts.CallFlags.NONE) def verify_with_ECDSA_Secp256r1(engine: contracts.ApplicationEngine, - stack_item: vm.StackItem, - public_key: bytes, - signature: bytes) -> bool: - value = stackitem_to_hash_data(engine, stack_item) - return cryptography.verify_signature(value, signature, public_key, cryptography.ECCCurve.SECP256R1) - - -@register("Neo.Crypto.VerifyWithECDsaSecp256k1", 1 << 15, contracts.CallFlags.NONE) -def verify_with_ECDSA_Secp256k1(engine: contracts.ApplicationEngine, - stack_item: vm.StackItem, public_key: bytes, signature: bytes) -> bool: - value = stackitem_to_hash_data(engine, stack_item) - return cryptography.verify_signature(value, signature, public_key, cryptography.ECCCurve.SECP256K1) + return cryptography.verify_signature(engine.script_container.get_hash_data(settings.network.magic), + signature, + public_key, + cryptography.ECCCurve.SECP256R1) -def _check_multisig(engine: contracts.ApplicationEngine, - stack_item: vm.StackItem, - public_keys: List[bytes], - signatures: List[bytes], - curve: cryptography.ECCCurve) -> bool: +@register("Neo.Crypto.CheckMultisig", 0, contracts.CallFlags.NONE) +def check_multisig_with_ECDSA_Secp256r1(engine: contracts.ApplicationEngine, + public_keys: List[bytes], + signatures: List[bytes]) -> bool: len_pub_keys = len(public_keys) len_sigs = len(signatures) if len_sigs == 0: @@ -65,15 +31,15 @@ def _check_multisig(engine: contracts.ApplicationEngine, if len_sigs > len_pub_keys: raise ValueError(f"Verification requires {len_sigs} public keys, got only {len_pub_keys}") - message = stackitem_to_hash_data(engine, stack_item) + message = engine.script_container.get_hash_data(settings.network.magic) - engine.add_gas(len_pub_keys * (1 << 15) * engine.exec_fee_factor) + engine.add_gas(len_pub_keys * CHECKSIG_PRICE * engine.exec_fee_factor) i = 0 j = 0 try: while i < len_sigs and j < len_pub_keys: - if cryptography.verify_signature(message, signatures[i], public_keys[j], curve): + if cryptography.verify_signature(message, signatures[i], public_keys[j], cryptography.ECCCurve.SECP256R1): i += 1 j += 1 @@ -82,19 +48,3 @@ def _check_multisig(engine: contracts.ApplicationEngine, except cryptography.ECCException as e: return False return True - - -@register("Neo.Crypto.CheckMultisigWithECDsaSecp256r1", 0, contracts.CallFlags.NONE) -def check_multisig_with_ECDSA_Secp256r1(engine: contracts.ApplicationEngine, - stack_item: vm.StackItem, - public_keys: List[bytes], - signatures: List[bytes]) -> bool: - return _check_multisig(engine, stack_item, public_keys, signatures, cryptography.ECCCurve.SECP256R1) - - -@register("Neo.Crypto.CheckMultisigWithECDsaSecp256k1", 0, contracts.CallFlags.NONE) -def check_multisig_with_ECDSA_Secp256k1(engine: contracts.ApplicationEngine, - stack_item: vm.StackItem, - public_keys: List[bytes], - signatures: List[bytes]) -> bool: - return _check_multisig(engine, stack_item, public_keys, signatures, cryptography.ECCCurve.SECP256K1) diff --git a/neo3/contracts/interop/enumerator.py b/neo3/contracts/interop/enumerator.py index cda8fc57..54c39c13 100644 --- a/neo3/contracts/interop/enumerator.py +++ b/neo3/contracts/interop/enumerator.py @@ -99,7 +99,7 @@ def value(self) -> vm.StackItem: key = key[1:] if contracts.FindOptions.DESERIALIZE_VALUES in self.options: - item: vm.StackItem = contracts.BinarySerializer.deserialize(value, 1024, len(value), self.reference_counter) + item: vm.StackItem = contracts.BinarySerializer.deserialize(value, 1024, self.reference_counter) else: item = vm.ByteStringStackItem(value) diff --git a/neo3/contracts/interop/json.py b/neo3/contracts/interop/json.py deleted file mode 100644 index 25565c16..00000000 --- a/neo3/contracts/interop/json.py +++ /dev/null @@ -1,13 +0,0 @@ -from __future__ import annotations -from neo3 import vm, contracts -from neo3.contracts.interop import register - - -@register("System.Json.Serialize", 1 << 12, contracts.CallFlags.NONE) -def json_serialize(engine: contracts.ApplicationEngine, stack_item: vm.StackItem) -> bytes: - return bytes(contracts.JSONSerializer.serialize(stack_item, engine.MAX_ITEM_SIZE), 'utf-8') - - -@register("System.Json.Deserialize", 1 << 14, contracts.CallFlags.NONE) -def json_deserialize(engine: contracts.ApplicationEngine, data: bytes) -> vm.StackItem: - return contracts.JSONSerializer.deserialize(data.decode(), engine.reference_counter) diff --git a/neo3/contracts/interop/runtime.py b/neo3/contracts/interop/runtime.py index 2ce0594f..4de66fa4 100644 --- a/neo3/contracts/interop/runtime.py +++ b/neo3/contracts/interop/runtime.py @@ -22,11 +22,10 @@ def get_time(engine: contracts.ApplicationEngine) -> int: @register("System.Runtime.GetScriptContainer", 1 << 3, contracts.CallFlags.NONE) -def get_scriptcontainer(engine: contracts.ApplicationEngine) -> IInteroperable: +def get_scriptcontainer(engine: contracts.ApplicationEngine) -> vm.StackItem: if not isinstance(engine.script_container, IInteroperable): raise ValueError("script container is not a valid IInteroperable type") - container = cast(IInteroperable, engine.script_container) - return container + return engine.script_container.to_stack_item(engine.reference_counter) @register("System.Runtime.GetExecutingScriptHash", 1 << 4, contracts.CallFlags.NONE) diff --git a/neo3/contracts/interop/storage.py b/neo3/contracts/interop/storage.py index c785e93c..4de05803 100644 --- a/neo3/contracts/interop/storage.py +++ b/neo3/contracts/interop/storage.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Optional from neo3 import contracts, storage -from neo3.core import types from neo3.contracts.interop import register, IIterator, StorageIterator MAX_STORAGE_KEY_SIZE = 64 @@ -66,11 +65,8 @@ def storage_find(engine: contracts.ApplicationEngine, return it -def _storage_put_internal(engine: contracts.ApplicationEngine, - context: storage.StorageContext, - key: bytes, - value: bytes, - flags: storage.StorageFlags) -> None: +@register("System.Storage.Put", 1 << 15, contracts.CallFlags.WRITE_STATES) +def storage_put(engine: contracts.ApplicationEngine, context: storage.StorageContext, key: bytes, value: bytes) -> None: if len(key) > MAX_STORAGE_KEY_SIZE: raise ValueError(f"Storage key length exceeds maximum of {MAX_STORAGE_KEY_SIZE}") if len(value) > MAX_STORAGE_VALUE_SIZE: @@ -81,50 +77,27 @@ def _storage_put_internal(engine: contracts.ApplicationEngine, storage_key = storage.StorageKey(context.id, key) item = engine.snapshot.storages.try_get(storage_key, read_only=False) - is_constant = storage.StorageFlags.CONSTANT in flags if item is None: new_data_len = len(key) + len(value) - item = storage.StorageItem(b'', is_constant) + item = storage.StorageItem(b'') engine.snapshot.storages.put(storage_key, item) else: - if item.is_constant: - raise ValueError("StorageItem is marked as constant") if len(value) == 0: - new_data_len = 1 + new_data_len = 0 elif len(value) <= len(item.value): new_data_len = (len(value) - 1) // 4 + 1 + elif len(item.value) == 0: + new_data_len = len(value) else: new_data_len = (len(item.value) - 1) // 4 + 1 + len(value) - len(item.value) engine.add_gas(new_data_len * engine.STORAGE_PRICE) item.value = value - item.is_constant = is_constant -@register("System.Storage.Put", 0, contracts.CallFlags.WRITE_STATES) -def storage_put(engine: contracts.ApplicationEngine, - context: storage.StorageContext, - key: bytes, - value: bytes) -> None: - _storage_put_internal(engine, context, key, value, storage.StorageFlags.NONE) - - -@register("System.Storage.PutEx", 0, contracts.CallFlags.WRITE_STATES) -def storage_put_ex(engine: contracts.ApplicationEngine, - context: storage.StorageContext, - key: bytes, - value: bytes, - flags: storage.StorageFlags) -> None: - _storage_put_internal(engine, context, key, value, flags) - - -@register("System.Storage.Delete", 0, contracts.CallFlags.WRITE_STATES) +@register("System.Storage.Delete", 1 << 15, contracts.CallFlags.WRITE_STATES) def storage_delete(engine: contracts.ApplicationEngine, context: storage.StorageContext, key: bytes) -> None: if context.is_read_only: raise ValueError("Cannot delete from read-only storage context") - engine.add_gas(engine.STORAGE_PRICE) storage_key = storage.StorageKey(context.id, key) - item = engine.snapshot.storages.try_get(storage_key) - if item and item.is_constant: - raise ValueError("Cannot delete a storage item that is marked constant") engine.snapshot.storages.delete(storage_key) diff --git a/neo3/contracts/manifest.py b/neo3/contracts/manifest.py index 6627b7db..42d886c8 100644 --- a/neo3/contracts/manifest.py +++ b/neo3/contracts/manifest.py @@ -57,9 +57,10 @@ def from_json(cls, json: dict) -> ContractGroup: KeyError: if the data supplied does not contain the necessary keys. ValueError: if the signature length is not 64. """ + pubkey = contracts.validate_type(json['pubkey'], str) c = cls( - public_key=cryptography.ECPoint.deserialize_from_bytes(binascii.unhexlify(json['pubkey'])), - signature=base64.b64decode(json['signature'].encode('utf8')) + public_key=cryptography.ECPoint.deserialize_from_bytes(binascii.unhexlify(pubkey)), + signature=base64.b64decode(contracts.validate_type(json['signature'], str).encode('utf8')) ) if len(c.signature) != 64: raise ValueError("Format error - invalid signature length") @@ -341,17 +342,18 @@ def deserialize(self, reader: BinaryReader) -> None: self._deserialize_from_json(json.loads(reader.read_var_string(self.MAX_LENGTH))) def _deserialize_from_json(self, json: dict) -> None: - self.name = json['name'] - if self.name is None: + if json['name'] is None: self.name = "" + else: + self.name = contracts.validate_type(json['name'], str) self.abi = contracts.ContractABI.from_json(json['abi']) self.groups = list(map(lambda g: ContractGroup.from_json(g), json['groups'])) - self.supported_standards = json['supportedstandards'] + self.supported_standards = list(map(lambda ss: contracts.validate_type(ss, str), json['supportedstandards'])) self.permissions = list(map(lambda p: ContractPermission.from_json(p), json['permissions'])) self.trusts = WildcardContainer.from_json_as_type( {'wildcard': json['trusts']}, - lambda t: types.UInt160.from_string(t)) + lambda t: types.UInt160.from_string(contracts.validate_type(t, str))) # converting json key/value back to default WildcardContainer format self.extra = json['extra'] @@ -372,7 +374,7 @@ def to_json(self) -> dict: } return json - def to_stack_item(self, reference_counter: vm.ReferenceCounter): + def to_stack_item(self, reference_counter: vm.ReferenceCounter) -> vm.StackItem: struct = vm.StructStackItem(reference_counter) struct.append(vm.ByteStringStackItem(self.name)) struct.append(vm.ArrayStackItem(reference_counter, diff --git a/neo3/contracts/native/__init__.py b/neo3/contracts/native/__init__.py index 67f2edcf..c22fa15b 100644 --- a/neo3/contracts/native/__init__.py +++ b/neo3/contracts/native/__init__.py @@ -1,3 +1,4 @@ +from .decorator import register from .nativecontract import NativeContract from .fungible import (FungibleToken, NeoToken, GasToken, FungibleTokenStorageState) from .policy import PolicyContract @@ -7,6 +8,8 @@ from .nonfungible import NonFungibleToken, NFTState from .nameservice import NameService from .ledger import LedgerContract +from .crypto import CryptoContract +from .stdlib import StdLibContract __all__ = ['NativeContract', 'PolicyContract', @@ -17,5 +20,8 @@ 'ManagementContract', 'NameService', 'LedgerContract', - 'FungibleToken' + 'FungibleToken', + 'CryptoContract', + 'StdLibContract', + 'register' ] diff --git a/neo3/contracts/native/crypto.py b/neo3/contracts/native/crypto.py new file mode 100644 index 00000000..d2702a15 --- /dev/null +++ b/neo3/contracts/native/crypto.py @@ -0,0 +1,37 @@ +from __future__ import annotations +import hashlib +import enum +from . import NativeContract, register +from neo3 import contracts +from neo3.core import cryptography + + +class NamedCurve(enum.IntEnum): + SECP256K1 = 22 + SECP256R1 = 23 + + +class CryptoContract(NativeContract): + + _service_name = "CryptoLib" + _id = -3 + + curves = { + NamedCurve.SECP256K1: cryptography.ECCCurve.SECP256K1, + NamedCurve.SECP256R1: cryptography.ECCCurve.SECP256R1 + } + + def init(self): + super(CryptoContract, self).init() + + @register("ripemd160", contracts.CallFlags.NONE, cpu_price=1 << 15) + def ripemd160(self, data: bytes) -> bytes: + return hashlib.new('ripemd160', data).digest() + + @register("sha256", contracts.CallFlags.NONE, cpu_price=1 << 15) + def sha256(self, data: bytes) -> bytes: + return hashlib.sha256(data).digest() + + @register("verifyWithECDsa", contracts.CallFlags.NONE, cpu_price=1 << 15) + def verify_with_ecdsa(self, message: bytes, public_key: bytes, signature: bytes, curve: NamedCurve) -> bool: + return cryptography.verify_signature(message, signature, public_key, self.curves.get(curve)) diff --git a/neo3/contracts/native/decorator.py b/neo3/contracts/native/decorator.py new file mode 100644 index 00000000..86764dbe --- /dev/null +++ b/neo3/contracts/native/decorator.py @@ -0,0 +1,26 @@ +from __future__ import annotations +from neo3 import contracts + + +def register(method: str, + flags: contracts.CallFlags, + *, + cpu_price: int = 0, + storage_price: int = 0): + """ + Register a publicly callable method on a native contract + + Args: + method: name of call. + cpu_price: the computational price of calling the handler. + storage_price: the storage price of calling the handler. + flags: ExecutionContext rights needed. + """ + def inner_func(func): + func.native_call = True + func.name = method + func.cpu_price = cpu_price + func.storage_price = storage_price + func.flags = flags + return func + return inner_func diff --git a/neo3/contracts/native/designate.py b/neo3/contracts/native/designate.py index 4cc54c5a..31d771e1 100644 --- a/neo3/contracts/native/designate.py +++ b/neo3/contracts/native/designate.py @@ -2,7 +2,7 @@ import struct from enum import IntEnum from typing import List -from . import NativeContract +from . import NativeContract, register from neo3 import storage, contracts, cryptography, vm from neo3.core import serialization @@ -10,29 +10,17 @@ class DesignateRole(IntEnum): STATE_VALIDATOR = 4 ORACLE = 8 + NEO_FS_ALPHABET_NODE = 16 class DesignationContract(NativeContract): - _id = -6 + _id = -8 _service_name = "RoleManagement" def init(self): super(DesignationContract, self).init() - self._register_contract_method(self.get_designated_by_role, - "getDesignatedByRole", - 1000000, - parameter_names=["role", "index"], - call_flags=contracts.CallFlags.READ_STATES) - - self._register_contract_method(self.designate_as_role, - "designateAsRole", - 0, - parameter_names=["role", "nodes"], - call_flags=contracts.CallFlags.WRITE_STATES) - - def _to_uint32(self, value: int) -> bytes: - return struct.pack(">I", value) + @register("getDesignatedByRole", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def get_designated_by_role(self, snapshot: storage.Snapshot, role: DesignateRole, @@ -48,6 +36,7 @@ def get_designated_by_role(self, else: return [] + @register("designateAsRole", contracts.CallFlags.STATES, cpu_price=1 << 15) def designate_as_role(self, engine: contracts.ApplicationEngine, role: DesignateRole, @@ -71,3 +60,6 @@ def designate_as_role(self, writer.write_serializable_list(nodes) storage_item = storage.StorageItem(writer.to_array()) engine.snapshot.storages.update(storage_key, storage_item) + + def _to_uint32(self, value: int) -> bytes: + return struct.pack(">I", value) diff --git a/neo3/contracts/native/fungible.py b/neo3/contracts/native/fungible.py index 004826b4..62f90460 100644 --- a/neo3/contracts/native/fungible.py +++ b/neo3/contracts/native/fungible.py @@ -1,9 +1,10 @@ from __future__ import annotations import struct from .nativecontract import NativeContract +from .decorator import register from neo3 import storage, contracts, vm, settings from neo3.core import types, msgrouter, cryptography, serialization, to_script_hash, Size as s, IInteroperable -from typing import Tuple, List, Dict, Sequence, cast, Optional +from typing import Tuple, List, Dict, Sequence, cast class FungibleTokenStorageState(IInteroperable, serialization.ISerializable): @@ -25,9 +26,9 @@ def deserialize(self, reader: serialization.BinaryReader) -> None: self.balance = vm.BigInteger(reader.read_var_bytes()) def to_stack_item(self, reference_counter: vm.ReferenceCounter) -> vm.StackItem: - struct = vm.StructStackItem(reference_counter) - struct.append(vm.IntegerStackItem(self.balance)) - return struct + struct_ = vm.StructStackItem(reference_counter) + struct_.append(vm.IntegerStackItem(self.balance)) + return struct_ @classmethod def from_stack_item(cls, stack_item: vm.StackItem): @@ -63,35 +64,108 @@ def init(self): ] self.factor = pow(vm.BigInteger(10), vm.BigInteger(self._decimals)) - self._register_contract_method(self.total_supply, - "totalSupply", - 1000000, - call_flags=contracts.CallFlags.READ_STATES) - self._register_contract_method(self.balance_of, - "balanceOf", - 1000000, - parameter_names=["account"], - call_flags=contracts.CallFlags.READ_STATES) - self._register_contract_method(self.transfer, - "transfer", - 9000000, - parameter_names=["account_from", "account_to", "amount", "data"], - call_flags=(contracts.CallFlags.WRITE_STATES - | contracts.CallFlags.ALLOW_CALL - | contracts.CallFlags.ALLOW_NOTIFY)) - self._register_contract_method(self.symbol, - "symbol", - 0, - call_flags=contracts.CallFlags.READ_STATES) - self._register_contract_method(self.on_persist, - "onPersist", - 0, - call_flags=contracts.CallFlags.WRITE_STATES) + @register("decimals", contracts.CallFlags.READ_STATES) + def decimals(self) -> int: + return self._decimals + @register("symbol", contracts.CallFlags.READ_STATES) def symbol(self) -> str: """ Token symbol. """ return self._symbol + @register("totalSupply", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) + def total_supply(self, snapshot: storage.Snapshot) -> vm.BigInteger: + """ Get the total deployed tokens. """ + storage_item = snapshot.storages.try_get(self.key_total_supply) + if storage_item is None: + return vm.BigInteger.zero() + else: + return vm.BigInteger(storage_item.value) + + @register("balanceOf", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) + def balance_of(self, snapshot: storage.Snapshot, account: types.UInt160) -> vm.BigInteger: + """ + Get the balance of an account. + + Args: + snapshot: snapshot of the storage + account: script hash of the account to obtain the balance of + + Returns: + amount of balance. + + Note: The returned value is still in internal format. Divide the results by the contract's `decimals` + """ + storage_item = snapshot.storages.try_get(self.key_account + account) + if storage_item is None: + return vm.BigInteger.zero() + else: + state = self._state.deserialize_from_bytes(storage_item.value) + return state.balance + + @register("transfer", + (contracts.CallFlags.STATES | contracts.CallFlags.ALLOW_CALL | contracts.CallFlags.ALLOW_NOTIFY), + cpu_price=1 << 17, storage_price=50) + def transfer(self, + engine: contracts.ApplicationEngine, + account_from: types.UInt160, + account_to: types.UInt160, + amount: vm.BigInteger, + data: vm.StackItem + ) -> bool: + """ + Transfer tokens from one account to another. + + Raises: + ValueError: if the requested amount is negative. + + Returns: + True on success. False otherwise. + """ + if amount.sign < 0: + raise ValueError("Can't transfer a negative amount") + + # transfer from an account not owned by the smart contract that is requesting the transfer + # and there is no signature that approves we are allowed todo so + if account_from != engine.calling_scripthash and not engine.checkwitness(account_from): + return False + + storage_key_from = self.key_account + account_from + storage_item_from = engine.snapshot.storages.try_get(storage_key_from, read_only=False) + + if storage_item_from is None: + return False + + state_from = storage_item_from.get(self._state) + if amount == vm.BigInteger.zero(): + self.on_balance_changing(engine, account_from, state_from, amount) + else: + if state_from.balance < amount: + return False + + if account_from == account_to: + self.on_balance_changing(engine, account_from, state_from, vm.BigInteger.zero()) + else: + self.on_balance_changing(engine, account_from, state_from, -amount) + if state_from.balance == amount: + engine.snapshot.storages.delete(storage_key_from) + else: + state_from.balance -= amount + + storage_key_to = self.key_account + account_to + storage_item_to = engine.snapshot.storages.try_get(storage_key_to, read_only=False) + if storage_item_to is None: + storage_item_to = storage.StorageItem(self._state().to_array()) + engine.snapshot.storages.put(storage_key_to, storage_item_to) + + state_to = storage_item_to.get(self._state) + + self.on_balance_changing(engine, account_to, state_to, amount) + state_to.balance += amount + + self._post_transfer(engine, account_from, account_to, amount, data, True) + return True + def mint(self, engine: contracts.ApplicationEngine, account: types.UInt160, @@ -167,33 +241,14 @@ def burn(self, engine: contracts.ApplicationEngine, account: types.UInt160, amou storage_item.value = new_value.to_array() self._post_transfer(engine, account, types.UInt160.zero(), amount, vm.NullStackItem(), False) - def total_supply(self, snapshot: storage.Snapshot) -> vm.BigInteger: - """ Get the total deployed tokens. """ - storage_item = snapshot.storages.try_get(self.key_total_supply) - if storage_item is None: - return vm.BigInteger.zero() - else: - return vm.BigInteger(storage_item.value) - - def balance_of(self, snapshot: storage.Snapshot, account: types.UInt160) -> vm.BigInteger: - """ - Get the balance of an account. - - Args: - snapshot: snapshot of the storage - account: script hash of the account to obtain the balance of - - Returns: - amount of balance. + def on_balance_changing(self, engine: contracts.ApplicationEngine, + account: types.UInt160, + state, + amount: vm.BigInteger) -> None: + pass - Note: The returned value is still in internal format. Divide the results by the contract's `decimals` - """ - storage_item = snapshot.storages.try_get(self.key_account + account) - if storage_item is None: - return vm.BigInteger.zero() - else: - state = self._state.deserialize_from_bytes(storage_item.value) - return state.balance + def on_persist(self, engine: contracts.ApplicationEngine) -> None: + pass def _post_transfer(self, engine: contracts.ApplicationEngine, @@ -227,72 +282,6 @@ def _post_transfer(self, from_ = vm.ByteStringStackItem(account_from.to_array()) engine.call_from_native(self.hash, account_to, "onNEP17Payment", [from_, vm.IntegerStackItem(amount), data]) - def transfer(self, - engine: contracts.ApplicationEngine, - account_from: types.UInt160, - account_to: types.UInt160, - amount: vm.BigInteger, - data: vm.StackItem - ) -> bool: - """ - Transfer tokens from one account to another. - - Raises: - ValueError: if the requested amount is negative. - - Returns: - True on success. False otherwise. - """ - if amount.sign < 0: - raise ValueError("Can't transfer a negative amount") - - # transfer from an account not owned by the smart contract that is requesting the transfer - # and there is no signature that approves we are allowed todo so - if account_from != engine.calling_scripthash and not engine.checkwitness(account_from): - return False - - storage_key_from = self.key_account + account_from - storage_item_from = engine.snapshot.storages.try_get(storage_key_from, read_only=False) - - if storage_item_from is None: - return False - - state_from = storage_item_from.get(self._state) - if amount == vm.BigInteger.zero(): - self.on_balance_changing(engine, account_from, state_from, amount) - else: - if state_from.balance < amount: - return False - - if account_from == account_to: - self.on_balance_changing(engine, account_from, state_from, vm.BigInteger.zero()) - else: - self.on_balance_changing(engine, account_from, state_from, -amount) - if state_from.balance == amount: - engine.snapshot.storages.delete(storage_key_from) - else: - state_from.balance -= amount - - storage_key_to = self.key_account + account_to - storage_item_to = engine.snapshot.storages.try_get(storage_key_to, read_only=False) - if storage_item_to is None: - storage_item_to = storage.StorageItem(self._state().to_array()) - engine.snapshot.storages.put(storage_key_to, storage_item_to) - - state_to = storage_item_to.get(self._state) - - self.on_balance_changing(engine, account_to, state_to, amount) - state_to.balance += amount - - self._post_transfer(engine, account_from, account_to, amount, data, True) - return True - - def on_balance_changing(self, engine: contracts.ApplicationEngine, - account: types.UInt160, - state, - amount: vm.BigInteger) -> None: - pass - class _NeoTokenStorageState(FungibleTokenStorageState): """ @@ -447,7 +436,7 @@ def deserialize(self, reader: serialization.BinaryReader) -> None: class NeoToken(FungibleToken): - _id: int = -3 + _id: int = -5 _decimals: int = 0 key_committee = storage.StorageKey(_id, b'\x0e') @@ -455,6 +444,7 @@ class NeoToken(FungibleToken): key_voters_count = storage.StorageKey(_id, b'\x01') key_gas_per_block = storage.StorageKey(_id, b'\x29') key_voter_reward_per_committee = storage.StorageKey(_id, b'\x17') + key_register_price = storage.StorageKey(_id, b'\x0D') _NEO_HOLDER_REWARD_RATIO = 10 _COMMITTEE_REWARD_RATIO = 10 @@ -464,126 +454,10 @@ class NeoToken(FungibleToken): _candidates_dirty = True _candidates: List[Tuple[cryptography.ECPoint, vm.BigInteger]] = [] - def _to_uint32(self, value: int) -> bytes: - return struct.pack(">I", value) - - def _calculate_bonus(self, - snapshot: storage.Snapshot, - vote: cryptography.ECPoint, - value: vm.BigInteger, - start: int, - end: int) -> vm.BigInteger: - if value == vm.BigInteger.zero() or start >= end: - return vm.BigInteger.zero() - - if value.sign < 0: - raise ValueError("Can't calculate bonus over negative balance") - - neo_holder_reward = self._calculate_neo_holder_reward(snapshot, value, start, end) - if vote.is_zero(): - return neo_holder_reward - border = (self.key_voter_reward_per_committee + vote).to_array() - start_bytes = self._to_uint32(start) - key_start = (self.key_voter_reward_per_committee + vote + start_bytes).to_array() - - try: - pair = next(snapshot.storages.find_range(self.hash, key_start, border, "reverse")) - start_reward_per_neo = vm.BigInteger(pair[1].value) # first pair returned, StorageItem - except StopIteration: - start_reward_per_neo = vm.BigInteger.zero() - - end_bytes = self._to_uint32(end) - key_end = (self.key_voter_reward_per_committee + vote + end_bytes).to_array() - - try: - pair = next(snapshot.storages.find_range(self.hash, key_end, border, "reverse")) - end_reward_per_neo = vm.BigInteger(pair[1].value) # first pair returned, StorageItem - except StopIteration: - end_reward_per_neo = vm.BigInteger.zero() - - return neo_holder_reward + value * (end_reward_per_neo - start_reward_per_neo) / 100000000 - - def _calculate_neo_holder_reward(self, - snapshot: storage.Snapshot, - value: vm.BigInteger, - start: int, - end: int) -> vm.BigInteger: - gas_bonus_state = GasBonusState.from_snapshot(snapshot, read_only=True) - gas_sum = 0 - for pair in reversed(gas_bonus_state): # type: _GasRecord - cur_idx = pair.index - if cur_idx >= end: - continue - if cur_idx > start: - gas_sum += pair.gas_per_block * (end - cur_idx) - end = cur_idx - else: - gas_sum += pair.gas_per_block * (end - start) - break - return value * gas_sum * self._NEO_HOLDER_REWARD_RATIO / 100 / self.total_amount - - def _should_refresh_committee(self, height: int) -> bool: - return height % len(settings.standby_committee) == 0 - - def _check_candidate(self, - snapshot: storage.Snapshot, - public_key: cryptography.ECPoint, - candidate: _CandidateState) -> None: - if not candidate.registered and candidate.votes == 0: - for k, v in snapshot.storages.find((self.key_voter_reward_per_committee + public_key).to_array()): - snapshot.storages.delete(k) - snapshot.storages.delete(self.key_candidate + public_key) - def init(self): super(NeoToken, self).init() # singleton init, similar to __init__ but called only once self.total_amount = self.factor * 100_000_000 - - self._register_contract_method(self.register_candidate, - "registerCandidate", - 1000_00000000, - parameter_names=["public_key"], - call_flags=contracts.CallFlags.WRITE_STATES) - - self._register_contract_method(self.unregister_candidate, - "unregisterCandidate", - 5000000, - parameter_names=["public_key"], - call_flags=contracts.CallFlags.WRITE_STATES) - - self._register_contract_method(self.vote, - "vote", - 5000000, - parameter_names=["account", "public_key"], - call_flags=contracts.CallFlags.WRITE_STATES) - self._register_contract_method(self._set_gas_per_block, - "setGasPerBlock", - 5000000, - parameter_names=["gas_per_block"], - call_flags=contracts.CallFlags.WRITE_STATES - ) - self._register_contract_method(self.get_gas_per_block, - "getGasPerBlock", - 1000000, - call_flags=contracts.CallFlags.READ_STATES - ) - self._register_contract_method(self.get_committee, - "getCommittee", - 100000000, - call_flags=contracts.CallFlags.READ_STATES - ) - - self._register_contract_method(self.get_candidates, - "getCandidates", - 100000000, - call_flags=contracts.CallFlags.READ_STATES - ) - - self._register_contract_method(self.get_next_block_validators, - "getNextBlockValidators", - 100000000, - call_flags=contracts.CallFlags.READ_STATES - ) self._committee_state = None def _initialize(self, engine: contracts.ApplicationEngine) -> None: @@ -595,81 +469,14 @@ def _initialize(self, engine: contracts.ApplicationEngine) -> None: gas_bonus_state = GasBonusState(_GasRecord(0, GasToken().factor * 5)) engine.snapshot.storages.put(self.key_gas_per_block, storage.StorageItem(gas_bonus_state.to_array())) + engine.snapshot.storages.put(self.key_register_price, + storage.StorageItem((GasToken().factor * 1000).to_array())) self.mint(engine, contracts.Contract.get_consensus_address(settings.standby_validators), self.total_amount, False) - def total_supply(self, snapshot: storage.Snapshot) -> vm.BigInteger: - """ Get the total deployed tokens. """ - return self.total_amount - - def on_balance_changing(self, engine: contracts.ApplicationEngine, - account: types.UInt160, - state, - amount: vm.BigInteger) -> None: - self._distribute_gas(engine, account, state) - - if amount == vm.BigInteger.zero(): - return - - if state.vote_to.is_zero(): - return - - si_voters_count = engine.snapshot.storages.get(self.key_voters_count, read_only=False) - new_value = vm.BigInteger(si_voters_count.value) + amount - si_voters_count.value = new_value.to_array() - - si_candidate = engine.snapshot.storages.get(self.key_candidate + state.vote_to, read_only=False) - candidate_state = si_candidate.get(_CandidateState) - candidate_state.votes += amount - self._candidates_dirty = True - self._check_candidate(engine.snapshot, state.vote_to, candidate_state) - - def on_persist(self, engine: contracts.ApplicationEngine) -> None: - super(NeoToken, self).on_persist(engine) - - # set next committee - if self._should_refresh_committee(engine.snapshot.persisting_block.index): - validators = self._compute_committee_members(engine.snapshot) - if self._committee_state is None: - self._committee_state = _CommitteeState.from_snapshot(engine.snapshot) - self._committee_state._validators = validators - self._committee_state.persist(engine.snapshot) - - def post_persist(self, engine: contracts.ApplicationEngine): - super(NeoToken, self).post_persist(engine) - # distribute GAS for committee - m = len(settings.standby_committee) - n = settings.network.validators_count - index = engine.snapshot.persisting_block.index % m - gas_per_block = self.get_gas_per_block(engine.snapshot) - committee = self.get_committee_from_cache(engine.snapshot) - pubkey = committee[index] - account = to_script_hash(contracts.Contract.create_signature_redeemscript(pubkey)) - GasToken().mint(engine, account, gas_per_block * self._COMMITTEE_REWARD_RATIO / 100, False) - - if self._should_refresh_committee(engine.snapshot.persisting_block.index): - voter_reward_of_each_committee = gas_per_block * self._VOTER_REWARD_RATIO * 100000000 * m / (m + n) / 100 - for i, member in enumerate(committee): - factor = 2 if i < n else 1 - member_votes = self._committee_state[member] - if member_votes > 0: - voter_sum_reward_per_neo = voter_reward_of_each_committee * factor / member_votes - voter_reward_key = (self.key_voter_reward_per_committee - + member - + self._to_uint32(engine.snapshot.persisting_block.index + 1) - ) - border = (self.key_voter_reward_per_committee + member).to_array() - try: - pair = next(engine.snapshot.storages.find_range(voter_reward_key.to_array(), border, "reverse")) - result = vm.BigInteger(pair[1].value) - except StopIteration: - result = vm.BigInteger.zero() - voter_sum_reward_per_neo += result - engine.snapshot.storages.put(voter_reward_key, - storage.StorageItem(voter_sum_reward_per_neo.to_array())) - + @register("unclaimedGas", contracts.CallFlags.READ_STATES, cpu_price=1 << 17) def unclaimed_gas(self, snapshot: storage.Snapshot, account: types.UInt160, end: int) -> vm.BigInteger: """ Return the available bonus GAS for an account. @@ -687,6 +494,7 @@ def unclaimed_gas(self, snapshot: storage.Snapshot, account: types.UInt160, end: state = storage_item.get(self._state) return self._calculate_bonus(snapshot, state.vote_to, state.balance, state.balance_height, end) + @register("registerCandidate", contracts.CallFlags.STATES) def register_candidate(self, engine: contracts.ApplicationEngine, public_key: cryptography.ECPoint) -> bool: @@ -704,6 +512,7 @@ def register_candidate(self, if not engine.checkwitness(script_hash): return False + engine.add_gas(self.get_register_price(engine.snapshot)) storage_key = self.key_candidate + public_key storage_item = engine.snapshot.storages.try_get(storage_key, read_only=False) if storage_item is None: @@ -717,6 +526,7 @@ def register_candidate(self, self._candidates_dirty = True return True + @register("unregisterCandidate", contracts.CallFlags.STATES, cpu_price=1 << 16) def unregister_candidate(self, engine: contracts.ApplicationEngine, public_key: cryptography.ECPoint) -> bool: @@ -747,6 +557,7 @@ def unregister_candidate(self, self._candidates_dirty = True return True + @register("vote", contracts.CallFlags.STATES, cpu_price=1 << 16) def vote(self, engine: contracts.ApplicationEngine, account: types.UInt160, @@ -772,7 +583,7 @@ def vote(self, storage_key_candidate = self.key_candidate + vote_to storage_item_candidate = engine.snapshot.storages.try_get(storage_key_candidate, read_only=False) - if storage_key_candidate is None: + if storage_item_candidate is None: return False candidate_state = storage_item_candidate.get(_CandidateState) @@ -803,21 +614,7 @@ def vote(self, return True - def _get_candidates(self, - snapshot: storage.Snapshot) -> \ - List[Tuple[cryptography.ECPoint, vm.BigInteger]]: - if self._candidates_dirty: - self._candidates = [] - for k, v in snapshot.storages.find(self.key_candidate.to_array()): - candidate = _CandidateState.deserialize_from_bytes(v.value) - if candidate.registered: - # take of the CANDIDATE prefix - point = cryptography.ECPoint.deserialize_from_bytes(k.key[1:]) - self._candidates.append((point, candidate.votes)) - self._candidates_dirty = False - - return self._candidates - + @register("getCandidates", contracts.CallFlags.READ_STATES, cpu_price=1 << 22) def get_candidates(self, engine: contracts.ApplicationEngine) -> None: array = vm.ArrayStackItem(engine.reference_counter) for k, v in self._get_candidates(engine.snapshot): @@ -827,19 +624,141 @@ def get_candidates(self, engine: contracts.ApplicationEngine) -> None: array.append(struct) engine.push(array) + @register("getNextBlockValidators", contracts.CallFlags.READ_STATES, cpu_price=1 << 16) def get_next_block_validators(self, snapshot: storage.Snapshot) -> List[cryptography.ECPoint]: keys = self.get_committee_from_cache(snapshot)[:settings.network.validators_count] keys.sort() return keys + @register("setGasPerBlock", contracts.CallFlags.STATES, cpu_price=1 << 15) + def _set_gas_per_block(self, engine: contracts.ApplicationEngine, gas_per_block: vm.BigInteger) -> None: + if gas_per_block > 0 or gas_per_block > 10 * self._gas.factor: + raise ValueError("new gas per block value exceeds limits") + + if not self._check_committee(engine): + raise ValueError("Check committee failed") + + index = engine.snapshot.persisting_block.index + 1 + gas_bonus_state = GasBonusState.from_snapshot(engine.snapshot, read_only=False) + if gas_bonus_state[-1].index == index: + gas_bonus_state[-1] = _GasRecord(index, gas_per_block) + else: + gas_bonus_state.append(_GasRecord(index, gas_per_block)) + + @register("getGasPerBlock", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) + def get_gas_per_block(self, snapshot: storage.Snapshot) -> vm.BigInteger: + index = snapshot.best_block_height + 1 + gas_bonus_state = GasBonusState.from_snapshot(snapshot, read_only=True) + for record in reversed(gas_bonus_state): # type: _GasRecord + if record.index <= index: + return record.gas_per_block + else: + raise ValueError + + @register("setRegisterPrice", contracts.CallFlags.STATES, cpu_price=1 << 15) + def set_register_price(self, engine: contracts.ApplicationEngine, register_price: int) -> None: + if register_price <= 0: + raise ValueError("Register price cannot be negative or zero") + if not self._check_committee(engine): + raise ValueError("CheckCommittee failed for setRegisterPrice") + item = engine.snapshot.storages.get(self.key_register_price, read_only=False) + item.value = vm.BigInteger(register_price).to_array() + + @register("getRegisterPrice", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) + def get_register_price(self, snapshot: storage.Snapshot) -> int: + return int(vm.BigInteger(snapshot.storages.get(self.key_register_price, read_only=True).value)) + + def _distribute_gas(self, + engine: contracts.ApplicationEngine, + account: types.UInt160, + state: _NeoTokenStorageState) -> None: + if engine.snapshot.persisting_block is None: + return + + gas = self._calculate_bonus(engine.snapshot, state.vote_to, state.balance, state.balance_height, + engine.snapshot.persisting_block.index) + state.balance_height = engine.snapshot.persisting_block.index + GasToken().mint(engine, account, gas, True) + + @register("getCommittee", contracts.CallFlags.READ_STATES, cpu_price=1 << 16) + def get_committee(self, snapshot: storage.Snapshot) -> List[cryptography.ECPoint]: + return sorted(self.get_committee_from_cache(snapshot)) + + def total_supply(self, snapshot: storage.Snapshot) -> vm.BigInteger: + """ Get the total deployed tokens. """ + return self.total_amount + + def on_balance_changing(self, engine: contracts.ApplicationEngine, + account: types.UInt160, + state, + amount: vm.BigInteger) -> None: + self._distribute_gas(engine, account, state) + + if amount == vm.BigInteger.zero(): + return + + if state.vote_to.is_zero(): + return + + si_voters_count = engine.snapshot.storages.get(self.key_voters_count, read_only=False) + new_value = vm.BigInteger(si_voters_count.value) + amount + si_voters_count.value = new_value.to_array() + + si_candidate = engine.snapshot.storages.get(self.key_candidate + state.vote_to, read_only=False) + candidate_state = si_candidate.get(_CandidateState) + candidate_state.votes += amount + self._candidates_dirty = True + self._check_candidate(engine.snapshot, state.vote_to, candidate_state) + + def on_persist(self, engine: contracts.ApplicationEngine) -> None: + super(NeoToken, self).on_persist(engine) + + # set next committee + if self._should_refresh_committee(engine.snapshot.persisting_block.index): + validators = self._compute_committee_members(engine.snapshot) + if self._committee_state is None: + self._committee_state = _CommitteeState.from_snapshot(engine.snapshot) + self._committee_state._validators = validators + self._committee_state.persist(engine.snapshot) + + def post_persist(self, engine: contracts.ApplicationEngine): + super(NeoToken, self).post_persist(engine) + # distribute GAS for committee + m = len(settings.standby_committee) + n = settings.network.validators_count + index = engine.snapshot.persisting_block.index % m + gas_per_block = self.get_gas_per_block(engine.snapshot) + committee = self.get_committee_from_cache(engine.snapshot) + pubkey = committee[index] + account = to_script_hash(contracts.Contract.create_signature_redeemscript(pubkey)) + GasToken().mint(engine, account, gas_per_block * self._COMMITTEE_REWARD_RATIO / 100, False) + + if self._should_refresh_committee(engine.snapshot.persisting_block.index): + voter_reward_of_each_committee = gas_per_block * self._VOTER_REWARD_RATIO * 100000000 * m / (m + n) / 100 + for i, member in enumerate(committee): + factor = 2 if i < n else 1 + member_votes = self._committee_state[member] + if member_votes > 0: + voter_sum_reward_per_neo = voter_reward_of_each_committee * factor / member_votes + voter_reward_key = (self.key_voter_reward_per_committee + + member + + self._to_uint32(engine.snapshot.persisting_block.index + 1) + ) + border = (self.key_voter_reward_per_committee + member).to_array() + try: + pair = next(engine.snapshot.storages.find_range(voter_reward_key.to_array(), border, "reverse")) + result = vm.BigInteger(pair[1].value) + except StopIteration: + result = vm.BigInteger.zero() + voter_sum_reward_per_neo += result + engine.snapshot.storages.put(voter_reward_key, + storage.StorageItem(voter_sum_reward_per_neo.to_array())) + def get_committee_from_cache(self, snapshot: storage.Snapshot) -> List[cryptography.ECPoint]: if self._committee_state is None: self._committee_state = _CommitteeState.from_snapshot(snapshot) return self._committee_state.validators - def get_committee(self, snapshot: storage.Snapshot) -> List[cryptography.ECPoint]: - return sorted(self.get_committee_from_cache(snapshot)) - def get_committee_address(self, snapshot: storage.Snapshot) -> types.UInt160: comittees = self.get_committee(snapshot) return to_script_hash( @@ -848,6 +767,70 @@ def get_committee_address(self, snapshot: storage.Snapshot) -> types.UInt160: comittees) ) + def _calculate_bonus(self, + snapshot: storage.Snapshot, + vote: cryptography.ECPoint, + value: vm.BigInteger, + start: int, + end: int) -> vm.BigInteger: + if value == vm.BigInteger.zero() or start >= end: + return vm.BigInteger.zero() + + if value.sign < 0: + raise ValueError("Can't calculate bonus over negative balance") + + neo_holder_reward = self._calculate_neo_holder_reward(snapshot, value, start, end) + if vote.is_zero(): + return neo_holder_reward + border = (self.key_voter_reward_per_committee + vote).to_array() + start_bytes = self._to_uint32(start) + key_start = (self.key_voter_reward_per_committee + vote + start_bytes).to_array() + + try: + pair = next(snapshot.storages.find_range(key_start, border, "reverse")) + start_reward_per_neo = vm.BigInteger(pair[1].value) # first pair returned, StorageItem + except StopIteration: + start_reward_per_neo = vm.BigInteger.zero() + + end_bytes = self._to_uint32(end) + key_end = (self.key_voter_reward_per_committee + vote + end_bytes).to_array() + + try: + pair = next(snapshot.storages.find_range(key_end, border, "reverse")) + end_reward_per_neo = vm.BigInteger(pair[1].value) # first pair returned, StorageItem + except StopIteration: + end_reward_per_neo = vm.BigInteger.zero() + + return neo_holder_reward + value * (end_reward_per_neo - start_reward_per_neo) / 100000000 + + def _calculate_neo_holder_reward(self, + snapshot: storage.Snapshot, + value: vm.BigInteger, + start: int, + end: int) -> vm.BigInteger: + gas_bonus_state = GasBonusState.from_snapshot(snapshot, read_only=True) + gas_sum = 0 + for pair in reversed(gas_bonus_state): # type: _GasRecord + cur_idx = pair.index + if cur_idx >= end: + continue + if cur_idx > start: + gas_sum += pair.gas_per_block * (end - cur_idx) + end = cur_idx + else: + gas_sum += pair.gas_per_block * (end - start) + break + return value * gas_sum * self._NEO_HOLDER_REWARD_RATIO / 100 / self.total_amount + + def _check_candidate(self, + snapshot: storage.Snapshot, + public_key: cryptography.ECPoint, + candidate: _CandidateState) -> None: + if not candidate.registered and candidate.votes == 0: + for k, v in snapshot.storages.find((self.key_voter_reward_per_committee + public_key).to_array()): + snapshot.storages.delete(k) + snapshot.storages.delete(self.key_candidate + public_key) + def _compute_committee_members(self, snapshot: storage.Snapshot) -> Dict[cryptography.ECPoint, vm.BigInteger]: storage_item = snapshot.storages.get(self.key_voters_count, read_only=True) voters_count = int(vm.BigInteger(storage_item.value)) @@ -857,7 +840,12 @@ def _compute_committee_members(self, snapshot: storage.Snapshot) -> Dict[cryptog if voter_turnout < 0.2 or len(candidates) < len(settings.standby_committee): results = {} for key in settings.standby_committee: - results.update({key: self._committee_state[key]}) + for pair in candidates: + if pair[0] == key: + results.update({key: pair[1]}) + break + else: + results.update({key: vm.BigInteger.zero()}) return results # first sort by votes descending, then by ECPoint ascending # we negate the value of the votes (c[1]) such that they get sorted in descending order @@ -868,44 +856,30 @@ def _compute_committee_members(self, snapshot: storage.Snapshot) -> Dict[cryptog results.update({candidate[0]: candidate[1]}) return results - def _set_gas_per_block(self, engine: contracts.ApplicationEngine, gas_per_block: vm.BigInteger) -> None: - if gas_per_block > 0 or gas_per_block > 10 * self._gas.factor: - raise ValueError("new gas per block value exceeds limits") - - if not self._check_committee(engine): - raise ValueError("Check committee failed") - - index = engine.snapshot.persisting_block.index + 1 - gas_bonus_state = GasBonusState.from_snapshot(engine.snapshot, read_only=False) - if gas_bonus_state[-1].index == index: - gas_bonus_state[-1] = _GasRecord(index, gas_per_block) - else: - gas_bonus_state.append(_GasRecord(index, gas_per_block)) + def _get_candidates(self, + snapshot: storage.Snapshot) -> \ + List[Tuple[cryptography.ECPoint, vm.BigInteger]]: + if self._candidates_dirty: + self._candidates = [] + for k, v in snapshot.storages.find(self.key_candidate.to_array()): + candidate = _CandidateState.deserialize_from_bytes(v.value) + if candidate.registered: + # take of the CANDIDATE prefix + point = cryptography.ECPoint.deserialize_from_bytes(k.key[1:]) + self._candidates.append((point, candidate.votes)) + self._candidates_dirty = False - def get_gas_per_block(self, snapshot: storage.Snapshot) -> vm.BigInteger: - index = snapshot.best_block_height + 1 - gas_bonus_state = GasBonusState.from_snapshot(snapshot, read_only=True) - for record in reversed(gas_bonus_state): # type: _GasRecord - if record.index <= index: - return record.gas_per_block - else: - raise ValueError + return self._candidates - def _distribute_gas(self, - engine: contracts.ApplicationEngine, - account: types.UInt160, - state: _NeoTokenStorageState) -> None: - if engine.snapshot.persisting_block is None: - return + def _should_refresh_committee(self, height: int) -> bool: + return height % len(settings.standby_committee) == 0 - gas = self._calculate_bonus(engine.snapshot, state.vote_to, state.balance, state.balance_height, - engine.snapshot.persisting_block.index) - state.balance_height = engine.snapshot.persisting_block.index - GasToken().mint(engine, account, gas, True) + def _to_uint32(self, value: int) -> bytes: + return struct.pack(">I", value) class GasToken(FungibleToken): - _id: int = -4 + _id: int = -6 _decimals: int = 8 _state = FungibleTokenStorageState @@ -922,6 +896,6 @@ def on_persist(self, engine: contracts.ApplicationEngine) -> None: total_network_fee += tx.network_fee self.burn(engine, tx.sender, vm.BigInteger(tx.system_fee + tx.network_fee)) pub_keys = NeoToken().get_next_block_validators(engine.snapshot) - primary = pub_keys[engine.snapshot.persisting_block.consensus_data.primary_index] + primary = pub_keys[engine.snapshot.persisting_block.primary_index] script_hash = to_script_hash(contracts.Contract.create_signature_redeemscript(primary)) self.mint(engine, script_hash, vm.BigInteger(total_network_fee), False) diff --git a/neo3/contracts/native/ledger.py b/neo3/contracts/native/ledger.py index e4865be3..8e1c5fea 100644 --- a/neo3/contracts/native/ledger.py +++ b/neo3/contracts/native/ledger.py @@ -1,57 +1,26 @@ from __future__ import annotations from typing import Optional -from .nativecontract import NativeContract +from . import register, NativeContract from neo3 import storage, contracts, vm from neo3.core import types from neo3.network import payloads class LedgerContract(NativeContract): - _id = -2 + _id = -4 def init(self): super(LedgerContract, self).init() - self._register_contract_method(self.current_hash, - "currentHash", - 1000000, - call_flags=contracts.CallFlags.READ_STATES - ) - self._register_contract_method(self.current_index, - "currentIndex", - 1000000, - call_flags=contracts.CallFlags.READ_STATES) - self._register_contract_method(self.get_block, - "getBlock", - 1000000, - parameter_names=["block_index_or_hash"], - call_flags=contracts.CallFlags.READ_STATES) - self._register_contract_method(self.get_tx_for_contract, - "getTransaction", - 1000000, - parameter_names=["tx_hash"], - call_flags=contracts.CallFlags.READ_STATES) - self._register_contract_method(self.get_tx_height, - "getTransactionheight", - 1000000, - parameter_names=["tx_hash"], - call_flags=contracts.CallFlags.READ_STATES) - self._register_contract_method(self.get_tx_from_block, - "getTransactionFromBlock", - 2000000, - parameter_names=["block_index_or_hash", "tx_index"], - call_flags=contracts.CallFlags.READ_STATES) - - def on_persist(self, engine: contracts.ApplicationEngine) -> None: - # Unlike C# the current block or its transactions are not persisted here, - # it is still done in the Blockchain class in the persist() function - pass + @register("currentHash", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def current_hash(self, snapshot: storage.Snapshot) -> types.UInt256: return snapshot.persisting_block.hash() - def current_index(self, snapshot) -> int: + @register("currentIndex", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) + def current_index(self, snapshot: storage.Snapshot) -> int: return snapshot.best_block_height + @register("getBlock", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def get_block(self, snapshot: storage.Snapshot, index_or_hash: bytes) -> Optional[payloads.TrimmedBlock]: if len(index_or_hash) < types.UInt256._BYTE_LEN: height = vm.BigInteger(index_or_hash) @@ -68,18 +37,21 @@ def get_block(self, snapshot: storage.Snapshot, index_or_hash: bytes) -> Optiona block = None return block.trim() + @register("getTransaction", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def get_tx_for_contract(self, snapshot: storage.Snapshot, hash_: types.UInt256) -> Optional[payloads.Transaction]: tx = snapshot.transactions.try_get(hash_, read_only=True) if tx is None or not self._is_traceable_block(snapshot, tx.block_height): return None return tx + @register("getTransactionheight", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def get_tx_height(self, snapshot: storage.Snapshot, hash_: types.UInt256) -> int: tx = snapshot.transactions.try_get(hash_, read_only=True) if tx is None or not self._is_traceable_block(snapshot, tx.block_height): return -1 return tx.block_height + @register("getTransactionFromBlock", contracts.CallFlags.READ_STATES, cpu_price=1 << 16) def get_tx_from_block(self, snapshot: storage.Snapshot, block_index_or_hash: bytes, @@ -97,10 +69,15 @@ def get_tx_from_block(self, if block and not self._is_traceable_block(snapshot, block.index): block = None - if tx_index < 0 or tx_index > len(block.transactions) - 1: + if tx_index < 0 or tx_index >= len(block.transactions): raise ValueError("Transaction index out of range") return block.transactions[tx_index] + def on_persist(self, engine: contracts.ApplicationEngine) -> None: + # Unlike C# the current block or its transactions are not persisted here, + # it is still done in the Blockchain class in the persist() function + pass + def _is_traceable_block(self, snapshot: storage.Snapshot, index: int) -> bool: current_idx = self.current_index(snapshot) if index > current_idx: diff --git a/neo3/contracts/native/management.py b/neo3/contracts/native/management.py index 492be9c4..7d4d9e2e 100644 --- a/neo3/contracts/native/management.py +++ b/neo3/contracts/native/management.py @@ -1,6 +1,6 @@ from __future__ import annotations import json -from . import NativeContract +from . import NativeContract, register from typing import Optional from neo3 import storage, contracts, vm from neo3.core import to_script_hash, types, msgrouter @@ -17,48 +17,6 @@ class ManagementContract(NativeContract): def init(self): super(ManagementContract, self).init() - self._register_contract_method(self.get_contract, - "getContract", - 1000000, - parameter_names=["contract_hash"], - call_flags=contracts.CallFlags.READ_STATES) - self._register_contract_method(self.contract_create, - "deploy", - 0, - parameter_names=["nef_file", "manifest"], - call_flags=(contracts.CallFlags.WRITE_STATES - | contracts.CallFlags.ALLOW_NOTIFY) - ) - self._register_contract_method(self.contract_create_with_data, - "deploy", - 0, - parameter_names=["nef_file", "manifest", "data"], - call_flags=(contracts.CallFlags.WRITE_STATES - | contracts.CallFlags.ALLOW_NOTIFY) - ) - self._register_contract_method(self.contract_update, - "update", - 0, - parameter_names=["nef_file", "manifest", "data"], - call_flags=(contracts.CallFlags.WRITE_STATES - | contracts.CallFlags.ALLOW_NOTIFY) - ) - self._register_contract_method(self.contract_destroy, - "destroy", - 1000000, - call_flags=(contracts.CallFlags.WRITE_STATES - | contracts.CallFlags.ALLOW_NOTIFY) - ) - self._register_contract_method(self.get_minimum_deployment_fee, - "getMinimumDeploymentFee", - 1000000, - call_flags=contracts.CallFlags.READ_STATES) - self._register_contract_method(self._set_minimum_deployment_fee, - "setMinimumDeploymentFee", - 3000000, - parameter_names=["new_fee"], - call_flags=contracts.CallFlags.WRITE_STATES) - self.manifest.abi.events = [ contracts.ContractEventDescriptor( "Deploy", @@ -87,33 +45,18 @@ def _initialize(self, engine: contracts.ApplicationEngine) -> None: ) engine.snapshot.storages.put(self.key_next_id, storage.StorageItem(vm.BigInteger(1).to_array())) - def get_next_available_id(self, snapshot: storage.Snapshot) -> int: - si = snapshot.storages.get(self.key_next_id, read_only=False) - value = vm.BigInteger(si.value) - si.value = (value + 1).to_array() - return int(value) - - def on_persist(self, engine: contracts.ApplicationEngine) -> None: - # NEO implicitely expects a certain order of contract initialization - # Native contracts have negative values for `id`, so we reverse the results - sorted_contracts = sorted(self.registered_contracts, key=lambda contract: contract.id, reverse=True) - for contract in sorted_contracts: - if contract.active_block_index != engine.snapshot.persisting_block.index: - continue - engine.snapshot.contracts.put( - contracts.ContractState(contract.id, contract.nef, contract.manifest, 0, contract.hash) - ) - contract._initialize(engine) - + @register("getContract", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def get_contract(self, snapshot: storage.Snapshot, hash_: types.UInt160) -> Optional[contracts.ContractState]: return snapshot.contracts.try_get(hash_, read_only=True) + @register("deploy", contracts.CallFlags.STATES | contracts.CallFlags.ALLOW_NOTIFY) def contract_create(self, engine: contracts.ApplicationEngine, nef_file: bytes, manifest: bytes) -> contracts.ContractState: return self.contract_create_with_data(engine, nef_file, manifest, vm.NullStackItem()) + @register("deploy", contracts.CallFlags.STATES | contracts.CallFlags.ALLOW_NOTIFY) def contract_create_with_data(self, engine: contracts.ApplicationEngine, nef_file: bytes, @@ -167,17 +110,25 @@ def contract_create_with_data(self, ) return contract + @register("update", contracts.CallFlags.STATES | contracts.CallFlags.ALLOW_NOTIFY) def contract_update(self, engine: contracts.ApplicationEngine, nef_file: bytes, - manifest: bytes, - data: vm.StackItem) -> None: + manifest: bytes) -> None: + self.contract_update_with_data(engine, nef_file, manifest, vm.NullStackItem()) + + @register("update", contracts.CallFlags.STATES | contracts.CallFlags.ALLOW_NOTIFY) + def contract_update_with_data(self, + engine: contracts.ApplicationEngine, + nef_file: bytes, + manifest: bytes, + data: vm.StackItem) -> None: nef_len = len(nef_file) manifest_len = len(manifest) engine.add_gas(engine.STORAGE_PRICE * (nef_len + manifest_len)) - contract = engine.snapshot.contracts.try_get(engine.current_scripthash, read_only=False) + contract = engine.snapshot.contracts.try_get(engine.calling_scripthash, read_only=False) if contract is None: raise ValueError("Can't find contract to update") @@ -216,16 +167,17 @@ def contract_update(self, ) ) + @register("destroy", contracts.CallFlags.STATES | contracts.CallFlags.ALLOW_NOTIFY, cpu_price=1 << 15) def contract_destroy(self, engine: contracts.ApplicationEngine) -> None: - hash_ = engine.current_scripthash + hash_ = engine.calling_scripthash contract = engine.snapshot.contracts.try_get(hash_) if contract is None: return - engine.snapshot.storages.delete(hash_) + engine.snapshot.contracts.delete(hash_) - for key, _ in engine.snapshot.storages.find(contract.id.to_bytes(4, 'little', signed=True), b''): + for key, _ in engine.snapshot.storages.find(contract.id.to_bytes(4, 'little', signed=True)): engine.snapshot.storages.delete(key) msgrouter.interop_notify(self.hash, @@ -235,9 +187,11 @@ def contract_destroy(self, engine: contracts.ApplicationEngine) -> None: ) ) + @register("getMinimumDeploymentFee", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def get_minimum_deployment_fee(self, snapshot: storage.Snapshot) -> int: - return int.from_bytes(snapshot.storages[self.key_min_deploy_fee].value, 'little') + return int.from_bytes(snapshot.storages.get(self.key_min_deploy_fee, read_only=True).value, 'little') + @register("setMinimumDeploymentFee", contracts.CallFlags.STATES, cpu_price=1 << 15) def _set_minimum_deployment_fee(self, engine: contracts.ApplicationEngine, value: int) -> None: if value < 0: raise ValueError("Can't set deployment fee to a negative value") @@ -245,6 +199,24 @@ def _set_minimum_deployment_fee(self, engine: contracts.ApplicationEngine, value raise ValueError engine.snapshot.storages.update(self.key_min_deploy_fee, storage.StorageItem(vm.BigInteger(value).to_array())) + def get_next_available_id(self, snapshot: storage.Snapshot) -> int: + si = snapshot.storages.get(self.key_next_id, read_only=False) + value = vm.BigInteger(si.value) + si.value = (value + 1).to_array() + return int(value) + + def on_persist(self, engine: contracts.ApplicationEngine) -> None: + # NEO implicitely expects a certain order of contract initialization + # Native contracts have negative values for `id`, so we reverse the results + sorted_contracts = sorted(self.registered_contracts, key=lambda contract: contract.id, reverse=True) + for contract in sorted_contracts: + if contract.active_block_index != engine.snapshot.persisting_block.index: + continue + engine.snapshot.contracts.put( + contracts.ContractState(contract.id, contract.nef, contract.manifest, 0, contract.hash) + ) + contract._initialize(engine) + def validate(self, script: bytes, abi: contracts.ContractABI): s = vm.Script(script, True) for method in abi.methods: diff --git a/neo3/contracts/native/nameservice.py b/neo3/contracts/native/nameservice.py index 1faa1344..a24ad57a 100644 --- a/neo3/contracts/native/nameservice.py +++ b/neo3/contracts/native/nameservice.py @@ -4,6 +4,7 @@ import ipaddress from enum import IntEnum from .nonfungible import NFTState, NonFungibleToken +from . import register from typing import Optional, Iterator, Tuple from neo3 import contracts, storage, vm from neo3.core import serialization, types @@ -20,10 +21,9 @@ class NameState(NFTState): def __init__(self, owner: types.UInt160, name: str, - description: str, expiration: int, admin: Optional[types.UInt160] = None): - super(NameState, self).__init__(owner, name, description) + super(NameState, self).__init__(owner, name) self.expiration = expiration self.admin = admin if admin else types.UInt160.zero() self.id = name.encode() @@ -43,7 +43,7 @@ def deserialize(self, reader: serialization.BinaryReader) -> None: @classmethod def _serializable_init(cls): - return cls(types.UInt160.zero(), "", "", 0, types.UInt160.zero()) + return cls(types.UInt160.zero(), "", 0, types.UInt160.zero()) class StringList(list, serialization.ISerializable): @@ -61,7 +61,7 @@ def deserialize(self, reader: serialization.BinaryReader) -> None: class NameService(NonFungibleToken): - _id = -8 + _id = -10 _symbol = "NNS" _service_name = None @@ -76,42 +76,13 @@ class NameService(NonFungibleToken): def init(self): super(NameService, self).init() - self._register_contract_method(self.add_root, - "addRoot", - 3000000, - parameter_names=["root"], - call_flags=contracts.CallFlags.WRITE_STATES - ) - self._register_contract_method(self.set_price, - "setPrice", - 3000000, - parameter_names=["price"], - call_flags=contracts.CallFlags.WRITE_STATES) - self._register_contract_method(self.register, - "register", - 1000000, - parameter_names=["name", "owner"], - call_flags=contracts.CallFlags.WRITE_STATES - ) def _initialize(self, engine: contracts.ApplicationEngine) -> None: super(NameService, self)._initialize(engine) engine.snapshot.storages.put(self.key_domain_price, storage.StorageItem(vm.BigInteger(1000000000).to_array())) engine.snapshot.storages.put(self.key_roots, storage.StorageItem(b'\x00')) - def on_persist(self, engine: contracts.ApplicationEngine) -> None: - now = (engine.snapshot.persisting_block.timestamp // 1000) + 1 - start = (self.key_expiration + self._to_uint32(0)).to_array() - end = (self.key_expiration + self._to_uint32(now)).to_array() - for key, _ in engine.snapshot.storages.find_range(start, end): - engine.snapshot.storages.delete(key) - for key2, _ in engine.snapshot.storages.find(self.key_record + key.key[5:]).to_array(): - engine.snapshot.storages.delete(key2) - self.burn(engine, self.key_token + key.key[5:]) - - def on_transferred(self, engine: contracts.ApplicationEngine, from_account: types.UInt160, token: NFTState) -> None: - token.owner = types.UInt160.zero() - + @register("addRoot", contracts.CallFlags.STATES, cpu_price=1 << 15) def add_root(self, engine: contracts.ApplicationEngine, root: str) -> None: if not self.REGEX_ROOT.match(root): raise ValueError("Regex failure - root not found") @@ -123,6 +94,7 @@ def add_root(self, engine: contracts.ApplicationEngine, root: str) -> None: raise ValueError("The name already exists") roots.append(root) + @register("setPrice", contracts.CallFlags.STATES, cpu_price=1 << 15) def set_price(self, engine: contracts.ApplicationEngine, price: int) -> None: if price <= 0 or price > 10000_00000000: raise ValueError(f"New price '{price}' exceeds limits") @@ -131,9 +103,11 @@ def set_price(self, engine: contracts.ApplicationEngine, price: int) -> None: storage_item = engine.snapshot.storages.get(self.key_domain_price, read_only=False) storage_item.value = price.to_bytes(8, 'little') + @register("getPrice", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def get_price(self, snapshot: storage.Snapshot) -> int: return int.from_bytes(snapshot.storages.get(self.key_domain_price, read_only=True).value, 'little') + @register("isAvailable", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def is_available(self, snapshot: storage.Snapshot, name: str) -> bool: if not self.REGEX_NAME.match(name): raise ValueError("Regex failure - name is not valid") @@ -149,7 +123,8 @@ def is_available(self, snapshot: storage.Snapshot, name: str) -> bool: raise ValueError(f"'{names[1]}' is not a registered root") return True - def register(self, engine: contracts.ApplicationEngine, name: str, owner: types.UInt160) -> bool: + @register("register", contracts.CallFlags.STATES, cpu_price=1 << 15) + def do_register(self, engine: contracts.ApplicationEngine, name: str, owner: types.UInt160) -> bool: if not self.is_available(engine.snapshot, name): raise ValueError(f"Registration failure - '{name}' is not available") @@ -157,7 +132,7 @@ def register(self, engine: contracts.ApplicationEngine, name: str, owner: types. raise ValueError("CheckWitness failed") engine.add_gas(self.get_price(engine.snapshot)) - state = NameState(owner, name, "", (engine.snapshot.persisting_block.timestamp // 1000) + self.ONE_YEAR) + state = NameState(owner, name, (engine.snapshot.persisting_block.timestamp // 1000) + self.ONE_YEAR) self.mint(engine, state) engine.snapshot.storages.put( self.key_expiration + state.expiration.to_bytes(4, 'big') + name.encode(), @@ -165,6 +140,7 @@ def register(self, engine: contracts.ApplicationEngine, name: str, owner: types. ) return True + @register("renew", contracts.CallFlags.STATES, cpu_price=1 << 15) def renew(self, engine: contracts.ApplicationEngine, name: str) -> int: if not self.REGEX_NAME.match(name): raise ValueError("Regex failure - name is not valid") @@ -176,6 +152,7 @@ def renew(self, engine: contracts.ApplicationEngine, name: str) -> int: state.expiration += self.ONE_YEAR return state.expiration + @register("setAdmin", contracts.CallFlags.STATES, cpu_price=1 << 15, storage_price=20) def set_admin(self, engine: contracts.ApplicationEngine, name: str, admin: types.UInt160) -> None: if not self.REGEX_NAME.match(name): raise ValueError("Regex failure - name is not valid") @@ -194,6 +171,7 @@ def set_admin(self, engine: contracts.ApplicationEngine, name: str, admin: types state.admin = admin + @register("setRecord", contracts.CallFlags.STATES, cpu_price=1 << 15, storage_price=200) def set_record(self, engine: contracts.ApplicationEngine, name: str, record_type: RecordType, data: str) -> None: if not self.REGEX_NAME.match(name): raise ValueError("Regex failure - name is not valid") @@ -220,6 +198,7 @@ def set_record(self, engine: contracts.ApplicationEngine, name: str, record_type storage_key_record = self.key_record + domain.encode() + name.encode() + record_type.to_bytes(1, 'little') engine.snapshot.storages.update(storage_key_record, storage.StorageItem(data.encode())) + @register("getRecord", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def get_record(self, snapshot: storage.Snapshot, name: str, record_type: RecordType) -> Optional[str]: if not self.REGEX_NAME.match(name): raise ValueError("Regex failure - name is not valid") @@ -230,15 +209,7 @@ def get_record(self, snapshot: storage.Snapshot, name: str, record_type: RecordT return None return storage_item.value.decode() - def get_records(self, snapshot: storage.Snapshot, name: str) -> Iterator[Tuple[RecordType, str]]: - if not self.REGEX_NAME.match(name): - raise ValueError("Regex failure - name is not valid") - domain = '.'.join(name.split('.')[2:]) - storage_key = self.key_record + domain.encode() + name.encode() - for key, value in snapshot.storages.find(storage_key.to_array()): - record_type = RecordType(int.from_bytes(key.key[-1], 'little')) - yield record_type, value.value.decode() - + @register("deleteRecord", contracts.CallFlags.STATES, cpu_price=1 << 15) def delete_record(self, engine: contracts.ApplicationEngine, name: str, record_type: RecordType) -> None: if not self.REGEX_NAME.match(name): raise ValueError("Regex failure - name is not valid") @@ -252,6 +223,7 @@ def delete_record(self, engine: contracts.ApplicationEngine, name: str, record_t storage_key_record = self.key_record + domain.encode() + name.encode() + record_type.to_bytes(1, 'little') engine.snapshot.storages.delete(storage_key_record) + @register("resolve", contracts.CallFlags.READ_STATES, cpu_price=1 << 17) def resolve(self, snapshot: storage.Snapshot, name: str, @@ -269,6 +241,28 @@ def resolve(self, return None return self.resolve(snapshot, data, record_type, redirect_count - 1) + def get_records(self, snapshot: storage.Snapshot, name: str) -> Iterator[Tuple[RecordType, str]]: + if not self.REGEX_NAME.match(name): + raise ValueError("Regex failure - name is not valid") + domain = '.'.join(name.split('.')[2:]) + storage_key = self.key_record + domain.encode() + name.encode() + for key, value in snapshot.storages.find(storage_key.to_array()): + record_type = RecordType(int.from_bytes(key.key[-1], 'little')) + yield record_type, value.value.decode() + + def on_persist(self, engine: contracts.ApplicationEngine) -> None: + now = (engine.snapshot.persisting_block.timestamp // 1000) + 1 + start = (self.key_expiration + self._to_uint32(0)).to_array() + end = (self.key_expiration + self._to_uint32(now)).to_array() + for key, _ in engine.snapshot.storages.find_range(start, end): + engine.snapshot.storages.delete(key) + for key2, _ in engine.snapshot.storages.find(self.key_record + key.key[5:]).to_array(): + engine.snapshot.storages.delete(key2) + self.burn(engine, self.key_token + key.key[5:]) + + def on_transferred(self, engine: contracts.ApplicationEngine, from_account: types.UInt160, token: NFTState) -> None: + token.owner = types.UInt160.zero() + def _check_admin(self, engine: contracts.ApplicationEngine, state: NameState) -> bool: if engine.checkwitness(state.owner): return True diff --git a/neo3/contracts/native/nativecontract.py b/neo3/contracts/native/nativecontract.py index c1cdb783..67e296db 100644 --- a/neo3/contracts/native/nativecontract.py +++ b/neo3/contracts/native/nativecontract.py @@ -1,30 +1,52 @@ from __future__ import annotations -from typing import List, Callable, Dict, Tuple, Any, Optional, get_type_hints +import inspect +from typing import List, Callable, Dict, Any, Optional, get_type_hints from neo3 import contracts, vm, storage, settings from neo3.core import types, to_script_hash from neo3.network import convenience -class _ContractMethodMetadata: - """ - Internal helper class containing meta data that helps in translating VM Stack Items to the arguments types of the - handling function. Applies to native contracts only. - """ - - def __init__(self, handler: Callable[..., None], - price: int, - required_flags: contracts.CallFlags, - add_engine: bool, - add_snapshot: bool, - return_type, - parameter_types=None): - self.handler = handler - self.price = price - self.return_type = return_type - self.parameters = parameter_types if parameter_types else [] - self.required_flag = required_flags - self.add_engine = add_engine - self.add_snapshot = add_snapshot +class _NativeMethodMeta: + def __init__(self, func: Callable): + self.handler = func + self.name: str = func.name # type:ignore + self.cpu_price: int = func.cpu_price # type: ignore + self.storage_price: int = func.storage_price # type: ignore + self.required_flags: contracts.CallFlags = func.flags # type: ignore + self.add_engine = False + self.add_snapshot = False + self.return_type = None + + parameter_types = [] + parameter_names = [] + for k, v in get_type_hints(func).items(): + if k == 'return': + if v != type(None): + self.return_type = v + continue + if v == contracts.ApplicationEngine: + self.add_engine = True + continue + elif v == storage.Snapshot: + self.add_snapshot = True + continue + parameter_types.append(v) + parameter_names.append(k) + + params = [] + for t, n in zip(parameter_types, parameter_names): + params.append(contracts.ContractParameterDefinition( + name=n, + type=contracts.ContractParameterType.from_type(t) + )) + self.parameter_types = parameter_types + self.descriptor = contracts.ContractMethodDescriptor( + name=func.name, # type: ignore + offset=0, + return_type=contracts.ContractParameterType.from_type(self.return_type), + parameters=params, + safe=(func.flags & ~contracts.CallFlags.READ_ONLY) == 0 # type: ignore + ) class NativeContract(convenience._Singleton): @@ -36,12 +58,14 @@ class NativeContract(convenience._Singleton): #: A dictionary for accessing a native contract by its hash _contract_hashes: Dict[types.UInt160, NativeContract] = {} + #: Allows for overriding the contract name in the ABI. Otherwise the name equals the class name. _service_name: Optional[str] = None + #: The block index at which the native contract becomes active. active_block_index = 0 def init(self): - self._methods: Dict[Tuple[str, int], _ContractMethodMetadata] = {} + self._methods: Dict[int, _NativeMethodMeta] = {} # offset, meta self._management = contracts.ManagementContract() self._neo = contracts.NeoToken() @@ -51,22 +75,39 @@ def init(self): self._oracle = contracts.OracleContract() self._ledger = contracts.LedgerContract() self._role = contracts.DesignationContract() + self._crypto = contracts.CryptoContract() + self._stdlib = contracts.StdLibContract() + + # Find all methods that have been augmented by the @register decorator + # and turn them into methods that can be called by VM scripts + methods_meta = [] + for pair in inspect.getmembers(self, lambda m: hasattr(m, "native_call")): + methods_meta.append(_NativeMethodMeta(pair[1])) + + methods_meta.sort(key=lambda x: (x.descriptor.name, len(x.descriptor.parameters))) sb = vm.ScriptBuilder() - sb.emit_push(self.id) - sb.emit_syscall(1736177434) # "System.Contract.CallNative" + for meta in methods_meta: + meta.descriptor.offset = len(sb) + sb.emit_push(0) + self._methods.update({len(sb): meta}) + sb.emit_syscall(1736177434) # "System.Contract.CallNative" + sb.emit(vm.OpCode.RET) + self._script: bytes = sb.to_array() self.nef = contracts.NEF("neo-core-v3.0", self._script) + sender = types.UInt160.zero() # OpCode.PUSH1 sb = vm.ScriptBuilder() sb.emit(vm.OpCode.ABORT) sb.emit_push(sender.to_array()) - sb.emit_push(self.nef.checksum) + sb.emit_push(0) sb.emit_push(self.service_name()) self._hash: types.UInt160 = to_script_hash(sb.to_array()) self._manifest: contracts.ContractManifest = contracts.ContractManifest() self._manifest.name = self.service_name() - self._manifest.abi.methods = [] + self._manifest.abi.methods = list(map(lambda m: m.descriptor, methods_meta)) + if self._id != NativeContract._id: self._contracts.update({self.service_name(): self}) self._contract_hashes.update({self._hash: self}) @@ -86,76 +127,17 @@ def get_contract_by_name(cls, name: str) -> Optional[NativeContract]: @classmethod def get_contract_by_id(cls, contract_id: int) -> Optional[NativeContract]: + """ Get the native contract by its service id """ for contract in cls._contracts.values(): if contract_id == contract.id: return contract else: return None - def _register_contract_method(self, - func: Callable, - func_name: str, - price: int, - parameter_names: List[str] = None, - call_flags: contracts.CallFlags = contracts.CallFlags.NONE - ) -> None: - """ - Registers a native contract method into the manifest - - Args: - func: func pointer. - func_name: the name of the callable function. - price: the cost of calling the function. - parameter_names: the function argument names. - """ - return_type = None - parameter_types = [] - for k, v in get_type_hints(func).items(): - if k == 'return': - if v != type(None): - return_type = v - continue - parameter_types.append(v) - - add_engine = False - add_snapshot = False - if len(parameter_types) > 0: - if parameter_types[0] == contracts.ApplicationEngine: - add_engine = True - parameter_types = parameter_types[1:] - elif parameter_types[0] == storage.Snapshot: - add_snapshot = True - parameter_types = parameter_types[1:] - - params = [] - - if parameter_types and parameter_names is None: - raise ValueError(f"Found parameters types but missing parameter names for {self} {func_name}") - - if parameter_names: - if len(parameter_types) != len(parameter_names): - raise ValueError(f"Parameter types count must match parameter names count! " - f"{len(parameter_types)}!={len(parameter_names)}") - - for t, n in zip(parameter_types, parameter_names): - params.append(contracts.ContractParameterDefinition( - name=n, - type=contracts.ContractParameterType.from_type(t) - )) - - self._manifest.abi.methods.append( - contracts.ContractMethodDescriptor( - name=func_name, - offset=0, - return_type=contracts.ContractParameterType.from_type(return_type), - parameters=params, - safe=(call_flags & ~contracts.CallFlags.READ_ONLY) == 0 - ) - ) - - self._methods.update({(func_name, len(params)): _ContractMethodMetadata( - func, price, call_flags, add_engine, add_snapshot, return_type, parameter_types) - }) + @classmethod + def get_contract_by_hash(cls, contract_hash: types.UInt160) -> Optional[NativeContract]: + """ Get the native contract by its contract hash """ + return cls._contract_hashes.get(contract_hash, None) @property def registered_contract_names(self) -> List[str]: @@ -194,15 +176,7 @@ def manifest(self) -> contracts.ContractManifest: """ The associated contract manifest. """ return self._manifest - def _initialize(self, engine: contracts.ApplicationEngine) -> None: - """ - Called once when a native contract is deployed - - Args: - engine: ApplicationEngine - """ - - def invoke(self, engine: contracts.ApplicationEngine) -> None: + def invoke(self, engine: contracts.ApplicationEngine, version: int) -> None: """ Calls a contract function @@ -210,26 +184,28 @@ def invoke(self, engine: contracts.ApplicationEngine) -> None: Args: engine: the engine executing the smart contract + version: which version of the smart contract to load Raises: - SystemError: if not called via `System.Contract.Call` + ValueError: if the request contract version is not ValueError: if the function to be called does not exist on the contract ValueError: if trying to call a function without having the correct CallFlags """ - if engine.current_scripthash != self.hash: - raise SystemError("It is not allowed to use Neo.Native.Call directly, use System.Contract.Call") + if version != 0: + raise ValueError(f"Native contract version {version} is not active") # type: ignore context = engine.current_context - operation = context.evaluation_stack.pop().to_array().decode() - flags = contracts.CallFlags(context.call_flags) - method = self._methods.get((operation, len(context.evaluation_stack)), None) + method = self._methods.get(context.ip, None) if method is None: - raise ValueError(f"Method \"{operation}\" does not exist on contract {self.service_name()}") - if method.required_flag not in flags: - raise ValueError(f"Method requires call flag: {method.required_flag} received: {flags}") + raise ValueError(f"Method at IP \"{context.ip}\" does not exist on contract {self.service_name()}") + if method.required_flags not in flags: + raise ValueError(f"Method requires call flag: {method.required_flags} received: {flags}") - engine.add_gas(method.price) + engine.add_gas(method.cpu_price + * contracts.PolicyContract().get_exec_fee_factor(engine.snapshot) + + method.storage_price + * contracts.PolicyContract().get_storage_price(engine.snapshot)) params: List[Any] = [] if method.add_engine: @@ -238,8 +214,8 @@ def invoke(self, engine: contracts.ApplicationEngine) -> None: if method.add_snapshot: params.append(engine.snapshot) - for i in range(len(method.parameters)): - params.append(engine._stackitem_to_native(context.evaluation_stack.pop(), method.parameters[i])) + for t in method.parameter_types: + params.append(engine._stackitem_to_native(context.evaluation_stack.pop(), t)) if len(params) > 0: return_value = method.handler(*params) @@ -273,9 +249,23 @@ def on_persist(self, engine: contracts.ApplicationEngine) -> None: def post_persist(self, engine: contracts.ApplicationEngine): pass + def create_key(self, prefix: bytes) -> storage.StorageKey: + """ + Helper to create a storage key for the contract + + Args: + prefix: the storage prefix to be used + """ + return storage.StorageKey(self._id, prefix) + + def _initialize(self, engine: contracts.ApplicationEngine) -> None: + """ + Called once when a native contract is deployed + + Args: + engine: ApplicationEngine + """ + def _check_committee(self, engine: contracts.ApplicationEngine) -> bool: addr = contracts.NeoToken().get_committee_address(engine.snapshot) return engine.checkwitness(addr) - - def create_key(self, prefix: bytes) -> storage.StorageKey: - return storage.StorageKey(self._id, prefix) diff --git a/neo3/contracts/native/nonfungible.py b/neo3/contracts/native/nonfungible.py index 63993ccb..f1da776c 100644 --- a/neo3/contracts/native/nonfungible.py +++ b/neo3/contracts/native/nonfungible.py @@ -1,17 +1,16 @@ from __future__ import annotations from contextlib import suppress from typing import List, cast, Optional -from . import NativeContract, FungibleTokenStorageState +from . import NativeContract, FungibleTokenStorageState, register from neo3 import storage, contracts, vm from neo3.core import serialization, IInteroperable, types, msgrouter from neo3.contracts import interop class NFTState(IInteroperable, serialization.ISerializable): - def __init__(self, owner: types.UInt160, name: str, description: str): + def __init__(self, owner: types.UInt160, name: str): self.owner = owner self.name = name - self.description = description # I don't understand where this ID is coming from as its abstract in C# and not overridden # we'll probably figure out once we implement the name service in a later PR self.id: bytes = b'' @@ -22,31 +21,28 @@ def from_stack_item(cls, stack_item: vm.StackItem): owner = types.UInt160(stack_item[0].to_array()) name = stack_item[1].to_array().decode() description = stack_item[2].to_array().decode() - return cls(owner, name, description) + return cls(owner, name) def to_stack_item(self, reference_counter: vm.ReferenceCounter) -> vm.StackItem: return vm.StructStackItem(reference_counter, [ vm.ByteStringStackItem(self.owner.to_array()), vm.ByteStringStackItem(self.name), - vm.ByteStringStackItem(self.description) ]) def serialize(self, writer: serialization.BinaryWriter) -> None: writer.write_serializable(self.owner) writer.write_var_string(self.name) - writer.write_var_string(self.description) def deserialize(self, reader: serialization.BinaryReader) -> None: self.owner = reader.read_serializable(types.UInt160) self.name = reader.read_var_string() - self.description = reader.read_var_string() def to_json(self) -> dict: - return {"name": self.name, "description": self.description} + return {"name": self.name} @classmethod def _serializable_init(cls): - return cls(types.UInt160.zero(), "", "") + return cls(types.UInt160.zero(), "") class NFTAccountState(FungibleTokenStorageState): @@ -100,104 +96,38 @@ def init(self): ) ] - self._register_contract_method(self.total_supply, - "totalSupply", - 1000000, - call_flags=contracts.CallFlags.READ_STATES) - - self._register_contract_method(self.owner_of, - "ownerOf", - 1000000, - parameter_names=["token_id"], - call_flags=contracts.CallFlags.READ_STATES) - - self._register_contract_method(self.properties, - "properties", - 1000000, - parameter_names=["token_id"], - call_flags=contracts.CallFlags.READ_STATES) - - self._register_contract_method(self.balance_of, - "balanceOf", - 1000000, - parameter_names=["owner"], - call_flags=contracts.CallFlags.READ_STATES) - self._register_contract_method(self.transfer, - "transfer", - 9000000, - parameter_names=["to", "tokenId"], - call_flags=(contracts.CallFlags.WRITE_STATES - | contracts.CallFlags.ALLOW_NOTIFY)) - self._register_contract_method(self.tokens, - "tokens", - 1000000, - call_flags=contracts.CallFlags.READ_STATES) - self._register_contract_method(self.tokens_of, - "tokensOf", - 1000000, - parameter_names=["owner"], - call_flags=contracts.CallFlags.READ_STATES) - def _initialize(self, engine: contracts.ApplicationEngine) -> None: engine.snapshot.storages.put(self.key_total_suppply, storage.StorageItem(b'\x00')) - def mint(self, engine: contracts.ApplicationEngine, token: NFTState) -> None: - engine.snapshot.storages.put(self.key_token + token.id, storage.StorageItem(token.to_array())) - sk_account = self.key_account + token.id - si_account = engine.snapshot.storages.try_get(sk_account, read_only=False) - - if si_account is None: - si_account = storage.StorageItem(NFTAccountState().to_array()) - engine.snapshot.storages.put(sk_account, si_account) - - account = si_account.get(NFTAccountState) - account.add(token.id) - - si_total_supply = engine.snapshot.storages.get(self.key_total_suppply, read_only=False) - new_value = vm.BigInteger(si_total_supply.value) + 1 - si_total_supply.value = new_value.to_array() - - self._post_transfer(engine, types.UInt160.zero(), token.owner, token.id) - - def burn(self, engine: contracts.ApplicationEngine, token_id: bytes) -> None: - key_token = self.key_token + token_id - si_token = engine.snapshot.storages.try_get(key_token, read_only=True) - if si_token is None: - raise ValueError("Token cannot be found") - token = NFTState.deserialize_from_bytes(si_token.value) - engine.snapshot.storages.delete(key_token) - - key_account = self.key_account + token.owner.to_array() - account_state = engine.snapshot.storages.get(key_account).get(NFTAccountState) - account_state.remove(token_id) - - if account_state.balance == 0: - engine.snapshot.storages.delete(key_account) - - si_total_supply = engine.snapshot.storages.get(self.key_total_suppply) - new_value = vm.BigInteger(si_total_supply.value) + 1 - si_total_supply.value = new_value.to_array() - - self._post_transfer(engine, token.owner, types.UInt160.zero(), token_id) - + @register("totalSupply", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def total_supply(self, snapshot: storage.Snapshot) -> vm.BigInteger: storage_item = snapshot.storages.get(self.key_total_suppply) return vm.BigInteger(storage_item.value) + @register("ownerOf", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def owner_of(self, snapshot: storage.Snapshot, token_id: bytes) -> types.UInt160: storage_item = snapshot.storages.get(self.key_token + token_id, read_only=True) return NFTState.from_stack_item(storage_item).owner - def properties(self, snapshot: storage.Snapshot, token_id: bytes) -> dict: - storage_item = snapshot.storages.get(self.key_token + token_id, read_only=True) - return NFTState.deserialize_from_bytes(storage_item.value).to_json() + @register("properties", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) + def properties(self, engine: contracts.ApplicationEngine, token_id: bytes) -> vm.MapStackItem: + storage_item = engine.snapshot.storages.get(self.key_token + token_id, read_only=True) + map_ = vm.MapStackItem(engine.reference_counter) + for k, v in NFTState.deserialize_from_bytes(storage_item.value).to_json(): + map_[k] = v + return map_ + @register("balanceOf", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def balance_of(self, snapshot: storage.Snapshot, owner: types.UInt160) -> vm.BigInteger: storage_item = snapshot.storages.try_get(self.key_account + owner.to_array(), read_only=True) if storage_item is None: return vm.BigInteger.zero() return NFTAccountState.deserialize_from_bytes(storage_item.value).balance + @register("transfer", + contracts.CallFlags.STATES | contracts.CallFlags.ALLOW_CALL | contracts.CallFlags.ALLOW_NOTIFY, + cpu_price=1 << 17, + storage_price=50) def transfer(self, engine: contracts.ApplicationEngine, account_to: types.UInt160, token_id: bytes) -> bool: if account_to == types.UInt160.zero(): raise ValueError("To account can't be zero") @@ -228,6 +158,7 @@ def transfer(self, engine: contracts.ApplicationEngine, account_to: types.UInt16 self._post_transfer(engine, token_state.owner, account_to, token_id) return True + @register("tokens", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def tokens(self, snapshot: storage.Snapshot) -> interop.IIterator: result = snapshot.storages.find(self.key_token.to_array()) options = contracts.FindOptions @@ -237,6 +168,7 @@ def tokens(self, snapshot: storage.Snapshot) -> interop.IIterator: options.VALUES_ONLY | options.DESERIALIZE_VALUES | options.PICK_FIELD1, reference_counter) + @register("tokensOf", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def tokens_of(self, snapshot: storage.Snapshot, owner: types.UInt160) -> interop.IIterator: storage_item_account = snapshot.storages.try_get(self.key_account + owner.to_array(), read_only=True) reference_counter = vm.ReferenceCounter() @@ -246,6 +178,45 @@ def tokens_of(self, snapshot: storage.Snapshot, owner: types.UInt160) -> interop tokens: List[vm.StackItem] = list(map(lambda t: vm.ByteStringStackItem(t), account.tokens)) return interop.ArrayWrapper(vm.ArrayStackItem(reference_counter, tokens)) + def mint(self, engine: contracts.ApplicationEngine, token: NFTState) -> None: + engine.snapshot.storages.put(self.key_token + token.id, storage.StorageItem(token.to_array())) + sk_account = self.key_account + token.id + si_account = engine.snapshot.storages.try_get(sk_account, read_only=False) + + if si_account is None: + si_account = storage.StorageItem(NFTAccountState().to_array()) + engine.snapshot.storages.put(sk_account, si_account) + + account = si_account.get(NFTAccountState) + account.add(token.id) + + si_total_supply = engine.snapshot.storages.get(self.key_total_suppply, read_only=False) + new_value = vm.BigInteger(si_total_supply.value) + 1 + si_total_supply.value = new_value.to_array() + + self._post_transfer(engine, types.UInt160.zero(), token.owner, token.id) + + def burn(self, engine: contracts.ApplicationEngine, token_id: bytes) -> None: + key_token = self.key_token + token_id + si_token = engine.snapshot.storages.try_get(key_token, read_only=True) + if si_token is None: + raise ValueError("Token cannot be found") + token = NFTState.deserialize_from_bytes(si_token.value) + engine.snapshot.storages.delete(key_token) + + key_account = self.key_account + token.owner.to_array() + account_state = engine.snapshot.storages.get(key_account).get(NFTAccountState) + account_state.remove(token_id) + + if account_state.balance == 0: + engine.snapshot.storages.delete(key_account) + + si_total_supply = engine.snapshot.storages.get(self.key_total_suppply) + new_value = vm.BigInteger(si_total_supply.value) + 1 + si_total_supply.value = new_value.to_array() + + self._post_transfer(engine, token.owner, types.UInt160.zero(), token_id) + def on_transferred(self, engine: contracts.ApplicationEngine, from_account: types.UInt160, token: NFTState) -> None: pass diff --git a/neo3/contracts/native/oracle.py b/neo3/contracts/native/oracle.py index e12c5f98..eb572e54 100644 --- a/neo3/contracts/native/oracle.py +++ b/neo3/contracts/native/oracle.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional, cast, List -from . import NativeContract +from . import NativeContract, register from neo3 import contracts, storage, vm from neo3.core import types, cryptography, serialization, to_script_hash, msgrouter from neo3.network import payloads @@ -55,13 +55,12 @@ class OracleContract(NativeContract): _MAX_FILTER_LEN = 128 _MAX_CALLBACK_LEN = 32 _MAX_USER_DATA_LEN = 512 - _id = -7 + _id = -9 key_request_id = storage.StorageKey(_id, b'\x09') key_request = storage.StorageKey(_id, b'\x07') key_id_list = storage.StorageKey(_id, b'\x06') - - _ORACLE_REQUEST_PRICE = 50000000 + key_price = storage.StorageKey(_id, b'\x05') def init(self): super(OracleContract, self).init() @@ -84,27 +83,25 @@ def init(self): ) ] - self._register_contract_method(self.finish, - "finish", - 0, - call_flags=(contracts.CallFlags.WRITE_STATES - | contracts.CallFlags.ALLOW_CALL - | contracts.CallFlags.ALLOW_NOTIFY)) - - self._register_contract_method(self._request, - "request", - self._ORACLE_REQUEST_PRICE, - parameter_names=["url", "filter", "callback", "userdata", "gas_for_response"], - call_flags=contracts.CallFlags.WRITE_STATES | contracts.CallFlags.ALLOW_NOTIFY) - - self._register_contract_method(self._verify, - "verify", - 1000000, - call_flags=contracts.CallFlags.NONE) - def _initialize(self, engine: contracts.ApplicationEngine) -> None: engine.snapshot.storages.put(self.key_request_id, storage.StorageItem(vm.BigInteger.zero().to_array())) - + engine.snapshot.storages.put(self.key_price, storage.StorageItem(vm.BigInteger(50000000).to_array())) + + @register("setPrice", contracts.CallFlags.STATES, cpu_price=1 << 15) + def set_price(self, engine: contracts.ApplicationEngine, price: int) -> None: + if price <= 0: + raise ValueError("Oracle->setPrice value cannot be negative or zero") + if not self._check_committee(engine): + raise ValueError("Oracle->setPrice check committee failed") + item = engine.snapshot.storages.get(self.key_price) + item.value = vm.BigInteger(price).to_array() + + @register("getPrice", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) + def get_price(self, snapshot: storage.Snapshot) -> int: + return int(vm.BigInteger(snapshot.storages.get(self.key_price, read_only=True).value)) + + @register("finish", + (contracts.CallFlags.STATES | contracts.CallFlags.ALLOW_CALL | contracts.CallFlags.ALLOW_NOTIFY)) def finish(self, engine: contracts.ApplicationEngine) -> None: tx = engine.script_container tx = cast(payloads.Transaction, tx) @@ -127,7 +124,6 @@ def finish(self, engine: contracts.ApplicationEngine) -> None: user_data = contracts.BinarySerializer.deserialize(request.user_data, engine.MAX_STACK_SIZE, - engine.MAX_ITEM_SIZE, engine.reference_counter) args: List[vm.StackItem] = [vm.ByteStringStackItem(request.url.encode()), user_data, @@ -136,27 +132,7 @@ def finish(self, engine: contracts.ApplicationEngine) -> None: engine.call_from_native(self.hash, request.callback_contract, request.callback_method, args) - def get_request(self, snapshot: storage.Snapshot, id: int) -> Optional[OracleRequest]: - id_bytes = id.to_bytes(8, 'little', signed=False) - storage_item = snapshot.storages.try_get(self.key_request + id_bytes) - if storage_item is None: - return None - - return OracleRequest.deserialize_from_bytes(storage_item.value) - - def _get_url_hash(self, url: str) -> bytes: - return to_script_hash(url.encode('utf-8')).to_array() - - def _get_original_txid(self, engine: contracts.ApplicationEngine) -> types.UInt256: - tx = cast(payloads.Transaction, engine.script_container) - response = tx.try_get_attribute(payloads.OracleResponse) - if response is None: - return tx.hash() - request = self.get_request(engine.snapshot, response.id) - if request is None: - raise ValueError # C# will throw null pointer access exception - return request.original_tx_id - + @register("request", contracts.CallFlags.STATES | contracts.CallFlags.ALLOW_NOTIFY) def _request(self, engine: contracts.ApplicationEngine, url: str, @@ -171,6 +147,7 @@ def _request(self, gas_for_response < 10000000: raise ValueError + engine.add_gas(self.get_price(engine.snapshot)) engine.add_gas(gas_for_response) self._gas.mint(engine, self.hash, vm.BigInteger(gas_for_response), False) @@ -225,12 +202,21 @@ def _request(self, msgrouter.interop_notify(self.hash, "OracleRequest", state) + @register("verify", contracts.CallFlags.READ_ONLY, cpu_price=1 << 15) def _verify(self, engine: contracts.ApplicationEngine) -> bool: tx = engine.script_container if not isinstance(tx, payloads.Transaction): return False return bool(tx.try_get_attribute(payloads.OracleResponse)) + def get_request(self, snapshot: storage.Snapshot, id: int) -> Optional[OracleRequest]: + id_bytes = id.to_bytes(8, 'little', signed=False) + storage_item = snapshot.storages.try_get(self.key_request + id_bytes) + if storage_item is None: + return None + + return OracleRequest.deserialize_from_bytes(storage_item.value) + def post_persist(self, engine: contracts.ApplicationEngine) -> None: super(OracleContract, self).post_persist(engine) nodes = [] @@ -253,12 +239,7 @@ def post_persist(self, engine: contracts.ApplicationEngine) -> None: if si_id_list is None: si_id_list = storage.StorageItem(b'\x00') - with serialization.BinaryReader(si_id_list.value) as reader: - count = reader.read_var_int() - id_list = [] - for _ in range(count): - id_list.append(reader.read_uint64()) - + id_list = si_id_list.get(_IdList) id_list.remove(response.id) if len(id_list) == 0: engine.snapshot.storages.delete(sk_id_list) @@ -277,8 +258,35 @@ def post_persist(self, engine: contracts.ApplicationEngine) -> None: if len(nodes) > 0: idx = response.id % len(nodes) # mypy can't figure out that the second item is a BigInteger - nodes[idx][1] += self._ORACLE_REQUEST_PRICE # type: ignore + nodes[idx][1] += self.get_price(engine.snapshot) # type: ignore for pair in nodes: if pair[1].sign > 0: # type: ignore self._gas.mint(engine, pair[0], pair[1], False) + + def _get_url_hash(self, url: str) -> bytes: + return to_script_hash(url.encode('utf-8')).to_array() + + def _get_original_txid(self, engine: contracts.ApplicationEngine) -> types.UInt256: + tx = cast(payloads.Transaction, engine.script_container) + response = tx.try_get_attribute(payloads.OracleResponse) + if response is None: + return tx.hash() + request = self.get_request(engine.snapshot, response.id) + if request is None: + raise ValueError # C# will throw null pointer access exception + return request.original_tx_id + + +class _IdList(list, serialization.ISerializable): + """ + Helper class to get an IdList from storage and deal with caching. + """ + def serialize(self, writer: serialization.BinaryWriter) -> None: + for item in self: + writer.write_uint64(item) + + def deserialize(self, reader: serialization.BinaryReader) -> None: + count = reader.read_var_int() + for _ in range(count): + self.append(reader.read_uint64()) diff --git a/neo3/contracts/native/policy.py b/neo3/contracts/native/policy.py index 96b4657f..0ea8f845 100644 --- a/neo3/contracts/native/policy.py +++ b/neo3/contracts/native/policy.py @@ -1,23 +1,21 @@ from __future__ import annotations -from . import NativeContract +from . import NativeContract, register from neo3.core import types from neo3 import storage, contracts, vm from neo3.network import message class PolicyContract(NativeContract): - _id: int = -5 + _id: int = -7 DEFAULT_EXEC_FEE_FACTOR = 30 MAX_EXEC_FEE_FACTOR = 1000 + DEFAULT_FEE_PER_BYTE = 1000 DEFAULT_STORAGE_PRICE = 100000 MAX_STORAGE_PRICE = 10000000 - key_max_transactions_per_block = storage.StorageKey(_id, b'\x17') key_fee_per_byte = storage.StorageKey(_id, b'\x0A') key_blocked_account = storage.StorageKey(_id, b'\x0F') - key_max_block_size = storage.StorageKey(_id, b'\x0C') - key_max_block_system_fee = storage.StorageKey(_id, b'\x11') key_exec_fee_factor = storage.StorageKey(_id, b'\x12') key_storage_price = storage.StorageKey(_id, b'\x13') @@ -26,132 +24,7 @@ class PolicyContract(NativeContract): def init(self): super(PolicyContract, self).init() - self._register_contract_method(self.get_max_block_size, - "getMaxBlockSize", - 1000000, - call_flags=contracts.CallFlags.READ_STATES, - ) - self._register_contract_method(self.get_max_transactions_per_block, - "getMaxTransactionsPerBlock", - 1000000, - call_flags=contracts.CallFlags.READ_STATES, - ) - self._register_contract_method(self.get_max_block_system_fee, - "getMaxBlockSystemFee", - 1000000, - call_flags=contracts.CallFlags.READ_STATES, - ) - self._register_contract_method(self.get_fee_per_byte, - "getFeePerByte", - 1000000, - call_flags=contracts.CallFlags.READ_STATES, - ) - self._register_contract_method(self.is_blocked, - "isBlocked", - 1000000, - parameter_names=["account"], - call_flags=contracts.CallFlags.READ_STATES, - ) - self._register_contract_method(self._block_account, - "blockAccount", - 3000000, - parameter_names=["account"], - call_flags=contracts.CallFlags.WRITE_STATES) - self._register_contract_method(self._unblock_account, - "unblockAccount", - 3000000, - parameter_names=["account"], - call_flags=contracts.CallFlags.WRITE_STATES) - self._register_contract_method(self._set_max_block_size, - "setMaxBlockSize", - 3000000, - parameter_names=["value"], - call_flags=contracts.CallFlags.WRITE_STATES) - self._register_contract_method(self._set_max_transactions_per_block, - "setMaxTransactionsPerBlock", - 3000000, - parameter_names=["value"], - call_flags=contracts.CallFlags.WRITE_STATES) - self._register_contract_method(self._set_max_block_system_fee, - "setMaxBlockSystemFee", - 3000000, - parameter_names=["value"], - call_flags=contracts.CallFlags.WRITE_STATES) - self._register_contract_method(self._set_fee_per_byte, - "setFeePerByte", - 3000000, - parameter_names=["value"], - call_flags=contracts.CallFlags.WRITE_STATES) - self._register_contract_method(self.get_exec_fee_factor, - "getExecFeeFactor", - 1000000, - call_flags=contracts.CallFlags.READ_STATES - ) - self._register_contract_method(self.get_storage_price, - "getStoragePrice", - 1000000, - call_flags=contracts.CallFlags.READ_STATES - ) - self._register_contract_method(self._set_exec_fee_factor, - "setExecFeeFactor", - 3000000, - parameter_names=["value"], - call_flags=contracts.CallFlags.WRITE_STATES - ) - self._register_contract_method(self._set_storage_price, - "setStoragePrice", - 3000000, - parameter_names=["value"], - call_flags=contracts.CallFlags.WRITE_STATES - ) - - def _int_to_bytes(self, value: int) -> bytes: - return value.to_bytes((value.bit_length() + 7 + 1) // 8, 'little', signed=True) # +1 for signed - - def _initialize(self, engine: contracts.ApplicationEngine) -> None: - def _to_si(value: int) -> storage.StorageItem: - return storage.StorageItem(self._int_to_bytes(value)) - - engine.snapshot.storages.put(self.key_max_transactions_per_block, _to_si(512)) - engine.snapshot.storages.put(self.key_fee_per_byte, _to_si(1000)) - engine.snapshot.storages.put(self.key_max_block_size, _to_si(1024 * 256)) - engine.snapshot.storages.put(self.key_max_block_system_fee, _to_si(int(contracts.GasToken().factor * 9000))) - engine.snapshot.storages.put(self.key_exec_fee_factor, _to_si(self.DEFAULT_EXEC_FEE_FACTOR)) - engine.snapshot.storages.put(self.key_storage_price, _to_si(self.DEFAULT_STORAGE_PRICE)) - - def get_max_block_size(self, snapshot: storage.Snapshot) -> int: - """ - Retrieve the configured maximum size of a Block. - - Returns: - int: maximum number of bytes. - """ - data = snapshot.storages.get( - self.key_max_block_size, - read_only=True - ) - return int.from_bytes(data.value, 'little', signed=True) - - def get_max_transactions_per_block(self, snapshot: storage.Snapshot) -> int: - """ - Retrieve the configured maximum number of transaction in a Block. - - Returns: - int: maximum number of transaction. - """ - data = snapshot.storages.get(self.key_max_transactions_per_block, read_only=True) - return int.from_bytes(data.value, 'little', signed=True) - - def get_max_block_system_fee(self, snapshot: storage.Snapshot) -> int: - """ - Retrieve the configured maximum system fee of a Block. - - Returns: - int: maximum system fee. - """ - data = snapshot.storages.get(self.key_max_block_system_fee, read_only=True) - return int.from_bytes(data.value, 'little', signed=True) - + @register("getFeePerByte", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def get_fee_per_byte(self, snapshot: storage.Snapshot) -> int: """ Retrieve the configured maximum fee per byte of storage. @@ -162,6 +35,7 @@ def get_fee_per_byte(self, snapshot: storage.Snapshot) -> int: data = snapshot.storages.get(self.key_fee_per_byte, read_only=True) return int.from_bytes(data.value, 'little', signed=True) + @register("isBlocked", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def is_blocked(self, snapshot: storage.Snapshot, account: types.UInt160) -> bool: """ Check if the account is blocked @@ -175,46 +49,7 @@ def is_blocked(self, snapshot: storage.Snapshot, account: types.UInt160) -> bool else: return True - def _set_max_block_size(self, engine: contracts.ApplicationEngine, value: int) -> None: - """ - Should only be called through syscalls - """ - if value >= message.Message.PAYLOAD_MAX_SIZE: - raise ValueError("New blocksize exceeds PAYLOAD_MAX_SIZE") - - if not self._check_committee(engine): - raise ValueError("Check committee failed") - - storage_item = engine.snapshot.storages.get(self.key_max_block_size, read_only=False) - storage_item.value = self._int_to_bytes(value) - - def _set_max_transactions_per_block(self, engine: contracts.ApplicationEngine, value: int) -> None: - """ - Should only be called through syscalls - """ - if value > 0xFFFE: # MaxTransactionsPerBlock - raise ValueError("New value exceeds MAX_TRANSACTIONS_PER_BLOCK") - - if not self._check_committee(engine): - raise ValueError("Check committee failed") - - storage_item = engine.snapshot.storages.get(self.key_max_transactions_per_block, read_only=False) - storage_item.value = self._int_to_bytes(value) - - def _set_max_block_system_fee(self, engine: contracts.ApplicationEngine, value: int) -> None: - """ - Should only be called through syscalls - """ - # unknown magic value - if value <= 4007600: - raise ValueError("Invalid new system fee") - - if not self._check_committee(engine): - raise ValueError("Check committee failed") - - storage_item = engine.snapshot.storages.get(self.key_max_block_system_fee, read_only=False) - storage_item.value = self._int_to_bytes(value) - + @register("setFeePerByte", contracts.CallFlags.STATES, cpu_price=1 << 15) def _set_fee_per_byte(self, engine: contracts.ApplicationEngine, value: int) -> None: """ Should only be called through syscalls @@ -228,6 +63,7 @@ def _set_fee_per_byte(self, engine: contracts.ApplicationEngine, value: int) -> storage_item = engine.snapshot.storages.get(self.key_fee_per_byte, read_only=False) storage_item.value = self._int_to_bytes(value) + @register("blockAccount", contracts.CallFlags.STATES, cpu_price=1 << 15) def _block_account(self, engine: contracts.ApplicationEngine, account: types.UInt160) -> bool: """ Should only be called through syscalls @@ -244,6 +80,7 @@ def _block_account(self, engine: contracts.ApplicationEngine, account: types.UIn return True + @register("unblockAccount", contracts.CallFlags.STATES, cpu_price=1 << 15) def _unblock_account(self, engine: contracts.ApplicationEngine, account: types.UInt160) -> bool: """ Should only be called through syscalls @@ -258,10 +95,12 @@ def _unblock_account(self, engine: contracts.ApplicationEngine, account: types.U engine.snapshot.storages.delete(storage_key) return True + @register("getExecFeeFactor", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def get_exec_fee_factor(self, snapshot: storage.Snapshot) -> int: storage_item = snapshot.storages.get(self.key_exec_fee_factor, read_only=True) return int(vm.BigInteger(storage_item.value)) + @register("getStoragePrice", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def get_storage_price(self, snapshot: storage.Snapshot) -> int: if self._storage_price: return self._storage_price @@ -269,6 +108,7 @@ def get_storage_price(self, snapshot: storage.Snapshot) -> int: storage_item = snapshot.storages.get(self.key_storage_price, read_only=True) return int(vm.BigInteger(storage_item.value)) + @register("setExecFeeFactor", contracts.CallFlags.STATES, cpu_price=1 << 15) def _set_exec_fee_factor(self, engine: contracts.ApplicationEngine, value: int) -> None: if value == 0 or value > self.MAX_EXEC_FEE_FACTOR: raise ValueError("New exec fee value out of range") @@ -277,6 +117,7 @@ def _set_exec_fee_factor(self, engine: contracts.ApplicationEngine, value: int) storage_item = engine.snapshot.storages.get(self.key_exec_fee_factor, read_only=False) storage_item.value = vm.BigInteger(value).to_array() + @register("setStoragePrice", contracts.CallFlags.STATES, cpu_price=1 << 15) def _set_storage_price(self, engine: contracts.ApplicationEngine, value: int) -> None: if value == 0 or value > self.MAX_STORAGE_PRICE: raise ValueError("New storage price value out of range") @@ -286,3 +127,14 @@ def _set_storage_price(self, engine: contracts.ApplicationEngine, value: int) -> storage_item.value = vm.BigInteger(value).to_array() self._storage_price = value + + def _int_to_bytes(self, value: int) -> bytes: + return value.to_bytes((value.bit_length() + 7 + 1) // 8, 'little', signed=True) # +1 for signed + + def _initialize(self, engine: contracts.ApplicationEngine) -> None: + def _to_si(value: int) -> storage.StorageItem: + return storage.StorageItem(self._int_to_bytes(value)) + + engine.snapshot.storages.put(self.key_fee_per_byte, _to_si(self.DEFAULT_FEE_PER_BYTE)) + engine.snapshot.storages.put(self.key_exec_fee_factor, _to_si(self.DEFAULT_EXEC_FEE_FACTOR)) + engine.snapshot.storages.put(self.key_storage_price, _to_si(self.DEFAULT_STORAGE_PRICE)) diff --git a/neo3/contracts/native/stdlib.py b/neo3/contracts/native/stdlib.py new file mode 100644 index 00000000..ab96f473 --- /dev/null +++ b/neo3/contracts/native/stdlib.py @@ -0,0 +1,62 @@ +from __future__ import annotations +import base64 +import base58 # type: ignore +from . import NativeContract, register +from neo3 import contracts, vm + + +class StdLibContract(NativeContract): + + _service_name = "StdLib" + _id = -2 + + def init(self): + super(StdLibContract, self).init() + + @register("serialize", contracts.CallFlags.NONE, cpu_price=1 << 12) + def binary_serialize(self, engine: contracts.ApplicationEngine, stack_item: vm.StackItem) -> bytes: + return contracts.BinarySerializer.serialize(stack_item, engine.MAX_ITEM_SIZE) + + @register("deserialize", contracts.CallFlags.NONE, cpu_price=1 << 14) + def binary_deserialize(self, engine: contracts.ApplicationEngine, data: bytes) -> vm.StackItem: + return contracts.BinarySerializer.deserialize(data, engine.MAX_ITEM_SIZE, engine.reference_counter) + + @register("jsonSerialize", contracts.CallFlags.NONE, cpu_price=1 << 12) + def json_serialize(self, engine: contracts.ApplicationEngine, stack_item: vm.StackItem) -> bytes: + return bytes(contracts.JSONSerializer.serialize(stack_item, engine.MAX_ITEM_SIZE), 'utf-8') + + @register("jsonDeserialize", contracts.CallFlags.NONE, cpu_price=1 << 14) + def json_deserialize(self, engine: contracts.ApplicationEngine, data: bytes) -> vm.StackItem: + return contracts.JSONSerializer.deserialize(data.decode(), engine.reference_counter) + + @register("itoa", contracts.CallFlags.NONE, cpu_price=1 << 12) + def do_itoa(self, value: vm.BigInteger, base: int) -> str: + if base == 10: + return str(value) + elif base == 16: + return hex(int(value))[2:] + else: + raise ValueError("Invalid base specified") + + @register("atoi", contracts.CallFlags.NONE, cpu_price=1 << 12) + def do_atoi(self, value: str, base: int) -> int: + if base != 10 and base != 16: + raise ValueError("Invalid base specified") + else: + return int(value, base) + + @register("base64Encode", contracts.CallFlags.NONE, cpu_price=1 << 12) + def base64_encode(self, data: bytes) -> str: + return base64.b64encode(data).decode() + + @register("base64Decode", contracts.CallFlags.NONE, cpu_price=1 << 12) + def base64_decode(self, data: bytes) -> bytes: + return base64.b64decode(data) + + @register("base58Encode", contracts.CallFlags.NONE, cpu_price=1 << 12) + def base58_encode(self, data: bytes) -> str: + return base58.b58encode(data).decode() + + @register("base58Decode", contracts.CallFlags.NONE, cpu_price=1 << 12) + def base58_decode(self, data: bytes) -> bytes: + return base58.b58decode(data) diff --git a/neo3/contracts/nef.py b/neo3/contracts/nef.py index 399ad5f4..403e0b65 100644 --- a/neo3/contracts/nef.py +++ b/neo3/contracts/nef.py @@ -49,7 +49,7 @@ def __eq__(self, other): and self.checksum == other.checksum) @property - def checksum(self): + def checksum(self) -> int: if self._checksum == 0: self._checksum = self.compute_checksum() return self._checksum diff --git a/neo3/core/cryptography/merkletree.py b/neo3/core/cryptography/merkletree.py index 51e559e4..4e21a04a 100644 --- a/neo3/core/cryptography/merkletree.py +++ b/neo3/core/cryptography/merkletree.py @@ -34,7 +34,10 @@ def __init__(self, hashes: List[types.UInt256]): ValueError: if the `hashes` list is empty. """ if len(hashes) == 0: - raise ValueError("Hashes list can't empty") + self.has_root = False + return + else: + self.has_root = True self.root = self._build(leaves=[_MerkleTreeNode(h) for h in hashes]) _depth = 1 @@ -51,7 +54,8 @@ def to_hash_array(self) -> List[types.UInt256]: Note: does not include the Merkle root hash. """ hashes: List[types.UInt256] = [] - MerkleTree._depth_first_search(self.root, hashes) + if self.has_root: + MerkleTree._depth_first_search(self.root, hashes) return hashes @staticmethod @@ -101,6 +105,8 @@ def compute_root(hashes: List[types.UInt256]) -> types.UInt256: Raises: ValueError: if the `hashes` list is empty. """ + if len(hashes) == 0: + return types.UInt256.zero() if len(hashes) == 1: return hashes[0] tree = MerkleTree(hashes) diff --git a/neo3/core/serialization.py b/neo3/core/serialization.py index 5b228cce..44b5d41f 100644 --- a/neo3/core/serialization.py +++ b/neo3/core/serialization.py @@ -357,10 +357,13 @@ def read_serializable_list(self, obj_type: Type[ISerializable_T], max: int = Non if max and count > max: count = max - for _ in range(count): - obj = obj_type._serializable_init() - obj.deserialize(self) - obj_array.append(obj) + try: + for _ in range(count): + obj = obj_type._serializable_init() + obj.deserialize(self) + obj_array.append(obj) + except Exception as e: + raise ValueError(f"Insufficient data - {str(e)}") return obj_array def close(self) -> None: diff --git a/neo3/network/payloads/__init__.py b/neo3/network/payloads/__init__.py index 0385d5e5..0ce3b65a 100644 --- a/neo3/network/payloads/__init__.py +++ b/neo3/network/payloads/__init__.py @@ -4,7 +4,6 @@ from .address import NetworkAddress, AddrPayload, AddressState, DisconnectReason from .ping import PingPayload from .verification import Witness, WitnessScope, Signer, IVerifiable -from .consensus import ConsensusData, ConsensusPayload from .transaction import Transaction, TransactionAttribute, TransactionAttributeType from .block import (Header, Block, @@ -19,7 +18,7 @@ __all__ = ['EmptyPayload', 'InventoryPayload', 'InventoryType', 'VersionPayload', 'NetworkAddress', 'AddrPayload', 'PingPayload', 'Witness', 'WitnessScope', 'Header', 'Block', 'MerkleBlockPayload', - 'HeadersPayload', 'ConsensusData', 'ConsensusPayload', 'Transaction', 'TransactionAttribute', + 'HeadersPayload', 'Transaction', 'TransactionAttribute', 'TransactionAttributeType', 'Signer', 'GetBlocksPayload', 'GetBlockByIndexPayload', 'FilterAddPayload', 'FilterLoadPayload', 'TrimmedBlock', 'IVerifiable', 'OracleReponseCode', 'OracleReponseCode', 'ExtensiblePayload'] diff --git a/neo3/network/payloads/block.py b/neo3/network/payloads/block.py index 3b481e27..34329b64 100644 --- a/neo3/network/payloads/block.py +++ b/neo3/network/payloads/block.py @@ -2,7 +2,6 @@ import hashlib import struct from typing import List - from neo3 import vm, storage, settings from neo3.core import Size as s, serialization, types, utils, cryptography as crypto, IClonable, IInteroperable from neo3.network import payloads @@ -11,27 +10,49 @@ from .verification import IVerifiable -class _BlockBase(IVerifiable): - def __init__(self, - version: int, - prev_hash: types.UInt256, - timestamp: int, - index: int, - next_consensus: types.UInt160, - witness: payloads.Witness, - merkle_root: types.UInt256 = None, - ): +class Header(IVerifiable): + """ + A Block header only object. + Does not contain any consensus data or transactions. + + See also: + :class:`~neo3.network.payloads.block.TrimmedBlock` + """ + def __init__(self, version: int, prev_hash: types.UInt256, timestamp: int, index: int, primary_index: int, + next_consensus: types.UInt160, witness: payloads.Witness, merkle_root: types.UInt256 = None, *args, + **kwargs): + super().__init__(*args, **kwargs) self.version = version self.prev_hash = prev_hash self.merkle_root = merkle_root if merkle_root else types.UInt256.zero() self.timestamp = timestamp self.index = index + self.primary_index = primary_index self.next_consensus = next_consensus self.witness = witness def __len__(self): - return s.uint32 + s.uint256 + s.uint256 + s.uint64 + s.uint32 + s.uint160 + 1 + len(self.witness) + return s.uint32 + s.uint256 + s.uint256 + s.uint64 + s.uint32 + s.uint8 + s.uint160 + 1 + len(self.witness) + + def __eq__(self, other): + if other is None: + return False + if type(self) != type(other): + return False + if self.hash() != other.hash(): + return False + return True + + def hash(self) -> types.UInt256: + """ + Get a unique identifier based on the unsigned data portion of the object. + """ + with serialization.BinaryWriter() as bw: + self.serialize_unsigned(bw) + data_to_hash = bytearray(bw._stream.getvalue()) + data = hashlib.sha256(data_to_hash).digest() + return types.UInt256(data=data) def serialize(self, writer: serialization.BinaryWriter) -> None: """ @@ -41,8 +62,7 @@ def serialize(self, writer: serialization.BinaryWriter) -> None: writer: instance. """ self.serialize_unsigned(writer) - writer.write_uint8(1) - writer.write_serializable(self.witness) + writer.write_serializable_list([self.witness]) def serialize_unsigned(self, writer: serialization.BinaryWriter) -> None: writer.write_uint32(self.version) @@ -50,6 +70,7 @@ def serialize_unsigned(self, writer: serialization.BinaryWriter) -> None: writer.write_serializable(self.merkle_root) writer.write_uint64(self.timestamp) writer.write_uint32(self.index) + writer.write_uint8(self.primary_index) writer.write_serializable(self.next_consensus) def deserialize(self, reader: serialization.BinaryReader) -> None: @@ -60,13 +81,13 @@ def deserialize(self, reader: serialization.BinaryReader) -> None: reader: instance. Raises: - ValueError: if no witnesses are found. + ValueError: if the check byte does not equal. """ self.deserialize_unsigned(reader) - witness_obj_count = reader.read_uint8() - if witness_obj_count != 1: - raise ValueError(f"Deserialization error - Witness object count is {witness_obj_count} must be 1") - self.witness = reader.read_serializable(payloads.Witness) + witnesses = reader.read_serializable_list(payloads.Witness, max=1) + if len(witnesses) != 1: + raise ValueError(f"Deserialization error - Witness object count is {len(witnesses)} must be 1") + self.witness = witnesses[0] def deserialize_unsigned(self, reader: serialization.BinaryReader) -> None: (self.version, @@ -74,22 +95,15 @@ def deserialize_unsigned(self, reader: serialization.BinaryReader) -> None: merkleroot, self.timestamp, self.index, - consensus) = struct.unpack("= len(settings.standby_validators): + raise ValueError(f"Deserialization error - primary index {self.primary_index} exceeds validator count " + f"{len(settings.standby_validators)}") self.prev_hash = types.UInt256(prev_hash) self.merkle_root = types.UInt256(merkleroot) self.next_consensus = types.UInt160(consensus) - def hash(self) -> types.UInt256: - """ - Get a unique block identifier based on the unsigned data portion of the object. - """ - with serialization.BinaryWriter() as bw: - bw.write_uint32(settings.network.magic) - self.serialize_unsigned(bw) - data_to_hash = bytearray(bw._stream.getvalue()) - data = hashlib.sha256(hashlib.sha256(data_to_hash).digest()).digest() - return types.UInt256(data=data) - def get_script_hashes_for_verifying(self, snapshot: storage.Snapshot) -> List[types.UInt160]: if self.prev_hash == types.UInt256.zero(): return [self.witness.script_hash()] @@ -98,103 +112,37 @@ def get_script_hashes_for_verifying(self, snapshot: storage.Snapshot) -> List[ty raise ValueError("Can't get next_consensus hash from previous block. Block does not exist") return [prev_block.next_consensus] - -class Header(_BlockBase): - """ - A Block header only object. - - Does not contain any consensus data or transactions. - - See also: - :class:`~neo3.network.payloads.block.TrimmedBlock` - """ - def __init__(self, - version: int, - prev_hash: types.UInt256, - timestamp: int, - index: int, - next_consensus: types.UInt160, - witness: payloads.Witness, - merkle_root: types.UInt256 = None - ): - super(Header, self).__init__(version, prev_hash, timestamp, index, next_consensus, witness, merkle_root) - - def __len__(self): - return super(Header, self).__len__() + 1 - - def __eq__(self, other): - if other is None: - return False - if type(self) != type(other): - return False - if self.hash() != other.hash(): - return False - return True - - def serialize(self, writer: serialization.BinaryWriter) -> None: - """ - Serialize the object into a binary stream. - - Args: - writer: instance. - """ - super(Header, self).serialize(writer) - writer.write_uint8(0) - - def deserialize(self, reader: serialization.BinaryReader) -> None: - """ - Deserialize the object from a binary stream. - - Args: - reader: instance. - - Raises: - ValueError: if the check byte does not equal. - """ - super(Header, self).deserialize(reader) - tmp = reader.read_uint8() - if tmp != 0: - raise ValueError("Deserialization error") - @classmethod def _serializable_init(cls): return cls(0, types.UInt256.zero(), 0, 0, + 0, types.UInt160.zero(), payloads.Witness(b'', b'')) -class Block(_BlockBase, payloads.IInventory): +class Block(payloads.IInventory): """ The famous Block. I transfer chain state. """ - #: The maximum item count per block. Consensus data and Transactions are considered items. - MAX_CONTENTS_PER_BLOCK = 65535 - #: The maximum number of transactions allowed to be included in a block. - MAX_TX_PER_BLOCK = MAX_CONTENTS_PER_BLOCK - 1 def __init__(self, - version: int, - prev_hash: types.UInt256, - timestamp: int, - index: int, - next_consensus: types.UInt160, - witness: payloads.Witness, - consensus_data: payloads.ConsensusData, + header: Header, transactions: List[payloads.Transaction] = None, - merkle_root: types.UInt256 = None, + *args, + **kwargs ): - super(Block, self).__init__(version, prev_hash, timestamp, index, next_consensus, witness, merkle_root) - self.consensus_data = consensus_data + super(Block, self).__init__(*args, **kwargs) + self.header = header self.transactions = [] if transactions is None else transactions def __len__(self): # calculate the varint length that needs to be inserted before the transaction objects. magic_len = utils.get_var_size(len(self.transactions)) txs_len = sum([len(t) for t in self.transactions]) - return super(Block, self).__len__() + magic_len + len(self.consensus_data) + txs_len + return len(self.header) + magic_len + txs_len def __eq__(self, other): if other is None: @@ -205,6 +153,44 @@ def __eq__(self, other): return False return True + @property + def version(self) -> int: + return self.header.version + + @property + def prev_hash(self) -> types.UInt256: + return self.header.prev_hash + + @property + def merkle_root(self) -> types.UInt256: + return self.header.merkle_root + + @property + def timestamp(self) -> int: + return self.header.timestamp + + @property + def index(self) -> int: + return self.header.index + + @property + def primary_index(self) -> int: + return self.header.primary_index + + @property + def next_consensus(self) -> types.UInt160: + return self.header.next_consensus + + @property + def witness(self) -> payloads.Witness: + return self.header.witness + + def hash(self) -> types.UInt256: + return self.header.hash() + + def get_script_hashes_for_verifying(self, snapshot: storage.Snapshot) -> List[types.UInt160]: + return self.header.get_script_hashes_for_verifying(snapshot) + @property def inventory_type(self) -> payloads.InventoryType: """ @@ -219,12 +205,14 @@ def serialize(self, writer: serialization.BinaryWriter) -> None: Args: writer: instance. """ - super(Block, self).serialize(writer) - writer.write_var_int(len(self.transactions) + 1) - writer.write_serializable(self.consensus_data) + writer.write_serializable(self.header) + writer.write_var_int(len(self.transactions)) for tx in self.transactions: writer.write_serializable(tx) + def serialize_unsigned(self, writer: serialization.BinaryWriter) -> None: + self.header.serialize_unsigned(writer) + def deserialize(self, reader: serialization.BinaryReader) -> None: """ Deserialize the object from a binary stream. @@ -236,164 +224,93 @@ def deserialize(self, reader: serialization.BinaryReader) -> None: ValueError: if the content count of the block is zero, or if there is a duplicate transaction in the list, or if the merkle root does not included the calculated root. """ - super(Block, self).deserialize(reader) - content_count = reader.read_var_int(max=self.MAX_CONTENTS_PER_BLOCK) - if content_count == 0: - raise ValueError("Deserialization error - no contents") - - self.consensus_data = reader.read_serializable(payloads.ConsensusData) - tx_count = content_count - 1 - for _ in range(tx_count): - self.transactions.append(reader.read_serializable(payloads.Transaction)) + self.header = reader.read_serializable(Header) + self.transactions = reader.read_serializable_list(payloads.Transaction, max=0xFFFF) - if len(set(self.transactions)) != tx_count: + if len(set(self.transactions)) != len(self.transactions): raise ValueError("Deserialization error - block contains duplicate transaction") hashes = [t.hash() for t in self.transactions] - if Block.calculate_merkle_root(self.consensus_data.hash(), hashes) != self.merkle_root: + if crypto.MerkleTree.compute_root(hashes) != self.header.merkle_root: raise ValueError("Deserialization error - merkle root mismatch") + def deserialize_unsigned(self, reader: serialization.BinaryReader) -> None: + raise NotImplementedError + def rebuild_merkle_root(self) -> None: """ Recalculates the Merkle root. """ - self.merkle_root = Block.calculate_merkle_root(self.consensus_data.hash(), - [t.hash() for t in self.transactions]) + self.header.merkle_root = crypto.MerkleTree.compute_root([t.hash() for t in self.transactions]) def trim(self) -> TrimmedBlock: """ - Reduce a block in size by replacing the consensus data and transaction objects with their identifying hashes. + Reduce a block in size by replacing the transaction objects with their identifying hashes. """ - hashes = [self.consensus_data.hash()] + [t.hash() for t in self.transactions] - return TrimmedBlock(version=self.version, - prev_hash=self.prev_hash, - merkle_root=self.merkle_root, - timestamp=self.timestamp, - index=self.index, - next_consensus=self.next_consensus, - witness=self.witness, - hashes=hashes, - consensus_data=self.consensus_data - ) - - @staticmethod - def calculate_merkle_root(consensus_data_hash: types.UInt256, - transaction_hashes: List[types.UInt256]) -> types.UInt256: - """ - Calculate a Merkle root. - - Args: - consensus_data_hash: - transaction_hashes: - """ - hashes = [consensus_data_hash] + transaction_hashes - return crypto.MerkleTree.compute_root(hashes) + return TrimmedBlock(self.header, [t.hash() for t in self.transactions]) def from_replica(self, replica: Block) -> None: """ Shallow copy attributes from a reference object. """ - self.version = replica.version - self.prev_hash = replica.prev_hash - self.merkle_root = replica.merkle_root - self.timestamp = replica.timestamp - self.index = replica.index - self.next_consensus = replica.next_consensus - self.witness = replica.witness - self.consensus_data = replica.consensus_data + self.header = replica.header self.transactions = replica.transactions @classmethod def _serializable_init(cls): - return cls(0, - types.UInt256.zero(), - 0, - 0, - types.UInt160.zero(), - payloads.Witness(b'', b''), - payloads.ConsensusData()) + return cls(Header._serializable_init(), []) -class TrimmedBlock(_BlockBase): +class TrimmedBlock(serialization.ISerializable): """ A size reduced Block instance. Contains consensus data and transactions hashes instead of their full objects. """ - def __init__(self, - version: int, - prev_hash: types.UInt256, - timestamp: int, - index: int, - next_consensus: types.UInt160, - witness: payloads.Witness, - hashes: List[types.UInt256], - consensus_data: payloads.ConsensusData, - merkle_root: types.UInt256 = None): - super(TrimmedBlock, self).__init__(version, prev_hash, timestamp, index, next_consensus, witness, merkle_root) + def __init__(self, header: Header, hashes: List[types.UInt256]): + super(TrimmedBlock, self).__init__() + self.header = header self.hashes = hashes - self.consensus_data = consensus_data def __len__(self): - size = super(TrimmedBlock, self).__len__() - size += utils.get_var_size(self.hashes) - if self.consensus_data: - size += len(self.consensus_data) - return size + return len(self.header) + utils.get_var_size(self.hashes) def __deepcopy__(self, memodict={}): # not the best, but faster than letting deepcopy() do introspection - return TrimmedBlock.deserialize_from_bytes(self.to_array()) + return self.__class__.deserialize_from_bytes(self.to_array()) + + def hash(self): + return self.header.hash() + + @property + def index(self): + return self.header.index def serialize(self, writer: serialization.BinaryWriter) -> None: - super(TrimmedBlock, self).serialize(writer) + writer.write_serializable(self.header) writer.write_serializable_list(self.hashes) - if len(self.hashes) > 0: - writer.write_serializable(self.consensus_data) def deserialize(self, reader: serialization.BinaryReader) -> None: - super(TrimmedBlock, self).deserialize(reader) - self.hashes = reader.read_serializable_list(types.UInt256) - if len(self.hashes) > 0: - self.consensus_data = reader.read_serializable(payloads.ConsensusData) + self.header = reader.read_serializable(Header) + self.hashes = reader.read_serializable_list(types.UInt256, max=0xFFFF) @classmethod def _serializable_init(cls): - return cls(0, - types.UInt256.zero(), - 0, - 0, - types.UInt160.zero(), - payloads.Witness(b'', b''), - [], - payloads.ConsensusData()) + return cls(Header._serializable_init(), []) -class MerkleBlockPayload(_BlockBase): +class MerkleBlockPayload(serialization.ISerializable): def __init__(self, block: Block, flags: bitarray): - super(MerkleBlockPayload, self).__init__(block.version, - block.prev_hash, - block.timestamp, - block.index, - block.next_consensus, - block.witness, - block.merkle_root) - hashes = [block.consensus_data.hash()] + [t.hash() for t in block.transactions] + hashes = [t.hash() for t in block.transactions] tree = crypto.MerkleTree(hashes) self.flags = flags.tobytes() - self.content_count = len(hashes) + self.tx_count = len(hashes) self.hashes = tree.to_hash_array() + self.header = block.header def __len__(self): - return super(MerkleBlockPayload, self).__len__() + s.uint32 + utils.get_var_size(self.hashes) + \ - utils.get_var_size(self.flags) - - @classmethod - def _serializable_init(cls): - block = payloads.Block._serializable_init() - flags = bitarray() - return cls(block, flags) + return len(self.header) + s.uint32 + utils.get_var_size(self.hashes) + utils.get_var_size(self.flags) def serialize(self, writer: serialization.BinaryWriter) -> None: """ @@ -402,8 +319,8 @@ def serialize(self, writer: serialization.BinaryWriter) -> None: Args: writer: instance. """ - super(MerkleBlockPayload, self).serialize(writer) - writer.write_var_int(self.content_count) + writer.write_serializable(self.header) + writer.write_var_int(self.tx_count) writer.write_serializable_list(self.hashes) writer.write_var_bytes(self.flags) @@ -414,10 +331,16 @@ def deserialize(self, reader: serialization.BinaryReader) -> None: Args: reader: instance. """ - super(MerkleBlockPayload, self).deserialize(reader) - self.content_count = reader.read_var_int() - self.hashes = reader.read_serializable_list(types.UInt256) - self.flags = reader.read_var_bytes() + self.header = reader.read_serializable(Header) + self.tx_count = reader.read_var_int(max=0xFFFF) + self.hashes = reader.read_serializable_list(types.UInt256, max=self.tx_count) + self.flags = reader.read_var_bytes(max=(max(self.tx_count, 1) + 7) // 8) + + @classmethod + def _serializable_init(cls): + block = payloads.Block._serializable_init() + flags = bitarray() + return cls(block, flags) class HeadersPayload(serialization.ISerializable): diff --git a/neo3/network/payloads/consensus.py b/neo3/network/payloads/consensus.py index d9e88465..5d77aac7 100644 --- a/neo3/network/payloads/consensus.py +++ b/neo3/network/payloads/consensus.py @@ -56,135 +56,3 @@ def deserialize_specialization_from_bytes(self, data: bytearray) -> ConsensusMes @classmethod def _serializable_init(cls): return cls(ConsensusMessageType.CHANGE_VIEW) - - -class ConsensusPayload(payloads.IInventory): - def __init__(self, version: int, - prev_hash: types.UInt256, - block_index: int, - validator_index: int, - data: bytes, - witness: payloads.Witness): - self.version = version - self.prev_hash = prev_hash - self.block_index = block_index - self.validator_index = validator_index - self.data = data - self.witness = witness - - def __len__(self): - return (s.uint32 + len(self.prev_hash) + s.uint32 + s.uint8 + utils.get_var_size(self.data) + 1 - + len(self.witness)) - - def hash(self) -> types.UInt256: - with serialization.BinaryWriter() as bw: - bw.write_uint32(settings.network.magic) - self.serialize_unsigned(bw) - data_to_hash = bytearray(bw._stream.getvalue()) - data = hashlib.sha256(hashlib.sha256(data_to_hash).digest()).digest() - return types.UInt256(data=data) - - @property - def inventory_type(self): - return payloads.InventoryType.CONSENSUS - - def serialize(self, writer: serialization.BinaryWriter) -> None: - """ - Serialize the object into a binary stream. - - Args: - writer: instance. - """ - self.serialize_unsigned(writer) - writer.write_uint8(1) - writer.write_serializable(self.witness) - - def serialize_unsigned(self, writer: serialization.BinaryWriter) -> None: - """ - Serialize the object into a binary stream excluding the validation byte + witness. - - Args: - writer: instance. - """ - writer.write_uint32(self.version) - writer.write_serializable(self.prev_hash) - writer.write_uint32(self.block_index) - writer.write_uint8(self.validator_index) - writer.write_var_bytes(self.data) - - def deserialize(self, reader: serialization.BinaryReader) -> None: - """ - Deserialize the object from a binary stream. - - Args: - reader: instance. - - Raises: - ValueError: if the validation byte is not 1 - """ - self.deserialize_unsigned(reader) - if reader.read_uint8() != 1: - raise ValueError("Deserialization error - validation byte not 1") - self.witness = reader.read_serializable(payloads.Witness) - - def deserialize_unsigned(self, reader: serialization.BinaryReader) -> None: - """ - Deserialize the object from a binary stream excluding the validation byte + witness. - - Args: - reader: instance. - """ - self.version = reader.read_uint32() - self.prev_hash = reader.read_serializable(types.UInt256) - self.block_index = reader.read_uint32() - self.validator_index = reader.read_uint8() - if self.validator_index >= settings.network.validators_count: - raise ValueError("Deserialization error - validator index exceeds validator count") - self.data = reader.read_var_bytes() - - def get_script_hashes_for_verifying(self, snapshot: storage.Snapshot) -> List[types.UInt160]: - validators = contracts.NeoToken().get_next_block_validators(snapshot) - if len(validators) < self.validator_index: - raise ValueError("Validator index is out of range") - return [to_script_hash( - contracts.Contract.create_signature_redeemscript(validators[self.validator_index]) - )] - - @classmethod - def _serializable_init(cls): - return cls(0, types.UInt256.zero(), 0, 0, b'', payloads.Witness(b'', b'')) - - -class ConsensusData(serialization.ISerializable): - def __init__(self, primary_index: int = 0, nonce: int = 0): - self.primary_index = primary_index - self.nonce = nonce - - def __len__(self): - return s.uint8 + s.uint64 - - def hash(self) -> types.UInt256: - data = hashlib.sha256(hashlib.sha256(self.to_array()).digest()).digest() - return types.UInt256(data=data) - - def serialize(self, writer: serialization.BinaryWriter) -> None: - """ - Serialize the object into a binary stream. - - Args: - writer: instance. - """ - writer.write_uint8(self.primary_index) - writer.write_uint64(self.nonce) - - def deserialize(self, reader: serialization.BinaryReader) -> None: - """ - Deserialize the object from a binary stream. - - Args: - reader: instance. - """ - self.primary_index = reader.read_uint8() - if self.primary_index >= settings.network.validators_count: - raise ValueError("Deserialization error - primary index exceeds validator count") - self.nonce = reader.read_uint64() diff --git a/neo3/network/payloads/extensible.py b/neo3/network/payloads/extensible.py index a81ec8c3..df5a72a3 100644 --- a/neo3/network/payloads/extensible.py +++ b/neo3/network/payloads/extensible.py @@ -3,7 +3,7 @@ from typing import List from neo3 import storage, settings from neo3.core import types, serialization, Size as s, utils -from neo3.network import payloads +from neo3.network import payloads, message from neo3.network.payloads import InventoryType @@ -56,12 +56,14 @@ def deserialize_unsigned(self, reader: serialization.BinaryReader) -> None: if self.valid_block_start >= self.valid_block_end: raise ValueError("Deserialization error - valid_block_starts is bigger than valid_block_end") self.sender = reader.read_serializable(types.UInt160) - self.data = reader.read_var_bytes(0xFFFF) + self.data = reader.read_var_bytes(message.Message.PAYLOAD_MAX_SIZE) def hash(self) -> types.UInt256: - intermediate_data = hashlib.sha256(self.get_hash_data(settings.network.magic)).digest() - data = hashlib.sha256(intermediate_data).digest() - return types.UInt256(data) + with serialization.BinaryWriter() as bw: + self.serialize_unsigned(bw) + data_to_hash = bytearray(bw._stream.getvalue()) + data = hashlib.sha256(data_to_hash).digest() + return types.UInt256(data=data) @property def inventory_type(self) -> InventoryType: diff --git a/neo3/network/payloads/transaction.py b/neo3/network/payloads/transaction.py index ecbba83c..f575643e 100644 --- a/neo3/network/payloads/transaction.py +++ b/neo3/network/payloads/transaction.py @@ -192,10 +192,9 @@ def hash(self) -> types.UInt256: Get a unique block identifier based on the unsigned data portion of the object. """ with serialization.BinaryWriter() as bw: - bw.write_uint32(self.protocol_magic) self.serialize_unsigned(bw) data_to_hash = bytearray(bw._stream.getvalue()) - data = hashlib.sha256(hashlib.sha256(data_to_hash).digest()).digest() + data = hashlib.sha256(data_to_hash).digest() return types.UInt256(data=data) @property @@ -242,7 +241,9 @@ def deserialize(self, reader: serialization.BinaryReader) -> None: reader: instance. """ self.deserialize_unsigned(reader) - self.witnesses = reader.read_serializable_list(payloads.Witness) + self.witnesses = reader.read_serializable_list(payloads.Witness, max=len(self.signers)) + if len(self.witnesses) != len(self.signers): + raise ValueError("Deserialization error - witness length does not match signers length") def deserialize_unsigned(self, reader: serialization.BinaryReader) -> None: (self.version, diff --git a/neo3/storage/__init__.py b/neo3/storage/__init__.py index 34aa847e..4a40d19d 100644 --- a/neo3/storage/__init__.py +++ b/neo3/storage/__init__.py @@ -1,6 +1,6 @@ from .base import (IDBImplementation, StorageContext) -from .storageitem import StorageItem, StorageFlags +from .storageitem import StorageItem from .storagekey import StorageKey from .cache import (Trackable, TrackState, diff --git a/neo3/storage/storageitem.py b/neo3/storage/storageitem.py index bd8f445c..3b7266ad 100644 --- a/neo3/storage/storageitem.py +++ b/neo3/storage/storageitem.py @@ -1,22 +1,15 @@ from __future__ import annotations from typing import Type, Optional -from enum import IntFlag from neo3.core import serialization, utils, IClonable, Size as s -class StorageFlags(IntFlag): - NONE = 0 - CONSTANT = 0x1 - - class StorageItem(serialization.ISerializable, IClonable): - def __init__(self, value: bytes, is_constant=False): + def __init__(self, value: bytes): self._value = value - self.is_constant = is_constant self._cache: Optional[serialization.ISerializable] = None def __len__(self): - return utils.get_var_size(self.value) + s.uint8 + return utils.get_var_size(self.value) def __eq__(self, other): if not isinstance(other, type(self)): @@ -35,19 +28,17 @@ def value(self, new_value: bytes) -> None: self._cache = None def serialize(self, writer: serialization.BinaryWriter) -> None: - writer.write_var_bytes(self.value) - writer.write_bool(self.is_constant) + writer.write_bytes(self.value) def deserialize(self, reader: serialization.BinaryReader) -> None: - self.value = reader.read_var_bytes() - self.is_constant = reader.read_bool() + remaining_stream_size = len(reader) - reader._stream.tell() + self.value = reader.read_bytes(remaining_stream_size) def clone(self) -> StorageItem: - return StorageItem(self.value, self.is_constant) + return StorageItem(self.value) def from_replica(self, replica: StorageItem) -> None: self.value = replica.value - self.is_constant = replica.is_constant def get(self, type_: Type[serialization.ISerializable]): if self._cache and type(self._cache) == type_: diff --git a/requirements.txt b/requirements.txt index 78776803..e8529d38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,8 +6,8 @@ coverage>=5.0.2 Events==0.3 lz4==2.2.1 neo3crypto==0.2 -neo3vm>=0.5.1 -neo3vm-stubs>=0.5.1 +neo3vm>=0.6 +neo3vm-stubs>=0.6 mmh3==2.5.1 mypy>=0.782 mypy-extensions==0.4.3 diff --git a/tests/contracts/interop/test_binary.py b/tests/contracts/interop/test_binary.py deleted file mode 100644 index 06140dd8..00000000 --- a/tests/contracts/interop/test_binary.py +++ /dev/null @@ -1,113 +0,0 @@ -import unittest -from neo3 import vm -from tests.contracts.interop.utils import test_engine - - -class BinaryInteropTestCase(unittest.TestCase): - def test_serialization(self): - engine = test_engine() - engine.push(vm.IntegerStackItem(100)) - engine.invoke_syscall_by_name("System.Binary.Serialize") - item = engine.pop() - self.assertIsInstance(item, vm.ByteStringStackItem) - self.assertEqual(b'\x21\x01\x64', item.to_array()) - - # Create an item with data larger than engine.MAX_ITEM_SIZE - # this should fail in the BinarySerializer class - engine.push(vm.ByteStringStackItem(b'\x01' * (1024 * 1024 * 2))) - with self.assertRaises(ValueError) as context: - engine.invoke_syscall_by_name("System.Binary.Serialize") - self.assertEqual("Output length exceeds max size", str(context.exception)) - - def test_deserialization(self): - engine = test_engine() - original_item = vm.IntegerStackItem(100) - engine.push(original_item) - engine.invoke_syscall_by_name("System.Binary.Serialize") - engine.invoke_syscall_by_name("System.Binary.Deserialize") - item = engine.pop() - self.assertEqual(original_item, item) - - engine.push(vm.ByteStringStackItem(b'\xfa\x01')) - with self.assertRaises(ValueError) as context: - engine.invoke_syscall_by_name("System.Binary.Deserialize") - self.assertEqual("Invalid format", str(context.exception)) - - def test_base64(self): - engine = test_engine() - original_item = vm.IntegerStackItem(100) - engine.push(original_item) - engine.invoke_syscall_by_name("System.Binary.Base64Encode") - item = engine.pop() - self.assertEqual('ZA==', item.to_array().decode()) - - engine.push(item) - engine.invoke_syscall_by_name("System.Binary.Base64Decode") - item = engine.pop() - self.assertEqual(original_item, vm.IntegerStackItem(item.to_array())) - - def test_base58(self): - engine = test_engine() - original_item = vm.IntegerStackItem(100) - engine.push(original_item) - engine.invoke_syscall_by_name("System.Binary.Base58Encode") - item = engine.pop() - self.assertEqual('2j', item.to_array().decode()) - - engine.push(item) - engine.invoke_syscall_by_name("System.Binary.Base58Decode") - item = engine.pop() - self.assertEqual(original_item, vm.IntegerStackItem(item.to_array())) - - def test_itoa(self): - engine = test_engine() - original_item = vm.IntegerStackItem(100) - base = vm.IntegerStackItem(10) - engine.push(base) - engine.push(original_item) - engine.invoke_syscall_by_name("System.Binary.Itoa") - item = engine.pop() - self.assertEqual('100', item.to_array().decode('utf-8')) - - engine = test_engine() - base = vm.IntegerStackItem(16) - engine.push(base) - engine.push(original_item) - engine.invoke_syscall_by_name("System.Binary.Itoa") - item = engine.pop() - self.assertEqual('64', item.to_array().decode('utf-8')) - - engine = test_engine() - invalid_base = vm.IntegerStackItem(2) - engine.push(invalid_base) - engine.push(original_item) - with self.assertRaises(ValueError) as context: - engine.invoke_syscall_by_name("System.Binary.Itoa") - self.assertIn("Invalid base specified", str(context.exception)) - - def test_atoi(self): - engine = test_engine() - original_item = vm.ByteStringStackItem(b'100') - base = vm.IntegerStackItem(10) - engine.push(base) - engine.push(original_item) - engine.invoke_syscall_by_name("System.Binary.Atoi") - item = engine.pop() - self.assertEqual(vm.IntegerStackItem(100), item) - - engine = test_engine() - original_item = vm.ByteStringStackItem(b'64') - base = vm.IntegerStackItem(16) - engine.push(base) - engine.push(original_item) - engine.invoke_syscall_by_name("System.Binary.Atoi") - item = engine.pop() - self.assertEqual(vm.IntegerStackItem(100), item) - - engine = test_engine() - invalid_base = vm.IntegerStackItem(2) - engine.push(invalid_base) - engine.push(original_item) - with self.assertRaises(ValueError) as context: - engine.invoke_syscall_by_name("System.Binary.Atoi") - self.assertIn("Invalid base specified", str(context.exception)) diff --git a/tests/contracts/interop/test_contract_interop.py b/tests/contracts/interop/test_contract_interop.py index 509ad303..d681f219 100644 --- a/tests/contracts/interop/test_contract_interop.py +++ b/tests/contracts/interop/test_contract_interop.py @@ -131,13 +131,6 @@ def test_func2(value: int) -> int: contract3_manifest = contracts.ContractManifest.from_json(raw_contract3_manifest) -def to_contract_hash(sender: types.UInt160, script: bytes): - sb = vm.ScriptBuilder() - sb.emit(vm.OpCode.ABORT) - sb.emit_push(sender.to_array()) - sb.emit_push(script) - return to_script_hash(sb.to_array()) - class RuntimeInteropTestCase(unittest.TestCase): def shortDescription(self): # disable docstring printing in test runner @@ -257,55 +250,6 @@ def test_contract_call_exceptions(self): engine._contract_call_internal(target_contract.hash, "test_func", contracts.CallFlags.ALL, False, array) self.assertEqual("[System.Contract.Call] Method 'test_func' with 2 arguments does not exist on target contract", str(context.exception)) - @unittest.SkipTest - def test_contract_is_standard_ok(self): - keypair = cryptography.KeyPair(b'\x01' * 32) - sig_contract = contracts.Contract.create_signature_contract(keypair.public_key) - - engine = test_engine(has_snapshot=True) - contract = contracts.ContractState(sig_contract.script, contracts.ContractManifest(sig_contract.script_hash)) - engine.snapshot.contracts.put(contract) - engine.push(vm.ByteStringStackItem(contract.script_hash().to_array())) - engine.invoke_syscall_by_name("System.Contract.IsStandard") - engine.execute() - self.assertEqual(True, engine.result_stack.pop().to_boolean()) - - def test_contract_is_standard_fail(self): - # can't find contract - engine = test_engine(has_snapshot=True) - engine.push(vm.ByteStringStackItem(types.UInt160.zero().to_array())) - engine.invoke_syscall_by_name("System.Contract.IsStandard") - engine.execute() - self.assertEqual(False, engine.result_stack.pop().to_boolean()) - - @unittest.SkipTest - def test_contract_is_standard_fail2(self): - # can find contract, but is not a signature contract - engine = test_engine(has_snapshot=True) - - # create a non-standard contract - script = b'\x01\x02\x03' - script_hash = to_script_hash(script) - manifest = contracts.ContractManifest(script_hash) - contract = contracts.ContractState(script, manifest) - engine.snapshot.contracts.put(contract) - - # push function argument and call - engine.push(vm.ByteStringStackItem(script_hash.to_array())) - engine.invoke_syscall_by_name("System.Contract.IsStandard") - engine.execute() - self.assertEqual(False, engine.result_stack.pop().to_boolean()) - - def test_contract_is_standard_fail3(self): - # test on witnesses of a transaction - engine = test_engine(has_container=True, has_snapshot=True) - witness = payloads.Witness(invocation_script=b'\x01', verification_script=b'\x02') - engine.script_container.witnesses = [witness] - engine.push(vm.ByteStringStackItem(witness.script_hash().to_array())) - engine.invoke_syscall_by_name("System.Contract.IsStandard") - engine.execute() - self.assertEqual(False, engine.result_stack.pop().to_boolean()) - def test_contract_call_flags(self): engine = test_engine() engine.invoke_syscall_by_name("System.Contract.GetCallFlags") diff --git a/tests/contracts/interop/test_crypto.py b/tests/contracts/interop/test_crypto.py deleted file mode 100644 index b358ec61..00000000 --- a/tests/contracts/interop/test_crypto.py +++ /dev/null @@ -1,403 +0,0 @@ -import unittest -import binascii -from typing import List -from neo3.network import payloads -from neo3 import vm, storage, settings -from neo3.core import types, serialization, cryptography -from neo3.core.serialization import BinaryReader, BinaryWriter -from tests.contracts.interop.utils import test_engine, syscall_name_to_int -from neo3.contracts.interop.crypto import _check_multisig - - -class TestVerifiable(payloads.IVerifiable): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.teststr = "testStr" - - def serialize_unsigned(self, writer: serialization.BinaryWriter) -> None: - writer.write_uint8(len(self.teststr)) - writer.write_bytes(self.teststr.encode()) - - def deserialize_unsigned(self, reader: serialization.BinaryReader) -> None: - raise NotImplementedError() - - def get_script_hashes_for_verifying(self, snapshot: storage.Snapshot) -> List[types.UInt160]: - raise NotImplementedError() - - def serialize(self, writer: BinaryWriter) -> None: - raise NotImplementedError() - - def deserialize(self, reader: BinaryReader) -> None: - raise NotImplementedError() - - def __len__(self): - pass - - -class CryptoInteropTestCase(unittest.TestCase): - def shortDescription(self): - # disable docstring printing in test runner - return None - - def test_ripemd160_interop_type(self): - """ - using var script = new ScriptBuilder(); - script.EmitSysCall(ApplicationEngine.Neo_Crypto_RIPEMD160); // Syscall - var engine = ApplicationEngine.Create(TriggerType.Application, null, null, 100_000_000, false); - engine.LoadScript(script.ToArray()); - engine.Push(new InteropInterface(new TestVerifiable())); - Assert.AreEqual(engine.Execute(), VMState.HALT); - Assert.AreEqual(1, engine.ResultStack.Count); - var item = engine.ResultStack.Pop(); - Console.WriteLine($"{item.GetSpan().ToHexString()}"); - """ - # we have to set the network magic number, because that is serialized as part of the "get_hash_data()" call - settings.network.magic = 0x4F454E - - sb = vm.ScriptBuilder() - sb.emit_syscall(syscall_name_to_int("Neo.Crypto.RIPEMD160")) - - engine = test_engine() - script = vm.Script(sb.to_array()) - engine.load_script(script) - - # first test with an invalid interop item. They must be IVerifiable - engine.push(vm.InteropStackItem(object())) - engine.execute() - self.assertEqual(vm.VMState.FAULT, engine.state) - self.assertIn("Invalid type", engine.exception_message) - - engine = test_engine() - engine.load_script(script) - engine.push(vm.InteropStackItem(TestVerifiable())) - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(1, len(engine.result_stack)) - # captured from C# - expected = '72543eb0fa0ca623a95647f15dd55f52a327c77e' - self.assertEqual(expected, str(engine.result_stack.pop())) - - def test_ripemd160_null(self): - """ - var tx = new Neo.Network.P2P.Payloads.Transaction - { - Version = 0, - Nonce = 0, - SystemFee = 0, - NetworkFee = 0, - ValidUntilBlock = 99999, - Attributes = new TransactionAttribute[0], - Script = new byte[0], - Signers = new Signer[] { new Signer { Account = UInt160.Zero, Scopes = WitnessScope.FeeOnly}} - }; - using var script = new ScriptBuilder(); - script.EmitSysCall(ApplicationEngine.Neo_Crypto_RIPEMD160); // Syscall - var engine = ApplicationEngine.Create(TriggerType.Application, tx, null, 100_000_000, false); - engine.LoadScript(script.ToArray()); - engine.Push(StackItem.Null); - Assert.AreEqual(engine.Execute(), VMState.HALT); - Assert.AreEqual(1, engine.ResultStack.Count); - var item = engine.ResultStack.Pop(); - Console.WriteLine($"{item.GetSpan().ToHexString()}"); - """ - # we have to set the network magic number, because that is serialized as part of the "get_hash_data()" call - settings.network.magic = 0x4F454E - - sb = vm.ScriptBuilder() - sb.emit_syscall(syscall_name_to_int("Neo.Crypto.RIPEMD160")) - - engine = test_engine(has_container=True) - engine.script_container.signers = [payloads.Signer(types.UInt160.zero())] - script = vm.Script(sb.to_array()) - engine.load_script(script) - - engine.push(vm.NullStackItem()) - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(1, len(engine.result_stack)) - # captured from C# - expected = '0892b2402eb78d878a4c60fc799d879b672a5aa5' - self.assertEqual(expected, str(engine.result_stack.pop())) - - def test_ripemd160_other_types(self): - """ - using var script = new ScriptBuilder(); - script.EmitPush(new byte[] {0x1, 0x2, 0x3, 0x4}); - script.EmitSysCall(ApplicationEngine.Neo_Crypto_RIPEMD160); // Syscall - var engine = ApplicationEngine.Create(TriggerType.Application, null, null, 100_000_000, false); - engine.LoadScript(script.ToArray()); - Assert.AreEqual(engine.Execute(), VMState.HALT); - Assert.AreEqual(1, engine.ResultStack.Count); - var item = engine.ResultStack.Pop(); - Console.WriteLine($"{item.GetSpan().ToHexString()}"); - """ - sb = vm.ScriptBuilder() - sb.emit_push(b'\x01\x02\x03\x04') - sb.emit_syscall(syscall_name_to_int("Neo.Crypto.RIPEMD160")) - - engine = test_engine() - script = vm.Script(sb.to_array()) - engine.load_script(script) - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(1, len(engine.result_stack)) - # captured from C# - expected = '179bb366e5e224b8bf4ce302cefc5744961839c5' - self.assertEqual(expected, str(engine.result_stack.pop())) - - def test_sha256(self): - """ - using var script = new ScriptBuilder(); - script.EmitPush(new byte[] {0x1, 0x2, 0x3, 0x4}); - script.EmitSysCall(ApplicationEngine.Neo_Crypto_SHA256); // Syscall - var engine = ApplicationEngine.Create(TriggerType.Application, null, null, 100_000_000, false); - engine.LoadScript(script.ToArray()); - Assert.AreEqual(engine.Execute(), VMState.HALT); - Assert.AreEqual(1, engine.ResultStack.Count); - var item = engine.ResultStack.Pop(); - Console.WriteLine($"{item.GetSpan().ToHexString()}"); - """ - sb = vm.ScriptBuilder() - sb.emit_push(b'\x01\x02\x03\x04') - sb.emit_syscall(syscall_name_to_int("Neo.Crypto.SHA256")) - - engine = test_engine() - script = vm.Script(sb.to_array()) - engine.load_script(script) - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(1, len(engine.result_stack)) - # captured from C# - expected = '9f64a747e1b97f131fabb6b447296c9b6f0201e79fb3c5356e6c77e89b6a806a' - self.assertEqual(expected, str(engine.result_stack.pop())) - - def test_verify_secp256r1(self): - """ - var privkey = new byte[] - { - 2, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1 - }; - var message = new byte[] - { - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1 - }; - var signature = new byte[] { 56,70,104,22,234,182,23,161,111,25,71,188,12,5,54,28,99,189,8,47,4,82,62,150,57,216,25,130,217,25,123,118,89,149,217,130,12,109,34,125,176,189,142,119,154,140,116,16,32,209,214,87,178,248,214,39,248,29,214,10,205,153,146,111}; - var kp = new KeyPair(privkey); - Console.WriteLine(Crypto.VerifySignature(message, signature, kp.PublicKey.EncodePoint(false), ECCurve.Secp256r1)); - """ - message = b'\x01' * 32 - priv_key = b'\x02' + b'\x00' * 30 + b'\x01' - sig = cryptography.sign(message, priv_key) - - # from ecdsa import VerifyingKey, SigningKey, curves as ecdsa_curves - # import hashlib - # sk = SigningKey.from_string(priv_key, curve=ecdsa_curves.NIST256p, hashfunc=hashlib.sha256) - # sig = sk.sign(message, hashfunc=hashlib.sha256) - - kp = cryptography.KeyPair(priv_key) - - sb = vm.ScriptBuilder() - sb.emit_push(sig) - sb.emit_push(kp.public_key.encode_point(False)) - sb.emit_push(message) - sb.emit_syscall(syscall_name_to_int("Neo.Crypto.VerifyWithECDsaSecp256r1")) - - engine = test_engine() - script = vm.Script(sb.to_array()) - engine.load_script(script) - - # first test with an invalid interop item. They must be IVerifiable - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(1, len(engine.result_stack)) - self.assertEqual(vm.BooleanStackItem(True), engine.result_stack.pop()) - - def test_verify_secp256k1(self): - """ - byte[] message = System.Text.Encoding.Default.GetBytes("hello"); - byte[] signature = "5331be791532d157df5b5620620d938bcb622ad02c81cfc184c460efdad18e695480d77440c511e9ad02ea30d773cb54e88f8cbb069644aefa283957085f38b5".HexToBytes(); - byte[] pubKey = "03ea01cb94bdaf0cd1c01b159d474f9604f4af35a3e2196f6bdfdb33b2aa4961fa".HexToBytes(); - - Crypto.VerifySignature(message, signature, pubKey, Neo.Cryptography.ECC.ECCurve.Secp256k1).Should().BeTrue(); - """ - message = b'hello' - signature = binascii.unhexlify(b'5331be791532d157df5b5620620d938bcb622ad02c81cfc184c460efdad18e695480d77440c511e9ad02ea30d773cb54e88f8cbb069644aefa283957085f38b5') - public_key = binascii.unhexlify(b'03ea01cb94bdaf0cd1c01b159d474f9604f4af35a3e2196f6bdfdb33b2aa4961fa') - self.assertTrue(cryptography.verify_signature(message, signature, public_key, cryptography.ECCCurve.SECP256K1)) - - sb = vm.ScriptBuilder() - sb.emit_push(signature) - sb.emit_push(public_key) - sb.emit_push(message) - sb.emit_syscall(syscall_name_to_int("Neo.Crypto.VerifyWithECDsaSecp256k1")) - - engine = test_engine() - script = vm.Script(sb.to_array()) - engine.load_script(script) - - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(1, len(engine.result_stack)) - self.assertEqual(vm.BooleanStackItem(True), engine.result_stack.pop()) - - # again with bad signature - bad_signature = b'\xFF' + signature[1:] - sb = vm.ScriptBuilder() - sb.emit_push(bad_signature) - sb.emit_push(public_key) - sb.emit_push(message) - sb.emit_syscall(syscall_name_to_int("Neo.Crypto.VerifyWithECDsaSecp256k1")) - - engine = test_engine() - script = vm.Script(sb.to_array()) - engine.load_script(script) - - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(1, len(engine.result_stack)) - self.assertEqual(vm.BooleanStackItem(False), engine.result_stack.pop()) - - def test_multisig_verify_helper_bounds(self): - engine = None - message = vm.ByteStringStackItem(b'') - public_keys = [object()] - signatures = [] - - with self.assertRaises(ValueError) as context: - _check_multisig(engine, message, public_keys, signatures, cryptography.ECCCurve.SECP256R1) - self.assertEqual("No signatures supplied", str(context.exception)) - - public_keys = [] - signatures = [object()] - with self.assertRaises(ValueError) as context: - _check_multisig(engine, message, public_keys, signatures, cryptography.ECCCurve.SECP256R1) - self.assertEqual("No public keys supplied", str(context.exception)) - - public_keys = [object()] - signatures = [object(), object()] - with self.assertRaises(ValueError) as context: - _check_multisig(engine, message, public_keys, signatures, cryptography.ECCCurve.SECP256R1) - self.assertEqual("Verification requires 2 public keys, got only 1", str(context.exception)) - - def test_multisig_verify_helper_verification(self): - engine = test_engine() - message = vm.ByteStringStackItem(b'hello') - kp1 = cryptography.KeyPair(private_key=b'\x01' * 32) - kp2 = cryptography.KeyPair(private_key=b'\x02' * 32) - sig1 = cryptography.sign(message.to_array(), kp1.private_key) - sig2 = cryptography.sign(message.to_array(), kp2.private_key) - - # quick pre-check the verify_signature function actually passes - self.assertTrue(cryptography.verify_signature(message.to_array(), - sig1, - kp1.public_key.encode_point(False), - cryptography.ECCCurve.SECP256R1)) - self.assertTrue(cryptography.verify_signature(message.to_array(), - sig2, - kp2.public_key.encode_point(False), - cryptography.ECCCurve.SECP256R1)) - - # first do a check on regular data (meaning; check sig1 with pub_key1, sig2 with pub_key2) - public_keys = [kp1.public_key.encode_point(False), kp2.public_key.encode_point(False)] - signatures = [sig1, sig2] - self.assertTrue(_check_multisig(engine, message, public_keys, signatures, cryptography.ECCCurve.SECP256R1)) - - # same as previous, but supplying the keys out of order - public_keys = [kp2.public_key.encode_point(False), kp1.public_key.encode_point(False)] - signatures = [sig1, sig2] - self.assertFalse(_check_multisig(engine, message, public_keys, signatures, cryptography.ECCCurve.SECP256R1)) - - # now validate it will try all available public keys for a given signature (for 1-of-2, 3-of-5 like contracts) - public_keys = [kp2.public_key.encode_point(False), kp1.public_key.encode_point(False)] - signatures = [sig1] - self.assertTrue(_check_multisig(engine, message, public_keys, signatures, cryptography.ECCCurve.SECP256R1)) - - # test handling an exception caused by an invalid public key - public_keys = [b''] - signatures = [sig1] - self.assertFalse(_check_multisig(engine, message, public_keys, signatures, cryptography.ECCCurve.SECP256R1)) - - def test_check_multisig_with_ECDSA_Secp256r1_valid(self): - engine = test_engine() - message = vm.ByteStringStackItem(b'hello') - kp1 = cryptography.KeyPair(private_key=b'\x01' * 32) - sig1 = cryptography.sign(message.to_array(), kp1.private_key) - - signatures = vm.ArrayStackItem(engine.reference_counter) - signatures.append(vm.ByteStringStackItem(sig1)) - - public_keys = vm.ArrayStackItem(engine.reference_counter) - public_keys.append(vm.ByteStringStackItem(kp1.public_key.encode_point(False))) - - sb = vm.ScriptBuilder() - sb.emit_syscall(syscall_name_to_int("Neo.Crypto.CheckMultisigWithECDsaSecp256r1")) - script = vm.Script(sb.to_array()) - engine.load_script(script) - - # setup the stack for the syscall - engine.push(signatures) - engine.push(public_keys) - engine.push(message) - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(1, len(engine.result_stack)) - self.assertEqual(vm.BooleanStackItem(True), engine.result_stack.pop()) - - def test_check_multisig_with_ECDSA_Secp256r1_invalid(self): - engine = test_engine() - message = vm.ByteStringStackItem(b'hello') - bad_message = vm.ByteStringStackItem(b'badmessage') - - kp1 = cryptography.KeyPair(private_key=b'\x01' * 32) - sig1 = cryptography.sign(message.to_array(), kp1.private_key) - - signatures = vm.ArrayStackItem(engine.reference_counter) - signatures.append(vm.ByteStringStackItem(sig1)) - - public_keys = vm.ArrayStackItem(engine.reference_counter) - public_keys.append(vm.ByteStringStackItem(kp1.public_key.encode_point(False))) - - sb = vm.ScriptBuilder() - sb.emit_syscall(syscall_name_to_int("Neo.Crypto.CheckMultisigWithECDsaSecp256r1")) - script = vm.Script(sb.to_array()) - engine.load_script(script) - - # setup the stack for the syscall using a different message such that verification should fail - engine.push(signatures) - engine.push(public_keys) - engine.push(bad_message) - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(1, len(engine.result_stack)) - self.assertEqual(vm.BooleanStackItem(False), engine.result_stack.pop()) - - def test_check_multisig_with_ECDSA_Secp256k1(self): - # values taken from test_verify_secp256k1() - engine = test_engine() - message = vm.ByteStringStackItem(b'hello') - signature = vm.ByteStringStackItem(binascii.unhexlify(b'5331be791532d157df5b5620620d938bcb622ad02c81cfc184c460efdad18e695480d77440c511e9ad02ea30d773cb54e88f8cbb069644aefa283957085f38b5')) - signatures = vm.ArrayStackItem(engine.reference_counter) - signatures.append(signature) - - public_keys = vm.ArrayStackItem(engine.reference_counter) - public_key = vm.ByteStringStackItem(binascii.unhexlify(b'03ea01cb94bdaf0cd1c01b159d474f9604f4af35a3e2196f6bdfdb33b2aa4961fa')) - public_keys.append(public_key) - - sb = vm.ScriptBuilder() - sb.emit_syscall(syscall_name_to_int("Neo.Crypto.CheckMultisigWithECDsaSecp256k1")) - script = vm.Script(sb.to_array()) - engine.load_script(script) - - engine.push(signatures) - engine.push(public_keys) - engine.push(message) - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(1, len(engine.result_stack)) - self.assertEqual(vm.BooleanStackItem(True), engine.result_stack.pop()) diff --git a/tests/contracts/interop/test_json_interop.py b/tests/contracts/interop/test_json_interop.py deleted file mode 100644 index 37c1dcc8..00000000 --- a/tests/contracts/interop/test_json_interop.py +++ /dev/null @@ -1,58 +0,0 @@ -import unittest -import binascii -from neo3 import vm -from neo3 import contracts -from tests.contracts.interop.utils import syscall_name_to_int - - -class JSONInteropTestCase(unittest.TestCase): - def test_serialization(self): - script = vm.ScriptBuilder() - script.emit_push(5) - script.emit_syscall(syscall_name_to_int("System.Json.Serialize")) - script.emit(vm.OpCode.PUSH0) - script.emit(vm.OpCode.NOT) - script.emit_syscall(syscall_name_to_int("System.Json.Serialize")) - script.emit_push("test") - script.emit_syscall(syscall_name_to_int("System.Json.Serialize")) - script.emit(vm.OpCode.PUSHNULL) - script.emit_syscall(syscall_name_to_int("System.Json.Serialize")) - script.emit(vm.OpCode.NEWMAP) - script.emit(vm.OpCode.DUP) - script.emit_push("key") - script.emit_push("value") - script.emit(vm.OpCode.SETITEM) - script.emit_syscall(syscall_name_to_int("System.Json.Serialize")) - - data = script.to_array() - - engine = contracts.ApplicationEngine(contracts.TriggerType.APPLICATION, None, None, 0, True) - engine.load_script(vm.Script(data)) - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - - def pop_to_human_readable(): - return binascii.unhexlify(str(engine.result_stack.pop()).encode()).decode() - - self.assertEqual('{"key":"value"}', pop_to_human_readable()) - self.assertEqual('null', pop_to_human_readable()) - self.assertEqual('"test"', pop_to_human_readable()) - self.assertEqual('true', pop_to_human_readable()) - self.assertEqual('5', pop_to_human_readable()) - - def test_deserialization(self): - script = vm.ScriptBuilder() - script.emit_push(123) - script.emit_syscall(syscall_name_to_int("System.Json.Deserialize")) - script.emit_push("null") - script.emit_syscall(syscall_name_to_int("System.Json.Deserialize")) - - data = script.to_array() - - engine = contracts.ApplicationEngine(contracts.TriggerType.APPLICATION, None, None, 0, True) - engine.load_script(vm.Script(data)) - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(2, len(engine.result_stack._items)) - self.assertIsInstance(engine.result_stack.pop(), vm.NullStackItem) - self.assertEqual(vm.BigInteger(123), engine.result_stack.pop().to_biginteger()) diff --git a/tests/contracts/interop/test_storage.py b/tests/contracts/interop/test_storage.py index fc1309ec..86a40a6e 100644 --- a/tests/contracts/interop/test_storage.py +++ b/tests/contracts/interop/test_storage.py @@ -140,10 +140,10 @@ def test_storage_find(self): engine.snapshot.contracts.put(self.contract) storage_key1 = storage.StorageKey(self.contract.id, b'\x01') - storage_item1 = storage.StorageItem(b'\x11', is_constant=False) + storage_item1 = storage.StorageItem(b'\x11') engine.snapshot.storages.put(storage_key1, storage_item1) storage_key2 = storage.StorageKey(self.contract.id, b'\x02') - storage_item2 = storage.StorageItem(b'\x22', is_constant=False) + storage_item2 = storage.StorageItem(b'\x22') engine.snapshot.storages.put(storage_key2, storage_item2) ctx = engine.invoke_syscall_by_name("System.Storage.GetContext") @@ -166,79 +166,60 @@ def test_storage_find(self): def test_storage_put_helper_parameter_validation(self): with self.assertRaises(ValueError) as context: key = (b'\x01' * contracts.interop.MAX_STORAGE_KEY_SIZE) + b'\x01' - contracts.interop._storage_put_internal(None, None, key, b'', None) + contracts.interop.storage_put(None, None, key, b'') self.assertEqual(f"Storage key length exceeds maximum of {contracts.interop.MAX_STORAGE_KEY_SIZE}", str(context.exception)) with self.assertRaises(ValueError) as context: value = (b'\x01' * contracts.interop.MAX_STORAGE_VALUE_SIZE) + b'\x01' - contracts.interop._storage_put_internal(None, None, b'', value, None) + contracts.interop.storage_put(None, None, b'', value) self.assertEqual(f"Storage value length exceeds maximum of {contracts.interop.MAX_STORAGE_VALUE_SIZE}", str(context.exception)) with self.assertRaises(ValueError) as context: ctx = storage.StorageContext(None, is_read_only=True) - contracts.interop._storage_put_internal(None, ctx, b'', b'', None) + contracts.interop.storage_put(None, ctx, b'', b'') self.assertEqual("Cannot persist to read-only storage context", str(context.exception)) - # finaly make sure it fails if we try to modify an item that is marked constant + def test_storage_put_new(self): + # see `test_storage_get_key_not_found()` for a description on why the storage is setup with a script as is + + # setup engine = test_engine(has_snapshot=True) - key = storage.StorageKey(1, b'\x01') - item = storage.StorageItem(b'', is_constant=True) - engine.snapshot.storages.put(key, item) + script = vm.ScriptBuilder() + script.emit(vm.OpCode.PUSH2) # storage put value + script.emit(vm.OpCode.PUSH1) # storage put key + script.emit_syscall(syscall_name_to_int("System.Storage.GetContext")) + script.emit_syscall(syscall_name_to_int("System.Storage.Put")) + engine.load_script(vm.Script(script.to_array())) - with self.assertRaises(ValueError) as context: - ctx = storage.StorageContext(1, is_read_only=False) - contracts.interop._storage_put_internal(engine, ctx, b'\x01', b'', storage.StorageFlags.NONE) - self.assertEqual("StorageItem is marked as constant", str(context.exception)) + nef = contracts.NEF(script=script.to_array()) + manifest = contracts.ContractManifest(f"contractname1") + manifest.abi.methods = [ + contracts.ContractMethodDescriptor("test_func", 0, [], contracts.ContractParameterType.ANY, True) + ] + hash_ = to_script_hash(nef.script) - def test_storage_put_new(self): - # see `test_storage_get_key_not_found()` for a description on why the storage is setup with a script as is + contract = contracts.ContractState(1, nef, manifest, 0, hash_) + engine.snapshot.contracts.put(contract) - for i in range(2): - # setup - engine = test_engine(has_snapshot=True) - script = vm.ScriptBuilder() - if i == 0: - script.emit(vm.OpCode.PUSH2) # storage put value - script.emit(vm.OpCode.PUSH1) # storage put key - script.emit_syscall(syscall_name_to_int("System.Storage.GetContext")) - script.emit_syscall(syscall_name_to_int("System.Storage.Put")) - else: - script.emit(vm.OpCode.PUSH0) # storage put call flags - script.emit(vm.OpCode.PUSH2) # storage put value - script.emit(vm.OpCode.PUSH1) # storage put key - script.emit_syscall(syscall_name_to_int("System.Storage.GetContext")) - script.emit_syscall(syscall_name_to_int("System.Storage.PutEx")) - engine.load_script(vm.Script(script.to_array())) - - nef = contracts.NEF(script=script.to_array()) - manifest = contracts.ContractManifest(f"contractname{i}") - manifest.abi.methods = [ - contracts.ContractMethodDescriptor("test_func", 0, [], contracts.ContractParameterType.ANY, True) - ] - hash_ = to_script_hash(nef.script) - - contract = contracts.ContractState(i, nef, manifest, 0, hash_) - engine.snapshot.contracts.put(contract) - - engine.execute() - - self.assertEqual(vm.VMState.HALT, engine.state) - storage_key = storage.StorageKey(i, b'\x01') - item = engine.snapshot.storages.try_get(storage_key) - self.assertIsNotNone(item) - self.assertEqual(b'\x02', item.value) + engine.execute() + + self.assertEqual(vm.VMState.HALT, engine.state) + storage_key = storage.StorageKey(1, b'\x01') + item = engine.snapshot.storages.try_get(storage_key) + self.assertIsNotNone(item) + self.assertEqual(b'\x02', item.value) def test_storage_put_overwrite(self): # test with new data being shorter than the old data engine = test_engine(has_snapshot=True) key = b'\x01' storage_key = storage.StorageKey(1, key) - storage_item = storage.StorageItem(b'\x11\x22\x33', is_constant=False) + storage_item = storage.StorageItem(b'\x11\x22\x33') engine.snapshot.storages.put(storage_key, storage_item) ctx = storage.StorageContext(1, is_read_only=False) new_item_value = b'\x11\x22' - contracts.interop._storage_put_internal(engine, ctx, key, new_item_value, storage.StorageFlags.NONE) + contracts.interop.storage_put(engine, ctx, key, new_item_value) item = engine.snapshot.storages.get(storage_key) self.assertIsNotNone(item) @@ -246,7 +227,7 @@ def test_storage_put_overwrite(self): # now test with data being longer than before longer_item_value = b'\x11\x22\x33\x44' - contracts.interop._storage_put_internal(engine, ctx, key, longer_item_value, storage.StorageFlags.NONE) + contracts.interop.storage_put(engine, ctx, key, longer_item_value) item = engine.snapshot.storages.get(storage_key) self.assertIsNotNone(item) @@ -269,28 +250,12 @@ def test_storage_delete_readonly_context(self): engine.invoke_syscall_by_name("System.Storage.Delete") self.assertEqual("Cannot delete from read-only storage context", str(context.exception)) - def test_storage_delete_constant_item(self): - engine = test_engine(has_snapshot=True) - engine.snapshot.contracts.put(self.contract) - - storage_key = storage.StorageKey(self.contract.id, b'\x01') - storage_item = storage.StorageItem(b'\x11', is_constant=True) - engine.snapshot.storages.put(storage_key, storage_item) - - ctx = engine.invoke_syscall_by_name("System.Storage.GetContext") - engine.push(vm.ByteStringStackItem(storage_key.key)) - engine.push(vm.StackItem.from_interface(ctx)) - - with self.assertRaises(ValueError) as context: - engine.invoke_syscall_by_name("System.Storage.Delete") - self.assertEqual("Cannot delete a storage item that is marked constant", str(context.exception)) - def test_delete_ok(self): engine = test_engine(has_snapshot=True) engine.snapshot.contracts.put(self.contract) storage_key = storage.StorageKey(self.contract.id, b'\x01') - storage_item = storage.StorageItem(b'\x11', is_constant=False) + storage_item = storage.StorageItem(b'\x11') engine.snapshot.storages.put(storage_key, storage_item) ctx = engine.invoke_syscall_by_name("System.Storage.GetContext") diff --git a/tests/contracts/interop/utils.py b/tests/contracts/interop/utils.py index 23f16d7d..a0003aab 100644 --- a/tests/contracts/interop/utils.py +++ b/tests/contracts/interop/utils.py @@ -61,15 +61,16 @@ def test_tx(with_block_height=1, signers: List[types.UInt160]=None) -> payloads. def test_block(with_index=1) -> payloads.Block: tx = test_tx(with_index) - block1 = payloads.Block(version=0, - prev_hash=types.UInt256.from_string( - "f782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"), - timestamp=123, - index=with_index, - next_consensus=types.UInt160.from_string("d7678dd97c000be3f33e9362e673101bac4ca654"), - witness=payloads.Witness(invocation_script=b'', verification_script=b'\x55'), - consensus_data=payloads.ConsensusData(primary_index=1, nonce=123), - transactions=[tx]) + header1 = payloads.Header( + version=0, + prev_hash=types.UInt256.from_string("f782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"), + timestamp=123, + index=with_index, + primary_index=0, + next_consensus=types.UInt160.from_string("d7678dd97c000be3f33e9362e673101bac4ca654"), + witness=payloads.Witness(invocation_script=b'', verification_script=b'\x55') + ) + block1 = payloads.Block(header1, transactions=[tx]) block1.rebuild_merkle_root() return block1 diff --git a/tests/contracts/native/test_nameservice.py b/tests/contracts/native/test_nameservice.py index 570039ce..57e792b3 100644 --- a/tests/contracts/native/test_nameservice.py +++ b/tests/contracts/native/test_nameservice.py @@ -109,12 +109,12 @@ def test_register(self): # not signed with self.assertRaises(ValueError) as context: - nameservice.register(engine, "coz.org", types.UInt160.zero()) + nameservice.do_register(engine, "coz.org", types.UInt160.zero()) self.assertEqual("CheckWitness failed", str(context.exception)) - self.assertTrue(nameservice.register(engine, "coz.org", tx.sender)) + self.assertTrue(nameservice.do_register(engine, "coz.org", tx.sender)) # already registered with self.assertRaises(ValueError) as context: - nameservice.register(engine, "coz.org", tx.sender) + nameservice.do_register(engine, "coz.org", tx.sender) self.assertEqual("Registration failure - 'coz.org' is not available", str(context.exception)) diff --git a/tests/contracts/native/test_native.py b/tests/contracts/native/test_native.py index 0ff4f16d..94c4eeac 100644 --- a/tests/contracts/native/test_native.py +++ b/tests/contracts/native/test_native.py @@ -25,18 +25,6 @@ def test_requesting_contract_by_name(self): self.assertIsNone(contracts.NativeContract.get_contract_by_name("bogus_contract")) self.assertIsInstance(contracts.NativeContract.get_contract_by_name("PolicyContract"), contracts.PolicyContract) - def test_parameter_types_matched_parameter_names(self): - def dummy_func_no_params(): - pass - - class NativeTestContract(contracts.NativeContract): - def init(self): - self._register_contract_method(dummy_func_no_params, "dummy_func", 0, parameter_names=["error"]) - - with self.assertRaises(ValueError) as context: - NativeTestContract() - self.assertEqual("Parameter types count must match parameter names count! 0!=1", str(context.exception)) - def test_various(self): native = contracts.NativeContract() known_contracts = native.registered_contracts diff --git a/tests/contracts/native/test_nep17.py b/tests/contracts/native/test_nep17.py index 42fa8951..1f0a2ad7 100644 --- a/tests/contracts/native/test_nep17.py +++ b/tests/contracts/native/test_nep17.py @@ -109,7 +109,7 @@ def test_on_persist(self): engine.snapshot.persisting_block.transactions[0].signers = [mock_signer] # our consensus_data is not setup in a realistic way, so we have to correct for that here # or we fail to get the account of primary consensus node - engine.snapshot.persisting_block.consensus_data.primary_index = settings.network.validators_count - 1 + engine.snapshot.persisting_block.header.primary_index = settings.network.validators_count - 1 gas.on_persist(engine) diff --git a/tests/contracts/native/test_policy.py b/tests/contracts/native/test_policy.py index 755d3ccf..c8981ef2 100644 --- a/tests/contracts/native/test_policy.py +++ b/tests/contracts/native/test_policy.py @@ -50,33 +50,9 @@ def tearDownClass(cls) -> None: def test_basics(self): policy = contracts.PolicyContract() - self.assertEqual(-5, policy.id) + self.assertEqual(-7, policy.id) self.assertEqual("PolicyContract", contracts.PolicyContract().service_name()) - def test_policy_defaul_get_max_tx_per_block(self): - engine = test_native_contract(contracts.PolicyContract().hash, "getMaxTransactionsPerBlock") - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(1, len(engine.result_stack)) - item = engine.result_stack.pop() - self.assertEqual(vm.IntegerStackItem(512), item) - - def test_policy_default_get_max_block_size(self): - engine = test_native_contract(contracts.PolicyContract().hash, "getMaxBlockSize") - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(1, len(engine.result_stack)) - item = engine.result_stack.pop() - self.assertEqual(vm.IntegerStackItem(262144), item) - - def test_policy_default_get_max_block_system_fee(self): - engine = test_native_contract(contracts.PolicyContract().hash, "getMaxBlockSystemFee") - engine.execute() - self.assertEqual(vm.VMState.HALT, engine.state) - self.assertEqual(1, len(engine.result_stack)) - item = engine.result_stack.pop() - self.assertEqual(vm.IntegerStackItem(900000000000), item) - def test_policy_default_get_fee_per_byte(self): engine = test_native_contract(contracts.PolicyContract().hash, "getFeePerByte") engine.execute() @@ -166,18 +142,6 @@ def test_policy_setters_fail_without_signatures(self): engine.snapshot.persisting_block = block engine.script_container = TestIVerifiable() - with self.assertRaises(ValueError) as context: - policy._set_max_block_size(engine, 0) - self.assertEqual("Check committee failed", str(context.exception)) - - with self.assertRaises(ValueError) as context: - policy._set_max_transactions_per_block(engine, 0) - self.assertEqual("Check committee failed", str(context.exception)) - - with self.assertRaises(ValueError) as context: - policy._set_max_block_system_fee(engine, 5000000) - self.assertEqual("Check committee failed", str(context.exception)) - with self.assertRaises(ValueError) as context: policy._set_fee_per_byte(engine, 0) self.assertEqual("Check committee failed", str(context.exception)) diff --git a/tests/contracts/native/test_rolemanagement.py b/tests/contracts/native/test_rolemanagement.py index 7d3cc608..75fd1b95 100644 --- a/tests/contracts/native/test_rolemanagement.py +++ b/tests/contracts/native/test_rolemanagement.py @@ -23,7 +23,7 @@ def test_assign_and_get_role(self): engine = test_engine(has_snapshot=True, has_container=True) # set signers list to our committee to pass check_committee() validation engine.script_container.signers = [payloads.Signer( - types.UInt160.from_string("d8a929a05b368db8c6878e850d03fcec8a5cc3b2"), + types.UInt160.from_string("54166e586e86b9d653bf96f61e6568df7a8ecb50"), payloads.WitnessScope.GLOBAL )] public_key1 = cryptography.KeyPair(b'\x01' * 32).public_key diff --git a/tests/contracts/native/test_stdlib.py b/tests/contracts/native/test_stdlib.py new file mode 100644 index 00000000..3d767be0 --- /dev/null +++ b/tests/contracts/native/test_stdlib.py @@ -0,0 +1,237 @@ +import unittest +import binascii +from neo3 import vm, contracts +from neo3.core import syscall_name_to_int +from tests.contracts.interop.utils import test_engine + + +class StdLibTestCase(unittest.TestCase): + def test_binary_serialization(self): + engine = test_engine(has_snapshot=True) + original_item = 100 + sb = vm.ScriptBuilder() + sb.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "serialize", [original_item]) + engine.load_script(vm.Script(sb.to_array())) + engine.execute() + self.assertEqual(vm.VMState.HALT, engine.state) + item = engine.result_stack.pop() + self.assertIsInstance(item, vm.ByteStringStackItem) + self.assertEqual(b'\x21\x01\x64', item.to_array()) + + def test_binary_deserialization(self): + engine = test_engine(has_snapshot=True) + original_item = 100 + sb = vm.ScriptBuilder() + sb.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "serialize", [original_item]) + # now take the results of the "serialize" call and call "deserialize" on the StdLib contract + sb.emit_push(1) + sb.emit(vm.OpCode.PACK) # pack the results of "serialize" as the arguments for the next call + sb.emit_push(0xF) # CallFlags.ALL + sb.emit_push("deserialize") + sb.emit_push(contracts.StdLibContract().hash.to_array()) + sb.emit_syscall(syscall_name_to_int("System.Contract.Call")) + engine.load_script(vm.Script(sb.to_array())) + engine.execute() + self.assertEqual(vm.VMState.HALT, engine.state) + item = engine.result_stack.pop() + self.assertEqual(original_item, int(item)) + + def test_json_serialization(self): + sb = vm.ScriptBuilder() + sb.emit_push(5) + sb.emit_push(1) # arg len + sb.emit(vm.OpCode.PACK) + sb.emit_push(0xF) # CallFlags.ALL + sb.emit_push("jsonSerialize") + sb.emit_push(contracts.StdLibContract().hash.to_array()) + sb.emit_syscall(syscall_name_to_int("System.Contract.Call")) + + sb.emit(vm.OpCode.PUSH0) + sb.emit(vm.OpCode.NOT) + sb.emit_push(1) # arg len + sb.emit(vm.OpCode.PACK) + sb.emit_push(0xF) # CallFlags.ALL + sb.emit_push("jsonSerialize") + sb.emit_push(contracts.StdLibContract().hash.to_array()) + sb.emit_syscall(syscall_name_to_int("System.Contract.Call")) + + sb.emit_push("test") + sb.emit_push(1) # arg len + sb.emit(vm.OpCode.PACK) + sb.emit_push(0xF) # CallFlags.ALL + sb.emit_push("jsonSerialize") + sb.emit_push(contracts.StdLibContract().hash.to_array()) + sb.emit_syscall(syscall_name_to_int("System.Contract.Call")) + + sb.emit(vm.OpCode.PUSHNULL) + sb.emit_push(1) # arg len + sb.emit(vm.OpCode.PACK) + sb.emit_push(0xF) # CallFlags.ALL + sb.emit_push("jsonSerialize") + sb.emit_push(contracts.StdLibContract().hash.to_array()) + sb.emit_syscall(syscall_name_to_int("System.Contract.Call")) + + sb.emit(vm.OpCode.NEWMAP) + sb.emit(vm.OpCode.DUP) + sb.emit_push("key") + sb.emit_push("value") + sb.emit(vm.OpCode.SETITEM) + sb.emit_push(1) # arg len + sb.emit(vm.OpCode.PACK) + sb.emit_push(0xF) # CallFlags.ALL + sb.emit_push("jsonSerialize") + sb.emit_push(contracts.StdLibContract().hash.to_array()) + sb.emit_syscall(syscall_name_to_int("System.Contract.Call")) + + data = sb.to_array() + + engine = test_engine(has_snapshot=True) + engine.load_script(vm.Script(data)) + engine.execute() + self.assertEqual(vm.VMState.HALT, engine.state) + + def pop_to_human_readable(): + return binascii.unhexlify(str(engine.result_stack.pop()).encode()).decode() + + self.assertEqual('{"key":"value"}', pop_to_human_readable()) + self.assertEqual('null', pop_to_human_readable()) + self.assertEqual('"test"', pop_to_human_readable()) + self.assertEqual('true', pop_to_human_readable()) + self.assertEqual('5', pop_to_human_readable()) + + def test_json_deserialization(self): + script = vm.ScriptBuilder() + script.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "jsonDeserialize", [123]) + script.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "jsonDeserialize", ["null"]) + + engine = test_engine(has_snapshot=True) + data = script.to_array() + engine.load_script(vm.Script(data)) + engine.execute() + self.assertEqual(vm.VMState.HALT, engine.state) + self.assertEqual(2, len(engine.result_stack._items)) + self.assertIsInstance(engine.result_stack.pop(), vm.NullStackItem) + self.assertEqual(vm.BigInteger(123), engine.result_stack.pop().to_biginteger()) + + def test_atoi(self): + engine = test_engine(has_snapshot=True) + original_item = b'100' + base = 10 + sb = vm.ScriptBuilder() + sb.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "atoi", [original_item, base]) + engine.load_script(vm.Script(sb.to_array())) + engine.execute() + self.assertEqual(vm.VMState.HALT, engine.state) + item = engine.result_stack.pop() + self.assertEqual(vm.IntegerStackItem(100), item) + + engine = test_engine(has_snapshot=True) + original_item = b'64' + base = 16 + sb = vm.ScriptBuilder() + sb.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "atoi", [original_item, base]) + engine.load_script(vm.Script(sb.to_array())) + engine.execute() + self.assertEqual(vm.VMState.HALT, engine.state) + item = engine.result_stack.pop() + self.assertEqual(vm.IntegerStackItem(100), item) + + engine = test_engine(has_snapshot=True) + invalid_base = 2 + sb = vm.ScriptBuilder() + sb.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "atoi", [original_item, invalid_base]) + engine.load_script(vm.Script(sb.to_array())) + engine.execute() + + self.assertEqual(vm.VMState.FAULT, engine.state) + self.assertIn("Invalid base specified", engine.exception_message) + + def test_itoa(self): + engine = test_engine(has_snapshot=True) + original_item = 100 + base = 10 + sb = vm.ScriptBuilder() + sb.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "itoa", [original_item, base]) + engine.load_script(vm.Script(sb.to_array())) + engine.execute() + self.assertEqual(vm.VMState.HALT, engine.state) + item = engine.result_stack.pop() + self.assertEqual('100', item.to_array().decode('utf-8')) + + engine = test_engine(has_snapshot=True) + base = 16 + sb = vm.ScriptBuilder() + sb.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "itoa", [original_item, base]) + engine.load_script(vm.Script(sb.to_array())) + engine.execute() + self.assertEqual(vm.VMState.HALT, engine.state) + item = engine.result_stack.pop() + self.assertEqual('64', item.to_array().decode('utf-8')) + + engine = test_engine(has_snapshot=True) + invalid_base = 2 + sb = vm.ScriptBuilder() + sb.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "itoa", [original_item, invalid_base]) + engine.load_script(vm.Script(sb.to_array())) + engine.execute() + + self.assertEqual(vm.VMState.FAULT, engine.state) + self.assertIn("Invalid base specified", engine.exception_message) + + def test_base64_encode(self): + engine = test_engine(has_snapshot=True) + original_item = 100 + sb = vm.ScriptBuilder() + sb.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "base64Encode", [original_item]) + engine.load_script(vm.Script(sb.to_array())) + engine.execute() + self.assertEqual(vm.VMState.HALT, engine.state) + item = engine.result_stack.peek() + self.assertEqual('ZA==', item.to_array().decode()) + + def test_base64_decode(self): + engine = test_engine(has_snapshot=True) + original_item = 100 + sb = vm.ScriptBuilder() + sb.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "base64Encode", [original_item]) + # now take the results of the "base64Encode" call and call "base64Decode" on the StdLib contract + sb.emit_push(1) # arg len + sb.emit(vm.OpCode.PACK) # pack the results of "base64Encode" as the arguments for the next call + sb.emit_push(0xF) # CallFlags.ALL + sb.emit_push("base64Decode") + sb.emit_push(contracts.StdLibContract().hash.to_array()) + sb.emit_syscall(syscall_name_to_int("System.Contract.Call")) + engine.load_script(vm.Script(sb.to_array())) + engine.execute() + self.assertEqual(vm.VMState.HALT, engine.state) + item = engine.result_stack.pop() + self.assertEqual(original_item, int(item.to_biginteger())) + + def test_base58_encode(self): + engine = test_engine(has_snapshot=True) + original_item = 100 + sb = vm.ScriptBuilder() + sb.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "base58Encode", [original_item]) + engine.load_script(vm.Script(sb.to_array())) + engine.execute() + self.assertEqual(vm.VMState.HALT, engine.state) + item = engine.result_stack.peek() + self.assertEqual('2j', item.to_array().decode()) + + def test_base58_decode(self): + engine = test_engine(has_snapshot=True) + original_item = 100 + sb = vm.ScriptBuilder() + sb.emit_dynamic_call_with_args(contracts.StdLibContract().hash, "base58Encode", [original_item]) + # now take the results of the "base58Encode" call and call "base64Decode" on the StdLib contract + sb.emit_push(1) + sb.emit(vm.OpCode.PACK) # pack the results of "base58Encode" as the arguments for the next call + sb.emit_push(0xF) # CallFlags.ALL + sb.emit_push("base58Decode") + sb.emit_push(contracts.StdLibContract().hash.to_array()) + sb.emit_syscall(syscall_name_to_int("System.Contract.Call")) + engine.load_script(vm.Script(sb.to_array())) + engine.execute() + self.assertEqual(vm.VMState.HALT, engine.state) + item = engine.result_stack.pop() + self.assertEqual(original_item, int(item.to_biginteger())) diff --git a/tests/contracts/test_binaryserializer.py b/tests/contracts/test_binaryserializer.py index e30253c1..874fb62e 100644 --- a/tests/contracts/test_binaryserializer.py +++ b/tests/contracts/test_binaryserializer.py @@ -41,7 +41,7 @@ def test_with_map(self): # now test deserialization m[i] = b # restore m[i] to original content - new_m = contracts.BinarySerializer.deserialize(out, 2048, len(out), self.reference_counter) + new_m = contracts.BinarySerializer.deserialize(out, 2048, self.reference_counter) self.assertEqual(len(m), len(new_m)) self.assertEqual(m.keys(), new_m.keys()) self.assertEqual(m.values(), new_m.values()) @@ -74,7 +74,7 @@ def test_with_array(self): # now test deserialization # first remove the reference to self a.remove(len(a) - 1) - new_a = contracts.BinarySerializer.deserialize(out, 2048, len(out), self.reference_counter) + new_a = contracts.BinarySerializer.deserialize(out, 2048, self.reference_counter) self.assertIsInstance(a, vm.ArrayStackItem) self.assertEqual(a._items, new_a._items) @@ -83,7 +83,7 @@ def test_with_null(self): x = contracts.BinarySerializer.serialize(n, 999) self.assertEqual(b'\x00', x) - new_n = contracts.BinarySerializer.deserialize(x, 1, 1, self.reference_counter) + new_n = contracts.BinarySerializer.deserialize(x, 1, self.reference_counter) self.assertIsInstance(new_n, vm.NullStackItem) def test_serialize_with_invalid_stackitem(self): @@ -102,14 +102,14 @@ def test_serialize_with_exceeding_max_length(self): def test_deserialize_invalid_data(self): with self.assertRaises(ValueError) as context: - contracts.BinarySerializer.deserialize(b'', 1, 1, self.reference_counter) + contracts.BinarySerializer.deserialize(b'', 1, self.reference_counter) self.assertEqual("Nothing to deserialize", str(context.exception)) def test_deserialize_bytestring(self): data = b'\x01\x02' b = vm.ByteStringStackItem(data) b_serialized = contracts.BinarySerializer.serialize(b, 999) - new_b = contracts.BinarySerializer.deserialize(b_serialized, 999, len(data), self.reference_counter) + new_b = contracts.BinarySerializer.deserialize(b_serialized, 999, self.reference_counter) self.assertIsInstance(new_b, vm.ByteStringStackItem) self.assertEqual(new_b, b) @@ -117,7 +117,7 @@ def test_deserialize_buffer(self): data = b'\x01\x02' b = vm.BufferStackItem(data) b_serialized = contracts.BinarySerializer.serialize(b, 999) - new_b = contracts.BinarySerializer.deserialize(b_serialized, 999, len(data), self.reference_counter) + new_b = contracts.BinarySerializer.deserialize(b_serialized, 999, self.reference_counter) self.assertIsInstance(new_b, vm.BufferStackItem) self.assertEqual(new_b.to_array(), b.to_array()) @@ -127,7 +127,7 @@ def test_deserialize_struct(self): bool2 = vm.BooleanStackItem(False) s.append([bool1, bool2]) s_serialized = contracts.BinarySerializer.serialize(s, 999) - new_s = contracts.BinarySerializer.deserialize(s_serialized, 999, 2, self.reference_counter) + new_s = contracts.BinarySerializer.deserialize(s_serialized, 999, self.reference_counter) self.assertIsInstance(new_s, vm.StructStackItem) for l, r in zip(new_s._items, s._items): self.assertEqual(l, r) @@ -138,5 +138,5 @@ def test_deserialize_with_invalid_format(self): b_serialized = bytearray(contracts.BinarySerializer.serialize(b, 999)) b_serialized[0] = 0xFF # non existing stackitem type with self.assertRaises(ValueError) as context: - contracts.BinarySerializer.deserialize(b_serialized, 999, len(data), self.reference_counter) + contracts.BinarySerializer.deserialize(b_serialized, 999, self.reference_counter) self.assertEqual("Invalid format", str(context.exception)) \ No newline at end of file diff --git a/tests/contracts/test_contract.py b/tests/contracts/test_contract.py index 780ce6c7..0f9f6130 100644 --- a/tests/contracts/test_contract.py +++ b/tests/contracts/test_contract.py @@ -22,8 +22,7 @@ def test_create_signature_contract(self): var c = Contract.CreateSignatureContract(kp1.PublicKey); Console.WriteLine(c.Script.ToHexString()); """ - - expected = binascii.unhexlify(b'0c21026ff03b949241ce1dadd43519e6960e0a85b41a69a05c328103aa2bce1594ca160b4195440d78') + expected = binascii.unhexlify(b'0c21026ff03b949241ce1dadd43519e6960e0a85b41a69a05c328103aa2bce1594ca1641747476aa') keypair = cryptography.KeyPair(private_key=b'\x01' * 32) contract = contracts.Contract.create_signature_contract(keypair.public_key) self.assertEqual(expected, contract.script) @@ -47,15 +46,16 @@ def test_create_multisignature_contract(self): var kp1 = new KeyPair(priv_key1); var kp2 = new KeyPair(priv_key2); var c = Contract.CreateMultiSigContract(1, new ECPoint[] {kp1.PublicKey, kp2.PublicKey}); + Console.WriteLine(c.Script.ToHexString()); + Console.WriteLine(c.ScriptHash.ToArray().ToHexString()); """ - expected_script = binascii.unhexlify(b'110c2102550f471003f3df97c3df506ac797f6721fb1a1fb7b8f6f83d224498a65c88e240c21026ff03b949241ce1dadd43519e6960e0a85b41a69a05c328103aa2bce1594ca16120b41138defaf') - expected_script_hash = types.UInt160(binascii.unhexlify(b'205bc1a9d199eecb30ab0c1ff027456ce7998e1f')) + expected_script = binascii.unhexlify(b'110c2102550f471003f3df97c3df506ac797f6721fb1a1fb7b8f6f83d224498a65c88e240c21026ff03b949241ce1dadd43519e6960e0a85b41a69a05c328103aa2bce1594ca1612417bce6ca5') + expected_script_hash = types.UInt160(binascii.unhexlify(b'2514b406a154dd2b1148a79f33a8d6926e4a30c3')) keypair1 = cryptography.KeyPair(private_key=b'\x01' * 32) keypair2 = cryptography.KeyPair(private_key=b'\x02' * 32) contract = contracts.Contract.create_multisig_contract(1, [keypair1.public_key, keypair2.public_key]) self.assertEqual(expected_script, contract.script) - self.assertEqual(expected_script_hash, contract.script_hash) - + self.assertEqual(str(expected_script_hash), str(contract.script_hash)) def test_create_multisignature_redeemscript_invalid_arguments(self): with self.assertRaises(ValueError) as context: @@ -76,43 +76,36 @@ def test_is_signature_contract(self): - PUSHDATA1 (0xC) - LEN PUBLIC KEY (33) - PUBLIC KEY data - - PUSHNULL (0xB) - SYSCALL (0x41) - - "Neo.Crypto.VerifyWithECDsaSecp256r1" identifier + - "Neo.Crypto.CheckSig" identifier """ incorrect_script_len = b'\x01' * 10 self.assertFalse(contracts.Contract.is_signature_contract(incorrect_script_len)) # first byte should be PUSHDATA1 (0xC) - incorrect_script_start_byte = b'\x01' * 41 + incorrect_script_start_byte = b'\x01' * 40 self.assertFalse(contracts.Contract.is_signature_contract(incorrect_script_start_byte)) # second byte should be 33 - incorrect_second_byte = bytearray(b'\x01' * 41) + incorrect_second_byte = bytearray(b'\x01' * 40) incorrect_second_byte[0] = int(vm.OpCode.PUSHDATA1) self.assertFalse(contracts.Contract.is_signature_contract(incorrect_second_byte)) - # index 35 should be PUSHNULL - incorrect_idx_35 = bytearray([0xc, 33]) + b'\01' * 39 - self.assertFalse(contracts.Contract.is_signature_contract(incorrect_idx_35)) + # index 35 should be SYSCALL + incorrect_idx_35 = bytearray([0xc, 33]) + b'\01' * 38 + incorrect_idx_35[35] = int(vm.OpCode.PUSHNULL) + self.assertFalse(contracts.Contract.is_signature_contract(incorrect_idx_35)) # index 35 should be SYSCALL - # index 36 should be SYSCALL - incorrect_idx_36 = bytearray([0xc, 33]) + b'\01' * 39 - incorrect_idx_36[35] = int(vm.OpCode.PUSHNULL) - self.assertFalse(contracts.Contract.is_signature_contract(incorrect_idx_36)) # index 36 should be SYSCALL - - # the last 4 bytes should be the "Neo.Crypto.VerifyWithECDsaSecp256r1" SYSCALL - incorrect_syscall_number = bytearray([0xc, 33]) + b'\01' * 39 - incorrect_syscall_number[35] = int(vm.OpCode.PUSHNULL) - incorrect_syscall_number[36] = int(vm.OpCode.SYSCALL) + # the last 4 bytes should be the "Neo.Crypto.CheckSig" SYSCALL + incorrect_syscall_number = bytearray([0xc, 33]) + b'\01' * 38 + incorrect_syscall_number[35] = int(vm.OpCode.SYSCALL) self.assertFalse(contracts.Contract.is_signature_contract(incorrect_syscall_number)) # and finally a contract that matches the correct format - correct = bytearray([0xc, 33]) + b'\01' * 39 - correct[35] = int(vm.OpCode.PUSHNULL) - correct[36] = int(vm.OpCode.SYSCALL) - correct[37:41] = contracts.syscall_name_to_int("Neo.Crypto.VerifyWithECDsaSecp256r1").to_bytes(4, 'little') + correct = bytearray([0xc, 33]) + b'\01' * 38 + correct[35] = int(vm.OpCode.SYSCALL) + correct[36:40] = contracts.syscall_name_to_int("Neo.Crypto.CheckSig").to_bytes(4, 'little') self.assertTrue(contracts.Contract.is_signature_contract(correct)) def test_is_multisig_contract_too_short(self): @@ -185,18 +178,14 @@ def test_script_invalid_tail(self): # and assert we don't have enough data left for the remainder of the checks self.assertFalse(contracts.Contract.is_multisig_contract(script)) - # now we extend with 6 bytes to give enough data - # the first should be PUSHNULL (0xB) but isn't - script += b'\x00' * 6 - self.assertFalse(contracts.Contract.is_multisig_contract(script)) - - # we fix the PUSHNULL, and the next should be SYSCALL, but isn't - script[-6] = int(vm.OpCode.PUSHNULL) + # now we extend with 5 bytes to give enough data + script += b'\x00' * 5 + # the first byte should be a SYSCALL, but it isn't self.assertFalse(contracts.Contract.is_multisig_contract(script)) - # finally test the last 4 bytes should be "Neo.Crypto.VerifyWithECDsaSecp256r1" syscall number - # all we have to do is fix the syscall opcode + # we fix the SYSCALL byte script[-5] = int(vm.OpCode.SYSCALL) + # finally test the last 4 bytes should be "Neo.Crypto.CheckSig" syscall number self.assertFalse(contracts.Contract.is_multisig_contract(script)) def test_is_multsig_contract_ok(self): @@ -222,9 +211,8 @@ def test_is_multisig_contract_256_pubkeys(self): # now we correct the public key count in the script and make it valid by adding the expected tail script[-2] = 2 - script += bytearray([int(vm.OpCode.PUSHNULL)]) script += bytearray([int(vm.OpCode.SYSCALL)]) - script += contracts.syscall_name_to_int("Neo.Crypto.CheckMultisigWithECDsaSecp256r1").to_bytes(4, 'little') + script += contracts.syscall_name_to_int("Neo.Crypto.CheckMultisig").to_bytes(4, 'little') self.assertTrue(contracts.Contract.is_multisig_contract(script)) def test_is_multsig_contract_invalid_pubkey_count(self): @@ -235,11 +223,11 @@ def test_is_multsig_contract_invalid_pubkey_count(self): contract = contracts.Contract.create_multisig_contract(1, [keypair.public_key]) # and modify the claimed public key length from 1 to 255 data = bytearray(contract.script) - data[-7] = int(vm.OpCode.PUSH16) - 1 + data[-6] = int(vm.OpCode.PUSH16) - 1 self.assertFalse(contracts.Contract.is_multisig_contract(data)) # and finally we try it again but this time the public key length is in an invalid range - data[-7] = int(vm.OpCode.PUSH16) + 1 + data[-6] = int(vm.OpCode.PUSH16) + 1 self.assertFalse(contracts.Contract.is_multisig_contract(data)) def test_is_multisig_contract_invalid_sig_counts(self): diff --git a/tests/contracts/test_descriptors.py b/tests/contracts/test_descriptors.py index 8bb55689..498e26d2 100644 --- a/tests/contracts/test_descriptors.py +++ b/tests/contracts/test_descriptors.py @@ -22,7 +22,7 @@ def test_group(self): self.assertTrue(cpd.is_hash) self.assertFalse(cpd.is_group) self.assertFalse(cpd.is_wildcard) - self.assertDictEqual({'contract': '0000000000000000000000000000000000000000'}, cpd.to_json()) + self.assertDictEqual({'contract': '0x0000000000000000000000000000000000000000'}, cpd.to_json()) cpd_from_json = contracts.ContractPermissionDescriptor.from_json(cpd.to_json()) self.assertEqual(cpd.contract_hash, cpd_from_json.contract_hash) self.assertEqual(cpd.group, cpd_from_json.group) diff --git a/tests/core/test_cryptography.py b/tests/core/test_cryptography.py index bc70b986..9b784840 100644 --- a/tests/core/test_cryptography.py +++ b/tests/core/test_cryptography.py @@ -4,6 +4,7 @@ from neo3.core import types from neo3.core import cryptography as crypto + class MerkleTreeTestCase(unittest.TestCase): def test_compute_root_single_hash(self): data = binascii.unhexlify(b'aa' * 32) @@ -23,9 +24,7 @@ def test_compute_root_multiple_hashes(self): self.assertEqual(expected_hash, root.to_array()) def test_computer_root_no_input(self): - with self.assertRaises(ValueError) as context: - crypto.MerkleTree.compute_root([]) - self.assertIn("Hashes list can't empty",str(context.exception)) + self.assertEqual(types.UInt256.zero(), crypto.MerkleTree.compute_root([])) def test_build_no_leaves(self): with self.assertRaises(ValueError) as context: diff --git a/tests/network/test_node.py b/tests/network/test_node.py index 95abca8e..759c6176 100644 --- a/tests/network/test_node.py +++ b/tests/network/test_node.py @@ -3,7 +3,6 @@ import logging import asyncio import binascii -from functools import partial from neo3.network import node, message, payloads, capabilities, ipfilter, protocol, encode_base62 from neo3 import settings, network_logger @@ -25,8 +24,8 @@ def __init__(self, loop, hostaddr: str, port: int): capabilities.ServerCapability(n_type=capabilities.NodeCapabilityType.TCPSERVER, port=10333)] self.m_send_version = message.Message(msg_type=message.MessageType.VERSION, payload=payloads.VersionPayload(nonce=123, - user_agent="NEO3-MOCK-CLIENT", - capabilities=caps)) + user_agent="NEO3-MOCK-CLIENT", + capabilities=caps)) self.m_verack = message.Message(msg_type=message.MessageType.VERACK) def _recv_data(self): @@ -377,7 +376,7 @@ async def test_processing_messages(self): m_mempool = message.Message(msg_type=message.MessageType.MEMPOOL, payload=payloads.EmptyPayload()) # taken from the Headers testcase in `test_payloads` - raw_headers_payload = binascii.unhexlify(b'020000000001FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A402FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A400000000000000007B0000008A2B438EACA8B4B2AB6B4524B5A69A45D920C35101020102020304000000000001FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A402FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A400000000000000007B0000008A2B438EACA8B4B2AB6B4524B5A69A45D920C3510102010202030400') + raw_headers_payload = binascii.unhexlify(b'0000000001FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A402FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A400000000000000007B00000000F7B4D00143932F3B6243CFC06CB4A68F22C739E201020102020304') m_headers = message.Message(msg_type=message.MessageType.HEADERS, payload=payloads.HeadersPayload.deserialize_from_bytes(raw_headers_payload)) m_ping = message.Message(msg_type=message.MessageType.PING, payload=payloads.PingPayload(0)) diff --git a/tests/network/test_payloads.py b/tests/network/test_payloads.py index b0e8a9fc..ce84a008 100644 --- a/tests/network/test_payloads.py +++ b/tests/network/test_payloads.py @@ -122,33 +122,43 @@ class BlockTestCase(unittest.TestCase): @classmethod def setUpClass(cls) -> None: """ - Transaction tx = new Transaction(); - tx.Nonce = 123; - tx.SystemFee = 456; - tx.NetworkFee = 789; - tx.ValidUntilBlock = 1; - tx.Attributes = new TransactionAttribute[0]; - tx.Signers = new Signer[] { new Signer() { Account = UInt160.Parse("0xe239c7228fa6b46cc0cf43623b2f934301d0b4f7")}}; - tx.Script = new byte[] { 0x1 }; - tx.Witnesses = new Witness[0]; - - - - Block b = new Block(); - b.Version = 0; - b.PrevHash = UInt256.Parse("0xf782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"); - b.Timestamp = 123; - b.Index = 1; - b.NextConsensus = UInt160.Parse("0xd7678dd97c000be3f33e9362e673101bac4ca654"); - b.Witness = new Witness { InvocationScript = new byte[0], VerificationScript = new byte[] { 0x55 } }; - b.ConsensusData = new ConsensusData { Nonce = 123, PrimaryIndex = 0 }; - b.Transactions = new Transaction[] { tx }; - b.RebuildMerkleRoot(); + Transaction tx = new Transaction(); + tx.Nonce = 123; + tx.SystemFee = 456; + tx.NetworkFee = 789; + tx.ValidUntilBlock = 1; + tx.Attributes = new TransactionAttribute[0]; + tx.Signers = new Signer[] { new Signer() { Account = UInt160.Parse("0xe239c7228fa6b46cc0cf43623b2f934301d0b4f7")}}; + tx.Script = new byte[] { 0x1 }; + tx.Witnesses = new Witness[] {new Witness { InvocationScript = new byte[0], VerificationScript = new byte[] { 0x55 } }}; + + Block b = new Block + { + Header = new Header + { + Version = 0, + PrevHash = UInt256.Parse("0xf782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"), + Timestamp = 123, + Index = 1, + PrimaryIndex = 0, + NextConsensus = UInt160.Parse("0xd7678dd97c000be3f33e9362e673101bac4ca654"), + Witness = new Witness { InvocationScript = new byte[0], VerificationScript = new byte[] { 0x55 } }, + MerkleRoot = UInt256.Zero + }, + Transactions = new Transaction[] { tx } + }; + b.Header.MerkleRoot = MerkleTree.ComputeRoot(b.Transactions.Select(p => p.Hash).ToArray()); + Console.WriteLine($"{b.Size}"); + Console.WriteLine($"{BitConverter.ToString(b.ToArray()).Replace("-", "")}"); + + var trimmedBlock = new TrimmedBlock + { + Header = b.Header, + Hashes = b.Transactions.Select(p => p.Hash).ToArray() + }; + Console.WriteLine($"{trimmedBlock.Size}"); + Console.WriteLine($"{BitConverter.ToString(trimmedBlock.ToArray()).Replace("-", "")}"); - Console.WriteLine($"{b.Size}"); - Console.WriteLine($"{BitConverter.ToString(b.ToArray()).Replace("-", "")}"); - Console.WriteLine($"{b.Trim().Size}"); - Console.WriteLine($"{BitConverter.ToString(b.Trim().ToArray()).Replace("-", "")}"); """ cls.tx = payloads.Transaction(version=0, nonce=123, @@ -158,22 +168,22 @@ def setUpClass(cls) -> None: attributes=[], signers=[payloads.Signer(types.UInt160.from_string("e239c7228fa6b46cc0cf43623b2f934301d0b4f7"))], script=b'\x01', - witnesses=[]) - - cls.block = payloads.Block(version=0, - prev_hash=types.UInt256.from_string("f782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"), - timestamp=123, - index=1, - next_consensus=types.UInt160.from_string("d7678dd97c000be3f33e9362e673101bac4ca654"), - witness=payloads.Witness(invocation_script=b'', verification_script=b'\x55'), - consensus_data=payloads.ConsensusData(primary_index=0, nonce=123), + witnesses=[payloads.Witness(invocation_script=b'', verification_script=b'\x55')]) + + cls.header = payloads.Header(version=0, + prev_hash=types.UInt256.from_string("f782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"), + timestamp=123, + index=1, + primary_index=0, + next_consensus=types.UInt160.from_string("d7678dd97c000be3f33e9362e673101bac4ca654"), + witness=payloads.Witness(invocation_script=b'', verification_script=b'\x55')) + cls.block = payloads.Block(cls.header, transactions=[cls.tx]) cls.block.rebuild_merkle_root() - cls.trimmed_block = cls.block.trim() def test_len(self): # captured from C#, see setUpClass() for the capture code - expected_len = 165 + expected_len = 160 self.assertEqual(expected_len, len(self.block)) def test_equals(self): @@ -182,13 +192,13 @@ def test_equals(self): # test different hashes modified_block = deepcopy(self.block) - modified_block.timestamp = 1 + modified_block.header.timestamp = 1 self.assertFalse(self.block == modified_block) self.assertTrue(self.block == self.block) def test_serialization(self): # captured from C#, see setUpClass() for the capture code - expected_data = binascii.unhexlify("000000000AAEB9C2C35A97AF7C054553CE64DA5E52FB530D6CB929E6AFF6EEB2FBC782F749A149B5AEED7B7DBD753FE54F9FCC4A0B368221EF06F76DC4ABB0317972BEE07B000000000000000100000054A64CAC1B1073E662933EF3E30B007CD98D67D70100015502007B00000000000000007B000000C80100000000000015030000000000000100000001F7B4D00143932F3B6243CFC06CB4A68F22C739E20000010100") + expected_data = binascii.unhexlify("000000000AAEB9C2C35A97AF7C054553CE64DA5E52FB530D6CB929E6AFF6EEB2FBC782F75618ADD6F91FAD691D6A4D430DB27CE5CA607296863E73A23FC3622A415CDD407B00000000000000010000000054A64CAC1B1073E662933EF3E30B007CD98D67D70100015501007B000000C80100000000000015030000000000000100000001F7B4D00143932F3B6243CFC06CB4A68F22C739E20000010101000155") self.assertEqual(expected_data, self.block.to_array()) def test_deserialization(self): @@ -198,24 +208,12 @@ def test_deserialization(self): self.assertEqual(self.block.prev_hash, deserialized_block.prev_hash) self.assertEqual(self.block.timestamp, deserialized_block.timestamp) self.assertEqual(self.block.index, deserialized_block.index) + self.assertEqual(self.block.primary_index, deserialized_block.primary_index) self.assertEqual(self.block.next_consensus, deserialized_block.next_consensus) self.assertEqual(self.block.witness.invocation_script, deserialized_block.witness.invocation_script) self.assertEqual(self.block.witness.verification_script, deserialized_block.witness.verification_script) - self.assertEqual(self.block.consensus_data.primary_index, deserialized_block.consensus_data.primary_index) - self.assertEqual(self.block.consensus_data.nonce, deserialized_block.consensus_data.nonce) self.assertEqual(1, len(deserialized_block.transactions)) - def test_deserialization_zero_contents(self): - # a block can't have 0 contents - block_contents_length_index = 104 - raw_data = bytearray(self.block.to_array()) - - # we force the block contents length to 0 - raw_data[block_contents_length_index] = 0 - with self.assertRaises(ValueError) as context: - payloads.Block.deserialize_from_bytes(raw_data) - self.assertIn("Deserialization error - no contents", str(context.exception)) - def test_deserialization_no_duplicate_transactions(self): # A block should not have duplicate transactions block_copy = deepcopy(self.block) @@ -226,7 +224,7 @@ def test_deserialization_no_duplicate_transactions(self): def test_deserialization_wrong_merkle_root(self): block_copy = deepcopy(self.block) - block_copy.merkle_root = types.UInt256.zero() + block_copy.header.merkle_root = types.UInt256.zero() with self.assertRaises(ValueError) as context: payloads.Block.deserialize_from_bytes(block_copy.to_array()) self.assertIn("Deserialization error - merkle root mismatch", str(context.exception)) @@ -238,130 +236,16 @@ def test_trim(self): trimmed_block = self.block.trim() self.assertIsInstance(trimmed_block, payloads.TrimmedBlock) # captured from C#, see setUpClass() for the capture code - expected_len = 178 + expected_len = 138 self.assertEqual(expected_len, len(trimmed_block)) - expected_data = binascii.unhexlify('000000000AAEB9C2C35A97AF7C054553CE64DA5E52FB530D6CB929E6AFF6EEB2FBC782F749A149B5AEED7B7DBD753FE54F9FCC4A0B368221EF06F76DC4ABB0317972BEE07B000000000000000100000054A64CAC1B1073E662933EF3E30B007CD98D67D70100015502CEBBAA303E74D5CAC6DF34823B7484B7760460295EC1E97845FDF138F9A87A62DBB73FBF82438E317ABA947D8853907AB259BDCEB8A5771AF394371492BD7D88007B00000000000000') + expected_data = binascii.unhexlify('000000000AAEB9C2C35A97AF7C054553CE64DA5E52FB530D6CB929E6AFF6EEB2FBC782F75618ADD6F91FAD691D6A4D430DB27CE5CA607296863E73A23FC3622A415CDD407B00000000000000010000000054A64CAC1B1073E662933EF3E30B007CD98D67D701000155015618ADD6F91FAD691D6A4D430DB27CE5CA607296863E73A23FC3622A415CDD40') self.assertEqual(expected_data, trimmed_block.to_array()) deserialized_trimmed_block = payloads.TrimmedBlock.deserialize_from_bytes(trimmed_block.to_array()) - self.assertEqual(trimmed_block.version, deserialized_trimmed_block.version) - self.assertEqual(trimmed_block.prev_hash, deserialized_trimmed_block.prev_hash) - self.assertEqual(trimmed_block.timestamp, deserialized_trimmed_block.timestamp) - self.assertEqual(trimmed_block.index, deserialized_trimmed_block.index) - self.assertEqual(trimmed_block.next_consensus, deserialized_trimmed_block.next_consensus) - self.assertEqual(trimmed_block.witness.invocation_script, deserialized_trimmed_block.witness.invocation_script) - self.assertEqual(trimmed_block.witness.verification_script, deserialized_trimmed_block.witness.verification_script) - self.assertEqual(trimmed_block.consensus_data.primary_index, deserialized_trimmed_block.consensus_data.primary_index) - self.assertEqual(trimmed_block.consensus_data.nonce, deserialized_trimmed_block.consensus_data.nonce) + self.assertEqual(trimmed_block.header, deserialized_trimmed_block.header) self.assertEqual(trimmed_block.hashes, deserialized_trimmed_block.hashes) - self.assertEqual(2, len(deserialized_trimmed_block.hashes)) - - -class ConsensusDataTestCase(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - """ - ConsensusData cd = new ConsensusData(); - cd.PrimaryIndex = 0; - cd.Nonce = 456; - Console.WriteLine(cd.Size); - Console.WriteLine(cd.Hash); - Console.WriteLine($"b\'{BitConverter.ToString(cd.ToArray()).Replace("-", "")}\'"); - """ - cd = payloads.ConsensusData() - cd.primary_index = 0 - cd.nonce = 456 - cls.consensus_data = cd - - def test_len_and_hash(self): - # captured from C#, see setUpClass() for the capture code - expected_len = 9 - expected_hash = types.UInt256.from_string('b616ce734d5d6bfb0c5b3c9fe890b29299f5338c1af9156342e4df9d5828a303') - self.assertEqual(expected_len, len(self.consensus_data)) - self.assertEqual(expected_hash, self.consensus_data.hash()) - - def test_serialization(self): - # captured from C#, see setUpClass() for the capture code - expected_data = binascii.unhexlify(b'00C801000000000000') - self.assertEqual(expected_data, self.consensus_data.to_array()) - - def test_deserialization(self): - # if the serialization() test for this class passes, we can use that as a reference to test deserialization against - deserialized_consensus = payloads.ConsensusData.deserialize_from_bytes(self.consensus_data.to_array()) - self.assertEqual(self.consensus_data.primary_index, deserialized_consensus.primary_index) - self.assertEqual(self.consensus_data.nonce, deserialized_consensus.nonce) - - -class ConsensusPayloadTestCase(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - """ - ConsensusPayload cp = new ConsensusPayload - { - Version = 1, - PrevHash = UInt256.Parse("f782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"), - BlockIndex = 2, - ValidatorIndex = 0, - Witness = new Witness - { - InvocationScript = new byte[0], - VerificationScript = new byte[0] - }, - Data = new byte[] { - 0x0, /* ConsensusMessageType.CHANGEVIEW */ - 0x1, /* View number */ - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, /* ChangeView.Timestamp */ - 0x0 /* ChangeViewReason.Timeout */ - } - }; - Console.WriteLine(cp.Size); - Console.WriteLine(cp.Hash); - Console.WriteLine($"b\'{BitConverter.ToString(cp.ToArray()).Replace("-", "")}\'"); - """ - cls.payload = payloads.ConsensusPayload( - version=1, - prev_hash=types.UInt256.from_string("f782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"), - block_index=2, - validator_index=0, - data=b'\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00', - witness=payloads.Witness(bytearray(), bytearray()) - ) - - def test_len_and_hash(self): - # captured from C#, see setUpClass() for the capture code - expected_len = 56 - expected_hash = types.UInt256.from_string('44d22b68b530cfc7f1c1586e7e516368227bffd95a912413af7ea424f5605633') - self.assertEqual(expected_len, len(self.payload)) - self.assertEqual(expected_hash, self.payload.hash()) - - def test_serialization(self): - # captured from C#, see setUpClass() for the capture code - expected_data = binascii.unhexlify(b'010000000AAEB9C2C35A97AF7C054553CE64DA5E52FB530D6CB929E6AFF6EEB2FBC782F702000000000B0001000000000000000000010000') - self.assertEqual(expected_data, self.payload.to_array()) - - def test_deserialization(self): - # if the serialization() test for this class passes, we can use that as a reference to test deserialization against - deserialized_consensus_payload = payloads.ConsensusPayload.deserialize_from_bytes(self.payload.to_array()) - self.assertEqual(self.payload.version, deserialized_consensus_payload.version) - self.assertEqual(self.payload.prev_hash, deserialized_consensus_payload.prev_hash) - self.assertEqual(self.payload.block_index, deserialized_consensus_payload.block_index) - self.assertEqual(self.payload.validator_index, deserialized_consensus_payload.validator_index) - self.assertEqual(self.payload.data, deserialized_consensus_payload.data) - self.assertEqual(self.payload.witness.invocation_script, deserialized_consensus_payload.witness.invocation_script) - self.assertEqual(self.payload.witness.verification_script, deserialized_consensus_payload.witness.verification_script) - - def test_deserialization_error(self): - # an exception should be thrown if the validation byte is wrong - payload_data = bytearray(self.payload.to_array()) - # modify validation byte - payload_data[-3] = 0xEE - with self.assertRaises(ValueError) as context: - payloads.ConsensusPayload.deserialize_from_bytes(payload_data) - self.assertIn("Deserialization error - validation byte not 1", str(context.exception)) - - def test_inventory_type(self): - self.assertEqual(payloads.InventoryType.CONSENSUS, self.payload.inventory_type) + self.assertEqual(1, len(deserialized_trimmed_block.hashes)) class SignerTestCase(unittest.TestCase): @@ -572,6 +456,7 @@ def setUpClass(cls) -> None: MerkleRoot = UInt256.Parse("a400ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff02"), Timestamp = 0, Index = 123, + PrimaryIndex = 0, NextConsensus = UInt160.Parse("0xe239c7228fa6b46cc0cf43623b2f934301d0b4f7"), Witness = new Witness { @@ -588,21 +473,22 @@ def setUpClass(cls) -> None: merkleroot = types.UInt256.from_string("a400ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff02") timestamp = 0 index = 123 + primary_index = 0 next_consensus = types.UInt160.from_string("e239c7228fa6b46cc0cf43623b2f934301d0b4f7") witness = payloads.Witness(invocation_script=b'\x01\x02', verification_script=b'\x03\x04') - cls.header = payloads.Header(version, previous_hash, timestamp, index, next_consensus, witness, merkleroot) + cls.header = payloads.Header(version, previous_hash, timestamp, index, primary_index, next_consensus, witness, merkleroot) def test_len_and_hash(self): # captured from C#, see setUpClass() for the capture code expected_len = 108 - expected_hash = types.UInt256.from_string('8672a4dbd51bb911d1988d633f539ccf05e46cf614160be15472b9ae10e43a88') + expected_hash = types.UInt256.from_string('47a455f322441b6c3b4dffd039df34ca724fac222686e515043f01c494687868') self.assertEqual(expected_len, len(self.header)) self.assertEqual(expected_hash, self.header.hash()) def test_serialization(self): # captured from C#, see setUpClass() for the capture code - expected_data = binascii.unhexlify(b'0000000001FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A402FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A400000000000000007B000000F7B4D00143932F3B6243CFC06CB4A68F22C739E20102010202030400') + expected_data = binascii.unhexlify(b'0000000001FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A402FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A400000000000000007B00000000F7B4D00143932F3B6243CFC06CB4A68F22C739E201020102020304') self.assertEqual(expected_data, self.header.to_array()) def test_deserialization(self): @@ -621,26 +507,14 @@ def test_deserialization(self): def test_deserialization_failure1(self): # there should be a 1 byte witness object count (fixed to value 1) before the actual witness object. # see https://github.com/neo-project/neo/issues/1128 - raw_data = binascii.unhexlify(b'0000000001FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A402FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A400000000000000007B0000008A2B438EACA8B4B2AB6B4524B5A69A45D920C351FF02010202030400') - deserialized_header = payloads.Header._serializable_init() - - with self.assertRaises(ValueError) as context: - with serialization.BinaryReader(raw_data) as br: - deserialized_header.deserialize(br) - self.assertIn("Deserialization error", str(context.exception)) - self.assertIn("Witness object count is 255 must be 1", str(context.exception)) - - - def test_deserialization_failure2(self): - # the last byte in the stream should always be 0, this is to differentiate between blocks and headers according to - # https://github.com/neo-project/neo/pull/1129#issuecomment-537102207 - raw_data = binascii.unhexlify(b'0000000001FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A402FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A400000000000000007B0000008A2B438EACA8B4B2AB6B4524B5A69A45D920C3510102010202030411') + raw_data = binascii.unhexlify(b'0000000001FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A402FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A400000000000000007B00000000F7B4D00143932F3B6243CFC06CB4A68F22C739E200020102020304') deserialized_header = payloads.Header._serializable_init() with self.assertRaises(ValueError) as context: with serialization.BinaryReader(raw_data) as br: deserialized_header.deserialize(br) self.assertIn("Deserialization error", str(context.exception)) + self.assertIn("Witness object count is 0 must be 1", str(context.exception)) def test_equals(self): self.assertFalse(None == self.header) @@ -657,26 +531,28 @@ class HeadersPayloadTestCase(unittest.TestCase): @classmethod def setUpClass(cls) -> None: """ - Neo.IO.Json.JObject json = new Neo.IO.Json.JObject + Header h1 = new Header { - ["version"] = 0, - ["previousblockhash"] = "a400ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff01", - ["merkleroot"] = "a400ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff02", - ["time"] = 0, - ["index"] = 123, - ["nextconsensus"] = "AUNSizuErA3dv1a2ag2ozvikkQS7hhPY1X", - ["witnesses"] = new Neo.IO.Json.JArray - { - new Neo.IO.Json.JObject { - ["invocation"] = "0102", - ["verification"] = "0304" + Version = 0, + PrevHash = UInt256.Parse("a400ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff01"), + MerkleRoot = UInt256.Parse("a400ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff02"), + Timestamp = 0, + Index = 123, + PrimaryIndex = 0, + NextConsensus = UInt160.Parse("8a2b438eaca8b4b2ab6b4524b5a69a45d920c351"), + Witness = new Witness + { + InvocationScript = new byte[] {0x1, 0x2}, + VerificationScript = new byte[] {0x3, 0x4} } - } }; + Header h2; + using (BinaryReader reader = new BinaryReader(new MemoryStream(h1.ToArray()))) + { + h2 = reader.ReadSerializable
(); + } - Header h1 = Header.FromJson(json); - Header h2 = Header.FromJson(json); - HeadersPayload hp = HeadersPayload.Create(new List
{ h1, h2 }); + HeadersPayload hp = HeadersPayload.Create(new Header[] { h1, h2 }); Console.WriteLine(hp.Size); Console.WriteLine($"b\'{BitConverter.ToString(hp.ToArray()).Replace("-", "")}\'"); """ @@ -685,12 +561,12 @@ def setUpClass(cls) -> None: merkleroot = types.UInt256.from_string("a400ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff00ff02") timestamp = 0 index = 123 - addr_data = base58.b58decode_check('AUNSizuErA3dv1a2ag2ozvikkQS7hhPY1X')[1:] - next_consensus = types.UInt160(data=addr_data) + primary_index = 0 + next_consensus = types.UInt160.from_string("8a2b438eaca8b4b2ab6b4524b5a69a45d920c351") witness = payloads.Witness(invocation_script=b'\x01\x02', verification_script=b'\x03\x04') - h1 = payloads.Header(version, previous_hash, timestamp, index, next_consensus, witness, merkleroot) - h2 = payloads.Header(version, previous_hash, timestamp, index, next_consensus, witness, merkleroot) + h1 = payloads.Header(version, previous_hash, timestamp, index, primary_index, next_consensus, witness, merkleroot) + h2 = payloads.Header(version, previous_hash, timestamp, index, primary_index, next_consensus, witness, merkleroot) cls.payload = payloads.HeadersPayload.create([h1, h2]) def test_len(self): @@ -700,7 +576,7 @@ def test_len(self): def test_serialization(self): # captured from C#, see setUpClass() for the capture code - expected_data = binascii.unhexlify(b'020000000001FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A402FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A400000000000000007B0000008A2B438EACA8B4B2AB6B4524B5A69A45D920C35101020102020304000000000001FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A402FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A400000000000000007B0000008A2B438EACA8B4B2AB6B4524B5A69A45D920C3510102010202030400') + expected_data = binascii.unhexlify(b'020000000001FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A402FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A400000000000000007B0000000051C320D9459AA6B524456BABB2B4A8AC8E432B8A010201020203040000000001FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A402FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00A400000000000000007B0000000051C320D9459AA6B524456BABB2B4A8AC8E432B8A01020102020304') self.assertEqual(expected_data, self.payload.to_array()) def test_deserialization(self): @@ -758,24 +634,29 @@ def setUpClass(cls) -> None: tx.Script = new byte[] { 0x1 }; tx.Witnesses = new Witness[0]; + Block b = new Block + { + Header = new Header + { + Version = 0, + PrevHash = UInt256.Parse("0xf782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"), + Timestamp = 123, + Index = 1, + PrimaryIndex = 0, + NextConsensus = UInt160.Parse("0xd7678dd97c000be3f33e9362e673101bac4ca654"), + Witness = new Witness { InvocationScript = new byte[0], VerificationScript = new byte[] { 0x55 } }, + MerkleRoot = UInt256.Zero + }, + Transactions = new Transaction[] { tx } + }; + b.Header.MerkleRoot = MerkleTree.ComputeRoot(b.Transactions.Select(p => p.Hash).ToArray()); - - Block b = new Block(); - b.Version = 0; - b.PrevHash = UInt256.Parse("0xf782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"); - b.Timestamp = 123; - b.Index = 1; - b.NextConsensus = UInt160.Parse("0xd7678dd97c000be3f33e9362e673101bac4ca654"); - b.Witness = new Witness { InvocationScript = new byte[0], VerificationScript = new byte[] { 0x55 } }; - b.ConsensusData = new ConsensusData { Nonce = 123, PrimaryIndex = 1 }; - b.Transactions = new Transaction[] { tx }; - b.RebuildMerkleRoot(); - - byte[] bytes = { 0x1, 0x2 }; + byte[] bytes = { 0x1 }; BitArray flags = new BitArray(bytes); MerkleBlockPayload mbp = MerkleBlockPayload.Create(b, flags); Console.WriteLine($"b\'{BitConverter.ToString(mbp.ToArray()).Replace("-", "")}\'"); """ + cls.tx = payloads.Transaction(version=0, nonce=123, system_fee=456, @@ -784,38 +665,38 @@ def setUpClass(cls) -> None: attributes=[], signers=[payloads.Signer(types.UInt160.from_string("e239c7228fa6b46cc0cf43623b2f934301d0b4f7"))], script=b'\x01', - witnesses=[]) - - cls.block = payloads.Block(version=0, - prev_hash=types.UInt256.from_string("f782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"), - timestamp=123, - index=1, - next_consensus=types.UInt160.from_string("d7678dd97c000be3f33e9362e673101bac4ca654"), - witness=payloads.Witness(invocation_script=b'', verification_script=b'\x55'), - consensus_data=payloads.ConsensusData(primary_index=1, nonce=123), + witnesses=[payloads.Witness(invocation_script=b'', verification_script=b'\x55')]) + + cls.header = payloads.Header(version=0, + prev_hash=types.UInt256.from_string("f782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"), + timestamp=123, + index=1, + primary_index=0, + next_consensus=types.UInt160.from_string("d7678dd97c000be3f33e9362e673101bac4ca654"), + witness=payloads.Witness(invocation_script=b'', verification_script=b'\x55')) + cls.block = payloads.Block(cls.header, transactions=[cls.tx]) cls.block.rebuild_merkle_root() flags = bitarray() - flags.frombytes(b'\x01\x02') + flags.frombytes(b'\x01') cls.merkle_payload = payloads.MerkleBlockPayload(cls.block, flags) def test_len(self): # captured from C#, see setUpClass() for the capture code - expected_len = 176 + expected_len = 144 self.assertEqual(expected_len, len(self.merkle_payload)) def test_serialization(self): # captured from C#, see setUpClass() for the capture code - expected_data = binascii.unhexlify(b'000000000AAEB9C2C35A97AF7C054553CE64DA5E52FB530D6CB929E6AFF6EEB2FBC782F72F9B61E3B410EF24D86B2BAFD9F2611AD8F43A9F7167FC58C3FCCC80BBFD40A67B000000000000000100000054A64CAC1B1073E662933EF3E30B007CD98D67D7010001550202FCAF61CDF5BEF2AB0FFAC66D846D14EDF06C84A0FD852264918E2F1E2E0A546CDBB73FBF82438E317ABA947D8853907AB259BDCEB8A5771AF394371492BD7D88020102') + expected_data = binascii.unhexlify(b'000000000AAEB9C2C35A97AF7C054553CE64DA5E52FB530D6CB929E6AFF6EEB2FBC782F75618ADD6F91FAD691D6A4D430DB27CE5CA607296863E73A23FC3622A415CDD407B00000000000000010000000054A64CAC1B1073E662933EF3E30B007CD98D67D70100015501015618ADD6F91FAD691D6A4D430DB27CE5CA607296863E73A23FC3622A415CDD400101') self.assertEqual(expected_data, self.merkle_payload.to_array()) def test_deserialization(self): # if the serialization() test for this class passes, we can use that as a reference to test deserialization against deserialized_merkle_payload = payloads.MerkleBlockPayload.deserialize_from_bytes(self.merkle_payload.to_array()) # not testing all properties again. It re-uses the same block as created in the Block test case - self.assertEqual(self.merkle_payload.prev_hash, deserialized_merkle_payload.prev_hash) # only testing new properties - self.assertEqual(self.merkle_payload.content_count, deserialized_merkle_payload.content_count) + self.assertEqual(self.merkle_payload.tx_count, deserialized_merkle_payload.tx_count) self.assertEqual(len(self.merkle_payload.hashes), len(deserialized_merkle_payload.hashes)) self.assertEqual(self.merkle_payload.hashes[0], deserialized_merkle_payload.hashes[0]) self.assertEqual(self.merkle_payload.flags, deserialized_merkle_payload.flags) @@ -962,7 +843,7 @@ def setUpClass(cls) -> None: Signer co = new Signer(); co.Account = UInt160.Parse("0xd7678dd97c000be3f33e9362e673101bac4ca654"); - co.Scopes = WitnessScope.FeeOnly; + co.Scopes = WitnessScope.None; tx.Signers = new Signer[] { co }; tx.Script = new byte[] { 0x1, 0x2 }; @@ -994,7 +875,7 @@ def tearDown(cls) -> None: def test_len_and_hash(self): # captured from C#, see setUpClass() for the capture code expected_len = 55 - expected_hash = types.UInt256.from_string('175cdc35664fc27e09b1970f190b6dce41d82c5409882e74c395f57de5c84ecd') + expected_hash = types.UInt256.from_string('da0343daadf88f95ece657fed6c20e05256d37f0d68757034da1d34f534d2c2c') self.assertEqual(expected_len, len(self.tx)) self.assertEqual(expected_hash, self.tx.hash()) diff --git a/tests/storage/storagetest.py b/tests/storage/storagetest.py index 1853f0ac..317d377d 100644 --- a/tests/storage/storagetest.py +++ b/tests/storage/storagetest.py @@ -41,21 +41,21 @@ def setUp(self) -> None: attributes=[], signers=[signer], script=b'\x01', - witnesses=[]) - - self.block1 = payloads.Block(version=0, - prev_hash=types.UInt256.from_string("f782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"), - timestamp=123, - index=1, - next_consensus=types.UInt160.from_string("d7678dd97c000be3f33e9362e673101bac4ca654"), - witness=payloads.Witness(invocation_script=b'', verification_script=b'\x55'), - consensus_data=payloads.ConsensusData(primary_index=0, nonce=123), - transactions=[tx]) + witnesses=[payloads.Witness(invocation_script=b'', verification_script=b'\x55')]) + + self.header = payloads.Header(version=0, + prev_hash=types.UInt256.from_string("f782c7fbb2eef6afe629b96c0d53fb525eda64ce5345057caf975ac3c2b9ae0a"), + timestamp=123, + index=1, + primary_index=0, + next_consensus=types.UInt160.from_string("d7678dd97c000be3f33e9362e673101bac4ca654"), + witness=payloads.Witness(invocation_script=b'', verification_script=b'\x55'),) + self.block1 = payloads.Block(self.header, transactions=[tx]) self.block1.rebuild_merkle_root() self.block1_hash = self.block1.hash() self.block2 = deepcopy(self.block1) - self.block2.index = 2 + self.block2.header.index = 2 self.block2_hash = self.block2.hash() def test_raw(self): @@ -189,7 +189,7 @@ def test_snapshot_get_various(self): raw_view = self.db.get_rawview() raw_view.blocks.put(self.block1) block = snapshot_view.blocks.get(self.block1_hash, read_only=True) - block.index = 123 + block.header.index = 123 block_again = snapshot_view.blocks.get(self.block1_hash, read_only=True) # We validate the hash of the original with the hash of the block we retrieved. @@ -199,7 +199,7 @@ def test_snapshot_get_various(self): # same as above but test read_only for get_by_height() block = snapshot_view.blocks.get_by_height(self.block1.index, read_only=True) - block.index = 123 + block.header.index = 123 block_again = snapshot_view.blocks.get(self.block1_hash, read_only=True) self.assertEqual(self.block1_hash, block_again.hash()) @@ -272,15 +272,15 @@ def test_all(self): snapshot_view = self.db.get_snapshotview() # get() a block to fill the cache so we can test sorting and readonly behaviour - # block2's hash comes before block1 when sorting. So we cache that first as the all() function internals - # collect the results from the backend (=block1) before results from the cache (=block2). - # Therefore if block2 is found in the first position of the all() results, we can + # block1's hash comes before block2 when sorting. So we cache that first as the all() function internals + # collect the results from the backend (=block2) before results from the cache (=block1). + # Therefore if block1 is found in the first position of the all() results, we can # conclude that the sort() happened correctly. - snapshot_view.blocks.get(self.block2_hash) + snapshot_view.blocks.get(self.block1_hash) blocks = list(snapshot_view.blocks.all()) self.assertEqual(2, len(blocks)) - self.assertEqual(self.block2, blocks[0]) - self.assertEqual(self.block1, blocks[1]) + self.assertEqual(self.block1, blocks[0]) + self.assertEqual(self.block2, blocks[1]) # ensure all() results are readonly blocks[0].transactions.append(payloads.Transaction._serializable_init()) @@ -293,7 +293,7 @@ def test_all(self): # test clone all() block3 = deepcopy(self.block1) - block3.index = 3 + block3.header.index = 3 clone_view = snapshot_view.clone() clone_view.blocks.put(block3) @@ -301,8 +301,8 @@ def test_all(self): self.assertEqual(3, len(blocks)) self.assertEqual(2, len(list(snapshot_view.blocks.all()))) self.assertEqual(self.block1, blocks[1]) - self.assertEqual(self.block2, blocks[0]) - self.assertEqual(block3, blocks[2]) + self.assertEqual(self.block2, blocks[2]) + self.assertEqual(block3, blocks[0]) def test_snapshot_bestblockheight(self): snapshot_view = self.db.get_snapshotview() @@ -1471,5 +1471,5 @@ def test_all(self): self.assertEqual(3, len(txs)) self.assertEqual(2, len(list(snapshot_view.transactions.all()))) self.assertEqual(self.tx1, txs[2]) - self.assertEqual(self.tx2, txs[1]) - self.assertEqual(tx3, txs[0]) + self.assertEqual(self.tx2, txs[0]) + self.assertEqual(tx3, txs[1]) diff --git a/tests/storage/test_item_and_key.py b/tests/storage/test_item_and_key.py index ee33622d..0a74b8ff 100644 --- a/tests/storage/test_item_and_key.py +++ b/tests/storage/test_item_and_key.py @@ -68,15 +68,13 @@ def test_eq(self): def test_len(self): si = storage.StorageItem(b'\x01') - self.assertEqual(3, len(si)) + self.assertEqual(2, len(si)) def test_serialization(self): si_data = b'\x01\x02\x03' - si = storage.StorageItem(si_data, False) - length_indicator = b'\x03' - bool_false = b'\x00' - self.assertEqual(length_indicator + si_data + bool_false, si.to_array()) - self.assertEqual(si, storage.StorageItem.deserialize_from_bytes(length_indicator + si_data + bool_false)) + si = storage.StorageItem(si_data) + self.assertEqual(si_data, si.to_array()) + self.assertEqual(si, storage.StorageItem.deserialize_from_bytes(si_data)) def test_clone_from_replica(self): si_data = b'\x01\x02\x03'