Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup interaction with GraphQL service #2

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 1 addition & 64 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,12 @@ To use this library with [github.com/graph-gophers/graphql-go](https://github.co
package main

import (
"context"
"encoding/json"
"fmt"
"net/http"

"github.com/gorilla/websocket"
graphql "github.com/graph-gophers/graphql-go"
"github.com/graph-gophers/graphql-go/relay"
"github.com/graph-gophers/graphql-transport-ws/graphqlws"
"github.com/graph-gophers/graphql-transport-ws/graphqlws/event"
)

const schema = `
Expand All @@ -46,73 +42,14 @@ func main() {
}

// graphQL handler
graphQLHandler := newHandler(s, &relay.Handler{Schema: s})
graphQLHandler := graphqlws.NewHandlerFunc(s, &relay.Handler{Schema: s})
http.HandleFunc("/graphql", graphQLHandler)

// start HTTP server
if err := http.ListenAndServe(fmt.Sprintf(":%d", 8080), nil); err != nil {
panic(err)
}
}

func newHandler(s *graphql.Schema, httpHandler http.Handler) http.HandlerFunc {
wsHandler := graphqlws.NewHandler(&defaultCallback{schema: s})
return func(w http.ResponseWriter, r *http.Request) {
for _, subprotocol := range websocket.Subprotocols(r) {
if subprotocol == "graphql-ws" {
wsHandler.ServeHTTP(w, r)
return
}
}
httpHandler.ServeHTTP(w, r)
}
}

type defaultCallback struct {
schema *graphql.Schema
}

func (h *defaultCallback) OnOperation(ctx context.Context, args *event.OnOperationArgs) (json.RawMessage, func(), error) {
b, err := json.Marshal(args.StartMessage.Variables)
if err != nil {
return nil, nil, err
}

variables := map[string]interface{}{}
err = json.Unmarshal(b, &variables)
if err != nil {
return nil, nil, err
}

ctx, cancel := context.WithCancel(ctx)
c, err := h.schema.Subscribe(ctx, args.StartMessage.Query, args.StartMessage.OperationName, variables)
if err != nil {
cancel()
return nil, nil, err
}

go func() {
defer cancel()
for {
select {
case <-ctx.Done():
return
case response, more := <-c:
if !more {
return
}
responseJSON, err := json.Marshal(response)
if err != nil {
args.Send(json.RawMessage(`{"errors":["internal error: can't marshal response into json"]}`))
continue
}
args.Send(responseJSON)
}
}
}()

return nil, cancel, nil
}
```

For a more in depth example see [this repo](https://github.com/matiasanaya/go-graphql-subscription-example).
Expand Down
22 changes: 0 additions & 22 deletions graphqlws/event/event.go

This file was deleted.

46 changes: 22 additions & 24 deletions graphqlws/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"github.com/gorilla/websocket"

"github.com/graph-gophers/graphql-transport-ws/graphqlws/event"
"github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/connection"
)

Expand All @@ -16,28 +15,27 @@ var upgrader = websocket.Upgrader{
Subprotocols: []string{protocolGraphQLWS},
}

// Handler is a GraphQL websocket subscription handler
type Handler struct {
eventsHandler event.Handler
}

// NewHandler returns a new Handler
func NewHandler(eh event.Handler) *Handler {
return &Handler{eventsHandler: eh}
}

func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
// NewHandlerFunc returns an http.HandlerFunc that supports GraphQL over websockets
func NewHandlerFunc(svc connection.GraphQLService, httpHandler http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
for _, subprotocol := range websocket.Subprotocols(r) {
if subprotocol == "graphql-ws" {
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}

if ws.Subprotocol() != protocolGraphQLWS {
ws.Close()
return
}

go connection.Connect(ws, svc)
return
}
}

// Fallback to HTTP
httpHandler.ServeHTTP(w, r)
}

if ws.Subprotocol() != protocolGraphQLWS {
ws.Close()
return
}

go connection.Connect(ws, h.eventsHandler)

return
}
77 changes: 45 additions & 32 deletions graphqlws/internal/connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"errors"
"fmt"
"time"

"github.com/graph-gophers/graphql-transport-ws/graphqlws/event"
)

type operationMessageType string
Expand Down Expand Up @@ -43,11 +41,22 @@ type operationMessage struct {
Type operationMessageType `json:"type"`
}

type startMessagePayload struct {
OperationName string `json:"operationName"`
Query string `json:"query"`
Variables map[string]interface{} `json:"variables"`
}

type initMessagePayload struct{}

// GraphQLService interface
type GraphQLService interface {
Subscribe(ctx context.Context, document string, operationName string, variableValues map[string]interface{}) (payloads <-chan interface{}, err error)
}

