From 524c0b5b356193587491ad493a29487c82636812 Mon Sep 17 00:00:00 2001 From: "Robin C. Pel" Date: Mon, 22 Nov 2021 17:36:19 +0100 Subject: [PATCH 1/4] Added code to the web socket so it closes when the context is cancelled (with an optional close reason). --- graphql/handler/transport/websocket.go | 13 ++++++++++++ .../transport/websocket_close_reason.go | 21 +++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 graphql/handler/transport/websocket_close_reason.go diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index d6dada03a7a..c91d9cc0d43 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -170,6 +170,10 @@ func (c *wsConnection) run() { go c.ping(ctx) } + // Close the connection when the context is cancelled. + // Will optionally send a "close reason" that is retrieved from the context. + go c.closeOnCancel(ctx) + for { start := graphql.Now() m, err := c.me.NextMessage() @@ -227,6 +231,15 @@ func (c *wsConnection) ping(ctx context.Context) { } } +func (c *wsConnection) closeOnCancel(ctx context.Context) { + <-ctx.Done() + + if r := closeReasonForContext(ctx); r != "" { + c.sendConnectionError(r) + } + c.close(websocket.CloseNormalClosure, "terminated") +} + func (c *wsConnection) subscribe(start time.Time, msg *message) { ctx := graphql.StartOperationTrace(c.ctx) var params *graphql.RawParams diff --git a/graphql/handler/transport/websocket_close_reason.go b/graphql/handler/transport/websocket_close_reason.go new file mode 100644 index 00000000000..121791b33d3 --- /dev/null +++ b/graphql/handler/transport/websocket_close_reason.go @@ -0,0 +1,21 @@ +package transport + +import ( + "context" +) + +// A private key for context that only this package can access. This is important +// to prevent collisions between different context uses +var closeReasonCtxKey = &wsCloseReasonContextKey{"close-reason"} +type wsCloseReasonContextKey struct { + name string +} + +func AppendCloseReason(ctx context.Context, reason string) context.Context { + return context.WithValue(ctx, closeReasonCtxKey, reason) +} + +func closeReasonForContext(ctx context.Context) string { + reason, _ := ctx.Value(closeReasonCtxKey).(string) + return reason +} From 7838459b70caf2db7236fb6b81e65068da8722ec Mon Sep 17 00:00:00 2001 From: "Robin C. Pel" Date: Thu, 25 Nov 2021 16:08:52 +0100 Subject: [PATCH 2/4] Added a test. --- graphql/handler/transport/websocket_test.go | 22 +++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/graphql/handler/transport/websocket_test.go b/graphql/handler/transport/websocket_test.go index af82c616fda..587d64b48e7 100644 --- a/graphql/handler/transport/websocket_test.go +++ b/graphql/handler/transport/websocket_test.go @@ -265,6 +265,28 @@ func TestWebsocketInitFunc(t *testing.T) { require.NoError(t, err) assert.Equal(t, "ok", resp.Empty) }) + + t.Run("can set a deadline on a websocket connection and close it with a reason", func(t *testing.T) { + h := testserver.New() + h.AddTransport(transport.Websocket{ + InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) { + newCtx, _ := context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5) + return newCtx, nil + }, + }) + srv := httptest.NewServer(h) + defer srv.Close() + + c := wsConnect(srv.URL) + require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) + assert.Equal(t, connectionAckMsg, readOp(c).Type) + assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) + + time.Sleep(time.Millisecond*10) + m := readOp(c) + assert.Equal(t, m.Type, connectionErrorMsg) + assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`) + }) } func TestWebsocketGraphqltransportwsSubprotocol(t *testing.T) { From 574107eaf8533ae95ff10d633e8995edd8d7df4b Mon Sep 17 00:00:00 2001 From: Steve Coffman Date: Thu, 25 Nov 2021 11:45:08 -0500 Subject: [PATCH 3/4] go fmt Signed-off-by: Steve Coffman --- .../transport/websocket_close_reason.go | 1 + .../websocket_graphql_transport_ws.go | 22 +++++++++---------- graphql/handler/transport/websocket_test.go | 2 +- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/graphql/handler/transport/websocket_close_reason.go b/graphql/handler/transport/websocket_close_reason.go index 121791b33d3..c8217debe74 100644 --- a/graphql/handler/transport/websocket_close_reason.go +++ b/graphql/handler/transport/websocket_close_reason.go @@ -7,6 +7,7 @@ import ( // A private key for context that only this package can access. This is important // to prevent collisions between different context uses var closeReasonCtxKey = &wsCloseReasonContextKey{"close-reason"} + type wsCloseReasonContextKey struct { name string } diff --git a/graphql/handler/transport/websocket_graphql_transport_ws.go b/graphql/handler/transport/websocket_graphql_transport_ws.go index de04404226b..a5b6e3a9b92 100644 --- a/graphql/handler/transport/websocket_graphql_transport_ws.go +++ b/graphql/handler/transport/websocket_graphql_transport_ws.go @@ -21,18 +21,16 @@ const ( graphqltransportwsPongMsg = graphqltransportwsMessageType("pong") ) -var ( - allGraphqltransportwsMessageTypes = []graphqltransportwsMessageType{ - graphqltransportwsConnectionInitMsg, - graphqltransportwsConnectionAckMsg, - graphqltransportwsSubscribeMsg, - graphqltransportwsNextMsg, - graphqltransportwsErrorMsg, - graphqltransportwsCompleteMsg, - graphqltransportwsPingMsg, - graphqltransportwsPongMsg, - } -) +var allGraphqltransportwsMessageTypes = []graphqltransportwsMessageType{ + graphqltransportwsConnectionInitMsg, + graphqltransportwsConnectionAckMsg, + graphqltransportwsSubscribeMsg, + graphqltransportwsNextMsg, + graphqltransportwsErrorMsg, + graphqltransportwsCompleteMsg, + graphqltransportwsPingMsg, + graphqltransportwsPongMsg, +} type ( graphqltransportwsMessageExchanger struct { diff --git a/graphql/handler/transport/websocket_test.go b/graphql/handler/transport/websocket_test.go index 587d64b48e7..5dec80ffb89 100644 --- a/graphql/handler/transport/websocket_test.go +++ b/graphql/handler/transport/websocket_test.go @@ -282,7 +282,7 @@ func TestWebsocketInitFunc(t *testing.T) { assert.Equal(t, connectionAckMsg, readOp(c).Type) assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) - time.Sleep(time.Millisecond*10) + time.Sleep(time.Millisecond * 10) m := readOp(c) assert.Equal(t, m.Type, connectionErrorMsg) assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`) From f74739b1abc78daa1359b8196eb20f0592b17d95 Mon Sep 17 00:00:00 2001 From: "Robin C. Pel" Date: Fri, 26 Nov 2021 09:36:00 +0100 Subject: [PATCH 4/4] Fix linter issues about the cancel function being thrown away. --- graphql/handler/transport/websocket_test.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/graphql/handler/transport/websocket_test.go b/graphql/handler/transport/websocket_test.go index 5dec80ffb89..44e78e18f16 100644 --- a/graphql/handler/transport/websocket_test.go +++ b/graphql/handler/transport/websocket_test.go @@ -268,10 +268,11 @@ func TestWebsocketInitFunc(t *testing.T) { t.Run("can set a deadline on a websocket connection and close it with a reason", func(t *testing.T) { h := testserver.New() + var cancel func() h.AddTransport(transport.Websocket{ - InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) { - newCtx, _ := context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5) - return newCtx, nil + InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) { + newCtx, cancel = context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5) + return }, }) srv := httptest.NewServer(h) @@ -282,6 +283,9 @@ func TestWebsocketInitFunc(t *testing.T) { assert.Equal(t, connectionAckMsg, readOp(c).Type) assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) + // Cancel should contain an actual value now, so let's call it when we exit this scope (to make the linter happy) + defer cancel() + time.Sleep(time.Millisecond * 10) m := readOp(c) assert.Equal(t, m.Type, connectionErrorMsg)