diff --git a/call_opt.go b/call_opt.go index 851baa5..84a04c2 100644 --- a/call_opt.go +++ b/call_opt.go @@ -46,3 +46,12 @@ func StringID() CallOption { return nil }) } + +// OmitNilParams returns a call option that instructs requests to omit params +// values of nil instead of JSON encoding them to null. +func OmitNilParams() CallOption { + return callOptionFunc(func(r *Request) error { + r.OmitNilParams = true + return nil + }) +} diff --git a/call_opt_test.go b/call_opt_test.go index b64f661..aff8003 100644 --- a/call_opt_test.go +++ b/call_opt_test.go @@ -2,7 +2,10 @@ package jsonrpc2_test import ( "context" + "encoding/json" "fmt" + "net" + "sync" "testing" "github.com/sourcegraph/jsonrpc2" @@ -140,3 +143,111 @@ func TestExtraField(t *testing.T) { t.Fatal(err) } } + +func TestOmitNilParams(t *testing.T) { + rawJSONMessage := func(v string) *json.RawMessage { + b := []byte(v) + return (*json.RawMessage)(&b) + } + + type testCase struct { + callOpt jsonrpc2.CallOption + sendParams interface{} + wantParams *json.RawMessage + } + + testCases := []testCase{ + { + sendParams: nil, + wantParams: rawJSONMessage("null"), + }, + { + sendParams: rawJSONMessage("null"), + wantParams: rawJSONMessage("null"), + }, + { + callOpt: jsonrpc2.OmitNilParams(), + sendParams: nil, + wantParams: nil, + }, + { + callOpt: jsonrpc2.OmitNilParams(), + sendParams: rawJSONMessage("null"), + wantParams: rawJSONMessage("null"), + }, + } + + assert := func(got *json.RawMessage, want *json.RawMessage) error { + // Assert pointers. + if got == nil || want == nil { + if got != want { + return fmt.Errorf("got %v, want %v", got, want) + } + return nil + } + { + // If pointers are not nil, then assert values. + got := string(*got) + want := string(*want) + if got != want { + return fmt.Errorf("got %q, want %q", got, want) + } + } + return nil + } + + newClientServer := func(handler jsonrpc2.Handler) (client *jsonrpc2.Conn, server *jsonrpc2.Conn) { + ctx := context.Background() + connA, connB := net.Pipe() + client = jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connA), + noopHandler{}, + ) + server = jsonrpc2.NewConn( + ctx, + jsonrpc2.NewPlainObjectStream(connB), + handler, + ) + return client, server + } + + for i, tc := range testCases { + + t.Run(fmt.Sprintf("test case %v", i), func(t *testing.T) { + t.Run("call", func(t *testing.T) { + handler := jsonrpc2.HandlerWithError(func(ctx context.Context, c *jsonrpc2.Conn, r *jsonrpc2.Request) (result interface{}, err error) { + return nil, assert(r.Params, tc.wantParams) + }) + + client, server := newClientServer(handler) + defer client.Close() + defer server.Close() + + if err := client.Call(context.Background(), "f", tc.sendParams, nil, tc.callOpt); err != nil { + t.Fatal(err) + } + }) + t.Run("notify", func(t *testing.T) { + wg := &sync.WaitGroup{} + handler := handlerFunc(func(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + err := assert(req.Params, tc.wantParams) + if err != nil { + t.Error(err) + } + wg.Done() + }) + + client, server := newClientServer(handler) + defer client.Close() + defer server.Close() + + wg.Add(1) + if err := client.Notify(context.Background(), "f", tc.sendParams, tc.callOpt); err != nil { + t.Fatal(err) + } + wg.Wait() + }) + }) + } +} diff --git a/jsonrpc2.go b/jsonrpc2.go index 0bfcc71..4885b05 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -57,6 +57,9 @@ type Request struct { // NOTE: It is not part of the spec, but there are other protocols based on // JSON-RPC 2 that require it. ExtraFields []RequestField `json:"-"` + // OmitNilParams instructs the SetParams method to not JSON encode a nil + // value and set Params to nil instead. + OmitNilParams bool `json:"-"` } // MarshalJSON implements json.Marshaler and adds the "jsonrpc":"2.0" @@ -159,9 +162,15 @@ func (r *Request) UnmarshalJSON(data []byte) error { return nil } -// SetParams sets r.Params to the JSON representation of v. If JSON -// marshaling fails, it returns an error. +// SetParams sets r.Params to the JSON representation of v. If JSON marshaling +// fails, it returns an error. Beware that the JSON encoding of nil is null. If +// r.OmitNilParams is true and v is nil, then r.Params is set to nil and +// therefore omitted from the JSON-RPC request. func (r *Request) SetParams(v interface{}) error { + if r.OmitNilParams && v == nil { + r.Params = nil + return nil + } b, err := json.Marshal(v) if err != nil { return err @@ -511,9 +520,6 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface // otherwise use Call. func (c *Conn) DispatchCall(ctx context.Context, method string, params interface{}, opts ...CallOption) (Waiter, error) { req := &Request{Method: method} - if err := req.SetParams(params); err != nil { - return Waiter{}, err - } for _, opt := range opts { if opt == nil { continue @@ -522,6 +528,9 @@ func (c *Conn) DispatchCall(ctx context.Context, method string, params interface return Waiter{}, err } } + if err := req.SetParams(params); err != nil { + return Waiter{}, err + } call, err := c.send(ctx, &anyMessage{request: req}, true) if err != nil { return Waiter{}, err @@ -569,9 +578,6 @@ var jsonNull = json.RawMessage("null") // notifications do not have responses). func (c *Conn) Notify(ctx context.Context, method string, params interface{}, opts ...CallOption) error { req := &Request{Method: method, Notif: true} - if err := req.SetParams(params); err != nil { - return err - } for _, opt := range opts { if opt == nil { continue @@ -580,6 +586,9 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}, op return err } } + if err := req.SetParams(params); err != nil { + return err + } _, err := c.send(ctx, &anyMessage{request: req}, false) return err }