diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..c2016bf --- /dev/null +++ b/conn.go @@ -0,0 +1,460 @@ +package jsonrpc2 + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "log" + "os" + "strconv" + "sync" +) + +// Conn is a JSON-RPC client/server connection. The JSON-RPC protocol +// is symmetric, so a Conn runs on both ends of a client-server +// connection. +type Conn struct { + stream ObjectStream + + h Handler + + mu sync.Mutex + closed bool + seq uint64 + pending map[ID]*call + + sending sync.Mutex + + disconnect chan struct{} + + logger Logger + + // Set by ConnOpt funcs. + onRecv []func(*Request, *Response) + onSend []func(*Request, *Response) +} + +var _ JSONRPC2 = (*Conn)(nil) + +// NewConn creates a new JSON-RPC client/server connection using the +// given ReadWriteCloser (typically a TCP connection or stdio). The +// JSON-RPC protocol is symmetric, so a Conn runs on both ends of a +// client-server connection. +// +// NewClient consumes conn, so you should call Close on the returned +// client not on the given conn. +func NewConn(ctx context.Context, stream ObjectStream, h Handler, opts ...ConnOpt) *Conn { + c := &Conn{ + stream: stream, + h: h, + pending: map[ID]*call{}, + disconnect: make(chan struct{}), + logger: log.New(os.Stderr, "", log.LstdFlags), + } + for _, opt := range opts { + if opt == nil { + continue + } + opt(c) + } + go c.readMessages(ctx) + return c +} + +// Close closes the JSON-RPC connection. The connection may not be +// used after it has been closed. +func (c *Conn) Close() error { + return c.close(nil) +} + +// Call initiates a JSON-RPC call using the specified method and +// params, and waits for the response. If the response is successful, +// its result is stored in result (a pointer to a value that can be +// JSON-unmarshaled into); otherwise, a non-nil error is returned. +func (c *Conn) Call(ctx context.Context, method string, params, result interface{}, opts ...CallOption) error { + call, err := c.DispatchCall(ctx, method, params, opts...) + if err != nil { + return err + } + return call.Wait(ctx, result) +} + +// DisconnectNotify returns a channel that is closed when the +// underlying connection is disconnected. +func (c *Conn) DisconnectNotify() <-chan struct{} { + return c.disconnect +} + +// DispatchCall dispatches a JSON-RPC call using the specified method +// and params, and returns a call proxy or an error. Call Wait() +// on the returned proxy to receive the response. Only use this +// function if you need to do work after dispatching the request, +// otherwise use Call. +func (c *Conn) DispatchCall(ctx context.Context, method string, params interface{}, opts ...CallOption) (Waiter, error) { + req := &Request{Method: method} + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.apply(req); err != nil { + return Waiter{}, err + } + } + if err := req.SetParams(params); err != nil { + return Waiter{}, err + } + call, err := c.send(ctx, &anyMessage{request: req}, true) + if err != nil { + return Waiter{}, err + } + return Waiter{call: call}, nil +} + +// Notify is like Call, but it returns when the notification request +// is sent (without waiting for a response, because JSON-RPC +// notifications do not have responses). +func (c *Conn) Notify(ctx context.Context, method string, params interface{}, opts ...CallOption) error { + req := &Request{Method: method, Notif: true} + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.apply(req); err != nil { + return err + } + } + if err := req.SetParams(params); err != nil { + return err + } + _, err := c.send(ctx, &anyMessage{request: req}, false) + return err +} + +// Reply sends a successful response with a result. +func (c *Conn) Reply(ctx context.Context, id ID, result interface{}) error { + resp := &Response{ID: id} + if err := resp.SetResult(result); err != nil { + return err + } + _, err := c.send(ctx, &anyMessage{response: resp}, false) + return err +} + +// ReplyWithError sends a response with an error. +func (c *Conn) ReplyWithError(ctx context.Context, id ID, respErr *Error) error { + _, err := c.send(ctx, &anyMessage{response: &Response{ID: id, Error: respErr}}, false) + return err +} + +// SendResponse sends resp to the peer. It is lower level than (*Conn).Reply. +func (c *Conn) SendResponse(ctx context.Context, resp *Response) error { + _, err := c.send(ctx, &anyMessage{response: resp}, false) + return err +} + +func (c *Conn) close(cause error) error { + c.sending.Lock() + c.mu.Lock() + defer c.sending.Unlock() + defer c.mu.Unlock() + + if c.closed { + return ErrClosed + } + + for _, call := range c.pending { + close(call.done) + } + + if cause != nil && cause != io.EOF && cause != io.ErrUnexpectedEOF { + c.logger.Printf("jsonrpc2: protocol error: %v\n", cause) + } + + close(c.disconnect) + c.closed = true + return c.stream.Close() +} + +func (c *Conn) readMessages(ctx context.Context) { + var err error + for err == nil { + var m anyMessage + err = c.stream.ReadObject(&m) + if err != nil { + break + } + + switch { + case m.request != nil: + for _, onRecv := range c.onRecv { + onRecv(m.request, nil) + } + c.h.Handle(ctx, c, m.request) + + case m.response != nil: + resp := m.response + if resp != nil { + id := resp.ID + c.mu.Lock() + call := c.pending[id] + delete(c.pending, id) + c.mu.Unlock() + + if call != nil { + call.response = resp + } + + if len(c.onRecv) > 0 { + var req *Request + if call != nil { + req = call.request + } + for _, onRecv := range c.onRecv { + onRecv(req, resp) + } + } + + switch { + case call == nil: + c.logger.Printf("jsonrpc2: ignoring response #%s with no corresponding request\n", id) + + case resp.Error != nil: + call.done <- resp.Error + close(call.done) + + default: + call.done <- nil + close(call.done) + } + } + } + } + c.close(err) +} + +func (c *Conn) send(_ context.Context, m *anyMessage, wait bool) (cc *call, err error) { + c.sending.Lock() + defer c.sending.Unlock() + + // m.request.ID could be changed, so we store a copy to correctly + // clean up pending + var id ID + + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, ErrClosed + } + + // Assign a default id if not set + if m.request != nil && wait { + cc = &call{request: m.request, seq: c.seq, done: make(chan error, 1)} + + isIDUnset := len(m.request.ID.Str) == 0 && m.request.ID.Num == 0 + if isIDUnset { + if m.request.ID.IsString { + m.request.ID.Str = strconv.FormatUint(c.seq, 10) + } else { + m.request.ID.Num = c.seq + } + } + c.seq++ + } + c.mu.Unlock() + + if len(c.onSend) > 0 { + var ( + req *Request + resp *Response + ) + switch { + case m.request != nil: + req = m.request + case m.response != nil: + resp = m.response + } + for _, onSend := range c.onSend { + onSend(req, resp) + } + } + + // Store requests so we can later associate them with incoming + // responses. + if m.request != nil && wait { + c.mu.Lock() + id = m.request.ID + c.pending[id] = cc + c.mu.Unlock() + } + + // From here on, if we fail to send this, then we need to remove + // this from the pending map so we don't block on it or pile up + // pending entries for unsent messages. + defer func() { + if err != nil { + if cc != nil { + c.mu.Lock() + delete(c.pending, id) + c.mu.Unlock() + } + } + }() + + if err := c.stream.WriteObject(m); err != nil { + return nil, err + } + return cc, nil +} + +// Waiter proxies an ongoing JSON-RPC call. +type Waiter struct { + *call +} + +// Wait for the result of an ongoing JSON-RPC call. If the response +// is successful, its result is stored in result (a pointer to a +// value that can be JSON-unmarshaled into); otherwise, a non-nil +// error is returned. +func (w Waiter) Wait(ctx context.Context, result interface{}) error { + select { + case err, ok := <-w.call.done: + if !ok { + err = ErrClosed + } + if err != nil { + return err + } + if result != nil { + if w.call.response.Result == nil { + w.call.response.Result = &jsonNull + } + if err := json.Unmarshal(*w.call.response.Result, result); err != nil { + return err + } + } + return nil + + case <-ctx.Done(): + return ctx.Err() + } +} + +// call represents a JSON-RPC call over its entire lifecycle. +type call struct { + request *Request + response *Response + seq uint64 // the seq of the request + done chan error +} + +// anyMessage represents either a JSON Request or Response. +type anyMessage struct { + request *Request + response *Response +} + +func (m anyMessage) MarshalJSON() ([]byte, error) { + var v interface{} + switch { + case m.request != nil && m.response == nil: + v = m.request + case m.request == nil && m.response != nil: + v = m.response + } + if v != nil { + return json.Marshal(v) + } + return nil, errors.New("jsonrpc2: message must have exactly one of the request or response fields set") +} + +func (m *anyMessage) UnmarshalJSON(data []byte) error { + // The presence of these fields distinguishes between the 2 + // message types. + type msg struct { + ID interface{} `json:"id"` + Method *string `json:"method"` + Result anyValueWithExplicitNull `json:"result"` + Error interface{} `json:"error"` + } + + var isRequest, isResponse bool + checkType := func(m *msg) error { + mIsRequest := m.Method != nil + mIsResponse := m.Result.null || m.Result.value != nil || m.Error != nil + if (!mIsRequest && !mIsResponse) || (mIsRequest && mIsResponse) { + return errors.New("jsonrpc2: unable to determine message type (request or response)") + } + if (mIsRequest && isResponse) || (mIsResponse && isRequest) { + return errors.New("jsonrpc2: batch message type mismatch (must be all requests or all responses)") + } + isRequest = mIsRequest + isResponse = mIsResponse + return nil + } + + if isArray := len(data) > 0 && data[0] == '['; isArray { + var msgs []msg + if err := json.Unmarshal(data, &msgs); err != nil { + return err + } + if len(msgs) == 0 { + return errors.New("jsonrpc2: invalid empty batch") + } + for i := range msgs { + if err := checkType(&msg{ + ID: msgs[i].ID, + Method: msgs[i].Method, + Result: msgs[i].Result, + Error: msgs[i].Error, + }); err != nil { + return err + } + } + } else { + var m msg + if err := json.Unmarshal(data, &m); err != nil { + return err + } + if err := checkType(&m); err != nil { + return err + } + } + + var v interface{} + switch { + case isRequest && !isResponse: + v = &m.request + case !isRequest && isResponse: + v = &m.response + } + if err := json.Unmarshal(data, v); err != nil { + return err + } + if !isRequest && isResponse && m.response.Error == nil && m.response.Result == nil { + m.response.Result = &jsonNull + } + return nil +} + +// anyValueWithExplicitNull is used to distinguish {} from +// {"result":null} by anyMessage's JSON unmarshaler. +type anyValueWithExplicitNull struct { + null bool // JSON "null" + value interface{} +} + +func (v anyValueWithExplicitNull) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *anyValueWithExplicitNull) UnmarshalJSON(data []byte) error { + data = bytes.TrimSpace(data) + if string(data) == "null" { + *v = anyValueWithExplicitNull{null: true} + return nil + } + *v = anyValueWithExplicitNull{} + return json.Unmarshal(data, &v.value) +} diff --git a/conn_test.go b/conn_test.go new file mode 100644 index 0000000..773c5cb --- /dev/null +++ b/conn_test.go @@ -0,0 +1,103 @@ +package jsonrpc2_test + +import ( + "context" + "io" + "net" + "testing" + "time" + + "github.com/sourcegraph/jsonrpc2" +) + +func TestConn_DisconnectNotify(t *testing.T) { + + t.Run("EOF", func(t *testing.T) { + connA, connB := net.Pipe() + c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) + // By closing connA, connB receives io.EOF + if err := connA.Close(); err != nil { + t.Error(err) + } + assertDisconnect(t, c, connB) + }) + + t.Run("Close", func(t *testing.T) { + _, connB := net.Pipe() + c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) + if err := c.Close(); err != nil { + t.Error(err) + } + assertDisconnect(t, c, connB) + }) + + t.Run("Close async", func(t *testing.T) { + done := make(chan struct{}) + _, connB := net.Pipe() + c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) + go func() { + if err := c.Close(); err != nil && err != jsonrpc2.ErrClosed { + t.Error(err) + } + close(done) + }() + assertDisconnect(t, c, connB) + <-done + }) + + t.Run("protocol error", func(t *testing.T) { + connA, connB := net.Pipe() + c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) + connA.Write([]byte("invalid json")) + assertDisconnect(t, c, connB) + }) +} + +func TestConn_Close(t *testing.T) { + t.Run("waiting for response", func(t *testing.T) { + connA, connB := net.Pipe() + nodeA := jsonrpc2.NewConn( + context.Background(), + jsonrpc2.NewPlainObjectStream(connA), noopHandler{}, + ) + defer nodeA.Close() + nodeB := jsonrpc2.NewConn( + context.Background(), + jsonrpc2.NewPlainObjectStream(connB), + noopHandler{}, + ) + defer nodeB.Close() + + ready := make(chan struct{}) + done := make(chan struct{}) + go func() { + close(ready) + err := nodeB.Call(context.Background(), "m", nil, nil) + if err != jsonrpc2.ErrClosed { + t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed) + } + close(done) + }() + // Wait for the request to be sent before we close the connection. + <-ready + if err := nodeB.Close(); err != nil && err != jsonrpc2.ErrClosed { + t.Error(err) + } + assertDisconnect(t, nodeB, connB) + <-done + }) +} + +func assertDisconnect(t *testing.T, c *jsonrpc2.Conn, conn io.Writer) { + select { + case <-c.DisconnectNotify(): + case <-time.After(200 * time.Millisecond): + t.Fatal("no disconnect notification") + } + // Assert that conn is closed by trying to write to it. + _, got := conn.Write(nil) + want := io.ErrClosedPipe + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} diff --git a/internal_test.go b/internal_test.go new file mode 100644 index 0000000..990fb24 --- /dev/null +++ b/internal_test.go @@ -0,0 +1,35 @@ +package jsonrpc2 + +import ( + "encoding/json" + "testing" +) + +func TestAnyMessage(t *testing.T) { + tests := map[string]struct { + request, response, invalid bool + }{ + // Single messages + `{}`: {invalid: true}, + `{"foo":"bar"}`: {invalid: true}, + `{"method":"m"}`: {request: true}, + `{"result":123}`: {response: true}, + `{"result":null}`: {response: true}, + `{"error":{"code":456,"message":"m"}}`: {response: true}, + } + for s, want := range tests { + var m anyMessage + if err := json.Unmarshal([]byte(s), &m); err != nil { + if !want.invalid { + t.Errorf("%s: error: %v", s, err) + } + continue + } + if (m.request != nil) != want.request { + t.Errorf("%s: got request %v, want %v", s, m.request != nil, want.request) + } + if (m.response != nil) != want.response { + t.Errorf("%s: got response %v, want %v", s, m.response != nil, want.response) + } + } +} diff --git a/jsonrpc2.go b/jsonrpc2.go index 32bc98c..17e1a59 100644 --- a/jsonrpc2.go +++ b/jsonrpc2.go @@ -3,16 +3,11 @@ package jsonrpc2 import ( - "bytes" "context" "encoding/json" "errors" "fmt" - "io" - "log" - "os" "strconv" - "sync" ) // JSONRPC2 describes an interface for issuing requests that speak the @@ -30,247 +25,6 @@ type JSONRPC2 interface { Close() error } -// RequestField is a top-level field that can be added to the JSON-RPC request. -type RequestField struct { - Name string - Value interface{} -} - -// Request represents a JSON-RPC request or -// notification. See -// http://www.jsonrpc.org/specification#request_object and -// http://www.jsonrpc.org/specification#notification. -type Request struct { - Method string `json:"method"` - Params *json.RawMessage `json:"params,omitempty"` - ID ID `json:"id"` - Notif bool `json:"-"` - - // Meta optionally provides metadata to include in the request. - // - // NOTE: It is not part of spec. However, it is useful for propogating - // tracing context, etc. - Meta *json.RawMessage `json:"meta,omitempty"` - - // ExtraFields optionally adds fields to the root of the JSON-RPC request. - // - // NOTE: It is not part of the spec, but there are other protocols based on - // JSON-RPC 2 that require it. - ExtraFields []RequestField `json:"-"` - // OmitNilParams instructs the SetParams method to not JSON encode a nil - // value and set Params to nil instead. - OmitNilParams bool `json:"-"` -} - -// MarshalJSON implements json.Marshaler and adds the "jsonrpc":"2.0" -// property. -func (r Request) MarshalJSON() ([]byte, error) { - r2 := map[string]interface{}{ - "jsonrpc": "2.0", - "method": r.Method, - } - for _, field := range r.ExtraFields { - r2[field.Name] = field.Value - } - if !r.Notif { - r2["id"] = &r.ID - } - if r.Params != nil { - r2["params"] = r.Params - } - if r.Meta != nil { - r2["meta"] = r.Meta - } - return json.Marshal(r2) -} - -// UnmarshalJSON implements json.Unmarshaler. -func (r *Request) UnmarshalJSON(data []byte) error { - r2 := make(map[string]interface{}) - - // Detect if the "params" or "meta" fields are JSON "null" or just not - // present by seeing if the field gets overwritten to nil. - emptyParams := &json.RawMessage{} - r2["params"] = emptyParams - emptyMeta := &json.RawMessage{} - r2["meta"] = emptyMeta - - decoder := json.NewDecoder(bytes.NewReader(data)) - decoder.UseNumber() - if err := decoder.Decode(&r2); err != nil { - return err - } - var ok bool - r.Method, ok = r2["method"].(string) - if !ok { - return errors.New("missing method field") - } - switch { - case r2["params"] == nil: - r.Params = &jsonNull - case r2["params"] == emptyParams: - r.Params = nil - default: - b, err := json.Marshal(r2["params"]) - if err != nil { - return fmt.Errorf("failed to marshal params: %w", err) - } - r.Params = (*json.RawMessage)(&b) - } - switch { - case r2["meta"] == nil: - r.Meta = &jsonNull - case r2["meta"] == emptyMeta: - r.Meta = nil - default: - b, err := json.Marshal(r2["meta"]) - if err != nil { - return fmt.Errorf("failed to marshal Meta: %w", err) - } - r.Meta = (*json.RawMessage)(&b) - } - switch rawID := r2["id"].(type) { - case nil: - r.ID = ID{} - r.Notif = true - case string: - r.ID = ID{Str: rawID, IsString: true} - r.Notif = false - case json.Number: - id, err := rawID.Int64() - if err != nil { - return fmt.Errorf("failed to unmarshal ID: %w", err) - } - r.ID = ID{Num: uint64(id)} - r.Notif = false - default: - return fmt.Errorf("unexpected ID type: %T", rawID) - } - - // Clear the extra fields before populating them again. - r.ExtraFields = nil - for name, value := range r2 { - switch name { - case "id", "jsonrpc", "meta", "method", "params": - continue - } - r.ExtraFields = append(r.ExtraFields, RequestField{ - Name: name, - Value: value, - }) - } - return nil -} - -// SetParams sets r.Params to the JSON representation of v. If JSON marshaling -// fails, it returns an error. Beware that the JSON encoding of nil is null. If -// r.OmitNilParams is true and v is nil, then r.Params is set to nil and -// therefore omitted from the JSON-RPC request. -func (r *Request) SetParams(v interface{}) error { - if r.OmitNilParams && v == nil { - r.Params = nil - return nil - } - b, err := json.Marshal(v) - if err != nil { - return err - } - r.Params = (*json.RawMessage)(&b) - return nil -} - -// SetMeta sets r.Meta to the JSON representation of v. If JSON -// marshaling fails, it returns an error. -func (r *Request) SetMeta(v interface{}) error { - b, err := json.Marshal(v) - if err != nil { - return err - } - r.Meta = (*json.RawMessage)(&b) - return nil -} - -// SetExtraField adds an entry to r.ExtraFields, so that it is added to the -// JSON representation of the request, as a way to add arbitrary extensions to -// JSON RPC 2.0. If JSON marshaling fails, it returns an error. -func (r *Request) SetExtraField(name string, v interface{}) error { - switch name { - case "id", "jsonrpc", "meta", "method", "params": - return fmt.Errorf("invalid extra field %q", name) - } - r.ExtraFields = append(r.ExtraFields, RequestField{ - Name: name, - Value: v, - }) - return nil -} - -// Response represents a JSON-RPC response. See -// http://www.jsonrpc.org/specification#response_object. -type Response struct { - ID ID `json:"id"` - Result *json.RawMessage `json:"result,omitempty"` - Error *Error `json:"error,omitempty"` - - // Meta optionally provides metadata to include in the response. - // - // NOTE: It is not part of spec. However, it is useful for propogating - // tracing context, etc. - Meta *json.RawMessage `json:"meta,omitempty"` - - // SPEC NOTE: The spec says "If there was an error in detecting - // the id in the Request object (e.g. Parse error/Invalid - // Request), it MUST be Null." If we made the ID field nullable, - // then we'd have to make it a pointer type. For simplicity, we're - // ignoring the case where there was an error in detecting the ID - // in the Request object. -} - -// MarshalJSON implements json.Marshaler and adds the "jsonrpc":"2.0" -// property. -func (r Response) MarshalJSON() ([]byte, error) { - if (r.Result == nil || len(*r.Result) == 0) && r.Error == nil { - return nil, errors.New("can't marshal *jsonrpc2.Response (must have result or error)") - } - type tmpType Response // avoid infinite MarshalJSON recursion - b, err := json.Marshal(tmpType(r)) - if err != nil { - return nil, err - } - b = append(b[:len(b)-1], []byte(`,"jsonrpc":"2.0"}`)...) - return b, nil -} - -// UnmarshalJSON implements json.Unmarshaler. -func (r *Response) UnmarshalJSON(data []byte) error { - type tmpType Response - - // Detect if the "result" field is JSON "null" or just not present - // by seeing if the field gets overwritten to nil. - *r = Response{Result: &json.RawMessage{}} - - if err := json.Unmarshal(data, (*tmpType)(r)); err != nil { - return err - } - if r.Result == nil { // JSON "null" - r.Result = &jsonNull - } else if len(*r.Result) == 0 { - r.Result = nil - } - return nil -} - -// SetResult sets r.Result to the JSON representation of v. If JSON -// marshaling fails, it returns an error. -func (r *Response) SetResult(v interface{}) error { - b, err := json.Marshal(v) - if err != nil { - return err - } - r.Result = (*json.RawMessage)(&b) - return nil -} - // Error represents a JSON-RPC response error. type Error struct { Code int64 `json:"code"` @@ -358,455 +112,8 @@ func (id *ID) UnmarshalJSON(data []byte) error { return nil } -// Conn is a JSON-RPC client/server connection. The JSON-RPC protocol -// is symmetric, so a Conn runs on both ends of a client-server -// connection. -type Conn struct { - stream ObjectStream - - h Handler - - mu sync.Mutex - closed bool - seq uint64 - pending map[ID]*call - - sending sync.Mutex - - disconnect chan struct{} - - logger Logger - - // Set by ConnOpt funcs. - onRecv []func(*Request, *Response) - onSend []func(*Request, *Response) -} - -var _ JSONRPC2 = (*Conn)(nil) - // ErrClosed indicates that the JSON-RPC connection is closed (or in // the process of closing). var ErrClosed = errors.New("jsonrpc2: connection is closed") -// NewConn creates a new JSON-RPC client/server connection using the -// given ReadWriteCloser (typically a TCP connection or stdio). The -// JSON-RPC protocol is symmetric, so a Conn runs on both ends of a -// client-server connection. -// -// NewClient consumes conn, so you should call Close on the returned -// client not on the given conn. -func NewConn(ctx context.Context, stream ObjectStream, h Handler, opts ...ConnOpt) *Conn { - c := &Conn{ - stream: stream, - h: h, - pending: map[ID]*call{}, - disconnect: make(chan struct{}), - logger: log.New(os.Stderr, "", log.LstdFlags), - } - for _, opt := range opts { - if opt == nil { - continue - } - opt(c) - } - go c.readMessages(ctx) - return c -} - -// Close closes the JSON-RPC connection. The connection may not be -// used after it has been closed. -func (c *Conn) Close() error { - return c.close(nil) -} - -func (c *Conn) close(cause error) error { - c.sending.Lock() - c.mu.Lock() - defer c.sending.Unlock() - defer c.mu.Unlock() - - if c.closed { - return ErrClosed - } - - for _, call := range c.pending { - close(call.done) - } - - if cause != nil && cause != io.EOF && cause != io.ErrUnexpectedEOF { - c.logger.Printf("jsonrpc2: protocol error: %v\n", cause) - } - - close(c.disconnect) - c.closed = true - return c.stream.Close() -} - -func (c *Conn) send(_ context.Context, m *anyMessage, wait bool) (cc *call, err error) { - c.sending.Lock() - defer c.sending.Unlock() - - // m.request.ID could be changed, so we store a copy to correctly - // clean up pending - var id ID - - c.mu.Lock() - if c.closed { - c.mu.Unlock() - return nil, ErrClosed - } - - // Assign a default id if not set - if m.request != nil && wait { - cc = &call{request: m.request, seq: c.seq, done: make(chan error, 1)} - - isIDUnset := len(m.request.ID.Str) == 0 && m.request.ID.Num == 0 - if isIDUnset { - if m.request.ID.IsString { - m.request.ID.Str = strconv.FormatUint(c.seq, 10) - } else { - m.request.ID.Num = c.seq - } - } - c.seq++ - } - c.mu.Unlock() - - if len(c.onSend) > 0 { - var ( - req *Request - resp *Response - ) - switch { - case m.request != nil: - req = m.request - case m.response != nil: - resp = m.response - } - for _, onSend := range c.onSend { - onSend(req, resp) - } - } - - // Store requests so we can later associate them with incoming - // responses. - if m.request != nil && wait { - c.mu.Lock() - id = m.request.ID - c.pending[id] = cc - c.mu.Unlock() - } - - // From here on, if we fail to send this, then we need to remove - // this from the pending map so we don't block on it or pile up - // pending entries for unsent messages. - defer func() { - if err != nil { - if cc != nil { - c.mu.Lock() - delete(c.pending, id) - c.mu.Unlock() - } - } - }() - - if err := c.stream.WriteObject(m); err != nil { - return nil, err - } - return cc, nil -} - -// Call initiates a JSON-RPC call using the specified method and -// params, and waits for the response. If the response is successful, -// its result is stored in result (a pointer to a value that can be -// JSON-unmarshaled into); otherwise, a non-nil error is returned. -func (c *Conn) Call(ctx context.Context, method string, params, result interface{}, opts ...CallOption) error { - call, err := c.DispatchCall(ctx, method, params, opts...) - if err != nil { - return err - } - return call.Wait(ctx, result) -} - -// DispatchCall dispatches a JSON-RPC call using the specified method -// and params, and returns a call proxy or an error. Call Wait() -// on the returned proxy to receive the response. Only use this -// function if you need to do work after dispatching the request, -// otherwise use Call. -func (c *Conn) DispatchCall(ctx context.Context, method string, params interface{}, opts ...CallOption) (Waiter, error) { - req := &Request{Method: method} - for _, opt := range opts { - if opt == nil { - continue - } - if err := opt.apply(req); err != nil { - return Waiter{}, err - } - } - if err := req.SetParams(params); err != nil { - return Waiter{}, err - } - call, err := c.send(ctx, &anyMessage{request: req}, true) - if err != nil { - return Waiter{}, err - } - return Waiter{call: call}, nil -} - -// Waiter proxies an ongoing JSON-RPC call. -type Waiter struct { - *call -} - -// Wait for the result of an ongoing JSON-RPC call. If the response -// is successful, its result is stored in result (a pointer to a -// value that can be JSON-unmarshaled into); otherwise, a non-nil -// error is returned. -func (w Waiter) Wait(ctx context.Context, result interface{}) error { - select { - case err, ok := <-w.call.done: - if !ok { - err = ErrClosed - } - if err != nil { - return err - } - if result != nil { - if w.call.response.Result == nil { - w.call.response.Result = &jsonNull - } - if err := json.Unmarshal(*w.call.response.Result, result); err != nil { - return err - } - } - return nil - - case <-ctx.Done(): - return ctx.Err() - } -} - var jsonNull = json.RawMessage("null") - -// Notify is like Call, but it returns when the notification request -// is sent (without waiting for a response, because JSON-RPC -// notifications do not have responses). -func (c *Conn) Notify(ctx context.Context, method string, params interface{}, opts ...CallOption) error { - req := &Request{Method: method, Notif: true} - for _, opt := range opts { - if opt == nil { - continue - } - if err := opt.apply(req); err != nil { - return err - } - } - if err := req.SetParams(params); err != nil { - return err - } - _, err := c.send(ctx, &anyMessage{request: req}, false) - return err -} - -// Reply sends a successful response with a result. -func (c *Conn) Reply(ctx context.Context, id ID, result interface{}) error { - resp := &Response{ID: id} - if err := resp.SetResult(result); err != nil { - return err - } - _, err := c.send(ctx, &anyMessage{response: resp}, false) - return err -} - -// ReplyWithError sends a response with an error. -func (c *Conn) ReplyWithError(ctx context.Context, id ID, respErr *Error) error { - _, err := c.send(ctx, &anyMessage{response: &Response{ID: id, Error: respErr}}, false) - return err -} - -// SendResponse sends resp to the peer. It is lower level than (*Conn).Reply. -func (c *Conn) SendResponse(ctx context.Context, resp *Response) error { - _, err := c.send(ctx, &anyMessage{response: resp}, false) - return err -} - -// DisconnectNotify returns a channel that is closed when the -// underlying connection is disconnected. -func (c *Conn) DisconnectNotify() <-chan struct{} { - return c.disconnect -} - -func (c *Conn) readMessages(ctx context.Context) { - var err error - for err == nil { - var m anyMessage - err = c.stream.ReadObject(&m) - if err != nil { - break - } - - switch { - case m.request != nil: - for _, onRecv := range c.onRecv { - onRecv(m.request, nil) - } - c.h.Handle(ctx, c, m.request) - - case m.response != nil: - resp := m.response - if resp != nil { - id := resp.ID - c.mu.Lock() - call := c.pending[id] - delete(c.pending, id) - c.mu.Unlock() - - if call != nil { - call.response = resp - } - - if len(c.onRecv) > 0 { - var req *Request - if call != nil { - req = call.request - } - for _, onRecv := range c.onRecv { - onRecv(req, resp) - } - } - - switch { - case call == nil: - c.logger.Printf("jsonrpc2: ignoring response #%s with no corresponding request\n", id) - - case resp.Error != nil: - call.done <- resp.Error - close(call.done) - - default: - call.done <- nil - close(call.done) - } - } - } - } - c.close(err) -} - -// call represents a JSON-RPC call over its entire lifecycle. -type call struct { - request *Request - response *Response - seq uint64 // the seq of the request - done chan error -} - -// anyMessage represents either a JSON Request or Response. -type anyMessage struct { - request *Request - response *Response -} - -func (m anyMessage) MarshalJSON() ([]byte, error) { - var v interface{} - switch { - case m.request != nil && m.response == nil: - v = m.request - case m.request == nil && m.response != nil: - v = m.response - } - if v != nil { - return json.Marshal(v) - } - return nil, errors.New("jsonrpc2: message must have exactly one of the request or response fields set") -} - -func (m *anyMessage) UnmarshalJSON(data []byte) error { - // The presence of these fields distinguishes between the 2 - // message types. - type msg struct { - ID interface{} `json:"id"` - Method *string `json:"method"` - Result anyValueWithExplicitNull `json:"result"` - Error interface{} `json:"error"` - } - - var isRequest, isResponse bool - checkType := func(m *msg) error { - mIsRequest := m.Method != nil - mIsResponse := m.Result.null || m.Result.value != nil || m.Error != nil - if (!mIsRequest && !mIsResponse) || (mIsRequest && mIsResponse) { - return errors.New("jsonrpc2: unable to determine message type (request or response)") - } - if (mIsRequest && isResponse) || (mIsResponse && isRequest) { - return errors.New("jsonrpc2: batch message type mismatch (must be all requests or all responses)") - } - isRequest = mIsRequest - isResponse = mIsResponse - return nil - } - - if isArray := len(data) > 0 && data[0] == '['; isArray { - var msgs []msg - if err := json.Unmarshal(data, &msgs); err != nil { - return err - } - if len(msgs) == 0 { - return errors.New("jsonrpc2: invalid empty batch") - } - for i := range msgs { - if err := checkType(&msg{ - ID: msgs[i].ID, - Method: msgs[i].Method, - Result: msgs[i].Result, - Error: msgs[i].Error, - }); err != nil { - return err - } - } - } else { - var m msg - if err := json.Unmarshal(data, &m); err != nil { - return err - } - if err := checkType(&m); err != nil { - return err - } - } - - var v interface{} - switch { - case isRequest && !isResponse: - v = &m.request - case !isRequest && isResponse: - v = &m.response - } - if err := json.Unmarshal(data, v); err != nil { - return err - } - if !isRequest && isResponse && m.response.Error == nil && m.response.Result == nil { - m.response.Result = &jsonNull - } - return nil -} - -// anyValueWithExplicitNull is used to distinguish {} from -// {"result":null} by anyMessage's JSON unmarshaler. -type anyValueWithExplicitNull struct { - null bool // JSON "null" - value interface{} -} - -func (v anyValueWithExplicitNull) MarshalJSON() ([]byte, error) { - return json.Marshal(v.value) -} - -func (v *anyValueWithExplicitNull) UnmarshalJSON(data []byte) error { - data = bytes.TrimSpace(data) - if string(data) == "null" { - *v = anyValueWithExplicitNull{null: true} - return nil - } - *v = anyValueWithExplicitNull{} - return json.Unmarshal(data, &v.value) -} diff --git a/jsonrpc2_test.go b/jsonrpc2_test.go index b68f3c1..1a73744 100644 --- a/jsonrpc2_test.go +++ b/jsonrpc2_test.go @@ -1,7 +1,6 @@ package jsonrpc2_test import ( - "bytes" "context" "encoding/json" "fmt" @@ -19,66 +18,6 @@ import ( websocketjsonrpc2 "github.com/sourcegraph/jsonrpc2/websocket" ) -func TestRequest_MarshalJSON_jsonrpc(t *testing.T) { - b, err := json.Marshal(&jsonrpc2.Request{}) - if err != nil { - t.Fatal(err) - } - if want := `{"id":0,"jsonrpc":"2.0","method":""}`; string(b) != want { - t.Errorf("got %q, want %q", b, want) - } -} - -func TestResponse_MarshalJSON_jsonrpc(t *testing.T) { - null := json.RawMessage("null") - b, err := json.Marshal(&jsonrpc2.Response{Result: &null}) - if err != nil { - t.Fatal(err) - } - if want := `{"id":0,"result":null,"jsonrpc":"2.0"}`; string(b) != want { - t.Errorf("got %q, want %q", b, want) - } -} - -func TestResponseMarshalJSON_Notif(t *testing.T) { - tests := map[*jsonrpc2.Request]bool{ - {ID: jsonrpc2.ID{Num: 0}}: true, - {ID: jsonrpc2.ID{Num: 1}}: true, - {ID: jsonrpc2.ID{Str: "", IsString: true}}: true, - {ID: jsonrpc2.ID{Str: "a", IsString: true}}: true, - {Notif: true}: false, - } - for r, wantIDKey := range tests { - b, err := json.Marshal(r) - if err != nil { - t.Fatal(err) - } - hasIDKey := bytes.Contains(b, []byte(`"id"`)) - if hasIDKey != wantIDKey { - t.Errorf("got %s, want contain id key: %v", b, wantIDKey) - } - } -} - -func TestResponseUnmarshalJSON_Notif(t *testing.T) { - tests := map[string]bool{ - `{"method":"f","id":0}`: false, - `{"method":"f","id":1}`: false, - `{"method":"f","id":"a"}`: false, - `{"method":"f","id":""}`: false, - `{"method":"f"}`: true, - } - for s, want := range tests { - var r jsonrpc2.Request - if err := json.Unmarshal([]byte(s), &r); err != nil { - t.Fatal(err) - } - if r.Notif != want { - t.Errorf("%s: got %v, want %v", s, r.Notif, want) - } - } -} - // testHandlerA is the "server" handler. type testHandlerA struct{ t *testing.T } @@ -314,84 +253,6 @@ type noopHandler struct{} func (noopHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {} -func TestConn_DisconnectNotify(t *testing.T) { - - t.Run("EOF", func(t *testing.T) { - connA, connB := net.Pipe() - c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) - // By closing connA, connB receives io.EOF - if err := connA.Close(); err != nil { - t.Error(err) - } - assertDisconnect(t, c, connB) - }) - - t.Run("Close", func(t *testing.T) { - _, connB := net.Pipe() - c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) - if err := c.Close(); err != nil { - t.Error(err) - } - assertDisconnect(t, c, connB) - }) - - t.Run("Close async", func(t *testing.T) { - done := make(chan struct{}) - _, connB := net.Pipe() - c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) - go func() { - if err := c.Close(); err != nil && err != jsonrpc2.ErrClosed { - t.Error(err) - } - close(done) - }() - assertDisconnect(t, c, connB) - <-done - }) - - t.Run("protocol error", func(t *testing.T) { - connA, connB := net.Pipe() - c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil) - connA.Write([]byte("invalid json")) - assertDisconnect(t, c, connB) - }) -} - -func TestConn_Close(t *testing.T) { - t.Run("waiting for response", func(t *testing.T) { - connA, connB := net.Pipe() - nodeA := jsonrpc2.NewConn( - context.Background(), - jsonrpc2.NewPlainObjectStream(connA), noopHandler{}, - ) - defer nodeA.Close() - nodeB := jsonrpc2.NewConn( - context.Background(), - jsonrpc2.NewPlainObjectStream(connB), - noopHandler{}, - ) - defer nodeB.Close() - - ready := make(chan struct{}) - done := make(chan struct{}) - go func() { - close(ready) - err := nodeB.Call(context.Background(), "m", nil, nil) - if err != jsonrpc2.ErrClosed { - t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed) - } - close(done) - }() - // Wait for the request to be sent before we close the connection. - <-ready - if err := nodeB.Close(); err != nil && err != jsonrpc2.ErrClosed { - t.Error(err) - } - assertDisconnect(t, nodeB, connB) - <-done - }) -} - func serve(ctx context.Context, lis net.Listener, h jsonrpc2.Handler, streamMaker streamMaker, opts ...jsonrpc2.ConnOpt) error { for { conn, err := lis.Accept() @@ -401,17 +262,3 @@ func serve(ctx context.Context, lis net.Listener, h jsonrpc2.Handler, streamMake jsonrpc2.NewConn(ctx, streamMaker(conn), h, opts...) } } - -func assertDisconnect(t *testing.T, c *jsonrpc2.Conn, conn io.Writer) { - select { - case <-c.DisconnectNotify(): - case <-time.After(200 * time.Millisecond): - t.Fatal("no disconnect notification") - } - // Assert that conn is closed by trying to write to it. - _, got := conn.Write(nil) - want := io.ErrClosedPipe - if got != want { - t.Fatalf("got %q, want %q", got, want) - } -} diff --git a/object_test.go b/object_test.go deleted file mode 100644 index cfa5b00..0000000 --- a/object_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package jsonrpc2 - -import ( - "bytes" - "encoding/json" - "reflect" - "testing" -) - -func TestAnyMessage(t *testing.T) { - tests := map[string]struct { - request, response, invalid bool - }{ - // Single messages - `{}`: {invalid: true}, - `{"foo":"bar"}`: {invalid: true}, - `{"method":"m"}`: {request: true}, - `{"result":123}`: {response: true}, - `{"result":null}`: {response: true}, - `{"error":{"code":456,"message":"m"}}`: {response: true}, - } - for s, want := range tests { - var m anyMessage - if err := json.Unmarshal([]byte(s), &m); err != nil { - if !want.invalid { - t.Errorf("%s: error: %v", s, err) - } - continue - } - if (m.request != nil) != want.request { - t.Errorf("%s: got request %v, want %v", s, m.request != nil, want.request) - } - if (m.response != nil) != want.response { - t.Errorf("%s: got response %v, want %v", s, m.response != nil, want.response) - } - } -} - -func TestRequest_MarshalUnmarshalJSON(t *testing.T) { - null := json.RawMessage("null") - obj := json.RawMessage(`{"foo":"bar"}`) - tests := []struct { - data []byte - want Request - }{ - { - data: []byte(`{"id":123,"jsonrpc":"2.0","method":"m","params":{"foo":"bar"}}`), - want: Request{ID: ID{Num: 123}, Method: "m", Params: &obj}, - }, - { - data: []byte(`{"id":123,"jsonrpc":"2.0","method":"m","params":null}`), - want: Request{ID: ID{Num: 123}, Method: "m", Params: &null}, - }, - { - data: []byte(`{"id":123,"jsonrpc":"2.0","method":"m"}`), - want: Request{ID: ID{Num: 123}, Method: "m", Params: nil}, - }, - { - data: []byte(`{"id":123,"jsonrpc":"2.0","method":"m","sessionId":"session"}`), - want: Request{ID: ID{Num: 123}, Method: "m", Params: nil, ExtraFields: []RequestField{{Name: "sessionId", Value: "session"}}}, - }, - } - for _, test := range tests { - var got Request - if err := json.Unmarshal(test.data, &got); err != nil { - t.Error(err) - continue - } - if !reflect.DeepEqual(got, test.want) { - t.Errorf("%q: got %+v, want %+v", test.data, got, test.want) - continue - } - data, err := json.Marshal(got) - if err != nil { - t.Error(err) - continue - } - if !bytes.Equal(data, test.data) { - t.Errorf("got JSON %q, want %q", data, test.data) - } - } -} - -func TestResponse_MarshalUnmarshalJSON(t *testing.T) { - null := json.RawMessage("null") - obj := json.RawMessage(`{"foo":"bar"}`) - tests := []struct { - data []byte - want Response - error bool - }{ - { - data: []byte(`{"id":123,"result":{"foo":"bar"},"jsonrpc":"2.0"}`), - want: Response{ID: ID{Num: 123}, Result: &obj}, - }, - { - data: []byte(`{"id":123,"result":null,"jsonrpc":"2.0"}`), - want: Response{ID: ID{Num: 123}, Result: &null}, - }, - { - data: []byte(`{"id":123,"jsonrpc":"2.0"}`), - want: Response{ID: ID{Num: 123}, Result: nil}, - error: true, // either result or error field must be set - }, - } - for _, test := range tests { - var got Response - if err := json.Unmarshal(test.data, &got); err != nil { - t.Error(err) - continue - } - if !reflect.DeepEqual(got, test.want) { - t.Errorf("%q: got %+v, want %+v", test.data, got, test.want) - continue - } - data, err := json.Marshal(got) - if err != nil { - if test.error { - continue - } - t.Error(err) - continue - } - if test.error { - t.Errorf("%q: expected error", test.data) - continue - } - if !bytes.Equal(data, test.data) { - t.Errorf("got JSON %q, want %q", data, test.data) - } - } -} diff --git a/request.go b/request.go new file mode 100644 index 0000000..666573b --- /dev/null +++ b/request.go @@ -0,0 +1,183 @@ +package jsonrpc2 + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" +) + +// Request represents a JSON-RPC request or +// notification. See +// http://www.jsonrpc.org/specification#request_object and +// http://www.jsonrpc.org/specification#notification. +type Request struct { + Method string `json:"method"` + Params *json.RawMessage `json:"params,omitempty"` + ID ID `json:"id"` + Notif bool `json:"-"` + + // Meta optionally provides metadata to include in the request. + // + // NOTE: It is not part of spec. However, it is useful for propagating + // tracing context, etc. + Meta *json.RawMessage `json:"meta,omitempty"` + + // ExtraFields optionally adds fields to the root of the JSON-RPC request. + // + // NOTE: It is not part of the spec, but there are other protocols based on + // JSON-RPC 2 that require it. + ExtraFields []RequestField `json:"-"` + // OmitNilParams instructs the SetParams method to not JSON encode a nil + // value and set Params to nil instead. + OmitNilParams bool `json:"-"` +} + +// MarshalJSON implements json.Marshaler and adds the "jsonrpc":"2.0" +// property. +func (r Request) MarshalJSON() ([]byte, error) { + r2 := map[string]interface{}{ + "jsonrpc": "2.0", + "method": r.Method, + } + for _, field := range r.ExtraFields { + r2[field.Name] = field.Value + } + if !r.Notif { + r2["id"] = &r.ID + } + if r.Params != nil { + r2["params"] = r.Params + } + if r.Meta != nil { + r2["meta"] = r.Meta + } + return json.Marshal(r2) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (r *Request) UnmarshalJSON(data []byte) error { + r2 := make(map[string]interface{}) + + // Detect if the "params" or "meta" fields are JSON "null" or just not + // present by seeing if the field gets overwritten to nil. + emptyParams := &json.RawMessage{} + r2["params"] = emptyParams + emptyMeta := &json.RawMessage{} + r2["meta"] = emptyMeta + + decoder := json.NewDecoder(bytes.NewReader(data)) + decoder.UseNumber() + if err := decoder.Decode(&r2); err != nil { + return err + } + var ok bool + r.Method, ok = r2["method"].(string) + if !ok { + return errors.New("missing method field") + } + switch { + case r2["params"] == nil: + r.Params = &jsonNull + case r2["params"] == emptyParams: + r.Params = nil + default: + b, err := json.Marshal(r2["params"]) + if err != nil { + return fmt.Errorf("failed to marshal params: %w", err) + } + r.Params = (*json.RawMessage)(&b) + } + switch { + case r2["meta"] == nil: + r.Meta = &jsonNull + case r2["meta"] == emptyMeta: + r.Meta = nil + default: + b, err := json.Marshal(r2["meta"]) + if err != nil { + return fmt.Errorf("failed to marshal Meta: %w", err) + } + r.Meta = (*json.RawMessage)(&b) + } + switch rawID := r2["id"].(type) { + case nil: + r.ID = ID{} + r.Notif = true + case string: + r.ID = ID{Str: rawID, IsString: true} + r.Notif = false + case json.Number: + id, err := rawID.Int64() + if err != nil { + return fmt.Errorf("failed to unmarshal ID: %w", err) + } + r.ID = ID{Num: uint64(id)} + r.Notif = false + default: + return fmt.Errorf("unexpected ID type: %T", rawID) + } + + // Clear the extra fields before populating them again. + r.ExtraFields = nil + for name, value := range r2 { + switch name { + case "id", "jsonrpc", "meta", "method", "params": + continue + } + r.ExtraFields = append(r.ExtraFields, RequestField{ + Name: name, + Value: value, + }) + } + return nil +} + +// SetParams sets r.Params to the JSON representation of v. If JSON marshaling +// fails, it returns an error. Beware that the JSON encoding of nil is null. If +// r.OmitNilParams is true and v is nil, then r.Params is set to nil and +// therefore omitted from the JSON-RPC request. +func (r *Request) SetParams(v interface{}) error { + if r.OmitNilParams && v == nil { + r.Params = nil + return nil + } + b, err := json.Marshal(v) + if err != nil { + return err + } + r.Params = (*json.RawMessage)(&b) + return nil +} + +// SetMeta sets r.Meta to the JSON representation of v. If JSON +// marshaling fails, it returns an error. +func (r *Request) SetMeta(v interface{}) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + r.Meta = (*json.RawMessage)(&b) + return nil +} + +// SetExtraField adds an entry to r.ExtraFields, so that it is added to the +// JSON representation of the request, as a way to add arbitrary extensions to +// JSON RPC 2.0. If JSON marshaling fails, it returns an error. +func (r *Request) SetExtraField(name string, v interface{}) error { + switch name { + case "id", "jsonrpc", "meta", "method", "params": + return fmt.Errorf("invalid extra field %q", name) + } + r.ExtraFields = append(r.ExtraFields, RequestField{ + Name: name, + Value: v, + }) + return nil +} + +// RequestField is a top-level field that can be added to the JSON-RPC request. +type RequestField struct { + Name string + Value interface{} +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 0000000..9d6c243 --- /dev/null +++ b/request_test.go @@ -0,0 +1,65 @@ +package jsonrpc2_test + +import ( + "bytes" + "encoding/json" + "reflect" + "testing" + + "github.com/sourcegraph/jsonrpc2" +) + +func TestRequest_MarshalJSON_jsonrpc(t *testing.T) { + b, err := json.Marshal(&jsonrpc2.Request{}) + if err != nil { + t.Fatal(err) + } + if want := `{"id":0,"jsonrpc":"2.0","method":""}`; string(b) != want { + t.Errorf("got %q, want %q", b, want) + } +} + +func TestRequest_MarshalUnmarshalJSON(t *testing.T) { + null := json.RawMessage("null") + obj := json.RawMessage(`{"foo":"bar"}`) + tests := []struct { + data []byte + want jsonrpc2.Request + }{ + { + data: []byte(`{"id":123,"jsonrpc":"2.0","method":"m","params":{"foo":"bar"}}`), + want: jsonrpc2.Request{ID: jsonrpc2.ID{Num: 123}, Method: "m", Params: &obj}, + }, + { + data: []byte(`{"id":123,"jsonrpc":"2.0","method":"m","params":null}`), + want: jsonrpc2.Request{ID: jsonrpc2.ID{Num: 123}, Method: "m", Params: &null}, + }, + { + data: []byte(`{"id":123,"jsonrpc":"2.0","method":"m"}`), + want: jsonrpc2.Request{ID: jsonrpc2.ID{Num: 123}, Method: "m", Params: nil}, + }, + { + data: []byte(`{"id":123,"jsonrpc":"2.0","method":"m","sessionId":"session"}`), + want: jsonrpc2.Request{ID: jsonrpc2.ID{Num: 123}, Method: "m", Params: nil, ExtraFields: []jsonrpc2.RequestField{{Name: "sessionId", Value: "session"}}}, + }, + } + for _, test := range tests { + var got jsonrpc2.Request + if err := json.Unmarshal(test.data, &got); err != nil { + t.Error(err) + continue + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%q: got %+v, want %+v", test.data, got, test.want) + continue + } + data, err := json.Marshal(got) + if err != nil { + t.Error(err) + continue + } + if !bytes.Equal(data, test.data) { + t.Errorf("got JSON %q, want %q", data, test.data) + } + } +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..c9a0bfe --- /dev/null +++ b/response.go @@ -0,0 +1,72 @@ +package jsonrpc2 + +import ( + "encoding/json" + "errors" +) + +// Response represents a JSON-RPC response. See +// http://www.jsonrpc.org/specification#response_object. +type Response struct { + ID ID `json:"id"` + Result *json.RawMessage `json:"result,omitempty"` + Error *Error `json:"error,omitempty"` + + // Meta optionally provides metadata to include in the response. + // + // NOTE: It is not part of spec. However, it is useful for propagating + // tracing context, etc. + Meta *json.RawMessage `json:"meta,omitempty"` + + // SPEC NOTE: The spec says "If there was an error in detecting + // the id in the Request object (e.g. Parse error/Invalid + // Request), it MUST be Null." If we made the ID field nullable, + // then we'd have to make it a pointer type. For simplicity, we're + // ignoring the case where there was an error in detecting the ID + // in the Request object. +} + +// MarshalJSON implements json.Marshaler and adds the "jsonrpc":"2.0" +// property. +func (r Response) MarshalJSON() ([]byte, error) { + if (r.Result == nil || len(*r.Result) == 0) && r.Error == nil { + return nil, errors.New("can't marshal *jsonrpc2.Response (must have result or error)") + } + type tmpType Response // avoid infinite MarshalJSON recursion + b, err := json.Marshal(tmpType(r)) + if err != nil { + return nil, err + } + b = append(b[:len(b)-1], []byte(`,"jsonrpc":"2.0"}`)...) + return b, nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (r *Response) UnmarshalJSON(data []byte) error { + type tmpType Response + + // Detect if the "result" field is JSON "null" or just not present + // by seeing if the field gets overwritten to nil. + *r = Response{Result: &json.RawMessage{}} + + if err := json.Unmarshal(data, (*tmpType)(r)); err != nil { + return err + } + if r.Result == nil { // JSON "null" + r.Result = &jsonNull + } else if len(*r.Result) == 0 { + r.Result = nil + } + return nil +} + +// SetResult sets r.Result to the JSON representation of v. If JSON +// marshaling fails, it returns an error. +func (r *Response) SetResult(v interface{}) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + r.Result = (*json.RawMessage)(&b) + return nil +} diff --git a/response_test.go b/response_test.go new file mode 100644 index 0000000..38eeb52 --- /dev/null +++ b/response_test.go @@ -0,0 +1,110 @@ +package jsonrpc2_test + +import ( + "bytes" + "encoding/json" + "reflect" + "testing" + + "github.com/sourcegraph/jsonrpc2" +) + +func TestResponse_MarshalJSON_jsonrpc(t *testing.T) { + null := json.RawMessage("null") + b, err := json.Marshal(&jsonrpc2.Response{Result: &null}) + if err != nil { + t.Fatal(err) + } + if want := `{"id":0,"result":null,"jsonrpc":"2.0"}`; string(b) != want { + t.Errorf("got %q, want %q", b, want) + } +} + +func TestResponseMarshalJSON_Notif(t *testing.T) { + tests := map[*jsonrpc2.Request]bool{ + {ID: jsonrpc2.ID{Num: 0}}: true, + {ID: jsonrpc2.ID{Num: 1}}: true, + {ID: jsonrpc2.ID{Str: "", IsString: true}}: true, + {ID: jsonrpc2.ID{Str: "a", IsString: true}}: true, + {Notif: true}: false, + } + for r, wantIDKey := range tests { + b, err := json.Marshal(r) + if err != nil { + t.Fatal(err) + } + hasIDKey := bytes.Contains(b, []byte(`"id"`)) + if hasIDKey != wantIDKey { + t.Errorf("got %s, want contain id key: %v", b, wantIDKey) + } + } +} + +func TestResponseUnmarshalJSON_Notif(t *testing.T) { + tests := map[string]bool{ + `{"method":"f","id":0}`: false, + `{"method":"f","id":1}`: false, + `{"method":"f","id":"a"}`: false, + `{"method":"f","id":""}`: false, + `{"method":"f"}`: true, + } + for s, want := range tests { + var r jsonrpc2.Request + if err := json.Unmarshal([]byte(s), &r); err != nil { + t.Fatal(err) + } + if r.Notif != want { + t.Errorf("%s: got %v, want %v", s, r.Notif, want) + } + } +} + +func TestResponse_MarshalUnmarshalJSON(t *testing.T) { + null := json.RawMessage("null") + obj := json.RawMessage(`{"foo":"bar"}`) + tests := []struct { + data []byte + want jsonrpc2.Response + error bool + }{ + { + data: []byte(`{"id":123,"result":{"foo":"bar"},"jsonrpc":"2.0"}`), + want: jsonrpc2.Response{ID: jsonrpc2.ID{Num: 123}, Result: &obj}, + }, + { + data: []byte(`{"id":123,"result":null,"jsonrpc":"2.0"}`), + want: jsonrpc2.Response{ID: jsonrpc2.ID{Num: 123}, Result: &null}, + }, + { + data: []byte(`{"id":123,"jsonrpc":"2.0"}`), + want: jsonrpc2.Response{ID: jsonrpc2.ID{Num: 123}, Result: nil}, + error: true, // either result or error field must be set + }, + } + for _, test := range tests { + var got jsonrpc2.Response + if err := json.Unmarshal(test.data, &got); err != nil { + t.Error(err) + continue + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%q: got %+v, want %+v", test.data, got, test.want) + continue + } + data, err := json.Marshal(got) + if err != nil { + if test.error { + continue + } + t.Error(err) + continue + } + if test.error { + t.Errorf("%q: expected error", test.data) + continue + } + if !bytes.Equal(data, test.data) { + t.Errorf("got JSON %q, want %q", data, test.data) + } + } +}