Skip to content

Commit

Permalink
check for nil before returning invalid json in rpc streaming calls (#…
Browse files Browse the repository at this point in the history
…7104)

should handle nil having already been written in any rpc call before
writing it again causing invalid json to be returned.
  • Loading branch information
hexoscott authored Mar 14, 2023
1 parent 2ba3b08 commit 84ec0a0
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 1 deletion.
25 changes: 24 additions & 1 deletion rpc/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ func (h *handler) runMethod(ctx context.Context, msg *jsonrpcMessage, callb *cal
stream.WriteObjectField("result")
_, err := callb.call(ctx, msg.Method, args, stream)
if err != nil {
stream.WriteNil()
writeNilIfNotPresent(stream)
stream.WriteMore()
HandleError(err, stream)
}
Expand All @@ -519,6 +519,29 @@ func (h *handler) runMethod(ctx context.Context, msg *jsonrpcMessage, callb *cal
return nil
}

var nullAsBytes = []byte{110, 117, 108, 108}

// there are many avenues that could lead to an error being handled in runMethod, so we need to check
// if nil has already been written to the stream before writing it again here
func writeNilIfNotPresent(stream *jsoniter.Stream) {
b := stream.Buffer()
hasNil := true
if len(b) >= 4 {
b = b[len(b)-4:]
for i, v := range nullAsBytes {
if v != b[i] {
hasNil = false
break
}
}
} else {
hasNil = false
}
if !hasNil {
stream.WriteNil()
}
}

// unsubscribe is the callback function for all *_unsubscribe calls.
func (h *handler) unsubscribe(ctx context.Context, id ID) (bool, error) {
h.subLock.Lock()
Expand Down
84 changes: 84 additions & 0 deletions rpc/handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package rpc

import (
"bytes"
"context"
"fmt"
"reflect"
"testing"

jsoniter "github.com/json-iterator/go"
"github.com/stretchr/testify/assert"
)

func TestHandlerDoesNotDoubleWriteNull(t *testing.T) {

tests := map[string]struct {
params []byte
expected string
}{
"error_with_stream_write": {
params: []byte("[1]"),
expected: `{"jsonrpc":"2.0","id":1,"result":null,"error":{"code":-32000,"message":"id 1"}}`,
},
"error_without_stream_write": {
params: []byte("[2]"),
expected: `{"jsonrpc":"2.0","id":1,"result":null,"error":{"code":-32000,"message":"id 2"}}`,
},
"no_error": {
params: []byte("[3]"),
expected: `{"jsonrpc":"2.0","id":1,"result":{}}`,
},
}

for name, testParams := range tests {
t.Run(name, func(t *testing.T) {
msg := jsonrpcMessage{
Version: "2.0",
ID: []byte{49},
Method: "test_test",
Params: testParams.params,
Error: nil,
Result: nil,
}

dummyFunc := func(id int, stream *jsoniter.Stream) error {
if id == 1 {
stream.WriteNil()
return fmt.Errorf("id 1")
}
if id == 2 {
return fmt.Errorf("id 2")
}
stream.WriteEmptyObject()
return nil
}

var arg1 int
cb := &callback{
fn: reflect.ValueOf(dummyFunc),
rcvr: reflect.Value{},
argTypes: []reflect.Type{reflect.TypeOf(arg1)},
hasCtx: false,
errPos: 0,
isSubscribe: false,
streamable: true,
}

args, err := parsePositionalArguments((msg).Params, cb.argTypes)
if err != nil {
t.Fatal(err)
}

var buf bytes.Buffer
stream := jsoniter.NewStream(jsoniter.ConfigDefault, &buf, 4096)

h := handler{}
h.runMethod(context.Background(), &msg, cb, args, stream)

output := buf.String()
assert.Equal(t, testParams.expected, output, "expected output should match")
})
}

}

0 comments on commit 84ec0a0

Please sign in to comment.