diff --git a/connection_manager.go b/connection_manager.go index db5827415..373ac0268 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -5,6 +5,7 @@ import ( "time" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" ) // TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet @@ -153,6 +154,26 @@ func (n *connectionManager) Run() { } } +// Check if peer's certificate is not expired or invalid. +func (n *connectionManager) checkToDisconnect(hostinfo *HostInfo) (bool, *cert.NebulaCertificate) { + if !n.intf.disconnectInvalid { + return false, nil + } + + if hostinfo == nil { + return false, nil + } + + if remoteCert := hostinfo.GetCert(); remoteCert != nil { + valid, _ := remoteCert.Verify(time.Now(), n.intf.caPool) + if !valid { + return true, remoteCert + } + } + + return false, nil +} + func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) { n.TrafficTimer.advance(now) for { @@ -166,25 +187,40 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) // Check for traffic coming back in from this host. traf := n.CheckIn(vpnIP) - // If we saw incoming packets from this ip, just return + hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) + + disconnect_invalid, remoteCert := n.checkToDisconnect(hostinfo) + if disconnect_invalid { + n.l.WithField("vpnIp", IntIp(vpnIP)). + WithField("certName", remoteCert.Details.Name). + Debug("Invalid certificate status") + } + + // If we saw an incoming packets from this ip and peer's certificate is not + // expired, just ignore. if traf { if n.l.Level >= logrus.DebugLevel { n.l.WithField("vpnIp", IntIp(vpnIP)). WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). Debug("Tunnel status") } - n.ClearIP(vpnIP) - n.ClearPendingDeletion(vpnIP) - continue + + if !disconnect_invalid { + n.ClearIP(vpnIP) + n.ClearPendingDeletion(vpnIP) + continue + } } // If we didn't we may need to probe or destroy the conn - hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) if err != nil { n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) - n.ClearIP(vpnIP) - n.ClearPendingDeletion(vpnIP) - continue + + if !disconnect_invalid { + n.ClearIP(vpnIP) + n.ClearPendingDeletion(vpnIP) + continue + } } hostinfo.logger(n.l). @@ -213,18 +249,23 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) { vpnIP := ep.(uint32) + hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) + // If we saw incoming packets from this ip, just return traf := n.CheckIn(vpnIP) if traf { n.l.WithField("vpnIp", IntIp(vpnIP)). WithField("tunnelCheck", m{"state": "alive", "method": "active"}). Debug("Tunnel status") - n.ClearIP(vpnIP) - n.ClearPendingDeletion(vpnIP) - continue + + disconnect_invalid, _ := n.checkToDisconnect(hostinfo) + if !disconnect_invalid { + n.ClearIP(vpnIP) + n.ClearPendingDeletion(vpnIP) + continue + } } - hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) if err != nil { n.ClearIP(vpnIP) n.ClearPendingDeletion(vpnIP) diff --git a/examples/config.yml b/examples/config.yml index 7d4cf2373..317e0c45d 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -10,6 +10,8 @@ pki: #blocklist is a list of certificate fingerprints that we will refuse to talk to #blocklist: # - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72 + #disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. + #disconnect_invalid: false # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. diff --git a/interface.go b/interface.go index 108ca05b6..9ea6c3b2b 100644 --- a/interface.go +++ b/interface.go @@ -43,6 +43,7 @@ type InterfaceConfig struct { MessageMetrics *MessageMetrics version string caPool *cert.NebulaCAPool + disconnectInvalid bool ConntrackCacheTimeout time.Duration l *logrus.Logger @@ -67,6 +68,7 @@ type Interface struct { udpBatchSize int routines int caPool *cert.NebulaCAPool + disconnectInvalid bool // rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse rebindCount int8 @@ -118,6 +120,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) { writers: make([]*udpConn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), caPool: c.caPool, + disconnectInvalid: c.disconnectInvalid, myVpnIp: ip2int(c.certState.certificate.Details.Ips[0].IP), conntrackCacheTimeout: c.ConntrackCacheTimeout, diff --git a/main.go b/main.go index a77559926..74f06b155 100644 --- a/main.go +++ b/main.go @@ -371,6 +371,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L MessageMetrics: messageMetrics, version: buildVersion, caPool: caPool, + disconnectInvalid: config.GetBool("pki.disconnect_invalid", false), ConntrackCacheTimeout: conntrackCacheTimeout, l: l,