diff --git a/client/client.go b/client/client.go index 6354010adf1..9d2271c08be 100644 --- a/client/client.go +++ b/client/client.go @@ -62,7 +62,7 @@ func (p *Client) MustPost(query string, response interface{}, options ...Option) } } -func (p *Client) Post(query string, response interface{}, options ...Option) error { +func (p *Client) mkRequest(query string, options ...Option) Request { r := Request{ Query: query, } @@ -71,6 +71,11 @@ func (p *Client) Post(query string, response interface{}, options ...Option) err option(&r) } + return r +} + +func (p *Client) Post(query string, response interface{}, options ...Option) error { + r := p.mkRequest(query, options...) requestBody, err := json.Marshal(r) if err != nil { return fmt.Errorf("encode: %s", err.Error()) @@ -120,6 +125,7 @@ func unpack(data interface{}, into interface{}) error { Result: into, TagName: "json", ErrorUnused: true, + ZeroFields: true, }) if err != nil { return fmt.Errorf("mapstructure: %s", err.Error()) diff --git a/client/websocket.go b/client/websocket.go new file mode 100644 index 00000000000..8bd7382a0a8 --- /dev/null +++ b/client/websocket.go @@ -0,0 +1,104 @@ +package client + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/gorilla/websocket" + "github.com/vektah/gqlgen/neelance/errors" +) + +const ( + connectionInitMsg = "connection_init" // Client -> Server + connectionTerminateMsg = "connection_terminate" // Client -> Server + startMsg = "start" // Client -> Server + stopMsg = "stop" // Client -> Server + connectionAckMsg = "connection_ack" // Server -> Client + connectionErrorMsg = "connection_error" // Server -> Client + connectionKeepAliveMsg = "ka" // Server -> Client + dataMsg = "data" // Server -> Client + errorMsg = "error" // Server -> Client + completeMsg = "complete" // Server -> Client +) + +type operationMessage struct { + Payload json.RawMessage `json:"payload,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type"` +} + +type Subscription struct { + Close func() error + Next func(response interface{}) error +} + +func errorSubscription(err error) *Subscription { + return &Subscription{ + Close: func() error { return nil }, + Next: func(response interface{}) error { + return err + }, + } +} + +func (p *Client) Websocket(query string, options ...Option) *Subscription { + r := p.mkRequest(query, options...) + requestBody, err := json.Marshal(r) + if err != nil { + return errorSubscription(fmt.Errorf("encode: %s", err.Error())) + } + + url := strings.Replace(p.url, "http://", "ws://", -1) + url = strings.Replace(url, "https://", "wss://", -1) + + c, _, err := websocket.DefaultDialer.Dial(url, nil) + if err != nil { + return errorSubscription(fmt.Errorf("dial: %s", err.Error())) + } + + if err = c.WriteJSON(operationMessage{Type: connectionInitMsg}); err != nil { + return errorSubscription(fmt.Errorf("init: %s", err.Error())) + } + + var ack operationMessage + if err := c.ReadJSON(&ack); err != nil { + return errorSubscription(fmt.Errorf("ack: %s", err.Error())) + } + if ack.Type != connectionAckMsg { + return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack)) + } + + if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil { + return errorSubscription(fmt.Errorf("start: %s", err.Error())) + } + + return &Subscription{ + Close: c.Close, + Next: func(response interface{}) error { + var op operationMessage + c.ReadJSON(&op) + if op.Type != dataMsg { + return fmt.Errorf("expected data message, got %#v", op) + } + + respDataRaw := map[string]interface{}{} + err = json.Unmarshal(op.Payload, &respDataRaw) + if err != nil { + return fmt.Errorf("decode: %s", err.Error()) + } + + if respDataRaw["errors"] != nil { + var errs []*errors.QueryError + if err := unpack(respDataRaw["errors"], errs); err != nil { + return err + } + if len(errs) > 0 { + return fmt.Errorf("errors: %s", errs) + } + } + + return unpack(respDataRaw["data"], response) + }, + } +} diff --git a/example/chat/chat_test.go b/example/chat/chat_test.go new file mode 100644 index 00000000000..44734c2e969 --- /dev/null +++ b/example/chat/chat_test.go @@ -0,0 +1,52 @@ +package chat + +import ( + "net/http/httptest" + "sync" + "testing" + + "github.com/stretchr/testify/require" + "github.com/vektah/gqlgen/client" + "github.com/vektah/gqlgen/handler" +) + +func TestChat(t *testing.T) { + srv := httptest.NewServer(handler.GraphQL(MakeExecutableSchema(New()))) + c := client.New(srv.URL) + var wg sync.WaitGroup + wg.Add(1) + + t.Run("subscribe to chat events", func(t *testing.T) { + t.Parallel() + + sub := c.Websocket(`subscription { messageAdded(roomName:"#gophers") { text createdBy } }`) + defer sub.Close() + + wg.Done() + var resp struct { + MessageAdded struct { + Text string + CreatedBy string + } + } + require.NoError(t, sub.Next(&resp)) + require.Equal(t, "Hello!", resp.MessageAdded.Text) + require.Equal(t, "vektah", resp.MessageAdded.CreatedBy) + + require.NoError(t, sub.Next(&resp)) + require.Equal(t, "Whats up?", resp.MessageAdded.Text) + require.Equal(t, "vektah", resp.MessageAdded.CreatedBy) + }) + + t.Run("post two messages", func(t *testing.T) { + t.Parallel() + + wg.Wait() + var resp interface{} + c.MustPost(`mutation { + a:post(text:"Hello!", roomName:"#gophers", username:"vektah") { id } + b:post(text:"Whats up?", roomName:"#gophers", username:"vektah") { id } + }`, &resp) + }) + +} diff --git a/example/chat/generated.go b/example/chat/generated.go index e5cc8747109..f89b74d95df 100644 --- a/example/chat/generated.go +++ b/example/chat/generated.go @@ -4,7 +4,6 @@ package chat import ( context "context" - fmt "fmt" strconv "strconv" sync "sync" time "time" @@ -344,7 +343,7 @@ func (ec *executionContext) _subscription(sel []query.Selection, it *interface{} fields := graphql.CollectFields(ec.doc, sel, subscriptionImplementors, ec.variables) if len(fields) != 1 { - fmt.Errorf("must subscribe to exactly one stream") + ec.Errorf("must subscribe to exactly one stream") return nil } diff --git a/example/chat/models.go b/example/chat/models.go index 9397a0cfcd2..50a9cc092f1 100644 --- a/example/chat/models.go +++ b/example/chat/models.go @@ -1,7 +1,6 @@ package chat import ( - "sync" "time" ) @@ -9,10 +8,9 @@ type Chatroom struct { Name string Messages []Message Observers map[string]chan Message - mu sync.Mutex } -func (c Chatroom) ID() string { return "C" + c.Name } +func (c *Chatroom) ID() string { return "C" + c.Name } type Message struct { ID string diff --git a/example/chat/resolvers.go b/example/chat/resolvers.go index 9fee0076bb0..9b3dfd4463a 100644 --- a/example/chat/resolvers.go +++ b/example/chat/resolvers.go @@ -5,11 +5,13 @@ package chat import ( context "context" "math/rand" + "sync" "time" ) type resolvers struct { Rooms map[string]*Chatroom + mu sync.Mutex } func New() *resolvers { @@ -19,11 +21,13 @@ func New() *resolvers { } func (r *resolvers) Mutation_post(ctx context.Context, text string, userName string, roomName string) (Message, error) { + r.mu.Lock() room := r.Rooms[roomName] if room == nil { room = &Chatroom{Name: roomName, Observers: map[string]chan Message{}} r.Rooms[roomName] = room } + r.mu.Unlock() message := Message{ ID: randString(8), @@ -33,44 +37,48 @@ func (r *resolvers) Mutation_post(ctx context.Context, text string, userName str } room.Messages = append(room.Messages, message) - room.mu.Lock() + r.mu.Lock() for _, observer := range room.Observers { observer <- message } - room.mu.Unlock() + r.mu.Unlock() return message, nil } func (r *resolvers) Query_room(ctx context.Context, name string) (*Chatroom, error) { + r.mu.Lock() room := r.Rooms[name] if room == nil { room = &Chatroom{Name: name, Observers: map[string]chan Message{}} r.Rooms[name] = room } + r.mu.Unlock() return room, nil } func (r *resolvers) Subscription_messageAdded(ctx context.Context, roomName string) (<-chan Message, error) { + r.mu.Lock() room := r.Rooms[roomName] if room == nil { room = &Chatroom{Name: roomName, Observers: map[string]chan Message{}} r.Rooms[roomName] = room } + r.mu.Unlock() id := randString(8) events := make(chan Message, 1) go func() { <-ctx.Done() - room.mu.Lock() + r.mu.Lock() delete(room.Observers, id) - room.mu.Unlock() + r.mu.Unlock() }() - room.mu.Lock() + r.mu.Lock() room.Observers[id] = events - room.mu.Unlock() + r.mu.Unlock() return events, nil } diff --git a/handler/graphql_test.go b/handler/graphql_test.go index ed1e59f5b6e..5aa1f238cb9 100644 --- a/handler/graphql_test.go +++ b/handler/graphql_test.go @@ -21,7 +21,7 @@ func TestHandlerPOST(t *testing.T) { t.Run("decode failure", func(t *testing.T) { resp := doRequest(h, "POST", "/graphql", "notjson") assert.Equal(t, http.StatusBadRequest, resp.Code) - assert.Equal(t, `{"data":null,"errors":[{"message":"json body could not be decoded"}]}`, resp.Body.String()) + assert.Equal(t, `{"data":null,"errors":[{"message":"json body could not be decoded: invalid character 'o' in literal null (expecting 'u')"}]}`, resp.Body.String()) }) t.Run("parse failure", func(t *testing.T) { diff --git a/templates/object.go b/templates/object.go index 290419a0607..2b3cf34c749 100644 --- a/templates/object.go +++ b/templates/object.go @@ -12,7 +12,7 @@ func (ec *executionContext) _{{$object.GQLType|lcFirst}}(sel []query.Selection, fields := graphql.CollectFields(ec.doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.variables) if len(fields) != 1 { - fmt.Errorf("must subscribe to exactly one stream") + ec.Errorf("must subscribe to exactly one stream") return nil }