diff --git a/share/tunnel/tunnel.go b/share/tunnel/tunnel.go index d59804f6..6e6dae10 100644 --- a/share/tunnel/tunnel.go +++ b/share/tunnel/tunnel.go @@ -166,18 +166,44 @@ func (t *Tunnel) BindRemotes(ctx context.Context, remotes []*settings.Remote) er } func (t *Tunnel) keepAliveLoop(sshConn ssh.Conn) { - //ping forever + // ping forever with a timeout +PingCheckOLoop: for { time.Sleep(t.Config.KeepAlive) - _, b, err := sshConn.SendRequest("ping", true, nil) - if err != nil { - break - } - if len(b) > 0 && !bytes.Equal(b, []byte("pong")) { - t.Debugf("strange ping response") - break + + ctx, cancel := context.WithTimeout(context.Background(), t.Config.KeepAlive) + defer cancel() + + responseCh := make(chan []byte, 1) + errCh := make(chan error, 1) + + // Asynchronously send a 'ping' request via SSH + go func() { + _, b, err := sshConn.SendRequest("ping", true, nil) + if err != nil { + errCh <- err + return + } + responseCh <- b + }() + + // Wait for a response, error, or timeout from the asynchronous 'ping' request + select { + case response := <-responseCh: + if len(response) > 0 && !bytes.Equal(response, []byte("pong")) { + t.Debugf("Unexpected ping response: %s", response) + break PingCheckOLoop + } + case err := <-errCh: + if err != nil { + t.Debugf("Failed to send ping: %s", err) + break PingCheckOLoop + } + case <-ctx.Done(): + t.Debugf("Ping timed out") + break PingCheckOLoop } } - //close ssh connection on abnormal ping + // Close the SSH connection on abnormal ping sshConn.Close() }