Skip to content

Commit

Permalink
Merge pull request #750 from 99designs/ws-connection-param-check
Browse files Browse the repository at this point in the history
[websocket] Add a config to reject initial connection
  • Loading branch information
Eddy Nguyen authored Jun 19, 2019
2 parents 090f0bd + c397be0 commit 726a94f
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 1 deletion.
11 changes: 11 additions & 0 deletions handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ type params struct {
Variables map[string]interface{} `json:"variables"`
}

type websocketInitFunc func(ctx context.Context, initPayload InitPayload) error

type Config struct {
cacheSize int
upgrader websocket.Upgrader
Expand All @@ -40,6 +42,7 @@ type Config struct {
tracer graphql.Tracer
complexityLimit int
complexityLimitFunc graphql.ComplexityLimitFunc
websocketInitFunc websocketInitFunc
disableIntrospection bool
connectionKeepAlivePingInterval time.Duration
uploadMaxMemory int64
Expand Down Expand Up @@ -250,6 +253,14 @@ func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) {
tw.tracer1.EndOperationExecution(ctx)
}

// WebsocketInitFunc is called when the server receives connection init message from the client.
// This can be used to check initial payload to see whether to accept the websocket connection.
func WebsocketInitFunc(websocketInitFunc func(ctx context.Context, initPayload InitPayload) error) Option {
return func(cfg *Config) {
cfg.websocketInitFunc = websocketInitFunc
}
}

// CacheSize sets the maximum size of the query cache.
// If size is less than or equal to 0, the cache is disabled.
func CacheSize(size int) Option {
Expand Down
10 changes: 9 additions & 1 deletion handler/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

"github.com/99designs/gqlgen/graphql"
"github.com/gorilla/websocket"
"github.com/hashicorp/golang-lru"
lru "github.com/hashicorp/golang-lru"
"github.com/vektah/gqlparser"
"github.com/vektah/gqlparser/ast"
"github.com/vektah/gqlparser/gqlerror"
Expand Down Expand Up @@ -94,6 +94,14 @@ func (c *wsConnection) init() bool {
}
}

if c.cfg.websocketInitFunc != nil {
if err := c.cfg.websocketInitFunc(c.ctx, c.initPayload); err != nil {
c.sendConnectionError(err.Error())
c.close(websocket.CloseNormalClosure, "terminated")
return false
}
}

c.write(&operationMessage{Type: connectionAckMsg})
case connectionTerminateMsg:
c.close(websocket.CloseNormalClosure, "terminated")
Expand Down
51 changes: 51 additions & 0 deletions handler/websocket_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package handler

import (
"context"
"encoding/json"
"errors"
"net/http/httptest"
"strings"
"testing"
Expand Down Expand Up @@ -158,6 +160,55 @@ func TestWebsocketWithKeepAlive(t *testing.T) {
})
}

func TestWebsocketInitFunc(t *testing.T) {
next := make(chan struct{})

t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) {
h := GraphQL(&executableSchemaStub{next})
srv := httptest.NewServer(h)
defer srv.Close()

c := wsConnect(srv.URL)
defer c.Close()

require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))

require.Equal(t, connectionAckMsg, readOp(c).Type)
})

t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) error {
return nil
}))
srv := httptest.NewServer(h)
defer srv.Close()

c := wsConnect(srv.URL)
defer c.Close()

require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))

require.Equal(t, connectionAckMsg, readOp(c).Type)
})

t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) error {
return errors.New("invalid init payload")
}))
srv := httptest.NewServer(h)
defer srv.Close()

c := wsConnect(srv.URL)
defer c.Close()

require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))

msg := readOp(c)
require.Equal(t, connectionErrorMsg, msg.Type)
require.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload))
})
}

func wsConnect(url string) *websocket.Conn {
c, _, err := websocket.DefaultDialer.Dial(strings.Replace(url, "http://", "ws://", -1), nil)
if err != nil {
Expand Down

0 comments on commit 726a94f

Please sign in to comment.