diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 41d31daada09a..af48f3f498a5b 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -163,11 +163,17 @@ func (s *WebSuite) SetUpTest(c *C) { s.user = u.Username s.clock = clockwork.NewFakeClock() + networkingConfig, err := types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{ + KeepAliveInterval: types.Duration(10 * time.Second), + }) + c.Assert(err, IsNil) + s.server, err = auth.NewTestServer(auth.TestServerConfig{ Auth: auth.TestAuthServerConfig{ - ClusterName: "localhost", - Dir: c.MkDir(), - Clock: s.clock, + ClusterName: "localhost", + Dir: c.MkDir(), + Clock: s.clock, + ClusterNetworkingConfig: networkingConfig, }, }) c.Assert(err, IsNil) @@ -1037,6 +1043,47 @@ func (s *WebSuite) TestResizeTerminal(c *C) { } } +// TestTerminalPing tests that the server sends continuous ping control messages. +func (s *WebSuite) TestTerminalPing(c *C) { + ws, err := s.makeTerminal(s.authPack(c, "foo")) + c.Assert(err, IsNil) + defer ws.Close() + + closed := false + done := make(chan struct{}) + ws.SetPingHandler(func(message string) error { + if closed == false { + close(done) + closed = true + } + + err := ws.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second)) + if err == websocket.ErrCloseSent { + return nil + } else if e, ok := err.(net.Error); ok && e.Temporary() { + return nil + } + return err + }) + + // We need to continuously read incoming messages in order to process ping messages. + // We only care about receiving a ping here so dropping them is fine. + go func() { + for { + _, _, err := ws.ReadMessage() + if err != nil { + return + } + } + }() + + select { + case <-done: + case <-time.After(time.Minute): + c.Fatal("timeout waiting for ping") + } +} + func (s *WebSuite) TestTerminal(c *C) { ws, err := s.makeTerminal(s.authPack(c, "foo")) c.Assert(err, IsNil) diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 2e35cf3981fba..937cfd0e892c9 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -80,6 +80,8 @@ type TerminalRequest struct { InteractiveCommand []string `json:"-"` // KeepAliveInterval is the interval for sending ping frames to web client. + // This value is pulled from the cluster network config and + // guaranteed to be set to a nonzero value as it's enforced by the configuration. KeepAliveInterval time.Duration } @@ -206,6 +208,7 @@ func (t *TerminalHandler) Serve(w http.ResponseWriter, r *http.Request) { return } + ws.SetReadDeadline(deadlineForInterval(t.params.KeepAliveInterval)) t.handler(ws, r) } @@ -224,6 +227,33 @@ func (t *TerminalHandler) Close() error { return nil } +// startPingLoop starts a loop that will continuously send a ping frame through the websocket +// to prevent the connection between web client and teleport proxy from becoming idle. +// Interval is determined by the keep_alive_interval config set by user (or default). +// Loop will terminate when there is an error sending ping frame or when terminal session is closed. +func (t *TerminalHandler) startPingLoop(ws *websocket.Conn) { + t.log.Debugf("Starting websocket ping loop with interval %v.", t.params.KeepAliveInterval) + tickerCh := time.NewTicker(t.params.KeepAliveInterval) + defer tickerCh.Stop() + + for { + select { + case <-tickerCh.C: + // A short deadline is used here to detect a broken connection quickly. + // If this is just a temporary issue, we will retry shortly anyway. + deadline := time.Now().Add(time.Second) + if err := ws.WriteControl(websocket.PingMessage, nil, deadline); err != nil { + t.log.Errorf("Unable to send ping frame to web client: %v.", err) + t.Close() + return + } + case <-t.terminalContext.Done(): + t.log.Debug("Terminating websocket ping loop.") + return + } + } +} + // handler is the main websocket loop. It creates a Teleport client and then // pumps raw events and audit events back to the client until the SSH session // is complete. @@ -247,6 +277,15 @@ func (t *TerminalHandler) handler(ws *websocket.Conn, r *http.Request) { t.log.Debugf("Creating websocket stream for %v.", t.params.SessionID) + // Update the read deadline upon receiving a pong message. + ws.SetPongHandler(func(_ string) error { + ws.SetReadDeadline(deadlineForInterval(t.params.KeepAliveInterval)) + return nil + }) + + // Start sending ping frames through websocket to client. + go t.startPingLoop(ws) + // Pump raw terminal in/out and audit events into the websocket. go t.streamTerminal(ws, tc) go t.streamEvents(ws, tc) @@ -699,12 +738,14 @@ func (w *terminalStream) Read(out []byte) (n int, err error) { return w.terminal.read(out, w.ws) } -// SetReadDeadline sets the network read deadline on the underlying websocket. -func (w *terminalStream) SetReadDeadline(t time.Time) error { - return w.ws.SetReadDeadline(t) -} - // Close the websocket. func (w *terminalStream) Close() error { return w.ws.Close() } + +// deadlineForInterval returns a suitable network read deadline for a given ping interval. +// We chose to take the current time plus twice the interval to allow the timeframe of one interval +// to wait for a returned pong message. +func deadlineForInterval(interval time.Duration) time.Time { + return time.Now().Add(interval * 2) +}