From 706ae5fc50663db380b793359e6c6781ffd9ca8c Mon Sep 17 00:00:00 2001 From: Lior Goldberg Date: Fri, 17 Mar 2023 16:36:20 +0200 Subject: [PATCH] Cairo v0.11.0 (pre2). --- scripts/requirements-gen.txt | 2 +- src/cmake_utils/python_rules.cmake | 27 + src/starkware/cairo/common/CMakeLists.txt | 2 + .../cairo/common/cairo_function_runner.py | 5 +- src/starkware/cairo/common/dict.py | 4 +- src/starkware/cairo/common/patricia.cairo | 30 +- .../cairo/common/patricia_utils.cairo | 23 + .../cairo/common/patricia_with_poseidon.cairo | 34 ++ .../cairo/common/patricia_with_sponge.cairo | 30 +- src/starkware/cairo/lang/VERSION | 2 +- .../cairo/lang/builtins/CMakeLists.txt | 1 + .../lang/builtins/poseidon/instance_def.py | 7 +- .../poseidon/poseidon_builtin_runner.py | 4 +- src/starkware/cairo/lang/compiler/program.py | 17 +- src/starkware/cairo/lang/instances.py | 4 - src/starkware/cairo/lang/tracer/tracer.py | 2 +- src/starkware/cairo/lang/vm/cairo_runner.py | 31 +- src/starkware/cairo/lang/vm/crypto.py | 1 + .../cairo/lang/vm/memory_segments.py | 5 +- .../cairo/lang/vm/virtual_machine_base.py | 5 +- src/starkware/cairo/lang/vm/vm_consts.py | 33 +- src/starkware/cairo/lang/vm/vm_consts_test.py | 8 +- .../air/layouts/all_cairo/public_verify.cairo | 8 + .../air/layouts/dex/public_verify.cairo | 4 + .../air/layouts/recursive/public_verify.cairo | 4 + .../air/layouts/small/public_verify.cairo | 4 + .../air/layouts/starknet/public_verify.cairo | 7 + .../starknet_with_keccak/public_verify.cairo | 8 + src/starkware/eth/eth_test_utils.py | 4 +- src/starkware/starknet/CMakeLists.txt | 1 + .../starknet/builtins/CMakeLists.txt | 1 + .../builtins/segment_arena/CMakeLists.txt | 27 + .../segment_arena/segment_arena.cairo | 90 +++ .../segment_arena_builtin_runner.py | 39 ++ .../segment_arena/segment_arena_test.cairo | 94 +++ .../segment_arena/segment_arena_test.py | 44 ++ .../business_logic/execution/CMakeLists.txt | 1 + .../execution/execute_entry_point.py | 35 +- .../execution/execute_entry_point_base.py | 1 + .../execution/os_resources.json | 58 +- .../fact_state/contract_class_objects.py | 20 +- .../fact_state/patricia_state.py | 6 +- .../business_logic/fact_state/state.py | 11 +- .../starknet/business_logic/state/state.py | 72 ++- .../business_logic/state/state_api.py | 2 +- .../business_logic/transaction/objects.py | 77 +-- .../starknet/business_logic/utils.py | 8 +- src/starkware/starknet/cli/CMakeLists.txt | 2 + src/starkware/starknet/cli/class_hash.py | 28 +- .../starknet/cli/compiled_class_hash.py | 39 ++ src/starkware/starknet/cli/starknet_cli.py | 215 +++++-- .../starknet/cli/starknet_cli_utils.py | 112 +++- .../starknet/common/new_syscalls.cairo | 109 +++- .../starknet/compiler/CMakeLists.txt | 3 + .../compiler/v1/BUILD.cairo-lang-1.0.0 | 7 + .../starknet/compiler/v1/CMakeLists.txt | 108 ++++ src/starkware/starknet/compiler/v1/compile.py | 100 ++++ .../starknet/core/os/block_context.cairo | 17 +- src/starkware/starknet/core/os/builtins.cairo | 6 + .../starknet/core/os/constants.cairo | 9 +- .../core/os/contract_class/CMakeLists.txt | 6 +- .../core/os/contract_class/class_hash.py | 38 +- ...{contract_class.py => class_hash_utils.py} | 26 +- .../os/contract_class/compiled_class.cairo | 14 +- .../os/contract_class/compiled_class_hash.py | 130 +---- .../compiled_class_hash_utils.py | 113 ++++ .../os/contract_class/contract_class.cairo | 14 +- .../contract_class/deprecated_class_hash.py | 5 +- .../starknet/core/os/contract_class/utils.py | 33 ++ .../deprecated_execute_entry_point.cairo | 20 +- .../deprecated_execute_syscalls.cairo | 91 +-- .../os/execution/execute_entry_point.cairo | 69 ++- .../core/os/execution/execute_syscalls.cairo | 299 ++++++++-- .../os/execution/execute_transactions.cairo | 258 ++++++--- .../starknet/core/os/program_hash.json | 2 +- src/starkware/starknet/core/os/state.cairo | 43 +- .../starknet/core/os/syscall_handler.py | 533 +++++++++++++++--- .../starknet/core/os/syscall_utils.py | 25 + .../starknet/core/test_contract/test_utils.py | 13 +- .../starknet/definitions/constants.py | 11 +- .../starknet/definitions/error_codes.py | 6 + src/starkware/starknet/definitions/fields.py | 18 +- .../starknet/definitions/general_config.py | 2 + .../api/contract_class/CMakeLists.txt | 13 + .../api/contract_class/contract_class.py | 145 ++++- .../contract_class/contract_class_utils.py | 60 ++ .../contracts/test_contract_cairo1.cairo | 119 ++++ .../feeder_gateway/feeder_gateway_client.py | 41 +- .../api/feeder_gateway/request_objects.py | 7 +- .../services/api/gateway/transaction.py | 6 +- .../starknet/services/utils/CMakeLists.txt | 2 - .../services/utils/sequencer_api_utils.py | 20 +- .../starknet/storage/starknet_storage.py | 2 +- src/starkware/starknet/wallets/account.py | 26 +- .../starknet/wallets/open_zeppelin.py | 63 ++- .../starknet/wallets/starknet_context.py | 2 +- .../starkware_utils/error_handling.py | 2 +- .../marshmallow_dataclass_fields.py | 11 + 98 files changed, 3078 insertions(+), 868 deletions(-) create mode 100644 src/starkware/cairo/common/patricia_utils.cairo create mode 100644 src/starkware/cairo/common/patricia_with_poseidon.cairo create mode 100644 src/starkware/starknet/builtins/CMakeLists.txt create mode 100644 src/starkware/starknet/builtins/segment_arena/CMakeLists.txt create mode 100644 src/starkware/starknet/builtins/segment_arena/segment_arena.cairo create mode 100644 src/starkware/starknet/builtins/segment_arena/segment_arena_builtin_runner.py create mode 100644 src/starkware/starknet/builtins/segment_arena/segment_arena_test.cairo create mode 100644 src/starkware/starknet/builtins/segment_arena/segment_arena_test.py mode change 100644 => 100755 src/starkware/starknet/cli/class_hash.py create mode 100755 src/starkware/starknet/cli/compiled_class_hash.py create mode 100644 src/starkware/starknet/compiler/v1/BUILD.cairo-lang-1.0.0 create mode 100644 src/starkware/starknet/compiler/v1/CMakeLists.txt create mode 100644 src/starkware/starknet/compiler/v1/compile.py rename src/starkware/starknet/core/os/contract_class/{contract_class.py => class_hash_utils.py} (83%) create mode 100644 src/starkware/starknet/core/os/contract_class/compiled_class_hash_utils.py create mode 100644 src/starkware/starknet/core/os/contract_class/utils.py create mode 100644 src/starkware/starknet/services/api/contract_class/contract_class_utils.py create mode 100644 src/starkware/starknet/services/api/contract_class/contracts/test_contract_cairo1.cairo diff --git a/scripts/requirements-gen.txt b/scripts/requirements-gen.txt index 95ce92e0..27d65ca9 100644 --- a/scripts/requirements-gen.txt +++ b/scripts/requirements-gen.txt @@ -16,6 +16,6 @@ prometheus-client pytest pytest-asyncio PyYAML -typeguard +typeguard<3.0.0 sympy Web3 diff --git a/src/cmake_utils/python_rules.cmake b/src/cmake_utils/python_rules.cmake index d4283faf..4602f7a3 100644 --- a/src/cmake_utils/python_rules.cmake +++ b/src/cmake_utils/python_rules.cmake @@ -303,3 +303,30 @@ function(full_python_test TEST_NAME) ${CODE_COVERAGE_SUPPRESSION_FLAG} ) endfunction() + +function(starknet_contract_v1 NAME) + # Parse arguments. + set(options) + set(oneValueArgs MAIN COMPILED_SIERRA_NAME COMPILED_CASM_NAME) + set(multiValueArgs) + cmake_parse_arguments(ARGS "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${ARGS_COMPILED_SIERRA_NAME} + COMMAND + ${CMAKE_BINARY_DIR}/src/starkware/starknet/compiler/v1/cairo/bin/starknet-compile + ${CMAKE_CURRENT_SOURCE_DIR}/${ARGS_MAIN} + ${CMAKE_CURRENT_BINARY_DIR}/${ARGS_COMPILED_SIERRA_NAME} + DEPENDS get_cairo_compiler ${CMAKE_CURRENT_SOURCE_DIR}/${ARGS_MAIN} + ) + add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${ARGS_COMPILED_CASM_NAME} + COMMAND + ${CMAKE_BINARY_DIR}/src/starkware/starknet/compiler/v1/cairo/bin/starknet-sierra-compile + --add-pythonic-hints + ${CMAKE_CURRENT_BINARY_DIR}/${ARGS_COMPILED_SIERRA_NAME} + ${CMAKE_CURRENT_BINARY_DIR}/${ARGS_COMPILED_CASM_NAME} + DEPENDS get_cairo_compiler ${CMAKE_CURRENT_BINARY_DIR}/${ARGS_COMPILED_SIERRA_NAME} + ) +add_custom_target(${NAME} ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${ARGS_COMPILED_SIERRA_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${ARGS_COMPILED_CASM_NAME}) +endfunction() diff --git a/src/starkware/cairo/common/CMakeLists.txt b/src/starkware/cairo/common/CMakeLists.txt index 899896c6..8c96c0e8 100644 --- a/src/starkware/cairo/common/CMakeLists.txt +++ b/src/starkware/cairo/common/CMakeLists.txt @@ -50,6 +50,8 @@ python_lib(cairo_common_lib merkle_update.cairo patricia_utils.py patricia_with_sponge.cairo + patricia_with_poseidon.cairo + patricia_utils.cairo patricia.cairo poseidon_state.cairo pow.cairo diff --git a/src/starkware/cairo/common/cairo_function_runner.py b/src/starkware/cairo/common/cairo_function_runner.py index b353ad55..2a60a2de 100644 --- a/src/starkware/cairo/common/cairo_function_runner.py +++ b/src/starkware/cairo/common/cairo_function_runner.py @@ -1,7 +1,6 @@ from collections.abc import Iterable from typing import Any, Dict, Optional, Tuple, Union, cast -from starkware.cairo.common.poseidon_utils import PoseidonParams from starkware.cairo.common.structs import CairoStructFactory from starkware.cairo.lang.builtins.bitwise.bitwise_builtin_runner import BitwiseBuiltinRunner from starkware.cairo.lang.builtins.bitwise.instance_def import BitwiseInstanceDef @@ -78,7 +77,6 @@ def __init__(self, *args, **kwargs): included=True, instance_def=PoseidonInstanceDef( ratio=1, - params=PoseidonParams.get_default_poseidon_params(), partial_rounds_partition=[64, 22], ), ) @@ -246,6 +244,7 @@ def run_from_entrypoint( verify_secure: Optional[bool] = None, program_segment_size: Optional[int] = None, apply_modulo_to_args: Optional[bool] = None, + allow_tmp_segments: bool = False, ): """ Runs the program from the given entrypoint. @@ -276,7 +275,7 @@ def run_from_entrypoint( self.initialize_vm(hint_locals=hint_locals, static_locals=static_locals) self.run_until_pc(addr=end, run_resources=run_resources) - self.end_run() + self.end_run(allow_tmp_segments=allow_tmp_segments) if verify_secure: verify_secure_runner( diff --git a/src/starkware/cairo/common/dict.py b/src/starkware/cairo/common/dict.py index 3fb352fe..154d9339 100644 --- a/src/starkware/cairo/common/dict.py +++ b/src/starkware/cairo/common/dict.py @@ -41,11 +41,11 @@ def new_dict(self, segments, initial_dict): ) return base - def new_default_dict(self, segments, default_value): + def new_default_dict(self, segments, default_value, temp_segment: bool = False): """ Creates a new Cairo default dictionary. """ - base = segments.add() + base = segments.add_temp_segment() if temp_segment else segments.add() assert base.segment_index not in self.trackers self.trackers[base.segment_index] = DictTracker( data=defaultdict(lambda: default_value), diff --git a/src/starkware/cairo/common/patricia.cairo b/src/starkware/cairo/common/patricia.cairo index a6321eab..cb1f8c47 100644 --- a/src/starkware/cairo/common/patricia.cairo +++ b/src/starkware/cairo/common/patricia.cairo @@ -10,33 +10,15 @@ from starkware.cairo.common.math import ( assert_nn_le, assert_not_zero, ) +from starkware.cairo.common.patricia_utils import ( + MAX_LENGTH, + NodeEdge, + ParticiaGlobals, + PatriciaUpdateConstants, +) // ADDITIONAL_IMPORTS_MACRO() -// Maximum length of an edge. -const MAX_LENGTH = 251; - -// A struct of globals that are passed throughout the algorithm. -struct ParticiaGlobals { - // An array of size MAX_LENGTH, where pow2[i] = 2**i. - pow2: felt*, - // Offset of the relevant value field in DictAccess. - // 1 if the previous tree is traversed and 2 if the new tree is traversed. - access_offset: felt, -} - -// Represents an edge node: a subtree with a path, s.t. all leaves not under that path are 0. -struct NodeEdge { - length: felt, - path: felt, - bottom: felt, -} - -// Holds the constants needed for Patricia updates. -struct PatriciaUpdateConstants { - globals_pow2: felt*, -} - // Given an edge node hash, opens the hash using the preimage hint, and returns a NodeEdge object. func open_edge{hash_ptr: HashBuiltin*, range_check_ptr}(globals: ParticiaGlobals*, node: felt) -> ( edge: NodeEdge* diff --git a/src/starkware/cairo/common/patricia_utils.cairo b/src/starkware/cairo/common/patricia_utils.cairo new file mode 100644 index 00000000..9d2ae313 --- /dev/null +++ b/src/starkware/cairo/common/patricia_utils.cairo @@ -0,0 +1,23 @@ +// Maximum length of an edge. +const MAX_LENGTH = 251; + +// A struct of globals that are passed throughout the algorithm. +struct ParticiaGlobals { + // An array of size MAX_LENGTH, where pow2[i] = 2**i. + pow2: felt*, + // Offset of the relevant value field in DictAccess. + // 1 if the previous tree is traversed and 2 if the new tree is traversed. + access_offset: felt, +} + +// Represents an edge node: a subtree with a path, s.t. all leaves not under that path are 0. +struct NodeEdge { + length: felt, + path: felt, + bottom: felt, +} + +// Holds the constants needed for Patricia updates. +struct PatriciaUpdateConstants { + globals_pow2: felt*, +} diff --git a/src/starkware/cairo/common/patricia_with_poseidon.cairo b/src/starkware/cairo/common/patricia_with_poseidon.cairo new file mode 100644 index 00000000..b9fe140c --- /dev/null +++ b/src/starkware/cairo/common/patricia_with_poseidon.cairo @@ -0,0 +1,34 @@ +from starkware.cairo.common.cairo_builtins import PoseidonBuiltin +from starkware.cairo.common.dict import DictAccess +from starkware.cairo.common.patricia_utils import PatriciaUpdateConstants +from starkware.cairo.common.patricia_with_sponge import ( + patricia_update_using_update_constants as patricia_update_using_update_constants_with_sponge, +) +from starkware.cairo.common.sponge_as_hash import SpongeHashBuiltin + +func patricia_update_using_update_constants{poseidon_ptr: PoseidonBuiltin*, range_check_ptr}( + patricia_update_constants: PatriciaUpdateConstants*, + update_ptr: DictAccess*, + n_updates: felt, + height: felt, + prev_root: felt, + new_root: felt, +) { + let hash_ptr = cast(poseidon_ptr, SpongeHashBuiltin*); + + with hash_ptr { + patricia_update_using_update_constants_with_sponge( + patricia_update_constants=patricia_update_constants, + update_ptr=update_ptr, + n_updates=n_updates, + height=height, + prev_root=prev_root, + new_root=new_root, + ); + } + + // Update poseidon_ptr. + let poseidon_ptr = cast(hash_ptr, PoseidonBuiltin*); + + return (); +} diff --git a/src/starkware/cairo/common/patricia_with_sponge.cairo b/src/starkware/cairo/common/patricia_with_sponge.cairo index bff280de..379767d6 100644 --- a/src/starkware/cairo/common/patricia_with_sponge.cairo +++ b/src/starkware/cairo/common/patricia_with_sponge.cairo @@ -8,33 +8,15 @@ from starkware.cairo.common.math import ( assert_nn_le, assert_not_zero, ) +from starkware.cairo.common.patricia_utils import ( + MAX_LENGTH, + NodeEdge, + ParticiaGlobals, + PatriciaUpdateConstants, +) from starkware.cairo.common.sponge_as_hash import SpongeHashBuiltin as HashBuiltin from starkware.cairo.common.sponge_as_hash import sponge_hash2 as hash2 -// Maximum length of an edge. -const MAX_LENGTH = 251; - -// A struct of globals that are passed throughout the algorithm. -struct ParticiaGlobals { - // An array of size MAX_LENGTH, where pow2[i] = 2**i. - pow2: felt*, - // Offset of the relevant value field in DictAccess. - // 1 if the previous tree is traversed and 2 if the new tree is traversed. - access_offset: felt, -} - -// Represents an edge node: a subtree with a path, s.t. all leaves not under that path are 0. -struct NodeEdge { - length: felt, - path: felt, - bottom: felt, -} - -// Holds the constants needed for Patricia updates. -struct PatriciaUpdateConstants { - globals_pow2: felt*, -} - // Given an edge node hash, opens the hash using the preimage hint, and returns a NodeEdge object. func open_edge{hash_ptr: HashBuiltin*, range_check_ptr}(globals: ParticiaGlobals*, node: felt) -> ( edge: NodeEdge* diff --git a/src/starkware/cairo/lang/VERSION b/src/starkware/cairo/lang/VERSION index d22e31d2..e0cbcd58 100644 --- a/src/starkware/cairo/lang/VERSION +++ b/src/starkware/cairo/lang/VERSION @@ -1 +1 @@ -0.11.0a0 +0.11.0a1 diff --git a/src/starkware/cairo/lang/builtins/CMakeLists.txt b/src/starkware/cairo/lang/builtins/CMakeLists.txt index daa1dd0b..02af1702 100644 --- a/src/starkware/cairo/lang/builtins/CMakeLists.txt +++ b/src/starkware/cairo/lang/builtins/CMakeLists.txt @@ -24,6 +24,7 @@ python_lib(cairo_run_builtins_lib LIBS cairo_common_lib cairo_relocatable_lib + cairo_vm_crypto_lib cairo_vm_lib starkware_python_utils_lib ) diff --git a/src/starkware/cairo/lang/builtins/poseidon/instance_def.py b/src/starkware/cairo/lang/builtins/poseidon/instance_def.py index d0d9dae5..ba91c715 100644 --- a/src/starkware/cairo/lang/builtins/poseidon/instance_def.py +++ b/src/starkware/cairo/lang/builtins/poseidon/instance_def.py @@ -1,7 +1,7 @@ import dataclasses from typing import List, Optional -from starkware.cairo.common.poseidon_utils import PoseidonParams +POSEIDON_M = 3 @dataclasses.dataclass @@ -10,15 +10,12 @@ class PoseidonInstanceDef: # None means dynamic ratio. ratio: Optional[int] - # Defines the Hades permutation. - params: PoseidonParams - # Defines the partition of the partial rounds to virtual columns. partial_rounds_partition: List[int] @property def cells_per_builtin(self): - return 2 * self.params.m + return 2 * POSEIDON_M @property def range_check_units_per_builtin(self): diff --git a/src/starkware/cairo/lang/builtins/poseidon/poseidon_builtin_runner.py b/src/starkware/cairo/lang/builtins/poseidon/poseidon_builtin_runner.py index 013b1ae2..3a766b47 100644 --- a/src/starkware/cairo/lang/builtins/poseidon/poseidon_builtin_runner.py +++ b/src/starkware/cairo/lang/builtins/poseidon/poseidon_builtin_runner.py @@ -1,8 +1,8 @@ from typing import Any, Dict, Set -from starkware.cairo.common.poseidon_utils import hades_permutation from starkware.cairo.lang.builtins.poseidon.instance_def import PoseidonInstanceDef from starkware.cairo.lang.vm.builtin_runner import SimpleBuiltinRunner +from starkware.cairo.lang.vm.crypto import poseidon_perm from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue from starkware.python.math_utils import safe_div @@ -45,7 +45,7 @@ def rule(vm, addr, verified_addresses): + f"Got: {value}." ) input_state = memory.get_range(first_input_addr, self.n_input_cells) - output_state = hades_permutation(input_state, self.instance_def.params) + output_state = poseidon_perm(*input_state) for i in range(self.n_input_cells): self.cache[first_output_addr + i] = output_state[i] return self.cache[addr] diff --git a/src/starkware/cairo/lang/compiler/program.py b/src/starkware/cairo/lang/compiler/program.py index c82a6f8f..e485b926 100644 --- a/src/starkware/cairo/lang/compiler/program.py +++ b/src/starkware/cairo/lang/compiler/program.py @@ -81,11 +81,7 @@ def run_validity_checks(self): @marshmallow_dataclass.dataclass(repr=False) -class HintedProgram(ProgramBase, SerializableMarshmallowDataclass): - """ - A Serializable Cairo Program with hints. - """ - +class Program(ProgramBase, SerializableMarshmallowDataclass): prime: int = field(metadata=additional_metadata(marshmallow_field=IntAsHex(required=True))) data: List[int] = field( metadata=additional_metadata(marshmallow_field=mfields.List(IntAsHex(), required=True)) @@ -95,17 +91,6 @@ class HintedProgram(ProgramBase, SerializableMarshmallowDataclass): compiler_version: Optional[str] = field( metadata=dict(marshmallow_field=mfields.String(required=False, load_default=None)) ) - - def stripped(self) -> StrippedProgram: - raise NotImplementedError("HintedProgram does not have a main entrypoint.") - - @property - def main(self) -> Optional[int]: # type: ignore - raise NotImplementedError("HintedProgram does not have a main entrypoint.") - - -@marshmallow_dataclass.dataclass(repr=False) -class Program(HintedProgram): main_scope: ScopedName = field( metadata=additional_metadata(marshmallow_field=ScopedNameAsStr()) ) diff --git a/src/starkware/cairo/lang/instances.py b/src/starkware/cairo/lang/instances.py index c4624663..39fb164c 100644 --- a/src/starkware/cairo/lang/instances.py +++ b/src/starkware/cairo/lang/instances.py @@ -2,7 +2,6 @@ from dataclasses import field from typing import Any, Dict, Optional -from starkware.cairo.common.poseidon_utils import PoseidonParams from starkware.cairo.lang.builtins.bitwise.instance_def import BitwiseInstanceDef from starkware.cairo.lang.builtins.ec.instance_def import EcOpInstanceDef from starkware.cairo.lang.builtins.hash.instance_def import PedersenInstanceDef @@ -195,7 +194,6 @@ def build_dynamic_layout(**ratios) -> CairoLayout: ), poseidon=PoseidonInstanceDef( ratio=32, - params=PoseidonParams.get_default_poseidon_params(), partial_rounds_partition=[64, 22], ), ), @@ -247,7 +245,6 @@ def build_dynamic_layout(**ratios) -> CairoLayout: ), poseidon=PoseidonInstanceDef( ratio=32, - params=PoseidonParams.get_default_poseidon_params(), partial_rounds_partition=[64, 22], ), ), @@ -366,7 +363,6 @@ def build_dynamic_layout(**ratios) -> CairoLayout: ), poseidon=PoseidonInstanceDef( ratio=256, - params=PoseidonParams.get_default_poseidon_params(), partial_rounds_partition=[64, 22], ), ), diff --git a/src/starkware/cairo/lang/tracer/tracer.py b/src/starkware/cairo/lang/tracer/tracer.py index a1a812e0..76b4aa20 100755 --- a/src/starkware/cairo/lang/tracer/tracer.py +++ b/src/starkware/cairo/lang/tracer/tracer.py @@ -15,7 +15,7 @@ def trace_runner(runner): runner.vm_memory.relocate_memory() runner.vm_memory.freeze() - runner.segments.compute_effective_sizes(include_tmp_segments=True) + runner.segments.compute_effective_sizes(allow_tmp_segments=True) if not hasattr(runner, "relocated_trace"): runner.relocate() diff --git a/src/starkware/cairo/lang/vm/cairo_runner.py b/src/starkware/cairo/lang/vm/cairo_runner.py index c1c93cff..3b8fb5f8 100644 --- a/src/starkware/cairo/lang/vm/cairo_runner.py +++ b/src/starkware/cairo/lang/vm/cairo_runner.py @@ -1,4 +1,17 @@ -from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, +) from starkware.cairo.lang.builtins.bitwise.bitwise_builtin_runner import BitwiseBuiltinRunner from starkware.cairo.lang.builtins.ec.ec_op_builtin_runner import EcOpBuiltinRunner @@ -74,7 +87,13 @@ def __init__( memory: MemoryDict = None, proof_mode: Optional[bool] = None, allow_missing_builtins: Optional[bool] = None, + additional_builtin_factories: Optional[ + Dict[str, Callable[[str, bool], BuiltinRunner]] + ] = None, ): + if additional_builtin_factories is None: + additional_builtin_factories = {} + self.program = program self.layout: CairoLayout if isinstance(layout, CairoLayout): @@ -128,6 +147,7 @@ def __init__( poseidon=lambda name, included: PoseidonBuiltinRunner( included=included, instance_def=self.layout.builtins["poseidon"] ), + **additional_builtin_factories, ) for name in self.layout.builtins: @@ -342,7 +362,12 @@ def run_until_next_power_of_2(self): """ self.run_until_steps(next_power_of_2(self.vm.current_step)) - def end_run(self, disable_trace_padding: bool = True, disable_finalize_all: bool = False): + def end_run( + self, + disable_trace_padding: bool = True, + disable_finalize_all: bool = False, + allow_tmp_segments: bool = False, + ): assert not self._run_ended, "end_run called twice." self.accessed_addresses = { @@ -358,7 +383,7 @@ def end_run(self, disable_trace_padding: bool = True, disable_finalize_all: bool # Freeze to enable caching; No changes in memory should be made from now on. self.vm_memory.freeze() # Deduce the size of each segment from its usage. - self.segments.compute_effective_sizes() + self.segments.compute_effective_sizes(allow_tmp_segments=allow_tmp_segments) if self.proof_mode and not disable_trace_padding: self.run_until_next_power_of_2() diff --git a/src/starkware/cairo/lang/vm/crypto.py b/src/starkware/cairo/lang/vm/crypto.py index 0fffffda..acbbc5c0 100644 --- a/src/starkware/cairo/lang/vm/crypto.py +++ b/src/starkware/cairo/lang/vm/crypto.py @@ -5,6 +5,7 @@ poseidon_hash_func, poseidon_hash_many, poseidon_hash_single, + poseidon_perm, ) from starkware.crypto.signature.fast_pedersen_hash import pedersen_hash # noqa from starkware.crypto.signature.fast_pedersen_hash import pedersen_hash_func # noqa diff --git a/src/starkware/cairo/lang/vm/memory_segments.py b/src/starkware/cairo/lang/vm/memory_segments.py index e2115ce3..6c71a34a 100644 --- a/src/starkware/cairo/lang/vm/memory_segments.py +++ b/src/starkware/cairo/lang/vm/memory_segments.py @@ -72,10 +72,9 @@ def finalize( self.public_memory_offsets[segment_index] = list(public_memory) - def compute_effective_sizes(self, include_tmp_segments: bool = False): + def compute_effective_sizes(self, allow_tmp_segments: bool = False): """ Computes the current used size of the segments, and caches it. - include_tmp_segments should be used for tests only. """ if self._segment_used_sizes is not None: # segment_sizes is already cached. @@ -83,7 +82,7 @@ def compute_effective_sizes(self, include_tmp_segments: bool = False): assert self.memory.is_frozen(), "Memory has to be frozen before calculating effective size." - first_segment_index = -self.n_temp_segments if include_tmp_segments else 0 + first_segment_index = -self.n_temp_segments if allow_tmp_segments else 0 self._segment_used_sizes = { index: 0 for index in range(first_segment_index, self.n_segments) } diff --git a/src/starkware/cairo/lang/vm/virtual_machine_base.py b/src/starkware/cairo/lang/vm/virtual_machine_base.py index 53bea94b..53fc0ec5 100644 --- a/src/starkware/cairo/lang/vm/virtual_machine_base.py +++ b/src/starkware/cairo/lang/vm/virtual_machine_base.py @@ -17,7 +17,7 @@ from starkware.cairo.lang.compiler.preprocessor.flow import FlowTrackingDataActual from starkware.cairo.lang.compiler.preprocessor.preprocessor import AttributeBase, AttributeScope from starkware.cairo.lang.compiler.preprocessor.reg_tracking import RegTrackingData -from starkware.cairo.lang.compiler.program import HintedProgram, Program, ProgramBase +from starkware.cairo.lang.compiler.program import Program, ProgramBase from starkware.cairo.lang.compiler.references import ApDeductionError from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.vm.builtin_runner import BuiltinRunner @@ -170,8 +170,7 @@ def __init__( self.validated_memory = ValidatedMemoryDict(memory=run_context.memory, prime=self.prime) # If program is a StrippedProgram, there are no hints or debug information to load. - if isinstance(program, HintedProgram): - assert isinstance(program, Program), "A bare HintedProgram cannot be loaded." + if isinstance(program, Program): self.load_program(program=program, program_base=program_base) # auto_deduction contains a mapping from a memory segment index to a list of functions diff --git a/src/starkware/cairo/lang/vm/vm_consts.py b/src/starkware/cairo/lang/vm/vm_consts.py index 25be8f6c..d2a4198b 100644 --- a/src/starkware/cairo/lang/vm/vm_consts.py +++ b/src/starkware/cairo/lang/vm/vm_consts.py @@ -265,20 +265,32 @@ def raise_unsupported_error(self, name: ScopedName, identifier_type: str): class VmConstsReference(VmConstsBase): - def __init__(self, *, reference_value, struct_name: ScopedName, **kw): + def __init__( + self, + *, + reference_value, + struct_name: Optional[ScopedName] = None, + struct_definition: Optional[StructDefinition] = None, + **kw, + ): """ Constructs a VmConstsReference which allows accessing a typed reference fields. """ super().__init__(**kw) - object.__setattr__( - self, - "_struct_definition", - get_struct_definition( + if struct_definition is None: + assert ( + struct_name is not None + ), "Exactly one of 'struct_name' and 'struct_definition' must be specified." + struct_definition = get_struct_definition( struct_name=struct_name, identifier_manager=self._context.identifiers - ), - ) + ) + else: + assert ( + struct_name is None + ), "Exactly one of 'struct_name' and 'struct_definition' must be specified." + object.__setattr__(self, "_struct_definition", struct_definition) object.__setattr__(self, "_reference_value", reference_value) object.__setattr__(self, "address_", reference_value) @@ -324,6 +336,13 @@ def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): reference_value=self._context.memory[addr], ) + def __getitem__(self, idx: int): + return VmConstsReference( + context=self._context, + struct_definition=self._struct_definition, + reference_value=self.address_ + idx * self._struct_definition.size, + ) + def is_simple_type(expr_type: CairoType) -> bool: """ diff --git a/src/starkware/cairo/lang/vm/vm_consts_test.py b/src/starkware/cairo/lang/vm/vm_consts_test.py index dbf5b16f..d4e7571e 100644 --- a/src/starkware/cairo/lang/vm/vm_consts_test.py +++ b/src/starkware/cairo/lang/vm/vm_consts_test.py @@ -157,6 +157,7 @@ def test_references(): my_struct = TypeStruct(scope=scope("MyStruct")) my_struct_star = TypePointer(pointee=my_struct) + MY_STRUCT_SIZE = 20 identifier_values = { scope("x.ref"): ReferenceDefinition( full_name=scope("x.ref"), cairo_type=TypeFelt(), references=[] @@ -179,7 +180,7 @@ def test_references(): "member": MemberDefinition(offset=10, cairo_type=TypeFelt()), "struct": MemberDefinition(offset=11, cairo_type=my_struct), }, - size=11, + size=MY_STRUCT_SIZE, ), } identifiers = IdentifierManager.from_dict(identifier_values) @@ -209,6 +210,9 @@ def test_references(): assert consts.x.ref == memory[(ap - 2) + 1] assert consts.x.typeref.address_ == (ap - 1) + 1 + assert consts.x.typeref[0].address_ == (ap - 1) + 1 + assert consts.x.typeref[2].address_ == (ap - 1) + 1 + MY_STRUCT_SIZE * 2 + assert consts.x.typeref[2].struct.address_ == (ap - 1) + 1 + MY_STRUCT_SIZE * 2 + 11 assert consts.x.typeref.member == memory[(ap - 1) + 1 + 10] with pytest.raises(IdentifierError, match="'abc' is not a member of 'MyStruct'."): consts.x.typeref.abc @@ -223,7 +227,7 @@ def test_references(): with pytest.raises(AssertionError, match="Cannot change the value of a constant."): consts.MyStruct.member = 13 - assert consts.MyStruct.SIZE == 11 + assert consts.MyStruct.SIZE == MY_STRUCT_SIZE with pytest.raises(AssertionError, match="Cannot change the value of a constant."): consts.MyStruct.SIZE = 13 diff --git a/src/starkware/cairo/stark_verifier/air/layouts/all_cairo/public_verify.cairo b/src/starkware/cairo/stark_verifier/air/layouts/all_cairo/public_verify.cairo index e69bd449..65be1128 100644 --- a/src/starkware/cairo/stark_verifier/air/layouts/all_cairo/public_verify.cairo +++ b/src/starkware/cairo/stark_verifier/air/layouts/all_cairo/public_verify.cairo @@ -97,6 +97,7 @@ func public_input_validate{range_check_ptr}( assert_nn(n_output_uses); assert public_input.n_segments = segments.N_SEGMENTS; + tempvar n_pedersen_copies = n_steps / PEDERSEN_BUILTIN_RATIO; tempvar n_pedersen_uses = ( public_input.segments[segments.PEDERSEN].stop_ptr - @@ -110,6 +111,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.RANGE_CHECK].stop_ptr - public_input.segments[segments.RANGE_CHECK].begin_addr ); + // Note that the following call implies that n_steps is divisible by RC_BUILTIN_RATIO. assert_nn_le(n_range_check_uses, n_range_check_copies); tempvar n_ecdsa_copies = n_steps / ECDSA_BUILTIN_RATIO; @@ -117,6 +119,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.ECDSA].stop_ptr - public_input.segments[segments.ECDSA].begin_addr ) / 2; + // Note that the following call implies that n_steps is divisible by ECDSA_BUILTIN_RATIO. assert_nn_le(n_ecdsa_uses, n_ecdsa_copies); tempvar n_bitwise_copies = n_steps / BITWISE__RATIO; @@ -124,6 +127,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.BITWISE].stop_ptr - public_input.segments[segments.BITWISE].begin_addr ) / 5; + // Note that the following call implies that n_steps is divisible by BITWISE__RATIO. assert_nn_le(n_bitwise_uses, n_bitwise_copies); tempvar n_ec_op_copies = n_steps / EC_OP_BUILTIN_RATIO; @@ -131,6 +135,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.EC_OP].stop_ptr - public_input.segments[segments.EC_OP].begin_addr ) / 7; + // Note that the following call implies that n_steps is divisible by EC_OP_BUILTIN_RATIO. assert_nn_le(n_ec_op_uses, n_ec_op_copies); tempvar n_keccak_copies = n_steps / KECCAK__RATIO; @@ -138,6 +143,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.KECCAK].stop_ptr - public_input.segments[segments.KECCAK].begin_addr ) / 16; + // Note that the following call implies that n_steps is divisible by KECCAK__RATIO. assert_nn_le(n_keccak_uses, n_keccak_copies); tempvar n_poseidon_copies = n_steps / POSEIDON__RATIO; @@ -145,6 +151,8 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.POSEIDON].stop_ptr - public_input.segments[segments.POSEIDON].begin_addr ) / 6; + // Note that the following call implies that n_steps is divisible by POSEIDON__RATIO. assert_nn_le(n_poseidon_uses, n_poseidon_copies); + return (); } diff --git a/src/starkware/cairo/stark_verifier/air/layouts/dex/public_verify.cairo b/src/starkware/cairo/stark_verifier/air/layouts/dex/public_verify.cairo index 711648b3..450d8de0 100644 --- a/src/starkware/cairo/stark_verifier/air/layouts/dex/public_verify.cairo +++ b/src/starkware/cairo/stark_verifier/air/layouts/dex/public_verify.cairo @@ -85,6 +85,7 @@ func public_input_validate{range_check_ptr}( assert_nn(n_output_uses); assert public_input.n_segments = segments.N_SEGMENTS; + tempvar n_pedersen_copies = n_steps / PEDERSEN_BUILTIN_RATIO; tempvar n_pedersen_uses = ( public_input.segments[segments.PEDERSEN].stop_ptr - @@ -98,6 +99,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.RANGE_CHECK].stop_ptr - public_input.segments[segments.RANGE_CHECK].begin_addr ); + // Note that the following call implies that n_steps is divisible by RC_BUILTIN_RATIO. assert_nn_le(n_range_check_uses, n_range_check_copies); tempvar n_ecdsa_copies = n_steps / ECDSA_BUILTIN_RATIO; @@ -105,6 +107,8 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.ECDSA].stop_ptr - public_input.segments[segments.ECDSA].begin_addr ) / 2; + // Note that the following call implies that n_steps is divisible by ECDSA_BUILTIN_RATIO. assert_nn_le(n_ecdsa_uses, n_ecdsa_copies); + return (); } diff --git a/src/starkware/cairo/stark_verifier/air/layouts/recursive/public_verify.cairo b/src/starkware/cairo/stark_verifier/air/layouts/recursive/public_verify.cairo index aeda2e8d..cedb45a6 100644 --- a/src/starkware/cairo/stark_verifier/air/layouts/recursive/public_verify.cairo +++ b/src/starkware/cairo/stark_verifier/air/layouts/recursive/public_verify.cairo @@ -85,6 +85,7 @@ func public_input_validate{range_check_ptr}( assert_nn(n_output_uses); assert public_input.n_segments = segments.N_SEGMENTS; + tempvar n_pedersen_copies = n_steps / PEDERSEN_BUILTIN_RATIO; tempvar n_pedersen_uses = ( public_input.segments[segments.PEDERSEN].stop_ptr - @@ -98,6 +99,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.RANGE_CHECK].stop_ptr - public_input.segments[segments.RANGE_CHECK].begin_addr ); + // Note that the following call implies that n_steps is divisible by RC_BUILTIN_RATIO. assert_nn_le(n_range_check_uses, n_range_check_copies); tempvar n_bitwise_copies = n_steps / BITWISE__RATIO; @@ -105,6 +107,8 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.BITWISE].stop_ptr - public_input.segments[segments.BITWISE].begin_addr ) / 5; + // Note that the following call implies that n_steps is divisible by BITWISE__RATIO. assert_nn_le(n_bitwise_uses, n_bitwise_copies); + return (); } diff --git a/src/starkware/cairo/stark_verifier/air/layouts/small/public_verify.cairo b/src/starkware/cairo/stark_verifier/air/layouts/small/public_verify.cairo index 5a694998..7dcedfad 100644 --- a/src/starkware/cairo/stark_verifier/air/layouts/small/public_verify.cairo +++ b/src/starkware/cairo/stark_verifier/air/layouts/small/public_verify.cairo @@ -85,6 +85,7 @@ func public_input_validate{range_check_ptr}( assert_nn(n_output_uses); assert public_input.n_segments = segments.N_SEGMENTS; + tempvar n_pedersen_copies = n_steps / PEDERSEN_BUILTIN_RATIO; tempvar n_pedersen_uses = ( public_input.segments[segments.PEDERSEN].stop_ptr - @@ -98,6 +99,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.RANGE_CHECK].stop_ptr - public_input.segments[segments.RANGE_CHECK].begin_addr ); + // Note that the following call implies that n_steps is divisible by RC_BUILTIN_RATIO. assert_nn_le(n_range_check_uses, n_range_check_copies); tempvar n_ecdsa_copies = n_steps / ECDSA_BUILTIN_RATIO; @@ -105,6 +107,8 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.ECDSA].stop_ptr - public_input.segments[segments.ECDSA].begin_addr ) / 2; + // Note that the following call implies that n_steps is divisible by ECDSA_BUILTIN_RATIO. assert_nn_le(n_ecdsa_uses, n_ecdsa_copies); + return (); } diff --git a/src/starkware/cairo/stark_verifier/air/layouts/starknet/public_verify.cairo b/src/starkware/cairo/stark_verifier/air/layouts/starknet/public_verify.cairo index a9b4d0ab..c96faf0e 100644 --- a/src/starkware/cairo/stark_verifier/air/layouts/starknet/public_verify.cairo +++ b/src/starkware/cairo/stark_verifier/air/layouts/starknet/public_verify.cairo @@ -94,6 +94,7 @@ func public_input_validate{range_check_ptr}( assert_nn(n_output_uses); assert public_input.n_segments = segments.N_SEGMENTS; + tempvar n_pedersen_copies = n_steps / PEDERSEN_BUILTIN_RATIO; tempvar n_pedersen_uses = ( public_input.segments[segments.PEDERSEN].stop_ptr - @@ -107,6 +108,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.RANGE_CHECK].stop_ptr - public_input.segments[segments.RANGE_CHECK].begin_addr ); + // Note that the following call implies that n_steps is divisible by RC_BUILTIN_RATIO. assert_nn_le(n_range_check_uses, n_range_check_copies); tempvar n_ecdsa_copies = n_steps / ECDSA_BUILTIN_RATIO; @@ -114,6 +116,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.ECDSA].stop_ptr - public_input.segments[segments.ECDSA].begin_addr ) / 2; + // Note that the following call implies that n_steps is divisible by ECDSA_BUILTIN_RATIO. assert_nn_le(n_ecdsa_uses, n_ecdsa_copies); tempvar n_bitwise_copies = n_steps / BITWISE__RATIO; @@ -121,6 +124,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.BITWISE].stop_ptr - public_input.segments[segments.BITWISE].begin_addr ) / 5; + // Note that the following call implies that n_steps is divisible by BITWISE__RATIO. assert_nn_le(n_bitwise_uses, n_bitwise_copies); tempvar n_ec_op_copies = n_steps / EC_OP_BUILTIN_RATIO; @@ -128,6 +132,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.EC_OP].stop_ptr - public_input.segments[segments.EC_OP].begin_addr ) / 7; + // Note that the following call implies that n_steps is divisible by EC_OP_BUILTIN_RATIO. assert_nn_le(n_ec_op_uses, n_ec_op_copies); tempvar n_poseidon_copies = n_steps / POSEIDON__RATIO; @@ -135,6 +140,8 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.POSEIDON].stop_ptr - public_input.segments[segments.POSEIDON].begin_addr ) / 6; + // Note that the following call implies that n_steps is divisible by POSEIDON__RATIO. assert_nn_le(n_poseidon_uses, n_poseidon_copies); + return (); } diff --git a/src/starkware/cairo/stark_verifier/air/layouts/starknet_with_keccak/public_verify.cairo b/src/starkware/cairo/stark_verifier/air/layouts/starknet_with_keccak/public_verify.cairo index e02e4984..adc19637 100644 --- a/src/starkware/cairo/stark_verifier/air/layouts/starknet_with_keccak/public_verify.cairo +++ b/src/starkware/cairo/stark_verifier/air/layouts/starknet_with_keccak/public_verify.cairo @@ -97,6 +97,7 @@ func public_input_validate{range_check_ptr}( assert_nn(n_output_uses); assert public_input.n_segments = segments.N_SEGMENTS; + tempvar n_pedersen_copies = n_steps / PEDERSEN_BUILTIN_RATIO; tempvar n_pedersen_uses = ( public_input.segments[segments.PEDERSEN].stop_ptr - @@ -110,6 +111,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.RANGE_CHECK].stop_ptr - public_input.segments[segments.RANGE_CHECK].begin_addr ); + // Note that the following call implies that n_steps is divisible by RC_BUILTIN_RATIO. assert_nn_le(n_range_check_uses, n_range_check_copies); tempvar n_ecdsa_copies = n_steps / ECDSA_BUILTIN_RATIO; @@ -117,6 +119,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.ECDSA].stop_ptr - public_input.segments[segments.ECDSA].begin_addr ) / 2; + // Note that the following call implies that n_steps is divisible by ECDSA_BUILTIN_RATIO. assert_nn_le(n_ecdsa_uses, n_ecdsa_copies); tempvar n_bitwise_copies = n_steps / BITWISE__RATIO; @@ -124,6 +127,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.BITWISE].stop_ptr - public_input.segments[segments.BITWISE].begin_addr ) / 5; + // Note that the following call implies that n_steps is divisible by BITWISE__RATIO. assert_nn_le(n_bitwise_uses, n_bitwise_copies); tempvar n_ec_op_copies = n_steps / EC_OP_BUILTIN_RATIO; @@ -131,6 +135,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.EC_OP].stop_ptr - public_input.segments[segments.EC_OP].begin_addr ) / 7; + // Note that the following call implies that n_steps is divisible by EC_OP_BUILTIN_RATIO. assert_nn_le(n_ec_op_uses, n_ec_op_copies); tempvar n_keccak_copies = n_steps / KECCAK__RATIO; @@ -138,6 +143,7 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.KECCAK].stop_ptr - public_input.segments[segments.KECCAK].begin_addr ) / 16; + // Note that the following call implies that n_steps is divisible by KECCAK__RATIO. assert_nn_le(n_keccak_uses, n_keccak_copies); tempvar n_poseidon_copies = n_steps / POSEIDON__RATIO; @@ -145,6 +151,8 @@ func public_input_validate{range_check_ptr}( public_input.segments[segments.POSEIDON].stop_ptr - public_input.segments[segments.POSEIDON].begin_addr ) / 6; + // Note that the following call implies that n_steps is divisible by POSEIDON__RATIO. assert_nn_le(n_poseidon_uses, n_poseidon_copies); + return (); } diff --git a/src/starkware/eth/eth_test_utils.py b/src/starkware/eth/eth_test_utils.py index cbc7605e..376627d7 100644 --- a/src/starkware/eth/eth_test_utils.py +++ b/src/starkware/eth/eth_test_utils.py @@ -20,7 +20,7 @@ TIMEOUT_FOR_WEB3_REQUESTS = 120 # Seconds. # Max number of attempts to check web3.isConnected(). -GANACHE_MAX_TRIES = 100 +GANACHE_MAX_TRIES = 60 logger = logging.getLogger(__name__) Abi = List[dict] @@ -130,7 +130,7 @@ def __init__(self): ) for i in range(GANACHE_MAX_TRIES): - time.sleep(0.1) + time.sleep(1) if self.w3.isConnected(): break else: diff --git a/src/starkware/starknet/CMakeLists.txt b/src/starkware/starknet/CMakeLists.txt index 2c4afbe8..5d7c76e6 100644 --- a/src/starkware/starknet/CMakeLists.txt +++ b/src/starkware/starknet/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(builtins) add_subdirectory(business_logic) add_subdirectory(cli) add_subdirectory(common) diff --git a/src/starkware/starknet/builtins/CMakeLists.txt b/src/starkware/starknet/builtins/CMakeLists.txt new file mode 100644 index 00000000..ea1e3754 --- /dev/null +++ b/src/starkware/starknet/builtins/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(segment_arena) diff --git a/src/starkware/starknet/builtins/segment_arena/CMakeLists.txt b/src/starkware/starknet/builtins/segment_arena/CMakeLists.txt new file mode 100644 index 00000000..0a6d039b --- /dev/null +++ b/src/starkware/starknet/builtins/segment_arena/CMakeLists.txt @@ -0,0 +1,27 @@ +python_lib(segment_arena_builtin_lib + PREFIX starkware/starknet/builtins/segment_arena + + FILES + segment_arena_builtin_runner.py + + LIBS + cairo_vm_lib +) + +full_python_test(segment_arena_test + PREFIX starkware/starknet/builtins/segment_arena + PYTHON ${PYTHON_COMMAND} + TESTED_MODULES starkware/starknet/builtins/segment_arena + + FILES + segment_arena.cairo + segment_arena_test.cairo + segment_arena_test.py + + LIBS + cairo_common_lib + cairo_function_runner_lib + cairo_constants_lib + cairo_compile_lib + pip_pytest +) diff --git a/src/starkware/starknet/builtins/segment_arena/segment_arena.cairo b/src/starkware/starknet/builtins/segment_arena/segment_arena.cairo new file mode 100644 index 00000000..a23340e7 --- /dev/null +++ b/src/starkware/starknet/builtins/segment_arena/segment_arena.cairo @@ -0,0 +1,90 @@ +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.segments import relocate_segment + +// The segment arena builtin allows Sierra libfuncs to allocate memory segments and only track their +// ends (rather than both the start and the end). When the segment is finalized, the arena can +// provide the start pointer that corresponds to the segment (given its end). +// +// The builtin should be used as follows: +// * Every segment must be allocated and finalized exactly once. This can be achieved by using +// linear types (the object holding the segment must not be duplicatable nor droppable). +// * Segment allocation: +// * Allocates a new segment and updates `SegmentInfo::start` with the pointer. +// * The segment should be temporary if `n_segments > 0`. +// * Increases `n_segments` by 1. +// * Segment finalization: +// * Guesses the index of the segment, in the range [0, n_segments). +// * Writes the end of the segment in `SegmentInfo::end`. +// * Copies the current value of `n_finalized` to `SegmentInfo::finalization_index`. +// * Increases `n_finalized` by 1. +// * Checks that the segment size is nonnegative (range-check on `end - start`). + +// Represents the information about a single segment allocated by the arena. +struct SegmentInfo { + // A pointer to the first element of this segment. + start: felt*, + // A pointer to the end of this segment (the first unused element). + end: felt*, + // A sequential id, assigned to the segment when it is finalized. + // This value is used to guarantee that 'end' is not assigned twice. + finalization_index: felt, +} + +// Represents the status of the segment arena. +struct SegmentArenaBuiltin { + // A pointer to a list of SegmentInfo. infos[i] contains information about the i-th segment + // (ordered by construction). + // The value is fixed during the execution of an entry point. + infos: SegmentInfo*, + // The number of segments that were created so far. + n_segments: felt, + // The number of segments that were finalized so far. + n_finalized: felt, +} + +// Constructs a new segment for the segment arena builtin and initializes it with an empty instance +// of `SegmentArenaBuiltin`. +func new_arena() -> SegmentArenaBuiltin* { + let (segment_arena: SegmentArenaBuiltin*) = alloc(); + assert segment_arena[0] = SegmentArenaBuiltin( + infos=cast(nondet %{ segments.add() %}, SegmentInfo*), n_segments=0, n_finalized=0 + ); + return &segment_arena[1]; +} + +// Validates the segment arena builtin. +// +// In particular, relocates the temporary segments such that the start of segment i is strictly +// larger than the end of segment i+1. +func validate_segment_arena(segment_arena: SegmentArenaBuiltin*) { + tempvar n_segments = segment_arena.n_segments; + tempvar n_finalized = segment_arena.n_finalized; + // The following line should follow from the fact that every allocated segment + // must be finalized exactly once. + // We keep it both as a sanity check and since Sierra compilation is not proven yet. + assert n_segments = n_finalized; + + if (n_segments == 0) { + return (); + } + + // The following call also implies that n_segments > 0. + _verify_continuity(infos=segment_arena.infos, n_segments_minus_one=n_segments - 1); + return (); +} + +// Helper function for validate_segment_arena. +func _verify_continuity(infos: SegmentInfo*, n_segments_minus_one: felt) { + if (n_segments_minus_one == 0) { + // If there is only one segment left, there is no need to check anything. + return (); + } + + // Enforce an empty cell between two consecutive segments so that the start of a segment + // is strictly bigger than the end of the previous segment. + // This is required for proving the soundness of this construction, in the case where a segment + // has length zero. + relocate_segment(infos[1].start, infos[0].end + 1); + + return _verify_continuity(infos=&infos[1], n_segments_minus_one=n_segments_minus_one - 1); +} diff --git a/src/starkware/starknet/builtins/segment_arena/segment_arena_builtin_runner.py b/src/starkware/starknet/builtins/segment_arena/segment_arena_builtin_runner.py new file mode 100644 index 00000000..05e0f89d --- /dev/null +++ b/src/starkware/starknet/builtins/segment_arena/segment_arena_builtin_runner.py @@ -0,0 +1,39 @@ +from starkware.cairo.lang.vm.builtin_runner import SimpleBuiltinRunner + +ARENA_BUILTIN_SIZE = 3 +# The size of the builtin segment at the time of its creation. +INITIAL_SEGMENT_SIZE = ARENA_BUILTIN_SIZE + + +class SegmentArenaBuiltinRunner(SimpleBuiltinRunner): + def __init__(self, included: bool): + super().__init__( + name="segment_arena", + included=included, + ratio=None, + cells_per_instance=ARENA_BUILTIN_SIZE, + n_input_cells=ARENA_BUILTIN_SIZE, + ) + + def initialize_segments(self, runner): + infos = runner.segments.add() + # Initiate the segment with an empty SegmentArenaBuiltin. + initial_values = [ + infos, + 0, # n_segments. + 0, # n_finalized. + ] + assert len(initial_values) == INITIAL_SEGMENT_SIZE + segment_start = runner.segments.gen_arg(initial_values) + self._base = segment_start + INITIAL_SEGMENT_SIZE + + def get_used_cells(self, runner): + used = runner.segments.get_segment_used_size(self.base.segment_index) + # The value returned from `get_segment_used_size` includes the initial values that were + # written by `initialize_segments`. We reduce it, since the result of the function should + # not include them. + assert used >= INITIAL_SEGMENT_SIZE + return used - INITIAL_SEGMENT_SIZE + + def get_memory_accesses(self, runner): + return {} diff --git a/src/starkware/starknet/builtins/segment_arena/segment_arena_test.cairo b/src/starkware/starknet/builtins/segment_arena/segment_arena_test.cairo new file mode 100644 index 00000000..56a76fcf --- /dev/null +++ b/src/starkware/starknet/builtins/segment_arena/segment_arena_test.cairo @@ -0,0 +1,94 @@ +from starkware.cairo.common.alloc import alloc +from starkware.starknet.builtins.segment_arena.segment_arena import ( + SegmentArenaBuiltin, + SegmentInfo, + new_arena, + validate_segment_arena, +) + +// Creates a new segment using the segment arena. +func new_segment{segment_arena: SegmentArenaBuiltin*}() -> felt* { + let prev_segment_arena = &segment_arena[-1]; + tempvar n_segments = prev_segment_arena.n_segments; + tempvar infos = prev_segment_arena.infos; + + %{ + if 'segment_index_to_arena_index' not in globals(): + # A map from the relocatable value segment index to the index in the arena. + segment_index_to_arena_index = {} + + # The segment is placed at the end of the arena. + index = ids.n_segments + + # Create a segment or a temporary segment. + start = segments.add_temp_segment() if index > 0 else segments.add() + + # Update 'SegmentInfo::start' and 'segment_index_to_arena_index'. + ids.prev_segment_arena.infos[index].start = start + segment_index_to_arena_index[start.segment_index] = index + %} + assert segment_arena[0] = SegmentArenaBuiltin( + infos=infos, n_segments=n_segments + 1, n_finalized=prev_segment_arena.n_finalized + ); + let segment_arena = &segment_arena[1]; + return infos[n_segments].start; +} + +// Finalizes a given segment and returns the corresponding start. +func finalize_segment{segment_arena: SegmentArenaBuiltin*}(segment_end: felt*) -> felt* { + let prev_segment_arena = &segment_arena[-1]; + tempvar n_segments = prev_segment_arena.n_segments; + tempvar n_finalized = prev_segment_arena.n_finalized; + + // Guess the index of the segment. + tempvar index = nondet %{ segment_index_to_arena_index[ids.segment_end.segment_index] %}; + + // Write segment_end in the manager. + tempvar infos: SegmentInfo* = prev_segment_arena.infos; + tempvar segment_info: SegmentInfo* = &infos[index]; + // Writing n_finalized to 'finalization_index' guarantees 'segment_info.end' was not assigned + // a value before. + assert segment_info.finalization_index = n_finalized; + assert segment_info.end = segment_end; + + assert segment_arena[0] = SegmentArenaBuiltin( + infos=infos, n_segments=n_segments, n_finalized=n_finalized + 1 + ); + + let segment_arena = &segment_arena[1]; + return segment_info.start; +} + +func test_segment_arena() -> (felt*, SegmentInfo*) { + alloc_locals; + + local segment_arena_start: SegmentArenaBuiltin* = new_arena(); + let segment_arena = segment_arena_start; + + with segment_arena { + let segment0 = new_segment(); + let segment1 = new_segment(); + let segment2 = new_segment(); + + assert segment0[0] = 1; + assert segment0[1] = 2; + + assert segment1[0] = 3; + assert segment1[1] = 4; + + assert segment2[0] = 5; + + assert finalize_segment(segment0 + 2) = segment0; + assert finalize_segment(segment1 + 2) = segment1; + + let segment3 = new_segment(); + + assert segment3[0] = 6; + assert segment3[1] = 7; + + assert finalize_segment(segment3 + 2) = segment3; + assert finalize_segment(segment2 + 1) = segment2; + } + validate_segment_arena(segment_arena=&segment_arena[-1]); + return (segment0, segment_arena_start[-1].infos); +} diff --git a/src/starkware/starknet/builtins/segment_arena/segment_arena_test.py b/src/starkware/starknet/builtins/segment_arena/segment_arena_test.py new file mode 100644 index 00000000..81cc83c0 --- /dev/null +++ b/src/starkware/starknet/builtins/segment_arena/segment_arena_test.py @@ -0,0 +1,44 @@ +import os + +import pytest + +from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME +from starkware.cairo.lang.compiler.cairo_compile import compile_cairo_files +from starkware.cairo.lang.compiler.program import Program + +CAIRO_TEST_FILE = os.path.join(os.path.dirname(__file__), "segment_arena_test.cairo") + + +@pytest.fixture(scope="session") +def program() -> Program: + return compile_cairo_files([CAIRO_TEST_FILE], prime=DEFAULT_PRIME, debug_info=True) + + +def test_dict_simple(program): + runner = CairoFunctionRunner(program) + runner.run("test_segment_arena") + (concat_segments, infos) = runner.get_return_values(2) + + concat_segments_data = [runner.vm_memory.get(concat_segments + i) for i in range(10)] + assert concat_segments_data == [1, 2, None, 3, 4, None, 5, None, 6, 7] + + infos_data = [runner.vm_memory.get(infos + i) for i in range(12)] + assert infos_data == [ + # segment0. + concat_segments, + concat_segments + 2, + 0, + # segment1. + concat_segments + 3, + concat_segments + 5, + 1, + # segment2. + concat_segments + 6, + concat_segments + 7, + 3, + # segment3. + concat_segments + 8, + concat_segments + 10, + 2, + ] diff --git a/src/starkware/starknet/business_logic/execution/CMakeLists.txt b/src/starkware/starknet/business_logic/execution/CMakeLists.txt index d669572c..789b4bcb 100644 --- a/src/starkware/starknet/business_logic/execution/CMakeLists.txt +++ b/src/starkware/starknet/business_logic/execution/CMakeLists.txt @@ -49,6 +49,7 @@ python_lib(starknet_execute_entry_point_lib cairo_relocatable_lib cairo_vm_lib everest_definitions_lib + segment_arena_builtin_lib starknet_abi_lib starknet_business_logic_fact_state_lib starknet_business_logic_state_lib diff --git a/src/starkware/starknet/business_logic/execution/execute_entry_point.py b/src/starkware/starknet/business_logic/execution/execute_entry_point.py index 69298a70..69f3505b 100644 --- a/src/starkware/starknet/business_logic/execution/execute_entry_point.py +++ b/src/starkware/starknet/business_logic/execution/execute_entry_point.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses import functools import logging from typing import Any, Dict, List, Optional, Union, cast @@ -11,6 +12,9 @@ from starkware.cairo.lang.vm.utils import ResourcesError, RunResources from starkware.cairo.lang.vm.vm_exceptions import HintException, VmException, VmExceptionBase from starkware.python.utils import as_non_optional +from starkware.starknet.builtins.segment_arena.segment_arena_builtin_runner import ( + SegmentArenaBuiltinRunner, +) from starkware.starknet.business_logic.execution.execute_entry_point_base import ( ExecuteEntryPointBase, ) @@ -37,6 +41,7 @@ DeprecatedBlSyscallHandler, ) from starkware.starknet.definitions import fields +from starkware.starknet.definitions.constants import GasCost from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.definitions.general_config import ( STARKNET_LAYOUT_INSTANCE, @@ -101,7 +106,7 @@ def create_for_testing( entry_point_selector: int, entry_point_type: Optional[EntryPointType] = None, caller_address: int = 0, - initial_gas: int = 0, + initial_gas: int = GasCost.INITIAL.value, call_type: Optional[CallType] = None, class_hash: Optional[int] = None, ): @@ -206,6 +211,7 @@ def execute( resources_manager=resources_manager, general_config=general_config, tx_execution_context=tx_execution_context, + support_reverted=support_reverted, ) if support_reverted or call_info.failure_flag == 0: return call_info @@ -226,10 +232,17 @@ def _execute( resources_manager: ExecutionResourcesManager, general_config: StarknetGeneralConfig, tx_execution_context: TransactionExecutionContext, + support_reverted: bool, ) -> CallInfo: # Fix the current resources usage, in order to calculate the usage of this run at the end. previous_cairo_usage = resources_manager.cairo_usage + # Create a dummy layout. + layout = dataclasses.replace( + STARKNET_LAYOUT_INSTANCE, + builtins={**STARKNET_LAYOUT_INSTANCE.builtins, "segment_arena": {}}, + ) + # Prepare runner. entry_point = self._get_selected_entry_point( compiled_class=compiled_class, class_hash=class_hash @@ -239,7 +252,13 @@ def _execute( ) with wrap_with_stark_exception(code=StarknetErrorCode.SECURITY_ERROR): runner = CairoFunctionRunner( - program=program, layout=STARKNET_LAYOUT_INSTANCE.layout_name + program=program, + layout=layout, + additional_builtin_factories=dict( + segment_arena=lambda name, included: SegmentArenaBuiltinRunner( + included=included + ) + ), ) # Prepare implicit arguments. @@ -249,11 +268,13 @@ def _execute( initial_syscall_ptr = cast(RelocatableValue, implicit_args[-1]) syscall_handler = BusinessLogicSyscallHandler( state=state, + resources_manager=resources_manager, segments=runner.segments, tx_execution_context=tx_execution_context, initial_syscall_ptr=initial_syscall_ptr, - caller_address=self.caller_address, - contract_address=self.contract_address, + general_config=general_config, + entry_point=self, + support_reverted=support_reverted, ) # Load the builtin costs; Cairo 1.0 programs are expected to end with a `ret` opcode @@ -285,6 +306,7 @@ def _execute( hint_locals={"syscall_handler": syscall_handler}, run_resources=tx_execution_context.run_resources, program_segment_size=len(runner.program.data) + len(program_extra_data), + allow_tmp_segments=True, ) # We should not count (possibly) unsued code as holes. @@ -308,7 +330,7 @@ def _execute( result=get_call_result(runner=runner, initial_gas=self.initial_gas), events=syscall_handler.events, l2_to_l1_messages=[], - internal_calls=[], + internal_calls=syscall_handler.internal_calls, ) def _run( @@ -318,6 +340,7 @@ def _run( entry_point_args: EntryPointArgs, hint_locals: Dict[str, Any], run_resources: RunResources, + allow_tmp_segments: bool, program_segment_size: Optional[int] = None, ): """ @@ -340,6 +363,7 @@ def _run( run_resources=run_resources, verify_secure=True, program_segment_size=program_segment_size, + allow_tmp_segments=allow_tmp_segments, ) except VmException as exception: code: ErrorCode = StarknetErrorCode.TRANSACTION_FAILED @@ -545,6 +569,7 @@ def _execute_version0_class( entry_point_args=entry_point_args, hint_locals={"syscall_handler": syscall_handler}, run_resources=tx_execution_context.run_resources, + allow_tmp_segments=False, ) # Complete validations. diff --git a/src/starkware/starknet/business_logic/execution/execute_entry_point_base.py b/src/starkware/starknet/business_logic/execution/execute_entry_point_base.py index 86d16299..e77b1289 100644 --- a/src/starkware/starknet/business_logic/execution/execute_entry_point_base.py +++ b/src/starkware/starknet/business_logic/execution/execute_entry_point_base.py @@ -45,6 +45,7 @@ def execute( resources_manager: ExecutionResourcesManager, tx_execution_context: TransactionExecutionContext, general_config: StarknetGeneralConfig, + support_reverted: bool = False, ) -> CallInfo: """ Executes the entry point. diff --git a/src/starkware/starknet/business_logic/execution/os_resources.json b/src/starkware/starknet/business_logic/execution/os_resources.json index c4114fbb..908475e1 100644 --- a/src/starkware/starknet/business_logic/execution/os_resources.json +++ b/src/starkware/starknet/business_logic/execution/os_resources.json @@ -2,32 +2,32 @@ "execute_syscalls": { "call_contract": { "builtin_instance_counter": { - "range_check_builtin": 18 + "range_check_builtin": 19 }, "n_memory_holes": 0, - "n_steps": 617 + "n_steps": 677 }, "delegate_call": { "builtin_instance_counter": { - "range_check_builtin": 18 + "range_check_builtin": 19 }, "n_memory_holes": 0, - "n_steps": 639 + "n_steps": 699 }, "delegate_l1_handler": { "builtin_instance_counter": { - "range_check_builtin": 14 + "range_check_builtin": 15 }, "n_memory_holes": 0, - "n_steps": 618 + "n_steps": 678 }, "deploy": { "builtin_instance_counter": { "pedersen_builtin": 7, - "range_check_builtin": 17 + "range_check_builtin": 18 }, "n_memory_holes": 0, - "n_steps": 864 + "n_steps": 926 }, "emit_event": { "builtin_instance_counter": {}, @@ -37,27 +37,27 @@ "get_block_number": { "builtin_instance_counter": {}, "n_memory_holes": 0, - "n_steps": 39 + "n_steps": 40 }, "get_block_timestamp": { "builtin_instance_counter": {}, "n_memory_holes": 0, - "n_steps": 37 + "n_steps": 38 }, "get_caller_address": { "builtin_instance_counter": {}, "n_memory_holes": 0, - "n_steps": 31 + "n_steps": 32 }, "get_contract_address": { "builtin_instance_counter": {}, "n_memory_holes": 0, - "n_steps": 35 + "n_steps": 36 }, "get_sequencer_address": { "builtin_instance_counter": {}, "n_memory_holes": 0, - "n_steps": 33 + "n_steps": 34 }, "get_tx_info": { "builtin_instance_counter": {}, @@ -71,47 +71,47 @@ }, "library_call": { "builtin_instance_counter": { - "range_check_builtin": 18 + "range_check_builtin": 19 }, "n_memory_holes": 0, - "n_steps": 607 + "n_steps": 666 }, "library_call_l1_handler": { "builtin_instance_counter": { - "range_check_builtin": 14 + "range_check_builtin": 15 }, "n_memory_holes": 0, - "n_steps": 586 + "n_steps": 645 }, "replace_class": { "builtin_instance_counter": {}, "n_memory_holes": 0, - "n_steps": 72 + "n_steps": 73 }, "send_message_to_l1": { "builtin_instance_counter": {}, "n_memory_holes": 0, - "n_steps": 83 + "n_steps": 84 }, "storage_read": { "builtin_instance_counter": {}, "n_memory_holes": 0, - "n_steps": 43 + "n_steps": 44 }, "storage_write": { "builtin_instance_counter": {}, "n_memory_holes": 0, - "n_steps": 45 + "n_steps": 46 } }, "execute_txs_inner": { "DECLARE": { "builtin_instance_counter": { "pedersen_builtin": 15, - "range_check_builtin": 59 + "range_check_builtin": 63 }, "n_memory_holes": 0, - "n_steps": 2508 + "n_steps": 2676 }, "DEPLOY": { "builtin_instance_counter": {}, @@ -121,26 +121,26 @@ "DEPLOY_ACCOUNT": { "builtin_instance_counter": { "pedersen_builtin": 23, - "range_check_builtin": 77 + "range_check_builtin": 83 }, "n_memory_holes": 0, - "n_steps": 3327 + "n_steps": 3577 }, "INVOKE_FUNCTION": { "builtin_instance_counter": { "pedersen_builtin": 16, - "range_check_builtin": 73 + "range_check_builtin": 80 }, "n_memory_holes": 0, - "n_steps": 3073 + "n_steps": 3323 }, "L1_HANDLER": { "builtin_instance_counter": { "pedersen_builtin": 11, - "range_check_builtin": 14 + "range_check_builtin": 17 }, "n_memory_holes": 0, - "n_steps": 953 + "n_steps": 1054 } } } diff --git a/src/starkware/starknet/business_logic/fact_state/contract_class_objects.py b/src/starkware/starknet/business_logic/fact_state/contract_class_objects.py index e67e1dd6..9689ec3a 100644 --- a/src/starkware/starknet/business_logic/fact_state/contract_class_objects.py +++ b/src/starkware/starknet/business_logic/fact_state/contract_class_objects.py @@ -2,6 +2,7 @@ import marshmallow_dataclass +from starkware.cairo.lang.vm.crypto import poseidon_hash_func from starkware.python.utils import to_bytes from starkware.starknet.core.os.contract_class.class_hash import compute_class_hash from starkware.starknet.core.os.contract_class.compiled_class_hash import ( @@ -11,6 +12,7 @@ compute_deprecated_class_hash, ) from starkware.starknet.definitions import fields +from starkware.starknet.definitions.constants import CONTRACT_CLASS_LEAF_VERSION from starkware.starknet.services.api.contract_class.contract_class import ( CompiledClass, ContractClass, @@ -19,7 +21,17 @@ from starkware.starkware_utils.commitment_tree.leaf_fact import LeafFact from starkware.starkware_utils.commitment_tree.patricia_tree.nodes import EmptyNodeFact from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass -from starkware.storage.storage import Fact, HashFunctionType +from starkware.storage.storage import Fact, FactFetchingContext, HashFunctionType + + +def get_ffc_for_contract_class_facts(ffc: FactFetchingContext) -> FactFetchingContext: + """ + Replaces the given FactFetchingContext object with a corresponding one used for· + fetching contract class facts. + """ + return FactFetchingContext( + storage=ffc.storage, hash_func=poseidon_hash_func, n_workers=ffc.n_workers + ) @marshmallow_dataclass.dataclass(frozen=True) @@ -92,9 +104,7 @@ def _hash(self, hash_func: HashFunctionType) -> bytes: if self.is_empty: return EmptyNodeFact.EMPTY_NODE_HASH - CONTRACT_CLASS_HASH_VERSION = b"CONTRACT_CLASS_LEAF_V0" - - # Return H(CONTRACT_CLASS_HASH_VERSION, compiled_class_hash). - hash_value = hash_func(CONTRACT_CLASS_HASH_VERSION, to_bytes(self.compiled_class_hash)) + # Return H(CONTRACT_CLASS_LEAF_VERSION, compiled_class_hash). + hash_value = hash_func(CONTRACT_CLASS_LEAF_VERSION, to_bytes(self.compiled_class_hash)) return hash_value diff --git a/src/starkware/starknet/business_logic/fact_state/patricia_state.py b/src/starkware/starknet/business_logic/fact_state/patricia_state.py index 9e5e629c..41319b74 100644 --- a/src/starkware/starknet/business_logic/fact_state/patricia_state.py +++ b/src/starkware/starknet/business_logic/fact_state/patricia_state.py @@ -1,11 +1,11 @@ from typing import Dict, Optional -from starkware.cairo.lang.vm.crypto import poseidon_hash_func from starkware.python.utils import from_bytes, to_bytes from starkware.starknet.business_logic.fact_state.contract_class_objects import ( CompiledClassFact, ContractClassLeaf, DeprecatedCompiledClassFact, + get_ffc_for_contract_class_facts, ) from starkware.starknet.business_logic.fact_state.contract_state_objects import ContractState from starkware.starknet.business_logic.state.state_api import ( @@ -37,9 +37,7 @@ def __init__( ): # Members related to dynamic retrieval of facts during transaction execution. self.ffc = ffc - self.ffc_for_class_hash = FactFetchingContext( - storage=ffc.storage, hash_func=poseidon_hash_func, n_workers=ffc.n_workers - ) + self.ffc_for_class_hash = get_ffc_for_contract_class_facts(ffc=ffc) self.contract_class_storage = contract_class_storage # Last committed state roots. diff --git a/src/starkware/starknet/business_logic/fact_state/state.py b/src/starkware/starknet/business_logic/fact_state/state.py index 2e3f4bfb..00d11998 100644 --- a/src/starkware/starknet/business_logic/fact_state/state.py +++ b/src/starkware/starknet/business_logic/fact_state/state.py @@ -11,7 +11,7 @@ StateSelectorBase, ) from starkware.cairo.lang.vm.cairo_pie import ExecutionResources -from starkware.cairo.lang.vm.crypto import poseidon_hash_func, poseidon_hash_many +from starkware.cairo.lang.vm.crypto import poseidon_hash_many from starkware.python.utils import ( from_bytes, gather_in_chunks, @@ -19,7 +19,10 @@ subtract_mappings, to_bytes, ) -from starkware.starknet.business_logic.fact_state.contract_class_objects import ContractClassLeaf +from starkware.starknet.business_logic.fact_state.contract_class_objects import ( + ContractClassLeaf, + get_ffc_for_contract_class_facts, +) from starkware.starknet.business_logic.fact_state.contract_state_objects import ( ContractCarriedState, ContractState, @@ -389,9 +392,7 @@ async def apply_updates( ffc=ffc, modifications=list(safe_zip(accessed_addresses, updated_contract_states)) ) - ffc_for_contract_class = FactFetchingContext( - storage=ffc.storage, hash_func=poseidon_hash_func, n_workers=ffc.n_workers - ) + ffc_for_contract_class = get_ffc_for_contract_class_facts(ffc=ffc) updated_contract_classes: Optional[PatriciaTree] = None if self.contract_classes is not None: updated_contract_classes = await self.contract_classes.update( diff --git a/src/starkware/starknet/business_logic/state/state.py b/src/starkware/starknet/business_logic/state/state.py index 4953d255..091569a4 100644 --- a/src/starkware/starknet/business_logic/state/state.py +++ b/src/starkware/starknet/business_logic/state/state.py @@ -497,12 +497,31 @@ def set_class_hash_at(self, contract_address: int, class_hash: int): self.cache._class_hash_writes[contract_address] = class_hash self.state.set_class_hash_at(contract_address=contract_address, class_hash=class_hash) + def get_nonce_at(self, contract_address: int) -> int: + # Delegate the request to the actual state anyway (even if the value is already cached). + nonce = self.state.get_nonce_at(contract_address=contract_address) + if contract_address not in self.cache.address_to_class_hash: + self.cache._nonce_initial_values[contract_address] = nonce + + return nonce + + def increment_nonce(self, contract_address: int): + if contract_address not in self.cache.address_to_nonce: + # First access (read or write) to this cell; cache initial value. + self.cache._nonce_initial_values[contract_address] = self.state.get_nonce_at( + contract_address=contract_address + ) + + self.state.increment_nonce(contract_address=contract_address) + new_nonce = self.state.get_nonce_at(contract_address=contract_address) + self.cache._nonce_writes[contract_address] = new_nonce + @property def block_info(self) -> BlockInfo: return self.state.block_info def update_block_info(self, block_info: BlockInfo): - return self.state.update_block_info(block_info=block_info) + self.state.update_block_info(block_info=block_info) def get_compiled_class(self, compiled_class_hash: int) -> CompiledClassBase: return self.state.get_compiled_class(compiled_class_hash=compiled_class_hash) @@ -510,36 +529,53 @@ def get_compiled_class(self, compiled_class_hash: int) -> CompiledClassBase: def get_compiled_class_hash(self, class_hash: int) -> int: return self.state.get_compiled_class_hash(class_hash=class_hash) - def get_nonce_at(self, contract_address: int) -> int: - return self.state.get_nonce_at(contract_address=contract_address) - def set_compiled_class_hash(self, class_hash: int, compiled_class_hash: int): self.state.set_compiled_class_hash( class_hash=class_hash, compiled_class_hash=compiled_class_hash ) - def increment_nonce(self, contract_address: int): - self.state.increment_nonce(contract_address=contract_address) - - def count_actual_storage_changes(self) -> Tuple[int, int]: + def count_actual_updates(self) -> Tuple[int, int, int, int]: """ - Returns the number of storage changes done through this state, and the number of modified - contracts, where a contract is considered as modified if one or more of its storage cells - has changed. + Returns a tuple of: + 1. The number of modified contracts. + 2. The number of storage updates. + 3. The number of class hash updates. + 4. The number of nonce updates. + + An update is any a change done through this state; A contract is considered + modified if its nonce was updated, if its class hash was updated or + if one of its storage cells has changed. """ + # Storage Update. storage_updates = subtract_mappings( self.cache._storage_writes, self.cache._storage_initial_values ) - modified_contracts = { + contracts_with_modified_storage = { contract_address for (contract_address, _key) in storage_updates.keys() } - return (len(modified_contracts), len(storage_updates)) - def count_actual_class_updates(self) -> int: - """ - Returns the number of class hashes updated through this state. - """ + # Class hash Update. class_hash_updates = subtract_mappings( self.cache._class_hash_writes, self.cache._class_hash_initial_values ) - return len(class_hash_updates) + contracts_with_modified_class_hash = set(class_hash_updates.keys()) + + # Nonce Update. + nonce_updates = subtract_mappings( + self.cache._nonce_writes, self.cache._nonce_initial_values + ) + contracts_with_modified_nonce = set(nonce_updates.keys()) + + # Modified contracts. + modified_contracts = ( + contracts_with_modified_storage + | contracts_with_modified_class_hash + | contracts_with_modified_nonce + ) + + return ( + len(modified_contracts), + len(storage_updates), + len(class_hash_updates), + len(nonce_updates), + ) diff --git a/src/starkware/starknet/business_logic/state/state_api.py b/src/starkware/starknet/business_logic/state/state_api.py index 5c00d7e9..1e32671f 100644 --- a/src/starkware/starknet/business_logic/state/state_api.py +++ b/src/starkware/starknet/business_logic/state/state_api.py @@ -170,7 +170,7 @@ def get_compiled_class_by_class_hash(self, class_hash: int) -> CompiledClassBase compiled_class = self.get_compiled_class(compiled_class_hash=class_hash) assert isinstance( compiled_class, DeprecatedCompiledClass - ), "Class of version > 0 must be committed." + ), "Expected class hash; got compiled class hash." return compiled_class diff --git a/src/starkware/starknet/business_logic/transaction/objects.py b/src/starkware/starknet/business_logic/transaction/objects.py index edbcbf74..2a6de3a7 100644 --- a/src/starkware/starknet/business_logic/transaction/objects.py +++ b/src/starkware/starknet/business_logic/transaction/objects.py @@ -279,6 +279,7 @@ class InternalAccountTransaction(InternalTransaction): # The address of the account contract who sent the transaction. sender_address: int = field(metadata=fields.contract_address_metadata) + # Forbid by default query-version transactions. only_query: ClassVar[bool] = False @property @@ -457,44 +458,46 @@ def verify_version(self): verify_version( version=self.version, expected_version=constants.DECLARE_VERSION, - only_query=False, + only_query=self.only_query, old_supported_versions=[0, 1], ) if self.version not in constants.DEPRECATED_DECLARE_VERSIONS: assert ( self.compiled_class_hash is not None - ), "The compiled_class_hash field must not be None." - - if self.version not in [0, constants.QUERY_VERSION_BASE]: - return + ), "The compiled_class_hash field must not be None for Cairo 1.0 declare." + else: + assert ( + self.compiled_class_hash is None + ), "The compiled_class_hash field must be None for deprecated declare." - stark_assert_eq( - DEFAULT_DECLARE_SENDER_ADDRESS, - self.sender_address, - code=StarknetErrorCode.OUT_OF_RANGE_CONTRACT_ADDRESS, - message=( - "The sender_address field in Declare transactions of version 0 " - f"must be {DEFAULT_DECLARE_SENDER_ADDRESS}." - ), - ) - stark_assert_eq( - 0, - self.max_fee, - code=StarknetErrorCode.OUT_OF_RANGE_FEE, - message="The max_fee field in Declare transactions of version 0 must be 0.", - ) - stark_assert_eq( - 0, - self.nonce, - code=StarknetErrorCode.OUT_OF_RANGE_NONCE, - message="The nonce field in Declare transactions of version 0 must be 0.", - ) - stark_assert_eq( - 0, - len(self.signature), - code=StarknetErrorCode.NON_EMPTY_SIGNATURE, - message="The signature field in Declare transactions must be an empty list.", - ) + if self.version in [0, constants.QUERY_VERSION_BASE]: + stark_assert_eq( + DEFAULT_DECLARE_SENDER_ADDRESS, + self.sender_address, + code=StarknetErrorCode.OUT_OF_RANGE_CONTRACT_ADDRESS, + message=( + "The sender_address field in Declare transactions of version 0 " + f"must be {DEFAULT_DECLARE_SENDER_ADDRESS}." + ), + ) + stark_assert_eq( + 0, + self.max_fee, + code=StarknetErrorCode.OUT_OF_RANGE_FEE, + message="The max_fee field in Declare transactions of version 0 must be 0.", + ) + stark_assert_eq( + 0, + self.nonce, + code=StarknetErrorCode.OUT_OF_RANGE_NONCE, + message="The nonce field in Declare transactions of version 0 must be 0.", + ) + stark_assert_eq( + 0, + len(self.signature), + code=StarknetErrorCode.NON_EMPTY_SIGNATURE, + message="The signature field in Declare transactions must be an empty list.", + ) @classmethod def create( @@ -683,7 +686,7 @@ class InternalDeployAccount(InternalAccountTransaction): contract_address_salt: int = field(metadata=fields.contract_address_salt_metadata) class_hash: int = field(metadata=fields.new_class_hash_metadata) - constructor_calldata: List[int] = field(metadata=fields.call_data_metadata) + constructor_calldata: List[int] = field(metadata=fields.calldata_metadata) version: int = field(metadata=fields.tx_version_metadata) # Repeat `nonce` to narrow its type to non-optional int. nonce: int = field(metadata=fields.nonce_metadata) @@ -713,7 +716,7 @@ def verify_version(self): verify_version( version=self.version, expected_version=constants.TRANSACTION_VERSION, - only_query=False, + only_query=self.only_query, old_supported_versions=[], ) @@ -942,7 +945,7 @@ class InternalDeploy(InternalTransaction): # is accessed as an integer, the property 'class_hash' is used. contract_hash: bytes = field(metadata=fields.non_required_class_hash_metadata) - constructor_calldata: List[int] = field(metadata=fields.call_data_metadata) + constructor_calldata: List[int] = field(metadata=fields.calldata_metadata) # Class variables. tx_type: ClassVar[TransactionType] = TransactionType.DEPLOY @@ -1167,7 +1170,7 @@ class InternalInvokeFunction(InternalAccountTransaction): # The decorator type of the called function. Note that a single function may be decorated with # multiple decorators and this member specifies which one. entry_point_type: EntryPointType - calldata: List[int] = field(metadata=fields.call_data_metadata) + calldata: List[int] = field(metadata=fields.calldata_metadata) # Class variables. tx_type: ClassVar[TransactionType] = TransactionType.INVOKE_FUNCTION @@ -1454,7 +1457,7 @@ class InternalL1Handler(InternalTransaction): contract_address: int = field(metadata=fields.contract_address_metadata) entry_point_selector: int = field(metadata=fields.entry_point_selector_metadata) - calldata: List[int] = field(metadata=fields.call_data_metadata) + calldata: List[int] = field(metadata=fields.calldata_metadata) # A unique nonce, added by the StarkNet core contract on L1. Guarantees a unique # hash_value of transactions. nonce: Optional[int] = field(metadata=fields.optional_nonce_metadata) diff --git a/src/starkware/starknet/business_logic/utils.py b/src/starkware/starknet/business_logic/utils.py index 0059438c..fc55e021 100644 --- a/src/starkware/starknet/business_logic/utils.py +++ b/src/starkware/starknet/business_logic/utils.py @@ -204,8 +204,12 @@ def calculate_tx_resources( Used for transaction fee; calculation is made as if the transaction is the first in batch, for consistency. """ - (n_modified_contracts, n_storage_changes) = state.count_actual_storage_changes() - n_class_updates = state.count_actual_class_updates() + ( + n_modified_contracts, + n_storage_changes, + n_class_updates, + _n_nonce_updates, + ) = state.count_actual_updates() non_optional_call_infos = [call for call in call_infos if call is not None] l2_to_l1_messages = [] diff --git a/src/starkware/starknet/cli/CMakeLists.txt b/src/starkware/starknet/cli/CMakeLists.txt index 1e226cfd..d7d76cf0 100644 --- a/src/starkware/starknet/cli/CMakeLists.txt +++ b/src/starkware/starknet/cli/CMakeLists.txt @@ -26,6 +26,7 @@ python_lib(starknet_cli_lib FILES class_hash.py + compiled_class_hash.py reconstruct_starknet_traceback.py starknet_cli.py @@ -40,6 +41,7 @@ python_lib(starknet_cli_lib starknet_cli_utils_lib starknet_compile_lib starknet_contract_class_lib + starknet_contract_class_utils_lib starknet_definitions_lib starknet_feeder_gateway_client_lib starknet_feeder_gateway_request_objects_lib diff --git a/src/starkware/starknet/cli/class_hash.py b/src/starkware/starknet/cli/class_hash.py old mode 100644 new mode 100755 index f9ed0058..1d2f3806 --- a/src/starkware/starknet/cli/class_hash.py +++ b/src/starkware/starknet/cli/class_hash.py @@ -1,20 +1,27 @@ +#!/usr/bin/env python3 + import argparse +import json from starkware.cairo.lang.version import __version__ from starkware.cairo.lang.vm.crypto import get_crypto_lib_context_manager +from starkware.starknet.core.os.contract_class.class_hash import compute_class_hash from starkware.starknet.core.os.contract_class.deprecated_class_hash import ( compute_deprecated_class_hash, ) from starkware.starknet.services.api.contract_class.contract_class import DeprecatedCompiledClass +from starkware.starknet.services.api.contract_class.contract_class_utils import ( + load_sierra_from_dict, +) def main(): parser = argparse.ArgumentParser( - description="A tool to compute the class hash of a StarkNet contract" + description="A tool to compute the class hash of a Starknet contract." ) parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}") parser.add_argument( - "compiled_contract", + "contract_class", type=argparse.FileType("r"), help="The name of the contract JSON file.", ) @@ -25,11 +32,24 @@ def main(): choices=["Debug", "Release", "RelWithDebInfo"], help="Build flavor", ) + parser.add_argument( + "--deprecated", + action="store_true", + help=( + "Compute the class hash of a deprecated compiled contract (i.e., a Cairo v0 contract)." + ), + ) args = parser.parse_args() with get_crypto_lib_context_manager(args.flavor): - compiled_contract = DeprecatedCompiledClass.loads(data=args.compiled_contract.read()) - print(hex(compute_deprecated_class_hash(compiled_contract))) + if args.deprecated: + deprecated_compiled_contract = DeprecatedCompiledClass.loads( + data=args.contract_class.read() + ) + print(hex(compute_deprecated_class_hash(contract_class=deprecated_compiled_contract))) + else: + contract_class = load_sierra_from_dict(sierra=json.load(args.contract_class)) + print(hex(compute_class_hash(contract_class=contract_class))) if __name__ == "__main__": diff --git a/src/starkware/starknet/cli/compiled_class_hash.py b/src/starkware/starknet/cli/compiled_class_hash.py new file mode 100755 index 00000000..2c0cc64f --- /dev/null +++ b/src/starkware/starknet/cli/compiled_class_hash.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + +import argparse +import json + +from starkware.cairo.lang.version import __version__ +from starkware.cairo.lang.vm.crypto import get_crypto_lib_context_manager +from starkware.starknet.core.os.contract_class.compiled_class_hash import ( + compute_compiled_class_hash, +) +from starkware.starknet.services.api.contract_class.contract_class import CompiledClass + + +def main(): + parser = argparse.ArgumentParser( + description="A tool to compute the compiled class hash of a Starknet contract." + ) + parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}") + parser.add_argument( + "compiled_class", + type=argparse.FileType("r"), + help="The name of the compiled contract JSON file.", + ) + parser.add_argument( + "--flavor", + type=str, + default="Release", + choices=["Debug", "Release", "RelWithDebInfo"], + help="Build flavor", + ) + args = parser.parse_args() + + with get_crypto_lib_context_manager(args.flavor): + compiled_class = CompiledClass.load(json.load(args.compiled_class)) + print(hex(compute_compiled_class_hash(compiled_class=compiled_class))) + + +if __name__ == "__main__": + main() diff --git a/src/starkware/starknet/cli/starknet_cli.py b/src/starkware/starknet/cli/starknet_cli.py index 4d71f9b4..8a933814 100755 --- a/src/starkware/starknet/cli/starknet_cli.py +++ b/src/starkware/starknet/cli/starknet_cli.py @@ -6,7 +6,7 @@ import os import sys import traceback -from typing import Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Awaitable, Callable, Dict, List, Optional, Tuple from web3 import Web3 @@ -17,6 +17,7 @@ from starkware.starknet.cli.reconstruct_starknet_traceback import reconstruct_starknet_traceback from starkware.starknet.cli.starknet_cli_utils import ( AbiFormatError, + DeclareArgs, DeprecatedDeclareArgs, InvokeFunctionArgs, NetworkData, @@ -24,6 +25,7 @@ compute_max_fee_for_tx, construct_declare_tx, construct_deploy_account_tx, + construct_deprecated_declare_tx, construct_feeder_gateway_client, construct_gateway_client, construct_invoke_tx, @@ -37,12 +39,18 @@ simulate_tx_at_block, simulate_tx_at_pending_block, tx_received, - validate_call_args, +) +from starkware.starknet.core.os.contract_class.compiled_class_hash import ( + compute_compiled_class_hash, ) from starkware.starknet.definitions import fields from starkware.starknet.definitions.general_config import StarknetChainId, StarknetGeneralConfig from starkware.starknet.public.abi import AbiType from starkware.starknet.services.api.contract_class.contract_class import DeprecatedCompiledClass +from starkware.starknet.services.api.contract_class.contract_class_utils import ( + compile_contract_class, + load_sierra_from_dict, +) from starkware.starknet.services.api.feeder_gateway.feeder_gateway_client import FeederGatewayClient from starkware.starknet.services.api.feeder_gateway.request_objects import ( CallFunction, @@ -56,6 +64,7 @@ from starkware.starknet.services.api.gateway.gateway_client import GatewayClient from starkware.starknet.services.api.gateway.transaction import ( AccountTransaction, + Declare, DeployAccount, DeprecatedDeclare, InvokeFunction, @@ -70,18 +79,25 @@ async def declare(args: argparse.Namespace, command_args: List[str]): """ - Creates a Cairo-0 declare transaction and sends it to the gateway. In case a wallet is - provided, the transaction is wrapped and signed by the wallet provider. Otherwise, a sender - address and a valid signature must be provided as arguments. + If the `--deprecated` flag is used, creates a version 1 - Declare transaction (which is used to + declare Cairo 0 contracts) and sends it to the gateway. If the `--deprecated` flag is not used, + creates a version 2 - Declare transaction (which is used to declare Cairo 1.0 contracts) and + sends it to the gateway. In case a wallet is provided, the transaction is wrapped and signed by + the wallet provider. Otherwise, a sender address and a valid signature must be provided as + arguments. """ parser = argparse.ArgumentParser(description="Sends a declare transaction to StarkNet.") add_declare_tx_arguments(parser=parser) parser.parse_args(command_args, namespace=args) - declare_tx_args = parse_declare_tx_args(args=args) + if args.deprecated: + await deprecated_declare(args=args) + return + has_wallet = get_wallet_provider(args=args) is not None + declare_tx_args = parse_declare_tx_args(args=args) - declare_tx_for_simulate: Optional[DeprecatedDeclare] = None + declare_tx_for_simulate: Optional[Declare] = None if need_simulate_tx(args=args, has_wallet=has_wallet): declare_tx_for_simulate = await create_declare_tx( args=args, @@ -94,10 +110,6 @@ async def declare(args: argparse.Namespace, command_args: List[str]): await simulate_or_estimate_fee(args=args, tx=declare_tx_for_simulate) return - assert args.block_hash is None and args.block_number is None, ( - "--block_hash and --block_number should only be passed when either --simulate or " - "--estimate_fee flag are used." - ) max_fee = await compute_max_fee(args=args, tx=declare_tx_for_simulate, has_wallet=has_wallet) tx = await create_declare_tx( @@ -119,6 +131,50 @@ async def declare(args: argparse.Namespace, command_args: List[str]): ) +async def deprecated_declare(args: argparse.Namespace): + """ + Creates a DeprecatedDeclare transaction and sends it to the gateway. In case a wallet is + provided, the transaction is wrapped and signed by the wallet provider. Otherwise, a sender + address and a valid signature must be provided as arguments. + """ + + declare_tx_args = parse_deprecated_declare_tx_args(args=args) + has_wallet = get_wallet_provider(args=args) is not None + + declare_tx_for_simulate: Optional[DeprecatedDeclare] = None + if need_simulate_tx(args=args, has_wallet=has_wallet): + declare_tx_for_simulate = await create_deprecated_declare_tx( + args=args, + declare_tx_args=declare_tx_args, + max_fee=args.max_fee if args.max_fee is not None else 0, + has_wallet=has_wallet, + query=True, + ) + if args.simulate or args.estimate_fee: + await simulate_or_estimate_fee(args=args, tx=declare_tx_for_simulate) + return + + max_fee = await compute_max_fee(args=args, tx=declare_tx_for_simulate, has_wallet=has_wallet) + + tx = await create_deprecated_declare_tx( + args=args, + declare_tx_args=declare_tx_args, + max_fee=max_fee, + has_wallet=has_wallet, + query=False, + ) + gateway_client = get_gateway_client(args) + gateway_response = await gateway_client.add_transaction(tx=tx, token=args.token) + assert_tx_received(gateway_response=gateway_response) + # Don't end sentences with '.', to allow easy double-click copy-pasting of the values. + print( + f"""\ +DeprecatedDeclare transaction was sent. +Contract class hash: {gateway_response['class_hash']} +Transaction hash: {gateway_response['transaction_hash']}""" + ) + + async def deploy(args, command_args): parser = argparse.ArgumentParser(description="Deploys a contract to StarkNet.") parser.add_argument( @@ -233,10 +289,6 @@ async def deploy_account(args: argparse.Namespace, command_args: List[str]): await simulate_or_estimate_fee(args=args, tx=deploy_account_tx_for_simulate) return - assert args.block_hash is None and args.block_number is None, ( - "--block_hash and --block_number should only be passed when either --simulate or " - "--estimate_fee flag are used." - ) max_fee = await compute_max_fee(args=args, tx=deploy_account_tx_for_simulate, has_wallet=True) tx, contract_address = await create_deploy_account_tx( @@ -305,19 +357,10 @@ async def invoke(args: argparse.Namespace, command_args: List[str]): await simulate_or_estimate_fee(args=args, tx=invoke_tx_for_simulate) return - assert args.block_hash is None and args.block_number is None, ( - "--block_hash and --block_number should only be passed when --simulate or " - "--estimate_fee flag is used." - ) - if args.dry_run: assert has_wallet, "--dry_run can only be used for invocation through an account contract." - max_fee = await compute_max_fee( - args=args, - tx=invoke_tx_for_simulate, - has_wallet=has_wallet, - ) + max_fee = await compute_max_fee(args=args, tx=invoke_tx_for_simulate, has_wallet=has_wallet) tx = await create_invoke_tx( args=args, @@ -667,12 +710,12 @@ async def get_storage_at(args, command_args): # Utilities. -def load_abi(args) -> AbiType: +def load_abi(args) -> Optional[AbiType]: """ - Raises an error if ABI doesn't exist / fails to load. + Raises an error if ABI fails to load. Returns None if ABI doesn't exist. """ try: - return json.load(args.abi) + return None if args.abi is None else json.load(args.abi) except Exception as ex: raise AbiFormatError(ex) from ex @@ -815,7 +858,9 @@ async def compute_max_fee( if has_wallet: max_fee = await compute_max_fee_for_tx( - feeder_client=get_feeder_gateway_client(args), tx=as_non_optional(tx) + feeder_client=get_feeder_gateway_client(args), + tx=as_non_optional(tx), + skip_validate=args.skip_validate, ) max_fee_eth = float(Web3.fromWei(max_fee, "ether")) @@ -827,7 +872,22 @@ async def compute_max_fee( def need_simulate_tx(args: argparse.Namespace, has_wallet: bool) -> bool: - return (args.max_fee is None and has_wallet) or args.simulate or args.estimate_fee + """ + Returns whether a simulate is required. + If simulation was not requested, asserts that no other simulation related flags appear. + """ + simulate_requested = args.simulate or args.estimate_fee + if not simulate_requested: + assert args.block_hash is None and args.block_number is None, ( + "--block_hash and --block_number should only be passed when either --simulate or " + "--estimate_fee flag are used." + ) + assert not args.skip_validate, ( + "--skip_validate should only be passed when either --simulate or " + "--estimate_fee flag are used." + ) + + return (args.max_fee is None and has_wallet) or simulate_requested async def load_account_from_args(args) -> Account: @@ -848,35 +908,20 @@ def handle_network_param(args): if network is not None: try: data = NetworkData.from_network_name(network=network) - if args.gateway_url is None: - args.gateway_url = data.gateway_url - if args.feeder_gateway_url is None: - args.feeder_gateway_url = data.feeder_gateway_url - if args.network_id is None: - args.network_id = data.network_id - if args.chain_id is None: - args.chain_id = data.chain_id except NetworkNameError as error: print(str(error), file=sys.stderr) return 1 - return 0 - + if args.gateway_url is None: + args.gateway_url = data.gateway_url + if args.feeder_gateway_url is None: + args.feeder_gateway_url = data.feeder_gateway_url + if args.network_id is None: + args.network_id = data.network_id + if args.chain_id is None: + args.chain_id = data.chain_id -def validate_call_function_args( - args: argparse.Namespace, - abi_entry_type: Union[Literal["function"], Literal["l1_handler"]], - inputs: List[int], -): - """ - Validates that the function name is in the ABI and that the inputs match the required structure. - """ - validate_call_args( - abi=load_abi(args=args), - abi_entry_name=args.function, - abi_entry_type=abi_entry_type, - inputs=inputs, - ) + return 0 def parse_call_function_args(args: argparse.Namespace) -> CallFunction: @@ -911,7 +956,23 @@ def parse_invoke_tx_args(args: argparse.Namespace) -> InvokeFunctionArgs: ) -def parse_declare_tx_args(args: argparse.Namespace) -> DeprecatedDeclareArgs: +def parse_declare_tx_args(args: argparse.Namespace) -> DeclareArgs: + sender = parse_hex_arg(arg=args.sender, arg_name="sender") if args.sender is not None else None + + contract_class = load_sierra_from_dict(sierra=json.load(args.contract)) + + compiled_class = compile_contract_class(contract_class=contract_class) + compiled_class_hash = compute_compiled_class_hash(compiled_class=compiled_class) + + return DeclareArgs( + sender=sender, + signature=cast_to_felts(values=args.signature), + compiled_class_hash=compiled_class_hash, + contract_class=contract_class, + ) + + +def parse_deprecated_declare_tx_args(args: argparse.Namespace) -> DeprecatedDeclareArgs: validate_max_fee(max_fee=args.max_fee) sender = parse_hex_arg(arg=args.sender, arg_name="sender") if args.sender is not None else None return DeprecatedDeclareArgs( @@ -971,6 +1032,28 @@ async def create_invoke_tx( async def create_declare_tx( + args: argparse.Namespace, + declare_tx_args: DeclareArgs, + max_fee: int, + has_wallet: bool, + query: bool, +) -> Declare: + """ + Creates and returns a DeprecatedDeclare transaction with the given parameters. + If a wallet provider was provided in args, that transaction will be wrapped and signed. + """ + return await construct_declare_tx( + feeder_client=get_feeder_gateway_client(args=args), + declare_tx_args=declare_tx_args, + chain_id=get_chain_id(args=args), + max_fee=max_fee, + account=await load_account_from_args(args=args) if has_wallet else None, + explicit_nonce=args.nonce, + simulate=query, + ) + + +async def create_deprecated_declare_tx( args: argparse.Namespace, declare_tx_args: DeprecatedDeclareArgs, max_fee: int, @@ -981,7 +1064,7 @@ async def create_declare_tx( Creates and returns a DeprecatedDeclare transaction with the given parameters. If a wallet provider was provided in args, that transaction will be wrapped and signed. """ - return await construct_declare_tx( + return await construct_deprecated_declare_tx( feeder_client=get_feeder_gateway_client(args=args), declare_tx_args=declare_tx_args, chain_id=get_chain_id(args=args), @@ -1021,14 +1104,18 @@ async def simulate_tx_inner( Returns a TransactionSimulationInfo object. """ feeder_client = get_feeder_gateway_client(args=args) + skip_validate = args.skip_validate if has_block_info: return await simulate_tx_at_block( feeder_client=feeder_client, tx=tx, block_hash=args.block_hash, block_number=args.block_number, + skip_validate=skip_validate, ) - return await simulate_tx_at_pending_block(feeder_client=feeder_client, tx=tx) + return await simulate_tx_at_pending_block( + feeder_client=feeder_client, tx=tx, skip_validate=skip_validate + ) def print_invoke_tx(tx: InvokeFunction, chain_id: int): @@ -1123,6 +1210,11 @@ def add_simulate_tx_arguments(parser: argparse.ArgumentParser): action="store_true", help="Estimates the fee of the transaction.", ) + parser.add_argument( + "--skip_validate", + action="store_true", + help="Skips the validate function on simulate and estimate_fee.", + ) add_block_identifier_arguments( parser=parser, block_role_description="be used as the context for the transaction simulation", @@ -1145,6 +1237,11 @@ def add_declare_tx_arguments(parser: argparse.ArgumentParser): type=str, help="The address of the account contract sending the transaction.", ) + parser.add_argument( + "--deprecated", + action="store_true", + help="Send a deprecated declare transaction (i.e., to declare a Cairo v0 contract).", + ) add_account_tx_arguments(parser=parser) parser.add_argument( "--token", type=str, help="Used for declaring contracts in Alpha MainNet.", required=False @@ -1159,9 +1256,7 @@ def add_call_function_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--address", type=str, required=True, help="The address of the invoked contract." ) - parser.add_argument( - "--abi", type=argparse.FileType("r"), required=True, help="The Cairo contract ABI." - ) + parser.add_argument("--abi", type=argparse.FileType("r"), help="The Cairo contract ABI.") parser.add_argument( "--function", type=str, required=True, help="The name of the invoked function." ) diff --git a/src/starkware/starknet/cli/starknet_cli_utils.py b/src/starkware/starknet/cli/starknet_cli_utils.py index c9df6b6d..6ed159f9 100644 --- a/src/starkware/starknet/cli/starknet_cli_utils.py +++ b/src/starkware/starknet/cli/starknet_cli_utils.py @@ -18,7 +18,10 @@ get_selector_from_name, ) from starkware.starknet.public.abi_structs import identifier_manager_from_abi -from starkware.starknet.services.api.contract_class.contract_class import DeprecatedCompiledClass +from starkware.starknet.services.api.contract_class.contract_class import ( + ContractClass, + DeprecatedCompiledClass, +) from starkware.starknet.services.api.feeder_gateway.feeder_gateway_client import ( CastableToHash, FeederGatewayClient, @@ -35,6 +38,7 @@ from starkware.starknet.services.api.gateway.gateway_client import GatewayClient from starkware.starknet.services.api.gateway.transaction import ( AccountTransaction, + Declare, DeployAccount, DeprecatedDeclare, InvokeFunction, @@ -128,6 +132,14 @@ class DeprecatedDeclareArgs: signature: List[int] +@dataclasses.dataclass +class DeclareArgs: + contract_class: ContractClass + compiled_class_hash: int + sender: Optional[int] + signature: List[int] + + def parse_block_identifiers( block_hash: Optional[CastableToHash], block_number: Optional[BlockIdentifier], @@ -305,20 +317,25 @@ async def load_account( async def simulate_tx_at_pending_block( - feeder_client: FeederGatewayClient, tx: AccountTransaction + feeder_client: FeederGatewayClient, tx: AccountTransaction, skip_validate: bool ) -> TransactionSimulationInfo: """ Simulates a transaction with the given parameters, relative to the state of the latest PENDING block. """ return await simulate_tx_at_block( - feeder_client=feeder_client, tx=tx, block_hash=None, block_number=PENDING_BLOCK_ID + feeder_client=feeder_client, + tx=tx, + block_hash=None, + block_number=PENDING_BLOCK_ID, + skip_validate=skip_validate, ) async def simulate_tx_at_block( feeder_client: FeederGatewayClient, tx: AccountTransaction, + skip_validate: bool, block_hash: Optional[CastableToHash] = None, block_number: Optional[BlockIdentifier] = None, ) -> TransactionSimulationInfo: @@ -327,15 +344,19 @@ async def simulate_tx_at_block( Returns a TransactionSimulationInfo object. """ return await feeder_client.simulate_transaction( - tx=tx, block_hash=block_hash, block_number=block_number + tx=tx, block_hash=block_hash, block_number=block_number, skip_validate=skip_validate ) -async def compute_max_fee_for_tx(feeder_client: FeederGatewayClient, tx: AccountTransaction) -> int: +async def compute_max_fee_for_tx( + feeder_client: FeederGatewayClient, tx: AccountTransaction, skip_validate: bool +) -> int: """ Given a transaction, estimates and returns the max fee. """ - simulate_tx_info = await simulate_tx_at_pending_block(feeder_client=feeder_client, tx=tx) + simulate_tx_info = await simulate_tx_at_pending_block( + feeder_client=feeder_client, tx=tx, skip_validate=skip_validate + ) return math.ceil(simulate_tx_info.fee_estimation.overall_fee * FEE_MARGIN_OF_ESTIMATION) @@ -356,14 +377,15 @@ async def get_nonce(address: int) -> int: def create_call_function( - contract_address: int, abi: AbiType, function_name: str, inputs: List[int] + contract_address: int, abi: Optional[AbiType], function_name: str, inputs: List[int] ) -> CallFunction: """ Constructs a CallFunction object for the given parameters. """ - validate_call_args( - abi=abi, abi_entry_name=function_name, abi_entry_type="function", inputs=inputs - ) + if abi is not None: + validate_call_args( + abi=abi, abi_entry_name=function_name, abi_entry_type="function", inputs=inputs + ) return CallFunction( contract_address=contract_address, @@ -373,17 +395,22 @@ def create_call_function( def create_call_l1_handler( - abi: AbiType, handler_name: str, from_address: int, to_address: int, payload: List[int] + abi: Optional[AbiType], + handler_name: str, + from_address: int, + to_address: int, + payload: List[int], ) -> CallL1Handler: """ Constructs a CallL1Handler object for the given parameters. """ - validate_call_args( - abi=abi, - abi_entry_name=handler_name, - abi_entry_type="l1_handler", - inputs=[from_address] + payload, - ) + if abi is not None: + validate_call_args( + abi=abi, + abi_entry_name=handler_name, + abi_entry_type="l1_handler", + inputs=[from_address] + payload, + ) return CallL1Handler( from_address=from_address, @@ -473,6 +500,57 @@ async def construct_invoke_tx( async def construct_declare_tx( + feeder_client: FeederGatewayClient, + declare_tx_args: DeclareArgs, + chain_id: int, + max_fee: int, + account: Optional[Account], + explicit_nonce: Optional[int], + simulate: bool, +) -> Declare: + """ + Creates and returns a DeprecatedDeclare transaction with the given parameters. + If an account is provided, that transaction will be wrapped and signed. + """ + version = constants.QUERY_DECLARE_VERSION if simulate else constants.DECLARE_VERSION + nonce_callback = construct_nonce_callback( + feeder_client=feeder_client, explicit_nonce=explicit_nonce + ) + if account is None: + # Declare directly. + assert ( + declare_tx_args.sender is not None + ), "Sender must be passed explicitly when making a direct declaration using --no_wallet." + return Declare( + contract_class=declare_tx_args.contract_class, + compiled_class_hash=declare_tx_args.compiled_class_hash, + sender_address=declare_tx_args.sender, + max_fee=max_fee, + version=version, + signature=declare_tx_args.signature, + nonce=await nonce_callback(declare_tx_args.sender), + ) + + # Declare through the account contract. + assert declare_tx_args.sender is None, ( + "Sender cannot be passed explicitly when using an account contract. " + "Consider making a direct declaration using --no_wallet." + ) + assert declare_tx_args.signature == [], ( + "Signature cannot be passed explicitly when using an account contract. " + "Consider making a direct declaration using --no_wallet." + ) + return await account.declare( + contract_class=declare_tx_args.contract_class, + compiled_class_hash=declare_tx_args.compiled_class_hash, + chain_id=chain_id, + max_fee=max_fee, + version=version, + nonce_callback=nonce_callback, + ) + + +async def construct_deprecated_declare_tx( feeder_client: FeederGatewayClient, declare_tx_args: DeprecatedDeclareArgs, chain_id: int, diff --git a/src/starkware/starknet/common/new_syscalls.cairo b/src/starkware/starknet/common/new_syscalls.cairo index 6d1bb4a7..942d8dd0 100644 --- a/src/starkware/starknet/common/new_syscalls.cairo +++ b/src/starkware/starknet/common/new_syscalls.cairo @@ -1,10 +1,60 @@ // Syscall selectors. -const GET_CALLER_ADDRESS_SELECTOR = 'GetCallerAddress'; +const CALL_CONTRACT_SELECTOR = 'CallContract'; +const DEPLOY_SELECTOR = 'Deploy'; const EMIT_EVENT_SELECTOR = 'EmitEvent'; +const GET_EXECUTION_INFO_SELECTOR = 'GetExecutionInfo'; +const LIBRARY_CALL_SELECTOR = 'LibraryCall'; +const REPLACE_CLASS_SELECTOR = 'ReplaceClass'; +const SEND_MESSAGE_TO_L1_SELECTOR = 'SendMessageToL1'; const STORAGE_READ_SELECTOR = 'StorageRead'; const STORAGE_WRITE_SELECTOR = 'StorageWrite'; +// Syscall structs. + +struct ExecutionInfo { + block_info: BlockInfo*, + tx_info: TxInfo*, + + // Entry-point-specific info. + + caller_address: felt, + // The execution is done in the context of the contract at this address. + // It controls the storage being used, messages sent to L1, calling contracts, etc. + contract_address: felt, + // The entry point selector. + selector: felt, +} + +struct BlockInfo { + block_number: felt, + block_timestamp: felt, + // The address of the sequencer that is creating this block. + sequencer_address: felt, +} + +struct TxInfo { + // The version of the transaction. It is fixed in the OS, and should be signed by the account + // contract. + // This field allows invalidating old transactions, whenever the meaning of the other + // transaction fields is changed (in the OS). + version: felt, + // The account contract from which this transaction originates. + account_contract_address: felt, + // The max_fee field of the transaction. + max_fee: felt, + // The signature of the transaction. + signature_start: felt*, + signature_end: felt*, + // The hash of the transaction. + transaction_hash: felt, + // The identifier of the chain. + // This field can be used to prevent replay of testnet transactions on mainnet. + chain_id: felt, + // The transaction's nonce. + nonce: felt, +} + // Shared attributes. struct RequestHeader { @@ -28,9 +78,41 @@ struct FailureReason { // Syscall requests. +struct CallContractRequest { + // The address of the L2 contract to call. + contract_address: felt, + // The selector of the function to call. + selector: felt, + // The calldata. + calldata_start: felt*, + calldata_end: felt*, +} + +struct LibraryCallRequest { + // The hash of the class to run. + class_hash: felt, + // The selector of the function to call. + selector: felt, + // The calldata. + calldata_start: felt*, + calldata_end: felt*, +} + struct EmptyRequest { } +struct DeployRequest { + // The hash of the class to deploy. + class_hash: felt, + // A salt for the new contract address calculation. + contract_address_salt: felt, + // The calldata for the constructor. + constructor_calldata_start: felt*, + constructor_calldata_end: felt*, + // Used for deterministic contract address deployment. + deploy_from_zero: felt, +} + struct StorageReadRequest { reserved: felt, key: felt, @@ -49,12 +131,33 @@ struct EmitEventRequest { data_end: felt*, } +struct ReplaceClassRequest { + class_hash: felt, +} + +struct SendMessageToL1Request { + to_address: felt, + payload_start: felt*, + payload_end: felt*, +} + // Syscall responses. +struct CallContractResponse { + retdata_start: felt*, + retdata_end: felt*, +} + +struct DeployResponse { + contract_address: felt, + constructor_retdata_start: felt*, + constructor_retdata_end: felt*, +} + struct StorageReadResponse { value: felt, } -struct GetCallerAddressResponse { - caller_address: felt, +struct GetExecutionInfoResponse { + execution_info: ExecutionInfo*, } diff --git a/src/starkware/starknet/compiler/CMakeLists.txt b/src/starkware/starknet/compiler/CMakeLists.txt index b36ec9ac..9a5c3fa8 100644 --- a/src/starkware/starknet/compiler/CMakeLists.txt +++ b/src/starkware/starknet/compiler/CMakeLists.txt @@ -1,3 +1,6 @@ +add_subdirectory(v1) + + python_lib(starknet_compile_lib PREFIX starkware/starknet/compiler diff --git a/src/starkware/starknet/compiler/v1/BUILD.cairo-lang-1.0.0 b/src/starkware/starknet/compiler/v1/BUILD.cairo-lang-1.0.0 new file mode 100644 index 00000000..8d99c63c --- /dev/null +++ b/src/starkware/starknet/compiler/v1/BUILD.cairo-lang-1.0.0 @@ -0,0 +1,7 @@ +package(default_visibility = ["//visibility:public"]) + +filegroup( + name = "cairo-lang-1.0.0", + srcs = glob(["**/*"]), + visibility = ["//visibility:public"], +) diff --git a/src/starkware/starknet/compiler/v1/CMakeLists.txt b/src/starkware/starknet/compiler/v1/CMakeLists.txt new file mode 100644 index 00000000..0709225e --- /dev/null +++ b/src/starkware/starknet/compiler/v1/CMakeLists.txt @@ -0,0 +1,108 @@ +set(CAIRO_COMPILER_DUMMY_FILE "${CMAKE_CURRENT_BINARY_DIR}/cairo_compiler_v1") +set(CAIRO_COMPILER_DIR "${CMAKE_CURRENT_BINARY_DIR}/cairo") + +set(CAIRO_COMPILER_FILES + "${CAIRO_COMPILER_DIR}/bin/cairo-compile" + "${CAIRO_COMPILER_DIR}/bin/cairo-format" + "${CAIRO_COMPILER_DIR}/bin/cairo-language-server" + "${CAIRO_COMPILER_DIR}/bin/cairo-run" + "${CAIRO_COMPILER_DIR}/bin/cairo-test" + "${CAIRO_COMPILER_DIR}/bin/sierra-compile" + "${CAIRO_COMPILER_DIR}/bin/starknet-compile" + "${CAIRO_COMPILER_DIR}/bin/starknet-sierra-compile" + "${CAIRO_COMPILER_DIR}/corelib/cairo_project.toml" + "${CAIRO_COMPILER_DIR}/corelib/Scarb.toml" + "${CAIRO_COMPILER_DIR}/corelib/src/array.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/box.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/clone.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/debug.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/dict.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/ec.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/ecdsa.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/gas.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/hash.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/integer.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/internal.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/lib.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/nullable.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/option.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/result.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/serde.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet/class_hash.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet/contract_address.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet/info.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet/storage_access.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet/syscalls.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet/testing.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/test.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/testing.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/traits.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/zeroable.cairo" +) + +set(CAIRO_COMPILER_ARTIFACTS + "${CAIRO_COMPILER_DIR}/bin/cairo-compile bin/cairo-compile" + "${CAIRO_COMPILER_DIR}/bin/cairo-format bin/cairo-format" + "${CAIRO_COMPILER_DIR}/bin/cairo-language-server bin/cairo-language-server" + "${CAIRO_COMPILER_DIR}/bin/cairo-run bin/cairo-run" + "${CAIRO_COMPILER_DIR}/bin/cairo-test bin/cairo-test" + "${CAIRO_COMPILER_DIR}/bin/sierra-compile bin/sierra-compile" + "${CAIRO_COMPILER_DIR}/bin/starknet-compile bin/starknet-compile" + "${CAIRO_COMPILER_DIR}/bin/starknet-sierra-compile bin/starknet-sierra-compile" + "${CAIRO_COMPILER_DIR}/corelib/cairo_project.toml corelib/cairo_project.toml" + "${CAIRO_COMPILER_DIR}/corelib/Scarb.toml corelib/Scarb.toml" + "${CAIRO_COMPILER_DIR}/corelib/src/array.cairo corelib/src/array.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/box.cairo corelib/src/box.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/clone.cairo corelib/src/clone.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/debug.cairo corelib/src/debug.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/dict.cairo corelib/src/dict.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/ec.cairo corelib/src/ec.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/ecdsa.cairo corelib/src/ecdsa.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/gas.cairo corelib/src/gas.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/hash.cairo corelib/src/hash.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/integer.cairo corelib/src/integer.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/internal.cairo corelib/src/internal.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/lib.cairo corelib/src/lib.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/nullable.cairo corelib/src/nullable.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/option.cairo corelib/src/option.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/result.cairo corelib/src/result.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/serde.cairo corelib/src/serde.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet.cairo corelib/src/starknet.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet/class_hash.cairo corelib/src/starknet/class_hash.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet/contract_address.cairo corelib/src/starknet/contract_address.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet/info.cairo corelib/src/starknet/info.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet/storage_access.cairo corelib/src/starknet/storage_access.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet/syscalls.cairo corelib/src/starknet/syscalls.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/starknet/testing.cairo corelib/src/starknet/testing.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/test.cairo corelib/src/test.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/testing.cairo corelib/src/testing.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/traits.cairo corelib/src/traits.cairo" + "${CAIRO_COMPILER_DIR}/corelib/src/zeroable.cairo corelib/src/zeroable.cairo" +) + +add_custom_command( + OUTPUT "${CAIRO_COMPILER_DUMMY_FILE}" "${CAIRO_COMPILER_FILES}" + COMMAND curl -Lo release-x86_64-unknown-linux-musl.tar.gz https://github.com/starkware-libs/cairo/releases/download/v1.0.0-alpha.5/release-x86_64-unknown-linux-musl.tar.gz + COMMAND tar -xf release-x86_64-unknown-linux-musl.tar.gz + COMMAND touch "${CAIRO_COMPILER_DUMMY_FILE}" + COMMENT "Downloading cairo compiler." +) + +add_custom_target(get_cairo_compiler DEPENDS ${CAIRO_COMPILER_DUMMY_FILE}) + +python_lib(starknet_compile_v1_lib + PREFIX starkware/starknet/compiler/v1 + + FILES + compile.py + + ARTIFACTS + "${CAIRO_COMPILER_ARTIFACTS}" + + LIBS + starknet_definitions_lib + starkware_error_handling_lib +) + +add_dependencies(starknet_compile_v1_lib get_cairo_compiler) diff --git a/src/starkware/starknet/compiler/v1/compile.py b/src/starkware/starknet/compiler/v1/compile.py new file mode 100644 index 00000000..1206f2ef --- /dev/null +++ b/src/starkware/starknet/compiler/v1/compile.py @@ -0,0 +1,100 @@ +import json +import os +import subprocess +import tempfile +from typing import Any, Dict, List, Optional + +from starkware.starknet.definitions.error_codes import StarknetErrorCode +from starkware.starkware_utils.error_handling import StarkException + +JsonObject = Dict[str, Any] + +DEFAULT_ALLOWED_LIBFUNCS_ARG: List[str] = [] + +if "RUNFILES_DIR" in os.environ: + from bazel_tools.tools.python.runfiles import runfiles + + r = runfiles.Create() + + STARKNET_SIERRA_COMPILE = r.Rlocation("cairo-lang-1.0.0/bin/starknet-sierra-compile") + STARKNET_COMPILE = r.Rlocation("cairo-lang-1.0.0/bin/starknet-compile") +else: + STARKNET_SIERRA_COMPILE = os.path.join( + os.path.dirname(__file__), "bin", "starknet-sierra-compile" + ) + STARKNET_COMPILE = os.path.join(os.path.dirname(__file__), "bin", "starknet-compile") + + +def compile_cairo_to_sierra( + cairo_path: str, allowed_libfuncs_list_name: Optional[str] = None +) -> JsonObject: + """ + Compiles a Starknet Cairo 1.0 contract; returns the resulting Sierra as json. + """ + additional_args = ( + DEFAULT_ALLOWED_LIBFUNCS_ARG + if allowed_libfuncs_list_name is None + else ["--allowed-libfuncs-list-name", allowed_libfuncs_list_name] + ) + return run_compile_command(command=[STARKNET_COMPILE, cairo_path, *additional_args]) + + +def compile_sierra_to_casm( + sierra_path: str, allowed_libfuncs_list_name: Optional[str] = None +) -> JsonObject: + """ + Compiles a Starknet Sierra contract; returns the resulting Casm as json. + """ + additional_args = ( + DEFAULT_ALLOWED_LIBFUNCS_ARG + if allowed_libfuncs_list_name is None + else ["--allowed-libfuncs-list-name", allowed_libfuncs_list_name] + ) + return run_compile_command( + command=[STARKNET_SIERRA_COMPILE, sierra_path, "--add-pythonic-hints", *additional_args] + ) + + +def compile_cairo_to_casm( + cairo_path: str, allowed_libfuncs_list_name: Optional[str] = None +) -> JsonObject: + """ + Compiles a Starknet Cairo 1.0 contract to Casm; returns the resulting Casm as json. + """ + raw_sierra = compile_cairo_to_sierra( + cairo_path=cairo_path, allowed_libfuncs_list_name=allowed_libfuncs_list_name + ) + with tempfile.NamedTemporaryFile(mode="w") as sierra_file: + json.dump(obj=raw_sierra, fp=sierra_file, indent=2) + sierra_file.flush() + + return compile_sierra_to_casm( + sierra_path=sierra_file.name, allowed_libfuncs_list_name=allowed_libfuncs_list_name + ) + + +def run_compile_command(command: List[str]) -> JsonObject: + try: + result: subprocess.CompletedProcess = subprocess.run(command, capture_output=True) + except subprocess.CalledProcessError: + # The inner command is responsible for printing the error message. No need to print the + # stack trace of this script. + raise StarkException( + code=StarknetErrorCode.COMPILATION_FAILED, + message="Compilation failed. Invalid file path input.", + ) + + if result is None: + raise StarkException( + code=StarknetErrorCode.COMPILATION_FAILED, + message="Compilation failed.", + ) + + if result.returncode != 0: + raise StarkException( + code=StarknetErrorCode.COMPILATION_FAILED, + message=f"Compilation failed. Error: {result.stderr.decode()}", + ) + + # Read and return the compilation result from the output. + return json.loads(result.stdout.decode()) diff --git a/src/starkware/starknet/core/os/block_context.cairo b/src/starkware/starknet/core/os/block_context.cairo index af826e37..8842df25 100644 --- a/src/starkware/starknet/core/os/block_context.cairo +++ b/src/starkware/starknet/core/os/block_context.cairo @@ -1,5 +1,6 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin, PoseidonBuiltin from starkware.cairo.common.registers import get_fp_and_pc +from starkware.starknet.common.new_syscalls import BlockInfo from starkware.starknet.core.os.builtins import BuiltinParams, get_builtin_params from starkware.starknet.core.os.contract_class.compiled_class import ( CompiledClassFact, @@ -11,12 +12,6 @@ from starkware.starknet.core.os.contract_class.deprecated_compiled_class import ) from starkware.starknet.core.os.os_config.os_config import StarknetOsConfig -struct BlockInfo { - // Currently, the block timestamp is not validated. - block_timestamp: felt, - block_number: felt, -} - // Represents information that is the same throughout the block. struct BlockContext { // Parameters for select_builtins. @@ -32,10 +27,8 @@ struct BlockContext { n_deprecated_compiled_class_facts: felt, deprecated_compiled_class_facts: DeprecatedCompiledClassFact*, - // The address of the sequencer that is creating this block. - sequencer_address: felt, // Information about the block. - block_info: BlockInfo, + block_info: BlockInfo*, // StarknetOsConfig instance. starknet_os_config: StarknetOsConfig, // A function pointer to the 'execute_syscalls' function. @@ -62,10 +55,10 @@ func get_block_context{poseidon_ptr: PoseidonBuiltin*, pedersen_ptr: HashBuiltin compiled_class_facts=compiled_class_facts, n_deprecated_compiled_class_facts=n_deprecated_compiled_class_facts, deprecated_compiled_class_facts=deprecated_compiled_class_facts, - sequencer_address=nondet %{ os_input.general_config.sequencer_address %}, - block_info=BlockInfo( - block_timestamp=nondet %{ deprecated_syscall_handler.block_info.block_timestamp %}, + block_info=new BlockInfo( block_number=nondet %{ deprecated_syscall_handler.block_info.block_number %}, + block_timestamp=nondet %{ deprecated_syscall_handler.block_info.block_timestamp %}, + sequencer_address=nondet %{ os_input.general_config.sequencer_address %}, ), starknet_os_config=StarknetOsConfig( chain_id=nondet %{ os_input.general_config.chain_id.value %}, diff --git a/src/starkware/starknet/core/os/builtins.cairo b/src/starkware/starknet/core/os/builtins.cairo index 7fc9dbce..ee0ba1d3 100644 --- a/src/starkware/starknet/core/os/builtins.cairo +++ b/src/starkware/starknet/core/os/builtins.cairo @@ -6,6 +6,7 @@ from starkware.cairo.common.cairo_builtins import ( SignatureBuiltin, ) from starkware.cairo.common.registers import get_fp_and_pc +from starkware.starknet.builtins.segment_arena.segment_arena import SegmentArenaBuiltin struct BuiltinPointers { pedersen: HashBuiltin*, @@ -14,6 +15,7 @@ struct BuiltinPointers { bitwise: felt, ec_op: felt, poseidon: PoseidonBuiltin*, + segment_arena: SegmentArenaBuiltin*, } // A struct containing the ASCII encoding of each builtin. @@ -24,6 +26,7 @@ struct BuiltinEncodings { bitwise: felt, ec_op: felt, poseidon: felt, + segment_arena: felt, } // A struct containing the instance size of each builtin. @@ -34,6 +37,7 @@ struct BuiltinInstanceSizes { bitwise: felt, ec_op: felt, poseidon: felt, + segment_arena: felt, } struct BuiltinParams { @@ -52,6 +56,7 @@ func get_builtin_params() -> (builtin_params: BuiltinParams*) { bitwise='bitwise', ec_op='ec_op', poseidon='poseidon', + segment_arena='segment_arena', ); local builtin_instance_sizes: BuiltinInstanceSizes = BuiltinInstanceSizes( @@ -61,6 +66,7 @@ func get_builtin_params() -> (builtin_params: BuiltinParams*) { bitwise=BitwiseBuiltin.SIZE, ec_op=EcOpBuiltin.SIZE, poseidon=PoseidonBuiltin.SIZE, + segment_arena=SegmentArenaBuiltin.SIZE, ); local builtin_params: BuiltinParams = BuiltinParams( diff --git a/src/starkware/starknet/core/os/constants.cairo b/src/starkware/starknet/core/os/constants.cairo index 0bf7b34d..d7d3add4 100644 --- a/src/starkware/starknet/core/os/constants.cairo +++ b/src/starkware/starknet/core/os/constants.cairo @@ -10,6 +10,8 @@ const DECLARE_VERSION = 2; const TRANSACTION_VERSION = 1; const L1_HANDLER_VERSION = 0; +const SIERRA_ARRAY_LEN_BOUND = 2 ** 32; + // get_selector_from_name('constructor'). const CONSTRUCTOR_ENTRY_POINT_SELECTOR = ( 0x28ffe4ff0f226a9107253e17a904099aa4f63a02a5621de0576e5aa71bc5194 @@ -69,10 +71,15 @@ const TRANSACTION_GAS_COST = (2 * ENTRY_POINT_GAS_COST) + FEE_TRANSFER_GAS_COST 100 * STEP_GAS_COST ); // Syscall gas costs. +const CALL_CONTRACT_GAS_COST = SYSCALL_BASE_GAS_COST + 10 * STEP_GAS_COST + ENTRY_POINT_GAS_COST; +const DEPLOY_GAS_COST = SYSCALL_BASE_GAS_COST + 200 * STEP_GAS_COST + ENTRY_POINT_GAS_COST; +const GET_EXECUTION_INFO_GAS_COST = SYSCALL_BASE_GAS_COST + 10 * STEP_GAS_COST; +const LIBRARY_CALL_GAS_COST = CALL_CONTRACT_GAS_COST; +const REPLACE_CLASS_GAS_COST = SYSCALL_BASE_GAS_COST + 50 * STEP_GAS_COST; const STORAGE_READ_GAS_COST = SYSCALL_BASE_GAS_COST + 50 * STEP_GAS_COST; const STORAGE_WRITE_GAS_COST = SYSCALL_BASE_GAS_COST + 50 * STEP_GAS_COST; -const GET_CALLER_ADDRESS_GAS_COST = SYSCALL_BASE_GAS_COST + 10 * STEP_GAS_COST; const EMIT_EVENT_GAS_COST = SYSCALL_BASE_GAS_COST + 10 * STEP_GAS_COST; +const SEND_MESSAGE_TO_L1_GAS_COST = SYSCALL_BASE_GAS_COST + 50 * STEP_GAS_COST; // Cairo 1.0 error codes. const ERROR_OUT_OF_GAS = 'Out of gas'; diff --git a/src/starkware/starknet/core/os/contract_class/CMakeLists.txt b/src/starkware/starknet/core/os/contract_class/CMakeLists.txt index 476bc026..a6a427d5 100644 --- a/src/starkware/starknet/core/os/contract_class/CMakeLists.txt +++ b/src/starkware/starknet/core/os/contract_class/CMakeLists.txt @@ -3,12 +3,14 @@ python_lib(starknet_os_abi_lib FILES class_hash.py + class_hash_utils.py compiled_class.cairo compiled_class_hash.py + compiled_class_hash_utils.py contract_class.cairo - contract_class.py deprecated_class_hash.py deprecated_compiled_class.cairo + utils.py LIBS cairo_common_lib @@ -20,6 +22,8 @@ python_lib(starknet_os_abi_lib poseidon_utils_lib starknet_abi_lib starknet_contract_class_lib + starknet_definitions_lib + starkware_error_handling_lib starkware_python_utils_lib pip_cachetools ) diff --git a/src/starkware/starknet/core/os/contract_class/class_hash.py b/src/starkware/starknet/core/os/contract_class/class_hash.py index 5c823ebe..a6104fd5 100644 --- a/src/starkware/starknet/core/os/contract_class/class_hash.py +++ b/src/starkware/starknet/core/os/contract_class/class_hash.py @@ -1,46 +1,13 @@ -import contextlib -from contextvars import ContextVar -from enum import Enum, auto -from typing import Any, Optional, Tuple - -import cachetools - from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner -from starkware.starknet.core.os.contract_class.contract_class import ( +from starkware.starknet.core.os.contract_class.class_hash_utils import ( get_contract_class_struct, load_contract_class_cairo_program, ) +from starkware.starknet.core.os.contract_class.utils import ClassHashType, class_hash_cache_ctx_var from starkware.starknet.public.abi import starknet_keccak from starkware.starknet.services.api.contract_class.contract_class import ContractClass -class ClassHashType(Enum): - CONTRACT_CLASS = 0 - COMPILED_CLASS = auto() - DEPRECATED_COMPILED_CLASS = auto() - - -ClassHashCacheKeyType = Tuple[ClassHashType, Any] - -class_hash_cache_ctx_var: ContextVar[ - Optional[cachetools.LRUCache[ClassHashCacheKeyType, int]] -] = ContextVar("class_hash_cache", default=None) - - -@contextlib.contextmanager -def set_class_hash_cache(cache: cachetools.LRUCache[ClassHashCacheKeyType, int]): - """ - Sets a cache to be used by compute_class_hash(). - """ - assert class_hash_cache_ctx_var.get() is None, "Cannot replace an existing class_hash_cache." - - token = class_hash_cache_ctx_var.set(cache) - try: - yield - finally: - class_hash_cache_ctx_var.reset(token) - - def compute_class_hash(contract_class: ContractClass) -> int: cache = class_hash_cache_ctx_var.get() if cache is None: @@ -65,6 +32,7 @@ def _compute_class_hash_inner(contract_class: ContractClass) -> int: runner.run( "starkware.starknet.core.os.contract_class.contract_class.class_hash", poseidon_ptr=runner.poseidon_builtin.base, + range_check_ptr=runner.range_check_builtin.base, contract_class=contract_class_struct, use_full_name=True, verify_secure=False, diff --git a/src/starkware/starknet/core/os/contract_class/contract_class.py b/src/starkware/starknet/core/os/contract_class/class_hash_utils.py similarity index 83% rename from src/starkware/starknet/core/os/contract_class/contract_class.py rename to src/starkware/starknet/core/os/contract_class/class_hash_utils.py index db1179bd..ebaaa1a3 100644 --- a/src/starkware/starknet/core/os/contract_class/contract_class.py +++ b/src/starkware/starknet/core/os/contract_class/class_hash_utils.py @@ -11,11 +11,13 @@ from starkware.cairo.lang.compiler.program import Program from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.python.utils import from_bytes, to_bytes +from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.public.abi import starknet_keccak from starkware.starknet.services.api.contract_class.contract_class import ( ContractClass, EntryPointType, ) +from starkware.starkware_utils.error_handling import StarkException CAIRO_FILE = os.path.join(os.path.dirname(__file__), "contract_class.cairo") CONTRACT_CLASS_MODULE = "starkware.starknet.core.os.contract_class.contract_class" @@ -26,9 +28,7 @@ def load_contract_class_cairo_program() -> Program: return compile_cairo_files( [CAIRO_FILE], prime=DEFAULT_PRIME, - main_scope=ScopedName.from_string( - "starkware.starknet.core.os.contract_class.contract_class" - ), + main_scope=ScopedName.from_string(CONTRACT_CLASS_MODULE), ) @@ -71,14 +71,20 @@ def get_contract_class_struct( ) assert isinstance(CONTRACT_CLASS_VERSION_IDENT, ConstDefinition) - contract_class_version_ident_str = to_bytes(CONTRACT_CLASS_VERSION_IDENT.value).decode("ascii") - assert CONTRACT_CLASS_VERSION_IDENT.value == from_bytes( + if CONTRACT_CLASS_VERSION_IDENT.value != from_bytes( ("CONTRACT_CLASS_V" + contract_class.contract_class_version).encode("ascii") - ), ( - "Unexpected contract class version. " - f"Expected {contract_class_version_ident_str}; " - f"got CONTRACT_CLASS_V{contract_class.contract_class_version}." - ) + ): + contract_class_version_ident_str = to_bytes(CONTRACT_CLASS_VERSION_IDENT.value).decode( + "ascii" + ) + raise StarkException( + code=StarknetErrorCode.INVALID_CONTRACT_CLASS_VERSION, + message=( + "Unexpected contract class version. " + f"Expected {contract_class_version_ident_str}; " + f"got CONTRACT_CLASS_V{contract_class.contract_class_version}." + ), + ) external_functions, l1_handlers, constructors = ( _get_contract_entry_points( diff --git a/src/starkware/starknet/core/os/contract_class/compiled_class.cairo b/src/starkware/starknet/core/os/contract_class/compiled_class.cairo index 7f27b765..289f10dc 100644 --- a/src/starkware/starknet/core/os/contract_class/compiled_class.cairo +++ b/src/starkware/starknet/core/os/contract_class/compiled_class.cairo @@ -39,13 +39,6 @@ struct CompiledClass { n_constructors: felt, constructors: CompiledClassEntryPoint*, - // The hinted_compiled_class_hash field should be set to the starknet_keccak of the - // contract program, including its hints. However the OS does not validate that. - // This field may be used by the operator to differentiate between contract classes that - // differ only in the hints. - // This field is included in the hash of the CompiledClass to simplify the implementation. - hinted_compiled_class_hash: felt, - // The length and pointer of the bytecode. bytecode_length: felt, bytecode_ptr: felt*, @@ -86,6 +79,8 @@ func validate_entry_points_inner{range_check_ptr}( func compiled_class_hash{poseidon_ptr: PoseidonBuiltin*}(compiled_class: CompiledClass*) -> ( hash: felt ) { + assert compiled_class.compiled_class_version = COMPILED_CLASS_VERSION; + let hash_state: HashState = hash_init(); with hash_state { hash_update_single(item=compiled_class.compiled_class_version); @@ -106,9 +101,6 @@ func compiled_class_hash{poseidon_ptr: PoseidonBuiltin*}(compiled_class: Compile entry_points=compiled_class.constructors, n_entry_points=compiled_class.n_constructors ); - // Hash hinted_compiled_class_hash. - hash_update_single(item=compiled_class.hinted_compiled_class_hash); - // Hash bytecode. hash_update_with_nested_hash( data_ptr=compiled_class.bytecode_ptr, data_length=compiled_class.bytecode_length @@ -221,8 +213,6 @@ func load_compiled_class_facts_inner{poseidon_ptr: PoseidonBuiltin*, range_check ids.compiled_class = segments.gen_arg(cairo_contract) %} - assert compiled_class.compiled_class_version = COMPILED_CLASS_VERSION; - validate_entry_points( n_entry_points=compiled_class.n_external_functions, entry_points=compiled_class.external_functions, diff --git a/src/starkware/starknet/core/os/contract_class/compiled_class_hash.py b/src/starkware/starknet/core/os/contract_class/compiled_class_hash.py index ed7d2d80..d321b067 100644 --- a/src/starkware/starknet/core/os/contract_class/compiled_class_hash.py +++ b/src/starkware/starknet/core/os/contract_class/compiled_class_hash.py @@ -1,40 +1,11 @@ -import itertools -import json -import os -from functools import lru_cache -from typing import List - from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner -from starkware.cairo.common.structs import CairoStructFactory, CairoStructProxy -from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME -from starkware.cairo.lang.compiler.cairo_compile import compile_cairo_files -from starkware.cairo.lang.compiler.identifier_definition import ConstDefinition -from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager -from starkware.cairo.lang.compiler.program import Program -from starkware.cairo.lang.compiler.scoped_name import ScopedName -from starkware.python.utils import as_non_optional, from_bytes -from starkware.starknet.core.os.contract_class.class_hash import ( - ClassHashType, - class_hash_cache_ctx_var, +from starkware.starknet.core.os.contract_class.compiled_class_hash_utils import ( + get_compiled_class_struct, + load_compiled_class_cairo_program, ) +from starkware.starknet.core.os.contract_class.utils import ClassHashType, class_hash_cache_ctx_var from starkware.starknet.public.abi import starknet_keccak -from starkware.starknet.services.api.contract_class.contract_class import ( - CompiledClass, - EntryPointType, -) - -CAIRO_FILE = os.path.join(os.path.dirname(__file__), "compiled_class.cairo") - - -@lru_cache() -def load_program() -> Program: - return compile_cairo_files( - [CAIRO_FILE], - prime=DEFAULT_PRIME, - main_scope=ScopedName.from_string( - "starkware.starknet.core.os.contract_class.compiled_class" - ), - ) +from starkware.starknet.services.api.contract_class.contract_class import CompiledClass def compute_compiled_class_hash(compiled_class: CompiledClass) -> int: @@ -55,7 +26,7 @@ def compute_compiled_class_hash(compiled_class: CompiledClass) -> int: def _compute_compiled_class_hash_inner(compiled_class: CompiledClass) -> int: - program = load_program() + program = load_compiled_class_cairo_program() compiled_class_struct = get_compiled_class_struct( identifiers=program.identifiers, compiled_class=compiled_class ) @@ -70,92 +41,3 @@ def _compute_compiled_class_hash_inner(compiled_class: CompiledClass) -> int: ) _, class_hash = runner.get_return_values(2) return class_hash - - -def _compute_hinted_compiled_class_hash(compiled_class: CompiledClass) -> int: - """ - Computes the hash of the compiled class, including hints. - """ - input_to_hash = dict(program=compiled_class.program.dump()) - return starknet_keccak(data=json.dumps(input_to_hash, sort_keys=True).encode()) - - -def _get_contract_entry_points( - structs: CairoStructProxy, - compiled_class: CompiledClass, - entry_point_type: EntryPointType, -) -> List[CairoStructProxy]: - # Check validity of entry points. - program_length = len(compiled_class.program.data) - entry_points = compiled_class.entry_points_by_type[entry_point_type] - for entry_point in entry_points: - assert ( - 0 <= entry_point.offset < program_length - ), f"Invalid entry point offset {entry_point.offset}, len(program_data)={program_length}." - - return [ - structs.CompiledClassEntryPoint( - selector=entry_point.selector, - offset=entry_point.offset, - n_builtins=len(as_non_optional(entry_point.builtins)), - builtin_list=[ - from_bytes(builtin.encode("ascii")) - for builtin in as_non_optional(entry_point.builtins) - ], - ) - for entry_point in entry_points - ] - - -def get_compiled_class_struct( - identifiers: IdentifierManager, compiled_class: CompiledClass -) -> CairoStructProxy: - """ - Returns the serialization of a compiled class as a list of field elements. - """ - structs = CairoStructFactory( - identifiers=identifiers, - additional_imports=[ - "starkware.starknet.core.os.contract_class.compiled_class.CompiledClass", - "starkware.starknet.core.os.contract_class.compiled_class.CompiledClassEntryPoint", - ], - ).structs - - COMPILED_CLASS_VERSION_IDENT = identifiers.get_by_full_name( - ScopedName.from_string( - "starkware.starknet.core.os.contract_class.compiled_class.COMPILED_CLASS_VERSION" - ) - ) - assert isinstance(COMPILED_CLASS_VERSION_IDENT, ConstDefinition) - - external_functions, l1_handlers, constructors = ( - _get_contract_entry_points( - structs=structs, - compiled_class=compiled_class, - entry_point_type=entry_point_type, - ) - for entry_point_type in ( - EntryPointType.EXTERNAL, - EntryPointType.L1_HANDLER, - EntryPointType.CONSTRUCTOR, - ) - ) - flat_external_functions, flat_l1_handlers, flat_constructors = ( - list(itertools.chain.from_iterable(entry_points)) - for entry_points in (external_functions, l1_handlers, constructors) - ) - - return structs.CompiledClass( - compiled_class_version=COMPILED_CLASS_VERSION_IDENT.value, - n_external_functions=len(external_functions), - external_functions=flat_external_functions, - n_l1_handlers=len(l1_handlers), - l1_handlers=flat_l1_handlers, - n_constructors=len(constructors), - constructors=flat_constructors, - hinted_compiled_class_hash=_compute_hinted_compiled_class_hash( - compiled_class=compiled_class - ), - bytecode_length=len(compiled_class.program.data), - bytecode_ptr=compiled_class.program.data, - ) diff --git a/src/starkware/starknet/core/os/contract_class/compiled_class_hash_utils.py b/src/starkware/starknet/core/os/contract_class/compiled_class_hash_utils.py new file mode 100644 index 00000000..d5274b25 --- /dev/null +++ b/src/starkware/starknet/core/os/contract_class/compiled_class_hash_utils.py @@ -0,0 +1,113 @@ +import itertools +import os +from functools import lru_cache +from typing import List + +from starkware.cairo.common.structs import CairoStructFactory, CairoStructProxy +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME +from starkware.cairo.lang.compiler.cairo_compile import compile_cairo_files +from starkware.cairo.lang.compiler.identifier_definition import ConstDefinition +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.program import Program +from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.python.utils import as_non_optional, from_bytes +from starkware.starknet.services.api.contract_class.contract_class import ( + CompiledClass, + EntryPointType, +) + +CAIRO_FILE = os.path.join(os.path.dirname(__file__), "compiled_class.cairo") +COMPILED_CLASS_MODULE = "starkware.starknet.core.os.contract_class.compiled_class" + + +@lru_cache() +def load_compiled_class_cairo_program() -> Program: + return compile_cairo_files( + [CAIRO_FILE], + prime=DEFAULT_PRIME, + main_scope=ScopedName.from_string(COMPILED_CLASS_MODULE), + ) + + +@lru_cache() +def _get_empty_compiled_class_structs() -> CairoStructProxy: + program = load_compiled_class_cairo_program() + return CairoStructFactory( + identifiers=program.identifiers, + additional_imports=[ + f"{COMPILED_CLASS_MODULE}.CompiledClass", + f"{COMPILED_CLASS_MODULE}.CompiledClassEntryPoint", + ], + ).structs + + +def _get_contract_entry_points( + structs: CairoStructProxy, + compiled_class: CompiledClass, + entry_point_type: EntryPointType, +) -> List[CairoStructProxy]: + # Check validity of entry points. + program_length = len(compiled_class.bytecode) + entry_points = compiled_class.entry_points_by_type[entry_point_type] + for entry_point in entry_points: + assert ( + 0 <= entry_point.offset < program_length + ), f"Invalid entry point offset {entry_point.offset}, len(program_data)={program_length}." + + return [ + structs.CompiledClassEntryPoint( + selector=entry_point.selector, + offset=entry_point.offset, + n_builtins=len(as_non_optional(entry_point.builtins)), + builtin_list=[ + from_bytes(builtin.encode("ascii")) + for builtin in as_non_optional(entry_point.builtins) + ], + ) + for entry_point in entry_points + ] + + +def get_compiled_class_struct( + identifiers: IdentifierManager, compiled_class: CompiledClass +) -> CairoStructProxy: + """ + Returns the serialization of a compiled class as a list of field elements. + """ + structs = _get_empty_compiled_class_structs() + + COMPILED_CLASS_VERSION_IDENT = identifiers.get_by_full_name( + ScopedName.from_string( + "starkware.starknet.core.os.contract_class.compiled_class.COMPILED_CLASS_VERSION" + ) + ) + assert isinstance(COMPILED_CLASS_VERSION_IDENT, ConstDefinition) + + external_functions, l1_handlers, constructors = ( + _get_contract_entry_points( + structs=structs, + compiled_class=compiled_class, + entry_point_type=entry_point_type, + ) + for entry_point_type in ( + EntryPointType.EXTERNAL, + EntryPointType.L1_HANDLER, + EntryPointType.CONSTRUCTOR, + ) + ) + flat_external_functions, flat_l1_handlers, flat_constructors = ( + list(itertools.chain.from_iterable(entry_points)) + for entry_points in (external_functions, l1_handlers, constructors) + ) + + return structs.CompiledClass( + compiled_class_version=COMPILED_CLASS_VERSION_IDENT.value, + n_external_functions=len(external_functions), + external_functions=flat_external_functions, + n_l1_handlers=len(l1_handlers), + l1_handlers=flat_l1_handlers, + n_constructors=len(constructors), + constructors=flat_constructors, + bytecode_length=len(compiled_class.bytecode), + bytecode_ptr=compiled_class.bytecode, + ) diff --git a/src/starkware/starknet/core/os/contract_class/contract_class.cairo b/src/starkware/starknet/core/os/contract_class/contract_class.cairo index 485e4f8e..9a8763b4 100644 --- a/src/starkware/starknet/core/os/contract_class/contract_class.cairo +++ b/src/starkware/starknet/core/os/contract_class/contract_class.cairo @@ -6,6 +6,7 @@ from starkware.cairo.common.hash_state_poseidon import ( hash_update_single, hash_update_with_nested_hash, ) +from starkware.starknet.common.storage import normalize_address const CONTRACT_CLASS_VERSION = 'CONTRACT_CLASS_V0.1.0'; @@ -39,7 +40,11 @@ struct ContractClass { sierra_program_ptr: felt*, } -func class_hash{poseidon_ptr: PoseidonBuiltin*}(contract_class: ContractClass*) -> (hash: felt) { +func class_hash{poseidon_ptr: PoseidonBuiltin*, range_check_ptr: felt}( + contract_class: ContractClass* +) -> (hash: felt) { + assert contract_class.contract_class_version = CONTRACT_CLASS_VERSION; + let hash_state: HashState = hash_init(); with hash_state { hash_update_single(item=contract_class.contract_class_version); @@ -70,8 +75,9 @@ func class_hash{poseidon_ptr: PoseidonBuiltin*}(contract_class: ContractClass*) data_ptr=contract_class.sierra_program_ptr, data_length=contract_class.sierra_program_length, ); - - let hash: felt = hash_finalize(hash_state=hash_state); } - return (hash=hash); + + let hash: felt = hash_finalize(hash_state=hash_state); + let (normalized_hash) = normalize_address(addr=hash); + return (hash=normalized_hash); } diff --git a/src/starkware/starknet/core/os/contract_class/deprecated_class_hash.py b/src/starkware/starknet/core/os/contract_class/deprecated_class_hash.py index 5df16685..1b5922e0 100644 --- a/src/starkware/starknet/core/os/contract_class/deprecated_class_hash.py +++ b/src/starkware/starknet/core/os/contract_class/deprecated_class_hash.py @@ -17,10 +17,7 @@ from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.vm.crypto import pedersen_hash from starkware.python.utils import from_bytes -from starkware.starknet.core.os.contract_class.class_hash import ( - ClassHashType, - class_hash_cache_ctx_var, -) +from starkware.starknet.core.os.contract_class.utils import ClassHashType, class_hash_cache_ctx_var from starkware.starknet.public.abi import starknet_keccak from starkware.starknet.services.api.contract_class.contract_class import ( DeprecatedCompiledClass, diff --git a/src/starkware/starknet/core/os/contract_class/utils.py b/src/starkware/starknet/core/os/contract_class/utils.py new file mode 100644 index 00000000..41a741fe --- /dev/null +++ b/src/starkware/starknet/core/os/contract_class/utils.py @@ -0,0 +1,33 @@ +import contextlib +from contextvars import ContextVar +from enum import Enum, auto +from typing import Any, Optional, Tuple + +import cachetools + + +class ClassHashType(Enum): + CONTRACT_CLASS = 0 + COMPILED_CLASS = auto() + DEPRECATED_COMPILED_CLASS = auto() + + +ClassHashCacheKeyType = Tuple[ClassHashType, Any] + +class_hash_cache_ctx_var: ContextVar[ + Optional[cachetools.LRUCache[ClassHashCacheKeyType, int]] +] = ContextVar("class_hash_cache", default=None) + + +@contextlib.contextmanager +def set_class_hash_cache(cache: cachetools.LRUCache[ClassHashCacheKeyType, int]): + """ + Sets a cache to be used by compute_class_hash(). + """ + assert class_hash_cache_ctx_var.get() is None, "Cannot replace an existing class_hash_cache." + + token = class_hash_cache_ctx_var.set(cache) + try: + yield + finally: + class_hash_cache_ctx_var.reset(token) diff --git a/src/starkware/starknet/core/os/execution/deprecated_execute_entry_point.cairo b/src/starkware/starknet/core/os/execution/deprecated_execute_entry_point.cairo index b71f06d6..b42543fa 100644 --- a/src/starkware/starknet/core/os/execution/deprecated_execute_entry_point.cairo +++ b/src/starkware/starknet/core/os/execution/deprecated_execute_entry_point.cairo @@ -59,7 +59,7 @@ func get_entry_point_offset{range_check_ptr}( array_ptr=cast(entry_points, felt*), elm_size=DeprecatedContractEntryPoint.SIZE, n_elms=n_entry_points, - key=execution_context.selector, + key=execution_context.execution_info.selector, ); if (success != 0) { return (entry_point_offset=entry_point_desc.offset); @@ -90,7 +90,7 @@ func call_execute_deprecated_syscalls{ // Executes an entry point in a contract. // The contract entry point is selected based on execution_context.entry_point_type -// and execution_context.selector. +// and execution_context.execution_info.selector. // // Arguments: // block_context - a global context that is fixed throughout the block. @@ -120,11 +120,10 @@ func deprecated_execute_entry_point{ compiled_class=compiled_class, execution_context=execution_context ); - %{ execution_helper.enter_call() %} if (entry_point_offset == NOP_ENTRY_POINT_OFFSET) { // Assert that there is no call data in the case of NOP entry point. assert execution_context.calldata_size = 0; - %{ execution_helper.exit_call() %} + %{ execution_helper.skip_call() %} return (retdata_size=0, retdata=cast(0, felt*)); } @@ -152,13 +151,20 @@ func deprecated_execute_entry_point{ ); // Use tempvar to pass arguments to contract_entry_point(). - tempvar selector = execution_context.selector; + tempvar selector = execution_context.execution_info.selector; tempvar context = os_context; tempvar calldata_size = execution_context.calldata_size; tempvar calldata = execution_context.calldata; + + %{ + execution_helper.enter_call( + execution_info_ptr=ids.execution_context.execution_info.address_) + %} %{ vm_enter_scope({'syscall_handler': deprecated_syscall_handler}) %} call abs contract_entry_point; %{ vm_exit_scope() %} + %{ execution_helper.exit_call() %} + // Retrieve returned_builtin_ptrs_subset. // Note that returned_builtin_ptrs_subset cannot be set in a hint because doing so will allow a // malicious prover to lie about the storage changes of a valid contract. @@ -201,6 +207,9 @@ func deprecated_execute_entry_point{ n_builtins=n_builtins, ); + // Validate that segment_arena builtin was not used. + assert builtin_ptrs.segment_arena = return_builtin_ptrs.segment_arena; + let syscall_end = cast([returned_builtin_ptrs_subset - 1], felt*); let builtin_ptrs = return_builtin_ptrs; @@ -211,7 +220,6 @@ func deprecated_execute_entry_point{ syscall_ptr=syscall_ptr, ); - %{ execution_helper.exit_call() %} return (retdata_size=retdata_size, retdata=retdata); } diff --git a/src/starkware/starknet/core/os/execution/deprecated_execute_syscalls.cairo b/src/starkware/starknet/core/os/execution/deprecated_execute_syscalls.cairo index 3fa3f917..6280e6cb 100644 --- a/src/starkware/starknet/core/os/execution/deprecated_execute_syscalls.cairo +++ b/src/starkware/starknet/core/os/execution/deprecated_execute_syscalls.cairo @@ -6,6 +6,7 @@ from starkware.cairo.common.math import assert_not_zero from starkware.cairo.common.memcpy import memcpy from starkware.cairo.common.segments import relocate_segment from starkware.starknet.common.constants import ORIGIN_ADDRESS +from starkware.starknet.common.new_syscalls import ExecutionInfo from starkware.starknet.common.syscalls import ( CALL_CONTRACT_SELECTOR, DELEGATE_CALL_SELECTOR, @@ -49,7 +50,6 @@ from starkware.starknet.common.syscalls import ( SendMessageToL1SysCall, StorageRead, StorageWrite, - TxInfo, ) from starkware.starknet.core.os.block_context import BlockContext from starkware.starknet.core.os.builtins import BuiltinPointers @@ -118,7 +118,7 @@ func execute_contract_call_syscall{ contract_address: felt, caller_address: felt, entry_point_type: felt, - original_tx_info: TxInfo*, + caller_execution_context: ExecutionContext*, syscall_ptr: CallContract*, ) { alloc_locals; @@ -129,15 +129,20 @@ func execute_contract_call_syscall{ key=call_req.contract_address ); + tempvar caller_execution_info = caller_execution_context.execution_info; local execution_context: ExecutionContext* = new ExecutionContext( entry_point_type=entry_point_type, - caller_address=caller_address, - contract_address=contract_address, class_hash=state_entry.class_hash, - selector=call_req.function_selector, calldata_size=call_req.calldata_size, calldata=call_req.calldata, - original_tx_info=original_tx_info, + execution_info=new ExecutionInfo( + block_info=caller_execution_info.block_info, + tx_info=caller_execution_info.tx_info, + caller_address=caller_address, + contract_address=contract_address, + selector=call_req.function_selector, + ), + deprecated_tx_info=caller_execution_context.deprecated_tx_info, ); return contract_call_helper( @@ -164,15 +169,20 @@ func execute_library_call_syscall{ let call_req = syscall_ptr.request; + tempvar caller_execution_info = caller_execution_context.execution_info; local execution_context: ExecutionContext* = new ExecutionContext( entry_point_type=entry_point_type, - caller_address=caller_execution_context.caller_address, - contract_address=caller_execution_context.contract_address, class_hash=call_req.class_hash, - selector=call_req.function_selector, calldata_size=call_req.calldata_size, calldata=call_req.calldata, - original_tx_info=caller_execution_context.original_tx_info, + execution_info=new ExecutionInfo( + block_info=caller_execution_info.block_info, + tx_info=caller_execution_info.tx_info, + caller_address=caller_execution_info.caller_address, + contract_address=caller_execution_info.contract_address, + selector=call_req.function_selector, + ), + deprecated_tx_info=caller_execution_context.deprecated_tx_info, ); return contract_call_helper( @@ -189,13 +199,15 @@ func execute_deploy_syscall{ contract_class_changes: DictAccess*, outputs: OsCarriedOutputs*, }(block_context: BlockContext*, caller_execution_context: ExecutionContext*, syscall_ptr: Deploy*) { + alloc_locals; + local caller_execution_info: ExecutionInfo* = caller_execution_context.execution_info; + local caller_address = caller_execution_info.contract_address; + let request = syscall_ptr.request; // Verify deploy_from_zero is either 0 (FALSE) or 1 (TRUE). assert request.deploy_from_zero * (request.deploy_from_zero - 1) = 0; // Set deployer_address to 0 if request.deploy_from_zero is TRUE. - let deployer_address = ( - (1 - request.deploy_from_zero) * caller_execution_context.contract_address - ); + let deployer_address = (1 - request.deploy_from_zero) * caller_address; let hash_ptr = builtin_ptrs.pedersen; with hash_ptr { @@ -214,6 +226,7 @@ func execute_deploy_syscall{ bitwise=builtin_ptrs.bitwise, ec_op=builtin_ptrs.ec_op, poseidon=builtin_ptrs.poseidon, + segment_arena=builtin_ptrs.segment_arena, ); // Fill the syscall response, before contract_address is revoked. @@ -225,13 +238,17 @@ func execute_deploy_syscall{ tempvar constructor_execution_context = new ExecutionContext( entry_point_type=ENTRY_POINT_TYPE_CONSTRUCTOR, - caller_address=caller_execution_context.contract_address, - contract_address=contract_address, class_hash=request.class_hash, - selector=CONSTRUCTOR_ENTRY_POINT_SELECTOR, calldata_size=request.constructor_calldata_size, calldata=request.constructor_calldata, - original_tx_info=caller_execution_context.original_tx_info, + execution_info=new ExecutionInfo( + block_info=caller_execution_info.block_info, + tx_info=caller_execution_info.tx_info, + caller_address=caller_address, + contract_address=contract_address, + selector=CONSTRUCTOR_ENTRY_POINT_SELECTOR, + ), + deprecated_tx_info=caller_execution_context.deprecated_tx_info, ); // Set enough gas for this call to succeed; see the comment in 'contract_call_helper'. @@ -377,7 +394,7 @@ func execute_deprecated_syscalls{ if (selector == STORAGE_READ_SELECTOR) { execute_storage_read( - contract_address=execution_context.contract_address, + contract_address=execution_context.execution_info.contract_address, syscall_ptr=cast(syscall_ptr, StorageRead*), ); return execute_deprecated_syscalls( @@ -390,7 +407,7 @@ func execute_deprecated_syscalls{ if (selector == STORAGE_WRITE_SELECTOR) { execute_storage_write( - contract_address=execution_context.contract_address, + contract_address=execution_context.execution_info.contract_address, syscall_ptr=cast(syscall_ptr, StorageWrite*), ); return execute_deprecated_syscalls( @@ -416,9 +433,9 @@ func execute_deprecated_syscalls{ execute_contract_call_syscall( block_context=block_context, contract_address=call_contract_syscall.request.contract_address, - caller_address=execution_context.contract_address, + caller_address=execution_context.execution_info.contract_address, entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, - original_tx_info=execution_context.original_tx_info, + caller_execution_context=execution_context, syscall_ptr=call_contract_syscall, ); return execute_deprecated_syscalls( @@ -461,7 +478,7 @@ func execute_deprecated_syscalls{ if (selector == GET_TX_INFO_SELECTOR) { assert cast(syscall_ptr, GetTxInfo*).response = GetTxInfoResponse( - tx_info=execution_context.original_tx_info + tx_info=execution_context.deprecated_tx_info ); return execute_deprecated_syscalls( block_context=block_context, @@ -473,7 +490,7 @@ func execute_deprecated_syscalls{ if (selector == GET_CALLER_ADDRESS_SELECTOR) { assert [cast(syscall_ptr, GetCallerAddress*)].response = GetCallerAddressResponse( - caller_address=execution_context.caller_address + caller_address=execution_context.execution_info.caller_address ); return execute_deprecated_syscalls( block_context=block_context, @@ -485,7 +502,7 @@ func execute_deprecated_syscalls{ if (selector == GET_SEQUENCER_ADDRESS_SELECTOR) { assert [cast(syscall_ptr, GetSequencerAddress*)].response = GetSequencerAddressResponse( - sequencer_address=block_context.sequencer_address + sequencer_address=block_context.block_info.sequencer_address ); return execute_deprecated_syscalls( block_context=block_context, @@ -497,7 +514,7 @@ func execute_deprecated_syscalls{ if (selector == GET_CONTRACT_ADDRESS_SELECTOR) { assert [cast(syscall_ptr, GetContractAddress*)].response = GetContractAddressResponse( - contract_address=execution_context.contract_address + contract_address=execution_context.execution_info.contract_address ); return execute_deprecated_syscalls( block_context=block_context, @@ -532,9 +549,9 @@ func execute_deprecated_syscalls{ } if (selector == GET_TX_SIGNATURE_SELECTOR) { - tempvar original_tx_info: TxInfo* = execution_context.original_tx_info; + tempvar deprecated_tx_info = execution_context.deprecated_tx_info; assert [cast(syscall_ptr, GetTxSignature*)].response = GetTxSignatureResponse( - signature_len=original_tx_info.signature_len, signature=original_tx_info.signature + signature_len=deprecated_tx_info.signature_len, signature=deprecated_tx_info.signature ); return execute_deprecated_syscalls( block_context=block_context, @@ -560,12 +577,13 @@ func execute_deprecated_syscalls{ // DEPRECATED. if (selector == DELEGATE_CALL_SELECTOR) { + tempvar execution_info = execution_context.execution_info; execute_contract_call_syscall( block_context=block_context, - contract_address=execution_context.contract_address, - caller_address=execution_context.caller_address, + contract_address=execution_info.contract_address, + caller_address=execution_info.caller_address, entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, - original_tx_info=execution_context.original_tx_info, + caller_execution_context=execution_context, syscall_ptr=cast(syscall_ptr, CallContract*), ); return execute_deprecated_syscalls( @@ -578,12 +596,13 @@ func execute_deprecated_syscalls{ // DEPRECATED. if (selector == DELEGATE_L1_HANDLER_SELECTOR) { + tempvar execution_info = execution_context.execution_info; execute_contract_call_syscall( block_context=block_context, - contract_address=execution_context.contract_address, - caller_address=execution_context.caller_address, + contract_address=execution_info.contract_address, + caller_address=execution_info.caller_address, entry_point_type=ENTRY_POINT_TYPE_L1_HANDLER, - original_tx_info=execution_context.original_tx_info, + caller_execution_context=execution_context, syscall_ptr=cast(syscall_ptr, CallContract*), ); return execute_deprecated_syscalls( @@ -596,7 +615,7 @@ func execute_deprecated_syscalls{ if (selector == REPLACE_CLASS_SELECTOR) { execute_replace_class( - contract_address=execution_context.contract_address, + contract_address=execution_context.execution_info.contract_address, syscall_ptr=cast(syscall_ptr, ReplaceClass*), ); return execute_deprecated_syscalls( @@ -613,7 +632,7 @@ func execute_deprecated_syscalls{ let syscall = [cast(syscall_ptr, SendMessageToL1SysCall*)]; assert [outputs.messages_to_l1] = MessageToL1Header( - from_address=execution_context.contract_address, + from_address=execution_context.execution_info.contract_address, to_address=syscall.to_address, payload_size=syscall.payload_size, ); @@ -650,7 +669,7 @@ func deploy_contract{ }(block_context: BlockContext*, constructor_execution_context: ExecutionContext*) { alloc_locals; - local contract_address = constructor_execution_context.contract_address; + local contract_address = constructor_execution_context.execution_info.contract_address; // Assert that we don't deploy to ORIGIN_ADDRESS. assert_not_zero(contract_address - ORIGIN_ADDRESS); diff --git a/src/starkware/starknet/core/os/execution/execute_entry_point.cairo b/src/starkware/starknet/core/os/execution/execute_entry_point.cairo index b884d756..5f4df5b2 100644 --- a/src/starkware/starknet/core/os/execution/execute_entry_point.cairo +++ b/src/starkware/starknet/core/os/execution/execute_entry_point.cairo @@ -7,7 +7,12 @@ from starkware.cairo.common.dict_access import DictAccess from starkware.cairo.common.find_element import find_element, search_sorted from starkware.cairo.common.math import assert_not_zero from starkware.cairo.common.registers import get_ap -from starkware.starknet.common.syscalls import TxInfo +from starkware.starknet.builtins.segment_arena.segment_arena import ( + SegmentArenaBuiltin, + validate_segment_arena, +) +from starkware.starknet.common.new_syscalls import ExecutionInfo +from starkware.starknet.common.syscalls import TxInfo as DeprecatedTxInfo from starkware.starknet.core.os.block_context import BlockContext from starkware.starknet.core.os.builtins import BuiltinEncodings, BuiltinParams, BuiltinPointers from starkware.starknet.core.os.constants import ( @@ -28,17 +33,14 @@ from starkware.starknet.core.os.output import OsCarriedOutputs // Represents the execution context during the execution of contract code. struct ExecutionContext { entry_point_type: felt, - caller_address: felt, - // The execution is done in the context of the contract at 'contract_address'. - // This address controls the storage being used, messages sent to L1, calling contracts, etc. - contract_address: felt, // The hash of the contract class to execute. class_hash: felt, - selector: felt, calldata_size: felt, calldata: felt*, + // Additional information about the execution. + execution_info: ExecutionInfo*, // Information about the transaction that triggered the execution. - original_tx_info: TxInfo*, + deprecated_tx_info: DeprecatedTxInfo*, } // Represents the arguments pushed to the stack before calling an entry point. @@ -106,7 +108,7 @@ func get_entry_point{range_check_ptr}( array_ptr=cast(entry_points, felt*), elm_size=CompiledClassEntryPoint.SIZE, n_elms=n_entry_points, - key=execution_context.selector, + key=execution_context.execution_info.selector, ); if (success != 0) { return (entry_point=entry_point_desc); @@ -121,7 +123,7 @@ func get_entry_point{range_check_ptr}( // Executes an entry point in a contract. // The contract entry point is selected based on execution_context.entry_point_type -// and execution_context.selector. +// and execution_context.execution_info.selector. // // Arguments: // block_context - a global context that is fixed throughout the block. @@ -137,11 +139,10 @@ func execute_entry_point{ retdata_size: felt, retdata: felt* ) { alloc_locals; - %{ execution_helper.enter_call() %} - let (compiled_class_hash: felt) = dict_read{dict_ptr=contract_class_changes}( key=execution_context.class_hash ); + // The key must be at offset 0. static_assert CompiledClassFact.hash == 0; let (compiled_class_fact: CompiledClassFact*) = find_element( @@ -158,7 +159,7 @@ func execute_entry_point{ if (compiled_class_entry_point == cast(0, CompiledClassEntryPoint*)) { // Assert that there is no call data in the case of NOP entry point. assert execution_context.calldata_size = 0; - %{ execution_helper.exit_call() %} + %{ execution_helper.skip_call() %} return (retdata_size=0, retdata=cast(0, felt*)); } @@ -177,6 +178,8 @@ func execute_entry_point{ %} assert [os_context] = cast(syscall_ptr, felt); + let builtin_ptrs: BuiltinPointers* = prepare_builtin_ptrs_for_execute(builtin_ptrs); + let n_builtins = BuiltinEncodings.SIZE; local builtin_params: BuiltinParams* = block_context.builtin_params; local calldata_start: felt* = execution_context.calldata; @@ -202,9 +205,14 @@ func execute_entry_point{ ); static_assert ap == current_ap + EntryPointCallArguments.SIZE; + %{ + execution_helper.enter_call( + execution_info_ptr=ids.execution_context.execution_info.address_) + %} %{ vm_enter_scope({'syscall_handler': syscall_handler}) %} call abs contract_entry_point; %{ vm_exit_scope() %} + // Retrieve returned_builtin_ptrs_subset. // Note that returned_builtin_ptrs_subset cannot be set in a hint because doing so will allow a // malicious prover to lie about the storage changes of a valid contract. @@ -218,6 +226,7 @@ func execute_entry_point{ syscall_handler.validate_and_discard_syscall_ptr( syscall_ptr_end=ids.entry_point_return_values.syscall_ptr ) + execution_helper.exit_call() %} // Check that the execution was successful. @@ -261,6 +270,14 @@ func execute_entry_point{ n_builtins=n_builtins, ); + // Validate the segment_arena builtin. + // Note that as the segment_arena pointer points to the first unused element, we need to + // take segment_arena[-1] to get the actual values. + tempvar prev_segment_arena = &builtin_ptrs.segment_arena[-1]; + tempvar current_segment_arena = &return_builtin_ptrs.segment_arena[-1]; + assert prev_segment_arena.infos = current_segment_arena.infos; + validate_segment_arena(segment_arena=current_segment_arena); + let builtin_ptrs = return_builtin_ptrs; with syscall_ptr { call_execute_syscalls( @@ -270,6 +287,32 @@ func execute_entry_point{ ); } - %{ execution_helper.exit_call() %} return (retdata_size=retdata_end - retdata_start, retdata=retdata_start); } + +// Prepares the builtin pointer for the execution of an entry point. +// In particular, restarts the SegmentArenaBuiltin struct if it was previously used. +func prepare_builtin_ptrs_for_execute(builtin_ptrs: BuiltinPointers*) -> BuiltinPointers* { + tempvar segment_arena_ptr = builtin_ptrs.segment_arena; + tempvar prev_segment_arena = &segment_arena_ptr[-1]; + + // If no segment was allocated, we don't need to restart the struct. + tempvar prev_n_segments = prev_segment_arena.n_segments; + if (prev_n_segments == 0) { + return builtin_ptrs; + } + + assert segment_arena_ptr[0] = SegmentArenaBuiltin( + infos=&prev_segment_arena.infos[prev_n_segments], n_segments=0, n_finalized=0 + ); + let segment_arena_ptr = &segment_arena_ptr[1]; + return new BuiltinPointers( + pedersen=builtin_ptrs.pedersen, + range_check=builtin_ptrs.range_check, + ecdsa=builtin_ptrs.ecdsa, + bitwise=builtin_ptrs.bitwise, + ec_op=builtin_ptrs.ec_op, + poseidon=builtin_ptrs.poseidon, + segment_arena=segment_arena_ptr, + ); +} diff --git a/src/starkware/starknet/core/os/execution/execute_syscalls.cairo b/src/starkware/starknet/core/os/execution/execute_syscalls.cairo index 9157ab6c..f3767b8e 100644 --- a/src/starkware/starknet/core/os/execution/execute_syscalls.cairo +++ b/src/starkware/starknet/core/os/execution/execute_syscalls.cairo @@ -1,14 +1,21 @@ from starkware.cairo.common.dict import dict_read, dict_update from starkware.cairo.common.dict_access import DictAccess from starkware.cairo.common.math import assert_lt, assert_nn, assert_not_zero +from starkware.cairo.common.segments import relocate_segment from starkware.starknet.common.new_syscalls import ( + CALL_CONTRACT_SELECTOR, EMIT_EVENT_SELECTOR, - GET_CALLER_ADDRESS_SELECTOR, + GET_EXECUTION_INFO_SELECTOR, + LIBRARY_CALL_SELECTOR, STORAGE_READ_SELECTOR, STORAGE_WRITE_SELECTOR, + CallContractRequest, + CallContractResponse, EmitEventRequest, + ExecutionInfo, FailureReason, - GetCallerAddressResponse, + GetExecutionInfoResponse, + LibraryCallRequest, RequestHeader, ResponseHeader, StorageReadRequest, @@ -18,13 +25,19 @@ from starkware.starknet.common.new_syscalls import ( from starkware.starknet.core.os.block_context import BlockContext from starkware.starknet.core.os.builtins import BuiltinPointers from starkware.starknet.core.os.constants import ( + CALL_CONTRACT_GAS_COST, EMIT_EVENT_GAS_COST, + ENTRY_POINT_TYPE_EXTERNAL, ERROR_OUT_OF_GAS, - GET_CALLER_ADDRESS_GAS_COST, + GET_EXECUTION_INFO_GAS_COST, + LIBRARY_CALL_GAS_COST, STORAGE_READ_GAS_COST, STORAGE_WRITE_GAS_COST, SYSCALL_BASE_GAS_COST, ) +from starkware.starknet.core.os.execution.deprecated_execute_entry_point import ( + select_execute_entry_point_func, +) from starkware.starknet.core.os.execution.execute_entry_point import ExecutionContext from starkware.starknet.core.os.output import OsCarriedOutputs from starkware.starknet.core.os.state import StateEntry @@ -50,7 +63,7 @@ func execute_syscalls{ tempvar selector = [syscall_ptr]; if (selector == STORAGE_READ_SELECTOR) { - execute_storage_read(contract_address=execution_context.contract_address); + execute_storage_read(contract_address=execution_context.execution_info.contract_address); return execute_syscalls( block_context=block_context, execution_context=execution_context, @@ -59,7 +72,7 @@ func execute_syscalls{ } if (selector == STORAGE_WRITE_SELECTOR) { - execute_storage_write(contract_address=execution_context.contract_address); + execute_storage_write(contract_address=execution_context.execution_info.contract_address); return execute_syscalls( block_context=block_context, execution_context=execution_context, @@ -67,8 +80,8 @@ func execute_syscalls{ ); } - if (selector == EMIT_EVENT_SELECTOR) { - reduce_syscall_gas(gas_cost=EMIT_EVENT_GAS_COST, request_size=EmitEventRequest.SIZE); + if (selector == GET_EXECUTION_INFO_SELECTOR) { + execute_get_execution_info(execution_info=execution_context.execution_info); return execute_syscalls( block_context=block_context, execution_context=execution_context, @@ -76,8 +89,32 @@ func execute_syscalls{ ); } - assert selector = GET_CALLER_ADDRESS_SELECTOR; - execute_get_caller_address(caller_address=execution_context.caller_address); + if (selector == CALL_CONTRACT_SELECTOR) { + execute_call_contract( + block_context=block_context, caller_execution_context=execution_context + ); + return execute_syscalls( + block_context=block_context, + execution_context=execution_context, + syscall_ptr_end=syscall_ptr_end, + ); + } + + if (selector == LIBRARY_CALL_SELECTOR) { + execute_library_call( + block_context=block_context, caller_execution_context=execution_context + ); + return execute_syscalls( + block_context=block_context, + execution_context=execution_context, + syscall_ptr_end=syscall_ptr_end, + ); + } + + assert selector = EMIT_EVENT_SELECTOR; + reduce_syscall_gas_and_write_response_header( + total_gas_cost=EMIT_EVENT_GAS_COST, request_struct_size=EmitEventRequest.SIZE + ); return execute_syscalls( block_context=block_context, execution_context=execution_context, @@ -85,6 +122,144 @@ func execute_syscalls{ ); } +// Executes a syscall that calls another contract. +func execute_call_contract{ + range_check_ptr, + syscall_ptr: felt*, + builtin_ptrs: BuiltinPointers*, + contract_state_changes: DictAccess*, + contract_class_changes: DictAccess*, + outputs: OsCarriedOutputs*, +}(block_context: BlockContext*, caller_execution_context: ExecutionContext*) { + let request = cast(syscall_ptr + RequestHeader.SIZE, CallContractRequest*); + let (success, remaining_gas) = reduce_syscall_base_gas( + specific_base_gas_cost=CALL_CONTRACT_GAS_COST, request_struct_size=CallContractRequest.SIZE + ); + if (success == 0) { + // Not enough gas to execute the syscall. + return (); + } + + tempvar contract_address = request.contract_address; + let (state_entry: StateEntry*) = dict_read{dict_ptr=contract_state_changes}( + key=contract_address + ); + + // Prepare execution context. + tempvar calldata_start = request.calldata_start; + tempvar caller_execution_info = caller_execution_context.execution_info; + tempvar execution_context: ExecutionContext* = new ExecutionContext( + entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, + class_hash=state_entry.class_hash, + calldata_size=request.calldata_end - calldata_start, + calldata=calldata_start, + execution_info=new ExecutionInfo( + block_info=caller_execution_info.block_info, + tx_info=caller_execution_info.tx_info, + caller_address=caller_execution_info.contract_address, + contract_address=contract_address, + selector=request.selector, + ), + deprecated_tx_info=caller_execution_context.deprecated_tx_info, + ); + + return contract_call_helper( + remaining_gas=remaining_gas, + block_context=block_context, + execution_context=execution_context, + ); +} + +// Implements the library_call syscall. +func execute_library_call{ + range_check_ptr, + syscall_ptr: felt*, + builtin_ptrs: BuiltinPointers*, + contract_state_changes: DictAccess*, + contract_class_changes: DictAccess*, + outputs: OsCarriedOutputs*, +}(block_context: BlockContext*, caller_execution_context: ExecutionContext*) { + let request = cast(syscall_ptr + RequestHeader.SIZE, LibraryCallRequest*); + let (success, remaining_gas) = reduce_syscall_base_gas( + specific_base_gas_cost=LIBRARY_CALL_GAS_COST, request_struct_size=LibraryCallRequest.SIZE + ); + if (success == 0) { + // Not enough gas to execute the syscall. + return (); + } + + // Prepare execution context. + tempvar calldata_start = request.calldata_start; + tempvar caller_execution_info = caller_execution_context.execution_info; + tempvar execution_context: ExecutionContext* = new ExecutionContext( + entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, + class_hash=request.class_hash, + calldata_size=request.calldata_end - calldata_start, + calldata=calldata_start, + execution_info=new ExecutionInfo( + block_info=caller_execution_info.block_info, + tx_info=caller_execution_info.tx_info, + caller_address=caller_execution_info.caller_address, + contract_address=caller_execution_info.contract_address, + selector=request.selector, + ), + deprecated_tx_info=caller_execution_context.deprecated_tx_info, + ); + + return contract_call_helper( + remaining_gas=remaining_gas, + block_context=block_context, + execution_context=execution_context, + ); +} + +// Executes the entry point and writes the corresponding response to the syscall_ptr. +// Assumes that syscall_ptr points at the response header. +func contract_call_helper{ + range_check_ptr, + syscall_ptr: felt*, + builtin_ptrs: BuiltinPointers*, + contract_state_changes: DictAccess*, + contract_class_changes: DictAccess*, + outputs: OsCarriedOutputs*, +}(remaining_gas: felt, block_context: BlockContext*, execution_context: ExecutionContext*) { + with remaining_gas { + let (retdata_size, retdata) = select_execute_entry_point_func( + block_context=block_context, execution_context=execution_context + ); + } + + let response_header = cast(syscall_ptr, ResponseHeader*); + // Advance syscall pointer to the response body. + let syscall_ptr = syscall_ptr + ResponseHeader.SIZE; + + // Write the response header. + assert [response_header] = ResponseHeader(gas=remaining_gas, failure_flag=0); + + let response = cast(syscall_ptr, CallContractResponse*); + // Advance syscall pointer to the next syscall. + let syscall_ptr = syscall_ptr + CallContractResponse.SIZE; + + %{ + # Check that the actual return value matches the expected one. + expected = memory.get_range( + addr=ids.response.retdata_start, + size=ids.response.retdata_end - ids.response.retdata_start, + ) + actual = memory.get_range(addr=ids.retdata, size=ids.retdata_size) + + assert expected == actual, f'Return value mismatch expected={expected}, actual={actual}.' + %} + + // Write the response. + relocate_segment(src_ptr=response.retdata_start, dest_ptr=retdata); + assert [response] = CallContractResponse( + retdata_start=retdata, retdata_end=retdata + retdata_size + ); + + return (); +} + // Reads a value from the current contract's storage. func execute_storage_read{range_check_ptr, syscall_ptr: felt*, contract_state_changes: DictAccess*}( contract_address @@ -93,8 +268,8 @@ func execute_storage_read{range_check_ptr, syscall_ptr: felt*, contract_state_ch let request = cast(syscall_ptr + RequestHeader.SIZE, StorageReadRequest*); // Reduce gas. - let success = reduce_syscall_gas( - gas_cost=STORAGE_READ_GAS_COST, request_size=StorageReadRequest.SIZE + let success = reduce_syscall_gas_and_write_response_header( + total_gas_cost=STORAGE_READ_GAS_COST, request_struct_size=StorageReadRequest.SIZE ); if (success == 0) { // Not enough gas to execute the syscall. @@ -142,8 +317,8 @@ func execute_storage_write{ let request = cast(syscall_ptr + RequestHeader.SIZE, StorageWriteRequest*); // Reduce gas. - let success = reduce_syscall_gas( - gas_cost=STORAGE_WRITE_GAS_COST, request_size=StorageWriteRequest.SIZE + let success = reduce_syscall_gas_and_write_response_header( + total_gas_cost=STORAGE_WRITE_GAS_COST, request_struct_size=StorageWriteRequest.SIZE ); if (success == 0) { // Not enough gas to execute the syscall. @@ -185,65 +360,89 @@ func execute_storage_write{ return (); } -// Gets the address of the caller contract. -func execute_get_caller_address{ - range_check_ptr, syscall_ptr: felt*, contract_state_changes: DictAccess* -}(caller_address) { +// Gets the execution info. +func execute_get_execution_info{range_check_ptr, syscall_ptr: felt*}( + execution_info: ExecutionInfo* +) { // Reduce gas. - let success = reduce_syscall_gas(gas_cost=GET_CALLER_ADDRESS_GAS_COST, request_size=0); + let success = reduce_syscall_gas_and_write_response_header( + total_gas_cost=GET_EXECUTION_INFO_GAS_COST, request_struct_size=0 + ); if (success == 0) { // Not enough gas to execute the syscall. return (); } - assert [cast(syscall_ptr, GetCallerAddressResponse*)] = GetCallerAddressResponse( - caller_address=caller_address + assert [cast(syscall_ptr, GetExecutionInfoResponse*)] = GetExecutionInfoResponse( + execution_info=execution_info ); // Advance syscall pointer to the next syscall. - let syscall_ptr = syscall_ptr + GetCallerAddressResponse.SIZE; + let syscall_ptr = syscall_ptr + GetExecutionInfoResponse.SIZE; return (); } -// Reduces the required amount of gas for the current syscall and writes the response header. +// Reduces the total amount of gas required for the current syscall and writes the response header. // In case of out-of-gas failure, writes the FailureReason object to syscall_ptr. // Returns 1 if the gas reduction succeeded and 0 otherwise. -func reduce_syscall_gas{range_check_ptr, syscall_ptr: felt*}( - gas_cost: felt, request_size: felt +func reduce_syscall_gas_and_write_response_header{range_check_ptr, syscall_ptr: felt*}( + total_gas_cost: felt, request_struct_size: felt ) -> felt { + let (success, remaining_gas) = reduce_syscall_base_gas( + specific_base_gas_cost=total_gas_cost, request_struct_size=request_struct_size + ); + if (success != 0) { + // Reduction has succeded; write the response header. + tempvar response_header = cast(syscall_ptr, ResponseHeader*); + // Advance syscall pointer to the response body. + let syscall_ptr = syscall_ptr + ResponseHeader.SIZE; + assert [response_header] = ResponseHeader(gas=remaining_gas, failure_flag=0); + + return 1; + } + + // Reduction has failed; in that case, 'reduce_syscall_base_gas' already wrote the response + // objects and advanced the syscall pointer. + return 0; +} + +// Reduces the base amount of gas for the current syscall. +// In case of out-of-gas failure, writes the corresponding ResponseHeader and FailureReason +// objects to syscall_ptr. +// Returns 1 if the gas reduction succeeded and 0 otherwise, along with the remaining gas. +func reduce_syscall_base_gas{range_check_ptr, syscall_ptr: felt*}( + specific_base_gas_cost: felt, request_struct_size: felt +) -> (success: felt, remaining_gas: felt) { let request_header = cast(syscall_ptr, RequestHeader*); // Advance syscall pointer to the response header. - tempvar syscall_ptr = syscall_ptr + RequestHeader.SIZE + request_size; + tempvar syscall_ptr = syscall_ptr + RequestHeader.SIZE + request_struct_size; + + // Refund the pre-charged base gas. + let required_gas = specific_base_gas_cost - SYSCALL_BASE_GAS_COST; + tempvar initial_gas = request_header.gas; + if (nondet %{ ids.initial_gas >= ids.required_gas %} != 0) { + tempvar remaining_gas = initial_gas - required_gas; + assert_nn(remaining_gas); + return (success=1, remaining_gas=remaining_gas); + } + // Handle out-of-gas. + assert_lt(initial_gas, required_gas); tempvar response_header = cast(syscall_ptr, ResponseHeader*); // Advance syscall pointer to the response body. let syscall_ptr = syscall_ptr + ResponseHeader.SIZE; - // Refund the pre-charged base gas. - let required_gas = gas_cost - SYSCALL_BASE_GAS_COST; - if (response_header.failure_flag != 0) { - // Verify that there was not enough gas to invoke the syscall. - tempvar initial_gas = request_header.gas; - assert_lt(initial_gas, required_gas); - assert [response_header] = ResponseHeader(gas=initial_gas, failure_flag=1); - - // Write the failure reason. - let failure_reason: FailureReason* = cast(syscall_ptr, FailureReason*); - // Advance syscall pointer to the next syscall. - let syscall_ptr = syscall_ptr + FailureReason.SIZE; - - tempvar start = failure_reason.start; - assert start[0] = ERROR_OUT_OF_GAS; - assert failure_reason.end = start + 1; - - return 0; - } + // Write the response header. + assert [response_header] = ResponseHeader(gas=initial_gas, failure_flag=1); - // Handle valid syscall. - tempvar remaining_gas = request_header.gas - required_gas; - assert [response_header] = ResponseHeader(gas=remaining_gas, failure_flag=0); - // Check that the remaining gas is non-negative. - assert_nn(remaining_gas); + let failure_reason: FailureReason* = cast(syscall_ptr, FailureReason*); + // Advance syscall pointer to the next syscall. + let syscall_ptr = syscall_ptr + FailureReason.SIZE; + + // Write the failure reason. + tempvar start = failure_reason.start; + assert start[0] = ERROR_OUT_OF_GAS; + assert failure_reason.end = start + 1; - return 1; + return (success=0, remaining_gas=initial_gas); } diff --git a/src/starkware/starknet/core/os/execution/execute_transactions.cairo b/src/starkware/starknet/core/os/execution/execute_transactions.cairo index 0c765285..a265a4db 100644 --- a/src/starkware/starknet/core/os/execution/execute_transactions.cairo +++ b/src/starkware/starknet/core/os/execution/execute_transactions.cairo @@ -7,6 +7,7 @@ from starkware.cairo.common.memcpy import memcpy from starkware.cairo.common.registers import get_fp_and_pc from starkware.cairo.common.segments import relocate_segment from starkware.cairo.common.uint256 import Uint256 +from starkware.starknet.builtins.segment_arena.segment_arena import new_arena from starkware.starknet.common.constants import ( DECLARE_HASH_PREFIX, DEPLOY_ACCOUNT_HASH_PREFIX, @@ -15,7 +16,9 @@ from starkware.starknet.common.constants import ( L1_HANDLER_HASH_PREFIX, ORIGIN_ADDRESS, ) -from starkware.starknet.common.syscalls import Deploy, TxInfo +from starkware.starknet.common.new_syscalls import ExecutionInfo, TxInfo +from starkware.starknet.common.syscalls import Deploy +from starkware.starknet.common.syscalls import TxInfo as DeprecatedTxInfo from starkware.starknet.core.os.block_context import BlockContext from starkware.starknet.core.os.builtins import BuiltinPointers from starkware.starknet.core.os.constants import ( @@ -28,6 +31,7 @@ from starkware.starknet.core.os.constants import ( EXECUTE_ENTRY_POINT_SELECTOR, INITIAL_GAS_COST, L1_HANDLER_VERSION, + SIERRA_ARRAY_LEN_BOUND, TRANSACTION_GAS_COST, TRANSACTION_VERSION, TRANSFER_ENTRY_POINT_SELECTOR, @@ -95,6 +99,8 @@ func execute_transactions{ // A dictionary from class hash to compiled class hash (casm). let (local contract_class_changes: DictAccess*) = dict_new(); + let segment_arena_ptr = new_arena(); + let (__fp__, _) = get_fp_and_pc(); local local_builtin_ptrs: BuiltinPointers = BuiltinPointers( pedersen=pedersen_ptr, @@ -103,6 +109,7 @@ func execute_transactions{ bitwise=bitwise_ptr, ec_op=ec_op_ptr, poseidon=poseidon_ptr, + segment_arena=segment_arena_ptr, ); let builtin_ptrs = &local_builtin_ptrs; @@ -233,20 +240,21 @@ func charge_fee{ outputs: OsCarriedOutputs*, }(block_context: BlockContext*, tx_execution_context: ExecutionContext*) { alloc_locals; - local original_tx_info: TxInfo* = tx_execution_context.original_tx_info; - local max_fee = original_tx_info.max_fee; + local execution_info: ExecutionInfo* = tx_execution_context.execution_info; + local tx_info: TxInfo* = execution_info.tx_info; + local max_fee = tx_info.max_fee; if (max_fee == 0) { return (); } // Transactions with fee should go through an account contract. - tempvar selector = tx_execution_context.selector; + tempvar selector = execution_info.selector; assert (selector - EXECUTE_ENTRY_POINT_SELECTOR) * ( selector - VALIDATE_DECLARE_ENTRY_POINT_SELECTOR ) * (selector - VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR) = 0; local calldata: TransferCallData = TransferCallData( - recipient=block_context.sequencer_address, + recipient=block_context.block_info.sequencer_address, amount=Uint256(low=nondet %{ execution_helper.tx_execution_info.actual_fee %}, high=0), ); @@ -260,13 +268,17 @@ func charge_fee{ let (__fp__, _) = get_fp_and_pc(); local execution_context: ExecutionContext = ExecutionContext( entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, - caller_address=original_tx_info.account_contract_address, - contract_address=fee_token_address, class_hash=fee_state_entry.class_hash, - selector=TRANSFER_ENTRY_POINT_SELECTOR, calldata_size=TransferCallData.SIZE, calldata=&calldata, - original_tx_info=original_tx_info, + execution_info=new ExecutionInfo( + block_info=execution_info.block_info, + tx_info=tx_info, + caller_address=tx_info.account_contract_address, + contract_address=fee_token_address, + selector=TRANSFER_ENTRY_POINT_SELECTOR, + ), + deprecated_tx_info=tx_execution_context.deprecated_tx_info, ); let remaining_gas = INITIAL_GAS_COST; @@ -312,8 +324,9 @@ func execute_invoke_function_transaction{ alloc_locals; let (local tx_execution_context: ExecutionContext*) = get_invoke_tx_execution_context( - entry_point_type=ENTRY_POINT_TYPE_EXTERNAL + block_context=block_context, entry_point_type=ENTRY_POINT_TYPE_EXTERNAL ); + local tx_execution_info: ExecutionInfo* = tx_execution_context.execution_info; // Guess tx version and make sure it's valid. local tx_version = nondet %{ tx.version %}; @@ -324,11 +337,11 @@ func execute_invoke_function_transaction{ let (__fp__, _) = get_fp_and_pc(); if (tx_version == 0) { - tempvar entry_point_selector_field = tx_execution_context.selector; + tempvar entry_point_selector_field = tx_execution_info.selector; tempvar additional_data_size = 0; tempvar additional_data = cast(0, felt*); } else { - assert tx_execution_context.selector = EXECUTE_ENTRY_POINT_SELECTOR; + assert tx_execution_info.selector = EXECUTE_ENTRY_POINT_SELECTOR; tempvar entry_point_selector_field = 0; tempvar additional_data_size = 1; tempvar additional_data = &nonce; @@ -346,20 +359,33 @@ func execute_invoke_function_transaction{ additional_data=additional_data, ); - assert [tx_execution_context.original_tx_info] = TxInfo( + // Write the transaction info and complete the ExecutionInfo struct. + tempvar tx_info = tx_execution_info.tx_info; + local signature_start: felt*; + local signature_len: felt; + %{ + ids.signature_start = segments.gen_arg(arg=tx.signature) + ids.signature_len = len(tx.signature) + %} + assert_nn_le(signature_len, SIERRA_ARRAY_LEN_BOUND - 1); + assert [tx_info] = TxInfo( version=tx_version, - account_contract_address=tx_execution_context.contract_address, + account_contract_address=tx_execution_info.contract_address, max_fee=max_fee, - signature_len=nondet %{ len(tx.signature) %}, - signature=cast(nondet %{ segments.gen_arg(arg=tx.signature) %}, felt*), + signature_start=signature_start, + signature_end=signature_start + signature_len, transaction_hash=transaction_hash, chain_id=chain_id, nonce=nonce, ); + fill_deprecated_tx_info(tx_info=tx_info, dst=tx_execution_context.deprecated_tx_info); check_and_increment_nonce(execution_context=tx_execution_context, nonce=nonce); - %{ execution_helper.start_tx(tx_info_ptr=ids.tx_execution_context.original_tx_info.address_) %} + %{ + tx_info_ptr = ids.tx_execution_context.deprecated_tx_info.address_ + execution_helper.start_tx(tx_info_ptr=tx_info_ptr) + %} run_validate(block_context=block_context, tx_execution_context=tx_execution_context); select_execute_entry_point_func( @@ -389,8 +415,9 @@ func execute_l1_handler_transaction{ alloc_locals; let (local tx_execution_context: ExecutionContext*) = get_invoke_tx_execution_context( - entry_point_type=ENTRY_POINT_TYPE_L1_HANDLER + block_context=block_context, entry_point_type=ENTRY_POINT_TYPE_L1_HANDLER ); + local tx_execution_info: ExecutionInfo* = tx_execution_context.execution_info; local nonce = nondet %{ tx.nonce %}; local chain_id = block_context.starknet_os_config.chain_id; @@ -400,28 +427,34 @@ func execute_l1_handler_transaction{ tx_hash_prefix=L1_HANDLER_HASH_PREFIX, version=L1_HANDLER_VERSION, execution_context=tx_execution_context, - entry_point_selector_field=tx_execution_context.selector, + entry_point_selector_field=tx_execution_info.selector, max_fee=0, chain_id=chain_id, additional_data_size=1, additional_data=&nonce, ); - assert [tx_execution_context.original_tx_info] = TxInfo( + // Write the transaction info and complete the ExecutionInfo struct. + tempvar tx_info = tx_execution_info.tx_info; + assert [tx_info] = TxInfo( version=L1_HANDLER_VERSION, - account_contract_address=tx_execution_context.contract_address, + account_contract_address=tx_execution_info.contract_address, max_fee=0, - signature_len=0, - signature=cast(0, felt*), + signature_start=cast(0, felt*), + signature_end=cast(0, felt*), transaction_hash=transaction_hash, chain_id=chain_id, nonce=nonce, ); + fill_deprecated_tx_info(tx_info=tx_info, dst=tx_execution_context.deprecated_tx_info); // Consume L1-to-L2 message. consume_l1_to_l2_message(execution_context=tx_execution_context, nonce=nonce); - %{ execution_helper.start_tx(tx_info_ptr=ids.tx_execution_context.original_tx_info.address_) %} + %{ + tx_info_ptr = ids.tx_execution_context.deprecated_tx_info.address_ + execution_helper.start_tx(tx_info_ptr=tx_info_ptr) + %} select_execute_entry_point_func( block_context=block_context, execution_context=tx_execution_context ); @@ -431,9 +464,10 @@ func execute_l1_handler_transaction{ } // Guess the execution context of an invoke transaction (either invoke function or L1 handler). -// Leaves 'original_tx_info' empty - should be filled later on. -func get_invoke_tx_execution_context{contract_state_changes: DictAccess*}( - entry_point_type: felt +// Leaves 'execution_info.tx_info' and 'deprecated_tx_info' empty - should be +// filled later on. +func get_invoke_tx_execution_context{range_check_ptr, contract_state_changes: DictAccess*}( + block_context: BlockContext*, entry_point_type: felt ) -> (tx_execution_context: ExecutionContext*) { alloc_locals; local contract_address; @@ -448,32 +482,53 @@ func get_invoke_tx_execution_context{contract_state_changes: DictAccess*}( ); local tx_execution_context: ExecutionContext* = new ExecutionContext( entry_point_type=entry_point_type, - caller_address=ORIGIN_ADDRESS, - contract_address=contract_address, class_hash=state_entry.class_hash, - selector=nondet %{ tx.entry_point_selector %}, calldata_size=nondet %{ len(tx.calldata) %}, calldata=cast(nondet %{ segments.gen_arg(tx.calldata) %}, felt*), - original_tx_info=cast(nondet %{ segments.add() %}, TxInfo*), + execution_info=new ExecutionInfo( + block_info=block_context.block_info, + tx_info=cast(nondet %{ segments.add() %}, TxInfo*), + caller_address=ORIGIN_ADDRESS, + contract_address=contract_address, + selector=nondet %{ tx.entry_point_selector %}, + ), + deprecated_tx_info=cast(nondet %{ segments.add() %}, DeprecatedTxInfo*), ); + assert_nn_le(tx_execution_context.calldata_size, SIERRA_ARRAY_LEN_BOUND - 1); return (tx_execution_context=tx_execution_context); } +// Initializes the given DeprecatedTxInfo (dst) based on the given TxInfo. +func fill_deprecated_tx_info(tx_info: TxInfo*, dst: DeprecatedTxInfo*) { + tempvar signature_start = tx_info.signature_start; + assert [dst] = DeprecatedTxInfo( + version=tx_info.version, + account_contract_address=tx_info.account_contract_address, + max_fee=tx_info.max_fee, + signature_len=tx_info.signature_end - signature_start, + signature=signature_start, + transaction_hash=tx_info.transaction_hash, + chain_id=tx_info.chain_id, + nonce=tx_info.nonce, + ); + return (); +} + // Verifies that the transaction's nonce matches the contract's nonce and increments the // latter. func check_and_increment_nonce{contract_state_changes: DictAccess*}( execution_context: ExecutionContext*, nonce: felt ) -> () { alloc_locals; + local execution_info: ExecutionInfo* = execution_context.execution_info; // Do not handle nonce for version 0. - local tx_version = execution_context.original_tx_info.version; - if (tx_version == 0) { + if (execution_info.tx_info.version == 0) { return (); } - tempvar contract_address = execution_context.contract_address; + tempvar contract_address = execution_info.contract_address; local state_entry: StateEntry*; %{ # Fetch a state_entry in this hint and validate it in the update that comes next. @@ -516,22 +571,27 @@ func run_validate{ outputs: OsCarriedOutputs*, }(block_context: BlockContext*, tx_execution_context: ExecutionContext*) { alloc_locals; + local tx_execution_info: ExecutionInfo* = tx_execution_context.execution_info; // Do not run "__validate__" for version 0. - if (tx_execution_context.original_tx_info.version == 0) { + if (tx_execution_info.tx_info.version == 0) { return (); } // "__validate__" is expected to get the same calldata as "__execute__". local validate_execution_context: ExecutionContext* = new ExecutionContext( entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, - caller_address=ORIGIN_ADDRESS, - contract_address=tx_execution_context.contract_address, class_hash=tx_execution_context.class_hash, - selector=VALIDATE_ENTRY_POINT_SELECTOR, calldata_size=tx_execution_context.calldata_size, calldata=tx_execution_context.calldata, - original_tx_info=tx_execution_context.original_tx_info, + execution_info=new ExecutionInfo( + block_info=tx_execution_info.block_info, + tx_info=tx_execution_info.tx_info, + caller_address=tx_execution_info.caller_address, + contract_address=tx_execution_info.contract_address, + selector=VALIDATE_ENTRY_POINT_SELECTOR, + ), + deprecated_tx_info=tx_execution_context.deprecated_tx_info, ); select_execute_entry_point_func( @@ -549,12 +609,14 @@ func consume_l1_to_l2_message{outputs: OsCarriedOutputs*}( let payload: felt* = execution_context.calldata + 1; tempvar payload_size = execution_context.calldata_size - 1; + tempvar execution_info = execution_context.execution_info; + // Write the given transaction to the output. assert [outputs.messages_to_l2] = MessageToL2Header( from_address=[execution_context.calldata], - to_address=execution_context.contract_address, + to_address=execution_info.contract_address, nonce=nonce, - selector=execution_context.selector, + selector=execution_info.selector, payload_size=payload_size, ); @@ -570,10 +632,10 @@ func consume_l1_to_l2_message{outputs: OsCarriedOutputs*}( } // Prepares a constructor execution context based on the 'tx' hint variable. -// Leaves 'original_tx_info' empty - should be filled later on. -func prepare_constructor_execution_context{range_check_ptr, builtin_ptrs: BuiltinPointers*}() -> ( - constructor_execution_context: ExecutionContext*, salt: felt -) { +// Leaves 'execution_info.tx_info' and 'deprecated_tx_info' empty - should be filled later on. +func prepare_constructor_execution_context{range_check_ptr, builtin_ptrs: BuiltinPointers*}( + block_context: BlockContext* +) -> (constructor_execution_context: ExecutionContext*, salt: felt) { alloc_locals; local contract_address_salt; @@ -586,7 +648,7 @@ func prepare_constructor_execution_context{range_check_ptr, builtin_ptrs: Builti ids.constructor_calldata_size = len(tx.constructor_calldata) ids.constructor_calldata = segments.gen_arg(arg=tx.constructor_calldata) %} - assert_nn(constructor_calldata_size); + assert_nn_le(constructor_calldata_size, SIERRA_ARRAY_LEN_BOUND - 1); let hash_ptr = builtin_ptrs.pedersen; with hash_ptr { @@ -605,17 +667,22 @@ func prepare_constructor_execution_context{range_check_ptr, builtin_ptrs: Builti bitwise=builtin_ptrs.bitwise, ec_op=builtin_ptrs.ec_op, poseidon=builtin_ptrs.poseidon, + segment_arena=builtin_ptrs.segment_arena, ); tempvar constructor_execution_context = new ExecutionContext( entry_point_type=ENTRY_POINT_TYPE_CONSTRUCTOR, - caller_address=ORIGIN_ADDRESS, - contract_address=contract_address, class_hash=class_hash, - selector=CONSTRUCTOR_ENTRY_POINT_SELECTOR, calldata_size=constructor_calldata_size, calldata=constructor_calldata, - original_tx_info=cast(nondet %{ segments.add() %}, TxInfo*), + execution_info=new ExecutionInfo( + block_info=block_context.block_info, + tx_info=cast(nondet %{ segments.add() %}, TxInfo*), + caller_address=ORIGIN_ADDRESS, + contract_address=contract_address, + selector=CONSTRUCTOR_ENTRY_POINT_SELECTOR, + ), + deprecated_tx_info=cast(nondet %{ segments.add() %}, DeprecatedTxInfo*), ); return ( @@ -636,7 +703,8 @@ func execute_deploy_account_transaction{ // Calculate address and prepare constructor execution context. let ( local constructor_execution_context: ExecutionContext*, local salt - ) = prepare_constructor_execution_context(); + ) = prepare_constructor_execution_context(block_context=block_context); + local constructor_execution_info: ExecutionInfo* = constructor_execution_context.execution_info; // Prepare validate_deploy calldata. let (validate_deploy_calldata: felt*) = alloc(); @@ -648,17 +716,22 @@ func execute_deploy_account_transaction{ len=constructor_execution_context.calldata_size, ); - // Note that the members of original_tx_info are not initialized at this point. - local original_tx_info: TxInfo* = constructor_execution_context.original_tx_info; + // Note that the members of execution_info.tx_info are not initialized at this point. + local tx_info: TxInfo* = constructor_execution_info.tx_info; + local deprecated_tx_info: DeprecatedTxInfo* = constructor_execution_context.deprecated_tx_info; local validate_deploy_execution_context: ExecutionContext* = new ExecutionContext( entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, - caller_address=ORIGIN_ADDRESS, - contract_address=constructor_execution_context.contract_address, class_hash=constructor_execution_context.class_hash, - selector=VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR, calldata_size=constructor_execution_context.calldata_size + 2, calldata=validate_deploy_calldata, - original_tx_info=original_tx_info, + execution_info=new ExecutionInfo( + block_info=constructor_execution_info.block_info, + tx_info=tx_info, + caller_address=constructor_execution_info.caller_address, + contract_address=constructor_execution_info.contract_address, + selector=VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR, + ), + deprecated_tx_info=deprecated_tx_info, ); // Compute transaction hash and prepare transaction info. @@ -679,18 +752,26 @@ func execute_deploy_account_transaction{ // Assign the transaction info to both calls. // Note that both constructor_execution_context and // validate_deploy_execution_context hold this pointer. - assert [original_tx_info] = TxInfo( + local signature_start: felt*; + local signature_len: felt; + %{ + ids.signature_start = segments.gen_arg(arg=tx.signature) + ids.signature_len = len(tx.signature) + %} + assert_nn_le(signature_len, SIERRA_ARRAY_LEN_BOUND - 1); + assert [tx_info] = TxInfo( version=tx_version, - account_contract_address=validate_deploy_execution_context.contract_address, + account_contract_address=constructor_execution_info.contract_address, max_fee=max_fee, - signature_len=nondet %{ len(tx.signature) %}, - signature=cast(nondet %{ segments.gen_arg(arg=tx.signature) %}, felt*), + signature_start=signature_start, + signature_end=signature_start + signature_len, transaction_hash=transaction_hash, chain_id=block_context.starknet_os_config.chain_id, nonce=[nonce_ptr], ); + fill_deprecated_tx_info(tx_info=tx_info, dst=deprecated_tx_info); - %{ execution_helper.start_tx(tx_info_ptr=ids.original_tx_info.address_) %} + %{ execution_helper.start_tx(tx_info_ptr=ids.deprecated_tx_info.address_) %} deploy_contract( block_context=block_context, constructor_execution_context=constructor_execution_context @@ -724,7 +805,7 @@ func execute_deploy_transaction{ let ( local constructor_execution_context: ExecutionContext*, _ - ) = prepare_constructor_execution_context(); + ) = prepare_constructor_execution_context(block_context=block_context); // Guess tx version and make sure it's valid. local tx_version = nondet %{ tx.version %}; @@ -743,20 +824,23 @@ func execute_deploy_transaction{ additional_data=nullptr, ); - assert [constructor_execution_context.original_tx_info] = TxInfo( + // Write the transaction info and complete the ExecutionInfo struct. + tempvar tx_info = constructor_execution_context.execution_info.tx_info; + assert [tx_info] = TxInfo( version=tx_version, account_contract_address=ORIGIN_ADDRESS, max_fee=0, - signature_len=0, - signature=nullptr, + signature_start=nullptr, + signature_end=nullptr, transaction_hash=transaction_hash, chain_id=chain_id, nonce=0, ); + fill_deprecated_tx_info(tx_info=tx_info, dst=constructor_execution_context.deprecated_tx_info); %{ execution_helper.start_tx( - tx_info_ptr=ids.constructor_execution_context.original_tx_info.address_ + tx_info_ptr=ids.constructor_execution_context.deprecated_tx_info.address_ ) %} @@ -832,13 +916,17 @@ func execute_declare_transaction{ let (state_entry: StateEntry*) = dict_read{dict_ptr=contract_state_changes}(key=sender_address); local validate_declare_execution_context: ExecutionContext* = new ExecutionContext( entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, - caller_address=ORIGIN_ADDRESS, - contract_address=sender_address, class_hash=state_entry.class_hash, - selector=VALIDATE_DECLARE_ENTRY_POINT_SELECTOR, calldata_size=1, calldata=calldata, - original_tx_info=cast(nondet %{ segments.add() %}, TxInfo*), + execution_info=new ExecutionInfo( + block_info=block_context.block_info, + tx_info=cast(nondet %{ segments.add() %}, TxInfo*), + caller_address=ORIGIN_ADDRESS, + contract_address=sender_address, + selector=VALIDATE_DECLARE_ENTRY_POINT_SELECTOR, + ), + deprecated_tx_info=cast(nondet %{ segments.add() %}, DeprecatedTxInfo*), ); let (transaction_hash) = compute_transaction_hash( @@ -851,22 +939,35 @@ func execute_declare_transaction{ additional_data_size=additional_data_size, additional_data=additional_data, ); - assert [validate_declare_execution_context.original_tx_info] = TxInfo( + + // Write the transaction info and complete the ExecutionInfo struct. + tempvar tx_info = validate_declare_execution_context.execution_info.tx_info; + local signature_start: felt*; + local signature_len: felt; + %{ + ids.signature_start = segments.gen_arg(arg=tx.signature) + ids.signature_len = len(tx.signature) + %} + assert_nn_le(signature_len, SIERRA_ARRAY_LEN_BOUND - 1); + assert [tx_info] = TxInfo( version=tx_version, account_contract_address=sender_address, max_fee=max_fee, - signature_len=nondet %{ len(tx.signature) %}, - signature=cast(nondet %{ segments.gen_arg(arg=tx.signature) %}, felt*), + signature_start=signature_start, + signature_end=signature_start + signature_len, transaction_hash=transaction_hash, chain_id=chain_id, nonce=nonce, ); + fill_deprecated_tx_info( + tx_info=tx_info, dst=validate_declare_execution_context.deprecated_tx_info + ); check_and_increment_nonce(execution_context=validate_declare_execution_context, nonce=nonce); %{ execution_helper.start_tx( - tx_info_ptr=ids.validate_declare_execution_context.original_tx_info.address_ + tx_info_ptr=ids.validate_declare_execution_context.deprecated_tx_info.address_ ) %} @@ -885,8 +986,8 @@ func execute_declare_transaction{ // Computes the hash of the transaction. // -// Note that execution_context.original_tx_info is uninitialized when this function is called. -// In particular, this field is not used in this function. +// Note that 'execution_context.execution_info.tx_info' and 'deprecated_tx_info' are uninitialized +// when this function is called. In particular, these fields are not used in this function. func compute_transaction_hash{builtin_ptrs: BuiltinPointers*}( tx_hash_prefix: felt, version: felt, @@ -902,7 +1003,7 @@ func compute_transaction_hash{builtin_ptrs: BuiltinPointers*}( let (transaction_hash) = get_transaction_hash( tx_hash_prefix=tx_hash_prefix, version=version, - contract_address=execution_context.contract_address, + contract_address=execution_context.execution_info.contract_address, entry_point_selector=entry_point_selector_field, calldata_size=execution_context.calldata_size, calldata=execution_context.calldata, @@ -926,6 +1027,7 @@ func compute_transaction_hash{builtin_ptrs: BuiltinPointers*}( bitwise=builtin_ptrs.bitwise, ec_op=builtin_ptrs.ec_op, poseidon=builtin_ptrs.poseidon, + segment_arena=builtin_ptrs.segment_arena, ); return (transaction_hash=transaction_hash); diff --git a/src/starkware/starknet/core/os/program_hash.json b/src/starkware/starknet/core/os/program_hash.json index cfd0b722..e2538571 100644 --- a/src/starkware/starknet/core/os/program_hash.json +++ b/src/starkware/starknet/core/os/program_hash.json @@ -1,3 +1,3 @@ { - "program_hash": "0x79ded6c61f7f046dc7a354dc2c046efbc9d6fed0f860ce0362749db79dc8f12" + "program_hash": "0x1c3dd8ba84f2895f749e8286cf26d813e7a51d7b65565a75d8a8da33f1aeae" } diff --git a/src/starkware/starknet/core/os/state.cairo b/src/starkware/starknet/core/os/state.cairo index 453d5cfc..7a66a0bc 100644 --- a/src/starkware/starknet/core/os/state.cairo +++ b/src/starkware/starknet/core/os/state.cairo @@ -6,22 +6,19 @@ from starkware.cairo.common.hash import hash2 from starkware.cairo.common.math import assert_nn_le from starkware.cairo.common.math_cmp import is_not_zero from starkware.cairo.common.patricia import ( - PatriciaUpdateConstants, patricia_update_constants_new, patricia_update_using_update_constants, ) -from starkware.cairo.common.patricia_with_sponge import ( - PatriciaUpdateConstants as PatriciaUpdateConstantsWithSponge, -) -from starkware.cairo.common.patricia_with_sponge import ( - patricia_update_using_update_constants as patricia_update_using_update_constants_with_sponge, +from starkware.cairo.common.patricia_utils import PatriciaUpdateConstants +from starkware.cairo.common.patricia_with_poseidon import ( + patricia_update_using_update_constants as patricia_update_using_update_constants_with_poseidon, ) from starkware.cairo.common.segments import relocate_segment -from starkware.cairo.common.sponge_as_hash import SpongeHashBuiltin const MERKLE_HEIGHT = 251; // PRIME.bit_length() - 1. const UNINITIALIZED_CLASS_HASH = 0; const GLOBAL_STATE_VERSION = 'STARKNET_STATE_V0'; +const CONTRACT_CLASS_LEAF_VERSION = 'CONTRACT_CLASS_LEAF_V0'; // The on-chain data for contract state changes has the following format: // @@ -241,27 +238,16 @@ func contract_class_update{ hashed_class_changes=hashed_class_changes, ); - // Perform casts to work with sponge hashes. - let patricia_update_constants_with_sponge = cast( - patricia_update_constants, PatriciaUpdateConstantsWithSponge* - ); - let hash_ptr = cast(poseidon_ptr, SpongeHashBuiltin*); - // Call patricia_update_using_update_constants() instead of patricia_update() // in order not to repeat globals_pow2 calculation. - with hash_ptr { - patricia_update_using_update_constants_with_sponge( - patricia_update_constants=patricia_update_constants_with_sponge, - update_ptr=hashed_class_changes, - n_updates=n_class_updates, - height=MERKLE_HEIGHT, - prev_root=initial_root, - new_root=final_root, - ); - } - - // Update poseidon_ptr. - let poseidon_ptr = cast(hash_ptr, PoseidonBuiltin*); + patricia_update_using_update_constants_with_poseidon( + patricia_update_constants=patricia_update_constants, + update_ptr=hashed_class_changes, + n_updates=n_class_updates, + height=MERKLE_HEIGHT, + prev_root=initial_root, + new_root=final_root, + ); serialize_contract_class_da_changes(update_ptr=squashed_dict, n_updates=n_class_updates); @@ -305,13 +291,12 @@ func hash_class_changes{poseidon_ptr: PoseidonBuiltin*}( func get_contract_class_leaf_hash{poseidon_ptr: PoseidonBuiltin*}(compiled_class_hash: felt) -> ( hash: felt ) { - const CONTRACT_CLASS_HASH_VERSION = 'CONTRACT_CLASS_LEAF_V0'; if (compiled_class_hash == UNINITIALIZED_CLASS_HASH) { return (hash=0); } - // Return H(CONTRACT_CLASS_HASH_VERSION, compiled_class_hash). - let (hash_value) = poseidon_hash(CONTRACT_CLASS_HASH_VERSION, compiled_class_hash); + // Return H(CONTRACT_CLASS_LEAF_VERSION, compiled_class_hash). + let (hash_value) = poseidon_hash(CONTRACT_CLASS_LEAF_VERSION, compiled_class_hash); return (hash=hash_value); } diff --git a/src/starkware/starknet/core/os/syscall_handler.py b/src/starkware/starknet/core/os/syscall_handler.py index f4620b3c..5600275a 100644 --- a/src/starkware/starknet/core/os/syscall_handler.py +++ b/src/starkware/starknet/core/os/syscall_handler.py @@ -1,8 +1,19 @@ -import contextlib import dataclasses import functools from abc import ABC, abstractmethod -from typing import Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Tuple, Type, cast +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Tuple, + Type, + cast, +) import cachetools @@ -10,7 +21,7 @@ from starkware.cairo.common.structs import CairoStructProxy from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue -from starkware.python.utils import assert_exhausted, camel_to_snake_case +from starkware.python.utils import as_non_optional, assert_exhausted, camel_to_snake_case from starkware.starknet.business_logic.execution.execute_entry_point_base import ( ExecuteEntryPointBase, ) @@ -32,13 +43,13 @@ ) from starkware.starknet.core.os.syscall_utils import ( STARKNET_SYSCALLS_COMPILED_PATH, - HandlerException, cast_to_int, get_deprecated_syscall_structs_and_info, get_selector_from_program, get_syscall_structs, load_program, validate_runtime_request_type, + wrap_with_handler_exception, ) from starkware.starknet.definitions.constants import GasCost from starkware.starknet.definitions.error_codes import CairoErrorCode, StarknetErrorCode @@ -46,7 +57,7 @@ from starkware.starknet.public.abi import CONSTRUCTOR_ENTRY_POINT_SELECTOR from starkware.starknet.services.api.contract_class.contract_class import EntryPointType from starkware.starknet.storage.starknet_storage import OsSingleStarknetStorage -from starkware.starkware_utils.error_handling import StarkException, stark_assert +from starkware.starkware_utils.error_handling import stark_assert SyscallFullResponse = Tuple[tuple, tuple] # Response header + specific syscall response. ExecuteSyscallCallback = Callable[ @@ -85,6 +96,26 @@ def get_selector_to_syscall_info(cls) -> Dict[int, SyscallInfo]: get_selector_from_program, syscalls_program=syscalls_program ) return { + get_selector("call_contract"): SyscallInfo( + name="call_contract", + execute_callback=cls.call_contract, + request_struct=structs.CallContractRequest, + ), + get_selector("deploy"): SyscallInfo( + name="deploy", + execute_callback=cls.deploy, + request_struct=structs.DeployRequest, + ), + get_selector("get_execution_info"): SyscallInfo( + name="get_execution_info", + execute_callback=cls.get_execution_info, + request_struct=structs.EmptyRequest, + ), + get_selector("library_call"): SyscallInfo( + name="library_call", + execute_callback=cls.library_call, + request_struct=structs.LibraryCallRequest, + ), get_selector("storage_read"): SyscallInfo( name="storage_read", execute_callback=cls.storage_read, @@ -95,16 +126,21 @@ def get_selector_to_syscall_info(cls) -> Dict[int, SyscallInfo]: execute_callback=cls.storage_write, request_struct=structs.StorageWriteRequest, ), - get_selector("get_caller_address"): SyscallInfo( - name="get_caller_address", - execute_callback=cls.get_caller_address, - request_struct=structs.EmptyRequest, - ), get_selector("emit_event"): SyscallInfo( name="emit_event", execute_callback=cls.emit_event, request_struct=structs.EmitEventRequest, ), + get_selector("replace_class"): SyscallInfo( + name="replace_class", + execute_callback=cls.replace_class, + request_struct=structs.ReplaceClassRequest, + ), + get_selector("send_message_to_l1"): SyscallInfo( + name="send_message_to_l1", + execute_callback=cls.send_message_to_l1, + request_struct=structs.SendMessageToL1Request, + ), } @property @@ -149,8 +185,69 @@ def syscall(self, syscall_ptr: RelocatableValue): # Syscalls. + def call_contract(self, remaining_gas: int, request: CairoStructProxy) -> SyscallFullResponse: + return self.call_contract_helper( + remaining_gas=remaining_gas, request=request, syscall_name="call_contract" + ) + + def library_call(self, remaining_gas: int, request: CairoStructProxy) -> SyscallFullResponse: + return self.call_contract_helper( + remaining_gas=remaining_gas, request=request, syscall_name="library_call" + ) + + def call_contract_helper( + self, remaining_gas: int, request: CairoStructProxy, syscall_name: str + ) -> SyscallFullResponse: + result = self._call_contract_helper( + remaining_gas=remaining_gas, request=request, syscall_name=syscall_name + ) + + remaining_gas -= result.gas_consumed + response_header = self.structs.ResponseHeader( + gas=remaining_gas, failure_flag=result.failure_flag + ) + retdata_start = self._allocate_segment_for_retdata(retdata=result.retdata) + retdata_end = retdata_start + len(result.retdata) + if response_header.failure_flag == 0: + response = self.structs.CallContractResponse( + retdata_start=retdata_start, retdata_end=retdata_end + ) + else: + response = self.structs.FailureReason(start=retdata_start, end=retdata_end) + + return response_header, response + + def deploy(self, remaining_gas: int, request: CairoStructProxy) -> SyscallFullResponse: + contract_address, result = self._deploy(remaining_gas=remaining_gas, request=request) + + remaining_gas -= result.gas_consumed + response_header = self.structs.ResponseHeader( + gas=remaining_gas, failure_flag=result.failure_flag + ) + retdata_start = self._allocate_segment_for_retdata(retdata=result.retdata) + retdata_end = retdata_start + len(result.retdata) + if response_header.failure_flag == 0: + response = self.structs.DeployResponse( + contract_address=contract_address, + constructor_retdata_start=retdata_start, + constructor_retdata_end=retdata_end, + ) + else: + response = self.structs.FailureReason(start=retdata_start, end=retdata_end) + + return response_header, response + + def get_execution_info( + self, remaining_gas: int, request: CairoStructProxy + ) -> SyscallFullResponse: + execution_info_ptr = self._get_execution_info_ptr() + + response_header = self.structs.ResponseHeader(gas=remaining_gas, failure_flag=0) + response = self.structs.GetExecutionInfoResponse(execution_info=execution_info_ptr) + return response_header, response + def storage_read(self, remaining_gas: int, request: CairoStructProxy) -> SyscallFullResponse: - assert request.reserved == 0, "Unexpected reserved value." + assert request.reserved == 0, f"Unsupported address domain: {request.reserved}." value = self._storage_read(key=cast_to_int(request.key)) response_header = self.structs.ResponseHeader(gas=remaining_gas, failure_flag=0) @@ -158,37 +255,64 @@ def storage_read(self, remaining_gas: int, request: CairoStructProxy) -> Syscall return response_header, response def storage_write(self, remaining_gas: int, request: CairoStructProxy) -> SyscallFullResponse: - assert request.reserved == 0, "Unexpected reserved value." + assert request.reserved == 0, f"Unsupported address domain: {request.reserved}." self._storage_write(key=cast_to_int(request.key), value=cast_to_int(request.value)) response_header = self.structs.ResponseHeader(gas=remaining_gas, failure_flag=0) return response_header, tuple() - def get_caller_address( - self, remaining_gas: int, request: CairoStructProxy - ) -> SyscallFullResponse: - caller_address = self._get_caller_address() + def emit_event(self, remaining_gas: int, request: CairoStructProxy) -> SyscallFullResponse: + keys = self._get_felt_range(start_addr=request.keys_start, end_addr=request.keys_end) + data = self._get_felt_range(start_addr=request.data_start, end_addr=request.data_end) + self._emit_event(keys=keys, data=data) response_header = self.structs.ResponseHeader(gas=remaining_gas, failure_flag=0) - response = self.structs.GetCallerAddressResponse(caller_address=caller_address) - return response_header, response + return response_header, tuple() - def emit_event(self, remaining_gas: int, request: CairoStructProxy) -> SyscallFullResponse: - keys = self._get_felt_range( - start_addr=cast(RelocatableValue, request.keys_start), - end_addr=cast(RelocatableValue, request.keys_end), - ) - data = self._get_felt_range( - start_addr=cast(RelocatableValue, request.data_start), - end_addr=cast(RelocatableValue, request.data_end), + def replace_class(self, remaining_gas: int, request: CairoStructProxy) -> SyscallFullResponse: + self._replace_class(class_hash=cast_to_int(request.class_hash)) + + response_header = self.structs.ResponseHeader(gas=remaining_gas, failure_flag=0) + return response_header, tuple() + + def send_message_to_l1( + self, remaining_gas: int, request: CairoStructProxy + ) -> SyscallFullResponse: + payload = self._get_felt_range( + start_addr=cast(RelocatableValue, request.payload_start), + end_addr=cast(RelocatableValue, request.payload_end), ) - self._emit_event(keys=keys, data=data) + self._send_message_to_l1(to_address=cast_to_int(request.to_address), payload=payload) response_header = self.structs.ResponseHeader(gas=remaining_gas, failure_flag=0) return response_header, tuple() # Application-specific syscall implementation. + @abstractmethod + def _call_contract_helper( + self, remaining_gas: int, request: CairoStructProxy, syscall_name: str + ) -> CallResult: + """ + Returns the call's result. + + syscall_name can be "call_contract" or "library_call". + """ + + @abstractmethod + def _deploy(self, remaining_gas: int, request: CairoStructProxy) -> Tuple[int, CallResult]: + """ + Returns the address of the newly deployed contract and the constructor call's result. + Note that the result may contain failures that preceded the constructor invocation, such + as undeclared class. + """ + + @abstractmethod + def _get_execution_info_ptr(self) -> RelocatableValue: + """ + Returns a pointer to the ExecutionInfo struct. + """ + @abstractmethod def _storage_read(self, key: int) -> int: """ @@ -202,15 +326,21 @@ def _storage_write(self, key: int, value: int): """ @abstractmethod - def _get_caller_address(self) -> int: + def _emit_event(self, keys: List[int], data: List[int]): """ - Returns the address of the caller contract. + Specific implementation of the emit_event syscall. """ @abstractmethod - def _emit_event(self, keys: List[int], data: List[int]): + def _replace_class(self, class_hash: int): """ - Specific implementation of the emit_event syscall. + Specific implementation of the replace_class syscall. + """ + + @abstractmethod + def _send_message_to_l1(self, to_address: int, payload: List[int]): + """ + Specific implementation of the send_message_to_l1 syscall. """ # Internal utilities. @@ -231,9 +361,9 @@ def _handle_out_of_gas(self, initial_gas: int) -> SyscallFullResponse: return response_header, failure_reason - def _get_felt_range( - self, start_addr: RelocatableValue, end_addr: RelocatableValue - ) -> List[int]: + def _get_felt_range(self, start_addr: Any, end_addr: Any) -> List[int]: + assert isinstance(start_addr, RelocatableValue) + assert isinstance(end_addr, RelocatableValue) assert start_addr.segment_index == end_addr.segment_index, ( "Inconsistent start and end segment indices " f"({start_addr.segment_index} != {end_addr.segment_index})." @@ -255,6 +385,12 @@ def allocate_segment(self, data: Iterable[MaybeRelocatable]) -> RelocatableValue recursive input - call allocate_segment for the inner items if needed. """ + @abstractmethod + def _allocate_segment_for_retdata(self, retdata: Iterable[int]) -> RelocatableValue: + """ + Allocates and returns a new (read-only) segment with the given retdata. + """ + def _validate_syscall_ptr(self, actual_syscall_ptr: RelocatableValue): assert ( actual_syscall_ptr == self.syscall_ptr @@ -284,37 +420,151 @@ class BusinessLogicSyscallHandler(SyscallHandlerBase): def __init__( self, state: SyncState, + resources_manager: ExecutionResourcesManager, segments: MemorySegmentManager, tx_execution_context: TransactionExecutionContext, initial_syscall_ptr: RelocatableValue, - caller_address: int, - contract_address: int, + general_config: StarknetGeneralConfig, + entry_point: ExecuteEntryPointBase, + support_reverted: bool, ): super().__init__(segments=segments, initial_syscall_ptr=initial_syscall_ptr) + # Entry point info. + self.entry_point = entry_point + self.execute_entry_point_cls: Type[ExecuteEntryPointBase] = type(entry_point) + + # Configuration objects. + self.general_config = general_config + + # Execution-related objects. self.tx_execution_context = tx_execution_context - self.caller_address = caller_address + self.resources_manager = resources_manager + self.state = state + self.support_reverted = support_reverted # The storage which the current call acts on. - self.storage = ContractStorageState(state=state, contract_address=contract_address) + self.storage = ContractStorageState( + state=state, contract_address=self.entry_point.contract_address + ) # A list of dynamically allocated segments that are expected to be read-only. self.read_only_segments: List[Tuple[RelocatableValue, int]] = [] + # Internal calls executed by the current contract call. + self.internal_calls: List[CallInfo] = [] + # Events emitted by the current contract call. self.events: List[OrderedEvent] = [] + # Messages sent by the current contract call to L1. + self.l2_to_l1_messages: List[OrderedL2ToL1Message] = [] + + # A pointer to the Cairo ExecutionInfo struct. + self._execution_info_ptr: Optional[RelocatableValue] = None + # Syscalls. + def _call_contract_helper( + self, remaining_gas: int, request: CairoStructProxy, syscall_name: str + ) -> CallResult: + calldata = self._get_felt_range( + start_addr=request.calldata_start, end_addr=request.calldata_end + ) + class_hash: Optional[int] = None + if syscall_name == "call_contract": + contract_address = cast_to_int(request.contract_address) + caller_address = self.entry_point.contract_address + call_type = CallType.CALL + elif syscall_name == "library_call": + contract_address = self.entry_point.contract_address + caller_address = self.entry_point.caller_address + call_type = CallType.DELEGATE + class_hash = cast_to_int(request.class_hash) + else: + raise NotImplementedError(f"Unsupported call type {syscall_name}.") + + call = self.execute_entry_point_cls( + call_type=call_type, + contract_address=contract_address, + entry_point_selector=cast_to_int(request.selector), + entry_point_type=EntryPointType.EXTERNAL, + calldata=calldata, + caller_address=caller_address, + initial_gas=remaining_gas, + class_hash=class_hash, + code_address=None, + ) + + return self.execute_entry_point(call=call) + + def _deploy(self, remaining_gas: int, request: CairoStructProxy) -> Tuple[int, CallResult]: + assert request.deploy_from_zero in [0, 1], "The deploy_from_zero field must be 0 or 1." + constructor_calldata = self._get_felt_range( + start_addr=request.constructor_calldata_start, end_addr=request.constructor_calldata_end + ) + class_hash = cast_to_int(request.class_hash) + + # Calculate contract address. + deployer_address = self.entry_point.contract_address if request.deploy_from_zero == 0 else 0 + contract_address = calculate_contract_address_from_hash( + salt=cast_to_int(request.contract_address_salt), + class_hash=class_hash, + constructor_calldata=constructor_calldata, + deployer_address=deployer_address, + ) + # Instantiate the contract (may raise UNDECLARED_CLASS and CONTRACT_ADDRESS_UNAVAILABLE). + self.state.deploy_contract(contract_address=contract_address, class_hash=class_hash) + + # Invoke constructor. + result = self.execute_constructor_entry_point( + contract_address=contract_address, + class_hash=class_hash, + constructor_calldata=constructor_calldata, + remaining_gas=remaining_gas, + ) + return contract_address, result + + def _get_execution_info_ptr(self) -> RelocatableValue: + if self._execution_info_ptr is None: + # Prepare block info. + python_block_info = self.storage.state.block_info + block_info = self.structs.BlockInfo( + block_number=python_block_info.block_number, + block_timestamp=python_block_info.block_timestamp, + sequencer_address=as_non_optional(python_block_info.sequencer_address), + ) + # Prepare transaction info. + signature = self.tx_execution_context.signature + signature_start = self.allocate_segment(data=signature) + tx_info = self.structs.TxInfo( + version=self.tx_execution_context.version, + account_contract_address=self.tx_execution_context.account_contract_address, + max_fee=self.tx_execution_context.max_fee, + signature_start=signature_start, + signature_end=signature_start + len(signature), + transaction_hash=self.tx_execution_context.transaction_hash, + chain_id=self.general_config.chain_id.value, + nonce=self.tx_execution_context.nonce, + ) + # Gather all info. + execution_info = self.structs.ExecutionInfo( + block_info=self.allocate_segment(data=block_info), + tx_info=self.allocate_segment(data=tx_info), + caller_address=self.entry_point.caller_address, + contract_address=self.entry_point.contract_address, + selector=self.entry_point.entry_point_selector, + ) + self._execution_info_ptr = self.allocate_segment(data=execution_info) + + return self._execution_info_ptr + def _storage_read(self, key: int) -> int: return self.storage.read(address=key) def _storage_write(self, key: int, value: int): self.storage.write(address=key, value=value) - def _get_caller_address(self) -> int: - return self.caller_address - def _emit_event(self, keys: List[int], data: List[int]): self.events.append( OrderedEvent(order=self.tx_execution_context.n_emitted_events, keys=keys, data=data) @@ -323,14 +573,95 @@ def _emit_event(self, keys: List[int], data: List[int]): # Update events count. self.tx_execution_context.n_emitted_events += 1 + def _replace_class(self, class_hash: int): + compiled_class_hash = self.storage.state.get_compiled_class_hash(class_hash=class_hash) + stark_assert( + compiled_class_hash != 0, + code=StarknetErrorCode.UNDECLARED_CLASS, + message=f"Class with hash {class_hash} is not declared.", + ) + + # Replace the class. + self.state.set_class_hash_at( + contract_address=self.entry_point.contract_address, class_hash=class_hash + ) + + def _send_message_to_l1(self, to_address: int, payload: List[int]): + self.l2_to_l1_messages.append( + # Note that the constructor of OrderedL2ToL1Message might fail as it is + # more restrictive than the Cairo code. + OrderedL2ToL1Message( + order=self.tx_execution_context.n_sent_messages, + to_address=to_address, + payload=payload, + ) + ) + + # Update messages count. + self.tx_execution_context.n_sent_messages += 1 + # Utilities. + def execute_entry_point(self, call: ExecuteEntryPointBase) -> CallResult: + with wrap_with_handler_exception(call=call): + call_info = call.execute( + state=self.state, + resources_manager=self.resources_manager, + tx_execution_context=self.tx_execution_context, + general_config=self.general_config, + support_reverted=self.support_reverted, + ) + + self.internal_calls.append(call_info) + return call_info.result() + + def execute_constructor_entry_point( + self, + contract_address: int, + class_hash: int, + constructor_calldata: List[int], + remaining_gas: int, + ) -> CallResult: + contract_class = self.state.get_compiled_class_by_class_hash(class_hash=class_hash) + constructor_entry_points = contract_class.entry_points_by_type[EntryPointType.CONSTRUCTOR] + if len(constructor_entry_points) == 0: + # Contract has no constructor. + assert ( + len(constructor_calldata) == 0 + ), "Cannot pass calldata to a contract with no constructor." + + call_info = CallInfo.empty_constructor_call( + contract_address=contract_address, + caller_address=self.entry_point.contract_address, + class_hash=class_hash, + ) + self.internal_calls.append(call_info) + + return call_info.result() + + call = self.execute_entry_point_cls( + call_type=CallType.CALL, + contract_address=contract_address, + entry_point_selector=CONSTRUCTOR_ENTRY_POINT_SELECTOR, + entry_point_type=EntryPointType.CONSTRUCTOR, + calldata=constructor_calldata, + caller_address=self.entry_point.contract_address, + initial_gas=remaining_gas, + class_hash=None, + code_address=None, + ) + + return self.execute_entry_point(call=call) + def allocate_segment(self, data: Iterable[MaybeRelocatable]) -> RelocatableValue: segment_start = self.segments.add() segment_end = self.segments.write_arg(ptr=segment_start, arg=data) self.read_only_segments.append((segment_start, segment_end - segment_start)) return segment_start + def _allocate_segment_for_retdata(self, retdata: Iterable[int]) -> RelocatableValue: + return self.allocate_segment(data=retdata) + def post_run(self, runner: CairoFunctionRunner, syscall_end_ptr: MaybeRelocatable): """ Performs post-run syscall-related tasks. @@ -379,10 +710,8 @@ def __init__( ) self.call_iterator: Iterator[CallInfo] = iter([]) - # A stack that keeps track of the state of the calls being executed now. - # The last item is the state of the current call; the one before it, is the - # state of the caller (the call the called the current call); and so on. - self.call_stack: List[CallInfo] = [] + # The CallInfo for the call currently being executed. + self._call_info: Optional[CallInfo] = None # An iterator over contract addresses that were deployed during that call. self.deployed_contracts_iterator: Iterator[int] = iter([]) @@ -394,19 +723,30 @@ def __init__( # code is executed. self.execute_code_read_iterator: Iterator[int] = iter([]) - # A pointer to the Cairo TxInfo struct. - # This pointer needs to match the TxInfo pointer that is going to be used during the system - # call validation by the StarkNet OS. - # Set during enter_tx. - self.tx_info_ptr: Optional[RelocatableValue] = None - # The TransactionExecutionInfo for the transaction currently being executed. self.tx_execution_info: Optional[TransactionExecutionInfo] = None - # StarkNet storage-related members. + # Starknet storage-related members. self.storage_by_address = storage_by_address - def start_tx(self, tx_info_ptr: RelocatableValue): + # A pointer to the Cairo (deprecated) TxInfo struct. + # This pointer needs to match the DeprecatedTxInfo pointer that is going to be used during + # the system call validation by the Starknet OS. + # Set during enter_tx. + self.tx_info_ptr: Optional[RelocatableValue] = None + + # A pointer to the Cairo ExecutionInfo struct of the current call. + # This pointer needs to match the ExecutionInfo pointer that is going to be used during the + # system call validation by the StarkNet OS. + # Set during enter_call. + self.call_execution_info_ptr: Optional[RelocatableValue] = None + + @property + def call_info(self) -> CallInfo: + assert self._call_info is not None + return self._call_info + + def start_tx(self, tx_info_ptr: Optional[RelocatableValue]): """ Called when starting the execution of a transaction. @@ -424,7 +764,6 @@ def end_tx(self): Called after the execution of the current transaction complete. """ assert_exhausted(iterator=self.call_iterator) - assert self.tx_info_ptr is not None self.tx_info_ptr = None assert self.tx_execution_info is not None self.tx_execution_info = None @@ -434,30 +773,45 @@ def assert_interators_exhausted(self): assert_exhausted(iterator=self.result_iterator) assert_exhausted(iterator=self.execute_code_read_iterator) - def enter_call(self): + def enter_call(self, execution_info_ptr: Optional[RelocatableValue]): + assert self.call_execution_info_ptr is None + self.call_execution_info_ptr = execution_info_ptr + self.assert_interators_exhausted() - call_info = next(self.call_iterator) - self.call_stack.append(call_info) + assert self._call_info is None + self._call_info = next(self.call_iterator) self.deployed_contracts_iterator = ( call.contract_address - for call in call_info.internal_calls + for call in self.call_info.internal_calls if call.entry_point_type is EntryPointType.CONSTRUCTOR ) - self.result_iterator = (call.result() for call in call_info.internal_calls) - self.execute_code_read_iterator = iter(call_info.storage_read_values) + self.result_iterator = (call.result() for call in self.call_info.internal_calls) + self.execute_code_read_iterator = iter(self.call_info.storage_read_values) def exit_call(self): + self.call_execution_info_ptr = None + self.assert_interators_exhausted() - self.call_stack.pop() + assert self._call_info is not None + self._call_info = None + + def skip_call(self): + """ + Called when skipping the execution of a call. + It replaces a call to enter_call and exit_call. + """ + self.enter_call(execution_info_ptr=None) + self.exit_call() def skip_tx(self): """ Called when skipping the execution of a transaction. It replaces a call to start_tx and end_tx. """ - next(self.tx_execution_info_iterator) + self.start_tx(tx_info_ptr=None) + self.end_tx() class OsSyscallHandler(SyscallHandlerBase): @@ -491,20 +845,44 @@ def allocate_segment(self, data: Iterable[MaybeRelocatable]) -> RelocatableValue self.segments.write_arg(ptr=segment_start, arg=data) return segment_start + def _allocate_segment_for_retdata(self, retdata: Iterable[int]) -> RelocatableValue: + segment_start = self.segments.add_temp_segment() + self.segments.write_arg(ptr=segment_start, arg=retdata) + return segment_start + # Syscalls. + def _call_contract_helper( + self, remaining_gas: int, request: CairoStructProxy, syscall_name: str + ) -> CallResult: + return next(self.execution_helper.result_iterator) + + def _deploy(self, remaining_gas: int, request: CairoStructProxy) -> Tuple[int, CallResult]: + constructor_result = next(self.execution_helper.result_iterator) + contract_address = next(self.execution_helper.deployed_contracts_iterator) + return contract_address, constructor_result + + def _get_execution_info_ptr(self) -> RelocatableValue: + assert ( + self.execution_helper.call_execution_info_ptr is not None + ), "ExecutionInfo pointer is not set." + return self.execution_helper.call_execution_info_ptr + def _storage_read(self, key: int) -> int: return next(self.execution_helper.execute_code_read_iterator) def _storage_write(self, key: int, value: int): return - def _get_caller_address(self) -> int: - return self.execution_helper.call_stack[-1].caller_address - def _emit_event(self, keys: List[int], data: List[int]): return + def _replace_class(self, class_hash: int): + return + + def _send_message_to_l1(self, to_address: int, payload: List[int]): + return + # Deprecated handlers. @@ -1041,7 +1419,7 @@ def execute_constructor_entry_point( self.execute_entry_point(call=call) def execute_entry_point(self, call: ExecuteEntryPointBase) -> CallResult: - with self.entry_point_execution_context(call=call): + with wrap_with_handler_exception(call=call): # Execute contract call. call_info = call.execute( state=self.sync_state, @@ -1055,25 +1433,6 @@ def execute_entry_point(self, call: ExecuteEntryPointBase) -> CallResult: return call_info.result() - @contextlib.contextmanager - def entry_point_execution_context(self, call: ExecuteEntryPointBase): - try: - yield - except StarkException as exception: - raise HandlerException( - called_contract_address=call.contract_address, stark_exception=exception - ) from exception - except Exception as exception: - # Exceptions caught here that are not StarkException, are necessarily caused due to - # security issues, since every exception raised from a Cairo run (in _run) is already - # wrapped with StarkException. - stark_exception = StarkException( - code=StarknetErrorCode.SECURITY_ERROR, message=str(exception) - ) - raise HandlerException( - called_contract_address=call.contract_address, stark_exception=stark_exception - ) from exception - def emit_event(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): """ Handles the emit_event system call. @@ -1240,10 +1599,10 @@ def _deploy(self, syscall_ptr: RelocatableValue) -> int: return next(self.execution_helper.deployed_contracts_iterator) def _get_caller_address(self, syscall_ptr: RelocatableValue) -> int: - return self.execution_helper.call_stack[-1].caller_address + return self.execution_helper.call_info.caller_address def _get_contract_address(self, syscall_ptr: RelocatableValue) -> int: - return self.execution_helper.call_stack[-1].contract_address + return self.execution_helper.call_info.contract_address def _get_tx_info_ptr(self) -> RelocatableValue: assert self.execution_helper.tx_info_ptr is not None diff --git a/src/starkware/starknet/core/os/syscall_utils.py b/src/starkware/starknet/core/os/syscall_utils.py index 45d45795..80bebc35 100644 --- a/src/starkware/starknet/core/os/syscall_utils.py +++ b/src/starkware/starknet/core/os/syscall_utils.py @@ -1,3 +1,4 @@ +import contextlib import dataclasses import functools import os @@ -11,6 +12,10 @@ from starkware.cairo.lang.compiler.program import Program from starkware.cairo.lang.vm.relocatable import RelocatableValue from starkware.python.utils import safe_zip +from starkware.starknet.business_logic.execution.execute_entry_point_base import ( + ExecuteEntryPointBase, +) +from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starkware_utils.error_handling import StarkException STARKNET_OLD_SYSCALLS_COMPILED_PATH = os.path.join( @@ -161,6 +166,26 @@ class HandlerException(Exception): stark_exception: StarkException +@contextlib.contextmanager +def wrap_with_handler_exception(call: ExecuteEntryPointBase): + try: + yield + except StarkException as exception: + raise HandlerException( + called_contract_address=call.contract_address, stark_exception=exception + ) from exception + except Exception as exception: + # Exceptions caught here that are not StarkException, are necessarily caused due to + # security issues, since every exception raised from a Cairo run (in _run) is already + # wrapped with StarkException. + stark_exception = StarkException( + code=StarknetErrorCode.SECURITY_ERROR, message=str(exception) + ) + raise HandlerException( + called_contract_address=call.contract_address, stark_exception=stark_exception + ) from exception + + def get_selector_from_program(syscall_name: str, syscalls_program: Program) -> int: return syscalls_program.get_const( name=f"__main__.{syscall_name.upper()}_SELECTOR", full_name_lookup=True diff --git a/src/starkware/starknet/core/test_contract/test_utils.py b/src/starkware/starknet/core/test_contract/test_utils.py index 91f98e1f..d170ba5d 100644 --- a/src/starkware/starknet/core/test_contract/test_utils.py +++ b/src/starkware/starknet/core/test_contract/test_utils.py @@ -2,7 +2,6 @@ import os from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME -from starkware.cairo.lang.compiler.program import HintedProgram from starkware.starknet.services.api.contract_class.contract_class import ( CompiledClass, CompiledClassEntryPoint, @@ -26,13 +25,11 @@ def get_test_deprecated_compiled_class() -> DeprecatedCompiledClass: def get_test_compiled_class() -> CompiledClass: return CompiledClass( - program=HintedProgram( - prime=DEFAULT_PRIME, - data=[1, 2, 3], - builtins=[], - hints={}, - compiler_version="", - ), + prime=DEFAULT_PRIME, + bytecode=[1, 2, 3], + hints=[], + pythonic_hints={}, + compiler_version="", entry_points_by_type={ EntryPointType.EXTERNAL: [ CompiledClassEntryPoint(selector=1, offset=1, builtins=["237"]) diff --git a/src/starkware/starknet/definitions/constants.py b/src/starkware/starknet/definitions/constants.py index 80c86d91..0ac178cf 100644 --- a/src/starkware/starknet/definitions/constants.py +++ b/src/starkware/starknet/definitions/constants.py @@ -62,6 +62,7 @@ DECLARE_VERSION = 2 QUERY_VERSION_BASE = 2**128 QUERY_VERSION = QUERY_VERSION_BASE + TRANSACTION_VERSION +QUERY_DECLARE_VERSION = QUERY_VERSION_BASE + DECLARE_VERSION DEPRECATED_DECLARE_VERSIONS = ( 0, 1, @@ -69,6 +70,9 @@ QUERY_VERSION_BASE + 1, ) +# The version of contract class leaf. +CONTRACT_CLASS_LEAF_VERSION: bytes = b"CONTRACT_CLASS_LEAF_V0" + # The version of the Starknet global state. GLOBAL_STATE_VERSION = from_bytes(b"STARKNET_STATE_V0") @@ -111,10 +115,15 @@ class GasCost(Enum): FEE_TRANSFER = ENTRY_POINT + 100 * STEP TRANSACTION = (2 * ENTRY_POINT) + FEE_TRANSFER + (100 * STEP) # Syscall cas costs. + CALL_CONTRACT = SYSCALL_BASE + 10 * STEP + ENTRY_POINT + DEPLOY = SYSCALL_BASE + 200 * STEP + ENTRY_POINT + GET_EXECUTION_INFO = SYSCALL_BASE + 10 * STEP + LIBRARY_CALL = CALL_CONTRACT + REPLACE_CLASS = SYSCALL_BASE + 50 * STEP STORAGE_READ = SYSCALL_BASE + 50 * STEP STORAGE_WRITE = SYSCALL_BASE + 50 * STEP - GET_CALLER_ADDRESS = SYSCALL_BASE + 10 * STEP EMIT_EVENT = SYSCALL_BASE + 10 * STEP + SEND_MESSAGE_TO_L1 = SYSCALL_BASE + 50 * STEP @property def int_value(self) -> int: diff --git a/src/starkware/starknet/definitions/error_codes.py b/src/starkware/starknet/definitions/error_codes.py index 8c0e80e0..9e4c7a60 100644 --- a/src/starkware/starknet/definitions/error_codes.py +++ b/src/starkware/starknet/definitions/error_codes.py @@ -8,6 +8,7 @@ class StarknetErrorCode(ErrorCode): BLOCK_NOT_FOUND = 0 CLASS_ALREADY_DECLARED = auto() + COMPILATION_FAILED = auto() CONTRACT_ADDRESS_UNAVAILABLE = auto() CONTRACT_BYTECODE_SIZE_TOO_LARGE = auto() CONTRACT_CLASS_OBJECT_SIZE_TOO_LARGE = auto() @@ -19,6 +20,7 @@ class StarknetErrorCode(ErrorCode): INVALID_BLOCK_TIMESTAMP = auto() INVALID_COMPILED_CLASS_HASH = auto() INVALID_CONTRACT_CLASS = auto() + INVALID_CONTRACT_CLASS_VERSION = auto() INVALID_PROGRAM = auto() INVALID_RETURN_DATA = auto() INVALID_STATUS_MODE = auto() @@ -63,6 +65,7 @@ class StarknetErrorCode(ErrorCode): UNDECLARED_CLASS = auto() UNEXPECTED_FAILURE = auto() UNINITIALIZED_CONTRACT = auto() + UNSUPPORTED_TRANSACTION = auto() # Errors that are raised by the gateways and caused by wrong usage of the user. @@ -87,8 +90,10 @@ class StarknetErrorCode(ErrorCode): StarknetErrorCode.MISSING_ENTRY_POINT_FOR_INVOKE, StarknetErrorCode.UNAUTHORIZED_ENTRY_POINT_FOR_INVOKE, # Contract class validation. + StarknetErrorCode.COMPILATION_FAILED, StarknetErrorCode.INVALID_COMPILED_CLASS_HASH, StarknetErrorCode.INVALID_CONTRACT_CLASS, + StarknetErrorCode.INVALID_CONTRACT_CLASS_VERSION, # Validate execution. StarknetErrorCode.UNAUTHORIZED_ACTION_ON_VALIDATE, ] @@ -143,6 +148,7 @@ class StarknetErrorCode(ErrorCode): StarknetErrorCode.OUT_OF_RANGE_CONTRACT_STORAGE_KEY, StarknetErrorCode.OUT_OF_RANGE_TRANSACTION_HASH, StarknetErrorCode.OUT_OF_RANGE_TRANSACTION_ID, + StarknetErrorCode.UNSUPPORTED_TRANSACTION, ] ) diff --git a/src/starkware/starknet/definitions/fields.py b/src/starkware/starknet/definitions/fields.py index 10989968..bb42c80a 100644 --- a/src/starkware/starknet/definitions/fields.py +++ b/src/starkware/starknet/definitions/fields.py @@ -41,6 +41,8 @@ ) ) +calldata_metadata = felt_as_hex_or_str_list_metadata + felt_list_metadata = dict( marshmallow_field=mfields.List(IntAsStr(validate=everest_fields.FeltField.validate)) ) @@ -155,7 +157,6 @@ def address_metadata(name: str, error_code: StarknetErrorCode) -> Dict[str, Any] # InvokeFunction. -call_data_metadata = felt_list_metadata call_data_as_hex_metadata = felt_as_hex_list_metadata signature_as_hex_metadata = felt_as_hex_or_str_list_metadata signature_metadata = felt_list_metadata @@ -166,6 +167,19 @@ def address_metadata(name: str, error_code: StarknetErrorCode) -> Dict[str, Any] payload_metadata = felt_as_hex_list_metadata +# Used in the CallL1Handler, to solve compatibility bug. +FromAddressEthAddressField = BackwardCompatibleIntAsHex( + allow_decimal_loading=True, + allow_bytes_hex_loading=False, + allow_int_loading=True, + required=True, + validate=everest_fields.EthAddressIntField.validate, +) + +from_address_field_metadata = dict( + marshmallow_field=FromAddressEthAddressField, field_name="from_address" +) + # Contract address. L2AddressField = RangeValidatedField( @@ -312,7 +326,7 @@ def validate_optional_new_class_hash(class_hash: Optional[int]): upper_bound=constants.ENTRY_POINT_OFFSET_UPPER_BOUND, name="Entry point offset", error_code=StarknetErrorCode.OUT_OF_RANGE_ENTRY_POINT_OFFSET, - formatter=hex, + formatter=None, ) entry_point_offset_metadata = EntryPointOffsetField.metadata() diff --git a/src/starkware/starknet/definitions/general_config.py b/src/starkware/starknet/definitions/general_config.py index 237a6af7..08e74bda 100644 --- a/src/starkware/starknet/definitions/general_config.py +++ b/src/starkware/starknet/definitions/general_config.py @@ -71,6 +71,7 @@ class StarknetChainId(Enum): DEFAULT_CAIRO_RESOURCE_FEE_WEIGHTS = { N_STEPS_RESOURCE: 1.0, **{builtin: 0.0 for builtin in ALL_BUILTINS.except_for(KECCAK_BUILTIN).with_suffix()}, + "segment_arena_builtin": 0.0, } @@ -223,5 +224,6 @@ def build_general_config(raw_general_config: Dict[str, Any]) -> StarknetGeneralC }, } ) + cairo_resource_fee_weights["segment_arena_builtin"] = n_steps_weight * 10 return StarknetGeneralConfig.load(data=raw_general_config) diff --git a/src/starkware/starknet/services/api/contract_class/CMakeLists.txt b/src/starkware/starknet/services/api/contract_class/CMakeLists.txt index 61aea46a..5e85deb0 100644 --- a/src/starkware/starknet/services/api/contract_class/CMakeLists.txt +++ b/src/starkware/starknet/services/api/contract_class/CMakeLists.txt @@ -16,3 +16,16 @@ python_lib(starknet_contract_class_lib pip_marshmallow pip_marshmallow_dataclass ) + +python_lib(starknet_contract_class_utils_lib + PREFIX starkware/starknet/services/api/contract_class + PYTHON ${PYTHON_COMMAND} + + FILES + contract_class_utils.py + + LIBS + cairo_compile_lib + starknet_compile_v1_lib + starknet_contract_class_lib +) diff --git a/src/starkware/starknet/services/api/contract_class/contract_class.py b/src/starkware/starknet/services/api/contract_class/contract_class.py index 0e962e27..2f1183bf 100644 --- a/src/starkware/starknet/services/api/contract_class/contract_class.py +++ b/src/starkware/starknet/services/api/contract_class/contract_class.py @@ -1,21 +1,25 @@ import dataclasses +import re +from abc import abstractmethod from dataclasses import field from enum import Enum, auto from typing import Any, Dict, List, Optional import marshmallow +import marshmallow.fields as mfields import marshmallow_dataclass from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager from starkware.cairo.lang.compiler.preprocessor.flow import ReferenceManager -from starkware.cairo.lang.compiler.program import HintedProgram, Program +from starkware.cairo.lang.compiler.program import CairoHint, Program from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.python.utils import as_non_optional from starkware.starknet.definitions import fields from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.public.abi import AbiType from starkware.starkware_utils.error_handling import stark_assert +from starkware.starkware_utils.marshmallow_dataclass_fields import IntAsHex, additional_metadata from starkware.starkware_utils.subsequence import is_subsequence from starkware.starkware_utils.validated_dataclass import ( ValidatedDataclass, @@ -23,7 +27,15 @@ ) # An ordered list of the supported builtins. -SUPPORTED_BUILTINS = ["pedersen", "range_check", "ecdsa", "bitwise", "ec_op", "poseidon"] +SUPPORTED_BUILTINS = [ + "pedersen", + "range_check", + "ecdsa", + "bitwise", + "ec_op", + "poseidon", + "segment_arena", +] # Utilites. @@ -67,7 +79,7 @@ class ContractClass(ValidatedMarshmallowDataclass): abi: str -@dataclasses.dataclass(frozen=True) +@marshmallow_dataclass.dataclass(frozen=True) class CompiledClassEntryPoint(ValidatedDataclass): # A field element that encodes the signature of the called function. selector: int = field(metadata=fields.entry_point_selector_metadata) @@ -76,12 +88,44 @@ class CompiledClassEntryPoint(ValidatedDataclass): # Builtins used by the entry point. builtins: Optional[List[str]] + @marshmallow.decorators.pre_load + def load_offset_formatted_as_hex( + self, data: Dict[str, Any], many: bool, **kwargs + ) -> Dict[str, Any]: + offset = data["offset"] + if isinstance(offset, str): + assert ( + re.match("^0x[0-9a-f]+$", offset) is not None + ), f"offset field is of unexpected format: {offset}." + data["offset"] = int(offset, 16) -@marshmallow_dataclass.dataclass(frozen=True) + return data + + +# Mypy has a problem with dataclasses that contain unimplemented abstract methods. +# See https://github.com/python/mypy/issues/5374 for details on this problem. +@marshmallow_dataclass.dataclass(frozen=True) # type: ignore[misc] class CompiledClassBase(ValidatedMarshmallowDataclass): - program: HintedProgram entry_points_by_type: Dict[EntryPointType, List[CompiledClassEntryPoint]] + @abstractmethod + def get_builtins(self) -> List[str]: + """ + Returns the "builtins" attribute of the compiled class. + """ + + @abstractmethod + def get_prime(self) -> int: + """ + Returns the "prime" attribute of the compiled class. + """ + + @abstractmethod + def get_bytecode(self) -> List[int]: + """ + Returns the "bytecode" attribute of the compiled class. + """ + def __post_init__(self): super().__post_init__() @@ -109,16 +153,16 @@ def __post_init__(self): ) def validate(self): - validate_builtins(builtins=self.program.builtins) + validate_builtins(builtins=self.get_builtins()) for entry_points in self.entry_points_by_type.values(): for entry_point in entry_points: validate_builtins(builtins=entry_point.builtins) stark_assert( - self.program.prime == DEFAULT_PRIME, + self.get_prime() == DEFAULT_PRIME, code=StarknetErrorCode.INVALID_CONTRACT_CLASS, message=( - f"Invalid value for field prime: {self.program.prime}. Expected: {DEFAULT_PRIME}." + f"Invalid value for field prime: {self.get_prime()}. Expected: {DEFAULT_PRIME}." ), ) @@ -137,15 +181,71 @@ class CompiledClass(CompiledClassBase): Represents a compiled contract class in the StarkNet network. """ - def __post_init__(self): - super().__post_init__() + prime: int = field(metadata=additional_metadata(marshmallow_field=IntAsHex(required=True))) + bytecode: List[int] = field( + metadata=additional_metadata(marshmallow_field=mfields.List(IntAsHex(), required=True)) + ) + # Rust hints. + hints: List[Any] + pythonic_hints: Dict[int, List[CairoHint]] + compiler_version: str = field( + metadata=dict(marshmallow_field=mfields.String(required=False, load_default=None)) + ) - stark_assert( - len(self.program.builtins) == 0, - code=StarknetErrorCode.INVALID_CONTRACT_CLASS, - message="Builtins should be specified per entry point.", + def get_builtins(self) -> List[str]: + return [] + + def get_prime(self) -> int: + return self.prime + + def get_bytecode(self) -> List[int]: + return self.bytecode + + @marshmallow.decorators.pre_load + def parse_pythonic_hints(self, data: Dict[str, Any], many: bool, **kwargs) -> Dict[str, Any]: + """ + Parses Cairo 1.0 casm hints. + Each hint comprises a two-item List: an ID (int) and a List of hint codes (strings). + The returned CairoHint object takes empty "accessible_scopes" and "flow_tracking_data" + values as these are only relevant to Cairo 0 programs. + """ + assert "program" not in data, ( + "Unsupported compiled class format. " + "Cairo 1.0 compiled class must not contain the attribute `program`." ) + pythonic_hints = data["pythonic_hints"] + empty_accessible_scope: List = [] + empty_flow_tracking_data: Dict[str, Any] = { + "ap_tracking": {"group": 0, "offset": 0}, + "reference_ids": {}, + } + + data["pythonic_hints"] = { + hint_id: [ + { + "code": hint_code, + "accessible_scopes": empty_accessible_scope, + "flow_tracking_data": empty_flow_tracking_data, + } + for hint_code in hint_codes + ] + for hint_id, hint_codes in pythonic_hints + } + + return data + + @marshmallow.decorators.post_dump + def dump_pythonic_hints(self, data: Dict[str, Any], many: bool, **kwargs) -> Dict[str, Any]: + data["pythonic_hints"] = [ + [hint_id, [hint_obj["code"] for hint_obj in hint_obj_list]] + for hint_id, hint_obj_list in data["pythonic_hints"].items() + ] + return data + + def __post_init__(self): + super().__post_init__() + for entry_points in self.entry_points_by_type.values(): for entry_point in entry_points: stark_assert( @@ -159,12 +259,12 @@ def get_runnable_program(self, entrypoint_builtins: List[str]) -> Program: Converts the HintedProgram into a Program object that can be run by the Python CairoRunner. """ return Program( - prime=self.program.prime, - data=self.program.data, + prime=self.prime, + data=self.bytecode, # Buitlins for the entrypoint to execute. builtins=entrypoint_builtins, - hints=self.program.hints, - compiler_version=self.program.compiler_version, + hints=self.pythonic_hints, + compiler_version=self.compiler_version, # Fill missing fields with empty values. main_scope=ScopedName(), identifiers=IdentifierManager(), @@ -183,6 +283,15 @@ class DeprecatedCompiledClass(CompiledClassBase): program: Program abi: Optional[AbiType] = None + def get_builtins(self) -> List[str]: + return self.program.builtins + + def get_prime(self) -> int: + return self.program.prime + + def get_bytecode(self) -> List[int]: + return self.program.data + @marshmallow.decorators.post_dump def remove_none_builtins(self, data: Dict[str, Any], many: bool, **kwargs) -> Dict[str, Any]: """ diff --git a/src/starkware/starknet/services/api/contract_class/contract_class_utils.py b/src/starkware/starknet/services/api/contract_class/contract_class_utils.py new file mode 100644 index 00000000..a25085fd --- /dev/null +++ b/src/starkware/starknet/services/api/contract_class/contract_class_utils.py @@ -0,0 +1,60 @@ +import json +import tempfile +from typing import Any, Dict, Optional + +from starkware.starknet.compiler.v1.compile import JsonObject, compile_sierra_to_casm +from starkware.starknet.services.api.contract_class.contract_class import ( + CompiledClass, + ContractClass, +) + + +def compile_contract_class( + contract_class: ContractClass, allowed_libfuncs_list_name: Optional[str] = None +) -> CompiledClass: + """ + Compiles a contract class to a compiled class. + """ + # Extract Sierra representation from the contact class. + sierra = contract_class.dump() + sierra.pop("abi", None) + + # Create a temporary Sierra file. + with tempfile.NamedTemporaryFile(mode="w") as sierra_file: + # Obtain temp file name. + temp_sierra_file_name = sierra_file.name + # Write the contract class Sierra representation to the sierra file. + json.dump(obj=sierra, fp=sierra_file, indent=2) + + # Flush the Sierra content to the file. + sierra_file.flush() + + # Compile the Sierra file. + casm_from_compiled_sierra = compile_sierra_to_casm( + sierra_path=temp_sierra_file_name, allowed_libfuncs_list_name=allowed_libfuncs_list_name + ) + + # Parse the resultant Casm file. + return CompiledClass.load(data=casm_from_compiled_sierra) + + +def load_file(path: str) -> JsonObject: + with open(path, "r") as fp: + return json.load(fp) + + +def load_sierra(sierra_path: str) -> ContractClass: + sierra = load_file(path=sierra_path) + return load_sierra_from_dict(sierra=sierra) + + +def load_sierra_from_dict(sierra: JsonObject) -> ContractClass: + sierra.pop("sierra_program_debug_info", None) + convert_sierra_program_abi_to_string(sierra_program=sierra) + return ContractClass.load(data=sierra) + + +def convert_sierra_program_abi_to_string(sierra_program: Dict[str, Any]): + if "abi" in sierra_program: + assert isinstance(sierra_program["abi"], list), "Unexpected ABI type." + sierra_program["abi"] = json.dumps(obj=sierra_program["abi"]) diff --git a/src/starkware/starknet/services/api/contract_class/contracts/test_contract_cairo1.cairo b/src/starkware/starknet/services/api/contract_class/contracts/test_contract_cairo1.cairo new file mode 100644 index 00000000..d864c1cc --- /dev/null +++ b/src/starkware/starknet/services/api/contract_class/contracts/test_contract_cairo1.cairo @@ -0,0 +1,119 @@ +#[contract] +mod TestContract { + use starknet::storage_read_syscall; + use starknet::storage_write_syscall; + use starknet::syscalls::emit_event_syscall; + use starknet::StorageAddress; + use starknet::ContractAddress; + use starknet::storage_access::storage_base_address_from_felt252; + use starknet::storage_access::storage_address_from_base_and_offset; + use starknet::class_hash::ClassHash; + use starknet::class_hash::ClassHashSerde; + use starknet::ContractAddressIntoFelt252; + use traits::Into; + use array::SpanTrait; + use array::ArrayTrait; + use box::BoxTrait; + + const UNEXPECTED_ERROR: felt252 = 'UNEXPECTED ERROR'; + + struct Storage { + my_storage_var: felt252 + } + + #[external] + fn test(ref arg: felt252, arg1: felt252, arg2: felt252) -> felt252 { + let x = my_storage_var::read(); + my_storage_var::write(x + 1); + x + 1 + } + + #[external] + fn test_storage_read(address: felt252) -> felt252 { + let domain_address = 0_u32; // Only address_domain 0 is currently supported. + let storage_address = storage_address_from_base_and_offset( + storage_base_address_from_felt252(address), 0_u8 + ); + storage_read_syscall(domain_address, storage_address).unwrap_syscall() + } + + #[external] + fn test_storage_write(address: felt252, value: felt252) { + let domain_address = 0_u32; // Only address_domain 0 is currently supported. + let storage_address = storage_address_from_base_and_offset( + storage_base_address_from_felt252(address), 0_u8 + ); + storage_write_syscall(domain_address, storage_address, value).unwrap_syscall(); + } + + #[external] + fn test_get_execution_info( + // Expected block info. + block_number: felt252, + block_timestamp: felt252, + sequencer_address: felt252, + // Expected transaction info. + version: felt252, + account_address: felt252, + max_fee: felt252, + chain_id: felt252, + nonce: felt252, + // Expected call info. + caller_address: felt252, + contract_address: felt252, + entry_point_selector: felt252, + ) { + let execution_info = starknet::get_execution_info().unbox(); + let block_info = execution_info.block_info.unbox(); + assert(block_info.block_number.into() == block_number, UNEXPECTED_ERROR); + assert(block_info.block_timestamp.into() == block_timestamp, UNEXPECTED_ERROR); + assert(block_info.sequencer_address.into() == sequencer_address, UNEXPECTED_ERROR); + + let tx_info = execution_info.tx_info.unbox(); + assert(tx_info.version == version, UNEXPECTED_ERROR); + assert(tx_info.account_contract_address.into() == account_address, UNEXPECTED_ERROR); + assert(tx_info.max_fee.into() == max_fee, UNEXPECTED_ERROR); + assert(tx_info.signature.len() == 1_u32, UNEXPECTED_ERROR); + let transaction_hash = *tx_info.signature.at(0_u32); + assert(tx_info.transaction_hash == transaction_hash, UNEXPECTED_ERROR); + assert(tx_info.chain_id == chain_id, UNEXPECTED_ERROR); + assert(tx_info.nonce == nonce, UNEXPECTED_ERROR); + + assert(execution_info.caller_address.into() == caller_address, UNEXPECTED_ERROR); + assert(execution_info.contract_address.into() == contract_address, UNEXPECTED_ERROR); + assert( + execution_info.entry_point_selector == entry_point_selector, UNEXPECTED_ERROR + ); + } + + #[external] + fn test_emit_event(keys: Array::, data: Array::) { + emit_event_syscall(keys.span(), data.span()).unwrap_syscall(); + } + + #[external] + fn test_call_contract( + contract_address: ContractAddress, entry_point_selector: felt252, calldata: Array:: + ) { + starknet::call_contract_syscall( + contract_address, entry_point_selector, calldata.span() + ).unwrap_syscall(); + } + + #[external] + fn test_library_call( + class_hash: ClassHash, entry_point_selector: felt252, calldata: Array:: + ) { + starknet::syscalls::library_call_syscall( + class_hash, entry_point_selector, calldata.span() + ).unwrap_syscall(); + } + + #[external] + fn assert_eq(x: felt252, y: felt252) -> felt252{ + assert(x == y, 'x != y'); + 'success' + } + +} + diff --git a/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py b/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py index f667f498..3ab8e866 100644 --- a/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py +++ b/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py @@ -24,6 +24,10 @@ CastableToHash = Union[int, str] +# Simulation-related. +SKIP_VALIDATE = "skipValidate" + + class FeederGatewayClient(EverestFeederGatewayClient): """ A client class for the StarkNet FeederGateway. @@ -82,13 +86,14 @@ async def estimate_fee( tx: AccountTransaction, block_hash: Optional[CastableToHash] = None, block_number: Optional[BlockIdentifier] = None, + skip_validate: bool = False, ) -> FeeEstimationInfo: - formatted_block_named_argument = get_formatted_block_named_argument( - block_hash=block_hash, block_number=block_number + formatted_simulate_tx_arguments = get_formatted_simulate_tx_arguments( + block_hash=block_hash, block_number=block_number, skip_validate=skip_validate ) raw_response = await self._send_request( send_method="POST", - uri=f"/estimate_fee?{formatted_block_named_argument}", + uri=f"/estimate_fee?{formatted_simulate_tx_arguments}", data=AccountTransaction.Schema().dumps(obj=tx), ) return FeeEstimationInfo.loads(data=raw_response) @@ -98,13 +103,14 @@ async def estimate_fee_bulk( txs: List[AccountTransaction], block_hash: Optional[CastableToHash] = None, block_number: Optional[BlockIdentifier] = None, + skip_validate: bool = False, ) -> List[FeeEstimationInfo]: - formatted_block_named_argument = get_formatted_block_named_argument( - block_hash=block_hash, block_number=block_number + formatted_simulate_tx_arguments = get_formatted_simulate_tx_arguments( + block_hash=block_hash, block_number=block_number, skip_validate=skip_validate ) raw_response = await self._send_request( send_method="POST", - uri=f"/estimate_fee_bulk?{formatted_block_named_argument}", + uri=f"/estimate_fee_bulk?{formatted_simulate_tx_arguments}", data=AccountTransaction.Schema().dumps(obj=txs, many=True), ) return FeeEstimationInfo.Schema().loads(json_data=raw_response, many=True) @@ -130,13 +136,14 @@ async def simulate_transaction( tx: AccountTransaction, block_hash: Optional[CastableToHash] = None, block_number: Optional[BlockIdentifier] = None, + skip_validate: bool = False, ) -> TransactionSimulationInfo: - formatted_block_named_argument = get_formatted_block_named_argument( - block_hash=block_hash, block_number=block_number + formatted_simulate_tx_arguments = get_formatted_simulate_tx_arguments( + block_hash=block_hash, block_number=block_number, skip_validate=skip_validate ) raw_response = await self._send_request( send_method="POST", - uri=f"/simulate_transaction?{formatted_block_named_argument}", + uri=f"/simulate_transaction?{formatted_simulate_tx_arguments}", data=AccountTransaction.Schema().dumps(obj=tx), ) return TransactionSimulationInfo.loads(data=raw_response) @@ -374,3 +381,19 @@ def get_formatted_block_named_argument( return f"blockNumber={block_number_str}" else: return f"blockHash={format_hash(hash_value=block_hash, hash_field=fields.BlockHashField)}" + + +def get_formatted_simulate_tx_arguments( + block_hash: Optional[CastableToHash], + block_number: Optional[BlockIdentifier], + skip_validate: bool, +) -> str: + """ + Returns formatted simulate transaction arguments, corresponding to the request's arguments. + """ + formatted_block_named_argument = get_formatted_block_named_argument( + block_hash=block_hash, block_number=block_number + ) + formatted_simulation_flags = f"{SKIP_VALIDATE}={json.dumps(skip_validate)}" + + return "&".join([formatted_block_named_argument, formatted_simulation_flags]) diff --git a/src/starkware/starknet/services/api/feeder_gateway/request_objects.py b/src/starkware/starknet/services/api/feeder_gateway/request_objects.py index ca084ca0..1918315a 100644 --- a/src/starkware/starknet/services/api/feeder_gateway/request_objects.py +++ b/src/starkware/starknet/services/api/feeder_gateway/request_objects.py @@ -3,7 +3,6 @@ import marshmallow_dataclass -from services.everest.definitions import fields as everest_fields from starkware.starknet.business_logic.execution.execute_entry_point import ExecuteEntryPoint from starkware.starknet.business_logic.transaction.objects import InternalL1Handler from starkware.starknet.definitions import fields @@ -19,7 +18,7 @@ class CallFunction(ValidatedMarshmallowDataclass): contract_address: int = field(metadata=fields.contract_address_metadata) # A field element that encodes the invoked method. entry_point_selector: int = field(metadata=fields.entry_point_selector_metadata) - calldata: List[int] = field(metadata=fields.call_data_metadata) + calldata: List[int] = field(metadata=fields.calldata_metadata) def to_entry_point(self) -> ExecuteEntryPoint: return ExecuteEntryPoint.create_for_testing( @@ -35,9 +34,7 @@ class CallL1Handler(ValidatedMarshmallowDataclass): Represents an L1 handler call in the StarkNet network. """ - from_address: int = field( - metadata=everest_fields.EthAddressIntField.metadata(field_name="from_address") - ) + from_address: int = field(metadata=fields.from_address_field_metadata) to_address: int = field(metadata=fields.contract_address_metadata) entry_point_selector: int = field(metadata=fields.entry_point_selector_metadata) payload: List[int] = field(metadata=fields.payload_metadata) diff --git a/src/starkware/starknet/services/api/gateway/transaction.py b/src/starkware/starknet/services/api/gateway/transaction.py index 651e5345..eba7d3d0 100644 --- a/src/starkware/starknet/services/api/gateway/transaction.py +++ b/src/starkware/starknet/services/api/gateway/transaction.py @@ -173,7 +173,7 @@ class Deploy(Transaction): contract_address_salt: int = field(metadata=fields.contract_address_salt_metadata) contract_definition: DeprecatedCompiledClass - constructor_calldata: List[int] = field(metadata=fields.call_data_metadata) + constructor_calldata: List[int] = field(metadata=fields.calldata_metadata) # Class variables. tx_type: ClassVar[TransactionType] = TransactionType.DEPLOY @@ -213,7 +213,7 @@ class DeployAccount(AccountTransaction): class_hash: int = field(metadata=fields.ClassHashIntField.metadata()) contract_address_salt: int = field(metadata=fields.contract_address_salt_metadata) - constructor_calldata: List[int] = field(metadata=fields.call_data_metadata) + constructor_calldata: List[int] = field(metadata=fields.calldata_metadata) version: int = field(metadata=fields.tx_version_metadata) # Repeat `nonce` to narrow its type to non-optional int. nonce: int = field(metadata=fields.nonce_metadata) @@ -251,7 +251,7 @@ class InvokeFunction(AccountTransaction): """ sender_address: int = field(metadata=fields.contract_address_metadata) - calldata: List[int] = field(metadata=fields.call_data_metadata) + calldata: List[int] = field(metadata=fields.calldata_metadata) # Class variables. tx_type: ClassVar[TransactionType] = TransactionType.INVOKE_FUNCTION diff --git a/src/starkware/starknet/services/utils/CMakeLists.txt b/src/starkware/starknet/services/utils/CMakeLists.txt index 5afc4fe0..9692326f 100644 --- a/src/starkware/starknet/services/utils/CMakeLists.txt +++ b/src/starkware/starknet/services/utils/CMakeLists.txt @@ -8,8 +8,6 @@ python_lib(starknet_sequencer_api_utils_lib everest_transaction_lib starknet_business_logic_fact_state_lib starknet_business_logic_state_lib - starknet_business_logic_utils_lib - starknet_definitions_lib starknet_feeder_gateway_response_objects_lib starknet_general_config_lib starknet_transaction_execution_objects_lib diff --git a/src/starkware/starknet/services/utils/sequencer_api_utils.py b/src/starkware/starknet/services/utils/sequencer_api_utils.py index 210a6d5d..1c1d8201 100644 --- a/src/starkware/starknet/services/utils/sequencer_api_utils.py +++ b/src/starkware/starknet/services/utils/sequencer_api_utils.py @@ -14,12 +14,11 @@ InternalTransaction, ) from starkware.starknet.business_logic.transaction.state_objects import FeeInfo -from starkware.starknet.business_logic.utils import verify_version -from starkware.starknet.definitions import constants from starkware.starknet.definitions.general_config import StarknetGeneralConfig from starkware.starknet.services.api.feeder_gateway.response_objects import FeeEstimationInfo from starkware.starknet.services.api.gateway.transaction import ( AccountTransaction, + Declare, DeployAccount, DeprecatedDeclare, InvokeFunction, @@ -42,6 +41,8 @@ class InternalAccountTransactionForSimulate(InternalAccountTransaction): # Simulation flags; should be replaced with actual values after construction. skip_validate: Optional[bool] = None + # Override InternalAccountTransaction flag; enable query-version transactions to be created and + # executed. only_query: ClassVar[bool] = True @classmethod @@ -70,7 +71,7 @@ def _from_external( internal_cls: Type[InternalAccountTransactionForSimulate] if isinstance(external_tx, InvokeFunction): internal_cls = InternalInvokeFunctionForSimulate - elif isinstance(external_tx, DeprecatedDeclare): + elif isinstance(external_tx, (Declare, DeprecatedDeclare)): internal_cls = InternalDeclareForSimulate elif isinstance(external_tx, DeployAccount): internal_cls = InternalDeployAccountForSimulate @@ -138,16 +139,3 @@ class InternalDeployAccountForSimulate( """ Represents an internal deploy account in the StarkNet network for the simulate transaction API. """ - - def verify_version(self): - expected_transaction_version_constant = 1 - assert constants.TRANSACTION_VERSION == expected_transaction_version_constant, ( - f"Unexpected constant value. Expected {expected_transaction_version_constant}; " - f"got {constants.TRANSACTION_VERSION}." - ) - verify_version( - version=self.version, - expected_version=constants.TRANSACTION_VERSION, - only_query=self.only_query, - old_supported_versions=[], - ) diff --git a/src/starkware/starknet/storage/starknet_storage.py b/src/starkware/starknet/storage/starknet_storage.py index a948a7ec..52772df0 100644 --- a/src/starkware/starknet/storage/starknet_storage.py +++ b/src/starkware/starknet/storage/starknet_storage.py @@ -16,7 +16,7 @@ class StorageLeaf(FeltLeaf): """ - Represents a commitment tree leaf in a StarkNet contract storage. + Represents a commitment tree leaf in a Starknet contract storage. """ @classmethod diff --git a/src/starkware/starknet/wallets/account.py b/src/starkware/starknet/wallets/account.py index 2075844e..edb84504 100644 --- a/src/starkware/starknet/wallets/account.py +++ b/src/starkware/starknet/wallets/account.py @@ -1,8 +1,12 @@ from abc import ABC, abstractmethod from typing import Awaitable, Callable, List, Tuple -from starkware.starknet.services.api.contract_class.contract_class import DeprecatedCompiledClass +from starkware.starknet.services.api.contract_class.contract_class import ( + ContractClass, + DeprecatedCompiledClass, +) from starkware.starknet.services.api.gateway.transaction import ( + Declare, DeployAccount, DeprecatedDeclare, InvokeFunction, @@ -79,6 +83,22 @@ async def deploy_contract( Returns the signed transaction and the deployed contract address. """ + @abstractmethod + async def declare( + self, + contract_class: ContractClass, + compiled_class_hash: int, + chain_id: int, + max_fee: int, + version: int, + nonce_callback: Callable[[int], Awaitable[int]], + dry_run: bool = False, + ) -> Declare: + """ + Prepares the required information for declaring a contract class through the account + contract. + """ + @abstractmethod async def deprecated_declare( self, @@ -90,6 +110,6 @@ async def deprecated_declare( dry_run: bool = False, ) -> DeprecatedDeclare: """ - Prepares the required information for declaring a contract class through the account - contract. + Prepares the required information for declaring a deprecated contract class through the + account contract. """ diff --git a/src/starkware/starknet/wallets/open_zeppelin.py b/src/starkware/starknet/wallets/open_zeppelin.py index 341b4826..5ae2f7ac 100644 --- a/src/starkware/starknet/wallets/open_zeppelin.py +++ b/src/starkware/starknet/wallets/open_zeppelin.py @@ -14,14 +14,19 @@ ) from starkware.starknet.core.os.transaction_hash.transaction_hash import ( TransactionHashPrefix, + calculate_declare_transaction_hash, calculate_deploy_account_transaction_hash, calculate_deprecated_declare_transaction_hash, calculate_transaction_hash_common, ) from starkware.starknet.definitions import fields from starkware.starknet.public.abi import get_selector_from_name -from starkware.starknet.services.api.contract_class.contract_class import DeprecatedCompiledClass +from starkware.starknet.services.api.contract_class.contract_class import ( + ContractClass, + DeprecatedCompiledClass, +) from starkware.starknet.services.api.gateway.transaction import ( + Declare, DeployAccount, DeprecatedDeclare, InvokeFunction, @@ -32,7 +37,6 @@ ACCOUNT_FILE_NAME = "starknet_open_zeppelin_accounts.json" DEPLOY_CONTRACT_SELECTOR = get_selector_from_name("deploy_contract") -GET_NONCE_SELECTOR = get_selector_from_name("get_nonce") class AccountNotFoundException(Exception): @@ -54,6 +58,28 @@ def account_file(self): os.path.expanduser(self.starknet_context.account_dir), ACCOUNT_FILE_NAME ) + async def declare( + self, + contract_class: ContractClass, + compiled_class_hash: int, + chain_id: int, + max_fee: int, + version: int, + nonce_callback: Callable[[int], Awaitable[int]], + dry_run: bool = False, + ) -> Declare: + account_address, private_key = self._get_account_address_and_private_key(dry_run=dry_run) + return sign_declare_tx( + contract_class=contract_class, + compiled_class_hash=compiled_class_hash, + private_key=private_key, + sender_address=account_address, + chain_id=chain_id, + max_fee=max_fee, + version=version, + nonce=await nonce_callback(account_address), + ) + async def deprecated_declare( self, contract_class: DeprecatedCompiledClass, @@ -275,6 +301,39 @@ def _get_account_address_and_private_key(self, dry_run: bool) -> Tuple[int, Opti return account_address, private_key +def sign_declare_tx( + contract_class: ContractClass, + private_key: Optional[int], + sender_address: int, + chain_id: int, + compiled_class_hash: int, + max_fee: int, + version: int, + nonce: int, +) -> Declare: + hash_value = calculate_declare_transaction_hash( + contract_class=contract_class, + compiled_class_hash=compiled_class_hash, + chain_id=chain_id, + sender_address=sender_address, + max_fee=max_fee, + version=version, + nonce=nonce, + ) + + return Declare( + contract_class=contract_class, + compiled_class_hash=compiled_class_hash, + sender_address=sender_address, + max_fee=max_fee, + signature=( + [] if private_key is None else list(sign(msg_hash=hash_value, priv_key=private_key)) + ), + nonce=nonce, + version=version, + ) + + def sign_deprecated_declare_tx( contract_class: DeprecatedCompiledClass, private_key: Optional[int], diff --git a/src/starkware/starknet/wallets/starknet_context.py b/src/starkware/starknet/wallets/starknet_context.py index 912c20d0..110a256f 100644 --- a/src/starkware/starknet/wallets/starknet_context.py +++ b/src/starkware/starknet/wallets/starknet_context.py @@ -3,7 +3,7 @@ @dataclasses.dataclass class StarknetContext: - # A textual identifier used to distinguish between different StarkNet networks. + # A textual identifier used to distinguish between different Starknet networks. network_id: str # The directory which contains the account information files. account_dir: str diff --git a/src/starkware/starkware_utils/error_handling.py b/src/starkware/starkware_utils/error_handling.py index 0ec0dfa6..a4567d37 100644 --- a/src/starkware/starkware_utils/error_handling.py +++ b/src/starkware/starkware_utils/error_handling.py @@ -22,7 +22,7 @@ class ErrorCode(Enum): class StarkErrorCode(ErrorCode): #: Api function temporarily disabled. API_FUNCTION_TEMPORARILY_DISABLED = 0 - #: Bach was aborted. + #: Batch was aborted. BATCH_ABORTED = auto() #: Batch creation failure; batch currently cannot be created. BATCH_CREATION_FAILURE = auto() diff --git a/src/starkware/starkware_utils/marshmallow_dataclass_fields.py b/src/starkware/starkware_utils/marshmallow_dataclass_fields.py index e9a54fa3..0bf74c4a 100644 --- a/src/starkware/starkware_utils/marshmallow_dataclass_fields.py +++ b/src/starkware/starkware_utils/marshmallow_dataclass_fields.py @@ -98,6 +98,9 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def _serialize(self, value, attr, obj, **kwargs): + """ + Used during dump. + """ if value is None: return None assert isinstance(value, int) @@ -105,6 +108,9 @@ def _serialize(self, value, attr, obj, **kwargs): return hex(value) def _deserialize(self, value, attr, data, **kwargs): + """ + Used during load. + """ if re.match("^0x[0-9a-f]+$", value) is not None: return int(value, 16) @@ -121,6 +127,7 @@ def __init__( self, allow_decimal_loading: bool = False, allow_bytes_hex_loading: bool = False, + allow_int_loading: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -130,8 +137,12 @@ def __init__( ) self._allow_decimal_loading = allow_decimal_loading self._allow_bytes_hex_loading = allow_bytes_hex_loading + self._allow_int_loading = allow_int_loading def _deserialize(self, value, attr, data, **kwargs): + if self._allow_int_loading and isinstance(value, int): + return value + if self._allow_decimal_loading and re.match("^[0-9]+$", value) is not None: # Load non-negative int string. return int(value)