Skip to content

Commit

Permalink
Add some go tests to the chat app
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Feb 17, 2018
1 parent ec2916d commit d514b82
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 14 deletions.
8 changes: 7 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (p *Client) MustPost(query string, response interface{}, options ...Option)
}
}

func (p *Client) Post(query string, response interface{}, options ...Option) error {
func (p *Client) mkRequest(query string, options ...Option) Request {
r := Request{
Query: query,
}
Expand All @@ -71,6 +71,11 @@ func (p *Client) Post(query string, response interface{}, options ...Option) err
option(&r)
}

return r
}

func (p *Client) Post(query string, response interface{}, options ...Option) error {
r := p.mkRequest(query, options...)
requestBody, err := json.Marshal(r)
if err != nil {
return fmt.Errorf("encode: %s", err.Error())
Expand Down Expand Up @@ -120,6 +125,7 @@ func unpack(data interface{}, into interface{}) error {
Result: into,
TagName: "json",
ErrorUnused: true,
ZeroFields: true,
})
if err != nil {
return fmt.Errorf("mapstructure: %s", err.Error())
Expand Down
104 changes: 104 additions & 0 deletions client/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package client

import (
"encoding/json"
"fmt"
"strings"

"github.com/gorilla/websocket"
"github.com/vektah/gqlgen/neelance/errors"
)

const (
connectionInitMsg = "connection_init" // Client -> Server
connectionTerminateMsg = "connection_terminate" // Client -> Server
startMsg = "start" // Client -> Server
stopMsg = "stop" // Client -> Server
connectionAckMsg = "connection_ack" // Server -> Client
connectionErrorMsg = "connection_error" // Server -> Client
connectionKeepAliveMsg = "ka" // Server -> Client
dataMsg = "data" // Server -> Client
errorMsg = "error" // Server -> Client
completeMsg = "complete" // Server -> Client
)

type operationMessage struct {
Payload json.RawMessage `json:"payload,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type"`
}

type Subscription struct {
Close func() error
Next func(response interface{}) error
}

func errorSubscription(err error) *Subscription {
return &Subscription{
Close: func() error { return nil },
Next: func(response interface{}) error {
return err
},
}
}

func (p *Client) Websocket(query string, options ...Option) *Subscription {
r := p.mkRequest(query, options...)
requestBody, err := json.Marshal(r)
if err != nil {
return errorSubscription(fmt.Errorf("encode: %s", err.Error()))
}

url := strings.Replace(p.url, "http://", "ws://", -1)
url = strings.Replace(url, "https://", "wss://", -1)

c, _, err := websocket.DefaultDialer.Dial(url, nil)
if err != nil {
return errorSubscription(fmt.Errorf("dial: %s", err.Error()))
}

if err = c.WriteJSON(operationMessage{Type: connectionInitMsg}); err != nil {
return errorSubscription(fmt.Errorf("init: %s", err.Error()))
}

var ack operationMessage
if err := c.ReadJSON(&ack); err != nil {
return errorSubscription(fmt.Errorf("ack: %s", err.Error()))
}
if ack.Type != connectionAckMsg {
return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
}

if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil {
return errorSubscription(fmt.Errorf("start: %s", err.Error()))
}

return &Subscription{
Close: c.Close,
Next: func(response interface{}) error {
var op operationMessage
c.ReadJSON(&op)
if op.Type != dataMsg {
return fmt.Errorf("expected data message, got %#v", op)
}

respDataRaw := map[string]interface{}{}
err = json.Unmarshal(op.Payload, &respDataRaw)
if err != nil {
return fmt.Errorf("decode: %s", err.Error())
}

if respDataRaw["errors"] != nil {
var errs []*errors.QueryError
if err := unpack(respDataRaw["errors"], errs); err != nil {
return err
}
if len(errs) > 0 {
return fmt.Errorf("errors: %s", errs)
}
}

return unpack(respDataRaw["data"], response)
},
}
}
52 changes: 52 additions & 0 deletions example/chat/chat_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package chat

import (
"net/http/httptest"
"sync"
"testing"

"github.com/stretchr/testify/require"
"github.com/vektah/gqlgen/client"
"github.com/vektah/gqlgen/handler"
)

func TestChat(t *testing.T) {
srv := httptest.NewServer(handler.GraphQL(MakeExecutableSchema(New())))
c := client.New(srv.URL)
var wg sync.WaitGroup
wg.Add(1)

t.Run("subscribe to chat events", func(t *testing.T) {
t.Parallel()

sub := c.Websocket(`subscription { messageAdded(roomName:"#gophers") { text createdBy } }`)
defer sub.Close()

wg.Done()
var resp struct {
MessageAdded struct {
Text string
CreatedBy string
}
}
require.NoError(t, sub.Next(&resp))
require.Equal(t, "Hello!", resp.MessageAdded.Text)
require.Equal(t, "vektah", resp.MessageAdded.CreatedBy)

require.NoError(t, sub.Next(&resp))
require.Equal(t, "Whats up?", resp.MessageAdded.Text)
require.Equal(t, "vektah", resp.MessageAdded.CreatedBy)
})

t.Run("post two messages", func(t *testing.T) {
t.Parallel()

wg.Wait()
var resp interface{}
c.MustPost(`mutation {
a:post(text:"Hello!", roomName:"#gophers", username:"vektah") { id }
b:post(text:"Whats up?", roomName:"#gophers", username:"vektah") { id }
}`, &resp)
})

}
3 changes: 1 addition & 2 deletions example/chat/generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package chat

import (
context "context"
fmt "fmt"
strconv "strconv"
sync "sync"
time "time"
Expand Down Expand Up @@ -344,7 +343,7 @@ func (ec *executionContext) _subscription(sel []query.Selection, it *interface{}
fields := graphql.CollectFields(ec.doc, sel, subscriptionImplementors, ec.variables)

if len(fields) != 1 {
fmt.Errorf("must subscribe to exactly one stream")
ec.Errorf("must subscribe to exactly one stream")
return nil
}

Expand Down
4 changes: 1 addition & 3 deletions example/chat/models.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
package chat

import (
"sync"
"time"
)

type Chatroom struct {
Name string
Messages []Message
Observers map[string]chan Message
mu sync.Mutex
}

func (c Chatroom) ID() string { return "C" + c.Name }
func (c *Chatroom) ID() string { return "C" + c.Name }

type Message struct {
ID string
Expand Down
20 changes: 14 additions & 6 deletions example/chat/resolvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ package chat
import (
context "context"
"math/rand"
"sync"
"time"
)

type resolvers struct {
Rooms map[string]*Chatroom
mu sync.Mutex
}

func New() *resolvers {
Expand All @@ -19,11 +21,13 @@ func New() *resolvers {
}

func (r *resolvers) Mutation_post(ctx context.Context, text string, userName string, roomName string) (Message, error) {
r.mu.Lock()
room := r.Rooms[roomName]
if room == nil {
room = &Chatroom{Name: roomName, Observers: map[string]chan Message{}}
r.Rooms[roomName] = room
}
r.mu.Unlock()

message := Message{
ID: randString(8),
Expand All @@ -33,44 +37,48 @@ func (r *resolvers) Mutation_post(ctx context.Context, text string, userName str
}

room.Messages = append(room.Messages, message)
room.mu.Lock()
r.mu.Lock()
for _, observer := range room.Observers {
observer <- message
}
room.mu.Unlock()
r.mu.Unlock()
return message, nil
}

func (r *resolvers) Query_room(ctx context.Context, name string) (*Chatroom, error) {
r.mu.Lock()
room := r.Rooms[name]
if room == nil {
room = &Chatroom{Name: name, Observers: map[string]chan Message{}}
r.Rooms[name] = room
}
r.mu.Unlock()

return room, nil
}

func (r *resolvers) Subscription_messageAdded(ctx context.Context, roomName string) (<-chan Message, error) {
r.mu.Lock()
room := r.Rooms[roomName]
if room == nil {
room = &Chatroom{Name: roomName, Observers: map[string]chan Message{}}
r.Rooms[roomName] = room
}
r.mu.Unlock()

id := randString(8)
events := make(chan Message, 1)

go func() {
<-ctx.Done()
room.mu.Lock()
r.mu.Lock()
delete(room.Observers, id)
room.mu.Unlock()
r.mu.Unlock()
}()

room.mu.Lock()
r.mu.Lock()
room.Observers[id] = events
room.mu.Unlock()
r.mu.Unlock()

return events, nil
}
Expand Down
2 changes: 1 addition & 1 deletion handler/graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestHandlerPOST(t *testing.T) {
t.Run("decode failure", func(t *testing.T) {
resp := doRequest(h, "POST", "/graphql", "notjson")
assert.Equal(t, http.StatusBadRequest, resp.Code)
assert.Equal(t, `{"data":null,"errors":[{"message":"json body could not be decoded"}]}`, resp.Body.String())
assert.Equal(t, `{"data":null,"errors":[{"message":"json body could not be decoded: invalid character 'o' in literal null (expecting 'u')"}]}`, resp.Body.String())
})

t.Run("parse failure", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion templates/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func (ec *executionContext) _{{$object.GQLType|lcFirst}}(sel []query.Selection,
fields := graphql.CollectFields(ec.doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.variables)
if len(fields) != 1 {
fmt.Errorf("must subscribe to exactly one stream")
ec.Errorf("must subscribe to exactly one stream")
return nil
}
Expand Down

0 comments on commit d514b82

Please sign in to comment.