Skip to content

Commit

Permalink
Fix: complete handleRequest modification to allow headers
Browse files Browse the repository at this point in the history
  • Loading branch information
obasekiosa committed Aug 7, 2024
1 parent 6d65ec7 commit ef54d71
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 16 deletions.
29 changes: 22 additions & 7 deletions jsonrpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,18 @@ func (s *Server) registerMethod(method Method) error {
if numArgs != len(method.Params) {
return errors.New("number of non-context function params and param names must match")
}
if handlerT.NumOut() != 2 {
return errors.New("handler must return 2 values")
outSize := handlerT.NumOut()
if outSize < 2 || outSize > 3 {
return errors.New("handler must return 2 or 3 values")
}
if handlerT.Out(1) != reflect.TypeOf(&Error{}) {
return errors.New("second return value must be a *jsonrpc.Error")
if outSize == 2 && handlerT.Out(1) != reflect.TypeOf(&Error{}) {
return errors.New("second return value must be a *jsonrpc.Error for 2 tuple handler")
} else if outSize == 3 && handlerT.Out(2) != reflect.TypeOf(&Error{}) {
return errors.New("third return value must be a *jsonrpc.Error for 3 tuple handler")
}

if outSize == 3 && handlerT.Out(1) != reflect.TypeOf(http.Header{}) {
return errors.New("second return value must be a http.Header for 3 tuple handler")
}

// The method is valid. Mutate the appropriate fields and register on the server.
Expand All @@ -206,7 +213,7 @@ type connection struct {
w io.Writer
activated <-chan struct{}

//todo: guard against this in the code! don't depend on devs finding out about this!
//todo: guard for this in the code! don't depend on devs finding out about this!
// initialErr is not thread-safe. It must be set to its final value before the connection is activated.
initialErr error
}
Expand Down Expand Up @@ -282,7 +289,7 @@ func (s *Server) HandleReader(ctx context.Context, reader io.Reader) ([]byte, ht
Version: "2.0",
}

var header http.Header
header := http.Header{}

dec := json.NewDecoder(bufferedReader)
dec.UseNumber()
Expand Down Expand Up @@ -313,6 +320,10 @@ func (s *Server) HandleReader(ctx context.Context, reader io.Reader) ([]byte, ht
}
}

if res == nil {
return nil, header, nil
}

result, err := json.Marshal(res)
return result, header, err
}
Expand Down Expand Up @@ -376,11 +387,15 @@ func (s *Server) handleBatchRequest(ctx context.Context, batchReq []json.RawMess
wg.Wait()
// according to the spec if there are no response objects server must not return empty array
if len(responses) == 0 {
return nil, nil, nil
return nil, http.Header{}, nil
}

result, err := json.Marshal(responses)

if len(headers) == 0 {
return result, http.Header{}, err

Check warning on line 396 in jsonrpc/server.go

View check run for this annotation

Codecov / codecov/patch

jsonrpc/server.go#L396

Added line #L396 was not covered by tests
}

return result, headers[0], err // todo: fix batch request aggregate header
}

Expand Down
20 changes: 15 additions & 5 deletions jsonrpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,27 @@ func TestServer_RegisterMethod(t *testing.T) {
"no return": {
handler: func(param1, param2 int) {},
paramNames: []jsonrpc.Parameter{{Name: "param1"}, {Name: "param2"}},
want: "handler must return 2 values",
want: "handler must return 2 or 3 values",
},
"int return": {
handler: func(param1, param2 int) (int, int) { return 0, 0 },
paramNames: []jsonrpc.Parameter{{Name: "param1"}, {Name: "param2"}},
want: "second return value must be a *jsonrpc.Error",
want: "second return value must be a *jsonrpc.Error for 2 tuple handler",
},
"no error return 3": {
handler: func(param1, param2 int) (int, int, int) { return 0, 0, 0 },
paramNames: []jsonrpc.Parameter{{Name: "param1"}, {Name: "param2"}},
want: "third return value must be a *jsonrpc.Error for 3 tuple handler",
},
"no header return 3": {
handler: func(param1, param2 int) (int, int, *jsonrpc.Error) { return 0, 0, &jsonrpc.Error{} },
paramNames: []jsonrpc.Parameter{{Name: "param1"}, {Name: "param2"}},
want: "second return value must be a http.Header for 3 tuple handler",
},
"no error return": {
handler: func(param1, param2 int) (any, int) { return 0, 0 },
paramNames: []jsonrpc.Parameter{{Name: "param1"}, {Name: "param2"}},
want: "second return value must be a *jsonrpc.Error",
want: "second return value must be a *jsonrpc.Error for 2 tuple handler",
},
}

Expand Down Expand Up @@ -472,9 +482,9 @@ func TestHandle(t *testing.T) {
oldRequestFailedEventCount := len(listener.OnRequestFailedCalls)
oldRequestHandledCalls := len(listener.OnRequestHandledCalls)

res, header, err := server.HandleReader(context.Background(), strings.NewReader(test.req))
res, httpHeader, err := server.HandleReader(context.Background(), strings.NewReader(test.req))
require.NoError(t, err)
require.NotNil(t, header)
require.NotNil(t, httpHeader)

if test.isBatch {
assertBatchResponse(t, test.res, string(res))
Expand Down
18 changes: 14 additions & 4 deletions rpc/estimate_fee_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func TestEstimateMessageFee(t *testing.T) {
_, httpHeader, err := handler.EstimateMessageFeeV0_6(msg, rpc.BlockID{Latest: true})
require.Equal(t, rpc.ErrBlockNotFound, err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
})

latestHeader := &core.Header{
Expand Down Expand Up @@ -80,8 +81,11 @@ func TestEstimateMessageFee(t *testing.T) {
},
)

estimateFee, err := handler.EstimateMessageFeeV0_6(msg, rpc.BlockID{Latest: true})
estimateFee, httpHeader, err := handler.EstimateMessageFeeV0_6(msg, rpc.BlockID{Latest: true})
require.Nil(t, err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))

feeUnit := rpc.WEI
require.Equal(t, rpc.FeeEstimate{
GasConsumed: expectedGasConsumed,
Expand Down Expand Up @@ -114,16 +118,20 @@ func TestEstimateFee(t *testing.T) {
mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &blockInfo, mockState, n, true, true, false, false).
Return([]*felt.Felt{}, []vm.TransactionTrace{}, nil)

_, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{}, rpc.BlockID{Latest: true})
_, httpHeader, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{}, rpc.BlockID{Latest: true})
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
require.Nil(t, err)
})

