diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index db6542c701..51b1104ccc 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -44,12 +44,29 @@ type ( } WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) - WebsocketErrorFunc func(ctx context.Context, err error, isOnRead bool) + WebsocketErrorFunc func(ctx context.Context, err error) ) var errReadTimeout = errors.New("read timeout") -var _ graphql.Transport = Websocket{} +type WebsocketError struct { + Err error + + // IsReadError flags whether the error occurred on read or write to the websocket + IsReadError bool +} + +func (e WebsocketError) Error() string { + if e.IsReadError { + return fmt.Sprintf("websocket read: %v", e.Err) + } + return fmt.Sprintf("websocket write: %v", e.Err) +} + +var ( + _ graphql.Transport = Websocket{} + _ error = WebsocketError{} +) func (t Websocket) Supports(r *http.Request) bool { return r.Header.Get("Upgrade") != "" @@ -94,9 +111,12 @@ func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.Graph conn.run() } -func (c *wsConnection) handlePossibleError(err error, isOnRead bool) { +func (c *wsConnection) handlePossibleError(err error, isReadError bool) { if c.ErrorFunc != nil && err != nil { - c.ErrorFunc(c.ctx, err, isOnRead) + c.ErrorFunc(c.ctx, WebsocketError{ + Err: err, + IsReadError: isReadError, + }) } } diff --git a/graphql/handler/transport/websocket_test.go b/graphql/handler/transport/websocket_test.go index 9fddb28dc0..fb2a07bf8c 100644 --- a/graphql/handler/transport/websocket_test.go +++ b/graphql/handler/transport/websocket_test.go @@ -354,10 +354,11 @@ func TestWebSocketErrorFunc(t *testing.T) { errFuncCalled := make(chan bool, 1) h := testserver.New() h.AddTransport(transport.Websocket{ - ErrorFunc: func(_ context.Context, err error, isOnRead bool) { + ErrorFunc: func(_ context.Context, err error) { require.Error(t, err) - assert.Equal(t, err.Error(), "invalid message received") - assert.True(t, isOnRead) + assert.Equal(t, err.Error(), "websocket read: invalid message received") + assert.IsType(t, transport.WebsocketError{}, err) + assert.True(t, err.(transport.WebsocketError).IsReadError) errFuncCalled <- true }, }) @@ -385,7 +386,7 @@ func TestWebSocketErrorFunc(t *testing.T) { InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) { return ctx, errors.New("this is not what we agreed upon") }, - ErrorFunc: func(_ context.Context, err error, isOnRead bool) { + ErrorFunc: func(_ context.Context, err error) { assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) }, }) @@ -405,7 +406,7 @@ func TestWebSocketErrorFunc(t *testing.T) { time.AfterFunc(time.Millisecond*5, cancel) return newCtx, nil }, - ErrorFunc: func(_ context.Context, err error, isOnRead bool) { + ErrorFunc: func(_ context.Context, err error) { assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) }, }) @@ -427,7 +428,7 @@ func TestWebSocketErrorFunc(t *testing.T) { newCtx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Millisecond*5)) return newCtx, nil }, - ErrorFunc: func(_ context.Context, err error, isOnRead bool) { + ErrorFunc: func(_ context.Context, err error) { assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) }, })