From 5b76ca0934c3163667e1a357a5db40ff10e80d50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20Faruk=20IRMAK?= Date: Mon, 22 Jan 2024 18:04:54 +0300 Subject: [PATCH] Allow configuring max number of steps to be executed for starknet_call --- cmd/juno/juno.go | 4 ++++ cmd/juno/juno_test.go | 16 ++++++++++++++++ mocks/mock_vm.go | 8 ++++---- node/node.go | 3 ++- node/throttled_vm.go | 5 +++-- rpc/handlers.go | 10 ++++++++-- rpc/handlers_test.go | 37 ++++++++++++++++++++++++++++++++++++- vm/rust/src/lib.rs | 10 +++++++--- vm/vm.go | 7 ++++--- vm/vm_test.go | 8 ++++---- 10 files changed, 88 insertions(+), 20 deletions(-) diff --git a/cmd/juno/juno.go b/cmd/juno/juno.go index 739511ed20..0279eb1e50 100644 --- a/cmd/juno/juno.go +++ b/cmd/juno/juno.go @@ -75,6 +75,7 @@ const ( cnL2ChainIDF = "cn-l2-chain-id" cnCoreContractAddressF = "cn-core-contract-address" cnUnverifiableRangeF = "cn-unverifiable-range" + callMaxStepsF = "rpc-call-max-steps" defaultConfig = "" defaulHost = "localhost" @@ -105,6 +106,7 @@ const ( defaultCNL1ChainID = "" defaultCNL2ChainID = "" defaultCNCoreContractAddressStr = "" + defaultCallMaxSteps = 4_000_000 configFlagUsage = "The yaml configuration file." logLevelFlagUsage = "Options: debug, info, warn, error." @@ -146,6 +148,7 @@ const ( dbCacheSizeUsage = "Determines the amount of memory (in megabytes) allocated for caching data in the database." dbMaxHandlesUsage = "A soft limit on the number of open files that can be used by the DB" gwAPIKeyUsage = "API key for gateway endpoints to avoid throttling" //nolint: gosec + callMaxStepsUsage = "Maximum number of steps to be executed in starknet_call requests" ) var Version string @@ -320,6 +323,7 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr junoCmd.Flags().Int(dbMaxHandlesF, defaultMaxHandles, dbMaxHandlesUsage) junoCmd.MarkFlagsRequiredTogether(cnNameF, cnFeederURLF, cnGatewayURLF, cnL1ChainIDF, cnL2ChainIDF, cnCoreContractAddressF, cnUnverifiableRangeF) //nolint:lll junoCmd.MarkFlagsMutuallyExclusive(networkF, cnNameF) + junoCmd.Flags().Uint(callMaxStepsF, defaultCallMaxSteps, callMaxStepsUsage) return junoCmd } diff --git a/cmd/juno/juno_test.go b/cmd/juno/juno_test.go index 19b229a622..582dd49195 100644 --- a/cmd/juno/juno_test.go +++ b/cmd/juno/juno_test.go @@ -62,6 +62,7 @@ func TestConfigPrecedence(t *testing.T) { defaultRPCMaxBlockScan := uint(math.MaxUint) defaultMaxCacheSize := uint(8) defaultMaxHandles := 1024 + defaultCallMaxSteps := uint(4_000_000) tests := map[string]struct { cfgFile bool @@ -106,6 +107,7 @@ func TestConfigPrecedence(t *testing.T) { RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "custom network config file": { @@ -149,6 +151,7 @@ cn-unverifiable-range: [0,10] RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "default config with no flags": { @@ -179,6 +182,7 @@ cn-unverifiable-range: [0,10] RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "config file path is empty string": { @@ -209,6 +213,7 @@ cn-unverifiable-range: [0,10] RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "config file doesn't exist": { @@ -244,6 +249,7 @@ cn-unverifiable-range: [0,10] RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "config file with all settings but without any other flags": { @@ -281,6 +287,7 @@ pprof: true RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "config file with some settings but without any other flags": { @@ -315,6 +322,7 @@ http-port: 4576 RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "all flags without config file": { @@ -347,6 +355,7 @@ http-port: 4576 RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "some flags without config file": { @@ -380,6 +389,7 @@ http-port: 4576 RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "all setting set in both config file and flags": { @@ -437,6 +447,7 @@ db-cache-size: 8 RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: 9, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "some setting set in both config file and flags": { @@ -473,6 +484,7 @@ network: goerli RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "some setting set in default, config file and flags": { @@ -505,6 +517,7 @@ network: goerli RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "only set env variables": { @@ -535,6 +548,7 @@ network: goerli RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "some setting set in both env variables and flags": { @@ -566,6 +580,7 @@ network: goerli RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, "some setting set in both env variables and config file": { @@ -598,6 +613,7 @@ network: goerli DBCacheSize: defaultMaxCacheSize, GatewayAPIKey: "apikey", DBMaxHandles: defaultMaxHandles, + RPCCallMaxSteps: defaultCallMaxSteps, }, }, } diff --git a/mocks/mock_vm.go b/mocks/mock_vm.go index e6a7de806e..fdadb3ab6e 100644 --- a/mocks/mock_vm.go +++ b/mocks/mock_vm.go @@ -42,18 +42,18 @@ func (m *MockVM) EXPECT() *MockVMMockRecorder { } // Call mocks base method. -func (m *MockVM) Call(arg0, arg1, arg2 *felt.Felt, arg3 []felt.Felt, arg4, arg5 uint64, arg6 core.StateReader, arg7 *utils.Network) ([]*felt.Felt, error) { +func (m *MockVM) Call(arg0, arg1, arg2 *felt.Felt, arg3 []felt.Felt, arg4, arg5 uint64, arg6 core.StateReader, arg7 *utils.Network, arg8 uint64) ([]*felt.Felt, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Call", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) + ret := m.ctrl.Call(m, "Call", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) ret0, _ := ret[0].([]*felt.Felt) ret1, _ := ret[1].(error) return ret0, ret1 } // Call indicates an expected call of Call. -func (mr *MockVMMockRecorder) Call(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7 any) *gomock.Call { +func (mr *MockVMMockRecorder) Call(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Call", reflect.TypeOf((*MockVM)(nil).Call), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Call", reflect.TypeOf((*MockVM)(nil).Call), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) } // Execute mocks base method. diff --git a/node/node.go b/node/node.go index d53924e26c..0f2e217827 100644 --- a/node/node.go +++ b/node/node.go @@ -76,6 +76,7 @@ type Config struct { MaxVMs uint `mapstructure:"max-vms"` MaxVMQueue uint `mapstructure:"max-vm-queue"` RPCMaxBlockScan uint `mapstructure:"rpc-max-block-scan"` + RPCCallMaxSteps uint `mapstructure:"rpc-call-max-steps"` DBCacheSize uint `mapstructure:"db-cache-size"` DBMaxHandles int `mapstructure:"db-max-handles"` @@ -145,7 +146,7 @@ func New(cfg *Config, version string) (*Node, error) { //nolint:gocyclo,funlen throttledVM := NewThrottledVM(vm.New(log), cfg.MaxVMs, int32(cfg.MaxVMQueue)) rpcHandler := rpc.New(chain, synchronizer, throttledVM, version, log).WithGateway(gatewayClient).WithFeeder(client) - rpcHandler = rpcHandler.WithFilterLimit(cfg.RPCMaxBlockScan) + rpcHandler = rpcHandler.WithFilterLimit(cfg.RPCMaxBlockScan).WithCallMaxSteps(uint64(cfg.RPCCallMaxSteps)) services = append(services, rpcHandler) // to improve RPC throughput we double GOMAXPROCS maxGoroutines := 2 * runtime.GOMAXPROCS(0) diff --git a/node/throttled_vm.go b/node/throttled_vm.go index c20b16ca09..aaec8ec72b 100644 --- a/node/throttled_vm.go +++ b/node/throttled_vm.go @@ -16,13 +16,14 @@ func NewThrottledVM(res vm.VM, concurrenyBudget uint, maxQueueLen int32) *Thrott } func (tvm *ThrottledVM) Call(contractAddr, classHash, selector *felt.Felt, calldata []felt.Felt, blockNumber, - blockTimestamp uint64, state core.StateReader, network *utils.Network, + blockTimestamp uint64, state core.StateReader, network *utils.Network, maxSteps uint64, ) ([]*felt.Felt, error) { var ret []*felt.Felt throttler := (*utils.Throttler[vm.VM])(tvm) return ret, throttler.Do(func(vm *vm.VM) error { var err error - ret, err = (*vm).Call(contractAddr, classHash, selector, calldata, blockNumber, blockTimestamp, state, network) + ret, err = (*vm).Call(contractAddr, classHash, selector, calldata, blockNumber, blockTimestamp, + state, network, maxSteps) return err }) } diff --git a/rpc/handlers.go b/rpc/handlers.go index 1ee77a3c02..d5c8710f7f 100644 --- a/rpc/handlers.go +++ b/rpc/handlers.go @@ -96,7 +96,8 @@ type Handler struct { blockTraceCache *lru.Cache[traceCacheKey, []TracedBlockTransaction] - filterLimit uint + filterLimit uint + callMaxSteps uint64 } type subscription struct { @@ -132,6 +133,11 @@ func (h *Handler) WithFilterLimit(limit uint) *Handler { return h } +func (h *Handler) WithCallMaxSteps(maxSteps uint64) *Handler { + h.callMaxSteps = maxSteps + return h +} + func (h *Handler) WithIDGen(idgen func() uint64) *Handler { h.idgen = idgen return h @@ -1253,7 +1259,7 @@ func (h *Handler) Call(call FunctionCall, id BlockID) ([]*felt.Felt, *jsonrpc.Er } res, err := h.vm.Call(&call.ContractAddress, classHash, &call.EntryPointSelector, - call.Calldata, header.Number, header.Timestamp, state, h.bcReader.Network()) + call.Calldata, header.Number, header.Timestamp, state, h.bcReader.Network(), h.callMaxSteps) if err != nil { if errors.Is(err, utils.ErrResourceBusy) { return nil, ErrInternal.CloneWithData(err.Error()) diff --git a/rpc/handlers_test.go b/rpc/handlers_test.go index 07dca158a8..aee1399e44 100644 --- a/rpc/handlers_test.go +++ b/rpc/handlers_test.go @@ -2973,7 +2973,8 @@ func TestCall(t *testing.T) { t.Cleanup(mockCtrl.Finish) mockReader := mocks.NewMockReader(mockCtrl) - handler := rpc.New(mockReader, nil, nil, "", utils.NewNopZapLogger()) + mockVM := mocks.NewMockVM(mockCtrl) + handler := rpc.New(mockReader, nil, mockVM, "", utils.NewNopZapLogger()) t.Run("empty blockchain", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) @@ -3010,6 +3011,40 @@ func TestCall(t *testing.T) { require.Nil(t, res) assert.Equal(t, rpc.ErrContractNotFound, rpcErr) }) + + t.Run("ok", func(t *testing.T) { + handler = handler.WithCallMaxSteps(1337) + + contractAddr := new(felt.Felt).SetUint64(1) + selector := new(felt.Felt).SetUint64(2) + classHash := new(felt.Felt).SetUint64(3) + calldata := []felt.Felt{ + *new(felt.Felt).SetUint64(4), + *new(felt.Felt).SetUint64(5), + } + expectedRes := []*felt.Felt{ + new(felt.Felt).SetUint64(6), + new(felt.Felt).SetUint64(7), + } + + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) + mockReader.EXPECT().HeadsHeader().Return(&core.Header{ + Number: 100, + Timestamp: 101, + }, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockReader.EXPECT().Network().Return(&utils.Mainnet) + mockVM.EXPECT().Call(contractAddr, classHash, selector, calldata, uint64(100), + uint64(101), gomock.Any(), &utils.Mainnet, uint64(1337)).Return(expectedRes, nil) + + res, rpcErr := handler.Call(rpc.FunctionCall{ + ContractAddress: *contractAddr, + EntryPointSelector: *selector, + Calldata: calldata, + }, rpc.BlockID{Latest: true}) + require.Nil(t, rpcErr) + require.Equal(t, expectedRes, res) + }) } func TestEstimateMessageFee(t *testing.T) { diff --git a/vm/rust/src/lib.rs b/vm/rust/src/lib.rs index ebfb594e35..907aa55869 100644 --- a/vm/rust/src/lib.rs +++ b/vm/rust/src/lib.rs @@ -9,7 +9,7 @@ use std::{ }; use blockifier::{ - abi::constants::{INITIAL_GAS_COST, N_STEPS_RESOURCE}, + abi::constants::{INITIAL_GAS_COST, N_STEPS_RESOURCE, MAX_STEPS_PER_TX, MAX_VALIDATE_STEPS_PER_TX}, block_context::{BlockContext, GasPrices, FeeTokenAddresses}, execution::{ common_hints::ExecutionMode, @@ -68,6 +68,7 @@ pub extern "C" fn cairoVMCall( block_number: c_ulonglong, block_timestamp: c_ulonglong, chain_id: *const c_char, + max_steps: c_ulonglong, ) { let reader = JunoStateReader::new(reader_handle, block_number); let contract_addr_felt = ptr_to_felt(contract_address); @@ -113,6 +114,7 @@ pub extern "C" fn cairoVMCall( block_timestamp, StarkFelt::default(), GAS_PRICES, + Some(max_steps), ), &AccountTransactionContext::Deprecated(DeprecatedAccountTransactionContext::default()), ExecutionMode::Execute, @@ -204,6 +206,7 @@ pub extern "C" fn cairoVMExecute( eth_l1_gas_price: felt_to_u128(gas_price_wei_felt), strk_l1_gas_price: felt_to_u128(gas_price_strk_felt), }, + None ); let mut state = CachedState::new(reader, GlobalContractCache::default()); let charge_fee = skip_charge_fee == 0; @@ -396,6 +399,7 @@ fn build_block_context( block_timestamp: c_ulonglong, sequencer_address: StarkFelt, gas_prices: GasPrices, + max_steps: Option, ) -> BlockContext { BlockContext { chain_id: ChainId(chain_id_str.into()), @@ -432,8 +436,8 @@ fn build_block_context( (KECCAK_BUILTIN_NAME.to_string(), N_STEPS_FEE_WEIGHT * 2048.0), ]) .into(), - invoke_tx_max_n_steps: 3_000_000, - validate_max_n_steps: 3_000_000, + invoke_tx_max_n_steps: max_steps.unwrap_or(MAX_STEPS_PER_TX as u64).try_into().unwrap(), + validate_max_n_steps: MAX_VALIDATE_STEPS_PER_TX as u32, max_recursion_depth: 50, } } diff --git a/vm/vm.go b/vm/vm.go index 93a52e28ad..9b15778fad 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -5,7 +5,7 @@ package vm //#include // extern void cairoVMCall(char* contract_address, char* class_hash, char* entry_point_selector, char** calldata, // size_t len_calldata, uintptr_t readerHandle, unsigned long long block_number, -// unsigned long long block_timestamp, char* chain_id); +// unsigned long long block_timestamp, char* chain_id, unsigned long long max_steps); // // extern void cairoVMExecute(char* txns_json, char* classes_json, uintptr_t readerHandle, unsigned long long block_number, // unsigned long long block_timestamp, char* chain_id, char* sequencer_address, char* paid_fees_on_l1_json, @@ -31,7 +31,7 @@ import ( //go:generate mockgen -destination=../mocks/mock_vm.go -package=mocks github.com/NethermindEth/juno/vm VM type VM interface { Call(contractAddr, classHash, selector *felt.Felt, calldata []felt.Felt, blockNumber, - blockTimestamp uint64, state core.StateReader, network *utils.Network, + blockTimestamp uint64, state core.StateReader, network *utils.Network, maxSteps uint64, ) ([]*felt.Felt, error) Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64, sequencerAddress *felt.Felt, state core.StateReader, network *utils.Network, paidFeesOnL1 []*felt.Felt, @@ -111,7 +111,7 @@ func makePtrFromFelt(val *felt.Felt) unsafe.Pointer { } func (v *vm) Call(contractAddr, classHash, selector *felt.Felt, calldata []felt.Felt, blockNumber, - blockTimestamp uint64, state core.StateReader, network *utils.Network, + blockTimestamp uint64, state core.StateReader, network *utils.Network, maxSteps uint64, ) ([]*felt.Felt, error) { context := &callContext{ state: state, @@ -149,6 +149,7 @@ func (v *vm) Call(contractAddr, classHash, selector *felt.Felt, calldata []felt. C.ulonglong(blockNumber), C.ulonglong(blockTimestamp), chainID, + C.ulonglong(maxSteps), ) for _, ptr := range calldataPtrs { diff --git a/vm/vm_test.go b/vm/vm_test.go index 1b640b8677..5317f08a45 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -48,7 +48,7 @@ func TestV0Call(t *testing.T) { })) entryPoint := utils.HexToFelt(t, "0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") - ret, err := New(nil).Call(contractAddr, classHash, entryPoint, nil, 0, 0, testState, &utils.Mainnet) + ret, err := New(nil).Call(contractAddr, classHash, entryPoint, nil, 0, 0, testState, &utils.Mainnet, 1_000_000) require.NoError(t, err) assert.Equal(t, []*felt.Felt{&felt.Zero}, ret) @@ -64,7 +64,7 @@ func TestV0Call(t *testing.T) { }, }, nil)) - ret, err = New(nil).Call(contractAddr, classHash, entryPoint, nil, 1, 0, testState, &utils.Mainnet) + ret, err = New(nil).Call(contractAddr, classHash, entryPoint, nil, 1, 0, testState, &utils.Mainnet, 1_000_000) require.NoError(t, err) assert.Equal(t, []*felt.Felt{new(felt.Felt).SetUint64(1337)}, ret) } @@ -108,7 +108,7 @@ func TestV1Call(t *testing.T) { storageLocation := utils.HexToFelt(t, "0x44") ret, err := New(log).Call(contractAddr, nil, entryPoint, []felt.Felt{ *storageLocation, - }, 0, 0, testState, &utils.Goerli) + }, 0, 0, testState, &utils.Goerli, 1_000_000) require.NoError(t, err) assert.Equal(t, []*felt.Felt{&felt.Zero}, ret) @@ -126,7 +126,7 @@ func TestV1Call(t *testing.T) { ret, err = New(log).Call(contractAddr, nil, entryPoint, []felt.Felt{ *storageLocation, - }, 1, 0, testState, &utils.Goerli) + }, 1, 0, testState, &utils.Goerli, 1_000_000) require.NoError(t, err) assert.Equal(t, []*felt.Felt{new(felt.Felt).SetUint64(37)}, ret) }