t.Run("ok with zero values, skip validate", func(t *testing.T) {
mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &blockInfo, mockState, n, true, true, false, false).
Return([]*felt.Felt{}, []vm.TransactionTrace{}, nil)

_, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}, rpc.BlockID{Latest: true})
_, httpHeader, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}, rpc.BlockID{Latest: true})
require.Nil(t, err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
})

t.Run("transaction execution error", func(t *testing.T) {
Expand All @@ -133,11 +141,13 @@ func TestEstimateFee(t *testing.T) {
Cause: errors.New("oops"),
})

_, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}, rpc.BlockID{Latest: true})
_, httpHeader, err := handler.EstimateFee([]rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag}, rpc.BlockID{Latest: true})
require.Equal(t, rpc.ErrTransactionExecutionError.CloneWithData(rpc.TransactionExecutionErrorData{
TransactionIndex: 44,
ExecutionError: "oops",
}), err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))

mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &blockInfo, mockState, n, false, true, true, false).
Return(nil, nil, vm.TransactionExecutionError{
Expand Down
2 changes: 2 additions & 0 deletions rpc/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func TestThrottledVMError(t *testing.T) {
_, httpHeader, rpcErr := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipFeeChargeFlag})
assert.Equal(t, throttledErr, rpcErr.Data)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
})

t.Run("trace", func(t *testing.T) {
Expand Down Expand Up @@ -99,5 +100,6 @@ func TestThrottledVMError(t *testing.T) {
_, httpHeader, rpcErr := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Hash: blockHash})
assert.Equal(t, throttledErr, rpcErr.Data)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
})
}
4 changes: 4 additions & 0 deletions rpc/simulation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func TestSimulateTransactions(t *testing.T) {
_, httpHeader, err := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipFeeChargeFlag})
require.Nil(t, err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
})

t.Run("ok with zero values, skip validate", func(t *testing.T) {
Expand All @@ -53,6 +54,7 @@ func TestSimulateTransactions(t *testing.T) {
_, httpHeader, err := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag})
require.Nil(t, err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
})

