Skip to content

Commit

Permalink
Add on-close handler for websockets. (#2612)
Browse files Browse the repository at this point in the history
* working without test

* test
  • Loading branch information
szgupta authored Apr 8, 2023
1 parent 4548815 commit 8b38c0e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
8 changes: 8 additions & 0 deletions graphql/handler/transport/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type (
InitFunc WebsocketInitFunc
InitTimeout time.Duration
ErrorFunc WebsocketErrorFunc
CloseFunc WebsocketCloseFunc
KeepAlivePingInterval time.Duration
PingPongInterval time.Duration

Expand All @@ -45,6 +46,9 @@ type (

WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
WebsocketErrorFunc func(ctx context.Context, err error)

// Callback called when websocket is closed.
WebsocketCloseFunc func(ctx context.Context, closeCode int)
)

var errReadTimeout = errors.New("read timeout")
Expand Down Expand Up @@ -433,4 +437,8 @@ func (c *wsConnection) close(closeCode int, message string) {
}
c.mu.Unlock()
_ = c.conn.Close()

if c.CloseFunc != nil {
c.CloseFunc(c.ctx, closeCode)
}
}
52 changes: 52 additions & 0 deletions graphql/handler/transport/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,58 @@ func TestWebSocketErrorFunc(t *testing.T) {
})
}

func TestWebSocketCloseFunc(t *testing.T) {
t.Run("the on close handler gets called when the websocket is closed", func(t *testing.T) {
closeFuncCalled := make(chan bool, 1)
h := testserver.New()
h.AddTransport(transport.Websocket{
CloseFunc: func(_ context.Context, _closeCode int) {
closeFuncCalled <- true
},
})

srv := httptest.NewServer(h)
defer srv.Close()

c := wsConnect(srv.URL)
require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
assert.Equal(t, connectionAckMsg, readOp(c).Type)
assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))

select {
case res := <-closeFuncCalled:
assert.True(t, res)
case <-time.NewTimer(time.Millisecond * 20).C:
assert.Fail(t, "The close handler was not called in time")
}
})

t.Run("init func errors call the close handler", func(t *testing.T) {
h := testserver.New()
closeFuncCalled := make(chan bool, 1)
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
return ctx, errors.New("error during init")
},
CloseFunc: func(_ context.Context, _closeCode int) {
closeFuncCalled <- true
},
})
srv := httptest.NewServer(h)
defer srv.Close()

c := wsConnect(srv.URL)
require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
select {
case res := <-closeFuncCalled:
assert.True(t, res)
case <-time.NewTimer(time.Millisecond * 20).C:
assert.Fail(t, "The close handler was not called in time")
}
})
}

func TestWebsocketGraphqltransportwsSubprotocol(t *testing.T) {
initialize := func(ws transport.Websocket) (*testserver.TestServer, *httptest.Server) {
h := testserver.New()
Expand Down

0 comments on commit 8b38c0e

Please sign in to comment.