diff --git a/jsonrpc/http.go b/jsonrpc/http.go index 68dccd642f..e9f5a0bb1b 100644 --- a/jsonrpc/http.go +++ b/jsonrpc/http.go @@ -1,6 +1,7 @@ package jsonrpc import ( + "maps" "net/http" "github.com/NethermindEth/juno/utils" @@ -46,8 +47,11 @@ func (h *HTTP) ServeHTTP(writer http.ResponseWriter, req *http.Request) { req.Body = http.MaxBytesReader(writer, req.Body, MaxRequestBodySize) h.listener.OnNewRequest("any") - resp, err := h.rpc.HandleReader(req.Context(), req.Body) + resp, header, err := h.rpc.HandleReader(req.Context(), req.Body) + writer.Header().Set("Content-Type", "application/json") + maps.Copy(writer.Header(), header) // overwrites duplicate headers + if err != nil { h.log.Errorw("Handler failure", "err", err) writer.WriteHeader(http.StatusInternalServerError) diff --git a/jsonrpc/server.go b/jsonrpc/server.go index b75aeae6d0..c63f15c849 100644 --- a/jsonrpc/server.go +++ b/jsonrpc/server.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "net/http" "reflect" "strings" "sync" @@ -183,11 +184,18 @@ func (s *Server) registerMethod(method Method) error { if numArgs != len(method.Params) { return errors.New("number of non-context function params and param names must match") } - if handlerT.NumOut() != 2 { - return errors.New("handler must return 2 values") + outSize := handlerT.NumOut() + if outSize < 2 || outSize > 3 { + return errors.New("handler must return 2 or 3 values") } - if handlerT.Out(1) != reflect.TypeOf(&Error{}) { - return errors.New("second return value must be a *jsonrpc.Error") + if outSize == 2 && handlerT.Out(1) != reflect.TypeOf(&Error{}) { + return errors.New("second return value must be a *jsonrpc.Error for 2 tuple handler") + } else if outSize == 3 && handlerT.Out(2) != reflect.TypeOf(&Error{}) { + return errors.New("third return value must be a *jsonrpc.Error for 3 tuple handler") + } + + if outSize == 3 && handlerT.Out(1) != reflect.TypeOf(http.Header{}) { + return errors.New("second return value must be a http.Header for 3 tuple handler") } // The method is valid. Mutate the appropriate fields and register on the server. @@ -255,7 +263,8 @@ func (s *Server) HandleReadWriter(ctx context.Context, rw io.ReadWriter) error { activated: activated, } msgCtx := context.WithValue(ctx, ConnKey{}, conn) - resp, err := s.HandleReader(msgCtx, rw) + // header is unnecessary for read-writer(websocket) + resp, _, err := s.HandleReader(msgCtx, rw) if err != nil { conn.initialErr = err return err @@ -272,13 +281,15 @@ func (s *Server) HandleReadWriter(ctx context.Context, rw io.ReadWriter) error { // HandleReader processes a request to the server // It returns the response in a byte array, only returns an // error if it can not create the response byte array -func (s *Server) HandleReader(ctx context.Context, reader io.Reader) ([]byte, error) { +func (s *Server) HandleReader(ctx context.Context, reader io.Reader) ([]byte, http.Header, error) { bufferedReader := bufio.NewReaderSize(reader, bufferSize) requestIsBatch := isBatch(bufferedReader) res := &response{ Version: "2.0", } + header := http.Header{} + dec := json.NewDecoder(bufferedReader) dec.UseNumber() @@ -286,13 +297,15 @@ func (s *Server) HandleReader(ctx context.Context, reader io.Reader) ([]byte, er req := new(Request) if jsonErr := dec.Decode(req); jsonErr != nil { res.Error = Err(InvalidJSON, jsonErr.Error()) - } else if resObject, handleErr := s.handleRequest(ctx, req); handleErr != nil { + } else if resObject, httpHeader, handleErr := s.handleRequest(ctx, req); handleErr != nil { if !errors.Is(handleErr, ErrInvalidID) { res.ID = req.ID } res.Error = Err(InvalidRequest, handleErr.Error()) + header = httpHeader } else { res = resObject + header = httpHeader } } else { var batchReq []json.RawMessage @@ -307,23 +320,27 @@ func (s *Server) HandleReader(ctx context.Context, reader io.Reader) ([]byte, er } if res == nil { - return nil, nil + return nil, header, nil } - return json.Marshal(res) + + result, err := json.Marshal(res) + return result, header, err } -func (s *Server) handleBatchRequest(ctx context.Context, batchReq []json.RawMessage) ([]byte, error) { +func (s *Server) handleBatchRequest(ctx context.Context, batchReq []json.RawMessage) ([]byte, http.Header, error) { var ( - responses []json.RawMessage mutex sync.Mutex + responses []json.RawMessage + headers []http.Header ) - addResponse := func(response any) { + addResponse := func(response any, header http.Header) { if responseJSON, err := json.Marshal(response); err != nil { s.log.Errorw("failed to marshal response", "err", err) } else { mutex.Lock() responses = append(responses, responseJSON) + headers = append(headers, header) mutex.Unlock() } } @@ -341,7 +358,7 @@ func (s *Server) handleBatchRequest(ctx context.Context, batchReq []json.RawMess addResponse(&response{ Version: "2.0", Error: Err(InvalidRequest, err.Error()), - }) + }, http.Header{}) continue } @@ -349,7 +366,7 @@ func (s *Server) handleBatchRequest(ctx context.Context, batchReq []json.RawMess s.pool.Go(func() { defer wg.Done() - resp, err := s.handleRequest(ctx, req) + resp, header, err := s.handleRequest(ctx, req) if err != nil { resp = &response{ Version: "2.0", @@ -359,20 +376,33 @@ func (s *Server) handleBatchRequest(ctx context.Context, batchReq []json.RawMess resp.ID = req.ID } } - // for notification request response is nil + // for notification request response is nil and header is irrelevant for now if resp != nil { - addResponse(resp) + addResponse(resp, header) } }) } wg.Wait() + + // merge headers + finalHeaders := http.Header{} + for _, header := range headers { + for k, v := range header { + for _, e := range v { + finalHeaders.Add(k, e) + } + } + } + // according to the spec if there are no response objects server must not return empty array if len(responses) == 0 { - return nil, nil + return nil, finalHeaders, nil } - return json.Marshal(responses) + result, err := json.Marshal(responses) + + return result, finalHeaders, err // todo: fix batch request aggregate header } func isBatch(reader *bufio.Reader) bool { @@ -396,11 +426,13 @@ func isNil(i any) bool { return i == nil || reflect.ValueOf(i).IsNil() } -func (s *Server) handleRequest(ctx context.Context, req *Request) (*response, error) { +func (s *Server) handleRequest(ctx context.Context, req *Request) (*response, http.Header, error) { s.log.Tracew("Received request", "req", req) + + header := http.Header{} if err := req.isSane(); err != nil { s.log.Tracew("Request sanity check failed", "err", err) - return nil, err + return nil, header, err } res := &response{ @@ -412,7 +444,7 @@ func (s *Server) handleRequest(ctx context.Context, req *Request) (*response, er if !found { res.Error = Err(MethodNotFound, nil) s.log.Tracew("Method not found in request", "method", req.Method) - return res, nil + return res, header, nil } handlerTimer := time.Now() @@ -421,7 +453,7 @@ func (s *Server) handleRequest(ctx context.Context, req *Request) (*response, er if err != nil { res.Error = Err(InvalidParams, err.Error()) s.log.Tracew("Error building arguments for RPC call", "err", err) - return res, nil + return res, header, nil } defer func() { s.listener.OnRequestHandled(req.Method, time.Since(handlerTimer)) @@ -430,10 +462,16 @@ func (s *Server) handleRequest(ctx context.Context, req *Request) (*response, er tuple := reflect.ValueOf(calledMethod.Handler).Call(args) if res.ID == nil { // notification s.log.Tracew("Notification received, no response expected") - return nil, nil + return nil, header, nil + } + + errorIndex := 1 + if len(tuple) == 3 { + errorIndex = 2 + header = (tuple[1].Interface()).(http.Header) } - if errAny := tuple[1].Interface(); !isNil(errAny) { + if errAny := tuple[errorIndex].Interface(); !isNil(errAny) { res.Error = errAny.(*Error) if res.Error.Code == InternalError { s.listener.OnRequestFailed(req.Method, res.Error) @@ -441,11 +479,11 @@ func (s *Server) handleRequest(ctx context.Context, req *Request) (*response, er errJSON, _ := json.Marshal(res.Error) s.log.Debugw("Failed handing RPC request", "req", string(reqJSON), "res", string(errJSON)) } - return res, nil + return res, header, nil } res.Result = tuple[0].Interface() - return res, nil + return res, header, nil } func (s *Server) buildArguments(ctx context.Context, params any, method Method) ([]reflect.Value, error) { diff --git a/jsonrpc/server_test.go b/jsonrpc/server_test.go index 36794456ed..196aec253e 100644 --- a/jsonrpc/server_test.go +++ b/jsonrpc/server_test.go @@ -40,17 +40,27 @@ func TestServer_RegisterMethod(t *testing.T) { "no return": { handler: func(param1, param2 int) {}, paramNames: []jsonrpc.Parameter{{Name: "param1"}, {Name: "param2"}}, - want: "handler must return 2 values", + want: "handler must return 2 or 3 values", }, "int return": { handler: func(param1, param2 int) (int, int) { return 0, 0 }, paramNames: []jsonrpc.Parameter{{Name: "param1"}, {Name: "param2"}}, - want: "second return value must be a *jsonrpc.Error", + want: "second return value must be a *jsonrpc.Error for 2 tuple handler", + }, + "no error return 3": { + handler: func(param1, param2 int) (int, int, int) { return 0, 0, 0 }, + paramNames: []jsonrpc.Parameter{{Name: "param1"}, {Name: "param2"}}, + want: "third return value must be a *jsonrpc.Error for 3 tuple handler", + }, + "no header return 3": { + handler: func(param1, param2 int) (int, int, *jsonrpc.Error) { return 0, 0, &jsonrpc.Error{} }, + paramNames: []jsonrpc.Parameter{{Name: "param1"}, {Name: "param2"}}, + want: "second return value must be a http.Header for 3 tuple handler", }, "no error return": { handler: func(param1, param2 int) (any, int) { return 0, 0 }, paramNames: []jsonrpc.Parameter{{Name: "param1"}, {Name: "param2"}}, - want: "second return value must be a *jsonrpc.Error", + want: "second return value must be a *jsonrpc.Error for 2 tuple handler", }, } @@ -472,8 +482,9 @@ func TestHandle(t *testing.T) { oldRequestFailedEventCount := len(listener.OnRequestFailedCalls) oldRequestHandledCalls := len(listener.OnRequestHandledCalls) - res, err := server.HandleReader(context.Background(), strings.NewReader(test.req)) + res, httpHeader, err := server.HandleReader(context.Background(), strings.NewReader(test.req)) require.NoError(t, err) + assert.NotNil(t, httpHeader) if test.isBatch { assertBatchResponse(t, test.res, string(res)) @@ -515,8 +526,9 @@ func BenchmarkHandle(b *testing.B) { const request = `{"jsonrpc":"2.0","id":1,"method":"test"}` for i := 0; i < b.N; i++ { - _, err := server.HandleReader(context.Background(), strings.NewReader(request)) + _, header, err := server.HandleReader(context.Background(), strings.NewReader(request)) require.NoError(b, err) + require.NotNil(b, header) } } @@ -531,9 +543,10 @@ func TestCannotWriteToConnInHandler(t *testing.T) { return 0, nil }, })) - res, err := server.HandleReader(context.Background(), strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"test"}`)) + res, header, err := server.HandleReader(context.Background(), strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"test"}`)) require.NoError(t, err) require.Equal(t, `{"jsonrpc":"2.0","result":0,"id":1}`, string(res)) + require.NotNil(t, header) } type fakeConn struct{} diff --git a/mocks/mock_vm.go b/mocks/mock_vm.go index 134c8bb952..733ef7c20c 100644 --- a/mocks/mock_vm.go +++ b/mocks/mock_vm.go @@ -58,14 +58,15 @@ func (mr *MockVMMockRecorder) Call(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomo } // Execute mocks base method. -func (m *MockVM) Execute(arg0 []core.Transaction, arg1 []core.Class, arg2 []*felt.Felt, arg3 *vm.BlockInfo, arg4 core.StateReader, arg5 *utils.Network, arg6, arg7, arg8, arg9 bool) ([]*felt.Felt, []*felt.Felt, []vm.TransactionTrace, error) { +func (m *MockVM) Execute(arg0 []core.Transaction, arg1 []core.Class, arg2 []*felt.Felt, arg3 *vm.BlockInfo, arg4 core.StateReader, arg5 *utils.Network, arg6, arg7, arg8, arg9 bool) ([]*felt.Felt, []*felt.Felt, []vm.TransactionTrace, uint64, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Execute", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9) ret0, _ := ret[0].([]*felt.Felt) ret1, _ := ret[1].([]*felt.Felt) ret2, _ := ret[2].([]vm.TransactionTrace) - ret3, _ := ret[3].(error) - return ret0, ret1, ret2, ret3 + ret3, _ := ret[3].(uint64) + ret4, _ := ret[4].(error) + return ret0, ret1, ret2, ret3, ret4 } // Execute indicates an expected call of Execute. diff --git a/node/throttled_vm.go b/node/throttled_vm.go index 7151867c00..6ae72af728 100644 --- a/node/throttled_vm.go +++ b/node/throttled_vm.go @@ -33,13 +33,14 @@ func (tvm *ThrottledVM) Call(callInfo *vm.CallInfo, blockInfo *vm.BlockInfo, sta func (tvm *ThrottledVM) Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *vm.BlockInfo, state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, useBlobData bool, -) ([]*felt.Felt, []*felt.Felt, []vm.TransactionTrace, error) { +) ([]*felt.Felt, []*felt.Felt, []vm.TransactionTrace, uint64, error) { var ret []*felt.Felt var traces []vm.TransactionTrace var dataGasConsumed []*felt.Felt - return ret, dataGasConsumed, traces, tvm.Do(func(vm *vm.VM) error { + var numSteps uint64 + return ret, dataGasConsumed, traces, numSteps, tvm.Do(func(vm *vm.VM) error { var err error - ret, dataGasConsumed, traces, err = (*vm).Execute(txns, declaredClasses, paidFeesOnL1, blockInfo, state, network, + ret, dataGasConsumed, traces, numSteps, err = (*vm).Execute(txns, declaredClasses, paidFeesOnL1, blockInfo, state, network, skipChargeFee, skipValidate, errOnRevert, useBlobData) return err }) diff --git a/rpc/estimate_fee.go b/rpc/estimate_fee.go index c89599abac..55aebb5a4f 100644 --- a/rpc/estimate_fee.go +++ b/rpc/estimate_fee.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/jsonrpc" @@ -64,52 +65,55 @@ func (f FeeEstimate) MarshalJSON() ([]byte, error) { func (h *Handler) EstimateFee(broadcastedTxns []BroadcastedTransaction, simulationFlags []SimulationFlag, id BlockID, -) ([]FeeEstimate, *jsonrpc.Error) { - result, err := h.simulateTransactions(id, broadcastedTxns, append(simulationFlags, SkipFeeChargeFlag), false, true) +) ([]FeeEstimate, http.Header, *jsonrpc.Error) { + result, httpHeader, err := h.simulateTransactions(id, broadcastedTxns, append(simulationFlags, SkipFeeChargeFlag), false, true) if err != nil { - return nil, err + return nil, httpHeader, err } return utils.Map(result, func(tx SimulatedTransaction) FeeEstimate { return tx.FeeEstimation - }), nil + }), httpHeader, nil } func (h *Handler) EstimateFeeV0_6(broadcastedTxns []BroadcastedTransaction, simulationFlags []SimulationFlag, id BlockID, -) ([]FeeEstimate, *jsonrpc.Error) { - result, err := h.simulateTransactions(id, broadcastedTxns, append(simulationFlags, SkipFeeChargeFlag), true, true) +) ([]FeeEstimate, http.Header, *jsonrpc.Error) { + result, httpHeader, err := h.simulateTransactions(id, broadcastedTxns, append(simulationFlags, SkipFeeChargeFlag), true, true) if err != nil { - return nil, err + return nil, httpHeader, err } return utils.Map(result, func(tx SimulatedTransaction) FeeEstimate { return tx.FeeEstimation - }), nil + }), httpHeader, nil } -func (h *Handler) EstimateMessageFee(msg MsgFromL1, id BlockID) (*FeeEstimate, *jsonrpc.Error) { //nolint:gocritic +func (h *Handler) EstimateMessageFee(msg MsgFromL1, id BlockID) (*FeeEstimate, http.Header, *jsonrpc.Error) { //nolint:gocritic return h.estimateMessageFee(msg, id, h.EstimateFee) } -func (h *Handler) EstimateMessageFeeV0_6(msg MsgFromL1, id BlockID) (*FeeEstimate, *jsonrpc.Error) { //nolint:gocritic - feeEstimate, rpcErr := h.estimateMessageFee(msg, id, h.EstimateFeeV0_6) +func (h *Handler) EstimateMessageFeeV0_6(msg MsgFromL1, id BlockID) (*FeeEstimate, http.Header, *jsonrpc.Error) { //nolint:gocritic + feeEstimate, httpHeader, rpcErr := h.estimateMessageFee(msg, id, h.EstimateFeeV0_6) if rpcErr != nil { - return nil, rpcErr + return nil, httpHeader, rpcErr } feeEstimate.v0_6Response = true feeEstimate.DataGasPrice = nil feeEstimate.DataGasConsumed = nil - return feeEstimate, nil + return feeEstimate, httpHeader, nil } type estimateFeeHandler func(broadcastedTxns []BroadcastedTransaction, simulationFlags []SimulationFlag, id BlockID, -) ([]FeeEstimate, *jsonrpc.Error) +) ([]FeeEstimate, http.Header, *jsonrpc.Error) -func (h *Handler) estimateMessageFee(msg MsgFromL1, id BlockID, f estimateFeeHandler) (*FeeEstimate, *jsonrpc.Error) { //nolint:gocritic +//nolint:gocritic +func (h *Handler) estimateMessageFee(msg MsgFromL1, id BlockID, f estimateFeeHandler) (*FeeEstimate, + http.Header, *jsonrpc.Error, +) { calldata := make([]*felt.Felt, 0, len(msg.Payload)+1) // The order of the calldata parameters matters. msg.From must be prepended. calldata = append(calldata, new(felt.Felt).SetBytes(msg.From.Bytes())) @@ -129,15 +133,15 @@ func (h *Handler) estimateMessageFee(msg MsgFromL1, id BlockID, f estimateFeeHan // Must be greater than zero to successfully execute transaction. PaidFeeOnL1: new(felt.Felt).SetUint64(1), } - estimates, rpcErr := f([]BroadcastedTransaction{tx}, nil, id) + estimates, httpHeader, rpcErr := f([]BroadcastedTransaction{tx}, nil, id) if rpcErr != nil { if rpcErr.Code == ErrTransactionExecutionError.Code { data := rpcErr.Data.(TransactionExecutionErrorData) - return nil, makeContractError(errors.New(data.ExecutionError)) + return nil, httpHeader, makeContractError(errors.New(data.ExecutionError)) } - return nil, rpcErr + return nil, httpHeader, rpcErr } - return &estimates[0], nil + return &estimates[0], httpHeader, nil } type ContractErrorData struct { diff --git a/rpc/estimate_fee_test.go b/rpc/estimate_fee_test.go index 0520a806f2..660d3874d3 100644 --- a/rpc/estimate_fee_test.go +++ b/rpc/estimate_fee_test.go @@ -38,8 +38,9 @@ func TestEstimateMessageFeeV0_6(t *testing.T) { t.Run("block not found", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - _, err := handler.EstimateMessageFeeV0_6(msg, rpc.BlockID{Latest: true}) + _, httpHeader, err := handler.EstimateMessageFeeV0_6(msg, rpc.BlockID{Latest: true}) require.Equal(t, rpc.ErrBlockNotFound, err) + require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader)) }) latestHeader := &core.Header{ @@ -58,7 +59,7 @@ func TestEstimateMessageFeeV0_6(t *testing.T) { }, gomock.Any(), &utils.Mainnet, gomock.Any(), false, true, false).DoAndReturn( func(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *vm.BlockInfo, state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, useBlobData bool, - ) ([]*felt.Felt, []*felt.Felt, []vm.TransactionTrace, error) { + ) ([]*felt.Felt, []*felt.Felt, []vm.TransactionTrace, uint64, error) { require.Len(t, txns, 1) assert.NotNil(t, txns[0].(*core.L1HandlerTransaction)) @@ -75,11 +76,11 @@ func TestEstimateMessageFeeV0_6(t *testing.T) { DeclaredClasses: []vm.DeclaredClass{}, ReplacedClasses: []vm.ReplacedClass{}, }, - }}, nil + }}, uint64(123), nil }, ) - estimateFee, err := handler.EstimateMessageFeeV0_6(msg, rpc.BlockID{Latest: true}) + estimateFee, httpHeader, err := handler.EstimateMessageFeeV0_6(msg, rpc.BlockID{Latest: true}) require.Nil(t, err) expectedJSON := fmt.Sprintf( `{"gas_consumed":%q,"gas_price":%q,"overall_fee":%q,"unit":"WEI"}`, @@ -93,6 +94,7 @@ func TestEstimateMessageFeeV0_6(t *testing.T) { estimateFeeJSON, jsonErr := json.Marshal(estimateFee) require.NoError(t, jsonErr) require.Equal(t, expectedJSON, string(estimateFeeJSON)) + require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader)) } func TestEstimateFee(t *testing.T) { @@ -114,32 +116,35 @@ func TestEstimateFee(t *testing.T) { blockInfo := vm.BlockInfo{Header: &core.Header{}} t.Run("ok with zero values", func(t *testing.T) { mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &blockInfo, mockState, n, true, false, true, true). - Return([]*felt.Felt{}, []*felt.Felt{}, []vm.TransactionTrace{}, nil) + Return([]*felt.Felt{}, []*felt.Felt{}, []vm.TransactionTrace{}, uint64(123), nil) - _, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{}, rpc.BlockID{Latest: true}) + _, httpHeader, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{}, rpc.BlockID{Latest: true}) require.Nil(t, err) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), "123") }) t.Run("ok with zero values, skip validate", func(t *testing.T) { mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &blockInfo, mockState, n, true, true, true, true). - Return([]*felt.Felt{}, []*felt.Felt{}, []vm.TransactionTrace{}, nil) + Return([]*felt.Felt{}, []*felt.Felt{}, []vm.TransactionTrace{}, uint64(123), nil) - _, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}, rpc.BlockID{Latest: true}) + _, httpHeader, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}, rpc.BlockID{Latest: true}) require.Nil(t, err) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), "123") }) t.Run("transaction execution error", func(t *testing.T) { mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &blockInfo, mockState, n, true, true, true, true). - Return(nil, nil, nil, vm.TransactionExecutionError{ + Return(nil, nil, nil, uint64(0), vm.TransactionExecutionError{ Index: 44, Cause: errors.New("oops"), }) - _, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}, rpc.BlockID{Latest: true}) + _, httpHeader, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}, rpc.BlockID{Latest: true}) require.Equal(t, rpc.ErrTransactionExecutionError.CloneWithData(rpc.TransactionExecutionErrorData{ TransactionIndex: 44, ExecutionError: "oops", }), err) + require.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), "0") }) } diff --git a/rpc/handlers_test.go b/rpc/handlers_test.go index 48374d20d8..c34b7727dd 100644 --- a/rpc/handlers_test.go +++ b/rpc/handlers_test.go @@ -60,8 +60,9 @@ func TestThrottledVMError(t *testing.T) { t.Run("simulate", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(&core.Header{}, nil) - _, rpcErr := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipFeeChargeFlag}) + _, httpHeader, rpcErr := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipFeeChargeFlag}) assert.Equal(t, throttledErr, rpcErr.Data) + assert.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader)) }) t.Run("trace", func(t *testing.T) { @@ -95,7 +96,8 @@ func TestThrottledVMError(t *testing.T) { headState := mocks.NewMockStateHistoryReader(mockCtrl) headState.EXPECT().Class(declareTx.ClassHash).Return(declaredClass, nil) mockReader.EXPECT().PendingState().Return(headState, nopCloser, nil) - _, rpcErr := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Hash: blockHash}) + _, httpHeader, rpcErr := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Hash: blockHash}) assert.Equal(t, throttledErr, rpcErr.Data) + assert.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader)) }) } diff --git a/rpc/simulation.go b/rpc/simulation.go index fb58b230b4..63aee580ff 100644 --- a/rpc/simulation.go +++ b/rpc/simulation.go @@ -3,7 +3,9 @@ package rpc import ( "errors" "fmt" + "net/http" "slices" + "strconv" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" @@ -19,6 +21,8 @@ const ( SkipFeeChargeFlag ) +const ExecutionStepsHeader string = "X-Cairo-Steps" + func (s *SimulationFlag) UnmarshalJSON(bytes []byte) (err error) { switch flag := string(bytes); flag { case `"SKIP_VALIDATE"`: @@ -48,14 +52,14 @@ type TracedBlockTransaction struct { func (h *Handler) SimulateTransactions(id BlockID, transactions []BroadcastedTransaction, simulationFlags []SimulationFlag, -) ([]SimulatedTransaction, *jsonrpc.Error) { +) ([]SimulatedTransaction, http.Header, *jsonrpc.Error) { return h.simulateTransactions(id, transactions, simulationFlags, false, false) } // pre 13.1 func (h *Handler) SimulateTransactionsV0_6(id BlockID, transactions []BroadcastedTransaction, simulationFlags []SimulationFlag, -) ([]SimulatedTransaction, *jsonrpc.Error) { +) ([]SimulatedTransaction, http.Header, *jsonrpc.Error) { // todo double check errOnRevert = false return h.simulateTransactions(id, transactions, simulationFlags, true, false) } @@ -63,19 +67,22 @@ func (h *Handler) SimulateTransactionsV0_6(id BlockID, transactions []Broadcaste //nolint:funlen,gocyclo func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTransaction, simulationFlags []SimulationFlag, v0_6Response, errOnRevert bool, -) ([]SimulatedTransaction, *jsonrpc.Error) { +) ([]SimulatedTransaction, http.Header, *jsonrpc.Error) { skipFeeCharge := slices.Contains(simulationFlags, SkipFeeChargeFlag) skipValidate := slices.Contains(simulationFlags, SkipValidateFlag) + httpHeader := http.Header{} + httpHeader.Set(ExecutionStepsHeader, "0") + state, closer, rpcErr := h.stateByBlockID(&id) if rpcErr != nil { - return nil, rpcErr + return nil, httpHeader, rpcErr } defer h.callAndLogErr(closer, "Failed to close state in starknet_estimateFee") header, rpcErr := h.blockHeaderByID(&id) if rpcErr != nil { - return nil, rpcErr + return nil, httpHeader, rpcErr } txns := make([]core.Transaction, 0, len(transactions)) @@ -85,7 +92,7 @@ func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTra for idx := range transactions { txn, declaredClass, paidFeeOnL1, aErr := adaptBroadcastedTransaction(&transactions[idx], h.bcReader.Network()) if aErr != nil { - return nil, jsonrpc.Err(jsonrpc.InvalidParams, aErr.Error()) + return nil, httpHeader, jsonrpc.Err(jsonrpc.InvalidParams, aErr.Error()) } if paidFeeOnL1 != nil { @@ -100,24 +107,27 @@ func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTra blockHashToBeRevealed, err := h.getRevealedBlockHash(header.Number) if err != nil { - return nil, ErrInternal.CloneWithData(err) + return nil, httpHeader, ErrInternal.CloneWithData(err) } blockInfo := vm.BlockInfo{ Header: header, BlockHashToBeRevealed: blockHashToBeRevealed, } useBlobData := !v0_6Response - overallFees, dataGasConsumed, traces, err := h.vm.Execute(txns, classes, paidFeesOnL1, &blockInfo, + overallFees, dataGasConsumed, traces, numSteps, err := h.vm.Execute(txns, classes, paidFeesOnL1, &blockInfo, state, h.bcReader.Network(), skipFeeCharge, skipValidate, errOnRevert, useBlobData) + + httpHeader.Set(ExecutionStepsHeader, strconv.FormatUint(numSteps, 10)) + if err != nil { if errors.Is(err, utils.ErrResourceBusy) { - return nil, ErrInternal.CloneWithData(throttledVMErr) + return nil, httpHeader, ErrInternal.CloneWithData(throttledVMErr) } var txnExecutionError vm.TransactionExecutionError if errors.As(err, &txnExecutionError) { - return nil, makeTransactionExecutionError(&txnExecutionError) + return nil, httpHeader, makeTransactionExecutionError(&txnExecutionError) } - return nil, ErrUnexpectedError.CloneWithData(err.Error()) + return nil, httpHeader, ErrUnexpectedError.CloneWithData(err.Error()) } result := make([]SimulatedTransaction, 0, len(overallFees)) @@ -173,7 +183,7 @@ func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTra }) } - return result, nil + return result, httpHeader, nil } type TransactionExecutionErrorData struct { diff --git a/rpc/simulation_test.go b/rpc/simulation_test.go index 8a35aee4e0..c17929cfa7 100644 --- a/rpc/simulation_test.go +++ b/rpc/simulation_test.go @@ -10,6 +10,7 @@ import ( "github.com/NethermindEth/juno/rpc" "github.com/NethermindEth/juno/utils" "github.com/NethermindEth/juno/vm" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) @@ -34,23 +35,27 @@ func TestSimulateTransactionsV0_6(t *testing.T) { mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil).AnyTimes() t.Run("ok with zero values, skip fee", func(t *testing.T) { + stepsUsed := uint64(123) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, }, mockState, n, true, false, false, false). - Return([]*felt.Felt{}, []*felt.Felt{}, []vm.TransactionTrace{}, nil) + Return([]*felt.Felt{}, []*felt.Felt{}, []vm.TransactionTrace{}, stepsUsed, nil) - _, err := handler.SimulateTransactionsV0_6(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipFeeChargeFlag}) + _, httpHeader, err := handler.SimulateTransactionsV0_6(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipFeeChargeFlag}) require.Nil(t, err) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), "123") }) t.Run("ok with zero values, skip validate", func(t *testing.T) { + stepsUsed := uint64(123) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, }, mockState, n, false, true, false, false). - Return([]*felt.Felt{}, []*felt.Felt{}, []vm.TransactionTrace{}, nil) + Return([]*felt.Felt{}, []*felt.Felt{}, []vm.TransactionTrace{}, stepsUsed, nil) - _, err := handler.SimulateTransactionsV0_6(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}) + _, httpHeader, err := handler.SimulateTransactionsV0_6(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}) require.Nil(t, err) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), "123") }) t.Run("transaction execution error", func(t *testing.T) { @@ -58,31 +63,33 @@ func TestSimulateTransactionsV0_6(t *testing.T) { mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, }, mockState, n, false, true, false, false). - Return(nil, nil, nil, vm.TransactionExecutionError{ + Return(nil, nil, nil, uint64(0), vm.TransactionExecutionError{ Index: 44, Cause: errors.New("oops"), }) - _, err := handler.SimulateTransactionsV0_6(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}) + _, httpHeader, err := handler.SimulateTransactionsV0_6(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}) require.Equal(t, rpc.ErrTransactionExecutionError.CloneWithData(rpc.TransactionExecutionErrorData{ TransactionIndex: 44, ExecutionError: "oops", }), err) + require.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), "0") }) t.Run("v0_7", func(t *testing.T) { //nolint:dupl mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, }, mockState, n, false, true, false, true). - Return(nil, nil, nil, vm.TransactionExecutionError{ + Return(nil, nil, nil, uint64(0), vm.TransactionExecutionError{ Index: 44, Cause: errors.New("oops"), }) - _, err := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}) + _, httpHeader, err := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}) require.Equal(t, rpc.ErrTransactionExecutionError.CloneWithData(rpc.TransactionExecutionErrorData{ TransactionIndex: 44, ExecutionError: "oops", }), err) + require.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), "0") }) }) } diff --git a/rpc/trace.go b/rpc/trace.go index 809a8481de..5cf6f7c3cb 100644 --- a/rpc/trace.go +++ b/rpc/trace.go @@ -3,7 +3,9 @@ package rpc import ( "context" "errors" + "net/http" "slices" + "strconv" "github.com/Masterminds/semver/v3" "github.com/NethermindEth/juno/blockchain" @@ -129,18 +131,23 @@ func adaptFeederExecutionResources(resources *starknet.ExecutionResources) *vm.E // // It follows the specification defined here: // https://github.com/starkware-libs/starknet-specs/blob/1ae810e0137cc5d175ace4554892a4f43052be56/api/starknet_trace_api_openrpc.json#L11 -func (h *Handler) TraceTransaction(ctx context.Context, hash felt.Felt) (*vm.TransactionTrace, *jsonrpc.Error) { +func (h *Handler) TraceTransaction(ctx context.Context, hash felt.Felt) (*vm.TransactionTrace, http.Header, *jsonrpc.Error) { return h.traceTransaction(ctx, &hash, false) } -func (h *Handler) TraceTransactionV0_6(ctx context.Context, hash felt.Felt) (*vm.TransactionTrace, *jsonrpc.Error) { +func (h *Handler) TraceTransactionV0_6(ctx context.Context, hash felt.Felt) (*vm.TransactionTrace, http.Header, *jsonrpc.Error) { return h.traceTransaction(ctx, &hash, true) } -func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt, v0_6Response bool) (*vm.TransactionTrace, *jsonrpc.Error) { +func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt, v0_6Response bool) (*vm.TransactionTrace, + http.Header, *jsonrpc.Error, +) { _, blockHash, _, err := h.bcReader.Receipt(hash) + httpHeader := http.Header{} + httpHeader.Set(ExecutionStepsHeader, "0") + if err != nil { - return nil, ErrTxnHashNotFound + return nil, httpHeader, ErrTxnHashNotFound } var block *core.Block @@ -150,14 +157,14 @@ func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt, v0_6Res pending, err = h.bcReader.Pending() if err != nil { // for traceTransaction handlers there is no block not found error - return nil, ErrTxnHashNotFound + return nil, httpHeader, ErrTxnHashNotFound } block = pending.Block } else { block, err = h.bcReader.BlockByHash(blockHash) if err != nil { // for traceTransaction handlers there is no block not found error - return nil, ErrTxnHashNotFound + return nil, httpHeader, ErrTxnHashNotFound } } @@ -165,57 +172,63 @@ func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt, v0_6Res return tx.Hash().Equal(hash) }) if txIndex == -1 { - return nil, ErrTxnHashNotFound + return nil, httpHeader, ErrTxnHashNotFound } - traceResults, traceBlockErr := h.traceBlockTransactions(ctx, block, v0_6Response) + traceResults, header, traceBlockErr := h.traceBlockTransactions(ctx, block, v0_6Response) if traceBlockErr != nil { - return nil, traceBlockErr + return nil, header, traceBlockErr } - return traceResults[txIndex].TraceRoot, nil + return traceResults[txIndex].TraceRoot, header, nil } -func (h *Handler) TraceBlockTransactions(ctx context.Context, id BlockID) ([]TracedBlockTransaction, *jsonrpc.Error) { +func (h *Handler) TraceBlockTransactions(ctx context.Context, id BlockID) ([]TracedBlockTransaction, http.Header, *jsonrpc.Error) { block, rpcErr := h.blockByID(&id) if rpcErr != nil { - return nil, rpcErr + httpHeader := http.Header{} + httpHeader.Set(ExecutionStepsHeader, "0") + return nil, httpHeader, rpcErr } return h.traceBlockTransactions(ctx, block, false) } -func (h *Handler) TraceBlockTransactionsV0_6(ctx context.Context, id BlockID) ([]TracedBlockTransaction, *jsonrpc.Error) { +func (h *Handler) TraceBlockTransactionsV0_6(ctx context.Context, id BlockID) ([]TracedBlockTransaction, http.Header, *jsonrpc.Error) { block, rpcErr := h.blockByID(&id) if rpcErr != nil { - return nil, rpcErr + return nil, nil, rpcErr } return h.traceBlockTransactions(ctx, block, true) } func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block, v0_6Response bool, //nolint: gocyclo, funlen -) ([]TracedBlockTransaction, *jsonrpc.Error) { +) ([]TracedBlockTransaction, http.Header, *jsonrpc.Error) { + httpHeader := http.Header{} + httpHeader.Set(ExecutionStepsHeader, "0") + isPending := block.Hash == nil if !isPending { if blockVer, err := core.ParseBlockVersion(block.ProtocolVersion); err != nil { - return nil, ErrUnexpectedError.CloneWithData(err.Error()) + return nil, httpHeader, ErrUnexpectedError.CloneWithData(err.Error()) } else if blockVer.Compare(traceFallbackVersion) != 1 && block.ProtocolVersion != excludedVersion { // version <= 0.13.1 and not 0.13.1.1 fetch blocks from feeder gateway - return h.fetchTraces(ctx, block.Hash) + result, err := h.fetchTraces(ctx, block.Hash) + return result, httpHeader, err } if trace, hit := h.blockTraceCache.Get(traceCacheKey{ blockHash: *block.Hash, v0_6Response: v0_6Response, }); hit { - return trace, nil + return trace, httpHeader, nil } } state, closer, err := h.bcReader.StateAtBlockHash(block.ParentHash) if err != nil { - return nil, ErrBlockNotFound + return nil, httpHeader, ErrBlockNotFound } defer h.callAndLogErr(closer, "Failed to close state in traceBlockTransactions") @@ -229,7 +242,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block, headState, headStateCloser, err = h.bcReader.HeadState() } if err != nil { - return nil, jsonrpc.Err(jsonrpc.InternalError, err.Error()) + return nil, httpHeader, jsonrpc.Err(jsonrpc.InternalError, err.Error()) } defer h.callAndLogErr(headStateCloser, "Failed to close head state in traceBlockTransactions") @@ -241,7 +254,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block, case *core.DeclareTransaction: class, stateErr := headState.Class(tx.ClassHash) if stateErr != nil { - return nil, jsonrpc.Err(jsonrpc.InternalError, stateErr.Error()) + return nil, httpHeader, jsonrpc.Err(jsonrpc.InternalError, stateErr.Error()) } classes = append(classes, class.Class) case *core.L1HandlerTransaction: @@ -252,7 +265,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block, blockHashToBeRevealed, err := h.getRevealedBlockHash(block.Number) if err != nil { - return nil, ErrInternal.CloneWithData(err) + return nil, httpHeader, ErrInternal.CloneWithData(err) } network := h.bcReader.Network() header := block.Header @@ -262,15 +275,18 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block, } useBlobData := !v0_6Response - overallFees, dataGasConsumed, traces, err := h.vm.Execute(block.Transactions, classes, paidFeesOnL1, &blockInfo, state, network, false, - false, false, useBlobData) + overallFees, dataGasConsumed, traces, numSteps, err := h.vm.Execute(block.Transactions, classes, paidFeesOnL1, + &blockInfo, state, network, false, false, false, useBlobData) + + httpHeader.Set(ExecutionStepsHeader, strconv.FormatUint(numSteps, 10)) + if err != nil { if errors.Is(err, utils.ErrResourceBusy) { - return nil, ErrInternal.CloneWithData(throttledVMErr) + return nil, httpHeader, ErrInternal.CloneWithData(throttledVMErr) } // Since we are tracing an existing block, we know that there should be no errors during execution. If we encounter any, // report them as unexpected errors - return nil, ErrUnexpectedError.CloneWithData(err.Error()) + return nil, httpHeader, ErrUnexpectedError.CloneWithData(err.Error()) } result := make([]TracedBlockTransaction, 0, len(traces)) @@ -317,7 +333,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block, }, result) } - return result, nil + return result, httpHeader, nil } func (h *Handler) fetchTraces(ctx context.Context, blockHash *felt.Felt) ([]TracedBlockTransaction, *jsonrpc.Error) { diff --git a/rpc/trace_test.go b/rpc/trace_test.go index 3721788a68..a9fc4a47d5 100644 --- a/rpc/trace_test.go +++ b/rpc/trace_test.go @@ -59,12 +59,14 @@ func TestTraceFallback(t *testing.T) { return mockReader.BlockByNumber(test.blockNumber) }).Times(2) handler := rpc.New(mockReader, nil, nil, "", nil) - _, jErr := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Number: test.blockNumber}) + _, httpHeader, jErr := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Number: test.blockNumber}) require.Equal(t, rpc.ErrInternal.Code, jErr.Code) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), "0") handler = handler.WithFeeder(client) - trace, jErr := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Number: test.blockNumber}) + trace, httpHeader, jErr := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Number: test.blockNumber}) require.Nil(t, jErr) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), "0") jsonStr, err := json.Marshal(trace) require.NoError(t, err) assert.JSONEq(t, test.want, string(jsonStr)) @@ -86,9 +88,10 @@ func TestTraceTransaction(t *testing.T) { // Receipt() returns error related to db mockReader.EXPECT().Receipt(hash).Return(nil, nil, uint64(0), db.ErrKeyNotFound) - trace, err := handler.TraceTransaction(context.Background(), *hash) + trace, httpHeader, err := handler.TraceTransaction(context.Background(), *hash) assert.Nil(t, trace) assert.Equal(t, rpc.ErrTxnHashNotFound, err) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), "0") }) t.Run("ok", func(t *testing.T) { hash := utils.HexToFelt(t, "0x37b244ea7dc6b3f9735fba02d183ef0d6807a572dd91a63cc1b14b923c1ac0") @@ -157,11 +160,15 @@ func TestTraceTransaction(t *testing.T) { require.NoError(t, json.Unmarshal(json.RawMessage(vmTraceJSON), vmTrace)) consumedGas := []*felt.Felt{new(felt.Felt).SetUint64(1)} overallFee := []*felt.Felt{new(felt.Felt).SetUint64(1)} + stepsUsed := uint64(123) + stepsUsedStr := "123" mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, []*felt.Felt{}, - &vm.BlockInfo{Header: header}, gomock.Any(), &utils.Mainnet, false, false, false, true).Return(overallFee, consumedGas, []vm.TransactionTrace{*vmTrace}, nil) + &vm.BlockInfo{Header: header}, gomock.Any(), &utils.Mainnet, false, false, + false, true).Return(overallFee, consumedGas, []vm.TransactionTrace{*vmTrace}, stepsUsed, nil) - trace, err := handler.TraceTransaction(context.Background(), *hash) + trace, httpHeader, err := handler.TraceTransaction(context.Background(), *hash) require.Nil(t, err) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), stepsUsedStr) vmTrace.ExecutionResources = &vm.ExecutionResources{ ComputationResources: vm.ComputationResources{ @@ -244,12 +251,15 @@ func TestTraceTransaction(t *testing.T) { require.NoError(t, json.Unmarshal(json.RawMessage(vmTraceJSON), vmTrace)) consumedGas := []*felt.Felt{new(felt.Felt).SetUint64(1)} overallFee := []*felt.Felt{new(felt.Felt).SetUint64(1)} + stepsUsed := uint64(123) + stepsUsedStr := "123" mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, []*felt.Felt{}, &vm.BlockInfo{Header: header}, gomock.Any(), &utils.Mainnet, false, false, false, true). - Return(overallFee, consumedGas, []vm.TransactionTrace{*vmTrace}, nil) + Return(overallFee, consumedGas, []vm.TransactionTrace{*vmTrace}, stepsUsed, nil) - trace, err := handler.TraceTransaction(context.Background(), *hash) + trace, httpHeader, err := handler.TraceTransaction(context.Background(), *hash) require.Nil(t, err) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), stepsUsedStr) vmTrace.ExecutionResources = &vm.ExecutionResources{ // other of fields are zero @@ -275,8 +285,9 @@ func TestTraceTransactionV0_6(t *testing.T) { // Receipt() returns error related to db mockReader.EXPECT().Receipt(hash).Return(nil, nil, uint64(0), db.ErrKeyNotFound) - trace, err := handler.TraceTransactionV0_6(context.Background(), *hash) + trace, httpHeader, err := handler.TraceTransactionV0_6(context.Background(), *hash) assert.Nil(t, trace) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), "0") assert.Equal(t, rpc.ErrTxnHashNotFound, err) }) t.Run("ok", func(t *testing.T) { @@ -323,13 +334,16 @@ func TestTraceTransactionV0_6(t *testing.T) { } }`) vmTrace := new(vm.TransactionTrace) + stepsUsed := uint64(123) + stepsUsedStr := "123" require.NoError(t, json.Unmarshal(vmTraceJSON, vmTrace)) mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, []*felt.Felt{}, &vm.BlockInfo{Header: header}, gomock.Any(), &utils.Mainnet, false, false, false, false). - Return(nil, nil, []vm.TransactionTrace{*vmTrace}, nil) + Return(nil, nil, []vm.TransactionTrace{*vmTrace}, stepsUsed, nil) - trace, err := handler.TraceTransactionV0_6(context.Background(), *hash) + trace, httpHeader, err := handler.TraceTransactionV0_6(context.Background(), *hash) require.Nil(t, err) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), stepsUsedStr) assert.Equal(t, vmTrace, trace) }) t.Run("pending block", func(t *testing.T) { @@ -379,13 +393,16 @@ func TestTraceTransactionV0_6(t *testing.T) { } }`) vmTrace := new(vm.TransactionTrace) + stepsUsed := uint64(123) + stepsUsedStr := "123" require.NoError(t, json.Unmarshal(vmTraceJSON, vmTrace)) mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, []*felt.Felt{}, &vm.BlockInfo{Header: header}, gomock.Any(), &utils.Mainnet, false, false, false, false). - Return(nil, nil, []vm.TransactionTrace{*vmTrace}, nil) + Return(nil, nil, []vm.TransactionTrace{*vmTrace}, stepsUsed, nil) - trace, err := handler.TraceTransactionV0_6(context.Background(), *hash) + trace, httpHeader, err := handler.TraceTransactionV0_6(context.Background(), *hash) require.Nil(t, err) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), stepsUsedStr) assert.Equal(t, vmTrace, trace) }) } @@ -405,8 +422,9 @@ func TestTraceBlockTransactionsV0_6(t *testing.T) { chain := blockchain.New(pebble.NewMemTest(t), n) handler := rpc.New(chain, nil, nil, "", log) - update, rpcErr := handler.TraceBlockTransactions(context.Background(), id) + update, httpHeader, rpcErr := handler.TraceBlockTransactions(context.Background(), id) assert.Nil(t, update) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), "0") assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) }) } @@ -469,13 +487,16 @@ func TestTraceBlockTransactionsV0_6(t *testing.T) { } }`) vmTrace := vm.TransactionTrace{} + stepsUsed := uint64(123) + stepsUsedStr := "123" require.NoError(t, json.Unmarshal(vmTraceJSON, &vmTrace)) mockVM.EXPECT().Execute(block.Transactions, []core.Class{declaredClass.Class}, paidL1Fees, &vm.BlockInfo{Header: header}, gomock.Any(), n, false, false, false, false). - Return(nil, nil, []vm.TransactionTrace{vmTrace, vmTrace}, nil) + Return(nil, nil, []vm.TransactionTrace{vmTrace, vmTrace}, stepsUsed, nil) - result, err := handler.TraceBlockTransactionsV0_6(context.Background(), rpc.BlockID{Hash: blockHash}) + result, httpHeader, err := handler.TraceBlockTransactionsV0_6(context.Background(), rpc.BlockID{Hash: blockHash}) require.Nil(t, err) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), stepsUsedStr) assert.Equal(t, &vm.TransactionTrace{ ValidateInvocation: &vm.FunctionInvocation{}, ExecuteInvocation: &vm.ExecuteInvocation{}, @@ -536,10 +557,12 @@ func TestTraceBlockTransactionsV0_6(t *testing.T) { } }`) vmTrace := vm.TransactionTrace{} + stepsUsed := uint64(123) + stepsUsedStr := "123" require.NoError(t, json.Unmarshal(vmTraceJSON, &vmTrace)) mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, []*felt.Felt{}, &vm.BlockInfo{Header: header}, gomock.Any(), n, false, false, false, false). - Return(nil, nil, []vm.TransactionTrace{vmTrace}, nil) + Return(nil, nil, []vm.TransactionTrace{vmTrace}, stepsUsed, nil) expectedResult := []rpc.TracedBlockTransaction{ { @@ -547,8 +570,9 @@ func TestTraceBlockTransactionsV0_6(t *testing.T) { TraceRoot: &vmTrace, }, } - result, err := handler.TraceBlockTransactionsV0_6(context.Background(), rpc.BlockID{Hash: blockHash}) + result, httpHeader, err := handler.TraceBlockTransactionsV0_6(context.Background(), rpc.BlockID{Hash: blockHash}) require.Nil(t, err) + assert.Equal(t, httpHeader.Get(rpc.ExecutionStepsHeader), stepsUsedStr) assert.Equal(t, expectedResult, result) }) } diff --git a/vm/rust/src/lib.rs b/vm/rust/src/lib.rs index 6ead8c1d0c..f1aefd2b80 100644 --- a/vm/rust/src/lib.rs +++ b/vm/rust/src/lib.rs @@ -61,6 +61,7 @@ extern "C" { fn JunoAppendResponse(reader_handle: usize, ptr: *const c_uchar); fn JunoAppendActualFee(reader_handle: usize, ptr: *const c_uchar); fn JunoAppendDataGasConsumed(reader_handle: usize, ptr: *const c_uchar); + fn JunoAddExecutionSteps(reader_handle: usize, execSteps: c_ulonglong); } #[repr(C)] @@ -350,6 +351,13 @@ pub extern "C" fn cairoVMExecute( let actual_fee = t.transaction_receipt.fee.0.into(); let data_gas_consumed = t.transaction_receipt.da_gas.l1_data_gas.into(); + let execution_steps = t + .transaction_receipt + .resources + .vm_resources + .n_steps + .try_into() + .unwrap_or(u64::MAX); let trace = jsonrpc::new_transaction_trace(&txn_and_query_bit.txn, t, &mut txn_state); @@ -368,6 +376,7 @@ pub extern "C" fn cairoVMExecute( reader_handle, felt_to_byte_array(&data_gas_consumed).as_ptr(), ); + JunoAddExecutionSteps(reader_handle, execution_steps) } append_trace(reader_handle, trace.as_ref().unwrap(), &mut trace_buffer); } diff --git a/vm/vm.go b/vm/vm.go index 85be244dbf..36de6c4088 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -65,7 +65,7 @@ type VM interface { maxSteps uint64, useBlobData bool) ([]*felt.Felt, error) Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *BlockInfo, state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, useBlobData bool, - ) ([]*felt.Felt, []*felt.Felt, []TransactionTrace, error) + ) ([]*felt.Felt, []*felt.Felt, []TransactionTrace, uint64, error) } type vm struct { @@ -95,6 +95,7 @@ type callContext struct { actualFees []*felt.Felt traces []json.RawMessage dataGasConsumed []*felt.Felt + executionSteps uint64 } func unwrapContext(readerHandle C.uintptr_t) *callContext { @@ -138,6 +139,12 @@ func JunoAppendDataGasConsumed(readerHandle C.uintptr_t, ptr unsafe.Pointer) { context.dataGasConsumed = append(context.dataGasConsumed, makeFeltFromPtr(ptr)) } +//export JunoAddExecutionSteps +func JunoAddExecutionSteps(readerHandle C.uintptr_t, execSteps C.ulonglong) { + context := unwrapContext(readerHandle) + context.executionSteps += uint64(execSteps) +} + func makeFeltFromPtr(ptr unsafe.Pointer) *felt.Felt { return new(felt.Felt).SetBytes(C.GoBytes(ptr, felt.Bytes)) } @@ -260,7 +267,7 @@ func (v *vm) Call(callInfo *CallInfo, blockInfo *BlockInfo, state core.StateRead func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *BlockInfo, state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, useBlobData bool, -) ([]*felt.Felt, []*felt.Felt, []TransactionTrace, error) { +) ([]*felt.Felt, []*felt.Felt, []TransactionTrace, uint64, error) { context := &callContext{ state: state, log: v.log, @@ -270,12 +277,12 @@ func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, paid txnsJSON, classesJSON, err := marshalTxnsAndDeclaredClasses(txns, declaredClasses) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, 0, err } paidFeesOnL1Bytes, err := json.Marshal(paidFeesOnL1) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, 0, err } paidFeesOnL1CStr := cstring(paidFeesOnL1Bytes) @@ -324,22 +331,22 @@ func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, paid if context.err != "" { if context.errTxnIndex >= 0 { - return nil, nil, nil, TransactionExecutionError{ + return nil, nil, nil, 0, TransactionExecutionError{ Index: uint64(context.errTxnIndex), Cause: errors.New(context.err), } } - return nil, nil, nil, errors.New(context.err) + return nil, nil, nil, 0, errors.New(context.err) } traces := make([]TransactionTrace, len(context.traces)) for index, traceJSON := range context.traces { if err := json.Unmarshal(traceJSON, &traces[index]); err != nil { - return nil, nil, nil, fmt.Errorf("unmarshal trace: %v", err) + return nil, nil, nil, 0, fmt.Errorf("unmarshal trace: %v", err) } } - return context.actualFees, context.dataGasConsumed, traces, nil + return context.actualFees, context.dataGasConsumed, traces, context.executionSteps, nil } func marshalTxnsAndDeclaredClasses(txns []core.Transaction, declaredClasses []core.Class) (json.RawMessage, json.RawMessage, error) { //nolint:lll diff --git a/vm/vm_test.go b/vm/vm_test.go index 63c830debb..86e4cbe35f 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -203,7 +203,7 @@ func TestExecute(t *testing.T) { state := core.NewState(txn) t.Run("empty transaction list", func(t *testing.T) { - _, _, _, err := New(false, nil).Execute([]core.Transaction{}, []core.Class{}, []*felt.Felt{}, &BlockInfo{ + _, _, _, _, err := New(false, nil).Execute([]core.Transaction{}, []core.Class{}, []*felt.Felt{}, &BlockInfo{ Header: &core.Header{ Timestamp: 1666877926, SequencerAddress: utils.HexToFelt(t, "0x46a89ae102987331d369645031b49c27738ed096f2789c24449966da4c6de6b"), @@ -215,7 +215,7 @@ func TestExecute(t *testing.T) { require.NoError(t, err) }) t.Run("zero data", func(t *testing.T) { - _, _, _, err := New(false, nil).Execute(nil, nil, []*felt.Felt{}, &BlockInfo{ + _, _, _, _, err := New(false, nil).Execute(nil, nil, []*felt.Felt{}, &BlockInfo{ Header: &core.Header{ SequencerAddress: &felt.Zero, GasPrice: &felt.Zero,