Skip to content

Commit

Permalink
Add manual websocket pingloop (#11765)
Browse files Browse the repository at this point in the history
  • Loading branch information
xacrimon committed Apr 13, 2022
1 parent 7eaf9f1 commit 398d9e9
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 8 deletions.
53 changes: 50 additions & 3 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,17 @@ func (s *WebSuite) SetUpTest(c *C) {
s.user = u.Username
s.clock = clockwork.NewFakeClock()

networkingConfig, err := types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{
KeepAliveInterval: types.Duration(10 * time.Second),
})
c.Assert(err, IsNil)

s.server, err = auth.NewTestServer(auth.TestServerConfig{
Auth: auth.TestAuthServerConfig{
ClusterName: "localhost",
Dir: c.MkDir(),
Clock: s.clock,
ClusterName: "localhost",
Dir: c.MkDir(),
Clock: s.clock,
ClusterNetworkingConfig: networkingConfig,
},
})
c.Assert(err, IsNil)
Expand Down Expand Up @@ -1037,6 +1043,47 @@ func (s *WebSuite) TestResizeTerminal(c *C) {
}
}

// TestTerminalPing tests that the server sends continuous ping control messages.
func (s *WebSuite) TestTerminalPing(c *C) {
ws, err := s.makeTerminal(s.authPack(c, "foo"))
c.Assert(err, IsNil)
defer ws.Close()

closed := false
done := make(chan struct{})
ws.SetPingHandler(func(message string) error {
if closed == false {
close(done)
closed = true
}

err := ws.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second))
if err == websocket.ErrCloseSent {
return nil
} else if e, ok := err.(net.Error); ok && e.Temporary() {
return nil
}
return err
})

// We need to continuously read incoming messages in order to process ping messages.
// We only care about receiving a ping here so dropping them is fine.
go func() {
for {
_, _, err := ws.ReadMessage()
if err != nil {
return
}
}
}()

select {
case <-done:
case <-time.After(time.Minute):
c.Fatal("timeout waiting for ping")
}
}

func (s *WebSuite) TestTerminal(c *C) {
ws, err := s.makeTerminal(s.authPack(c, "foo"))
c.Assert(err, IsNil)
Expand Down
51 changes: 46 additions & 5 deletions lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ type TerminalRequest struct {
InteractiveCommand []string `json:"-"`

// KeepAliveInterval is the interval for sending ping frames to web client.
// This value is pulled from the cluster network config and
// guaranteed to be set to a nonzero value as it's enforced by the configuration.
KeepAliveInterval time.Duration
}

Expand Down Expand Up @@ -206,6 +208,7 @@ func (t *TerminalHandler) Serve(w http.ResponseWriter, r *http.Request) {
return
}

ws.SetReadDeadline(deadlineForInterval(t.params.KeepAliveInterval))
t.handler(ws, r)
}

Expand All @@ -224,6 +227,33 @@ func (t *TerminalHandler) Close() error {
return nil
}

// startPingLoop starts a loop that will continuously send a ping frame through the websocket
// to prevent the connection between web client and teleport proxy from becoming idle.
// Interval is determined by the keep_alive_interval config set by user (or default).
// Loop will terminate when there is an error sending ping frame or when terminal session is closed.
func (t *TerminalHandler) startPingLoop(ws *websocket.Conn) {
t.log.Debugf("Starting websocket ping loop with interval %v.", t.params.KeepAliveInterval)
tickerCh := time.NewTicker(t.params.KeepAliveInterval)
defer tickerCh.Stop()

for {
select {
case <-tickerCh.C:
// A short deadline is used here to detect a broken connection quickly.
// If this is just a temporary issue, we will retry shortly anyway.
deadline := time.Now().Add(time.Second)
if err := ws.WriteControl(websocket.PingMessage, nil, deadline); err != nil {
t.log.Errorf("Unable to send ping frame to web client: %v.", err)
t.Close()
return
}
case <-t.terminalContext.Done():
t.log.Debug("Terminating websocket ping loop.")
return
}
}
}

// handler is the main websocket loop. It creates a Teleport client and then
// pumps raw events and audit events back to the client until the SSH session
// is complete.
Expand All @@ -247,6 +277,15 @@ func (t *TerminalHandler) handler(ws *websocket.Conn, r *http.Request) {

t.log.Debugf("Creating websocket stream for %v.", t.params.SessionID)

// Update the read deadline upon receiving a pong message.
ws.SetPongHandler(func(_ string) error {
ws.SetReadDeadline(deadlineForInterval(t.params.KeepAliveInterval))
return nil
})

// Start sending ping frames through websocket to client.
go t.startPingLoop(ws)

// Pump raw terminal in/out and audit events into the websocket.
go t.streamTerminal(ws, tc)
go t.streamEvents(ws, tc)
Expand Down Expand Up @@ -699,12 +738,14 @@ func (w *terminalStream) Read(out []byte) (n int, err error) {
return w.terminal.read(out, w.ws)
}

// SetReadDeadline sets the network read deadline on the underlying websocket.
func (w *terminalStream) SetReadDeadline(t time.Time) error {
return w.ws.SetReadDeadline(t)
}

// Close the websocket.
func (w *terminalStream) Close() error {
return w.ws.Close()
}

// deadlineForInterval returns a suitable network read deadline for a given ping interval.
// We chose to take the current time plus twice the interval to allow the timeframe of one interval
// to wait for a returned pong message.
func deadlineForInterval(interval time.Duration) time.Time {
return time.Now().Add(interval * 2)
}

0 comments on commit 398d9e9

Please sign in to comment.