type connection struct {
cancel func()
handler event.Handler
service GraphQLService
writeTimeout time.Duration
ws wsConnection
}
Expand All @@ -68,9 +77,9 @@ func WriteTimeout(d time.Duration) func(conn *connection) {

// Connect implements the apollographql subscriptions-transport-ws [email protected]
// https://github.com/apollographql/subscriptions-transport-ws/blob/v0.9.4/PROTOCOL.md
func Connect(ws wsConnection, handler event.Handler, options ...func(conn *connection)) func() {
func Connect(ws wsConnection, service GraphQLService, options ...func(conn *connection)) func() {
conn := &connection{
handler: handler,
service: service,
ws: ws,
}

Expand Down Expand Up @@ -166,42 +175,46 @@ func (conn *connection) readLoop(ctx context.Context, send sendFunc) {
continue
}

args := &event.OnOperationArgs{ID: msg.ID}
if err := json.Unmarshal(msg.Payload, &args.Payload); err != nil {
var osp startMessagePayload
if err := json.Unmarshal(msg.Payload, &osp); err != nil {
ep := errPayload(fmt.Errorf("invalid payload for type: %s", msg.Type))
send(msg.ID, typeConnectionError, ep)
continue
}

// TODO: ensure args.Send doesn't work after typeStop or onDone()
args.Send = func(payload json.RawMessage) {
send(msg.ID, typeData, payload)
}
opCtx, cancel := context.WithCancel(ctx)
// TODO: timeout this call, to guard against poor clients
payload, onDone, err := conn.handler.OnOperation(ctx, args)
// query or mutation
if err != nil || payload != nil {
func() {
defer func() {
if onDone != nil {
onDone()
}
send(msg.ID, typeComplete, nil)
}()

if err != nil {
send(msg.ID, typeError, errPayload(err))
return
}
send(msg.ID, typeData, payload)
}()
c, err := conn.service.Subscribe(opCtx, osp.Query, osp.OperationName, osp.Variables)
if err != nil {
cancel()
send(msg.ID, typeError, errPayload(err))
send(msg.ID, typeComplete, nil)
continue
}

// subscription
if onDone != nil {
opDone[msg.ID] = onDone
}
opDone[msg.ID] = cancel

go func() {
defer cancel()
for {
select {
case <-opCtx.Done():
return
case payload, more := <-c:
if !more {
send(msg.ID, typeComplete, nil)
return
}

jsonPayload, err := json.Marshal(payload)
if err != nil {
send(msg.ID, typeError, errPayload(err))
continue
}
send(msg.ID, typeData, jsonPayload)
}
}
}()

case typeStop:
onDone, ok := opDone[msg.ID]
Expand Down
48 changes: 25 additions & 23 deletions graphqlws/internal/connection/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"testing"
"time"

"github.com/graph-gophers/graphql-transport-ws/graphqlws/event"
"github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/connection"
)

Expand All @@ -30,13 +29,12 @@ type message struct {

func TestConnect(t *testing.T) {
testTable := []struct {
name string
callbacks *callbacksHandler
messages []message
name string
svc *gqlService
messages []message
}{
{
name: "connection_init_ok",
callbacks: &callbacksHandler{},
name: "connection_init_ok",
messages: []message{
{
intention: clientSends,
Expand All @@ -52,8 +50,7 @@ func TestConnect(t *testing.T) {
},
},
{
name: "connection_init_error",
callbacks: &callbacksHandler{},
name: "connection_init_error",
messages: []message{
{
intention: clientSends,
Expand All @@ -74,10 +71,8 @@ func TestConnect(t *testing.T) {
},
},
{
name: "start_query_ok",
callbacks: &callbacksHandler{
payload: json.RawMessage(`{"data":{},"errors":null}`),
},
name: "start_ok",
svc: newGQLService(`{"data":{},"errors":null}`),
messages: []message{
{
intention: clientSends,
Expand Down Expand Up @@ -109,9 +104,7 @@ func TestConnect(t *testing.T) {
},
{
name: "start_query_data_error",
callbacks: &callbacksHandler{
payload: json.RawMessage(`{"data":null,"errors":[{"message":"a error"}]}`),
},
svc: newGQLService(`{"data":null,"errors":[{"message":"a error"}]}`),
messages: []message{
{
intention: clientSends,
Expand Down Expand Up @@ -144,7 +137,7 @@ func TestConnect(t *testing.T) {
},
{
name: "start_query_error",
callbacks: &callbacksHandler{
svc: &gqlService{
err: errors.New("some error"),
},
messages: []message{
Expand Down Expand Up @@ -179,20 +172,29 @@ func TestConnect(t *testing.T) {
for _, tt := range testTable {
t.Run(tt.name, func(t *testing.T) {
ws := newConnection()
go connection.Connect(ws, tt.callbacks)
go connection.Connect(ws, tt.svc)
ws.test(t, tt.messages)
})
}
}

type callbacksHandler struct {
payload json.RawMessage
cancel func()
err error
type gqlService struct {
payloads <-chan interface{}
err error
}

func newGQLService(pp ...string) *gqlService {
c := make(chan interface{}, len(pp))
for _, p := range pp {
c <- json.RawMessage(p)
}
close(c)

return &gqlService{payloads: c}
}

func (h *callbacksHandler) OnOperation(ctx context.Context, args *event.OnOperationArgs) (json.RawMessage, func(), error) {
return h.payload, h.cancel, h.err
func (h *gqlService) Subscribe(ctx context.Context, document string, operationName string, variableValues map[string]interface{}) (payloads <-chan interface{}, err error) {
return h.payloads, h.err
}

func newConnection() *wsConnection {
Expand Down