Skip to content

Commit

Permalink
Connection Manager: Disconnect and reconnect on failure (#273)
Browse files Browse the repository at this point in the history
Co-authored-by: Oliver Geiselhardt-Herms <[email protected]>
  • Loading branch information
taktv6 and Oliver Geiselhardt-Herms authored Dec 18, 2024
1 parent 6cdaa0c commit bfdc266
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 135 deletions.
2 changes: 1 addition & 1 deletion junos_collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func deviceInterfaceRegex(cfg *config.Config, host string) *regexp.Regexp {
}

func clientForDevice(device *connector.Device, connManager *connector.SSHConnectionManager) (*rpc.Client, error) {
conn, err := connManager.Connect(device)
conn, err := connManager.GetSSHConnection(device)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func initChannels(ctx context.Context) {

func shutdown() {
log.Infoln("Closing connections to devices")
connManager.Close()
connManager.CloseAll()
os.Exit(0)
}

Expand Down Expand Up @@ -198,7 +198,7 @@ func reinitialize() error {
defer configMu.Unlock()

if connManager != nil {
connManager.Close()
connManager.CloseAll()
connManager = nil
}

Expand Down
184 changes: 155 additions & 29 deletions pkg/connector/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,83 @@ import (
"time"

"github.com/pkg/errors"
log "github.com/sirupsen/logrus"

"golang.org/x/crypto/ssh"
)

// SSHConnection encapsulates the connection to the device
type SSHConnection struct {
device *Device
client *ssh.Client
conn net.Conn
lastUsed time.Time
mu sync.Mutex
done chan struct{}
device *Device
sshClient *ssh.Client
tcpConn net.Conn
isConnected bool
mu sync.RWMutex // protects sshClient, tcpConn and isConnected
lastUsed time.Time
lastUsedMu sync.RWMutex
done chan struct{}
keepAliveInterval time.Duration
keepAliveTimeout time.Duration
}

// RunCommand runs a command against the device
func (c *SSHConnection) RunCommand(cmd string) ([]byte, error) {
func NewSSHConnection(device *Device, keepAliveInterval time.Duration, keepAliveTimeout time.Duration) *SSHConnection {
return &SSHConnection{
device: device,
keepAliveInterval: keepAliveInterval,
keepAliveTimeout: keepAliveTimeout,
done: make(chan struct{}),
}
}

func (c *SSHConnection) Start(expiredConnectionTimeout time.Duration) error {
err := c.connect()
if err != nil {
return err
}

go c.keepalive(expiredConnectionTimeout)
return nil
}

func (c *SSHConnection) Stop(err error) {
log.Infof("Stopping SSH connection with %s (reason: %v)", c.device.Host, err)

c.mu.Lock()
defer c.mu.Unlock()

c.lastUsed = time.Now()
if !c.isConnected {
return
}

close(c.done)

if c.client == nil {
return nil, errors.New(fmt.Sprintf("not connected with %s", c.conn.RemoteAddr().String()))
if c.sshClient != nil {
c.sshClient.Close()
c.sshClient = nil
}

session, err := c.client.NewSession()
if c.tcpConn != nil {
c.tcpConn.Close()
c.tcpConn = nil
}

c.isConnected = false
}

// RunCommand runs a command against the device
func (c *SSHConnection) RunCommand(cmd string) ([]byte, error) {
c.setLastUsed(time.Now())

sshClient := c.getSSHClient()
if sshClient == nil {
c.Stop(fmt.Errorf("No ssh client"))
return nil, errors.New(fmt.Sprintf("no SSH client to %s", c.device.Host))
}

session, err := c.sshClient.NewSession()
if err != nil {
return nil, errors.Wrapf(err, "could not open session with %s", c.conn.RemoteAddr().String())
c.Stop(fmt.Errorf("SSH session failure"))
return nil, errors.Wrapf(err, "could not open session with %s", c.device.Host)
}
defer session.Close()

Expand All @@ -46,37 +95,114 @@ func (c *SSHConnection) RunCommand(cmd string) ([]byte, error) {

err = session.Run(cmd)
if err != nil {
return nil, errors.Wrapf(err, "could not run command %q on %s", cmd, c.conn.RemoteAddr().String())
c.Stop(fmt.Errorf("failed running command"))
return nil, errors.Wrapf(err, "could not run command %q on %s", cmd, c.device.Host)
}

return b.Bytes(), nil
}

func (c *SSHConnection) isConnected() bool {
return c.conn != nil
func (c *SSHConnection) keepalive(expiredConnectionTimeout time.Duration) {
for {
select {
case <-time.After(c.keepAliveInterval):
terminated := c.terminateIfLifetimeExpired(expiredConnectionTimeout)
if terminated {
return
}

_ = c.tcpConn.SetDeadline(time.Now().Add(c.keepAliveTimeout))

ok := c.testSSHClient()
if !ok {
return
}
case <-c.done:
return
}
}
}

func (c *SSHConnection) terminate() {
c.mu.Lock()
defer c.mu.Unlock()
func (c *SSHConnection) terminateIfLifetimeExpired(expiredConnectionTimeout time.Duration) bool {
if time.Since(c.GetLastUsed()) > expiredConnectionTimeout {
c.Stop(fmt.Errorf("lifetime expired"))
return true
}

return false
}

func (c *SSHConnection) testSSHClient() bool {
sshClient := c.getSSHClient()

c.conn.Close()
_, _, err := sshClient.SendRequest("[email protected]", true, nil)
if err != nil {
log.Infof("SSH keepalive request to %s failed: %v", c.device, err)
c.Stop(fmt.Errorf("keepalive failed"))
return false
}

c.client = nil
c.conn = nil
return true
}

func (c *SSHConnection) close() {
func (c *SSHConnection) connect() error {
cfg := &ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: timeoutInSeconds * time.Second,
}

c.device.Auth(cfg)

host := tcpAddressForHost(c.device.Host)
log.Infof("Establishing TCP connection with %s", host)

tcpConn, err := net.DialTimeout("tcp", host, cfg.Timeout)
if err != nil {
return fmt.Errorf("could not open tcp connection: %w", err)
}

sshConn, chans, reqs, err := ssh.NewClientConn(tcpConn, host, cfg)
if err != nil {
tcpConn.Close()
return fmt.Errorf("could not connect to device: %w", err)
}

c.mu.Lock()
defer c.mu.Unlock()

if c.client != nil {
c.client.Close()
}
c.tcpConn = tcpConn
c.sshClient = ssh.NewClient(sshConn, chans, reqs)
c.isConnected = true

return nil
}

func (c *SSHConnection) setLastUsed(t time.Time) {
c.lastUsedMu.Lock()
defer c.lastUsedMu.Unlock()

c.lastUsed = t
}

func (c *SSHConnection) GetLastUsed() time.Time {
c.lastUsedMu.RLock()
defer c.lastUsedMu.RUnlock()

return c.lastUsed
}

func (c *SSHConnection) getSSHClient() *ssh.Client {
c.mu.RLock()
defer c.mu.RUnlock()

return c.sshClient
}

func (c *SSHConnection) IsConnected() bool {
c.mu.RLock()
defer c.mu.RUnlock()

c.done <- struct{}{}
c.conn = nil
c.client = nil
return c.isConnected
}

// Host returns the hostname of the connected device
Expand Down
Loading

0 comments on commit bfdc266

Please sign in to comment.