From cb8c2fc91a08f0e6449e6fb8e51bcecd9f6e5e46 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 12 Jul 2024 16:45:11 +0200 Subject: [PATCH 1/7] replace ephemeral deletion logic this commit replaces the way we remove ephemeral nodes, currently they are deleted in a loop and we look at last seen time. This time is now only set when a node disconnects and there was a bug (#2006) where nodes that had never disconnected was deleted since they did not have a last seen. The new logic will start an expiry timer when the node disconnects and delete the node from the database when the timer is up. If the node reconnects within the expiry, the timer is cancelled. Fixes #2006 Signed-off-by: Kristoffer Dalby --- .github/workflows/test-integration.yaml | 1 + hscontrol/app.go | 54 ++-------- hscontrol/db/node.go | 125 +++++++++++++++--------- hscontrol/db/node_test.go | 65 ++++++++++++ hscontrol/db/preauth_keys_test.go | 72 -------------- hscontrol/poll.go | 15 +++ integration/general_test.go | 116 ++++++++++++++++++++++ 7 files changed, 286 insertions(+), 162 deletions(-) diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 9581badaef..2c3d028849 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -40,6 +40,7 @@ jobs: - TestPingAllByIPPublicDERP - TestAuthKeyLogoutAndRelogin - TestEphemeral + - TestEphemeral2006DeletedTooQuickly - TestPingAllByHostname - TestTaildrop - TestResolveMagicDNS diff --git a/hscontrol/app.go b/hscontrol/app.go index 253c2671b1..2862c60e1b 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -90,6 +90,7 @@ type Headscale struct { db *db.HSDatabase ipAlloc *db.IPAllocator noisePrivateKey *key.MachinePrivate + ephemeralGC *db.EphemeralGarbageCollector DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer @@ -152,6 +153,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, err } + app.ephemeralGC = db.NewEphemeralGarbageCollector(func(ni types.NodeID) { + if err := app.db.DeleteEphemeralNode(ni); err != nil { + log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to delete ephemeral node") + } + }) + if cfg.OIDC.Issuer != "" { err = app.initOIDC() if err != nil { @@ -216,47 +223,6 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) { http.Redirect(w, req, target, http.StatusFound) } -// deleteExpireEphemeralNodes deletes ephemeral node records that have not been -// seen for longer than h.cfg.EphemeralNodeInactivityTimeout. -func (h *Headscale) deleteExpireEphemeralNodes(ctx context.Context, every time.Duration) { - ticker := time.NewTicker(every) - - for { - select { - case <-ctx.Done(): - ticker.Stop() - return - case <-ticker.C: - var removed []types.NodeID - var changed []types.NodeID - if err := h.db.Write(func(tx *gorm.DB) error { - removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout) - - return nil - }); err != nil { - log.Error().Err(err).Msg("database error while expiring ephemeral nodes") - continue - } - - if removed != nil { - ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: removed, - }) - } - - if changed != nil { - ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: changed, - }) - } - } - } -} - // expireExpiredNodes expires nodes that have an explicit expiry set // after that expiry time has passed. func (h *Headscale) expireExpiredNodes(ctx context.Context, every time.Duration) { @@ -552,9 +518,7 @@ func (h *Headscale) Serve() error { return errEmptyInitialDERPMap } - expireEphemeralCtx, expireEphemeralCancel := context.WithCancel(context.Background()) - defer expireEphemeralCancel() - go h.deleteExpireEphemeralNodes(expireEphemeralCtx, updateInterval) + go h.ephemeralGC.Start() expireNodeCtx, expireNodeCancel := context.WithCancel(context.Background()) defer expireNodeCancel() @@ -810,7 +774,7 @@ func (h *Headscale) Serve() error { Msg("Received signal to stop, shutting down gracefully") expireNodeCancel() - expireEphemeralCancel() + h.ephemeralGC.Close() trace("waiting for netmap stream to close") h.pollNetMapStreamWG.Wait() diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index e36d6ed131..d0f0fa48d9 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -12,6 +12,7 @@ import ( "github.com/patrickmn/go-cache" "github.com/puzpuzpuz/xsync/v3" "github.com/rs/zerolog/log" + "github.com/sasha-s/go-deadlock" "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -286,6 +287,20 @@ func DeleteNode(tx *gorm.DB, return changed, nil } +// DeleteEphemeralNode deletes a Node from the database, note that this method +// will remove it straight, and not notify any changes or consider any routes. +// It is intended for Ephemeral nodes. +func (hsdb *HSDatabase) DeleteEphemeralNode( + nodeID types.NodeID, +) error { + return hsdb.Write(func(tx *gorm.DB) error { + if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil { + return err + } + return nil + }) +} + // SetLastSeen sets a node's last seen field indicating that we // have recently communicating with this node. func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { @@ -660,51 +675,6 @@ func GenerateGivenName( return givenName, nil } -func DeleteExpiredEphemeralNodes(tx *gorm.DB, - inactivityThreshold time.Duration, -) ([]types.NodeID, []types.NodeID) { - users, err := ListUsers(tx) - if err != nil { - return nil, nil - } - - var expired []types.NodeID - var changedNodes []types.NodeID - for _, user := range users { - nodes, err := ListNodesByUser(tx, user.Name) - if err != nil { - return nil, nil - } - - for idx, node := range nodes { - if node.IsEphemeral() && node.LastSeen != nil && - time.Now(). - After(node.LastSeen.Add(inactivityThreshold)) { - expired = append(expired, node.ID) - - log.Info(). - Str("node", node.Hostname). - Msg("Ephemeral client removed from database") - - // empty isConnected map as ephemeral nodes are not routes - changed, err := DeleteNode(tx, nodes[idx], nil) - if err != nil { - log.Error(). - Err(err). - Str("node", node.Hostname). - Msg("🤮 Cannot delete ephemeral node from the database") - } - - changedNodes = append(changedNodes, changed...) - } - } - - // TODO(kradalby): needs to be moved out of transaction - } - - return expired, changedNodes -} - func ExpireExpiredNodes(tx *gorm.DB, lastCheck time.Time, ) (time.Time, types.StateUpdate, bool) { @@ -737,3 +707,68 @@ func ExpireExpiredNodes(tx *gorm.DB, return started, types.StateUpdate{}, false } + +type EphemeralGarbageCollector struct { + mu deadlock.Mutex + + deleteFunc func(types.NodeID) + toBeDeleted map[types.NodeID]*time.Timer + + deleteCh chan types.NodeID + cancelCh chan struct{} +} + +func NewEphemeralGarbageCollector(deleteFunc func(types.NodeID)) *EphemeralGarbageCollector { + return &EphemeralGarbageCollector{ + toBeDeleted: make(map[types.NodeID]*time.Timer), + deleteCh: make(chan types.NodeID, 10), + cancelCh: make(chan struct{}), + deleteFunc: deleteFunc, + } +} + +func (e *EphemeralGarbageCollector) Close() { + e.cancelCh <- struct{}{} +} + +func (e *EphemeralGarbageCollector) Schedule(nodeID types.NodeID, expiry time.Duration) { + e.mu.Lock() + defer e.mu.Unlock() + + timer := time.NewTimer(expiry) + e.toBeDeleted[nodeID] = timer + + go func() { + select { + case _, ok := <-timer.C: + if ok { + e.deleteCh <- nodeID + } + } + }() +} + +func (e *EphemeralGarbageCollector) Cancel(nodeID types.NodeID) { + e.mu.Lock() + defer e.mu.Unlock() + + if timer, ok := e.toBeDeleted[nodeID]; ok { + timer.Stop() + delete(e.toBeDeleted, nodeID) + } +} + +func (e *EphemeralGarbageCollector) Start() { + for { + select { + case <-e.cancelCh: + return + case nodeID := <-e.deleteCh: + e.mu.Lock() + delete(e.toBeDeleted, nodeID) + e.mu.Unlock() + + go e.deleteFunc(nodeID) + } + } +} diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index e95ee4ae33..b164e526cf 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -1,13 +1,17 @@ package db import ( + "crypto/rand" "fmt" + "math/big" "net/netip" "regexp" "strconv" + "sync" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -599,3 +603,64 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { c.Assert(err, check.IsNil) c.Assert(enabledRoutes, check.HasLen, 4) } + +func TestEphemeralGarbageCollectorOrder(t *testing.T) { + want := []types.NodeID{1, 3} + got := []types.NodeID{} + + e := NewEphemeralGarbageCollector(func(ni types.NodeID) { + got = append(got, ni) + }) + go e.Start() + + e.Schedule(1, 1*time.Second) + e.Schedule(2, 2*time.Second) + e.Schedule(3, 3*time.Second) + e.Schedule(4, 4*time.Second) + e.Cancel(2) + e.Cancel(4) + + time.Sleep(6 * time.Second) + + e.Close() + + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("wrong nodes deleted, unexpected result (-want +got):\n%s", diff) + } +} + +func TestEphemeralGarbageCollectorLoads(t *testing.T) { + var got []types.NodeID + var mu sync.Mutex + + want := 1000 + + e := NewEphemeralGarbageCollector(func(ni types.NodeID) { + defer mu.Unlock() + mu.Lock() + + time.Sleep(time.Duration(generateRandomNumber(t, 3)) * time.Millisecond) + got = append(got, ni) + }) + go e.Start() + + for i := 0; i < want; i++ { + go e.Schedule(types.NodeID(i), 1*time.Second) + } + + time.Sleep(10 * time.Second) + + if len(got) != want { + t.Errorf("expected %d, got %d", want, len(got)) + } +} + +func generateRandomNumber(t *testing.T, max int64) int64 { + t.Helper() + maxB := big.NewInt(max) + n, err := rand.Int(rand.Reader, maxB) + if err != nil { + t.Fatalf("getting random number: %s", err) + } + return n.Int64() + 1 +} diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 9cdcba8044..88d5b213f5 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -6,7 +6,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" - "gorm.io/gorm" ) func (*Suite) TestCreatePreAuthKey(c *check.C) { @@ -127,77 +126,6 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { c.Assert(key.ID, check.Equals, pak.ID) } -func (*Suite) TestEphemeralKeyReusable(c *check.C) { - user, err := db.CreateUser("test7") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, true, true, nil, nil) - c.Assert(err, check.IsNil) - - now := time.Now().Add(-time.Second * 30) - pakID := uint(pak.ID) - node := types.Node{ - ID: 0, - Hostname: "testest", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - LastSeen: &now, - AuthKeyID: &pakID, - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - _, err = db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.IsNil) - - _, err = db.getNode("test7", "testest") - c.Assert(err, check.IsNil) - - db.Write(func(tx *gorm.DB) error { - DeleteExpiredEphemeralNodes(tx, time.Second*20) - return nil - }) - - // The machine record should have been deleted - _, err = db.getNode("test7", "testest") - c.Assert(err, check.NotNil) -} - -func (*Suite) TestEphemeralKeyNotReusable(c *check.C) { - user, err := db.CreateUser("test7") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) - c.Assert(err, check.IsNil) - - now := time.Now().Add(-time.Second * 30) - pakId := uint(pak.ID) - node := types.Node{ - ID: 0, - Hostname: "testest", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - LastSeen: &now, - AuthKeyID: &pakId, - } - db.DB.Save(&node) - - _, err = db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.NotNil) - - _, err = db.getNode("test7", "testest") - c.Assert(err, check.IsNil) - - db.Write(func(tx *gorm.DB) error { - DeleteExpiredEphemeralNodes(tx, time.Second*20) - return nil - }) - - // The machine record should have been deleted - _, err = db.getNode("test7", "testest") - c.Assert(err, check.NotNil) -} - func (*Suite) TestExpirePreauthKey(c *check.C) { user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) diff --git a/hscontrol/poll.go b/hscontrol/poll.go index d3c8211769..8122064b6d 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -135,6 +135,18 @@ func (m *mapSession) resetKeepAlive() { m.keepAliveTicker.Reset(m.keepAlive) } +func (m *mapSession) beforeServeLongPoll() { + if m.node.IsEphemeral() { + m.h.ephemeralGC.Cancel(m.node.ID) + } +} + +func (m *mapSession) afterServeLongPoll() { + if m.node.IsEphemeral() { + m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout) + } +} + // serve handles non-streaming requests. func (m *mapSession) serve() { // TODO(kradalby): A set todos to harden: @@ -180,6 +192,8 @@ func (m *mapSession) serve() { // //nolint:gocyclo func (m *mapSession) serveLongPoll() { + m.beforeServeLongPoll() + // Clean up the session when the client disconnects defer func() { m.cancelChMu.Lock() @@ -197,6 +211,7 @@ func (m *mapSession) serveLongPoll() { m.pollFailoverRoutes("node closing connection", m.node) } + m.afterServeLongPoll() m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) }() diff --git a/integration/general_test.go b/integration/general_test.go index 245e8f096a..c17b977e00 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -297,6 +297,122 @@ func TestEphemeral(t *testing.T) { } } +// TestEphemeral2006DeletedTooQuickly verifies that ephemeral nodes are not +// deleted by accident if they are still online and active. +func TestEphemeral2006DeletedTooQuickly(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + scenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + defer scenario.Shutdown() + + spec := map[string]int{ + "user1": len(MustTestVersions), + "user2": len(MustTestVersions), + } + + headscale, err := scenario.Headscale( + hsic.WithTestName("ephemeral2006"), + hsic.WithConfigEnv(map[string]string{ + "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "1m6s", + }), + ) + assertNoErrHeadscaleEnv(t, err) + + for userName, clientCount := range spec { + err = scenario.CreateUser(userName) + if err != nil { + t.Fatalf("failed to create user %s: %s", userName, err) + } + + err = scenario.CreateTailscaleNodesInUser(userName, "all", clientCount, []tsic.Option{}...) + if err != nil { + t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err) + } + + key, err := scenario.CreatePreAuthKey(userName, true, true) + if err != nil { + t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) + } + + err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) + if err != nil { + t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) + } + } + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + assertNoErrListClientIPs(t, err) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + // All ephemeral nodes should be online and reachable. + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + // Take down all clients, this should start an expiry timer for each. + for _, client := range allClients { + err := client.Down() + if err != nil { + t.Fatalf("failed to take down client %s: %s", client.Hostname(), err) + } + } + + // Wait a bit and bring up the clients again before the expiry + // time of the ephemeral nodes. + // Nodes should be able to reconnect and work fine. + time.Sleep(30 * time.Second) + + for _, client := range allClients { + err := client.Up() + if err != nil { + t.Fatalf("failed to take down client %s: %s", client.Hostname(), err) + } + } + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + success = pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + // Take down all clients, this should start an expiry timer for each. + for _, client := range allClients { + err := client.Down() + if err != nil { + t.Fatalf("failed to take down client %s: %s", client.Hostname(), err) + } + } + + // This time wait for all of the nodes to expire and check that they are no longer + // registered. + time.Sleep(3 * time.Minute) + + for userName := range spec { + nodes, err := headscale.ListNodesInUser(userName) + if err != nil { + log.Error(). + Err(err). + Str("user", userName). + Msg("Error listing nodes in user") + + return + } + + if len(nodes) != 0 { + t.Fatalf("expected no nodes, got %d in user %s", len(nodes), userName) + } + } +} + func TestPingAllByHostname(t *testing.T) { IntegrationSkip(t) t.Parallel() From 69f3583dfb7c9611e6a68f99f4c5bd6c81c5dce0 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 15 Jul 2024 10:21:28 +0200 Subject: [PATCH 2/7] use uint64 as authekyid and ptr helper in tests Signed-off-by: Kristoffer Dalby --- hscontrol/auth.go | 8 ++++---- hscontrol/db/node_test.go | 29 +++++++++++------------------ hscontrol/db/preauth_keys.go | 4 ++-- hscontrol/db/preauth_keys_test.go | 7 +++---- hscontrol/db/routes_test.go | 15 ++++++--------- hscontrol/db/users_test.go | 7 +++---- hscontrol/types/node.go | 2 +- 7 files changed, 30 insertions(+), 42 deletions(-) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 5ee925a6a0..010d15a255 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -16,6 +16,7 @@ import ( "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) func logAuthFunc( @@ -314,9 +315,8 @@ func (h *Headscale) handleAuthKey( Msg("node was already registered before, refreshing with new auth key") node.NodeKey = nodeKey - pakID := uint(pak.ID) - if pakID != 0 { - node.AuthKeyID = &pakID + if pak.ID != 0 { + node.AuthKeyID = ptr.To(pak.ID) } node.Expiry = ®isterRequest.Expiry @@ -394,7 +394,7 @@ func (h *Headscale) handleAuthKey( pakID := uint(pak.ID) if pakID != 0 { - nodeToRegister.AuthKeyID = &pakID + nodeToRegister.AuthKeyID = ptr.To(pak.ID) } node, err = h.db.RegisterNode( nodeToRegister, diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index b164e526cf..8f813e4846 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -16,9 +16,11 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/puzpuzpuz/xsync/v3" + "github.com/stretchr/testify/assert" "gopkg.in/check.v1" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) func (s *Suite) TestGetNode(c *check.C) { @@ -33,7 +35,6 @@ func (s *Suite) TestGetNode(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() - pakID := uint(pak.ID) node := &types.Node{ ID: 0, @@ -42,7 +43,7 @@ func (s *Suite) TestGetNode(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(node) c.Assert(trx.Error, check.IsNil) @@ -64,7 +65,6 @@ func (s *Suite) TestGetNodeByID(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() - pakID := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -72,7 +72,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -96,7 +96,6 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { machineKey := key.NewMachine() - pakID := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -104,7 +103,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -148,7 +147,6 @@ func (s *Suite) TestListPeers(c *check.C) { _, err = db.GetNodeByID(0) c.Assert(err, check.NotNil) - pakID := uint(pak.ID) for index := 0; index <= 10; index++ { nodeKey := key.NewNode() machineKey := key.NewMachine() @@ -160,7 +158,7 @@ func (s *Suite) TestListPeers(c *check.C) { Hostname: "testnode" + strconv.Itoa(index), UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -200,7 +198,6 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { for index := 0; index <= 10; index++ { nodeKey := key.NewNode() machineKey := key.NewMachine() - pakID := uint(stor[index%2].key.ID) v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))) node := types.Node{ @@ -211,7 +208,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { Hostname: "testnode" + strconv.Itoa(index), UserID: stor[index%2].user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(stor[index%2].key.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -286,7 +283,6 @@ func (s *Suite) TestExpireNode(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() - pakID := uint(pak.ID) node := &types.Node{ ID: 0, @@ -295,7 +291,7 @@ func (s *Suite) TestExpireNode(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Expiry: &time.Time{}, } db.DB.Save(node) @@ -331,7 +327,6 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { machineKey2 := key.NewMachine() - pakID := uint(pak.ID) node := &types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -340,7 +335,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { GivenName: "hostname-1", UserID: user1.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(node) @@ -375,7 +370,6 @@ func (s *Suite) TestSetTags(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() - pakID := uint(pak.ID) node := &types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -383,7 +377,7 @@ func (s *Suite) TestSetTags(c *check.C) { Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(node) @@ -569,7 +563,6 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { route2 := netip.MustParsePrefix("10.11.0.0/24") v4 := netip.MustParseAddr("100.64.0.1") - pakID := uint(pak.ID) node := types.Node{ ID: 0, MachineKey: machineKey.Public(), @@ -577,7 +570,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { Hostname: "test", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:exit"}, RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index adfd289a49..5ea59a9c90 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -10,6 +10,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "gorm.io/gorm" + "tailscale.com/types/ptr" ) var ( @@ -197,10 +198,9 @@ func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) { } nodes := types.Nodes{} - pakID := uint(pak.ID) if err := tx. Preload("AuthKey"). - Where(&types.Node{AuthKeyID: &pakID}). + Where(&types.Node{AuthKeyID: ptr.To(pak.ID)}). Find(&nodes).Error; err != nil { return nil, err } diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 88d5b213f5..9dd5b199ec 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -6,6 +6,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" + "tailscale.com/types/ptr" ) func (*Suite) TestCreatePreAuthKey(c *check.C) { @@ -75,13 +76,12 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "testest", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -98,13 +98,12 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) c.Assert(err, check.IsNil) - pakID := uint(pak.ID) node := types.Node{ ID: 1, Hostname: "testest", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 8bbc594868..122a7ff376 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -14,6 +14,7 @@ import ( "gopkg.in/check.v1" "gorm.io/gorm" "tailscale.com/tailcfg" + "tailscale.com/types/ptr" ) var smap = func(m map[types.NodeID]bool) *xsync.MapOf[types.NodeID, bool] { @@ -43,13 +44,12 @@ func (s *Suite) TestGetRoutes(c *check.C) { RoutableIPs: []netip.Prefix{route}, } - pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "test_get_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Hostinfo: &hostInfo, } trx := db.DB.Save(&node) @@ -95,13 +95,12 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { RoutableIPs: []netip.Prefix{route, route2}, } - pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Hostinfo: &hostInfo, } trx := db.DB.Save(&node) @@ -169,13 +168,12 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { hostInfo1 := tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{route, route2}, } - pakID := uint(pak.ID) node1 := types.Node{ ID: 1, Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Hostinfo: &hostInfo1, } trx := db.DB.Save(&node1) @@ -199,7 +197,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Hostinfo: &hostInfo2, } db.DB.Save(&node2) @@ -253,13 +251,12 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { } now := time.Now() - pakID := uint(pak.ID) node1 := types.Node{ ID: 1, Hostname: "test_enable_route_node", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), Hostinfo: &hostInfo1, LastSeen: &now, } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 98dea6c0dc..0629480c3d 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -5,6 +5,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" "gorm.io/gorm" + "tailscale.com/types/ptr" ) func (s *Suite) TestCreateAndDestroyUser(c *check.C) { @@ -46,13 +47,12 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "testnode", UserID: user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) @@ -100,13 +100,12 @@ func (s *Suite) TestSetMachineUser(c *check.C) { pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - pakID := uint(pak.ID) node := types.Node{ ID: 0, Hostname: "testnode", UserID: oldUser.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, + AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 6bee5c42e7..039c32c5c5 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -119,7 +119,7 @@ type Node struct { ForcedTags StringList // TODO(kradalby): This seems like irrelevant information? - AuthKeyID *uint `sql:"DEFAULT:NULL"` + AuthKeyID *uint64 `sql:"DEFAULT:NULL"` AuthKey *PreAuthKey `gorm:"constraint:OnDelete:SET NULL;"` LastSeen *time.Time From 4ee0b53109ce72423fe58a2cdb9e39ec1f7ea2ca Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 15 Jul 2024 10:21:50 +0200 Subject: [PATCH 3/7] add test db helper Signed-off-by: Kristoffer Dalby --- hscontrol/db/suite_test.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 1b97ce06a1..d546b33d79 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -36,10 +36,18 @@ func (s *Suite) ResetDB(c *check.C) { // } var err error - tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") + db, err = newTestDB() if err != nil { c.Fatal(err) } +} + +func newTestDB() (*HSDatabase, error) { + var err error + tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") + if err != nil { + return nil, err + } log.Printf("database path: %s", tmpDir+"/headscale_test.db") @@ -53,6 +61,8 @@ func (s *Suite) ResetDB(c *check.C) { "", ) if err != nil { - c.Fatal(err) + return nil, err } + + return db, nil } From b6c042480a53a66d2698e2a263e8aa49128888fd Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 15 Jul 2024 10:22:27 +0200 Subject: [PATCH 4/7] add list ephemeral node func Signed-off-by: Kristoffer Dalby --- hscontrol/db/node.go | 14 ++++++++++ hscontrol/db/node_test.go | 57 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index d0f0fa48d9..6f7c195cc8 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -79,6 +79,20 @@ func ListNodes(tx *gorm.DB) (types.Nodes, error) { return nodes, nil } +func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { + nodes := types.Nodes{} + if err := rx.Debug().Joins("AuthKey").Where("AuthKey.ephemeral = true").Find(&nodes).Error; err != nil { + return nil, err + } + // if err := rx.Debug().Preload("AuthKey", "ephemeral = 1").Find(&nodes).Error; err != nil { + // return nil, err + // } + + return nodes, nil + }) +} + func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) { nodes := types.Nodes{} if err := tx. diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 8f813e4846..c90d571efa 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -643,6 +643,7 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) { time.Sleep(10 * time.Second) + e.Close() if len(got) != want { t.Errorf("expected %d, got %d", want, len(got)) } @@ -657,3 +658,59 @@ func generateRandomNumber(t *testing.T, max int64) int64 { } return n.Int64() + 1 } + +func TestListEphemeralNodes(t *testing.T) { + db, err := newTestDB() + if err != nil { + t.Fatalf("creating db: %s", err) + } + + user, err := db.CreateUser("test") + assert.NoError(t, err) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + assert.NoError(t, err) + + pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) + assert.NoError(t, err) + + node := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pak.ID), + } + + nodeEph := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "ephemeral", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pakEph.ID), + } + + err = db.DB.Save(&node).Error + assert.NoError(t, err) + + err = db.DB.Save(&nodeEph).Error + assert.NoError(t, err) + + nodes, err := db.ListNodes() + assert.NoError(t, err) + + ephemeralNodes, err := db.ListEphemeralNodes() + assert.NoError(t, err) + + assert.Len(t, nodes, 2) + assert.Len(t, ephemeralNodes, 1) + + assert.Equal(t, nodeEph.ID, ephemeralNodes[0].ID) + assert.Equal(t, nodeEph.AuthKeyID, ephemeralNodes[0].AuthKeyID) + assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID) + assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname) +} From efd7964c87077c117f4ae8ccafd49494cb75d939 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 15 Jul 2024 10:22:45 +0200 Subject: [PATCH 5/7] schedule ephemeral nodes for removal on startup Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/hscontrol/app.go b/hscontrol/app.go index 2862c60e1b..27c26d62b2 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -518,7 +518,18 @@ func (h *Headscale) Serve() error { return errEmptyInitialDERPMap } + // Start ephemeral node garbage collector and schedule all nodes + // that are already in the database and ephemeral. If they are still + // around between restarts, they will reconnect and the GC will + // be cancelled. go h.ephemeralGC.Start() + ephmNodes, err := h.db.ListEphemeralNodes() + if err != nil { + return fmt.Errorf("failed to list ephemeral nodes: %w", err) + } + for _, node := range ephmNodes { + h.ephemeralGC.Schedule(node.ID, h.cfg.EphemeralNodeInactivityTimeout) + } expireNodeCtx, expireNodeCancel := context.WithCancel(context.Background()) defer expireNodeCancel() From 48731c200b16857edfb41fa5418f3d42ce03fae5 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 16 Jul 2024 09:25:16 +0200 Subject: [PATCH 6/7] fix gorm query for postgres Signed-off-by: Kristoffer Dalby --- hscontrol/db/node.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 6f7c195cc8..f4bdfdce9f 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -82,12 +82,9 @@ func ListNodes(tx *gorm.DB) (types.Nodes, error) { func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { nodes := types.Nodes{} - if err := rx.Debug().Joins("AuthKey").Where("AuthKey.ephemeral = true").Find(&nodes).Error; err != nil { + if err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error; err != nil { return nil, err } - // if err := rx.Debug().Preload("AuthKey", "ephemeral = 1").Find(&nodes).Error; err != nil { - // return nil, err - // } return nodes, nil }) From 8edb253542047c88482e91b6a0fedb16a3fab206 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 17 Jul 2024 17:13:26 +0200 Subject: [PATCH 7/7] add godoc Signed-off-by: Kristoffer Dalby --- hscontrol/db/node.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index f4bdfdce9f..a2515ebfc9 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -719,6 +719,10 @@ func ExpireExpiredNodes(tx *gorm.DB, return started, types.StateUpdate{}, false } +// EphemeralGarbageCollector is a garbage collector that will delete nodes after +// a certain amount of time. +// It is used to delete ephemeral nodes that have disconnected and should be +// cleaned up. type EphemeralGarbageCollector struct { mu deadlock.Mutex @@ -729,6 +733,8 @@ type EphemeralGarbageCollector struct { cancelCh chan struct{} } +// NewEphemeralGarbageCollector creates a new EphemeralGarbageCollector, it takes +// a deleteFunc that will be called when a node is scheduled for deletion. func NewEphemeralGarbageCollector(deleteFunc func(types.NodeID)) *EphemeralGarbageCollector { return &EphemeralGarbageCollector{ toBeDeleted: make(map[types.NodeID]*time.Timer), @@ -738,10 +744,12 @@ func NewEphemeralGarbageCollector(deleteFunc func(types.NodeID)) *EphemeralGarba } } +// Close stops the garbage collector. func (e *EphemeralGarbageCollector) Close() { e.cancelCh <- struct{}{} } +// Schedule schedules a node for deletion after the expiry duration. func (e *EphemeralGarbageCollector) Schedule(nodeID types.NodeID, expiry time.Duration) { e.mu.Lock() defer e.mu.Unlock() @@ -759,6 +767,7 @@ func (e *EphemeralGarbageCollector) Schedule(nodeID types.NodeID, expiry time.Du }() } +// Cancel cancels the deletion of a node. func (e *EphemeralGarbageCollector) Cancel(nodeID types.NodeID) { e.mu.Lock() defer e.mu.Unlock() @@ -769,6 +778,7 @@ func (e *EphemeralGarbageCollector) Cancel(nodeID types.NodeID) { } } +// Start starts the garbage collector. func (e *EphemeralGarbageCollector) Start() { for { select {