diff --git a/README.md b/README.md index b0638c5..ba83af7 100644 --- a/README.md +++ b/README.md @@ -12,16 +12,12 @@ To use this library with [github.com/graph-gophers/graphql-go](https://github.co package main import ( - "context" - "encoding/json" "fmt" "net/http" - "github.com/gorilla/websocket" graphql "github.com/graph-gophers/graphql-go" "github.com/graph-gophers/graphql-go/relay" "github.com/graph-gophers/graphql-transport-ws/graphqlws" - "github.com/graph-gophers/graphql-transport-ws/graphqlws/event" ) const schema = ` @@ -46,7 +42,7 @@ func main() { } // graphQL handler - graphQLHandler := newHandler(s, &relay.Handler{Schema: s}) + graphQLHandler := graphqlws.NewHandlerFunc(s, &relay.Handler{Schema: s}) http.HandleFunc("/graphql", graphQLHandler) // start HTTP server @@ -54,65 +50,6 @@ func main() { panic(err) } } - -func newHandler(s *graphql.Schema, httpHandler http.Handler) http.HandlerFunc { - wsHandler := graphqlws.NewHandler(&defaultCallback{schema: s}) - return func(w http.ResponseWriter, r *http.Request) { - for _, subprotocol := range websocket.Subprotocols(r) { - if subprotocol == "graphql-ws" { - wsHandler.ServeHTTP(w, r) - return - } - } - httpHandler.ServeHTTP(w, r) - } -} - -type defaultCallback struct { - schema *graphql.Schema -} - -func (h *defaultCallback) OnOperation(ctx context.Context, args *event.OnOperationArgs) (json.RawMessage, func(), error) { - b, err := json.Marshal(args.StartMessage.Variables) - if err != nil { - return nil, nil, err - } - - variables := map[string]interface{}{} - err = json.Unmarshal(b, &variables) - if err != nil { - return nil, nil, err - } - - ctx, cancel := context.WithCancel(ctx) - c, err := h.schema.Subscribe(ctx, args.StartMessage.Query, args.StartMessage.OperationName, variables) - if err != nil { - cancel() - return nil, nil, err - } - - go func() { - defer cancel() - for { - select { - case <-ctx.Done(): - return - case response, more := <-c: - if !more { - return - } - responseJSON, err := json.Marshal(response) - if err != nil { - args.Send(json.RawMessage(`{"errors":["internal error: can't marshal response into json"]}`)) - continue - } - args.Send(responseJSON) - } - } - }() - - return nil, cancel, nil -} ``` For a more in depth example see [this repo](https://github.com/matiasanaya/go-graphql-subscription-example). diff --git a/graphqlws/event/event.go b/graphqlws/event/event.go deleted file mode 100644 index 8bed39a..0000000 --- a/graphqlws/event/event.go +++ /dev/null @@ -1,22 +0,0 @@ -package event - -import ( - "context" - "encoding/json" -) - -// Handler handles graphqlws events -type Handler interface { - OnOperation(ctx context.Context, args *OnOperationArgs) (payload json.RawMessage, onDone func(), err error) -} - -// OnOperationArgs are the inputs available to the OnOperation event handler -type OnOperationArgs struct { - ID string - Send func(payload json.RawMessage) - Payload struct { - OperationName string `json:"operationName"` - Query string `json:"query"` - Variables map[string]json.RawMessage `json:"variables"` - } -} diff --git a/graphqlws/http.go b/graphqlws/http.go index b908144..8351893 100644 --- a/graphqlws/http.go +++ b/graphqlws/http.go @@ -5,7 +5,6 @@ import ( "github.com/gorilla/websocket" - "github.com/graph-gophers/graphql-transport-ws/graphqlws/event" "github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/connection" ) @@ -16,28 +15,27 @@ var upgrader = websocket.Upgrader{ Subprotocols: []string{protocolGraphQLWS}, } -// Handler is a GraphQL websocket subscription handler -type Handler struct { - eventsHandler event.Handler -} - -// NewHandler returns a new Handler -func NewHandler(eh event.Handler) *Handler { - return &Handler{eventsHandler: eh} -} - -func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return +// NewHandlerFunc returns an http.HandlerFunc that supports GraphQL over websockets +func NewHandlerFunc(svc connection.GraphQLService, httpHandler http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + for _, subprotocol := range websocket.Subprotocols(r) { + if subprotocol == "graphql-ws" { + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + if ws.Subprotocol() != protocolGraphQLWS { + ws.Close() + return + } + + go connection.Connect(ws, svc) + return + } + } + + // Fallback to HTTP + httpHandler.ServeHTTP(w, r) } - - if ws.Subprotocol() != protocolGraphQLWS { - ws.Close() - return - } - - go connection.Connect(ws, h.eventsHandler) - - return } diff --git a/graphqlws/internal/connection/connection.go b/graphqlws/internal/connection/connection.go index 6e89005..94c9167 100644 --- a/graphqlws/internal/connection/connection.go +++ b/graphqlws/internal/connection/connection.go @@ -6,8 +6,6 @@ import ( "errors" "fmt" "time" - - "github.com/graph-gophers/graphql-transport-ws/graphqlws/event" ) type operationMessageType string @@ -43,11 +41,22 @@ type operationMessage struct { Type operationMessageType `json:"type"` } +type startMessagePayload struct { + OperationName string `json:"operationName"` + Query string `json:"query"` + Variables map[string]interface{} `json:"variables"` +} + type initMessagePayload struct{} +// GraphQLService interface +type GraphQLService interface { + Subscribe(ctx context.Context, document string, operationName string, variableValues map[string]interface{}) (payloads <-chan interface{}, err error) +} + type connection struct { cancel func() - handler event.Handler + service GraphQLService writeTimeout time.Duration ws wsConnection } @@ -68,9 +77,9 @@ func WriteTimeout(d time.Duration) func(conn *connection) { // Connect implements the apollographql subscriptions-transport-ws protocol@v0.9.4 // https://github.com/apollographql/subscriptions-transport-ws/blob/v0.9.4/PROTOCOL.md -func Connect(ws wsConnection, handler event.Handler, options ...func(conn *connection)) func() { +func Connect(ws wsConnection, service GraphQLService, options ...func(conn *connection)) func() { conn := &connection{ - handler: handler, + service: service, ws: ws, } @@ -166,42 +175,46 @@ func (conn *connection) readLoop(ctx context.Context, send sendFunc) { continue } - args := &event.OnOperationArgs{ID: msg.ID} - if err := json.Unmarshal(msg.Payload, &args.Payload); err != nil { + var osp startMessagePayload + if err := json.Unmarshal(msg.Payload, &osp); err != nil { ep := errPayload(fmt.Errorf("invalid payload for type: %s", msg.Type)) send(msg.ID, typeConnectionError, ep) continue } - // TODO: ensure args.Send doesn't work after typeStop or onDone() - args.Send = func(payload json.RawMessage) { - send(msg.ID, typeData, payload) - } + opCtx, cancel := context.WithCancel(ctx) // TODO: timeout this call, to guard against poor clients - payload, onDone, err := conn.handler.OnOperation(ctx, args) - // query or mutation - if err != nil || payload != nil { - func() { - defer func() { - if onDone != nil { - onDone() - } - send(msg.ID, typeComplete, nil) - }() - - if err != nil { - send(msg.ID, typeError, errPayload(err)) - return - } - send(msg.ID, typeData, payload) - }() + c, err := conn.service.Subscribe(opCtx, osp.Query, osp.OperationName, osp.Variables) + if err != nil { + cancel() + send(msg.ID, typeError, errPayload(err)) + send(msg.ID, typeComplete, nil) continue } - // subscription - if onDone != nil { - opDone[msg.ID] = onDone - } + opDone[msg.ID] = cancel + + go func() { + defer cancel() + for { + select { + case <-opCtx.Done(): + return + case payload, more := <-c: + if !more { + send(msg.ID, typeComplete, nil) + return + } + + jsonPayload, err := json.Marshal(payload) + if err != nil { + send(msg.ID, typeError, errPayload(err)) + continue + } + send(msg.ID, typeData, jsonPayload) + } + } + }() case typeStop: onDone, ok := opDone[msg.ID] diff --git a/graphqlws/internal/connection/connection_test.go b/graphqlws/internal/connection/connection_test.go index 073c9fa..8ad2334 100644 --- a/graphqlws/internal/connection/connection_test.go +++ b/graphqlws/internal/connection/connection_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/graph-gophers/graphql-transport-ws/graphqlws/event" "github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/connection" ) @@ -30,13 +29,12 @@ type message struct { func TestConnect(t *testing.T) { testTable := []struct { - name string - callbacks *callbacksHandler - messages []message + name string + svc *gqlService + messages []message }{ { - name: "connection_init_ok", - callbacks: &callbacksHandler{}, + name: "connection_init_ok", messages: []message{ { intention: clientSends, @@ -52,8 +50,7 @@ func TestConnect(t *testing.T) { }, }, { - name: "connection_init_error", - callbacks: &callbacksHandler{}, + name: "connection_init_error", messages: []message{ { intention: clientSends, @@ -74,10 +71,8 @@ func TestConnect(t *testing.T) { }, }, { - name: "start_query_ok", - callbacks: &callbacksHandler{ - payload: json.RawMessage(`{"data":{},"errors":null}`), - }, + name: "start_ok", + svc: newGQLService(`{"data":{},"errors":null}`), messages: []message{ { intention: clientSends, @@ -109,9 +104,7 @@ func TestConnect(t *testing.T) { }, { name: "start_query_data_error", - callbacks: &callbacksHandler{ - payload: json.RawMessage(`{"data":null,"errors":[{"message":"a error"}]}`), - }, + svc: newGQLService(`{"data":null,"errors":[{"message":"a error"}]}`), messages: []message{ { intention: clientSends, @@ -144,7 +137,7 @@ func TestConnect(t *testing.T) { }, { name: "start_query_error", - callbacks: &callbacksHandler{ + svc: &gqlService{ err: errors.New("some error"), }, messages: []message{ @@ -179,20 +172,29 @@ func TestConnect(t *testing.T) { for _, tt := range testTable { t.Run(tt.name, func(t *testing.T) { ws := newConnection() - go connection.Connect(ws, tt.callbacks) + go connection.Connect(ws, tt.svc) ws.test(t, tt.messages) }) } } -type callbacksHandler struct { - payload json.RawMessage - cancel func() - err error +type gqlService struct { + payloads <-chan interface{} + err error +} + +func newGQLService(pp ...string) *gqlService { + c := make(chan interface{}, len(pp)) + for _, p := range pp { + c <- json.RawMessage(p) + } + close(c) + + return &gqlService{payloads: c} } -func (h *callbacksHandler) OnOperation(ctx context.Context, args *event.OnOperationArgs) (json.RawMessage, func(), error) { - return h.payload, h.cancel, h.err +func (h *gqlService) Subscribe(ctx context.Context, document string, operationName string, variableValues map[string]interface{}) (payloads <-chan interface{}, err error) { + return h.payloads, h.err } func newConnection() *wsConnection {