From 31360245a8075cc34b33be924a95704732066370 Mon Sep 17 00:00:00 2001 From: dima Date: Mon, 25 Apr 2022 16:58:04 -0700 Subject: [PATCH 1/2] Add argument to WebsocketErrorFunc to determine whether the error ocured on read or write to the websocket. --- graphql/handler/transport/websocket.go | 10 +++++----- graphql/handler/transport/websocket_test.go | 9 +++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index a0429b7481..db6542c701 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -44,7 +44,7 @@ type ( } WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) - WebsocketErrorFunc func(ctx context.Context, err error) + WebsocketErrorFunc func(ctx context.Context, err error, isOnRead bool) ) var errReadTimeout = errors.New("read timeout") @@ -94,9 +94,9 @@ func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.Graph conn.run() } -func (c *wsConnection) handlePossibleError(err error) { +func (c *wsConnection) handlePossibleError(err error, isOnRead bool) { if c.ErrorFunc != nil && err != nil { - c.ErrorFunc(c.ctx, err) + c.ErrorFunc(c.ctx, err, isOnRead) } } @@ -181,7 +181,7 @@ func (c *wsConnection) init() bool { func (c *wsConnection) write(msg *message) { c.mu.Lock() - c.handlePossibleError(c.me.Send(msg)) + c.handlePossibleError(c.me.Send(msg), false) c.mu.Unlock() } @@ -227,7 +227,7 @@ func (c *wsConnection) run() { if err != nil { // If the connection got closed by us, don't report the error if !errors.Is(err, net.ErrClosed) { - c.handlePossibleError(err) + c.handlePossibleError(err, true) } return } diff --git a/graphql/handler/transport/websocket_test.go b/graphql/handler/transport/websocket_test.go index 7c84f35262..9fddb28dc0 100644 --- a/graphql/handler/transport/websocket_test.go +++ b/graphql/handler/transport/websocket_test.go @@ -354,9 +354,10 @@ func TestWebSocketErrorFunc(t *testing.T) { errFuncCalled := make(chan bool, 1) h := testserver.New() h.AddTransport(transport.Websocket{ - ErrorFunc: func(_ context.Context, err error) { + ErrorFunc: func(_ context.Context, err error, isOnRead bool) { require.Error(t, err) assert.Equal(t, err.Error(), "invalid message received") + assert.True(t, isOnRead) errFuncCalled <- true }, }) @@ -384,7 +385,7 @@ func TestWebSocketErrorFunc(t *testing.T) { InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) { return ctx, errors.New("this is not what we agreed upon") }, - ErrorFunc: func(_ context.Context, err error) { + ErrorFunc: func(_ context.Context, err error, isOnRead bool) { assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) }, }) @@ -404,7 +405,7 @@ func TestWebSocketErrorFunc(t *testing.T) { time.AfterFunc(time.Millisecond*5, cancel) return newCtx, nil }, - ErrorFunc: func(_ context.Context, err error) { + ErrorFunc: func(_ context.Context, err error, isOnRead bool) { assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) }, }) @@ -426,7 +427,7 @@ func TestWebSocketErrorFunc(t *testing.T) { newCtx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Millisecond*5)) return newCtx, nil }, - ErrorFunc: func(_ context.Context, err error) { + ErrorFunc: func(_ context.Context, err error, isOnRead bool) { assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) }, }) From 84a01e0a61402fddf9cb93c30b2086634a092229 Mon Sep 17 00:00:00 2001 From: dima Date: Fri, 29 Apr 2022 10:08:25 -0700 Subject: [PATCH 2/2] Wrap websocket error --- graphql/handler/transport/websocket.go | 28 ++++++++++++++++++--- graphql/handler/transport/websocket_test.go | 13 +++++----- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index db6542c701..51b1104ccc 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -44,12 +44,29 @@ type ( } WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) - WebsocketErrorFunc func(ctx context.Context, err error, isOnRead bool) + WebsocketErrorFunc func(ctx context.Context, err error) ) var errReadTimeout = errors.New("read timeout") -var _ graphql.Transport = Websocket{} +type WebsocketError struct { + Err error + + // IsReadError flags whether the error occurred on read or write to the websocket + IsReadError bool +} + +func (e WebsocketError) Error() string { + if e.IsReadError { + return fmt.Sprintf("websocket read: %v", e.Err) + } + return fmt.Sprintf("websocket write: %v", e.Err) +} + +var ( + _ graphql.Transport = Websocket{} + _ error = WebsocketError{} +) func (t Websocket) Supports(r *http.Request) bool { return r.Header.Get("Upgrade") != "" @@ -94,9 +111,12 @@ func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.Graph conn.run() } -func (c *wsConnection) handlePossibleError(err error, isOnRead bool) { +func (c *wsConnection) handlePossibleError(err error, isReadError bool) { if c.ErrorFunc != nil && err != nil { - c.ErrorFunc(c.ctx, err, isOnRead) + c.ErrorFunc(c.ctx, WebsocketError{ + Err: err, + IsReadError: isReadError, + }) } } diff --git a/graphql/handler/transport/websocket_test.go b/graphql/handler/transport/websocket_test.go index 9fddb28dc0..fb2a07bf8c 100644 --- a/graphql/handler/transport/websocket_test.go +++ b/graphql/handler/transport/websocket_test.go @@ -354,10 +354,11 @@ func TestWebSocketErrorFunc(t *testing.T) { errFuncCalled := make(chan bool, 1) h := testserver.New() h.AddTransport(transport.Websocket{ - ErrorFunc: func(_ context.Context, err error, isOnRead bool) { + ErrorFunc: func(_ context.Context, err error) { require.Error(t, err) - assert.Equal(t, err.Error(), "invalid message received") - assert.True(t, isOnRead) + assert.Equal(t, err.Error(), "websocket read: invalid message received") + assert.IsType(t, transport.WebsocketError{}, err) + assert.True(t, err.(transport.WebsocketError).IsReadError) errFuncCalled <- true }, }) @@ -385,7 +386,7 @@ func TestWebSocketErrorFunc(t *testing.T) { InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) { return ctx, errors.New("this is not what we agreed upon") }, - ErrorFunc: func(_ context.Context, err error, isOnRead bool) { + ErrorFunc: func(_ context.Context, err error) { assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) }, }) @@ -405,7 +406,7 @@ func TestWebSocketErrorFunc(t *testing.T) { time.AfterFunc(time.Millisecond*5, cancel) return newCtx, nil }, - ErrorFunc: func(_ context.Context, err error, isOnRead bool) { + ErrorFunc: func(_ context.Context, err error) { assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) }, }) @@ -427,7 +428,7 @@ func TestWebSocketErrorFunc(t *testing.T) { newCtx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Millisecond*5)) return newCtx, nil }, - ErrorFunc: func(_ context.Context, err error, isOnRead bool) { + ErrorFunc: func(_ context.Context, err error) { assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) }, })