diff --git a/handler/graphql.go b/handler/graphql.go index a22542225fc..8c3882ce9c2 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -30,6 +30,8 @@ type params struct { Variables map[string]interface{} `json:"variables"` } +type websocketInitFunc func(ctx context.Context, initPayload InitPayload) error + type Config struct { cacheSize int upgrader websocket.Upgrader @@ -40,6 +42,7 @@ type Config struct { tracer graphql.Tracer complexityLimit int complexityLimitFunc graphql.ComplexityLimitFunc + websocketInitFunc websocketInitFunc disableIntrospection bool connectionKeepAlivePingInterval time.Duration uploadMaxMemory int64 @@ -250,6 +253,14 @@ func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) { tw.tracer1.EndOperationExecution(ctx) } +// WebsocketInitFunc is called when the server receives connection init message from the client. +// This can be used to check initial payload to see whether to accept the websocket connection. +func WebsocketInitFunc(websocketInitFunc func(ctx context.Context, initPayload InitPayload) error) Option { + return func(cfg *Config) { + cfg.websocketInitFunc = websocketInitFunc + } +} + // CacheSize sets the maximum size of the query cache. // If size is less than or equal to 0, the cache is disabled. func CacheSize(size int) Option { diff --git a/handler/websocket.go b/handler/websocket.go index 58f38e5d48d..07a1a8c2dd8 100644 --- a/handler/websocket.go +++ b/handler/websocket.go @@ -12,7 +12,7 @@ import ( "github.com/99designs/gqlgen/graphql" "github.com/gorilla/websocket" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" "github.com/vektah/gqlparser" "github.com/vektah/gqlparser/ast" "github.com/vektah/gqlparser/gqlerror" @@ -94,6 +94,14 @@ func (c *wsConnection) init() bool { } } + if c.cfg.websocketInitFunc != nil { + if err := c.cfg.websocketInitFunc(c.ctx, c.initPayload); err != nil { + c.sendConnectionError(err.Error()) + c.close(websocket.CloseNormalClosure, "terminated") + return false + } + } + c.write(&operationMessage{Type: connectionAckMsg}) case connectionTerminateMsg: c.close(websocket.CloseNormalClosure, "terminated") diff --git a/handler/websocket_test.go b/handler/websocket_test.go index f8675475c94..dc3e656e5fe 100644 --- a/handler/websocket_test.go +++ b/handler/websocket_test.go @@ -1,7 +1,9 @@ package handler import ( + "context" "encoding/json" + "errors" "net/http/httptest" "strings" "testing" @@ -158,6 +160,55 @@ func TestWebsocketWithKeepAlive(t *testing.T) { }) } +func TestWebsocketInitFunc(t *testing.T) { + next := make(chan struct{}) + + t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) { + h := GraphQL(&executableSchemaStub{next}) + srv := httptest.NewServer(h) + defer srv.Close() + + c := wsConnect(srv.URL) + defer c.Close() + + require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) + + require.Equal(t, connectionAckMsg, readOp(c).Type) + }) + + t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { + h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) error { + return nil + })) + srv := httptest.NewServer(h) + defer srv.Close() + + c := wsConnect(srv.URL) + defer c.Close() + + require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) + + require.Equal(t, connectionAckMsg, readOp(c).Type) + }) + + t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { + h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) error { + return errors.New("invalid init payload") + })) + srv := httptest.NewServer(h) + defer srv.Close() + + c := wsConnect(srv.URL) + defer c.Close() + + require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) + + msg := readOp(c) + require.Equal(t, connectionErrorMsg, msg.Type) + require.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload)) + }) +} + func wsConnect(url string) *websocket.Conn { c, _, err := websocket.DefaultDialer.Dial(strings.Replace(url, "http://", "ws://", -1), nil) if err != nil {