diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index ac0e0b385c..ce535b9d8c 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -340,6 +340,16 @@ func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error ) } + node.Expiry = &expiry + + stateSelfUpdate := types.StateUpdate{ + Type: types.StateSelfUpdate, + ChangeNodes: types.Nodes{node}, + } + if stateSelfUpdate.Valid() { + hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey) + } + stateUpdate := types.StateUpdate{ Type: types.StatePeerChangedPatch, ChangePatches: []*tailcfg.PeerChange{ @@ -350,7 +360,7 @@ func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error }, } if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) + hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) } return nil @@ -856,7 +866,7 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { // checked everything. started := time.Now() - expired := make([]*tailcfg.PeerChange, 0) + expiredNodes := make([]*types.Node, 0) nodes, err := hsdb.listNodes() if err != nil { @@ -872,17 +882,13 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { // It will notify about all nodes that has been expired. // It should only notify about expired nodes since _last check_. node.Expiry.After(lastCheck) { - expired = append(expired, &tailcfg.PeerChange{ - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: node.Expiry, - }) + expiredNodes = append(expiredNodes, &nodes[index]) - now := time.Now() // Do not use setNodeExpiry as that has a notifier hook, which // can cause a deadlock, we are updating all changed nodes later // and there is no point in notifiying twice. if err := hsdb.db.Model(nodes[index]).Updates(types.Node{ - Expiry: &now, + Expiry: &started, }).Error; err != nil { log.Error(). Err(err). @@ -898,6 +904,15 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { } } + expired := make([]*tailcfg.PeerChange, len(expiredNodes)) + for idx, node := range expiredNodes { + expired[idx] = &tailcfg.PeerChange{ + NodeID: tailcfg.NodeID(node.ID), + KeyExpiry: &started, + } + } + + // Inform the peers of a node with a lightweight update. stateUpdate := types.StateUpdate{ Type: types.StatePeerChangedPatch, ChangePatches: expired, @@ -906,5 +921,16 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { hsdb.notifier.NotifyAll(stateUpdate) } + // Inform the node itself that it has expired. + for _, node := range expiredNodes { + stateSelfUpdate := types.StateUpdate{ + Type: types.StateSelfUpdate, + ChangeNodes: types.Nodes{node}, + } + if stateSelfUpdate.Valid() { + hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey) + } + } + return started } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 0af569f26b..d6404ce109 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -21,7 +21,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/klauspost/compress/zstd" "github.com/rs/zerolog/log" - "github.com/samber/lo" "golang.org/x/exp/maps" "tailscale.com/envknob" "tailscale.com/smallzstd" @@ -595,15 +594,6 @@ func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes { return ret } -func filterExpiredAndNotReady(peers types.Nodes) types.Nodes { - return lo.Filter(peers, func(item *types.Node, index int) bool { - // Filter out nodes that are expired OR - // nodes that has no endpoints, this typically means they have - // registered, but are not configured. - return !item.IsExpired() || len(item.Endpoints) > 0 - }) -} - // appendPeerChanges mutates a tailcfg.MapResponse with all the // necessary changes when peers have changed. func appendPeerChanges( @@ -629,9 +619,6 @@ func appendPeerChanges( return err } - // Filter out peers that have expired. - changed = filterExpiredAndNotReady(changed) - // If there are filter rules present, see if there are any nodes that cannot // access eachother at all and remove them from the peers. if len(rules) > 0 { diff --git a/integration/general_test.go b/integration/general_test.go index 2e0f7fe662..d06356c15c 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -560,7 +560,7 @@ func TestExpireNode(t *testing.T) { t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String()) - time.Sleep(30 * time.Second) + time.Sleep(2 * time.Minute) now := time.Now() @@ -572,21 +572,33 @@ func TestExpireNode(t *testing.T) { if client.Hostname() != node.GetName() { t.Logf("available peers of %s: %v", client.Hostname(), status.Peers()) - // In addition to marking nodes expired, we filter them out during the map response - // this check ensures that the node is either not present, or that it is expired - // if it is in the map response. + // Ensures that the node is present, and that it is expired. if peerStatus, ok := status.Peer[expiredNodeKey]; ok { assertNotNil(t, peerStatus.Expired) - assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %s should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry) - assert.Truef(t, peerStatus.Expired, "node %s should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired) + assert.NotNil(t, peerStatus.KeyExpiry) + + t.Logf("node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry) + if peerStatus.KeyExpiry != nil { + assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry) + } + + assert.Truef(t, peerStatus.Expired, "node %q should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired) + + _, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()}) + if !strings.Contains(stderr, "node key has expired") { + t.Errorf("expected to be unable to ping expired host %q from %q", node.GetName(), client.Hostname()) + } + } else { + t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey) + } + } else { + if status.Self.KeyExpiry != nil { + assert.Truef(t, status.Self.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", status.Self.HostName, now.String(), status.Self.KeyExpiry) } - // TODO(kradalby): We do not propogate expiry correctly, nodes should be aware - // of their status, and this should be sent directly to the node when its - // expired. This needs a notifier that goes directly to the node (currently we only do peers) - // so fix this in a follow up PR. - // } else { - // assert.True(t, status.Self.Expired) + // NeedsLogin means that the node has understood that it is no longer + // valid. + assert.Equal(t, "NeedsLogin", status.BackendState) } } }