diff --git a/junos_collector.go b/junos_collector.go index 755d663..d9e7024 100644 --- a/junos_collector.go +++ b/junos_collector.go @@ -77,7 +77,7 @@ func deviceInterfaceRegex(cfg *config.Config, host string) *regexp.Regexp { } func clientForDevice(device *connector.Device, connManager *connector.SSHConnectionManager) (*rpc.Client, error) { - conn, err := connManager.Connect(device) + conn, err := connManager.GetSSHConnection(device) if err != nil { return nil, err } diff --git a/main.go b/main.go index b1ab6af..65b6d70 100644 --- a/main.go +++ b/main.go @@ -165,7 +165,7 @@ func initChannels(ctx context.Context) { func shutdown() { log.Infoln("Closing connections to devices") - connManager.Close() + connManager.CloseAll() os.Exit(0) } @@ -198,7 +198,7 @@ func reinitialize() error { defer configMu.Unlock() if connManager != nil { - connManager.Close() + connManager.CloseAll() connManager = nil } diff --git a/pkg/connector/connection.go b/pkg/connector/connection.go index 5ef3de7..7891a4c 100644 --- a/pkg/connector/connection.go +++ b/pkg/connector/connection.go @@ -10,34 +10,83 @@ import ( "time" "github.com/pkg/errors" + log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) // SSHConnection encapsulates the connection to the device type SSHConnection struct { - device *Device - client *ssh.Client - conn net.Conn - lastUsed time.Time - mu sync.Mutex - done chan struct{} + device *Device + sshClient *ssh.Client + tcpConn net.Conn + isConnected bool + mu sync.RWMutex // protects sshClient, tcpConn and isConnected + lastUsed time.Time + lastUsedMu sync.RWMutex + done chan struct{} + keepAliveInterval time.Duration + keepAliveTimeout time.Duration } -// RunCommand runs a command against the device -func (c *SSHConnection) RunCommand(cmd string) ([]byte, error) { +func NewSSHConnection(device *Device, keepAliveInterval time.Duration, keepAliveTimeout time.Duration) *SSHConnection { + return &SSHConnection{ + device: device, + keepAliveInterval: keepAliveInterval, + keepAliveTimeout: keepAliveTimeout, + done: make(chan struct{}), + } +} + +func (c *SSHConnection) Start(expiredConnectionTimeout time.Duration) error { + err := c.connect() + if err != nil { + return err + } + + go c.keepalive(expiredConnectionTimeout) + return nil +} + +func (c *SSHConnection) Stop(err error) { + log.Infof("Stopping SSH connection with %s (reason: %v)", c.device.Host, err) + c.mu.Lock() defer c.mu.Unlock() - c.lastUsed = time.Now() + if !c.isConnected { + return + } + + close(c.done) - if c.client == nil { - return nil, errors.New(fmt.Sprintf("not connected with %s", c.conn.RemoteAddr().String())) + if c.sshClient != nil { + c.sshClient.Close() + c.sshClient = nil } - session, err := c.client.NewSession() + if c.tcpConn != nil { + c.tcpConn.Close() + c.tcpConn = nil + } + + c.isConnected = false +} + +// RunCommand runs a command against the device +func (c *SSHConnection) RunCommand(cmd string) ([]byte, error) { + c.setLastUsed(time.Now()) + + sshClient := c.getSSHClient() + if sshClient == nil { + c.Stop(fmt.Errorf("No ssh client")) + return nil, errors.New(fmt.Sprintf("no SSH client to %s", c.device.Host)) + } + + session, err := c.sshClient.NewSession() if err != nil { - return nil, errors.Wrapf(err, "could not open session with %s", c.conn.RemoteAddr().String()) + c.Stop(fmt.Errorf("SSH session failure")) + return nil, errors.Wrapf(err, "could not open session with %s", c.device.Host) } defer session.Close() @@ -46,37 +95,114 @@ func (c *SSHConnection) RunCommand(cmd string) ([]byte, error) { err = session.Run(cmd) if err != nil { - return nil, errors.Wrapf(err, "could not run command %q on %s", cmd, c.conn.RemoteAddr().String()) + c.Stop(fmt.Errorf("failed running command")) + return nil, errors.Wrapf(err, "could not run command %q on %s", cmd, c.device.Host) } return b.Bytes(), nil } -func (c *SSHConnection) isConnected() bool { - return c.conn != nil +func (c *SSHConnection) keepalive(expiredConnectionTimeout time.Duration) { + for { + select { + case <-time.After(c.keepAliveInterval): + terminated := c.terminateIfLifetimeExpired(expiredConnectionTimeout) + if terminated { + return + } + + _ = c.tcpConn.SetDeadline(time.Now().Add(c.keepAliveTimeout)) + + ok := c.testSSHClient() + if !ok { + return + } + case <-c.done: + return + } + } } -func (c *SSHConnection) terminate() { - c.mu.Lock() - defer c.mu.Unlock() +func (c *SSHConnection) terminateIfLifetimeExpired(expiredConnectionTimeout time.Duration) bool { + if time.Since(c.GetLastUsed()) > expiredConnectionTimeout { + c.Stop(fmt.Errorf("lifetime expired")) + return true + } + + return false +} + +func (c *SSHConnection) testSSHClient() bool { + sshClient := c.getSSHClient() - c.conn.Close() + _, _, err := sshClient.SendRequest("keepalive@golang.org", true, nil) + if err != nil { + log.Infof("SSH keepalive request to %s failed: %v", c.device, err) + c.Stop(fmt.Errorf("keepalive failed")) + return false + } - c.client = nil - c.conn = nil + return true } -func (c *SSHConnection) close() { +func (c *SSHConnection) connect() error { + cfg := &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: timeoutInSeconds * time.Second, + } + + c.device.Auth(cfg) + + host := tcpAddressForHost(c.device.Host) + log.Infof("Establishing TCP connection with %s", host) + + tcpConn, err := net.DialTimeout("tcp", host, cfg.Timeout) + if err != nil { + return fmt.Errorf("could not open tcp connection: %w", err) + } + + sshConn, chans, reqs, err := ssh.NewClientConn(tcpConn, host, cfg) + if err != nil { + tcpConn.Close() + return fmt.Errorf("could not connect to device: %w", err) + } + c.mu.Lock() defer c.mu.Unlock() - if c.client != nil { - c.client.Close() - } + c.tcpConn = tcpConn + c.sshClient = ssh.NewClient(sshConn, chans, reqs) + c.isConnected = true + + return nil +} + +func (c *SSHConnection) setLastUsed(t time.Time) { + c.lastUsedMu.Lock() + defer c.lastUsedMu.Unlock() + + c.lastUsed = t +} + +func (c *SSHConnection) GetLastUsed() time.Time { + c.lastUsedMu.RLock() + defer c.lastUsedMu.RUnlock() + + return c.lastUsed +} + +func (c *SSHConnection) getSSHClient() *ssh.Client { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.sshClient +} + +func (c *SSHConnection) IsConnected() bool { + c.mu.RLock() + defer c.mu.RUnlock() - c.done <- struct{}{} - c.conn = nil - c.client = nil + return c.isConnected } // Host returns the hostname of the connected device diff --git a/pkg/connector/connection_manager.go b/pkg/connector/connection_manager.go index 7806a57..ad7d943 100644 --- a/pkg/connector/connection_manager.go +++ b/pkg/connector/connection_manager.go @@ -3,14 +3,12 @@ package connector import ( + "fmt" "net" "strings" "sync" "time" - "github.com/pkg/errors" - "golang.org/x/crypto/ssh" - log "github.com/sirupsen/logrus" ) @@ -51,11 +49,11 @@ func WithExpiredConnectionTimeout(d time.Duration) Option { // SSHConnectionManager manages SSH connections to different devices type SSHConnectionManager struct { connections map[string]*SSHConnection + connectionsMu sync.RWMutex reconnectInterval time.Duration keepAliveInterval time.Duration keepAliveTimeout time.Duration expiredConnectionTimeout time.Duration - locks map[string]*sync.Mutex } // NewConnectionManager creates a new connection manager @@ -65,7 +63,6 @@ func NewConnectionManager(opts ...Option) *SSHConnectionManager { reconnectInterval: 30 * time.Second, keepAliveInterval: 10 * time.Second, keepAliveTimeout: 15 * time.Second, - locks: make(map[string]*sync.Mutex), } for _, opt := range opts { @@ -75,80 +72,51 @@ func NewConnectionManager(opts ...Option) *SSHConnectionManager { return m } -func (m *SSHConnectionManager) lockForDevice(device *Device) *sync.Mutex { - if mu, exists := m.locks[device.Host]; exists { - return mu +// GetSSHConnection gets a cached SSHConnection or creates a fresh one, if necessary +func (m *SSHConnectionManager) GetSSHConnection(device *Device) (*SSHConnection, error) { + connection := m.getExistingConnection(device) + if connection != nil { + log.Infof("Re-using existing connection with %s", device.Host) + return connection, nil } - mu := &sync.Mutex{} - m.locks[device.Host] = mu - return mu + return m.connect(device) } -// Connect connects to a device or returns an long living connection -func (m *SSHConnectionManager) Connect(device *Device) (*SSHConnection, error) { - if connection, found := m.connections[device.Host]; found { - if connection.isConnected() { - return connection, nil - } - } - - mu := m.lockForDevice(device) - mu.Lock() - defer mu.Unlock() +func (m *SSHConnectionManager) getExistingConnection(device *Device) *SSHConnection { + m.connectionsMu.RLock() + defer m.connectionsMu.RUnlock() if connection, found := m.connections[device.Host]; found { - if connection.isConnected() { - return connection, nil + if connection.IsConnected() { + return connection } } - return m.connect(device) + return nil } func (m *SSHConnectionManager) connect(device *Device) (*SSHConnection, error) { - client, conn, err := m.connectToDevice(device) + log.Infof("Creating SSH connection with %s", device.Host) + c := NewSSHConnection(device, m.keepAliveInterval, m.keepAliveTimeout) + err := c.Start(m.expiredConnectionTimeout) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to get new SSH connection: %w", err) } - c := &SSHConnection{ - conn: conn, - client: client, - device: device, - done: make(chan struct{}), + m.connectionsMu.Lock() + defer m.connectionsMu.Unlock() + + if existingCon, exists := m.connections[device.Host]; exists && existingCon.IsConnected() { + c.Stop(fmt.Errorf("connection conflict")) + return existingCon, nil } - go m.keepAlive(c) m.connections[device.Host] = c - return c, nil } -func (m *SSHConnectionManager) connectToDevice(device *Device) (*ssh.Client, net.Conn, error) { - cfg := &ssh.ClientConfig{ - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - Timeout: timeoutInSeconds * time.Second, - } - - device.Auth(cfg) - - host := m.tcpAddressForHost(device.Host) - - conn, err := net.DialTimeout("tcp", host, cfg.Timeout) - if err != nil { - return nil, nil, errors.Wrap(err, "could not open tcp connection") - } - - c, chans, reqs, err := ssh.NewClientConn(conn, host, cfg) - if err != nil { - return nil, nil, errors.Wrap(err, "could not connect to device") - } - - return ssh.NewClient(c, chans, reqs), conn, nil -} - -func (m *SSHConnectionManager) tcpAddressForHost(host string) string { +func tcpAddressForHost(host string) string { colonCount := strings.Count(host, ":") if colonCount == 0 { @@ -158,13 +126,13 @@ func (m *SSHConnectionManager) tcpAddressForHost(host string) string { h, p, err := net.SplitHostPort(host) if err == nil { - return m.formatHost(h) + ":" + p + return formatHost(h) + ":" + p } - return m.formatHost(host) + ":" + defaultPort + return formatHost(host) + ":" + defaultPort } -func (m *SSHConnectionManager) formatHost(host string) string { +func formatHost(host string) string { ip := net.ParseIP(host) if ip == nil || ip.To4() != nil { @@ -175,47 +143,10 @@ func (m *SSHConnectionManager) formatHost(host string) string { return "[" + host + "]" } -func (m *SSHConnectionManager) keepAlive(connection *SSHConnection) { - for { - select { - case <-time.After(m.keepAliveInterval): - if time.Since(connection.lastUsed) > m.expiredConnectionTimeout { - connection.terminate() - return - } - - log.Debugf("Sending keepalive for ") - connection.conn.SetDeadline(time.Now().Add(m.keepAliveTimeout)) - _, _, err := connection.client.SendRequest("keepalive@golang.org", true, nil) - if err != nil { - log.Infof("Lost connection to %s (%v). Trying to reconnect...", connection.device, err) - connection.terminate() - m.reconnect(connection) - } - case <-connection.done: - return - } - } -} - -func (m *SSHConnectionManager) reconnect(connection *SSHConnection) { - for { - client, conn, err := m.connectToDevice(connection.device) - if err == nil { - connection.client = client - connection.conn = conn - return - } - - log.Infof("Reconnect to %s failed: %v", connection.device, err) - time.Sleep(m.reconnectInterval) - } -} - -// Close closes all TCP connections and stop keep alives -func (m *SSHConnectionManager) Close() error { +// CloseAll closes all TCP connections and stops keep alives +func (m *SSHConnectionManager) CloseAll() error { for _, c := range m.connections { - c.close() + c.Stop(fmt.Errorf("end of world")) } return nil diff --git a/pkg/connector/connection_manager_test.go b/pkg/connector/connection_manager_test.go index 005021c..435ef57 100644 --- a/pkg/connector/connection_manager_test.go +++ b/pkg/connector/connection_manager_test.go @@ -55,8 +55,7 @@ func TestTCPAddressForHost(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - m := NewConnectionManager() - assert.Equal(t, test.expected, m.tcpAddressForHost(test.host)) + assert.Equal(t, test.expected, tcpAddressForHost(test.host)) }) } }