From ebf301a44a4b07b094957766bfbc283094f9d595 Mon Sep 17 00:00:00 2001 From: Osakpolor Obaseki Date: Wed, 31 Jul 2024 07:21:33 +0100 Subject: [PATCH] Update vm execute interface to allow return of execution steps --- node/throttled_vm.go | 7 ++++--- rpc/simulation.go | 2 +- rpc/trace.go | 2 +- vm/vm.go | 16 ++++++++-------- vm/vm_test.go | 4 ++-- 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/node/throttled_vm.go b/node/throttled_vm.go index 7151867c00..6ae72af728 100644 --- a/node/throttled_vm.go +++ b/node/throttled_vm.go @@ -33,13 +33,14 @@ func (tvm *ThrottledVM) Call(callInfo *vm.CallInfo, blockInfo *vm.BlockInfo, sta func (tvm *ThrottledVM) Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *vm.BlockInfo, state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, useBlobData bool, -) ([]*felt.Felt, []*felt.Felt, []vm.TransactionTrace, error) { +) ([]*felt.Felt, []*felt.Felt, []vm.TransactionTrace, uint64, error) { var ret []*felt.Felt var traces []vm.TransactionTrace var dataGasConsumed []*felt.Felt - return ret, dataGasConsumed, traces, tvm.Do(func(vm *vm.VM) error { + var numSteps uint64 + return ret, dataGasConsumed, traces, numSteps, tvm.Do(func(vm *vm.VM) error { var err error - ret, dataGasConsumed, traces, err = (*vm).Execute(txns, declaredClasses, paidFeesOnL1, blockInfo, state, network, + ret, dataGasConsumed, traces, numSteps, err = (*vm).Execute(txns, declaredClasses, paidFeesOnL1, blockInfo, state, network, skipChargeFee, skipValidate, errOnRevert, useBlobData) return err }) diff --git a/rpc/simulation.go b/rpc/simulation.go index 4948c062ac..a9628a30af 100644 --- a/rpc/simulation.go +++ b/rpc/simulation.go @@ -106,7 +106,7 @@ func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTra BlockHashToBeRevealed: blockHashToBeRevealed, } useBlobData := !v0_6Response - overallFees, dataGasConsumed, traces, err := h.vm.Execute(txns, classes, paidFeesOnL1, &blockInfo, + overallFees, dataGasConsumed, traces, numSteps, err := h.vm.Execute(txns, classes, paidFeesOnL1, &blockInfo, state, h.bcReader.Network(), skipFeeCharge, skipValidate, errOnRevert, useBlobData) if err != nil { if errors.Is(err, utils.ErrResourceBusy) { diff --git a/rpc/trace.go b/rpc/trace.go index 809a8481de..a4552afbf2 100644 --- a/rpc/trace.go +++ b/rpc/trace.go @@ -262,7 +262,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block, } useBlobData := !v0_6Response - overallFees, dataGasConsumed, traces, err := h.vm.Execute(block.Transactions, classes, paidFeesOnL1, &blockInfo, state, network, false, + overallFees, dataGasConsumed, traces, numSteps, err := h.vm.Execute(block.Transactions, classes, paidFeesOnL1, &blockInfo, state, network, false, false, false, useBlobData) if err != nil { if errors.Is(err, utils.ErrResourceBusy) { diff --git a/vm/vm.go b/vm/vm.go index 24818d5d3e..03ab327946 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -65,7 +65,7 @@ type VM interface { maxSteps uint64, useBlobData bool) ([]*felt.Felt, error) Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *BlockInfo, state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, useBlobData bool, - ) ([]*felt.Felt, []*felt.Felt, []TransactionTrace, error) + ) ([]*felt.Felt, []*felt.Felt, []TransactionTrace, uint64, error) } type vm struct { @@ -267,7 +267,7 @@ func (v *vm) Call(callInfo *CallInfo, blockInfo *BlockInfo, state core.StateRead func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *BlockInfo, state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, useBlobData bool, -) ([]*felt.Felt, []*felt.Felt, []TransactionTrace, error) { +) ([]*felt.Felt, []*felt.Felt, []TransactionTrace, uint64, error) { context := &callContext{ state: state, log: v.log, @@ -277,12 +277,12 @@ func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, paid txnsJSON, classesJSON, err := marshalTxnsAndDeclaredClasses(txns, declaredClasses) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, 0, err } paidFeesOnL1Bytes, err := json.Marshal(paidFeesOnL1) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, 0, err } paidFeesOnL1CStr := cstring(paidFeesOnL1Bytes) @@ -331,23 +331,23 @@ func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, paid if context.err != "" { if context.errTxnIndex >= 0 { - return nil, nil, nil, TransactionExecutionError{ + return nil, nil, nil, context.executionSteps, TransactionExecutionError{ Index: uint64(context.errTxnIndex), Cause: errors.New(context.err), } } - return nil, nil, nil, errors.New(context.err) + return nil, nil, nil, context.executionSteps, errors.New(context.err) } traces := make([]TransactionTrace, len(context.traces)) for index, traceJSON := range context.traces { if err := json.Unmarshal(traceJSON, &traces[index]); err != nil { - return nil, nil, nil, fmt.Errorf("unmarshal trace: %v", err) + return nil, nil, nil, context.executionSteps, fmt.Errorf("unmarshal trace: %v", err) } // } - return context.actualFees, context.dataGasConsumed, traces, nil + return context.actualFees, context.dataGasConsumed, traces, context.executionSteps, nil } func marshalTxnsAndDeclaredClasses(txns []core.Transaction, declaredClasses []core.Class) (json.RawMessage, json.RawMessage, error) { //nolint:lll diff --git a/vm/vm_test.go b/vm/vm_test.go index 63c830debb..86e4cbe35f 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -203,7 +203,7 @@ func TestExecute(t *testing.T) { state := core.NewState(txn) t.Run("empty transaction list", func(t *testing.T) { - _, _, _, err := New(false, nil).Execute([]core.Transaction{}, []core.Class{}, []*felt.Felt{}, &BlockInfo{ + _, _, _, _, err := New(false, nil).Execute([]core.Transaction{}, []core.Class{}, []*felt.Felt{}, &BlockInfo{ Header: &core.Header{ Timestamp: 1666877926, SequencerAddress: utils.HexToFelt(t, "0x46a89ae102987331d369645031b49c27738ed096f2789c24449966da4c6de6b"), @@ -215,7 +215,7 @@ func TestExecute(t *testing.T) { require.NoError(t, err) }) t.Run("zero data", func(t *testing.T) { - _, _, _, err := New(false, nil).Execute(nil, nil, []*felt.Felt{}, &BlockInfo{ + _, _, _, _, err := New(false, nil).Execute(nil, nil, []*felt.Felt{}, &BlockInfo{ Header: &core.Header{ SequencerAddress: &felt.Zero, GasPrice: &felt.Zero,