t.Run("transaction execution error", func(t *testing.T) {
Expand All @@ -70,6 +72,7 @@ func TestSimulateTransactions(t *testing.T) {
ExecutionError: "oops",
}), err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))

mockVM.EXPECT().Execute(nil, nil, []*felt.Felt{}, &vm.BlockInfo{
Header: headsHeader,
Expand All @@ -84,5 +87,6 @@ func TestSimulateTransactions(t *testing.T) {
RevertError: "oops",
}), err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
})
}
11 changes: 11 additions & 0 deletions rpc/trace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@ func TestTraceFallback(t *testing.T) {
_, httpHeader, jErr := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Number: test.blockNumber})
require.Equal(t, rpc.ErrInternal.Code, jErr.Code)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))

handler = handler.WithFeeder(client)
trace, httpHeader, jErr := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Number: test.blockNumber})
require.Nil(t, jErr)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
jsonStr, err := json.Marshal(trace)
require.NoError(t, err)
assert.JSONEq(t, test.want, string(jsonStr))
Expand All @@ -92,6 +94,7 @@ func TestTraceTransaction(t *testing.T) {
assert.Nil(t, trace)
assert.Equal(t, rpc.ErrTxnHashNotFound, err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
})
t.Run("ok", func(t *testing.T) {
hash := utils.HexToFelt(t, "0x37b244ea7dc6b3f9735fba02d183ef0d6807a572dd91a63cc1b14b923c1ac0")
Expand Down Expand Up @@ -168,6 +171,7 @@ func TestTraceTransaction(t *testing.T) {
trace, httpHeader, err := handler.TraceTransaction(context.Background(), *hash)
require.Nil(t, err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))

vmTrace.ExecutionResources = &vm.ExecutionResources{
ComputationResources: vm.ComputationResources{
Expand Down Expand Up @@ -257,6 +261,7 @@ func TestTraceTransaction(t *testing.T) {
trace, httpHeader, err := handler.TraceTransaction(context.Background(), *hash)
require.Nil(t, err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))

vmTrace.ExecutionResources = &vm.ExecutionResources{
// other of fields are zero
Expand Down Expand Up @@ -285,6 +290,7 @@ func TestTraceTransactionV0_6(t *testing.T) {
trace, httpHeader, err := handler.TraceTransactionV0_6(context.Background(), *hash)
assert.Nil(t, trace)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
assert.Equal(t, rpc.ErrTxnHashNotFound, err)
})
t.Run("ok", func(t *testing.T) {
Expand Down Expand Up @@ -339,6 +345,7 @@ func TestTraceTransactionV0_6(t *testing.T) {
trace, httpHeader, err := handler.TraceTransactionV0_6(context.Background(), *hash)
require.Nil(t, err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
assert.Equal(t, vmTrace, trace)
})
t.Run("pending block", func(t *testing.T) {
Expand Down Expand Up @@ -396,6 +403,7 @@ func TestTraceTransactionV0_6(t *testing.T) {
trace, httpHeader, err := handler.TraceTransactionV0_6(context.Background(), *hash)
require.Nil(t, err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
assert.Equal(t, vmTrace, trace)
})
}
Expand All @@ -420,6 +428,7 @@ func TestTraceBlockTransactions(t *testing.T) {
update, httpHeader, rpcErr := handler.TraceBlockTransactions(context.Background(), id)
assert.Nil(t, update)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
assert.Equal(t, rpc.ErrBlockNotFound, rpcErr)
})
}
Expand Down Expand Up @@ -490,6 +499,7 @@ func TestTraceBlockTransactions(t *testing.T) {
result, httpHeader, err := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Hash: blockHash})
require.Nil(t, err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
assert.Equal(t, &vm.TransactionTrace{
ValidateInvocation: &vm.FunctionInvocation{},
ExecuteInvocation: &vm.ExecuteInvocation{},
Expand Down Expand Up @@ -564,6 +574,7 @@ func TestTraceBlockTransactions(t *testing.T) {
result, httpHeader, err := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Hash: blockHash})
require.Nil(t, err)
require.NotNil(t, httpHeader)
require.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader))
assert.Equal(t, expectedResult, result)
})
}
Expand Down

0 comments on commit ef54d71

Please sign in to comment.