Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EVM-437 Batch calls over websockets not working properly #1588

Merged
merged 4 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions jsonrpc/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ type Request struct {
Params json.RawMessage `json:"params,omitempty"`
}

type BatchRequest []Request

// Response is a jsonrpc response interface
type Response interface {
GetID() interface{}
Expand Down
117 changes: 79 additions & 38 deletions jsonrpc/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"math"
"reflect"
"strconv"
"strings"
"unicode"

Expand Down Expand Up @@ -58,6 +59,10 @@ type dispatcherParams struct {
blockRangeLimit uint64
}

func (dp dispatcherParams) isExceedingBatchLengthLimit(value uint64) bool {
return dp.jsonRPCBatchLengthLimit != 0 && value > dp.jsonRPCBatchLengthLimit
}

func newDispatcher(
logger hclog.Logger,
store JSONRPCStore,
Expand Down Expand Up @@ -161,22 +166,23 @@ type wsConn interface {

// as per https://www.jsonrpc.org/specification, the `id` in JSON-RPC 2.0
// can only be a string or a non-decimal integer
func formatFilterResponse(id interface{}, resp string) (string, Error) {
func formatID(id interface{}) (interface{}, Error) {
switch t := id.(type) {
case string:
return fmt.Sprintf(`{"jsonrpc":"2.0","id":"%s","result":"%s"}`, t, resp), nil
return t, nil
case float64:
if t == math.Trunc(t) {
return fmt.Sprintf(`{"jsonrpc":"2.0","id":%d,"result":"%s"}`, int(t), resp), nil
return int(t), nil
} else {
return "", NewInvalidRequestError("Invalid json request")
}
case nil:
return fmt.Sprintf(`{"jsonrpc":"2.0","id":null,"result":"%s"}`, resp), nil
return nil, nil
default:
return "", NewInvalidRequestError("Invalid json request")
}
}

func (d *Dispatcher) handleSubscribe(req Request, conn wsConn) (string, Error) {
var params []interface{}
if err := json.Unmarshal(req.Params, &params); err != nil {
Expand Down Expand Up @@ -231,54 +237,89 @@ func (d *Dispatcher) RemoveFilterByWs(conn wsConn) {
}

func (d *Dispatcher) HandleWs(reqBody []byte, conn wsConn) ([]byte, error) {
var req Request
if err := json.Unmarshal(reqBody, &req); err != nil {
return NewRPCResponse(req.ID, "2.0", nil, NewInvalidRequestError("Invalid json request")).Bytes()
}
const (
openSquareBracket byte = '['
closeSquareBracket byte = ']'
comma byte = ','
)

// if the request method is eth_subscribe we need to create a
// new filter with ws connection
if req.Method == "eth_subscribe" {
filterID, err := d.handleSubscribe(req, conn)
if err != nil {
return NewRPCResponse(req.ID, "2.0", nil, err).Bytes()
}
reqBody = bytes.TrimLeft(reqBody, " \t\r\n")

resp, err := formatFilterResponse(req.ID, filterID)
// if body begins with [ than consider this request as batch request
igorcrevar marked this conversation as resolved.
Show resolved Hide resolved
if len(reqBody) > 0 && reqBody[0] == openSquareBracket {
var batchReq BatchRequest

err := json.Unmarshal(reqBody, &batchReq)
if err != nil {
return NewRPCResponse(req.ID, "2.0", nil, err).Bytes()
return NewRPCResponse(nil, "2.0", nil, NewInvalidRequestError("Invalid json request")).Bytes()
}

return []byte(resp), nil
}

if req.Method == "eth_unsubscribe" {
ok, err := d.handleUnsubscribe(req)
if err != nil {
return nil, err
// if not disabled, avoid handling long batch requests
if d.params.isExceedingBatchLengthLimit(uint64(len(batchReq))) {
return NewRPCResponse(
nil,
Stefan-Ethernal marked this conversation as resolved.
Show resolved Hide resolved
"2.0",
nil,
NewInvalidRequestError("Batch request length too long"),
).Bytes()
}

res := "false"
if ok {
res = "true"
}
responses := make([][]byte, len(batchReq))

resp, err := formatFilterResponse(req.ID, res)
if err != nil {
return NewRPCResponse(req.ID, "2.0", nil, err).Bytes()
for i, req := range batchReq {
responses[i], err = d.handleWs(req, conn).Bytes()
if err != nil {
return nil, err
}
}

return []byte(resp), nil
var buf bytes.Buffer

// batch output should look like:
// [ { "requestId": "1", "status": 200 }, { "requestId": "2", "status": 200 } ]
buf.WriteByte(openSquareBracket) // [
buf.Write(bytes.Join(responses, []byte{comma})) // join responses with the comma separator
buf.WriteByte(closeSquareBracket) // ]

return buf.Bytes(), nil
}

// its a normal query that we handle with the dispatcher
resp, err := d.handleReq(req)
var req Request
if err := json.Unmarshal(reqBody, &req); err != nil {
return NewRPCResponse(req.ID, "2.0", nil, NewInvalidRequestError("Invalid json request")).Bytes()
}

return d.handleWs(req, conn).Bytes()
}

func (d *Dispatcher) handleWs(req Request, conn wsConn) Response {
id, err := formatID(req.ID)
if err != nil {
return nil, err
return NewRPCResponse(nil, "2.0", nil, err)
}

var response []byte

switch req.Method {
case "eth_subscribe":
var filterID string

// if the request method is eth_subscribe we need to create a new filter with ws connection
if filterID, err = d.handleSubscribe(req, conn); err == nil {
response = []byte(fmt.Sprintf("\"%s\"", filterID))
}
case "eth_unsubscribe":
var ok bool

if ok, err = d.handleUnsubscribe(req); err == nil {
response = []byte(strconv.FormatBool(ok))
}
default:
// its a normal query that we handle with the dispatcher
response, err = d.handleReq(req)
}

return NewRPCResponse(req.ID, "2.0", resp, err).Bytes()
return NewRPCResponse(id, "2.0", response, err)
}

func (d *Dispatcher) Handle(reqBody []byte) ([]byte, error) {
Expand All @@ -303,7 +344,7 @@ func (d *Dispatcher) Handle(reqBody []byte) ([]byte, error) {
}

// handle batch requests
var requests []Request
var requests BatchRequest
if err := json.Unmarshal(reqBody, &requests); err != nil {
return NewRPCResponse(
nil,
Expand All @@ -314,7 +355,7 @@ func (d *Dispatcher) Handle(reqBody []byte) ([]byte, error) {
}

// if not disabled, avoid handling long batch requests
if d.params.jsonRPCBatchLengthLimit != 0 && len(requests) > int(d.params.jsonRPCBatchLengthLimit) {
if d.params.isExceedingBatchLengthLimit(uint64(len(requests))) {
return NewRPCResponse(
nil,
"2.0",
Expand Down
106 changes: 92 additions & 14 deletions jsonrpc/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jsonrpc

import (
"encoding/json"
"fmt"
"math/big"
"reflect"
"testing"
Expand Down Expand Up @@ -102,6 +103,8 @@ func TestDispatcher_HandleWebsocketConnection_EthSubscribe(t *testing.T) {
}

func TestDispatcher_WebsocketConnection_RequestFormats(t *testing.T) {
t.Parallel()

store := newMockStore()
dispatcher := newTestDispatcher(t,
hclog.NewNullLogger(),
Expand Down Expand Up @@ -212,6 +215,8 @@ func (m *mockService) Filter(f LogQuery) (interface{}, error) {
}

func TestDispatcherFuncDecode(t *testing.T) {
t.Parallel()

srv := &mockService{msgCh: make(chan interface{}, 10)}

dispatcher := newTestDispatcher(t,
Expand Down Expand Up @@ -290,20 +295,29 @@ func TestDispatcherFuncDecode(t *testing.T) {
}

func TestDispatcherBatchRequest(t *testing.T) {
handle := func(dispatcher *Dispatcher, reqBody []byte) []byte {
res, _ := dispatcher.Handle(reqBody)

return res
}
t.Parallel()

cases := []struct {
type caseData struct {
name string
desc string
dispatcher *Dispatcher
reqBody []byte
err *ObjectError
batchResponse []*SuccessResponse
}{
}

mock := &mockWsConn{
SetFilterIDFn: func(s string) {
},
GetFilterIDFn: func() string {
return ""
},
WriteMessageFn: func(i int, b []byte) error {
return nil
},
}

cases := []caseData{
{
"leading-whitespace",
"test with leading whitespace (\" \\t\\n\\n\\r\\)",
Expand Down Expand Up @@ -425,36 +439,100 @@ func TestDispatcherBatchRequest(t *testing.T) {
},
}

for _, c := range cases {
res := handle(c.dispatcher, c.reqBody)

check := func(c caseData, res []byte) {
if c.err != nil {
var resp ErrorResponse

assert.NoError(t, expectBatchJSONResult(res, &resp))
assert.Equal(t, resp.Error, c.err)
assert.Equal(t, c.err, resp.Error)
} else {
var batchResp []SuccessResponse
assert.NoError(t, expectBatchJSONResult(res, &batchResp))

if c.name == "leading-whitespace" {
assert.Len(t, batchResp, 4)
for index, resp := range batchResp {
assert.Equal(t, resp.Error, c.batchResponse[index].Error)
assert.Equal(t, c.batchResponse[index].Error, resp.Error)
}
} else if c.name == "valid-batch-req" {
assert.Len(t, batchResp, 6)
for index, resp := range batchResp {
assert.Equal(t, resp.Error, c.batchResponse[index].Error)
assert.Equal(t, c.batchResponse[index].Error, resp.Error)
}
} else if c.name == "no-limits" {
assert.Len(t, batchResp, 12)
for index, resp := range batchResp {
assert.Equal(t, resp.Error, c.batchResponse[index].Error)
assert.Equal(t, c.batchResponse[index].Error, resp.Error)
}
}
}
}

for _, c := range cases {
c := c

t.Run(c.name, func(t *testing.T) {
t.Parallel()

res, _ := c.dispatcher.HandleWs(c.reqBody, mock)

check(c, res)

res, _ = c.dispatcher.Handle(c.reqBody)

check(c, res)
})
}
}

func TestDispatcher_WebsocketConnection_Unsubscribe(t *testing.T) {
t.Parallel()

store := newMockStore()
dispatcher := newTestDispatcher(t,
hclog.NewNullLogger(),
store,
&dispatcherParams{
chainID: 0,
priceLimit: 0,
jsonRPCBatchLengthLimit: 20,
blockRangeLimit: 1000,
},
)
mockConn := &mockWsConn{
SetFilterIDFn: func(s string) {
},
GetFilterIDFn: func() string {
return ""
},
WriteMessageFn: func(i int, b []byte) error {
return nil
},
}

resp := SuccessResponse{}
reqUnsub := func(n string) []byte {
return []byte(fmt.Sprintf(`{"method": "eth_unsubscribe", "params": [%s]}`, n))
}

// non existing subscription
r, err := dispatcher.HandleWs(reqUnsub("\"787832\""), mockConn)
require.NoError(t, err)

require.NoError(t, json.Unmarshal(r, &resp))
assert.Equal(t, "false", string(resp.Result))

r, err = dispatcher.HandleWs([]byte(`{"method": "eth_subscribe", "params": ["newHeads"]}`), mockConn)
require.NoError(t, err)

require.NoError(t, json.Unmarshal(r, &resp))

// existing subscription
r, err = dispatcher.HandleWs(reqUnsub(string(resp.Result)), mockConn)
require.NoError(t, err)

require.NoError(t, json.Unmarshal(r, &resp))
assert.Equal(t, "true", string(resp.Result))
}

func newTestDispatcher(t *testing.T, logger hclog.Logger, store JSONRPCStore, params *dispatcherParams) *Dispatcher {
Expand Down