From dd2fb3c1e358cf733e3d2e8b71dfd1622aaf1c78 Mon Sep 17 00:00:00 2001 From: Cyril Goust Date: Tue, 13 Apr 2021 07:35:59 +0200 Subject: [PATCH] handle ping pong interval & better handle active subscription close --- graphql/handler/transport/websocket.go | 39 ++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index 3089a877999..fb989657515 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -34,6 +34,7 @@ type ( Upgrader websocket.Upgrader InitFunc WebsocketInitFunc KeepAlivePingInterval time.Duration + PingPongInterval time.Duration } wsConnection struct { Websocket @@ -42,6 +43,7 @@ type ( active map[string]context.CancelFunc mu sync.Mutex keepAliveTicker *time.Ticker + pingPongTicker *time.Ticker exec graphql.GraphExecutor initPayload InitPayload @@ -138,7 +140,6 @@ func (c *wsConnection) run() { ctx, cancel := context.WithCancel(c.ctx) defer func() { cancel() - c.close(websocket.CloseAbnormalClosure, "unexpected closure") }() // Create a timer that will fire every interval to keep the connection alive. @@ -146,14 +147,31 @@ func (c *wsConnection) run() { c.mu.Lock() c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval) c.mu.Unlock() - go c.keepAlive(ctx) } + // Create a timer that will fire every interval a ping message that should + // receive a pong (SetPongHandler in init() function) + if c.PingPongInterval != 0 { + + pongWait := 2 * c.PingPongInterval + c.conn.SetReadDeadline(time.Now().Add(pongWait)) + c.conn.SetPongHandler(func(string) error { + return c.conn.SetReadDeadline(time.Now().UTC().Add(pongWait)) + }) + + c.mu.Lock() + c.pingPongTicker = time.NewTicker(c.PingPongInterval) + c.mu.Unlock() + + go c.ping(ctx) + } + for { start := graphql.Now() message := c.readOp() if message == nil { + c.close(websocket.CloseAbnormalClosure, "unexpected closure") return } @@ -190,6 +208,20 @@ func (c *wsConnection) keepAlive(ctx context.Context) { } } +func (c *wsConnection) ping(ctx context.Context) { + for { + select { + case <-ctx.Done(): + c.pingPongTicker.Stop() + return + case <-c.pingPongTicker.C: + c.mu.Lock() + c.conn.WriteMessage(websocket.PingMessage, nil) + c.mu.Unlock() + } + } +} + func (c *wsConnection) subscribe(start time.Time, message *operationMessage) { ctx := graphql.StartOperationTrace(c.ctx) var params *graphql.RawParams @@ -311,6 +343,9 @@ func (c *wsConnection) readOp() *operationMessage { func (c *wsConnection) close(closeCode int, message string) { c.mu.Lock() _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message)) + for key, closer := range c.active { + closer() + } c.mu.Unlock() _ = c.conn.Close() }